Skip to content

Commit

Permalink
server/product: require organization_id filter on list only if not or…
Browse files Browse the repository at this point in the history
…ganization token
  • Loading branch information
frankie567 committed Oct 25, 2024
1 parent 7159d44 commit 6f5e856
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 46 deletions.
9 changes: 7 additions & 2 deletions server/polar/product/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,13 @@ async def list(
pagination: PaginationParamsQuery,
sorting: ListSorting,
auth_subject: auth.CreatorProductsReadOrAnonymous,
organization_id: MultipleQueryFilter[OrganizationID] = Query(
title="OrganizationID Filter", description="Filter by organization ID."
organization_id: MultipleQueryFilter[OrganizationID] | None = Query(
None,
title="OrganizationID Filter",
description=(
"Filter by organization ID. "
"**Required unless you use an organization token.**"
),
),
query: str | None = Query(None, description="Filter by product name."),
is_archived: bool | None = Query(None, description="Filter on archived products."),
Expand Down
23 changes: 15 additions & 8 deletions server/polar/product/service/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from polar.auth.models import (
AuthSubject,
Subject,
is_organization,
is_user,
)
from polar.auth.models import AuthSubject, Subject, is_organization, is_user
from polar.authz.service import AccessType, Authz
from polar.benefit.service.benefit import benefit as benefit_service
from polar.exceptions import NotPermitted, PolarError, PolarRequestValidationError
Expand Down Expand Up @@ -66,7 +61,7 @@ async def list(
session: AsyncSession,
auth_subject: AuthSubject[Subject],
*,
organization_id: Sequence[uuid.UUID],
organization_id: Sequence[uuid.UUID] | None = None,
query: str | None = None,
is_archived: bool | None = None,
is_recurring: bool | None = None,
Expand Down Expand Up @@ -95,7 +90,19 @@ async def list(
isouter=True,
)

statement = statement.where(Product.organization_id.in_(organization_id))
if organization_id is not None:
statement = statement.where(Product.organization_id.in_(organization_id))
elif not is_organization(auth_subject):
raise PolarRequestValidationError(
[
{
"type": "missing",
"loc": ("query", "organization_id"),
"msg": "Field is required.",
"input": None,
}
]
)

if query is not None:
statement = statement.where(Product.name.ilike(f"%{query}%"))
Expand Down
45 changes: 21 additions & 24 deletions server/tests/product/service/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ async def test_anonymous(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[p.organization_id for p in products],
pagination=PaginationParams(1, 10),
organization_id=[products[0].organization_id],
)

assert count == 3
assert len(results) == 3
assert count == 2
assert len(results) == 2

@pytest.mark.auth
async def test_user(
Expand All @@ -100,12 +100,12 @@ async def test_user(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[p.organization_id for p in products],
pagination=PaginationParams(1, 10),
organization_id=[products[0].organization_id],
)

assert count == 3
assert len(results) == 3
assert count == 2
assert len(results) == 2

@pytest.mark.auth
async def test_user_organization(
Expand All @@ -121,12 +121,12 @@ async def test_user_organization(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[p.organization_id for p in products],
pagination=PaginationParams(1, 10),
organization_id=[products[0].organization_id],
)

assert count == 3
assert len(results) == 3
assert count == 2
assert len(results) == 2

@pytest.mark.auth(AuthSubjectFixture(subject="organization"))
async def test_organization(
Expand All @@ -139,10 +139,7 @@ async def test_organization(
session.expunge_all()

results, count = await product_service.list(
session,
auth_subject,
organization_id=[p.organization_id for p in products],
pagination=PaginationParams(1, 10),
session, auth_subject, pagination=PaginationParams(1, 10)
)

assert count == 2
Expand Down Expand Up @@ -174,7 +171,7 @@ async def test_filter_is_recurring(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[organization.id],
organization_id=[recurring_product.organization_id],
is_recurring=True,
pagination=PaginationParams(1, 10),
)
Expand All @@ -186,7 +183,7 @@ async def test_filter_is_recurring(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[organization.id],
organization_id=[recurring_product.organization_id],
is_recurring=False,
pagination=PaginationParams(1, 10),
)
Expand Down Expand Up @@ -241,15 +238,15 @@ async def test_filter_is_archived(
session,
get_auth_subject(Anonymous()),
is_archived=False,
organization_id=[organization.id],
organization_id=[archived_product.organization_id],
pagination=PaginationParams(1, 10),
)
assert count == 0
assert len(results) == 0
results, count = await product_service.list(
session,
get_auth_subject(Anonymous()),
organization_id=[organization.id],
organization_id=[archived_product.organization_id],
pagination=PaginationParams(1, 10),
)
assert count == 0
Expand All @@ -260,7 +257,7 @@ async def test_filter_is_archived(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[user_organization.organization_id],
organization_id=[archived_product.organization_id],
is_archived=False,
pagination=PaginationParams(1, 10),
)
Expand All @@ -269,7 +266,7 @@ async def test_filter_is_archived(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[user_organization.organization_id],
organization_id=[archived_product.organization_id],
pagination=PaginationParams(1, 10),
)
assert count == 1
Expand All @@ -295,7 +292,7 @@ async def test_filter_include_archived_authed_non_member(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[organization.id],
organization_id=[archived_product.organization_id],
is_archived=False,
pagination=PaginationParams(1, 10),
)
Expand All @@ -304,7 +301,7 @@ async def test_filter_include_archived_authed_non_member(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[organization.id],
organization_id=[archived_product.organization_id],
pagination=PaginationParams(1, 10),
)
assert count == 0
Expand Down Expand Up @@ -413,23 +410,23 @@ async def test_pagination(
results, count = await product_service.list(
session,
auth_subject,
organization_id=[user_organization.organization_id],
organization_id=[organization.id],
pagination=PaginationParams(1, 8), # page 1, limit 8
)
assert 20 == count
assert 8 == len(results)
results, count = await product_service.list(
session,
auth_subject,
organization_id=[user_organization.organization_id],
organization_id=[organization.id],
pagination=PaginationParams(2, 8), # page 2, limit 8
)
assert 20 == count
assert 8 == len(results)
results, count = await product_service.list(
session,
auth_subject,
organization_id=[user_organization.organization_id],
organization_id=[organization.id],
pagination=PaginationParams(3, 8), # page 3, limit 8
)
assert 20 == count
Expand Down
12 changes: 0 additions & 12 deletions server/tests/product/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,6 @@ async def test_anonymous(
json = response.json()
assert json["pagination"]["total_count"] == 2

async def test_anonymous_without_organization_filter(
self,
client: AsyncClient,
organization: Organization,
products: list[Product],
) -> None:
response = await client.get(
"/v1/products/",
)

assert response.status_code == 422

async def test_with_benefits(
self,
session: AsyncSession,
Expand Down

0 comments on commit 6f5e856

Please sign in to comment.