Skip to content

Commit

Permalink
server/checkout: fix custom checkout always adding VAT even on produc…
Browse files Browse the repository at this point in the history
…t without VAT applicable
  • Loading branch information
frankie567 committed Nov 7, 2024
1 parent 2fa4422 commit 61cae5f
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 140 deletions.
7 changes: 7 additions & 0 deletions server/polar/checkout/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ async def handle_stripe_success(
customer=stripe_customer_id,
currency=checkout.currency or "usd",
price=stripe_price_id,
automatic_tax=checkout.product.is_tax_applicable,
metadata=metadata,
invoice_metadata={
"payment_intent_id": payment_intent.id,
Expand All @@ -833,6 +834,7 @@ async def handle_stripe_success(
subscription_id=subscription.stripe_subscription_id,
old_price=subscription.price.stripe_price_id,
new_price=stripe_price_id,
automatic_tax=checkout.product.is_tax_applicable,
metadata=metadata,
invoice_metadata={
"payment_intent_id": payment_intent.id,
Expand All @@ -850,6 +852,7 @@ async def handle_stripe_success(
customer=stripe_customer_id,
currency=checkout.currency or "usd",
price=stripe_price_id,
automatic_tax=checkout.product.is_tax_applicable,
metadata={
**metadata,
"payment_intent_id": payment_intent.id,
Expand Down Expand Up @@ -1188,6 +1191,10 @@ async def _update_checkout(
async def _update_checkout_tax(
self, session: AsyncSession, checkout: Checkout
) -> Checkout:
if not checkout.product.is_tax_applicable:
checkout.tax_amount = 0
return checkout

if (
checkout.currency is not None
and checkout.amount is not None
Expand Down
3 changes: 2 additions & 1 deletion server/polar/integrations/stripe/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ async def update_out_of_band_subscription(
subscription_id: str,
old_price: str,
new_price: str,
automatic_tax: bool = True,
metadata: dict[str, str] | None = None,
invoice_metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
Expand All @@ -643,7 +644,7 @@ async def update_out_of_band_subscription(
modify_params: stripe_lib.Subscription.ModifyParams = {
"collection_method": "send_invoice",
"days_until_due": 0,
"automatic_tax": {"enabled": True},
"automatic_tax": {"enabled": automatic_tax},
}
if metadata is not None:
modify_params["metadata"] = metadata
Expand Down
40 changes: 12 additions & 28 deletions server/polar/product/service/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import stripe
from sqlalchemy import Select, UnaryExpression, asc, case, desc, func, select
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from polar.auth.models import AuthSubject, is_organization, is_user
Expand Down Expand Up @@ -364,7 +363,6 @@ async def update(
update_schema: ProductUpdate,
auth_subject: AuthSubject[User | Organization],
) -> Product:
product = await self.with_organization(session, product)
subject = auth_subject.subject

if not await authz.can(subject, AccessType.write, product):
Expand Down Expand Up @@ -523,24 +521,17 @@ async def update_benefits(
benefits: List[uuid.UUID], # noqa: UP006
auth_subject: AuthSubject[User | Organization],
) -> tuple[Product, set[Benefit], set[Benefit]]:
product = await self.with_organization(session, product)

subject = auth_subject.subject
if not await authz.can(subject, AccessType.write, product):
raise NotPermitted()

previous_benefits = set(product.benefits)
new_benefits: set[Benefit] = set()

nested = await session.begin_nested()

product.product_benefits = []
await session.flush()

new_product_benefits: list[ProductBenefit] = []
for order, benefit_id in enumerate(benefits):
benefit = await benefit_service.get_by_id(session, auth_subject, benefit_id)
if benefit is None:
await nested.rollback()
raise PolarRequestValidationError(
[
{
Expand All @@ -563,9 +554,15 @@ async def update_benefits(
]
)
new_benefits.add(benefit)
product.product_benefits.append(
ProductBenefit(benefit=benefit, order=order)
)
new_product_benefits.append(ProductBenefit(benefit=benefit, order=order))

# Remove all previous benefits: flush to actually remove them
product.product_benefits = []
session.add(product)
await session.flush()

# Set the new benefits
product.product_benefits = new_product_benefits

added_benefits = new_benefits - previous_benefits
deleted_benefits = previous_benefits - new_benefits
Expand Down Expand Up @@ -597,15 +594,6 @@ async def update_benefits(

return product, added_benefits, deleted_benefits

async def with_organization(
self, session: AsyncSession, product: Product
) -> Product:
try:
product.organization
except InvalidRequestError:
await session.refresh(product, {"organization"})
return product

async def _archive(self, product: Product) -> Product:
if product.stripe_product_id is not None:
await stripe_service.archive_product(product.stripe_product_id)
Expand Down Expand Up @@ -675,19 +663,15 @@ async def _send_webhook(
event_type: Literal[WebhookEventType.product_created]
| Literal[WebhookEventType.product_updated],
) -> None:
# load full tier with relations
full_product = await self.get_loaded(session, product.id, allow_deleted=True)
assert full_product

# mypy 1.9 is does not allow us to do
# event = (event_type, subscription)
# directly, even if it could have...
event: WebhookTypeObject | None = None
match event_type:
case WebhookEventType.product_created:
event = (event_type, full_product)
event = (event_type, product)
case WebhookEventType.product_updated:
event = (event_type, full_product)
event = (event_type, product)

if managing_org := await organization_service.get(
session, product.organization_id
Expand Down
6 changes: 3 additions & 3 deletions server/tests/benefit/service/test_benefit_grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from polar.subscription.service import subscription as subscription_service
from tests.fixtures.database import SaveFixture
from tests.fixtures.random_objects import (
add_product_benefits,
create_benefit_grant,
create_order,
create_subscription,
set_product_benefits,
)


Expand Down Expand Up @@ -348,7 +348,7 @@ async def test_subscription_scope(
"polar.benefit.service.benefit_grant.enqueue_job"
)

product = await add_product_benefits(
product = await set_product_benefits(
save_fixture, product=product, benefits=benefits
)

Expand Down Expand Up @@ -388,7 +388,7 @@ async def test_outdated_grants(
grant.set_granted()
await save_fixture(grant)

product = await add_product_benefits(
product = await set_product_benefits(
save_fixture, product=product, benefits=benefits[1:]
)

Expand Down
25 changes: 10 additions & 15 deletions server/tests/checkout/legacy/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from polar.exceptions import PolarRequestValidationError, ResourceNotFound
from polar.integrations.stripe.schemas import ProductType
from polar.integrations.stripe.service import StripeService
from polar.models import Organization, Product, User
from polar.models import Product, User
from polar.postgres import AsyncSession
from tests.fixtures.auth import AuthSubjectFixture
from tests.fixtures.database import SaveFixture
from tests.fixtures.random_objects import (
add_product_benefits,
create_benefit,
set_product_benefits,
)

SUCCESS_URL = Url("https://example.com/success")
Expand Down Expand Up @@ -128,7 +127,7 @@ async def test_valid_anonymous(
price.stripe_price_id,
str(SUCCESS_URL),
is_subscription=True,
is_tax_applicable=False,
is_tax_applicable=True,
metadata={
"type": ProductType.product,
"product_id": str(product.id),
Expand Down Expand Up @@ -180,7 +179,7 @@ async def test_valid_user_cookie(
price.stripe_price_id,
str(SUCCESS_URL),
is_subscription=True,
is_tax_applicable=False,
is_tax_applicable=True,
customer="STRIPE_CUSTOMER_ID",
metadata={
"type": ProductType.product,
Expand Down Expand Up @@ -235,7 +234,7 @@ async def test_valid_token(
price.stripe_price_id,
str(SUCCESS_URL),
is_subscription=True,
is_tax_applicable=False,
is_tax_applicable=True,
metadata={
"type": ProductType.product,
"product_id": str(product.id),
Expand Down Expand Up @@ -289,7 +288,7 @@ async def test_valid_token_customer_email(
price.stripe_price_id,
str(SUCCESS_URL),
is_subscription=True,
is_tax_applicable=False,
is_tax_applicable=True,
customer_email="[email protected]",
metadata={
"type": ProductType.product,
Expand All @@ -303,22 +302,18 @@ async def test_valid_token_customer_email(
},
)

async def test_valid_tax_applicable(
async def test_valid_tax_not_applicable(
self,
auth_subject: AuthSubject[Anonymous],
session: AsyncSession,
save_fixture: SaveFixture,
product: Product,
stripe_service_mock: MagicMock,
organization: Organization,
) -> None:
applicable_tax_benefit = await create_benefit(
save_fixture, is_tax_applicable=True, organization=organization
)
product = await add_product_benefits(
product = await set_product_benefits(
save_fixture,
product=product,
benefits=[applicable_tax_benefit],
benefits=[],
)

create_checkout_session_mock: MagicMock = (
Expand Down Expand Up @@ -349,7 +344,7 @@ async def test_valid_tax_applicable(
price.stripe_price_id,
str(SUCCESS_URL),
is_subscription=True,
is_tax_applicable=True,
is_tax_applicable=False,
metadata={
"type": ProductType.product,
"product_id": str(product.id),
Expand Down
59 changes: 59 additions & 0 deletions server/tests/checkout/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@ async def checkout_custom_fields(
return await create_checkout(save_fixture, price=product_custom_fields.prices[0])


@pytest_asyncio.fixture
async def product_tax_not_applicable(
save_fixture: SaveFixture, organization: Organization
) -> Product:
return await create_product(
save_fixture,
organization=organization,
tax_applicable=False,
)


@pytest_asyncio.fixture
async def checkout_tax_not_applicable(
save_fixture: SaveFixture, product_tax_not_applicable: Product
) -> Checkout:
return await create_checkout(
save_fixture, price=product_tax_not_applicable.prices[0]
)


@pytest.mark.asyncio
@pytest.mark.skip_db_asserts
class TestCreate:
Expand Down Expand Up @@ -749,6 +769,30 @@ async def test_valid_embed_origin(

assert checkout.embed_origin == "https://example.com"

async def test_valid_tax_not_applicable(
self,
session: AsyncSession,
auth_subject: AuthSubject[User | Organization],
user_organization: UserOrganization,
product_tax_not_applicable: Product,
) -> None:
price = product_tax_not_applicable.prices[0]
assert isinstance(price, ProductPriceFixed)

checkout = await checkout_service.create(
session,
CheckoutCreate(
payment_processor=PaymentProcessor.stripe,
product_price_id=price.id,
customer_billing_address=Address.model_validate({"country": "FR"}),
),
auth_subject,
)

assert checkout.tax_amount == 0
assert checkout.customer_billing_address is not None
assert checkout.customer_billing_address.country == "FR"


@pytest.mark.asyncio
@pytest.mark.skip_db_asserts
Expand Down Expand Up @@ -1345,6 +1389,21 @@ async def test_valid_embed_origin(

assert checkout.embed_origin == "https://example.com"

async def test_valid_tax_not_applicable(
self, session: AsyncSession, checkout_tax_not_applicable: Checkout
) -> None:
checkout = await checkout_service.update(
session,
checkout_tax_not_applicable,
CheckoutUpdate(
customer_billing_address=Address.model_validate({"country": "FR"}),
),
)

assert checkout.tax_amount == 0
assert checkout.customer_billing_address is not None
assert checkout.customer_billing_address.country == "FR"


@pytest.mark.asyncio
@pytest.mark.skip_db_asserts
Expand Down
12 changes: 10 additions & 2 deletions server/tests/fixtures/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,12 +581,13 @@ async def create_product(
| tuple[ProductPriceType, SubscriptionRecurringInterval | None]
] = [(1000, ProductPriceType.recurring, SubscriptionRecurringInterval.month)],
attached_custom_fields: Sequence[tuple[CustomField, bool]] = [],
tax_applicable: bool = True,
) -> Product:
product = Product(
name=name,
description="Description",
is_archived=is_archived,
organization_id=organization.id,
organization=organization,
stripe_product_id=rstr("PRODUCT_ID"),
all_prices=[],
prices=[],
Expand Down Expand Up @@ -638,6 +639,12 @@ async def create_product(
product.prices.append(product_price)
product.all_prices.append(product_price)

if tax_applicable:
benefit = await create_benefit(
save_fixture, organization=organization, is_tax_applicable=True
)
product.product_benefits.append(ProductBenefit(benefit=benefit, order=0))

return product


Expand Down Expand Up @@ -762,13 +769,14 @@ async def create_benefit(
return benefit


async def add_product_benefits(
async def set_product_benefits(
save_fixture: SaveFixture,
*,
product: Product,
benefits: list[Benefit],
) -> Product:
product.product_benefits = []
await save_fixture(product)
for order, benefit in enumerate(benefits):
product.product_benefits.append(ProductBenefit(benefit=benefit, order=order))
await save_fixture(product)
Expand Down
Loading

0 comments on commit 61cae5f

Please sign in to comment.