Skip to content

Commit

Permalink
feat(signal schema): serialize base classes for custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 4, 2025
1 parent 8dfa4ff commit 7d188a6
Show file tree
Hide file tree
Showing 2 changed files with 369 additions and 59 deletions.
116 changes: 89 additions & 27 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
Final,
List,
Literal,
Mapping,
Optional,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model
from sqlalchemy import ColumnElement
from typing_extensions import Literal as LiteralEx

Expand Down Expand Up @@ -85,8 +86,31 @@ def __init__(self, method: str, field):
)


class CustomType(BaseModel):
schema_version: int = Field(ge=1, le=2, strict=True)
name: str
fields: dict[str, str]
bases: list[tuple[str, str, Optional[str]]]

@classmethod
def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
version = data.get("schema_version", 1)

if version == 1:
data = {
"schema_version": 1,
"name": type_name,
"fields": data,
"bases": [],
}

return cls(**data)


def create_feature_model(
name: str, fields: dict[str, Union[type, tuple[type, Any]]]
name: str,
fields: Mapping[str, Union[type, None, tuple[type, Any]]],
base: Optional[type] = None,
) -> type[BaseModel]:
"""
This gets or returns a dynamic feature model for use in restoring a model
Expand All @@ -98,7 +122,7 @@ def create_feature_model(
name = name.replace("@", "_")
return create_model(
name,
__base__=DataModel, # type: ignore[call-overload]
__base__=base or DataModel, # type: ignore[call-overload]
# These are tuples for each field of: annotation, default (if any)
**{
field_name: anno if isinstance(anno, tuple) else (anno, None)
Expand Down Expand Up @@ -156,7 +180,7 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
return SignalSchema(signals)

@staticmethod
def _serialize_custom_model_fields(
def _serialize_custom_model(
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
) -> str:
"""This serializes any custom type information to the provided custom_types
Expand All @@ -165,12 +189,23 @@ def _serialize_custom_model_fields(
# This type is already stored in custom_types.
return version_name
fields = {}

for field_name, info in fr.model_fields.items():
field_type = info.annotation
# All fields should be typed.
assert field_type
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
custom_types[version_name] = fields

bases: list[tuple[str, str, Optional[str]]] = []
for type_ in fr.__mro__:
model_store_name = (
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
)
bases.append((type_.__name__, type_.__module__, model_store_name))

ct = CustomType(schema_version=2, name=version_name, fields=fields, bases=bases)
custom_types[version_name] = ct.model_dump()

return version_name

@staticmethod
Expand All @@ -184,15 +219,12 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
type_name = st_version_name
# Save this type to custom_types.
SignalSchema._serialize_custom_model_fields(
st_version_name, st, custom_types
)
SignalSchema._serialize_custom_model(st_version_name, st, custom_types)
return type_name

def serialize(self) -> dict[str, Any]:
Expand All @@ -215,39 +247,74 @@ def _split_subtypes(type_name: str) -> list[str]:
depth += 1
elif c == "]":
if depth == 0:
raise TypeError(
raise ValueError(
"Extra closing square bracket when parsing subtype list"
)
depth -= 1
elif c == "," and depth == 0:
subtypes.append(type_name[start:i].strip())
start = i + 1
if depth > 0:
raise TypeError("Unclosed square bracket when parsing subtype list")
raise ValueError("Unclosed square bracket when parsing subtype list")
subtypes.append(type_name[start:].strip())
return subtypes

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
def _deserialize_custom_type(
type_name: str, custom_types: dict[str, Any]
) -> Optional[type]:
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
it based on the information from the custom types dict that includes fields and
bases."""
model_name, version = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(model_name, version)
if fr:
return fr

if type_name in custom_types:
ct = CustomType.deserialize(custom_types[type_name], type_name)

fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in ct.fields.items()
}

base_model = None
for base in ct.bases:
_, _, model_store_name = base
if model_store_name:
model_name, version = ModelStore.parse_name_version(
model_store_name
)
base_model = ModelStore.get(model_name, version)
if base_model:
break

return create_feature_model(type_name, fields, base=base_model)

return None

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
"""Convert a string-based type back into a python type."""
type_name = type_name.strip()
if not type_name:
raise TypeError("Type cannot be empty")
raise ValueError("Type cannot be empty")
if type_name == "NoneType":
return None

bracket_idx = type_name.find("[")
subtypes: Optional[tuple[Optional[type], ...]] = None
if bracket_idx > -1:
if bracket_idx == 0:
raise TypeError("Type cannot start with '['")
raise ValueError("Type cannot start with '['")
close_bracket_idx = type_name.rfind("]")
if close_bracket_idx == -1:
raise TypeError("Unclosed square bracket when parsing type")
raise ValueError("Unclosed square bracket when parsing type")
if close_bracket_idx < bracket_idx:
raise TypeError("Square brackets are out of order when parsing type")
raise ValueError("Square brackets are out of order when parsing type")
if close_bracket_idx == bracket_idx + 1:
raise TypeError("Empty square brackets when parsing type")
raise ValueError("Empty square brackets when parsing type")
subtype_names = SignalSchema._split_subtypes(
type_name[bracket_idx + 1 : close_bracket_idx]
)
Expand All @@ -267,18 +334,10 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type
return fr[subtypes] # type: ignore[index]
return fr # type: ignore[return-value]

model_name, version = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(model_name, version)
fr = SignalSchema._deserialize_custom_type(type_name, custom_types)
if fr:
return fr

if type_name in custom_types:
fields = custom_types[type_name]
fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in fields.items()
}
return create_feature_model(type_name, fields)
# This can occur if a third-party or custom type is used, which is not available
# when deserializing.
warnings.warn(
Expand Down Expand Up @@ -317,7 +376,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema":
stacklevel=2,
)
continue
except TypeError as err:
except ValueError as err:
raise SignalSchemaError(
f"cannot deserialize '{signal}': {err}"
) from err
Expand Down Expand Up @@ -662,6 +721,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str:
stacklevel=2,
)
return "Any"
if ModelStore.is_pydantic(type_):
ModelStore.register(type_)
return ModelStore.get_name(type_)
return type_.__name__

@staticmethod
Expand Down
Loading

0 comments on commit 7d188a6

Please sign in to comment.