Skip to content

Commit

Permalink
fix transaction method to work on instance and class (#263)
Browse files Browse the repository at this point in the history
Changes:
- fix transaction method
- add helper TransactionCallProtocol
  • Loading branch information
devkral authored Jan 18, 2025
1 parent 6b80a02 commit fdb6ccf
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 20 deletions.
1 change: 1 addition & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ hide:
- Invalidating caused schema errors.
- ManyToMany and ForeignKey fields didn't worked when referencing tenant models.
- ManyToMany fields didn't worked when specified on tenant models.
- Fix transaction method to work on instance and class.

### BREAKING

Expand Down
6 changes: 3 additions & 3 deletions edgy/contrib/multi_tenancy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def real_add_to_registry(cls, **kwargs: Any) -> type["BaseModelType"]:
and not cls.meta.abstract
and not cls.__is_proxy_model__
):
assert cls.__reflected__ is False, (
"Reflected models are not compatible with multi_tenancy"
)
assert (
cls.__reflected__ is False
), "Reflected models are not compatible with multi_tenancy"

if not cls.meta.register_default:
# remove from models
Expand Down
6 changes: 3 additions & 3 deletions edgy/core/connection/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def callback(model_class: type["BaseModelType"]) -> None:
if "content_type" in model_class.meta.fields:
return
related_name = f"reverse_{model_class.__name__.lower()}"
assert related_name not in real_content_type.meta.fields, (
f"duplicate model name: {model_class.__name__}"
)
assert (
related_name not in real_content_type.meta.fields
), f"duplicate model name: {model_class.__name__}"

field_args: dict[str, Any] = {
"name": "content_type",
Expand Down
8 changes: 8 additions & 0 deletions edgy/core/db/models/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from edgy.exceptions import ImproperlyConfigured, TableBuildError

if TYPE_CHECKING:
from databasez.core.transaction import Transaction

from edgy.core.connection import Database
from edgy.core.db.models import Model
from edgy.core.db.models.types import BaseModelType
Expand Down Expand Up @@ -879,6 +881,12 @@ def signals(cls) -> signals_module.Broadcaster:
meta: MetaInfo = cls.meta
return meta.signals

def transaction(cls, *, force_rollback: bool = False, **kwargs: Any) -> Transaction:
"""Return database transaction for the assigned database"""
return cast(
"Transaction", cls.database.transaction(force_rollback=force_rollback, **kwargs)
)

def table_schema(
cls,
schema: Union[str, None] = None,
Expand Down
14 changes: 12 additions & 2 deletions edgy/core/db/models/mixins/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _set_related_name_for_foreign_keys(
class DatabaseMixin:
_removed_copy_keys: ClassVar[set[str]] = _removed_copy_keys

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__dict__["transaction"] = self.not_set_transaction

@classmethod
def real_add_to_registry(
cls: type[BaseModelType],
Expand Down Expand Up @@ -809,8 +813,14 @@ def _get_indexes(cls, index: Index) -> Optional[sqlalchemy.Index]:
),
)

def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction:
"""Return database transaction for the assigned database"""
def not_set_transaction(
self=None, *, force_rollback: bool = False, **kwargs: Any
) -> Transaction:
"""
Return database transaction for the assigned database.
This method is automatically assigned to transaction masking the metaclass transaction for instances.
"""
return cast(
"Transaction", self.database.transaction(force_rollback=force_rollback, **kwargs)
)
15 changes: 3 additions & 12 deletions edgy/core/db/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,17 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Optional,
Union,
)
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

if TYPE_CHECKING:
import sqlalchemy
from databasez.core.transaction import Transaction

from edgy.core.connection.database import Database
from edgy.core.db.models.base import BaseModel
from edgy.core.db.models.managers import BaseManager
from edgy.core.db.models.metaclasses import MetaInfo
from edgy.core.db.querysets.base import QuerySet
from edgy.protocols.transaction_call import TransactionCallProtocol


class DescriptiveMeta:
Expand Down Expand Up @@ -59,6 +53,7 @@ class BaseModelType(ABC):
query_related: ClassVar[BaseManager]
meta: ClassVar[MetaInfo]
Meta: ClassVar[DescriptiveMeta] = DescriptiveMeta()
transaction: ClassVar[TransactionCallProtocol]

__parent__: ClassVar[Union[type[BaseModelType], None]] = None
__is_proxy_model__: ClassVar[bool] = False
Expand All @@ -80,10 +75,6 @@ def identifying_db_fields(self) -> Any:
def can_load(self) -> bool:
"""identifying_db_fields are completely specified."""

@abstractmethod
def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction:
"""Return database transaction for the assigned database."""

@abstractmethod
def get_columns_for_name(self, name: str) -> Sequence[sqlalchemy.Column]:
"""Helper for retrieving columns from field name."""
Expand Down
10 changes: 10 additions & 0 deletions edgy/protocols/transaction_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from databasez.core.transaction import Transaction


class TransactionCallProtocol(Protocol):
def __call__(instance: Any, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: ...
6 changes: 6 additions & 0 deletions tests/models/test_model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def test_model_class():
assert isinstance(User.query.meta.fields["name"], Field)


def test_transactions():
user = User(id=1)
User.transaction()
user.transaction()


def test_model_pk():
user = User(pk=1)
assert user.pk == 1
Expand Down

0 comments on commit fdb6ccf

Please sign in to comment.