Skip to content

Commit

Permalink
Tests for pydantic support
Browse files Browse the repository at this point in the history
  • Loading branch information
kalaspuff committed Nov 23, 2023
1 parent f0726d3 commit 56d494d
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 20 deletions.
27 changes: 19 additions & 8 deletions stockholm/currency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from decimal import Decimal
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Type, Union, cast


class DefaultCurrencyValue(type):
Expand Down Expand Up @@ -144,7 +144,7 @@ def __instancecheck__(self, instance: Any) -> bool:
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable,
_handler: Any,
) -> Any:
def validate_currency_code(value: Any) -> BaseCurrency:
return get_currency(str(value))
Expand Down Expand Up @@ -173,11 +173,20 @@ def serialize(value: Any) -> str:
)
]

def json_schema(schema: Any) -> Any:
if isinstance(schema, dict):
if schema.get("type") == "is-instance":
return None
return {k: json_schema(v) for k, v in schema.items() if json_schema(v) is not None}
elif isinstance(schema, list):
return [json_schema(v) for v in schema if json_schema(v) is not None]
return schema

return {
"type": "json-or-python",
"json_schema": {
"type": "union",
"choices": schemas,
"choices": json_schema(schemas),
},
"python_schema": {
"type": "union",
Expand All @@ -192,7 +201,7 @@ def serialize(value: Any) -> str:
}

@classmethod
def _validate(cls, value: Any, handler: Callable[..., str]) -> str:
def _validate(cls, value: Any, handler: Callable[..., BaseCurrency]) -> BaseCurrency:
return handler(value)


Expand Down Expand Up @@ -1934,11 +1943,13 @@ class Currency(BaseCurrency):
ZWN = ZWN
ZWR = ZWR

def __get__(self, instance: Any, owner: Any) -> BaseCurrency:
return cast(BaseCurrency, ...)
if TYPE_CHECKING: # pragma: no cover

def __get__(self, instance: Any, owner: Any) -> BaseCurrency:
return cast(BaseCurrency, ...)

def __set__(self, instance: Any, value: CurrencyValue) -> None:
...
def __set__(self, instance: Any, value: CurrencyValue) -> None:
...


from stockholm.money import Money # noqa isort:skip
15 changes: 12 additions & 3 deletions stockholm/money.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def __deepcopy__(self, memo: Dict) -> MoneyModel[MoneyType]:
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable,
_handler: Any,
) -> Any:
def validate_money(value: Any) -> MoneyModel[MoneyType]:
return cls(value)
Expand Down Expand Up @@ -972,8 +972,8 @@ def serialize(value: MoneyModel[MoneyType]) -> Dict:
"schema": {
"type": "union",
"choices": [
currency_regex_str_schema,
is_currency_instance_schema,
currency_regex_str_schema,
],
},
},
Expand Down Expand Up @@ -1027,11 +1027,20 @@ def serialize(value: MoneyModel[MoneyType]) -> Dict:
)
]

def json_schema(schema: Any) -> Any:
if isinstance(schema, dict):
if schema.get("type") == "is-instance":
return None
return {k: json_schema(v) for k, v in schema.items() if json_schema(v) is not None}
elif isinstance(schema, list):
return [json_schema(v) for v in schema if json_schema(v) is not None]
return schema

return {
"type": "json-or-python",
"json_schema": {
"type": "union",
"choices": schemas,
"choices": json_schema(schemas),
},
"python_schema": {
"type": "union",
Expand Down
16 changes: 9 additions & 7 deletions stockholm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import sys
from abc import abstractmethod
from decimal import Decimal
from typing import Any, Callable, Generic, NotRequired, Required, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, NotRequired, Required, TypeAlias, TypeVar

from .currency import BaseCurrency, CurrencyValue, MetaCurrency
from .money import Money, MoneyModel
from .rate import Number, NumericType

if sys.version_info < (3, 9):
from typing_extensions import TypedDict # isort: skip
from typing_extensions import TypedDict # pragma: no cover
else:
from typing import TypedDict

Expand All @@ -22,12 +22,14 @@
class ConvertibleTypeDescriptor(Generic[SchemaT, GetT, SetT]):
__args__: tuple[SchemaT, GetT, SetT]

@abstractmethod
def __get__(self, instance: Any, owner: Any) -> GetT:
...
if TYPE_CHECKING: # pragma: no cover

def __set__(self, instance: Any, value: SetT) -> None:
...
@abstractmethod
def __get__(self, instance: Any, owner: Any) -> GetT:
...

def __set__(self, instance: Any, value: SetT) -> None:
...

@classmethod
def __get_pydantic_core_schema__(
Expand Down
50 changes: 48 additions & 2 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from decimal import Decimal
from typing import Any

from pydantic import BaseModel

from stockholm import Currency, Money
from stockholm.currency import JPY, USD
from stockholm import Currency, Money, get_currency
from stockholm.currency import JPY, USD, BaseCurrency
from stockholm.types import (
ConvertibleToCurrency,
ConvertibleToMoney,
Expand Down Expand Up @@ -69,3 +71,47 @@ class TestConvertibleModel(BaseModel):
"units": 42,
"nanos": 0,
}

assert json.loads(m.model_dump_json()) == {
"money": {"value": "100.45", "units": 100, "nanos": 450000000, "currency_code": None},
"money_with_currency": {"value": "42.999 SEK", "units": 42, "nanos": 999000000, "currency_code": "SEK"},
"number": {"value": "42", "units": 42, "nanos": 0},
"currency": "JPY",
}

assert m == TestConvertibleModel.model_validate_json(m.model_dump_json())
assert m.model_dump_json() == TestConvertibleModel.model_validate_json(m.model_dump_json()).model_dump_json()


def test_validate_money() -> None:
def validate_money(value: Any) -> Money:
return Money(value)

assert Money._validate(-0.01, validate_money).as_dict() == {
"value": "-0.01",
"units": 0,
"nanos": -10000000,
"currency_code": None,
}
assert Money._validate("4711.00499 EUR", validate_money).as_dict() == {
"value": "4711.00499 EUR",
"units": 4711,
"nanos": 4990000,
"currency_code": "EUR",
}
assert Money._validate({"units": 42, "nanos": 15000000, "currency": "USD"}, validate_money).as_dict() == {
"value": "42.015 USD",
"units": 42,
"nanos": 15000000,
"currency_code": "USD",
}


def test_validate_currency() -> None:
def validate_currency_code(value: Any) -> BaseCurrency:
return get_currency(str(value))

assert Currency._validate(Currency.USD, validate_currency_code).ticker == "USD"
assert Currency._validate("JPY", validate_currency_code).ticker == "JPY"
assert Currency._validate(Currency.USD, validate_currency_code).decimal_digits == 2
assert Currency._validate("JPY", validate_currency_code).decimal_digits == 0
11 changes: 11 additions & 0 deletions tests/types/pydantic_field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from decimal import Decimal

from pydantic import BaseModel
Expand Down Expand Up @@ -60,3 +61,13 @@ class TestConvertibleModel(BaseModel):
"units": 42,
"nanos": 0,
}

assert json.loads(m1.model_dump_json()) == {
"money": {"value": "100.45", "units": 100, "nanos": 450000000, "currency_code": None},
"money_with_currency": {"value": "42.999 SEK", "units": 42, "nanos": 999000000, "currency_code": "SEK"},
"number": {"value": "42", "units": 42, "nanos": 0},
"currency": "JPY",
}

assert m1 == TestConvertibleModel.model_validate_json(m1.model_dump_json())
assert m1.model_dump_json() == TestConvertibleModel.model_validate_json(m1.model_dump_json()).model_dump_json()

0 comments on commit 56d494d

Please sign in to comment.