diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index effc085a36..568564b6db 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -815,6 +815,7 @@ async def handle_stripe_success( customer=stripe_customer_id, currency=checkout.currency or "usd", price=stripe_price_id, + automatic_tax=checkout.product.is_tax_applicable, metadata=metadata, invoice_metadata={ "payment_intent_id": payment_intent.id, @@ -833,6 +834,7 @@ async def handle_stripe_success( subscription_id=subscription.stripe_subscription_id, old_price=subscription.price.stripe_price_id, new_price=stripe_price_id, + automatic_tax=checkout.product.is_tax_applicable, metadata=metadata, invoice_metadata={ "payment_intent_id": payment_intent.id, @@ -850,6 +852,7 @@ async def handle_stripe_success( customer=stripe_customer_id, currency=checkout.currency or "usd", price=stripe_price_id, + automatic_tax=checkout.product.is_tax_applicable, metadata={ **metadata, "payment_intent_id": payment_intent.id, @@ -1188,6 +1191,10 @@ async def _update_checkout( async def _update_checkout_tax( self, session: AsyncSession, checkout: Checkout ) -> Checkout: + if not checkout.product.is_tax_applicable: + checkout.tax_amount = 0 + return checkout + if ( checkout.currency is not None and checkout.amount is not None diff --git a/server/polar/integrations/stripe/service.py b/server/polar/integrations/stripe/service.py index 151cf99dc6..7831873c7c 100644 --- a/server/polar/integrations/stripe/service.py +++ b/server/polar/integrations/stripe/service.py @@ -634,6 +634,7 @@ async def update_out_of_band_subscription( subscription_id: str, old_price: str, new_price: str, + automatic_tax: bool = True, metadata: dict[str, str] | None = None, invoice_metadata: dict[str, str] | None = None, idempotency_key: str | None = None, @@ -643,7 +644,7 @@ async def update_out_of_band_subscription( modify_params: stripe_lib.Subscription.ModifyParams = { "collection_method": "send_invoice", "days_until_due": 0, - "automatic_tax": {"enabled": True}, + "automatic_tax": {"enabled": automatic_tax}, } if metadata is not None: modify_params["metadata"] = metadata diff --git a/server/polar/product/service/product.py b/server/polar/product/service/product.py index 02745d3e9a..6b571a695a 100644 --- a/server/polar/product/service/product.py +++ b/server/polar/product/service/product.py @@ -4,7 +4,6 @@ import stripe from sqlalchemy import Select, UnaryExpression, asc, case, desc, func, select -from sqlalchemy.exc import InvalidRequestError from sqlalchemy.orm import contains_eager, joinedload, selectinload from polar.auth.models import AuthSubject, is_organization, is_user @@ -364,7 +363,6 @@ async def update( update_schema: ProductUpdate, auth_subject: AuthSubject[User | Organization], ) -> Product: - product = await self.with_organization(session, product) subject = auth_subject.subject if not await authz.can(subject, AccessType.write, product): @@ -523,8 +521,6 @@ async def update_benefits( benefits: List[uuid.UUID], # noqa: UP006 auth_subject: AuthSubject[User | Organization], ) -> tuple[Product, set[Benefit], set[Benefit]]: - product = await self.with_organization(session, product) - subject = auth_subject.subject if not await authz.can(subject, AccessType.write, product): raise NotPermitted() @@ -532,15 +528,10 @@ async def update_benefits( previous_benefits = set(product.benefits) new_benefits: set[Benefit] = set() - nested = await session.begin_nested() - - product.product_benefits = [] - await session.flush() - + new_product_benefits: list[ProductBenefit] = [] for order, benefit_id in enumerate(benefits): benefit = await benefit_service.get_by_id(session, auth_subject, benefit_id) if benefit is None: - await nested.rollback() raise PolarRequestValidationError( [ { @@ -563,9 +554,15 @@ async def update_benefits( ] ) new_benefits.add(benefit) - product.product_benefits.append( - ProductBenefit(benefit=benefit, order=order) - ) + new_product_benefits.append(ProductBenefit(benefit=benefit, order=order)) + + # Remove all previous benefits: flush to actually remove them + product.product_benefits = [] + session.add(product) + await session.flush() + + # Set the new benefits + product.product_benefits = new_product_benefits added_benefits = new_benefits - previous_benefits deleted_benefits = previous_benefits - new_benefits @@ -597,15 +594,6 @@ async def update_benefits( return product, added_benefits, deleted_benefits - async def with_organization( - self, session: AsyncSession, product: Product - ) -> Product: - try: - product.organization - except InvalidRequestError: - await session.refresh(product, {"organization"}) - return product - async def _archive(self, product: Product) -> Product: if product.stripe_product_id is not None: await stripe_service.archive_product(product.stripe_product_id) @@ -675,19 +663,15 @@ async def _send_webhook( event_type: Literal[WebhookEventType.product_created] | Literal[WebhookEventType.product_updated], ) -> None: - # load full tier with relations - full_product = await self.get_loaded(session, product.id, allow_deleted=True) - assert full_product - # mypy 1.9 is does not allow us to do # event = (event_type, subscription) # directly, even if it could have... event: WebhookTypeObject | None = None match event_type: case WebhookEventType.product_created: - event = (event_type, full_product) + event = (event_type, product) case WebhookEventType.product_updated: - event = (event_type, full_product) + event = (event_type, product) if managing_org := await organization_service.get( session, product.organization_id diff --git a/server/tests/benefit/service/test_benefit_grant.py b/server/tests/benefit/service/test_benefit_grant.py index 5b777e5f26..10ab24006b 100644 --- a/server/tests/benefit/service/test_benefit_grant.py +++ b/server/tests/benefit/service/test_benefit_grant.py @@ -21,10 +21,10 @@ from polar.subscription.service import subscription as subscription_service from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( - add_product_benefits, create_benefit_grant, create_order, create_subscription, + set_product_benefits, ) @@ -348,7 +348,7 @@ async def test_subscription_scope( "polar.benefit.service.benefit_grant.enqueue_job" ) - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits ) @@ -388,7 +388,7 @@ async def test_outdated_grants( grant.set_granted() await save_fixture(grant) - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits[1:] ) diff --git a/server/tests/checkout/legacy/test_service.py b/server/tests/checkout/legacy/test_service.py index 8b0f35943b..fa1e4b76b0 100644 --- a/server/tests/checkout/legacy/test_service.py +++ b/server/tests/checkout/legacy/test_service.py @@ -12,13 +12,12 @@ from polar.exceptions import PolarRequestValidationError, ResourceNotFound from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import StripeService -from polar.models import Organization, Product, User +from polar.models import Product, User from polar.postgres import AsyncSession from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( - add_product_benefits, - create_benefit, + set_product_benefits, ) SUCCESS_URL = Url("https://example.com/success") @@ -128,7 +127,7 @@ async def test_valid_anonymous( price.stripe_price_id, str(SUCCESS_URL), is_subscription=True, - is_tax_applicable=False, + is_tax_applicable=True, metadata={ "type": ProductType.product, "product_id": str(product.id), @@ -180,7 +179,7 @@ async def test_valid_user_cookie( price.stripe_price_id, str(SUCCESS_URL), is_subscription=True, - is_tax_applicable=False, + is_tax_applicable=True, customer="STRIPE_CUSTOMER_ID", metadata={ "type": ProductType.product, @@ -235,7 +234,7 @@ async def test_valid_token( price.stripe_price_id, str(SUCCESS_URL), is_subscription=True, - is_tax_applicable=False, + is_tax_applicable=True, metadata={ "type": ProductType.product, "product_id": str(product.id), @@ -289,7 +288,7 @@ async def test_valid_token_customer_email( price.stripe_price_id, str(SUCCESS_URL), is_subscription=True, - is_tax_applicable=False, + is_tax_applicable=True, customer_email="backer@example.com", metadata={ "type": ProductType.product, @@ -303,22 +302,18 @@ async def test_valid_token_customer_email( }, ) - async def test_valid_tax_applicable( + async def test_valid_tax_not_applicable( self, auth_subject: AuthSubject[Anonymous], session: AsyncSession, save_fixture: SaveFixture, product: Product, stripe_service_mock: MagicMock, - organization: Organization, ) -> None: - applicable_tax_benefit = await create_benefit( - save_fixture, is_tax_applicable=True, organization=organization - ) - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, - benefits=[applicable_tax_benefit], + benefits=[], ) create_checkout_session_mock: MagicMock = ( @@ -349,7 +344,7 @@ async def test_valid_tax_applicable( price.stripe_price_id, str(SUCCESS_URL), is_subscription=True, - is_tax_applicable=True, + is_tax_applicable=False, metadata={ "type": ProductType.product, "product_id": str(product.id), diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index fe911a4421..68c67d4d4c 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -180,6 +180,26 @@ async def checkout_custom_fields( return await create_checkout(save_fixture, price=product_custom_fields.prices[0]) +@pytest_asyncio.fixture +async def product_tax_not_applicable( + save_fixture: SaveFixture, organization: Organization +) -> Product: + return await create_product( + save_fixture, + organization=organization, + tax_applicable=False, + ) + + +@pytest_asyncio.fixture +async def checkout_tax_not_applicable( + save_fixture: SaveFixture, product_tax_not_applicable: Product +) -> Checkout: + return await create_checkout( + save_fixture, price=product_tax_not_applicable.prices[0] + ) + + @pytest.mark.asyncio @pytest.mark.skip_db_asserts class TestCreate: @@ -749,6 +769,30 @@ async def test_valid_embed_origin( assert checkout.embed_origin == "https://example.com" + async def test_valid_tax_not_applicable( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Organization], + user_organization: UserOrganization, + product_tax_not_applicable: Product, + ) -> None: + price = product_tax_not_applicable.prices[0] + assert isinstance(price, ProductPriceFixed) + + checkout = await checkout_service.create( + session, + CheckoutCreate( + payment_processor=PaymentProcessor.stripe, + product_price_id=price.id, + customer_billing_address=Address.model_validate({"country": "FR"}), + ), + auth_subject, + ) + + assert checkout.tax_amount == 0 + assert checkout.customer_billing_address is not None + assert checkout.customer_billing_address.country == "FR" + @pytest.mark.asyncio @pytest.mark.skip_db_asserts @@ -1345,6 +1389,21 @@ async def test_valid_embed_origin( assert checkout.embed_origin == "https://example.com" + async def test_valid_tax_not_applicable( + self, session: AsyncSession, checkout_tax_not_applicable: Checkout + ) -> None: + checkout = await checkout_service.update( + session, + checkout_tax_not_applicable, + CheckoutUpdate( + customer_billing_address=Address.model_validate({"country": "FR"}), + ), + ) + + assert checkout.tax_amount == 0 + assert checkout.customer_billing_address is not None + assert checkout.customer_billing_address.country == "FR" + @pytest.mark.asyncio @pytest.mark.skip_db_asserts diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 01e4ee4d57..5175cf85fc 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -581,12 +581,13 @@ async def create_product( | tuple[ProductPriceType, SubscriptionRecurringInterval | None] ] = [(1000, ProductPriceType.recurring, SubscriptionRecurringInterval.month)], attached_custom_fields: Sequence[tuple[CustomField, bool]] = [], + tax_applicable: bool = True, ) -> Product: product = Product( name=name, description="Description", is_archived=is_archived, - organization_id=organization.id, + organization=organization, stripe_product_id=rstr("PRODUCT_ID"), all_prices=[], prices=[], @@ -638,6 +639,12 @@ async def create_product( product.prices.append(product_price) product.all_prices.append(product_price) + if tax_applicable: + benefit = await create_benefit( + save_fixture, organization=organization, is_tax_applicable=True + ) + product.product_benefits.append(ProductBenefit(benefit=benefit, order=0)) + return product @@ -762,13 +769,14 @@ async def create_benefit( return benefit -async def add_product_benefits( +async def set_product_benefits( save_fixture: SaveFixture, *, product: Product, benefits: list[Benefit], ) -> Product: product.product_benefits = [] + await save_fixture(product) for order, benefit in enumerate(benefits): product.product_benefits.append(ProductBenefit(benefit=benefit, order=order)) await save_fixture(product) diff --git a/server/tests/product/service/test_product.py b/server/tests/product/service/test_product.py index c705d0072f..de14320367 100644 --- a/server/tests/product/service/test_product.py +++ b/server/tests/product/service/test_product.py @@ -37,9 +37,9 @@ from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( - add_product_benefits, create_benefit, create_product, + set_product_benefits, ) @@ -238,7 +238,7 @@ async def test_filter_benefit_id( user_organization: UserOrganization, ) -> None: for product in products[:2]: - await add_product_benefits( + await set_product_benefits( save_fixture, product=product, benefits=[benefit_organization, benefit_organization_second], @@ -1220,6 +1220,7 @@ async def test_valid_media( @pytest.mark.asyncio +@pytest.mark.skip_db_asserts class TestUpdateBenefits: @pytest.mark.auth async def test_not_writable_product( @@ -1229,16 +1230,9 @@ async def test_not_writable_product( authz: Authz, product: Product, ) -> None: - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded - with pytest.raises(NotPermitted): await product_service.update_benefits( - session, authz, product_organization_loaded, [], auth_subject + session, authz, product, [], auth_subject ) @pytest.mark.auth( @@ -1255,31 +1249,21 @@ async def test_not_existing_benefit( product: Product, benefits: list[Benefit], ) -> None: - product = await add_product_benefits( - save_fixture, - product=product, - benefits=benefits, + product = await set_product_benefits( + save_fixture, product=product, benefits=benefits ) - - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded + assert len(product.product_benefits) == len(benefits) with pytest.raises(PolarRequestValidationError): await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [uuid.uuid4()], auth_subject, ) - await session.refresh(product_organization_loaded) - - assert len(product_organization_loaded.product_benefits) == len(benefits) + assert len(product.product_benefits) == len(benefits) @pytest.mark.auth( AuthSubjectFixture(subject="user"), @@ -1287,6 +1271,7 @@ async def test_not_existing_benefit( ) async def test_added_benefits( self, + save_fixture: SaveFixture, session: AsyncSession, enqueue_job_mock: AsyncMock, authz: Authz, @@ -1295,12 +1280,7 @@ async def test_added_benefits( product: Product, benefits: list[Benefit], ) -> None: - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded + await set_product_benefits(save_fixture, product=product, benefits=[]) ( product, @@ -1309,7 +1289,7 @@ async def test_added_benefits( ) = await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [benefit.id for benefit in benefits], auth_subject, ) @@ -1339,6 +1319,7 @@ async def test_added_benefits( ) async def test_order( self, + save_fixture: SaveFixture, session: AsyncSession, enqueue_job_mock: AsyncMock, authz: Authz, @@ -1347,12 +1328,7 @@ async def test_order( product: Product, benefits: list[Benefit], ) -> None: - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded + await set_product_benefits(save_fixture, product=product, benefits=[]) ( product, @@ -1361,7 +1337,7 @@ async def test_order( ) = await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [benefit.id for benefit in benefits[::-1]], auth_subject, ) @@ -1400,25 +1376,18 @@ async def test_deleted( product: Product, benefits: list[Benefit], ) -> None: - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits, ) - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded - ( product, added, deleted, ) = await product_service.update_benefits( - session, authz, product_organization_loaded, [], auth_subject + session, authz, product, [], auth_subject ) await session.flush() @@ -1451,19 +1420,12 @@ async def test_reordering( product: Product, benefits: list[Benefit], ) -> None: - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits, ) - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded - ( product, added, @@ -1471,11 +1433,10 @@ async def test_reordering( ) = await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [benefit.id for benefit in benefits[::-1]], auth_subject, ) - await session.flush() assert len(product.product_benefits) == len(benefits) for i, product_benefit in enumerate(product.product_benefits): @@ -1517,18 +1478,11 @@ async def test_add_not_selectable( selectable=False, ) - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded - with pytest.raises(PolarRequestValidationError): await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [not_selectable_benefit.id], auth_subject, ) @@ -1555,24 +1509,17 @@ async def test_remove_not_selectable( selectable=False, ) - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=[not_selectable_benefit], ) - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded - with pytest.raises(PolarRequestValidationError): await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [], auth_subject, ) @@ -1607,19 +1554,12 @@ async def test_add_with_existing_not_selectable( organization=organization, description="SELECTABLE", ) - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=[not_selectable_benefit], ) - # then - session.expunge_all() - - # load - product_organization_loaded = await product_service.get(session, product.id) - assert product_organization_loaded - ( _, added, @@ -1627,7 +1567,7 @@ async def test_add_with_existing_not_selectable( ) = await product_service.update_benefits( session, authz, - product_organization_loaded, + product, [not_selectable_benefit.id, selectable_benefit.id], auth_subject, ) diff --git a/server/tests/product/test_endpoints.py b/server/tests/product/test_endpoints.py index 76e65ec900..a7f0f54c76 100644 --- a/server/tests/product/test_endpoints.py +++ b/server/tests/product/test_endpoints.py @@ -16,7 +16,7 @@ from polar.postgres import AsyncSession from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( - add_product_benefits, + set_product_benefits, ) @@ -44,7 +44,7 @@ async def test_with_benefits( product: Product, benefits: list[Benefit], ) -> None: - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits, @@ -110,7 +110,7 @@ async def test_valid_with_benefits( benefits: list[Benefit], user_organization: UserOrganization, ) -> None: - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits, diff --git a/server/tests/subscription/test_service.py b/server/tests/subscription/test_service.py index 1f0634130b..cf9165d6c1 100644 --- a/server/tests/subscription/test_service.py +++ b/server/tests/subscription/test_service.py @@ -30,9 +30,9 @@ from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( - add_product_benefits, create_active_subscription, create_subscription, + set_product_benefits, ) @@ -577,7 +577,7 @@ async def test_incomplete_subscription( ) -> None: enqueue_job_mock = mocker.patch("polar.subscription.service.enqueue_job") - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits, @@ -606,7 +606,7 @@ async def test_active_subscription( ) -> None: enqueue_job_mock = mocker.patch("polar.subscription.service.enqueue_job") - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits, @@ -650,7 +650,7 @@ async def test_canceled_subscription( ) -> None: enqueue_job_mock = mocker.patch("polar.subscription.service.enqueue_job") - product = await add_product_benefits( + product = await set_product_benefits( save_fixture, product=product, benefits=benefits,