From 2700bf3ded756c51773b127c2bce22b67ba57417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 2 Jan 2025 17:11:32 +0100 Subject: [PATCH] server/customer: harden the rules for linking customers with users --- ...56_make_usercustomer_customer_id_unique.py | 42 +++++++++++++++++++ server/polar/checkout/service.py | 6 ++- server/polar/customer/service.py | 2 +- server/polar/models/user_customer.py | 8 ++-- server/polar/user/service/user.py | 4 +- server/polar/user/tasks.py | 3 +- server/tests/checkout/test_service.py | 38 ++++++++++++++++- server/tests/fixtures/random_objects.py | 25 +++++------ 8 files changed, 104 insertions(+), 24 deletions(-) create mode 100644 server/migrations/versions/2025-01-02-1656_make_usercustomer_customer_id_unique.py diff --git a/server/migrations/versions/2025-01-02-1656_make_usercustomer_customer_id_unique.py b/server/migrations/versions/2025-01-02-1656_make_usercustomer_customer_id_unique.py new file mode 100644 index 0000000000..863c5cb235 --- /dev/null +++ b/server/migrations/versions/2025-01-02-1656_make_usercustomer_customer_id_unique.py @@ -0,0 +1,42 @@ +"""Make UserCustomer.customer_id unique + +Revision ID: 58d5e316549f +Revises: 6a6e872cbea5 +Create Date: 2025-01-02 16:56:16.433658 + +""" + +import sqlalchemy as sa +from alembic import op + +# Polar Custom Imports + +# revision identifiers, used by Alembic. +revision = "58d5e316549f" +down_revision = "6a6e872cbea5" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + "user_customers_user_id_customer_id_key", "user_customers", type_="unique" + ) + op.create_unique_constraint( + op.f("user_customers_customer_id_key"), "user_customers", ["customer_id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + op.f("user_customers_customer_id_key"), "user_customers", type_="unique" + ) + op.create_unique_constraint( + "user_customers_user_id_customer_id_key", + "user_customers", + ["user_id", "customer_id"], + ) + # ### end Alembic commands ### diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index 888ae11db9..2f773c9dfe 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -1665,7 +1665,11 @@ async def _create_or_update_customer( await session.flush() if is_direct_user(auth_subject): - await customer_service.link_user(session, customer, auth_subject.subject) + user = auth_subject.subject + if user.email_verified and user.email.lower() == customer.email.lower(): + await customer_service.link_user( + session, customer, auth_subject.subject + ) return customer diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index 5dc38d7d51..798cfc6802 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -279,7 +279,7 @@ async def link_user( user_id=user.id, customer_id=customer.id ) insert_statement = insert_statement.on_conflict_do_nothing( - index_elements=["user_id", "customer_id"] + index_elements=["customer_id"] ) await session.execute(insert_statement) diff --git a/server/polar/models/user_customer.py b/server/polar/models/user_customer.py index d25dbe53be..668a37e867 100644 --- a/server/polar/models/user_customer.py +++ b/server/polar/models/user_customer.py @@ -1,6 +1,6 @@ from uuid import UUID -from sqlalchemy import ForeignKey, UniqueConstraint, Uuid +from sqlalchemy import ForeignKey, Uuid from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from polar.kit.db.models.base import RecordModel @@ -11,13 +11,15 @@ class UserCustomer(RecordModel): __tablename__ = "user_customers" - __table_args__ = (UniqueConstraint("user_id", "customer_id"),) user_id: Mapped[UUID] = mapped_column( Uuid, ForeignKey("users.id", ondelete="cascade"), nullable=False ) customer_id: Mapped[UUID] = mapped_column( - Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False + Uuid, + ForeignKey("customers.id", ondelete="cascade"), + nullable=False, + unique=True, # A customer can only be associated with one user ) @declared_attr diff --git a/server/polar/user/service/user.py b/server/polar/user/service/user.py index 71b7a7b532..535b1e5b34 100644 --- a/server/polar/user/service/user.py +++ b/server/polar/user/service/user.py @@ -137,9 +137,9 @@ async def link_customers(self, session: AsyncSession, user: User) -> None: Customer.id, func.uuid_generate_v4(), func.now(), - ).where(Customer.email == user.email), + ).where(func.lower(Customer.email) == user.email.lower()), ) - .on_conflict_do_nothing(index_elements=["user_id", "customer_id"]) + .on_conflict_do_nothing(index_elements=["customer_id"]) ) await session.execute(statement) diff --git a/server/polar/user/tasks.py b/server/polar/user/tasks.py index be3db75b85..4af8bfa4f7 100644 --- a/server/polar/user/tasks.py +++ b/server/polar/user/tasks.py @@ -28,4 +28,5 @@ async def user_on_after_signup( if user is None: raise UserDoesNotExist(user_id) - await user_service.link_customers(session, user) + if user.email_verified: + await user_service.link_customers(session, user) diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index f4d46a8c8b..2aeafb52db 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -2219,7 +2219,7 @@ async def test_valid_stripe_existing_customer_email( stripe_service_mock.update_customer.assert_called_once() @pytest.mark.auth(AuthSubjectFixture(subject="user_second")) - async def test_link_customer_to_authenticated_user( + async def test_link_customer_to_authenticated_user_different_email( self, stripe_service_mock: MagicMock, session: AsyncSession, @@ -2248,6 +2248,42 @@ async def test_link_customer_to_authenticated_user( ), ) + assert checkout.customer is not None + linked_customer = await customer_service.get_by_id_and_user( + session, checkout.customer.id, auth_subject.subject + ) + assert linked_customer is None + + @pytest.mark.auth(AuthSubjectFixture(subject="user_second")) + async def test_link_customer_to_authenticated_same_email( + self, + stripe_service_mock: MagicMock, + session: AsyncSession, + locker: Locker, + auth_subject: AuthSubject[User], + checkout_one_time_fixed: Checkout, + ) -> None: + stripe_service_mock.create_customer.return_value = SimpleNamespace( + id="STRIPE_CUSTOMER_ID" + ) + stripe_service_mock.create_payment_intent.return_value = SimpleNamespace( + client_secret="CLIENT_SECRET", status="succeeded" + ) + checkout = await checkout_service.confirm( + session, + locker, + auth_subject, + checkout_one_time_fixed, + CheckoutConfirmStripe.model_validate( + { + "confirmation_token_id": "CONFIRMATION_TOKEN_ID", + "customer_name": "Customer Name", + "customer_email": auth_subject.subject.email, + "customer_billing_address": {"country": "FR"}, + } + ), + ) + assert checkout.customer is not None linked_customer = await customer_service.get_by_id_and_user( session, checkout.customer.id, auth_subject.subject diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index e728b8fa4d..66f0173492 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -289,19 +289,15 @@ async def user_github_oauth( return await create_user_github_oauth(save_fixture, user) -@pytest_asyncio.fixture -async def user( - save_fixture: SaveFixture, -) -> User: - return await create_user(save_fixture) - - async def create_user( - save_fixture: SaveFixture, stripe_customer_id: str | None = None + save_fixture: SaveFixture, + stripe_customer_id: str | None = None, + email_verified: bool = True, ) -> User: user = User( id=uuid.uuid4(), email=rstr("test") + "@example.com", + email_verified=email_verified, avatar_url="https://avatars.githubusercontent.com/u/47952?v=4", oauth_accounts=[], stripe_customer_id=stripe_customer_id, @@ -310,15 +306,14 @@ async def create_user( return user +@pytest_asyncio.fixture +async def user(save_fixture: SaveFixture) -> User: + return await create_user(save_fixture) + + @pytest_asyncio.fixture async def user_second(save_fixture: SaveFixture) -> User: - user = User( - id=uuid.uuid4(), - email=rstr("test") + "@example.com", - avatar_url="https://avatars.githubusercontent.com/u/47952?v=4", - ) - await save_fixture(user) - return user + return await create_user(save_fixture) @pytest_asyncio.fixture