From 6b80a027b5675cab2ed572d68d779b3150971e15 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 13 Jan 2025 17:52:45 +0100 Subject: [PATCH] fix copying registry & models. Fix relationships among tenancy objects. (#259) Changes: - fix copying registry+tests - M2M-Field `create_through_model` allows now the keyword only argument `replace_related_field`. - `add_to_registry` has now an additional keyword `replace_related_field_m2m` for seperate controlling the `create_through_model` registration logic. - `add_to_registry` has at most one positional argument. It was intended this way but not enforced. - `create_edgy_model` passes through additional keyword arguments to the edgy model class. - add on_conflict for handling model conflicts - fix invalidation causing _db_schemas removed - instead of passing down keyword arguments from create_edgy_model to type add an argument matching the other behaviour - fix foreign keys with tenancy --- docs/fields.md | 4 +- docs/models.md | 14 +- docs/release-notes.md | 35 +++ docs_src/models/on_conflict.py | 19 ++ edgy/__init__.py | 2 +- edgy/contrib/autoreflection/metaclasses.py | 12 +- edgy/contrib/autoreflection/models.py | 11 +- edgy/contrib/multi_tenancy/base.py | 26 +- edgy/contrib/multi_tenancy/metaclasses.py | 42 ++- edgy/core/connection/registry.py | 81 ++--- edgy/core/connection/schemas.py | 23 +- edgy/core/db/context_vars.py | 3 + edgy/core/db/fields/base.py | 8 +- edgy/core/db/fields/many_to_many.py | 97 +++++- edgy/core/db/fields/types.py | 5 +- edgy/core/db/models/metaclasses.py | 32 +- edgy/core/db/models/mixins/db.py | 294 ++++++++++++++---- edgy/core/db/models/mixins/reflection.py | 10 +- edgy/core/db/models/types.py | 4 +- edgy/core/db/relationships/related_field.py | 1 + edgy/core/db/relationships/relation.py | 10 +- edgy/core/utils/models.py | 5 +- edgy/exceptions.py | 3 + .../cli/custom_multidb_copied_registry/README | 1 + .../alembic.ini.mako | 50 +++ .../cli/custom_multidb_copied_registry/env.py | 142 +++++++++ .../script.py.mako | 44 +++ tests/cli/test_multidb_templates.py | 11 +- .../autoreflection/test_reflecting_models.py | 63 ++++ tests/contrib/multi_tenancy/test_migrate.py | 35 +++ tests/contrib/multi_tenancy/test_mt_models.py | 24 ++ .../multi_tenancy/test_tenant_models_using.py | 53 +++- tests/metaclass/test_meta_errors.py | 43 ++- tests/models/test_model_copying.py | 175 +++++++++++ tests/registry/test_registry_copying.py | 98 ++++++ tests/test_migrate.py | 21 +- 36 files changed, 1311 insertions(+), 190 deletions(-) create mode 100644 docs_src/models/on_conflict.py create mode 100644 tests/cli/custom_multidb_copied_registry/README create mode 100644 tests/cli/custom_multidb_copied_registry/alembic.ini.mako create mode 100644 tests/cli/custom_multidb_copied_registry/env.py create mode 100644 tests/cli/custom_multidb_copied_registry/script.py.mako create mode 100644 tests/models/test_model_copying.py create mode 100644 tests/registry/test_registry_copying.py diff --git a/docs/fields.md b/docs/fields.md index 8341c020..13c45e2a 100644 --- a/docs/fields.md +++ b/docs/fields.md @@ -505,7 +505,7 @@ The reverse end of a `ForeignKey` is a [Many to one relation](./queries/many-to- ##### Parameters * `to` - A string [model](./models.md) name or a class object of that same model. -* `target_registry` - Registry where the model callback is installed if `to` is a string. +* `target_registry` - Registry where the model callback is installed if `to` is a string. Defaults to the field owner registry. * `related_name` - The name to use for the relation from the related object back to this one. Can be set to `False` to disable a reverse connection. Note: Setting to `False` will also prevent prefetching and reversing via `__`. See also [related_name](./queries/related-name.md) for defaults @@ -586,11 +586,13 @@ class MyModel(edgy.Model): ##### Parameters * `to` - A string [model](./models.md) name or a class object of that same model. +* `target_registry` - Registry where the model callback is installed if `to` is a string. Defaults to the field owner registry. * `from_fields` - Provide the `related_fields` for the implicitly generated ForeignKey to the owner model. * `to_fields` - Provide the `related_fields` for the implicitly generated ForeignKey to the child model. * `related_name` - The name to use for the relation from the related object back to this one. * `through` - The model to be used for the relationship. Edgy generates the model by default if None is provided or `through` is an abstract model. +* `through_registry` - Registry where the model callback is installed if `through` is a string or empty. Defaults to the field owner registry. * `through_tablename` - Custom tablename for `through`. E.g. when special characters are used in model names. * `embed_through` - When traversing, embed the through object in this attribute. Otherwise it is not accessable from the result. if an empty string was provided, the old behaviour is used to query from the through model as base (default). diff --git a/docs/models.md b/docs/models.md index 20863254..69d05bc8 100644 --- a/docs/models.md +++ b/docs/models.md @@ -60,6 +60,7 @@ For this the `StrictModel` model can be used. Otherwise it behaves like a normal There is no strict version of a `ReflectModel` because the laxness is required. + ### Loading models You may have the models distributed among multiple files and packages. @@ -97,6 +98,17 @@ If no `id` is declared in the model, **Edgy** will automatically generate an `id Earlier there were many restrictions. Now they were lifted +### Controlling collision behaviour + +Earlier models were simply replaced when defining a model with the same name or adding such. + +Now the default is to error when a collision was detected, or in case the `on_conflict` parameter was set, either +a `replace` or `keep` executed. + +``` python +{!> ../docs_src/models/on_conflict.py !} +``` + #### What you should not do ##### Declaring an IntegerField as primary key without autoincrement set @@ -171,7 +183,7 @@ to copy a model class and optionally add it to an other registry. You can add it to a registry later by using: -`model_class.add_to_registry(registry, name="", database=None, replace_related_field=False)` +`model_class.add_to_registry(registry, name="", database=None, replace_related_field=...)` In fact the last method is called when the registry parameter of `copy_edgy_model` is not `None`. diff --git a/docs/release-notes.md b/docs/release-notes.md index ec32c6bf..21c10bb3 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -6,6 +6,41 @@ hide: # Release Notes +## 0.24.3 + +### Added + +- ManyToManyField `create_through_model` method allows now the keyword only argument `replace_related_field`. +- `add_to_registry` and models have now an additional keyword-only argument `on_conflict` for controlling what happens when a same named model already exists. + For models this can be passed : `class Foo(edgy.Model, on_conflict="keep"): ...`. +- Passing a tuple or list of types to `replace_related_field` is now allowed. +- Add `through_registry` to ManyToMany. +- Add `no_copy` to models MetaInfo. +- Add `ModelCollisionError` exception. +- Add keyword only hook function `real_add_to_registry`. It can be used to customize the `add_to_registry` behaviour. + +### Changed + +- `create_edgy_model` has now `__type_kwargs__` which contains a dict of keyword arguments provided to `__new__` of type. +- RelatedField uses now `no_copy`. +- `add_to_registry` returns the type which was actually added to registry instead of None. +- Through models use now `no_copy` when autogenerated. This way they don't land in copied registries but are autogenerated again. +- Instead of silent replacing models with the same `__name__` now an error is raised. +- `skip_registry` has now also an allowed literal value: `"allow_search"`. It enables the search of the registry but doesn't register the model. + +### Fixed + +- Copying registries and models is working now. +- Fix deleting (clearing cache) of BaseForeignKey target. +- Creating two models with the same name did lead to silent replacements. +- Invalidating caused schema errors. +- ManyToMany and ForeignKey fields didn't worked when referencing tenant models. +- ManyToMany fields didn't worked when specified on tenant models. + +### BREAKING + +- Instead of silent replacing models with the same `__name__` now an error is raised. +- The return value of `add_to_registry` changed. If you customize the function you need to return now the actual model added to the registry. ## 0.24.2 diff --git a/docs_src/models/on_conflict.py b/docs_src/models/on_conflict.py new file mode 100644 index 00000000..1d0f3129 --- /dev/null +++ b/docs_src/models/on_conflict.py @@ -0,0 +1,19 @@ +import edgy + +models = ... + + +class Foo(edgy.Model, on_conflict="keep"): + class Meta: + registry = models + + +# or + + +class Foo2(edgy.Model): + class Meta: + registry = False + + +Foo2.add_to_registry(models, name="Foo", on_conflict="replace") diff --git a/edgy/__init__.py b/edgy/__init__.py index b7e3513c..be77123e 100644 --- a/edgy/__init__.py +++ b/edgy/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__version__ = "0.24.2" +__version__ = "0.24.3" from typing import TYPE_CHECKING from ._monkay import Instance, create_monkay diff --git a/edgy/contrib/autoreflection/metaclasses.py b/edgy/contrib/autoreflection/metaclasses.py index 3a3b94f8..14267d53 100644 --- a/edgy/contrib/autoreflection/metaclasses.py +++ b/edgy/contrib/autoreflection/metaclasses.py @@ -69,24 +69,14 @@ def __new__( name: str, bases: tuple[type, ...], attrs: dict[str, Any], - skip_registry: bool = False, meta_info_class: type[AutoReflectionMetaInfo] = AutoReflectionMetaInfo, **kwargs: Any, ) -> Any: - new_model = super().__new__( + return super().__new__( cls, name, bases, attrs, meta_info_class=meta_info_class, - skip_registry=True, **kwargs, ) - if ( - not skip_registry - and isinstance(new_model.meta, AutoReflectionMetaInfo) - and not new_model.meta.abstract - and new_model.meta.registry - ): - new_model.meta.registry.pattern_models[new_model.__name__] = new_model - return new_model diff --git a/edgy/contrib/autoreflection/models.py b/edgy/contrib/autoreflection/models.py index b9d75810..15c9aa27 100644 --- a/edgy/contrib/autoreflection/models.py +++ b/edgy/contrib/autoreflection/models.py @@ -1,9 +1,18 @@ -from typing import ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import edgy from .metaclasses import AutoReflectionMeta, AutoReflectionMetaInfo +if TYPE_CHECKING: + from edgy.core.db.models.types import BaseModelType + class AutoReflectModel(edgy.ReflectModel, metaclass=AutoReflectionMeta): meta: ClassVar[AutoReflectionMetaInfo] + + @classmethod + def real_add_to_registry(cls, **kwargs: Any) -> type["BaseModelType"]: + if isinstance(cls.meta, AutoReflectionMetaInfo): + kwargs.setdefault("registry_type_name", "pattern_models") + return super().real_add_to_registry(**kwargs) diff --git a/edgy/contrib/multi_tenancy/base.py b/edgy/contrib/multi_tenancy/base.py index 37d6ebe1..9463c6dd 100644 --- a/edgy/contrib/multi_tenancy/base.py +++ b/edgy/contrib/multi_tenancy/base.py @@ -1,8 +1,11 @@ -from typing import ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from edgy.contrib.multi_tenancy.metaclasses import BaseTenantMeta, TenantMeta from edgy.core.db.models.model import Model +if TYPE_CHECKING: + from edgy.core.db.models.types import BaseModelType + class TenantModel(Model, metaclass=BaseTenantMeta): """ @@ -16,3 +19,24 @@ class TenantModel(Model, metaclass=BaseTenantMeta): """ meta: ClassVar[TenantMeta] = TenantMeta(None, abstract=True) + + @classmethod + def real_add_to_registry(cls, **kwargs: Any) -> type["BaseModelType"]: + result = super().real_add_to_registry(**kwargs) + + if ( + cls.meta.registry + and cls.meta.is_tenant + and not cls.meta.abstract + and not cls.__is_proxy_model__ + ): + assert cls.__reflected__ is False, ( + "Reflected models are not compatible with multi_tenancy" + ) + + if not cls.meta.register_default: + # remove from models + cls.meta.registry.models.pop(cls.__name__, None) + cls.meta.registry.tenant_models[cls.__name__] = cls + + return result diff --git a/edgy/contrib/multi_tenancy/metaclasses.py b/edgy/contrib/multi_tenancy/metaclasses.py index c94aca36..8e58d05a 100644 --- a/edgy/contrib/multi_tenancy/metaclasses.py +++ b/edgy/contrib/multi_tenancy/metaclasses.py @@ -1,10 +1,13 @@ -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from edgy.core.db.models.metaclasses import ( BaseModelMeta, MetaInfo, ) +if TYPE_CHECKING: + from edgy.core.connection.database import Database + def _check_model_inherited_tenancy(bases: tuple[type, ...]) -> bool: for base in bases: @@ -42,23 +45,36 @@ class BaseTenantMeta(BaseModelMeta): your own tenant model using the `is_tenant` inside the `Meta` object. """ - def __new__(cls, name: str, bases: tuple[type, ...], attrs: Any, **kwargs: Any) -> Any: - new_model = super().__new__(cls, name, bases, attrs, meta_info_class=TenantMeta, **kwargs) + def __new__( + cls, + name: str, + bases: tuple[type, ...], + attrs: Any, + on_conflict: Literal["error", "replace", "keep"] = "error", + skip_registry: Union[bool, Literal["allow_search"]] = False, + meta_info_class: type[TenantMeta] = TenantMeta, + **kwargs: Any, + ) -> Any: + database: Union[Literal["keep"], None, Database, bool] = attrs.get("database", "keep") + new_model = super().__new__( + cls, + name, + bases, + attrs, + skip_registry="allow_search", + meta_info_class=meta_info_class, + **kwargs, + ) if new_model.meta.is_tenant is None: new_model.meta.is_tenant = _check_model_inherited_tenancy(bases) if ( - new_model.meta.registry - and new_model.meta.is_tenant + not skip_registry + and new_model.meta.registry and not new_model.meta.abstract and not new_model.__is_proxy_model__ ): - assert ( - new_model.__reflected__ is False - ), "Reflected models are not compatible with multi_tenancy" - - if not new_model.meta.register_default: - # remove from models - new_model.meta.registry.models.pop(new_model.__name__, None) - new_model.meta.registry.tenant_models[new_model.__name__] = new_model + new_model.add_to_registry( + new_model.meta.registry, on_conflict=on_conflict, database=database + ) return new_model diff --git a/edgy/core/connection/registry.py b/edgy/core/connection/registry.py index 3153b6e2..ee4d9e37 100644 --- a/edgy/core/connection/registry.py +++ b/edgy/core/connection/registry.py @@ -3,20 +3,11 @@ import re import warnings from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Container, Sequence from copy import copy as shallow_copy from functools import cached_property, partial from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Optional, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Union, cast, overload import sqlalchemy from loguru import logger @@ -103,6 +94,13 @@ class Registry: The command center for the models of Edgy. """ + model_registry_types: ClassVar[tuple[str, ...]] = ( + "models", + "reflected", + "tenant_models", + "pattern_models", + ) + db_schema: Union[str, None] = None content_type: Union[type["BaseModelType"], None] = None dbs_reflected: set[Union[str, None]] @@ -168,25 +166,24 @@ def __copy__(self) -> "Registry": content_type: Union[bool, type[BaseModelType]] = False if self.content_type is not None: try: - content_type2 = content_type = self.get_model( + content_type = self.get_model( "ContentType", include_content_type_attr=False ).copy_edgy_model() - # cleanup content_type copy - for field_name in list(content_type2.meta.fields.keys()): - if field_name.startswith("reverse_"): - del content_type2.meta.fields[field_name] except LookupError: content_type = self.content_type _copy = Registry( self.database, with_content_type=content_type, schema=self.db_schema, extra=self.extra ) - for i in ["models", "reflected", "tenant_models", "pattern_models"]: - dict_models = getattr(_copy, i) + for registry_type in self.model_registry_types: + dict_models = getattr(_copy, registry_type) dict_models.update( ( - (key, val.copy_edgy_model(_copy)) - for key, val in getattr(self, i).items() - if key not in dict_models + ( + key, + val.copy_edgy_model(registry=_copy), + ) + for key, val in getattr(self, registry_type).items() + if not val.meta.no_copy and key not in dict_models ) ) _copy.dbs_reflected = set(self.dbs_reflected) @@ -223,7 +220,7 @@ def _set_content_type( __bases__=(with_content_type,), ) elif real_content_type.meta.registry is None: - real_content_type.add_to_registry(self, "ContentType") + real_content_type.add_to_registry(self, name="ContentType") self.content_type = real_content_type def callback(model_class: type["BaseModelType"]) -> None: @@ -252,9 +249,9 @@ def callback(model_class: type["BaseModelType"]) -> None: if "content_type" in model_class.meta.fields: return related_name = f"reverse_{model_class.__name__.lower()}" - assert ( - related_name not in real_content_type.meta.fields - ), f"duplicate model name: {model_class.__name__}" + assert related_name not in real_content_type.meta.fields, ( + f"duplicate model name: {model_class.__name__}" + ) field_args: dict[str, Any] = { "name": "content_type", @@ -303,7 +300,11 @@ def metadata(self) -> sqlalchemy.MetaData: return self.metadata_by_name[None] def get_model( - self, model_name: str, *, include_content_type_attr: bool = True + self, + model_name: str, + *, + include_content_type_attr: bool = True, + exclude: Container[str] = (), ) -> type["BaseModelType"]: if ( include_content_type_attr @@ -311,14 +312,21 @@ def get_model( and self.content_type is not None ): return self.content_type - if model_name in self.models: - return self.models[model_name] - elif model_name in self.reflected: - return self.reflected[model_name] - elif model_name in self.tenant_models: - return self.tenant_models[model_name] - else: - raise LookupError(f"Registry doesn't have a {model_name} model.") from None + for model_dict_name in self.model_registry_types: + if model_dict_name in exclude: + continue + model_dict: dict = getattr(self, model_dict_name) + if model_name in model_dict: + return cast(type["BaseModelType"], model_dict[model_name]) + raise LookupError(f'Registry doesn\'t have a "{model_name}" model.') from None + + def delete_model(self, model_name: str) -> bool: + for model_dict_name in self.model_registry_types: + model_dict: dict = getattr(self, model_dict_name) + if model_name in model_dict: + del model_dict[model_name] + return True + return False def refresh_metadata( self, @@ -521,7 +529,9 @@ async def _connect_and_init(self, name: Union[str, None], database: "Database") new_name = pattern_model.meta.template(table) old_model: Optional[type[BaseModelType]] = None with contextlib.suppress(LookupError): - old_model = self.get_model(new_name) + old_model = self.get_model( + new_name, include_content_type_attr=False, exclude=("pattern_models",) + ) if old_model is not None: raise Exception( f"Conflicting model: {old_model.__name__} with pattern model: {pattern_model.__name__}" @@ -529,6 +539,7 @@ async def _connect_and_init(self, name: Union[str, None], database: "Database") concrete_reflect_model = pattern_model.copy_edgy_model( name=new_name, meta_info_class=MetaInfo ) + concrete_reflect_model.meta.no_copy = True concrete_reflect_model.meta.tablename = table.name concrete_reflect_model.__using_schema__ = table.schema concrete_reflect_model.add_to_registry(self, database=database) diff --git a/edgy/core/connection/schemas.py b/edgy/core/connection/schemas.py index d333dc35..5975680a 100644 --- a/edgy/core/connection/schemas.py +++ b/edgy/core/connection/schemas.py @@ -7,6 +7,7 @@ from sqlalchemy.exc import DBAPIError, ProgrammingError from edgy.core.connection.database import Database +from edgy.core.db.context_vars import NO_GLOBAL_FIELD_CONSTRAINTS from edgy.exceptions import SchemaError if TYPE_CHECKING: @@ -71,14 +72,16 @@ async def create_schema( if init_models: for model_class in self.registry.models.values(): model_class.table_schema(schema=schema, update_cache=update_cache) - if init_tenant_models and init_models: - for model_class in self.registry.tenant_models.values(): - model_class.table_schema(schema=schema, update_cache=update_cache) - elif init_tenant_models: + if init_tenant_models: + token = NO_GLOBAL_FIELD_CONSTRAINTS.set(True) + try: + for model_class in self.registry.tenant_models.values(): + tenant_tables.append(model_class.build(schema=schema)) + finally: + NO_GLOBAL_FIELD_CONSTRAINTS.reset(token) + # we need two passes for model_class in self.registry.tenant_models.values(): - tenant_tables.append( - model_class.table_schema(schema=schema, update_cache=update_cache) - ) + model_class.add_global_field_constraints(schema=schema) def execute_create(connection: sqlalchemy.Connection, name: Optional[str]) -> None: try: @@ -87,8 +90,10 @@ def execute_create(connection: sqlalchemy.Connection, name: Optional[str]) -> No ) except ProgrammingError as e: raise SchemaError(detail=e.orig.args[0]) from e - for table in tenant_tables: - table.create(connection, checkfirst=if_not_exists) + if tenant_tables: + self.registry.metadata_by_name[name].create_all( + connection, checkfirst=if_not_exists, tables=tenant_tables + ) if init_models: self.registry.metadata_by_name[name].create_all( connection, checkfirst=if_not_exists diff --git a/edgy/core/db/context_vars.py b/edgy/core/db/context_vars.py index 954cc381..0a7bf57a 100644 --- a/edgy/core/db/context_vars.py +++ b/edgy/core/db/context_vars.py @@ -14,6 +14,9 @@ "CURRENT_MODEL_INSTANCE", default=None ) CURRENT_PHASE: ContextVar[str] = ContextVar("CURRENT_PHASE", default="") +NO_GLOBAL_FIELD_CONSTRAINTS: ContextVar[bool] = ContextVar( + "NO_GLOBAL_FIELD_CONSTRAINTS", default=False +) EXPLICIT_SPECIFIED_VALUES: ContextVar[Optional[set[str]]] = ContextVar( "EXPLICIT_SPECIFIED_VALUES", default=None ) diff --git a/edgy/core/db/fields/base.py b/edgy/core/db/fields/base.py index d98880c7..b394bbde 100644 --- a/edgy/core/db/fields/base.py +++ b/edgy/core/db/fields/base.py @@ -432,6 +432,11 @@ def target_registry(self) -> "Registry": def target_registry(self, value: Any) -> None: self._target_registry = value + @target_registry.deleter + def target_registry(self) -> None: + with contextlib.suppress(AttributeError): + delattr(self, "_target_registry") + @property def target(self) -> Any: """ @@ -451,7 +456,8 @@ def target(self, value: Any) -> None: self.to = value @target.deleter - def target(self, value: Any) -> None: + def target(self) -> None: + # clear cache with contextlib.suppress(AttributeError): delattr(self, "_target") diff --git a/edgy/core/db/fields/many_to_many.py b/edgy/core/db/fields/many_to_many.py index ed308368..4e61f754 100644 --- a/edgy/core/db/fields/many_to_many.py +++ b/edgy/core/db/fields/many_to_many.py @@ -1,3 +1,4 @@ +import contextlib from collections.abc import Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast @@ -14,6 +15,7 @@ from edgy.protocols.many_relationship import ManyRelationProtocol if TYPE_CHECKING: + from edgy.core.connection.registry import Registry from edgy.core.db.fields.types import BaseFieldType from edgy.core.db.models.types import BaseModelType @@ -38,7 +40,7 @@ def __init__( self.to_foreign_key = to_foreign_key self.from_fields = from_fields self.from_foreign_key = from_foreign_key - self.through = through + self.through_original = self.through = through self.through_tablename = through_tablename self.embed_through = embed_through @@ -58,6 +60,24 @@ def reverse_embed_through_prefix(self) -> str: return self.reverse_name return f"{self.reverse_name}__{self.embed_through}" + @property + def through_registry(self) -> "Registry": + """Registry searched in case through is a string""" + + if not hasattr(self, "_through_registry"): + assert self.owner.meta.registry, "no registry found neither 'through_registry' set" + return self.owner.meta.registry + return cast("Registry", self._through_registry) + + @through_registry.setter + def through_registry(self, value: Any) -> None: + self._through_registry = value + + @through_registry.deleter + def through_registry(self) -> None: + with contextlib.suppress(AttributeError): + delattr(self, "_through_registry") + def clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: if not for_query: return {} @@ -106,7 +126,7 @@ def traverse_field(self, path: str) -> tuple[Any, str, str]: return ( self.target, self.reverse_name, - f'{path.removeprefix(self.name).removeprefix("__")}', + f"{path.removeprefix(self.name).removeprefix('__')}", ) return self.target, self.reverse_name, path.removeprefix(self.name).removeprefix("__") @@ -128,30 +148,57 @@ def reverse_traverse_field_fk(self, path: str) -> tuple[Any, str, str]: return ( self.owner, self.name, - f'{path.removeprefix(self.reverse_name).removeprefix("__")}', + f"{path.removeprefix(self.reverse_name).removeprefix('__')}", ) return self.owner, self.name, path.removeprefix(self.reverse_name).removeprefix("__") - def create_through_model(self) -> None: + def create_through_model( + self, + *, + replace_related_field: Union[ + bool, + type["BaseModelType"], + tuple[type["BaseModelType"], ...], + list[type["BaseModelType"]], + ] = False, + ) -> None: """ Creates the default empty through model. Generates a middle model based on the owner of the field and the field itself and adds it to the main registry to make sure it generates the proper models and migrations. """ + from edgy.contrib.multi_tenancy.base import TenantModel + from edgy.contrib.multi_tenancy.metaclasses import TenantMeta from edgy.core.db.models.metaclasses import MetaInfo - __bases__: tuple[type[BaseModelType], ...] = () + __bases__: tuple[type[BaseModelType], ...] = ( + (TenantModel,) + if getattr(self.owner.meta, "is_tenant", False) + or getattr(self.target.meta, "is_tenant", False) + else () + ) pknames = set() if self.through: - if isinstance(self.through, str): - assert self.owner.meta.registry, "no registry found" - self.through = self.owner.meta.registry.models[self.through] through = self.through - if through.meta.abstract: - pknames = set(through.pknames) - __bases__ = (through,) - else: + if isinstance(through, str): + + def callback(model_class: type["BaseModelType"]) -> None: + self.through = model_class + self.create_through_model(replace_related_field=replace_related_field) + + self.through_registry.register_callback(through, callback, one_time=True) + return + if not through.meta.abstract: + if not through.meta.registry: + through = cast( + "type[BaseModelType]", + through.add_to_registry( + self.through_registry, + replace_related_field=replace_related_field, + on_conflict="keep", + ), + ) if not self.from_foreign_key: candidate = None for field_name in through.meta.foreign_key_fields: @@ -162,7 +209,7 @@ def create_through_model(self) -> None: else: candidate = field_name if not candidate: - raise ValueError("no foreign key fo owner found") + raise ValueError("no foreign key to owner found") self.from_foreign_key = candidate if not self.to_foreign_key: candidate = None @@ -174,10 +221,14 @@ def create_through_model(self) -> None: else: candidate = field_name if not candidate: - raise ValueError("no foreign key fo target found") + raise ValueError("no foreign key to target found") self.to_foreign_key = candidate through.meta.multi_related.add((self.from_foreign_key, self.to_foreign_key)) + self.through = through return + pknames = set(through.pknames) + __bases__ = (through,) + del through assert self.owner.meta.registry, "no registry set" owner_name = self.owner.__name__ to_name = self.target.__name__ @@ -199,7 +250,16 @@ def create_through_model(self) -> None: if has_pknames: meta_args["unique_together"] = [(self.from_foreign_key, self.to_foreign_key)] - new_meta: MetaInfo = MetaInfo(None, **meta_args) + # TenantMeta is compatible to normal meta + new_meta: MetaInfo = TenantMeta( + None, + registry=False, + no_copy=True, + is_tenant=getattr(self.owner.meta, "is_tenant", False) + or getattr(self.target.meta, "is_tenant", False), + register_default=getattr(self.owner.meta, "register_default", False), + **meta_args, + ) to_related_name: Union[str, Literal[False]] if self.related_name is False: @@ -252,8 +312,11 @@ def create_through_model(self) -> None: through_model.meta.fields["content_type"] = ExcludeField( name="content_type", owner=through_model ) - through_model.add_to_registry(self.owner.meta.registry) - self.through = through_model + self.through = through_model.add_to_registry( + self.through_registry, + replace_related_field=replace_related_field, + on_conflict="keep", + ) def to_model( self, diff --git a/edgy/core/db/fields/types.py b/edgy/core/db/fields/types.py index d05bca45..bdc9bdcc 100644 --- a/edgy/core/db/fields/types.py +++ b/edgy/core/db/fields/types.py @@ -130,7 +130,10 @@ def to_model( return {field_name: value} def get_global_constraints( - self, name: str, columns: Sequence[sqlalchemy.Column], schemes: Sequence[str] = () + self, + name: str, + columns: Sequence[sqlalchemy.Column], + schemes: Sequence[str] = (), ) -> Sequence[Union[sqlalchemy.Constraint, sqlalchemy.Index]]: """Return global constraints and indexes. Useful for multicolumn fields diff --git a/edgy/core/db/models/metaclasses.py b/edgy/core/db/models/metaclasses.py index 129563b0..b29e5f37 100644 --- a/edgy/core/db/models/metaclasses.py +++ b/edgy/core/db/models/metaclasses.py @@ -211,6 +211,7 @@ class MetaInfo: "inherit", "fields", "registry", + "no_copy", "tablename", "unique_together", "indexes", @@ -264,6 +265,7 @@ def __init__(self, meta: Any = None, **kwargs: Any) -> None: self.model: Optional[type[BaseModelType]] = None # Difference between meta extraction and kwargs: meta attributes are copied self.abstract: bool = getattr(meta, "abstract", False) + self.no_copy: bool = getattr(meta, "no_copy", False) # for embedding self.inherit: bool = getattr(meta, "inherit", True) self.registry: Union[Registry, Literal[False], None] = getattr(meta, "registry", None) @@ -388,9 +390,11 @@ def invalidate( if self.model is None: return if clear_class_attrs: - for attr in ("_table", "_pknames", "_pkcolumns", "_db_schemas", "__proxy_model__"): + for attr in ("_table", "_pknames", "_pkcolumns", "__proxy_model__"): with contextlib.suppress(AttributeError): delattr(self.model, attr) + # needs an extra invalidation + self.model._db_schemas = {} def full_init(self, init_column_mappers: bool = True, init_class_attrs: bool = True) -> None: if not self._fields_are_initialized: @@ -582,7 +586,8 @@ def __new__( bases: tuple[type, ...], attrs: dict[str, Any], meta_info_class: type[MetaInfo] = MetaInfo, - skip_registry: bool = False, + skip_registry: Union[bool, Literal["allow_search"]] = False, + on_conflict: Literal["error", "replace", "keep"] = "error", **kwargs: Any, ) -> Any: fields: dict[str, BaseFieldType] = {} @@ -681,9 +686,9 @@ def __new__( for field_name, field_value in fields.items(): attrs.pop(field_name, None) - # clear cached target + # clear cached target, target is property if isinstance(field_value, BaseForeignKey): - field_value.__dict__.pop("_target", None) + del field_value.target for manager_name in managers: attrs.pop(manager_name, None) @@ -727,6 +732,7 @@ def __new__( # (excluding the edgy.Model class itself). if not has_parents: return new_class + new_class._db_schemas = {} # Ensure the model_fields are updated to the latest # required since pydantic 2.10 @@ -734,7 +740,6 @@ def __new__( # error since pydantic 2.10 with contextlib.suppress(AttributeError): new_class.model_fields = model_fields - new_class._db_schemas = {} # Set the owner of the field, must be done as early as possible # don't use meta.fields to not trigger the lazy evaluation @@ -779,19 +784,16 @@ def __new__( tablename = f"{name.lower()}s" meta.tablename = tablename meta.model = new_class - if skip_registry: - # don't add automatically to registry. Useful for subclasses which modify the registry itself. - new_class.model_rebuild(force=True) - return new_class - - # Now set the registry of models - if meta.registry is None: + # Now find a registry and add it to the meta. + if meta.registry is None and skip_registry is not True: registry: Union[Registry, None, Literal[False]] = get_model_registry(bases, meta_class) meta.registry = registry or None - if not meta.registry: + # don't add automatically to registry. Useful for subclasses which modify the registry itself. + # `skip_registry="allow_search"` is trueish so it works. + if not meta.registry or skip_registry: new_class.model_rebuild(force=True) return new_class - new_class.add_to_registry(meta.registry, database=database) + new_class.add_to_registry(meta.registry, database=database, on_conflict=on_conflict) return new_class def get_db_schema(cls) -> Union[str, None]: @@ -834,7 +836,7 @@ def table(cls) -> sqlalchemy.Table: if not cls.meta.registry: # we cannot set the table without a registry # raising is required - raise AttributeError() + raise AttributeError("No registry.") table = getattr(cls, "_table", None) # assert table.name.lower() == cls.meta.tablename, f"{table.name.lower()} != {cls.meta.tablename}" # fix assigned table diff --git a/edgy/core/db/models/mixins/db.py b/edgy/core/db/models/mixins/db.py index a5c6a4d0..4b58acb3 100644 --- a/edgy/core/db/models/mixins/db.py +++ b/edgy/core/db/models/mixins/db.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import copy import inspect @@ -15,15 +17,19 @@ CURRENT_INSTANCE, EXPLICIT_SPECIFIED_VALUES, MODEL_GETATTR_BEHAVIOR, + NO_GLOBAL_FIELD_CONSTRAINTS, get_schema, ) from edgy.core.db.datastructures import Index, UniqueConstraint +from edgy.core.db.fields.base import BaseForeignKey from edgy.core.db.fields.many_to_many import BaseManyToManyForeignKeyField from edgy.core.db.models.metaclasses import MetaInfo +from edgy.core.db.models.types import BaseModelType from edgy.core.db.models.utils import build_pkcolumns, build_pknames from edgy.core.db.relationships.related_field import RelatedField from edgy.core.utils.db import check_db_connection -from edgy.exceptions import ForeignKeyBadConfigured, ObjectNotFound +from edgy.core.utils.models import create_edgy_model +from edgy.exceptions import ForeignKeyBadConfigured, ModelCollisionError, ObjectNotFound from edgy.types import Undefined if TYPE_CHECKING: @@ -33,7 +39,6 @@ from edgy.core.connection.registry import Registry from edgy.core.db.fields.types import BaseFieldType from edgy.core.db.models.model import Model - from edgy.core.db.models.types import BaseModelType _empty = cast(set[str], frozenset()) @@ -49,6 +54,7 @@ class _EmptyClass: ... "_pknames", "_table", "_db_schemas", + "__proxy_model__", "meta", } _removed_copy_keys.difference_update( @@ -56,13 +62,28 @@ class _EmptyClass: ... ) +def _check_replace_related_field( + replace_related_field: Union[ + bool, type[BaseModelType], tuple[type[BaseModelType], ...], list[type[BaseModelType]] + ], + model: type[BaseModelType], +) -> bool: + if isinstance(replace_related_field, bool): + return replace_related_field + if not isinstance(replace_related_field, (tuple, list)): + replace_related_field = (replace_related_field,) + return any(refmodel is model for refmodel in replace_related_field) + + def _set_related_field( - target: type["BaseModelType"], + target: type[BaseModelType], *, foreign_key_name: str, related_name: str, - source: type["BaseModelType"], - replace_related_field: Union[bool, type["BaseModelType"]], + source: type[BaseModelType], + replace_related_field: Union[ + bool, type[BaseModelType], tuple[type[BaseModelType], ...], list[type[BaseModelType]] + ], ) -> None: if replace_related_field is not True and related_name in target.meta.fields: # is already correctly set, required for migrate of model_apps with registry set @@ -73,9 +94,8 @@ def _set_related_field( ): return # required for copying - if ( - related_field.related_from is not replace_related_field - or related_field.foreign_key_name != foreign_key_name + if related_field.foreign_key_name != foreign_key_name or _check_replace_related_field( + replace_related_field, related_field.related_from ): raise ForeignKeyBadConfigured( f"Multiple related_name with the same value '{related_name}' found to the same target. Related names must be different." @@ -101,9 +121,11 @@ def _set_related_field( def _set_related_name_for_foreign_keys( - meta: "MetaInfo", - model_class: type["BaseModelType"], - replace_related_field: Union[bool, type["BaseModelType"]] = False, + meta: MetaInfo, + model_class: type[BaseModelType], + replace_related_field: Union[ + bool, type[BaseModelType], tuple[type[BaseModelType], ...], list[type[BaseModelType]] + ] = False, ) -> None: """ Sets the related name for the foreign keys. @@ -148,13 +170,22 @@ class DatabaseMixin: _removed_copy_keys: ClassVar[set[str]] = _removed_copy_keys @classmethod - def add_to_registry( - cls: type["BaseModelType"], - registry: "Registry", + def real_add_to_registry( + cls: type[BaseModelType], + *, + registry: Registry, + registry_type_name: str = "models", name: str = "", - database: Union[bool, "Database", Literal["keep"]] = "keep", - replace_related_field: Union[bool, type["BaseModelType"]] = False, - ) -> None: + database: Union[bool, Database, Literal["keep"]] = "keep", + replace_related_field: Union[ + bool, + type[BaseModelType], + tuple[type[BaseModelType], ...], + list[type[BaseModelType]], + ] = False, + on_conflict: Literal["keep", "replace", "error"] = "error", + ) -> type[BaseModelType]: + """For customizations.""" # when called if registry is not set cls.meta.registry = registry if database is True: @@ -172,31 +203,73 @@ def add_to_registry( # Making sure it does not generate models if abstract or a proxy if not meta.abstract and not cls.__is_proxy_model__: - if getattr(cls, "__reflected__", False): - registry.reflected[cls.__name__] = cls + if on_conflict == "replace": + registry.delete_model(cls.__name__) else: - registry.models[cls.__name__] = cls - # after registrating the own model - for value in list(meta.fields.values()): - if isinstance(value, BaseManyToManyForeignKeyField): - m2m_registry: Registry = value.target_registry - with contextlib.suppress(Exception): - m2m_registry = cast("Registry", value.target.registry) - - def create_through_model(x: Any, field: "BaseFieldType" = value) -> None: - # we capture with field = ... the variable - field.create_through_model() - - m2m_registry.register_callback(value.to, create_through_model, one_time=True) - # Sets the foreign key fields - if meta.foreign_key_fields: - _set_related_name_for_foreign_keys( - meta, cls, replace_related_field=replace_related_field - ) - registry.execute_model_callbacks(cls) + with contextlib.suppress(LookupError): + original_model = registry.get_model( + cls.__name__, include_content_type_attr=False, exclude=("tenant_models",) + ) + if on_conflict == "keep": + return original_model + else: + raise ModelCollisionError( + detail=( + f'A model with the same name is already registered: "{cls.__name__}".\n' + "If this is not a bug, define the behaviour by " + 'setting "on_conflict" to either "keep" or "replace".' + ) + ) + if registry_type_name: + registry_dict = getattr(registry, registry_type_name) + registry_dict[cls.__name__] = cls + # after registrating the own model + for value in list(meta.fields.values()): + if isinstance(value, BaseManyToManyForeignKeyField): + m2m_registry: Registry = value.target_registry + with contextlib.suppress(Exception): + m2m_registry = cast("Registry", value.target.registry) + + def create_through_model(x: Any, field: BaseFieldType = value) -> None: + # we capture with field = ... the variable + field.create_through_model(replace_related_field=replace_related_field) + + m2m_registry.register_callback( + value.to, create_through_model, one_time=True + ) + # Sets the foreign key fields + if meta.foreign_key_fields: + _set_related_name_for_foreign_keys( + meta, cls, replace_related_field=replace_related_field + ) + registry.execute_model_callbacks(cls) # finalize cls.model_rebuild(force=True) + return cls + + @classmethod + def add_to_registry( + cls, + registry: Registry, + name: str = "", + database: Union[bool, Database, Literal["keep"]] = "keep", + *, + replace_related_field: Union[ + bool, + type[BaseModelType], + tuple[type[BaseModelType], ...], + list[type[BaseModelType]], + ] = False, + on_conflict: Literal["keep", "replace", "error"] = "error", + ) -> type[BaseModelType]: + return cls.real_add_to_registry( + registry=registry, + name=name, + database=database, + replace_related_field=replace_related_field, + on_conflict=on_conflict, + ) def get_active_instance_schema( self, check_schema: bool = True, check_tenant: bool = True @@ -221,8 +294,13 @@ def get_active_class_schema(cls, check_schema: bool = True, check_tenant: bool = @classmethod def copy_edgy_model( - cls: type["Model"], registry: Optional["Registry"] = None, name: str = "", **kwargs: Any - ) -> type["Model"]: + cls: type[Model], + registry: Optional[Registry] = None, + name: str = "", + unlink_same_registry: bool = True, + on_conflict: Literal["keep", "replace", "error"] = "error", + **kwargs: Any, + ) -> type[Model]: """Copy the model class and optionally add it to another registry.""" # removes private pydantic stuff, except the prefixed ones attrs = { @@ -237,27 +315,75 @@ def copy_edgy_model( ) ) attrs.update(cls.meta.managers) - _copy = cast( - type["Model"], - type(cls.__name__, cls.__bases__, attrs, skip_registry=True, **kwargs), + _copy = create_edgy_model( + __name__=name or cls.__name__, + __module__=cls.__module__, + __definitions__=attrs, + __metadata__=cls.meta, + __bases__=cls.__bases__, + __type_kwargs__={**kwargs, "skip_registry": True}, ) - for field_name in _copy.meta.foreign_key_fields: - # we need to unreference and check if both models are in the same registry - if cls.meta.fields[field_name].target.meta.registry is cls.meta.registry: - _copy.meta.fields[field_name].target = cls.meta.fields[field_name].target.__name__ + # should also allow masking database with None + if hasattr(cls, "database"): + _copy.database = cls.database + replaceable_models: list[type[BaseModelType]] = [cls] + for field_name in list(_copy.meta.fields): + src_field = cls.meta.fields.get(field_name) + if not isinstance(src_field, BaseForeignKey): + continue + # we use the target of source + replaceable_models.append(src_field.target) + + if src_field.target_registry is cls.meta.registry: + # clear target_registry, for obvious registries + del _copy.meta.fields[field_name].target_registry + if unlink_same_registry and src_field.target_registry is cls.meta.registry: + # we need to unreference so the target is retrieved from the new registry + + _copy.meta.fields[field_name].target = src_field.target.__name__ else: # otherwise we need to disable backrefs - _copy.meta.fields[field_name].target.related_name = False - if name: - _copy.__name__ = name + _copy.meta.fields[field_name].related_name = False + + if isinstance(src_field, BaseManyToManyForeignKeyField): + _copy.meta.fields[field_name].through = src_field.through_original + # clear through registry, we need a copy in the new registry + del _copy.meta.fields[field_name].through_registry + if ( + isinstance(_copy.meta.fields[field_name].through, type) + and issubclass(_copy.meta.fields[field_name].through, BaseModelType) + and not _copy.meta.fields[field_name].through.meta.abstract + ): + # unreference + _copy.meta.fields[field_name].through = through_model = _copy.meta.fields[ + field_name + ].through.copy_edgy_model() + # we want to set the registry explicit + through_model.meta.registry = False + if src_field.from_foreign_key in through_model.meta.fields: + # explicit set + through_model.meta.fields[src_field.from_foreign_key].target = _copy + through_model.meta.fields[src_field.from_foreign_key].related_name = cast( + BaseManyToManyForeignKeyField, + cast(type[BaseModelType], src_field.through).meta.fields[ + src_field.from_foreign_key + ], + ).related_name if registry is not None: # replace when old class otherwise old references can lead to issues - _copy.add_to_registry(registry, replace_related_field=cls) + _copy.add_to_registry( + registry, + replace_related_field=replaceable_models, + on_conflict=on_conflict, + database="keep" + if cls.meta.registry is False or cls.database is not cls.meta.registry.database + else True, + ) return _copy @property def table(self) -> sqlalchemy.Table: - if getattr(self, "_table", None) is None: + if self.__dict__.get("_table", None) is None: schema = self.get_active_instance_schema() return cast( "sqlalchemy.Table", @@ -266,13 +392,22 @@ def table(self) -> sqlalchemy.Table: return self._table @table.setter - def table(self, value: sqlalchemy.Table) -> None: + def table(self, value: Optional[sqlalchemy.Table]) -> None: with contextlib.suppress(AttributeError): del self._pknames with contextlib.suppress(AttributeError): del self._pkcolumns self._table = value + @table.deleter + def table(self) -> None: + with contextlib.suppress(AttributeError): + del self._pknames + with contextlib.suppress(AttributeError): + del self._pkcolumns + with contextlib.suppress(AttributeError): + del self._table + @property def pkcolumns(self) -> Sequence[str]: if self.__dict__.get("_pkcolumns", None) is None: @@ -291,7 +426,7 @@ def pknames(self) -> Sequence[str]: build_pknames(self) return self._pknames - def get_columns_for_name(self: "Model", name: str) -> Sequence["sqlalchemy.Column"]: + def get_columns_for_name(self: Model, name: str) -> Sequence[sqlalchemy.Column]: table = self.table meta = self.meta if name in meta.field_to_columns: @@ -320,7 +455,7 @@ def identifying_clauses(self, prefix: str = "") -> list[Any]: ) return clauses - async def _update(self: "Model", **kwargs: Any) -> Any: + async def _update(self: Model, **kwargs: Any) -> Any: """ Update operation of the database fields. """ @@ -367,7 +502,7 @@ async def _update(self: "Model", **kwargs: Any) -> Any: self._loaded_or_deleted = False await self.meta.signals.post_update.send_async(self.__class__, instance=self) - async def update(self: "Model", **kwargs: Any) -> "Model": + async def update(self: Model, **kwargs: Any) -> Model: token = EXPLICIT_SPECIFIED_VALUES.set(set(kwargs.keys())) try: await self._update(**kwargs) @@ -433,7 +568,7 @@ async def load(self, only_needed: bool = False) -> None: self.__dict__.update(self.transform_input(dict(row._mapping), phase="load", instance=self)) self._loaded_or_deleted = True - async def _insert(self: "Model", **kwargs: Any) -> "Model": + async def _insert(self: Model, **kwargs: Any) -> Model: """ Performs the save instruction. """ @@ -479,11 +614,11 @@ async def _insert(self: "Model", **kwargs: Any) -> "Model": return self async def save( - self: "Model", + self: Model, force_insert: bool = False, values: Union[dict[str, Any], set[str], None] = None, force_save: Optional[bool] = None, - ) -> "Model": + ) -> Model: """ Performs a save of a given model instance. When creating a user it will make sure it can update existing or @@ -544,7 +679,9 @@ async def save( @classmethod def build( - cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, ) -> sqlalchemy.Table: """ Builds the SQLAlchemy table representation from the loaded fields. @@ -572,7 +709,10 @@ def build( for name, field in cls.meta.fields.items(): current_columns = field.get_columns(name) columns.extend(current_columns) - global_constraints.extend(field.get_global_constraints(name, current_columns, schemes)) + if not NO_GLOBAL_FIELD_CONSTRAINTS.get(): + global_constraints.extend( + field.get_global_constraints(name, current_columns, schemes) + ) # Handle the uniqueness together uniques = [] @@ -598,6 +738,36 @@ def build( else cls.get_active_class_schema(check_schema=False, check_tenant=False), ) + @classmethod + def add_global_field_constraints( + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, + ) -> sqlalchemy.Table: + """ + Add global constraints to table. Required for tenants. + """ + tablename: str = cls.meta.tablename + registry = cls.meta.registry + assert registry, "registry is not set" + if metadata is None: + metadata = registry.metadata_by_url[str(cls.database.url)] + schemes: list[str] = [] + if schema: + schemes.append(schema) + if cls.__using_schema__ is not Undefined: + schemes.append(cls.__using_schema__) + db_schema = cls.get_db_schema() or "" + schemes.append(db_schema) + table = metadata.tables[tablename if not schema else f"{schema}.{tablename}"] + for name, field in cls.meta.fields.items(): + current_columns: list[sqlalchemy.Column] = [] + for column_name in cls.meta.field_to_column_names[name]: + current_columns.append(table.columns[column_name]) + for constraint in field.get_global_constraints(name, current_columns, schemes): + table.append_constraint(constraint) + return table + @classmethod def _get_unique_constraints( cls, fields: Union[Sequence, str, sqlalchemy.UniqueConstraint] @@ -639,7 +809,7 @@ def _get_indexes(cls, index: Index) -> Optional[sqlalchemy.Index]: ), ) - def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> "Transaction": + def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: """Return database transaction for the assigned database""" return cast( "Transaction", self.database.transaction(force_rollback=force_rollback, **kwargs) diff --git a/edgy/core/db/models/mixins/reflection.py b/edgy/core/db/models/mixins/reflection.py index 3c00d592..522397da 100644 --- a/edgy/core/db/models/mixins/reflection.py +++ b/edgy/core/db/models/mixins/reflection.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from edgy import Registry from edgy.core.connection.database import Database + from edgy.core.db.models.types import BaseModelType class ReflectedModelMixin: @@ -19,9 +20,16 @@ class ReflectedModelMixin: __reflected__: ClassVar[bool] = True + @classmethod + def real_add_to_registry(cls: type["BaseModelType"], **kwargs: Any) -> type["BaseModelType"]: + kwargs.setdefault("registry_type_name", "reflected") + return cast(type["BaseModelType"], super().real_add_to_registry(**kwargs)) + @classmethod def build( - cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, ) -> Any: """ The inspect is done in an async manner and reflects the objects from the database. diff --git a/edgy/core/db/models/types.py b/edgy/core/db/models/types.py index a39951fe..d32c0394 100644 --- a/edgy/core/db/models/types.py +++ b/edgy/core/db/models/types.py @@ -144,7 +144,9 @@ def model_dump(self, show_pk: Union[bool, None] = None, **kwargs: Any) -> dict[s @classmethod @abstractmethod def build( - cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, ) -> sqlalchemy.Table: """ Builds the SQLAlchemy table representation from the loaded fields. diff --git a/edgy/core/db/relationships/related_field.py b/edgy/core/db/relationships/related_field.py index 53a6eaa9..19a69879 100644 --- a/edgy/core/db/relationships/related_field.py +++ b/edgy/core/db/relationships/related_field.py @@ -34,6 +34,7 @@ def __init__( annotation=Any, column_type=None, null=True, + no_copy=True, **kwargs, ) if self.foreign_key.relation_has_post_delete_callback: diff --git a/edgy/core/db/relationships/relation.py b/edgy/core/db/relationships/relation.py index e19c0262..886c394a 100644 --- a/edgy/core/db/relationships/relation.py +++ b/edgy/core/db/relationships/relation.py @@ -95,6 +95,8 @@ def expand_relationship(self, value: Any) -> Any: **{self.from_foreign_key: self.instance, self.to_foreign_key: value} ) instance.identifying_db_fields = [self.from_foreign_key, self.to_foreign_key] # type: ignore + if getattr(through.meta, "is_tenant", False): + instance.__using_schema__ = self.instance.get_active_instance_schema() # type: ignore return instance def stage(self, *children: "BaseModelType") -> None: @@ -150,9 +152,9 @@ async def remove(self, child: Optional["BaseModelType"] = None) -> None: try: child = await self.get() except ObjectNotFound: - raise RelationshipNotFound(detail="no child found") from None + raise RelationshipNotFound(detail="No child found.") from None else: - raise RelationshipNotFound(detail="no child specified") + raise RelationshipNotFound(detail="No child specified.") if not isinstance( child, (self.to, self.to.proxy_model, self.through, self.through.proxy_model), # type: ignore @@ -164,7 +166,7 @@ async def remove(self, child: Optional["BaseModelType"] = None) -> None: count = await child.query.filter(*child.identifying_clauses()).count() if count == 0: raise RelationshipNotFound( - detail=f"There is no relationship between '{self.from_foreign_key}' and '{self.to_foreign_key}: {getattr(child,self.to_foreign_key).pk}'." + detail=f"There is no relationship between '{self.from_foreign_key}' and '{self.to_foreign_key}: {getattr(child, self.to_foreign_key).pk}'." ) else: await child.delete() @@ -242,6 +244,8 @@ def expand_relationship(self, value: Any) -> Any: value = {next(iter(related_columns)): value} instance = target.proxy_model(**value) instance.identifying_db_fields = related_columns # type: ignore + if getattr(target.meta, "is_tenant", False): + instance.__using_schema__ = self.instance.get_active_instance_schema() # type: ignore return instance def stage(self, *children: "BaseModelType") -> None: diff --git a/edgy/core/utils/models.py b/edgy/core/utils/models.py index fff2db7f..91016d04 100644 --- a/edgy/core/utils/models.py +++ b/edgy/core/utils/models.py @@ -20,6 +20,7 @@ def create_edgy_model( __bases__: Optional[tuple[type["BaseModelType"], ...]] = None, __proxy__: bool = False, __pydantic_extra__: Any = None, + __type_kwargs__: Optional[dict] = None, ) -> type["Model"]: """ Generates an `edgy.Model` with all the required definitions to generate the pydantic @@ -47,8 +48,10 @@ def create_edgy_model( core_definitions.update(**{"Meta": __metadata__}) if __pydantic_extra__: core_definitions.update(**{"__pydantic_extra__": __pydantic_extra__}) + if not __type_kwargs__: + __type_kwargs__ = {} - model: type[Model] = type(__name__, __bases__, core_definitions) + model: type[Model] = type(__name__, __bases__, core_definitions, **__type_kwargs__) return model diff --git a/edgy/exceptions.py b/edgy/exceptions.py index b6b2f5d9..4bd151fe 100644 --- a/edgy/exceptions.py +++ b/edgy/exceptions.py @@ -49,6 +49,9 @@ class RelationshipIncompatible(EdgyException): ... class DuplicateRecordError(EdgyException): ... +class ModelCollisionError(EdgyException): ... + + class RelationshipNotFound(EdgyException): ... diff --git a/tests/cli/custom_multidb_copied_registry/README b/tests/cli/custom_multidb_copied_registry/README new file mode 100644 index 00000000..58c93de4 --- /dev/null +++ b/tests/cli/custom_multidb_copied_registry/README @@ -0,0 +1 @@ +Custom template diff --git a/tests/cli/custom_multidb_copied_registry/alembic.ini.mako b/tests/cli/custom_multidb_copied_registry/alembic.ini.mako new file mode 100644 index 00000000..57ba8a58 --- /dev/null +++ b/tests/cli/custom_multidb_copied_registry/alembic.ini.mako @@ -0,0 +1,50 @@ +# A custom generic database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,saffier + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_saffier] +level = INFO +handlers = +qualname = saffier + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/tests/cli/custom_multidb_copied_registry/env.py b/tests/cli/custom_multidb_copied_registry/env.py new file mode 100644 index 00000000..86f32067 --- /dev/null +++ b/tests/cli/custom_multidb_copied_registry/env.py @@ -0,0 +1,142 @@ +# Custom env template +import asyncio +import logging +import os +from collections.abc import Generator +from logging.config import fileConfig +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from alembic import context +from rich.console import Console + +import edgy +from edgy.core.connection import Database, Registry + +if TYPE_CHECKING: + import sqlalchemy + +# The console used for the outputs +console = Console() + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config: Any = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger("alembic.env") +MAIN_DATABASE_NAME: str = " " + + +def iter_databases(registry: Registry) -> Generator[tuple[str, Database, "sqlalchemy.MetaData"]]: + url: Optional[str] = os.environ.get("EDGY_DATABASE_URL") + name: Union[str, Literal[False], None] = os.environ.get("EDGY_DATABASE") or False + if url and not name: + try: + name = registry.metadata_by_url.get_name(url) + except KeyError: + name = None + if name is False: + db_names = edgy.monkay.settings.migrate_databases + for name in db_names: + if name is None: + yield (None, registry.database, registry.metadata_by_name[None]) + else: + yield (name, registry.extra[name], registry.metadata_by_name[name]) + else: + if name == MAIN_DATABASE_NAME: + name = None + if url: + database = Database(url) + elif name is None: + database = registry.database + else: + database = registry.extra[name] + yield ( + name, + database, + registry.metadata_by_name[name], + ) + + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> Any: + """ + Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + registry = edgy.get_migration_prepared_registry(edgy.monkay.instance.registry.__copy__()) + for name, db, metadata in iter_databases(registry): + context.configure( + url=str(db.url), + target_metadata=metadata, + literal_binds=True, + ) + + with context.begin_transaction(): + context.run_migrations(edgy_dbname=name or "") + + +def do_run_migrations(connection: Any, name: str, metadata: "sqlalchemy.Metadata") -> Any: + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives) -> Any: # type: ignore + if getattr(config.cmd_opts, "autogenerate", False): + script = directives[0] + empty = True + for upgrade_ops in script.upgrade_ops_list: + if not upgrade_ops.is_empty(): + empty = False + break + if empty: + directives[:] = [] + console.print("[bright_red]No changes in schema detected.") + + context.configure( + connection=connection, + target_metadata=metadata, + upgrade_token=f"{name or ''}_upgrades", + downgrade_token=f"{name or ''}_downgrades", + process_revision_directives=process_revision_directives, + **edgy.monkay.settings.alembic_ctx_kwargs, + ) + + with context.begin_transaction(): + context.run_migrations(edgy_dbname=name or "") + + +async def run_migrations_online() -> Any: + """ + Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + # the original script checked for the async compatibility + # we are only compatible with async drivers so just use Database + registry = edgy.get_migration_prepared_registry(edgy.monkay.instance.registry.__copy__()) + async with registry: + for name, db, metadata in iter_databases(registry): + async with db as database: + await database.run_sync(do_run_migrations, name, metadata) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + asyncio.run(run_migrations_online()) diff --git a/tests/cli/custom_multidb_copied_registry/script.py.mako b/tests/cli/custom_multidb_copied_registry/script.py.mako new file mode 100644 index 00000000..f1d3a11e --- /dev/null +++ b/tests/cli/custom_multidb_copied_registry/script.py.mako @@ -0,0 +1,44 @@ +# Custom mako template +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + +def upgrade(edgy_dbname: str = "") -> None: + globals()[f"upgrade_{edgy_dbname}"]() + + +def downgrade(edgy_dbname: str = "") -> None: + globals()[f"downgrade_{edgy_dbname}"]() + + +<% + from edgy import monkay + db_names = monkay.settings.migrate_databases +%> + +## generate an "upgrade_() / downgrade_()" function +## according to edgy migrate settings + +% for db_name in db_names: + +def ${f"upgrade_{db_name or ''}"}(): + ${context.get(f"{db_name or ''}_upgrades", "pass")} + + +def ${f"downgrade_{db_name or ''}"}(): + ${context.get(f"{db_name or ''}_downgrades", "pass")} + +% endfor diff --git a/tests/cli/test_multidb_templates.py b/tests/cli/test_multidb_templates.py index 9ce8da13..13371c6e 100644 --- a/tests/cli/test_multidb_templates.py +++ b/tests/cli/test_multidb_templates.py @@ -66,8 +66,15 @@ async def cleanup_db(): @pytest.mark.parametrize("app_flag", ["explicit", "explicit_env"]) @pytest.mark.parametrize( "template_param", - ["", " -t default", " -t plain", " -t url", " -t ./custom_multidb"], - ids=["default_empty", "default", "plain", "url", "custom"], + [ + "", + " -t default", + " -t plain", + " -t url", + " -t ./custom_multidb", + " -t ./custom_multidb_copied_registry", + ], + ids=["default_empty", "default", "plain", "url", "custom", "custom_copied"], ) async def test_migrate_upgrade_multidb(app_flag, template_param): os.chdir(base_path) diff --git a/tests/contrib/autoreflection/test_reflecting_models.py b/tests/contrib/autoreflection/test_reflecting_models.py index aa6cc7af..f983eb43 100644 --- a/tests/contrib/autoreflection/test_reflecting_models.py +++ b/tests/contrib/autoreflection/test_reflecting_models.py @@ -101,6 +101,69 @@ class Meta: ) +async def test_basic_reflection_after_copy(): + reflected = edgy.Registry(DATABASE_URL) + + class AutoAll(AutoReflectModel): + class Meta: + registry = reflected + + class AutoNever(AutoReflectModel): + non_matching = edgy.CharField(max_length=40) + + class Meta: + registry = reflected + template = r"AutoNever" + + class AutoNever2(AutoReflectModel): + id = edgy.CharField(max_length=40, primary_key=True) + + class Meta: + registry = reflected + template = r"AutoNever2" + + class AutoNever3(AutoReflectModel): + class Meta: + registry = reflected + template = r"AutoNever3" + exclude_pattern = r".*" + + class AutoFoo(AutoReflectModel): + class Meta: + registry = reflected + include_pattern = r"^foos$" + + class AutoBar(AutoReflectModel): + class Meta: + registry = reflected + include_pattern = r"^bars" + template = r"{tablename}_{tablename}" + + assert AutoBar.meta.template + + reflected = reflected.__copy__() + + assert len(reflected.reflected) == 0 + async with reflected: + assert ( + sum( + 1 for model in reflected.reflected.values() if model.__name__.startswith("AutoAll") + ) + == 3 + ) + assert "bars_bars" in reflected.reflected + assert "AutoNever" not in reflected.reflected + assert "AutoNever2" not in reflected.reflected + assert "AutoNever3" not in reflected.reflected + + assert ( + sum( + 1 for model in reflected.reflected.values() if model.__name__.startswith("AutoFoo") + ) + == 1 + ) + + async def test_extra_reflection(): reflected = edgy.Registry(DATABASE_ALTERNATIVE_URL, extra={"another": DATABASE_URL}) diff --git a/tests/contrib/multi_tenancy/test_migrate.py b/tests/contrib/multi_tenancy/test_migrate.py index 1d65f6bb..d205947a 100644 --- a/tests/contrib/multi_tenancy/test_migrate.py +++ b/tests/contrib/multi_tenancy/test_migrate.py @@ -56,6 +56,20 @@ async def test_migrate_objs_main_only(): assert len(registry.metadata_by_name[None].tables.keys()) == 2 +async def test_migrate_objs_main_only_after_copy(): + tenant = await Tenant.query.create( + schema_name="migrate_edgy", + domain_url="https://edgy.dymmond.com", + tenant_name="migrate_edgy", + ) + + assert tenant.schema_name == "migrate_edgy" + assert tenant.tenant_name == "migrate_edgy" + + registry = edgy.get_migration_prepared_registry(models.__copy__()) + assert len(registry.metadata_by_name[None].tables.keys()) == 2 + + async def test_migrate_objs_all(): tenant = await Tenant.query.create( schema_name="migrate_edgy", @@ -79,6 +93,27 @@ async def test_migrate_objs_all(): } +async def test_migrate_objs_all_after_copy(): + tenant = await Tenant.query.create( + schema_name="migrate_edgy", + domain_url="https://edgy.dymmond.com", + tenant_name="migrate_edgy", + ) + + assert tenant.schema_name == "migrate_edgy" + assert tenant.tenant_name == "migrate_edgy" + + edgy.monkay.set_instance(Instance(registry=models.__copy__())) + with edgy.monkay.with_settings(edgy.monkay.settings.model_copy(update={"multi_schema": True})): + registry = edgy.get_migration_prepared_registry() + + assert set(registry.metadata_by_name[None].tables.keys()) == { + "tenants", + "migrate_edgy.products", + "products", + } + + async def test_migrate_objs_namespace_only(): tenant = await Tenant.query.create( schema_name="migrate_edgy", diff --git a/tests/contrib/multi_tenancy/test_mt_models.py b/tests/contrib/multi_tenancy/test_mt_models.py index 387f3bee..099bc95a 100644 --- a/tests/contrib/multi_tenancy/test_mt_models.py +++ b/tests/contrib/multi_tenancy/test_mt_models.py @@ -71,6 +71,14 @@ class Meta: is_tenant = True +class Cart(TenantModel): + products = fields.ManyToMany(Product) + + class Meta: + registry = models + is_tenant = True + + async def test_create_a_tenant_schema(): tenant = await Tenant.query.create( schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy" @@ -80,6 +88,22 @@ async def test_create_a_tenant_schema(): assert tenant.tenant_name == "edgy" +async def test_create_a_tenant_schema_copy(): + copied = models.__copy__() + tenant = await copied.get_model("Tenant").query.create( + schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy" + ) + + assert tenant.schema_name == "edgy" + assert tenant.tenant_name == "edgy" + NewProduct = copied.get_model("Product") + NewCart = copied.get_model("Cart") + assert NewCart.meta.fields["products"].target is NewProduct + assert NewCart.meta.fields["products"].through is not Cart.meta.fields["products"].through + assert hasattr(Cart.meta.fields["products"].through, "_db_schemas") + assert hasattr(NewCart.meta.fields["products"].through, "_db_schemas") + + async def test_raises_ModelSchemaError_on_public_schema(): with pytest.raises(ModelSchemaError) as raised: await Tenant.query.create( diff --git a/tests/contrib/multi_tenancy/test_tenant_models_using.py b/tests/contrib/multi_tenancy/test_tenant_models_using.py index 249b7cb9..816ff821 100644 --- a/tests/contrib/multi_tenancy/test_tenant_models_using.py +++ b/tests/contrib/multi_tenancy/test_tenant_models_using.py @@ -52,29 +52,66 @@ class Meta: is_tenant = True -async def test_schema_with_using_in_different_place(): - tenant = await Tenant.query.create( +class Cart(TenantModel): + products = fields.ManyToMany(Product) + + class Meta: + registry = models + is_tenant = True + + +@pytest.mark.parametrize("use_copy", ["false", "instant", "after"]) +async def test_schema_with_using_in_different_place(use_copy): + if use_copy == "instant": + copied = models.__copy__() + NewTenant = copied.get_model("Tenant") + NewProduct = copied.get_model("Product") + NewCart = copied.get_model("Cart") + else: + NewTenant = Tenant + NewProduct = Product + NewCart = Cart + tenant = await NewTenant.query.create( schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy" ) + if use_copy == "after": + copied = models.__copy__() + NewTenant = copied.get_model("Tenant") + NewProduct = copied.get_model("Product") + NewCart = copied.get_model("Cart") + cart = await NewCart.query.using(schema=tenant.schema_name).create() + assert cart.__using_schema__ == tenant.schema_name for i in range(5): - await Product.query.using(schema=tenant.schema_name).create(name=f"product-{i}") + product = await NewProduct.query.using(schema=tenant.schema_name).create( + name=f"product-{i}" + ) + if i % 2 == 0: + product_through = cart.products.through(cart=cart, product=product) + product_through.__using_schema__ = tenant.schema_name + assert await cart.products.add(product_through) + else: + assert await cart.products.add(product) + + total = await NewProduct.query.filter().using(schema=tenant.schema_name).all() + + assert len(total) == 5 - total = await Product.query.filter().using(schema=tenant.schema_name).all() + total = await cart.products.filter().using(schema=tenant.schema_name).all() assert len(total) == 5 - total = await Product.query.all() + total = await NewProduct.query.all() assert len(total) == 0 for i in range(15): - await Product.query.create(name=f"product-{i}") + await NewProduct.query.create(name=f"product-{i}") - total = await Product.query.all() + total = await NewProduct.query.all() assert len(total) == 15 - total = await Product.query.filter().using(schema=tenant.schema_name).all() + total = await NewProduct.query.filter().using(schema=tenant.schema_name).all() assert len(total) == 5 diff --git a/tests/metaclass/test_meta_errors.py b/tests/metaclass/test_meta_errors.py index 25cf48e4..5d20c1ba 100644 --- a/tests/metaclass/test_meta_errors.py +++ b/tests/metaclass/test_meta_errors.py @@ -4,7 +4,7 @@ import edgy from edgy import Manager, QuerySet -from edgy.exceptions import ForeignKeyBadConfigured, ImproperlyConfigured +from edgy.exceptions import ForeignKeyBadConfigured, ImproperlyConfigured, ModelCollisionError from edgy.testclient import DatabaseTestClient from tests.settings import DATABASE_URL @@ -120,21 +120,56 @@ class Meta: assert raised.value.args[0] == "Meta.indexes must be a list of Index types." +def test_raises_ModelCollisionError(): + with pytest.raises(ModelCollisionError) as raised: + + class User(edgy.StrictModel): + name = edgy.CharField(max_length=255) + + class Meta: + registry = models + + assert raised.value.args[0] == ( + 'A model with the same name is already registered: "User".\n' + 'If this is not a bug, define the behaviour by setting "on_conflict" to either "keep" or "replace".' + ) + + +@pytest.mark.parametrize("value", [True, "allow_search"]) +def test_no_raises_ModelCollisionError_and_set_correctly(value): + class BaseUser(edgy.StrictModel): + name = edgy.CharField(max_length=255) + + class Meta: + registry = models + abstract = True + + class User(BaseUser, skip_registry=value): + pass + + if value is True: + assert BaseUser.meta.registry is models + assert User.meta.registry is None + else: + assert BaseUser.meta.registry is models + assert User.meta.registry is models + + def test_raises_ForeignKeyBadConfigured(): name = "profiles" with pytest.raises(ForeignKeyBadConfigured) as raised: - class User(edgy.StrictModel): + class User2(edgy.StrictModel): name = edgy.CharField(max_length=255) class Meta: registry = models class Profile(edgy.StrictModel): - user = edgy.ForeignKey(User, null=False, on_delete=edgy.CASCADE, related_name=name) + user = edgy.ForeignKey(User2, null=False, on_delete=edgy.CASCADE, related_name=name) another_user = edgy.ForeignKey( - User, null=False, on_delete=edgy.CASCADE, related_name=name + User2, null=False, on_delete=edgy.CASCADE, related_name=name ) class Meta: diff --git a/tests/models/test_model_copying.py b/tests/models/test_model_copying.py new file mode 100644 index 00000000..cdd1cab0 --- /dev/null +++ b/tests/models/test_model_copying.py @@ -0,0 +1,175 @@ +import pytest + +import edgy +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +database = DatabaseTestClient(DATABASE_URL, drop_database=True) + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize("unlink_same_registry", [True, False]) +async def test_copy_model_abstract(unlink_same_registry): + models = edgy.Registry(database=database) + models2 = edgy.Registry(database=database, schema="another") + + class Product(edgy.StrictModel): + class Meta: + registry = models + + class ThroughModel(edgy.StrictModel): + class Meta: + abstract = True + + class Cart(edgy.StrictModel): + products = edgy.fields.ManyToMany(to=Product, through=ThroughModel) + + class Meta: + registry = models + + assert len(models.models) == 3 + + assert models.get_model("Cart").meta.fields["products"].target is Product + through = models.get_model("Cart").meta.fields["products"].through + assert through is models.get_model(through.__name__) + assert "ThroughModel" not in models.models + + NewCart = Cart.copy_edgy_model(registry=models2, unlink_same_registry=unlink_same_registry) + assert "ThroughModel" not in models2.models + + # nothing changed + assert len(models.models) == 3 + # but the copy has new models + assert NewCart is models2.get_model("Cart") + if not unlink_same_registry: + # cart, through could be added because of different registry + assert len(models2.models) == 2 + assert models2.get_model("Cart").meta.fields["products"].target is models.get_model( + "Product" + ) + else: + # cart, through couldn't be added yet + assert len(models2.models) == 1 + Product.copy_edgy_model(registry=models2) + # cart, through could be added now + assert len(models2.models) == 3 + + through = models2.get_model("Cart").meta.fields["products"].through + assert "_db_schemas" in through.__dict__ + assert through is models2.get_model(through.__name__) + assert through is not models.get_model(through.__name__) + for reg in [models, models2]: + assert "ThroughModel" not in reg.models + + +@pytest.mark.parametrize("unlink_same_registry", [True, False]) +async def test_copy_model_concrete_same(unlink_same_registry): + models = edgy.Registry(database=database) + models2 = edgy.Registry(database=database, schema="another") + + class Product(edgy.StrictModel): + class Meta: + registry = models + + class ThroughModel(edgy.StrictModel): + p = edgy.fields.ForeignKey(Product) + c = edgy.fields.ForeignKey("Cart", target_registry=models) + + class Meta: + registry = models + + class Cart(edgy.StrictModel): + products = edgy.fields.ManyToMany(to=Product, through=ThroughModel) + + class Meta: + registry = models + + assert len(models.models) == 3 + assert models.get_model("Cart").meta.fields["products"].target is Product + through = models.get_model("Cart").meta.fields["products"].through + assert through is models.get_model(through.__name__) + # try copying + + NewCart = Cart.copy_edgy_model(registry=models2, unlink_same_registry=unlink_same_registry) + + # nothing changed + assert len(models.models) == 3 + # but the copy has new models + assert NewCart is models2.get_model("Cart") + if not unlink_same_registry: + # cart, through could be added because of different registry + assert len(models2.models) == 2 + assert models2.get_model("Cart").meta.fields["products"].target is models.get_model( + "Product" + ) + else: + # cart, through couldn't be added yet + assert len(models2.models) == 1 + Product.copy_edgy_model(registry=models2) + # cart, through could be added now + assert len(models2.models) == 3 + + through = models2.get_model("Cart").meta.fields["products"].through + assert "_db_schemas" in through.__dict__ + assert through is models2.get_model(through.__name__) + assert through is not models.get_model(through.__name__) + assert through.__name__ == "ThroughModel" + + +@pytest.mark.parametrize("unlink_same_registry", [True, False]) +async def test_copy_model_concrete_other(unlink_same_registry): + models = edgy.Registry(database=database) + models2 = edgy.Registry(database=database, schema="another") + models3 = edgy.Registry(database=database, schema="another2") + + class Product(edgy.StrictModel): + class Meta: + registry = models + + class ThroughModel(edgy.StrictModel): + p = edgy.fields.ForeignKey(Product) + c = edgy.fields.ForeignKey("Cart", target_registry=models) + + class Meta: + registry = models3 + + class Cart(edgy.StrictModel): + products = edgy.fields.ManyToMany(to=Product, through=ThroughModel) + + class Meta: + registry = models + + assert len(models.models) == 2 + assert len(models3.models) == 1 + + assert models.get_model("Cart").meta.fields["products"].target is Product + through = models.get_model("Cart").meta.fields["products"].through + assert through is models3.get_model(through.__name__) + + # try copying + + NewCart = Cart.copy_edgy_model(registry=models2, unlink_same_registry=unlink_same_registry) + + # nothing changed + assert len(models.models) == 2 + # but the copy has new models + assert NewCart is models2.get_model("Cart") + if not unlink_same_registry: + # cart, through could be added because of different registry + assert len(models2.models) == 2 + assert models2.get_model("Cart").meta.fields["products"].target is models.get_model( + "Product" + ) + else: + # cart, through couldn't be added yet + assert len(models2.models) == 1 + Product.copy_edgy_model(registry=models2) + # cart, through could be added now + assert len(models2.models) == 3 + + through = models2.get_model("Cart").meta.fields["products"].through + assert "_db_schemas" in through.__dict__ + assert through is models2.get_model(through.__name__) + assert through is not models3.get_model(through.__name__) + assert through.__name__ == "ThroughModel" diff --git a/tests/registry/test_registry_copying.py b/tests/registry/test_registry_copying.py new file mode 100644 index 00000000..6bfe0910 --- /dev/null +++ b/tests/registry/test_registry_copying.py @@ -0,0 +1,98 @@ +import pytest + +import edgy +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +database = DatabaseTestClient(DATABASE_URL, drop_database=True) + +pytestmark = pytest.mark.anyio + + +async def test_copy_registry_abstract(): + models = edgy.Registry(database=database) + + class Product(edgy.StrictModel): + class Meta: + registry = models + + class ThroughModel(edgy.StrictModel): + class Meta: + abstract = True + + class Cart(edgy.StrictModel): + products = edgy.fields.ManyToMany(to=Product, through=ThroughModel) + + class Meta: + registry = models + + assert len(models.models) == 3 + + assert models.get_model("Cart").meta.fields["products"].target is Product + through = models.get_model("Cart").meta.fields["products"].through + assert through is models.get_model(through.__name__) + + # try copying + models_copy = edgy.get_migration_prepared_registry(models.__copy__()) + assert len(models_copy.models) == 3 + assert models_copy.get_model("Cart").meta.fields["products"].target is models_copy.get_model( + "Product" + ) + through = models_copy.get_model("Cart").meta.fields["products"].through + assert "_db_schemas" in through.__dict__ + assert through is models_copy.get_model(through.__name__) + + +@pytest.mark.parametrize("registry_used", ["same", "other", "none", "false"]) +async def test_copy_registry_concrete(registry_used): + models = edgy.Registry(database=database) + models2 = edgy.Registry(database=database, schema="another") + + class Product(edgy.StrictModel): + class Meta: + registry = models + + class ThroughModel(edgy.StrictModel): + p = edgy.fields.ForeignKey(Product) + c = edgy.fields.ForeignKey("Cart", target_registry=models) + + class Meta: + if registry_used == "same": + registry = models + elif registry_used == "other": + registry = models2 + elif registry_used == "none": + registry = None + elif registry_used == "false": + registry = False + + class Cart(edgy.StrictModel): + products = edgy.fields.ManyToMany(to=Product, through=ThroughModel) + + class Meta: + registry = models + + if registry_used == "other": + assert len(models.models) == 2 + assert len(models2.models) == 1 + else: + assert len(models.models) == 3 + + assert models.get_model("Cart").meta.fields["products"].target is Product + if registry_used == "other": + through = models.get_model("Cart").meta.fields["products"].through + assert through is models2.get_model(through.__name__) + else: + through = models.get_model("Cart").meta.fields["products"].through + assert through is models.get_model(through.__name__) + + # try copying + models_copy = edgy.get_migration_prepared_registry(models.__copy__()) + # the through model is copied for having a new one. It is added to the new registry + assert len(models_copy.models) == 3 + assert models_copy.get_model("Cart").meta.fields["products"].target is models_copy.get_model( + "Product" + ) + through = models_copy.get_model("Cart").meta.fields["products"].through + assert "_db_schemas" in through.__dict__ + assert through is models_copy.get_model(through.__name__) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 698a9fb0..a0024bf8 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -4,9 +4,10 @@ import edgy from edgy import Registry from edgy.testclient import DatabaseTestClient -from tests.settings import DATABASE_URL +from tests.settings import DATABASE_ALTERNATIVE_URL, DATABASE_URL database = DatabaseTestClient(url=DATABASE_URL) +database2 = DatabaseTestClient(url=DATABASE_ALTERNATIVE_URL) models = Registry(database=database) nother = Registry(database=database) @@ -34,6 +35,7 @@ class Meta: class Contact(Profile): age = edgy.CharField(max_length=255) address = edgy.CharField(max_length=255) + database = database2 class Meta: registry = models @@ -56,6 +58,23 @@ def test_migrate_without_model_apps(instance_wrapper, deprecated): registry = edgy.get_migration_prepared_registry() assert len(registry.models) == 3 + assert registry.get_model("Profile").meta.fields["related"].target is registry.get_model( + "Profile" + ) + through = registry.get_model("Profile").meta.fields["related"].through + assert through is registry.get_model(through.__name__) + assert registry.get_model("Contact").database is database2 + + # try copying + registry = edgy.get_migration_prepared_registry(registry.__copy__()) + assert len(registry.models) == 3 + assert registry.get_model("Profile").meta.fields["related"].target is registry.get_model( + "Profile" + ) + through = registry.get_model("Profile").meta.fields["related"].through + assert through is registry.get_model(through.__name__) + assert registry.get_model("Contact").database is database2 + def test_migrate_without_model_apps_and_app(): migrate = edgy.Instance(registry=models)