Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
zegl committed Nov 2, 2023
1 parent 7bee0d3 commit 757727b
Show file tree
Hide file tree
Showing 15 changed files with 205 additions and 29 deletions.
34 changes: 34 additions & 0 deletions server/polar/kit/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,37 @@ class RecordModel(TimestampedModel, MappedAsDataclass):
# default=generate_uuid,
insert_default=generate_uuid,
)


# same as above, but without dataclass
class RecordModelNoDataClass(
DeclarativeBase,
ActiveRecordMixin,
SerializeMixin,
):
__abstract__ = True

metadata = my_metadata

created_at: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=True),
nullable=False,
default=utc_now,
)

modified_at: Mapped[datetime | None] = mapped_column(
TIMESTAMP(timezone=True),
onupdate=utc_now,
nullable=True,
default=None,
)

deleted_at: Mapped[datetime | None] = mapped_column(
TIMESTAMP(timezone=True), nullable=True, default=None
)

id: MappedColumn[UUID] = mapped_column(
PostgresUUID,
primary_key=True,
default=generate_uuid,
)
150 changes: 140 additions & 10 deletions server/polar/kit/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from sqlalchemy.orm import InstrumentedAttribute
from typing_extensions import deprecated

from polar.kit.db.models.base import RecordModelNoDataClass
from polar.kit.utils import utc_now

from .db.models import RecordModel
from .db.postgres import AsyncSession, sql
from .schemas import Schema

ModelType = TypeVar("ModelType", bound=RecordModel)
ModelTypeNoDataClass = TypeVar("ModelTypeNoDataClass", bound=RecordModelNoDataClass)
CreateSchemaType = TypeVar("CreateSchemaType", bound=Schema)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=Schema)
SchemaType = TypeVar("SchemaType", bound=Schema)
Expand Down Expand Up @@ -59,16 +61,6 @@ class ResourceService(
# no state to retain. Unable to achieve this with mapping the model
# and schema as class attributes though without breaking typing.

# async def create(
# self,
# session: AsyncSession,
# create_schema: CreateSchemaType,
# autocommit: bool = True,
# ) -> ModelType:
# return await self.model.create(
# session, **create_schema.dict(), autocommit=autocommit
# )

# TODO: Investigate new bulk methods in SQLALchemy 2.0 for upsert_many
async def upsert_many(
self,
Expand Down Expand Up @@ -172,3 +164,141 @@ async def update(
include=include, exclude=exclude, exclude_unset=exclude_unset
),
)


class ResourceServiceNoDataClass(
Generic[ModelTypeNoDataClass, CreateSchemaType, UpdateSchemaType],
):
def __init__(self, model: type[ModelTypeNoDataClass]) -> None:
self.model = model

async def get(
self, session: AsyncSession, id: UUID, allow_deleted: bool = False
) -> ModelTypeNoDataClass | None:
query = sql.select(self.model).where(self.model.id == id)
if not allow_deleted:
query = query.where(self.model.deleted_at.is_(None))
res = await session.execute(query)
return res.scalars().unique().one_or_none()

async def get_by(
self, session: AsyncSession, **clauses: Any
) -> ModelTypeNoDataClass | None:
query = sql.select(self.model).filter_by(**clauses)
res = await session.execute(query)
return res.scalars().unique().one_or_none()

async def soft_delete(self, session: AsyncSession, id: UUID) -> None:
stmt = (
sql.update(self.model)
.where(self.model.id == id, self.model.deleted_at.is_(None))
.values(
deleted_at=utc_now(),
)
)
await session.execute(stmt)
await session.commit()

# TODO: Investigate new bulk methods in SQLALchemy 2.0 for upsert_many
async def upsert_many(
self,
session: AsyncSession,
create_schemas: list[CreateSchemaType],
constraints: list[InstrumentedAttribute[Any]],
mutable_keys: set[str],
autocommit: bool = True,
) -> Sequence[ModelTypeNoDataClass]:
return await self._db_upsert_many(
session,
create_schemas,
constraints=constraints,
mutable_keys=mutable_keys,
autocommit=autocommit,
)

async def upsert(
self,
session: AsyncSession,
create_schema: CreateSchemaType,
constraints: list[InstrumentedAttribute[Any]],
mutable_keys: set[str],
autocommit: bool = True,
) -> ModelTypeNoDataClass:
return await self._db_upsert(
session,
create_schema,
constraints=constraints,
mutable_keys=mutable_keys,
autocommit=autocommit,
)

async def _db_upsert_many(
self,
session: AsyncSession,
objects: list[CreateSchemaType],
constraints: list[InstrumentedAttribute[Any]],
mutable_keys: set[str],
autocommit: bool = True,
) -> Sequence[ModelTypeNoDataClass]:
values = [obj.dict() for obj in objects]
if not values:
raise ValueError("Zero values provided")

insert_stmt = sql.insert(self.model).values(values)

# Update the insert statement with what to update on conflict, i.e mutable keys.
upsert_stmt = (
insert_stmt.on_conflict_do_update(
index_elements=constraints,
set_={k: getattr(insert_stmt.excluded, k) for k in mutable_keys},
)
.returning(self.model)
.execution_options(populate_existing=True)
)

res = await session.execute(upsert_stmt)
instances = res.scalars().all()
if autocommit:
await session.commit()
return instances

async def _db_upsert(
self,
session: AsyncSession,
obj: CreateSchemaType,
constraints: list[InstrumentedAttribute[Any]],
mutable_keys: set[str],
autocommit: bool = True,
) -> ModelTypeNoDataClass:
"""
Usage of upsert is deprecated.
If you need an upsert, add the functionality in the service instead of relying
active record.
"""

upserted: Sequence[ModelTypeNoDataClass] = await self._db_upsert_many(
session,
[obj],
constraints=constraints,
mutable_keys=mutable_keys,
autocommit=autocommit,
)
return upserted[0]

async def update(
self,
session: AsyncSession,
source: ModelTypeNoDataClass,
update_schema: UpdateSchemaType,
include: set[str] | None = None,
exclude: set[str] | None = None,
exclude_unset: bool = False,
autocommit: bool = True,
) -> ModelTypeNoDataClass:
return await source.update(
session,
autocommit=autocommit,
**update_schema.dict(
include=include, exclude=exclude, exclude_unset=exclude_unset
),
)
5 changes: 4 additions & 1 deletion server/polar/models/magic_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ class MagicLink(RecordModel, MappedAsDataclass, kw_only=True):

token_hash: Mapped[str] = mapped_column(String, index=True, nullable=False)
expires_at: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=True), nullable=False, default=get_expires_at
TIMESTAMP(timezone=True),
nullable=False,
default=None,
insert_default=get_expires_at,
)

user_email: Mapped[str] = mapped_column(String, nullable=False)
Expand Down
3 changes: 2 additions & 1 deletion server/polar/models/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from polar.kit.db.models import RecordModel
from polar.kit.db.models.base import RecordModelNoDataClass
from polar.kit.extensions.sqlalchemy import PostgresUUID

if TYPE_CHECKING:
Expand All @@ -29,7 +30,7 @@ class SubscriptionStatus(StrEnum):
unpaid = "unpaid"


class Subscription(RecordModel, MappedAsDataclass, kw_only=True):
class Subscription(RecordModelNoDataClass):
__tablename__ = "subscriptions"

stripe_subscription_id: Mapped[str] = mapped_column(
Expand Down
3 changes: 2 additions & 1 deletion server/polar/models/subscription_benefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from polar.exceptions import PolarError
from polar.kit.db.models import RecordModel
from polar.kit.db.models.base import RecordModelNoDataClass
from polar.kit.extensions.sqlalchemy import PostgresUUID

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,7 +57,7 @@ class SubscriptionBenefitBuiltinProperties(SubscriptionBenefitProperties):
M = TypeVar("M", bound=SubscriptionBenefitProperties)


class SubscriptionBenefit(RecordModel, MappedAsDataclass, kw_only=True):
class SubscriptionBenefit(RecordModelNoDataClass):
__tablename__ = "subscription_benefits"

type: Mapped[SubscriptionBenefitType] = mapped_column(
Expand Down
3 changes: 2 additions & 1 deletion server/polar/models/subscription_benefit_grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
)

from polar.kit.db.models import RecordModel
from polar.kit.db.models.base import RecordModelNoDataClass
from polar.kit.extensions.sqlalchemy import PostgresUUID

if TYPE_CHECKING:
from polar.models import Subscription, SubscriptionBenefit


class SubscriptionBenefitGrant(RecordModel, MappedAsDataclass, kw_only=True):
class SubscriptionBenefitGrant(RecordModelNoDataClass):
__tablename__ = "subscription_benefit_grants"

granted_at: Mapped[datetime | None] = mapped_column(
Expand Down
4 changes: 2 additions & 2 deletions server/polar/models/subscription_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from polar.kit.db.models import RecordModel
from polar.kit.db.models.base import RecordModelNoDataClass
from polar.kit.extensions.sqlalchemy import PostgresUUID

if TYPE_CHECKING:
Expand All @@ -30,7 +31,7 @@ class SubscriptionTierType(StrEnum):
business = "business"


class SubscriptionTier(RecordModel, MappedAsDataclass, kw_only=True):
class SubscriptionTier(RecordModelNoDataClass):
__tablename__ = "subscription_tiers"

type: Mapped[SubscriptionTierType] = mapped_column(
Expand Down Expand Up @@ -84,7 +85,6 @@ def subscription_tier_benefits(cls) -> Mapped[list["SubscriptionTierBenefit"]]:
benefits: AssociationProxy[list["SubscriptionBenefit"]] = association_proxy(
"subscription_tier_benefits",
"subscription_benefit",
init=False,
)

@property
Expand Down
5 changes: 3 additions & 2 deletions server/polar/models/subscription_tier_benefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
)

from polar.kit.db.models import RecordModel
from polar.kit.db.models.base import RecordModelNoDataClass
from polar.kit.extensions.sqlalchemy import PostgresUUID

if TYPE_CHECKING:
from polar.models import SubscriptionBenefit


class SubscriptionTierBenefit(RecordModel, MappedAsDataclass, kw_only=True):
class SubscriptionTierBenefit(RecordModelNoDataClass):
__tablename__ = "subscription_tier_benefits"
__table_args__ = (UniqueConstraint("subscription_tier_id", "order"),)

Expand All @@ -36,4 +37,4 @@ class SubscriptionTierBenefit(RecordModel, MappedAsDataclass, kw_only=True):

@declared_attr
def subscription_benefit(cls) -> Mapped["SubscriptionBenefit"]:
return relationship("SubscriptionBenefit", lazy="joined")
return relationship("SubscriptionBenefit", lazy="raise")
5 changes: 4 additions & 1 deletion server/polar/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class User(RecordModel, MappedAsDataclass, kw_only=True):
)

profile: Mapped[dict[str, Any] | None] = mapped_column(
JSONB, default=None, nullable=True, insert_default={}
JSONB,
nullable=True,
default=None,
insert_default=dict,
)

@declared_attr
Expand Down
6 changes: 4 additions & 2 deletions server/polar/subscription/service/subscription_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from polar.integrations.stripe.service import stripe as stripe_service
from polar.kit.db.postgres import AsyncSession
from polar.kit.pagination import PaginationParams, paginate
from polar.kit.services import ResourceService
from polar.kit.services import ResourceService, ResourceServiceNoDataClass
from polar.models import (
Account,
Organization,
Expand Down Expand Up @@ -85,7 +85,9 @@ def __init__(self, organization_id: uuid.UUID) -> None:


class SubscriptionTierService(
ResourceService[SubscriptionTier, SubscriptionTierCreate, SubscriptionTierUpdate]
ResourceServiceNoDataClass[
SubscriptionTier, SubscriptionTierCreate, SubscriptionTierUpdate
]
):
async def create(
self,
Expand Down
3 changes: 1 addition & 2 deletions server/polar/user/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class UserBase(Schema):
username: str = Field(..., max_length=50)
email: EmailStr
avatar_url: str | None
profile: dict[str, Any]

class Config:
orm_mode = True
Expand All @@ -47,7 +46,7 @@ class UserRead(UserBase, TimestampedSchema):
accepted_terms_of_service: bool
email_newsletters_and_changelogs: bool
email_promotions_and_events: bool
oauth_accounts: list[OAuthAccountRead]
oauth_accounts: list[OAuthAccountRead] = []


# TODO: remove
Expand Down
2 changes: 1 addition & 1 deletion server/polar/user/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def get_by_email_or_signup(
return user

async def signup_by_email(self, session: AsyncSession, email: str) -> User:
user = User(username=email, email=email)
user = User(username=email, email=email, profile={})
session.add(user)
await session.commit()

Expand Down
2 changes: 2 additions & 0 deletions server/tests/fixtures/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ async def create_user(
username=rstr("testuser"),
email=rstr("test") + "@example.com",
avatar_url="https://avatars.githubusercontent.com/u/47952?v=4",
profile={},
).save(session=session)

await session.commit()
Expand All @@ -160,6 +161,7 @@ async def user_second(
username=rstr("testuser"),
email=rstr("test") + "@example.com",
avatar_url="https://avatars.githubusercontent.com/u/47952?v=4",
profile={},
).save(
session=session,
)
Expand Down
Loading

0 comments on commit 757727b

Please sign in to comment.