Skip to content

Commit

Permalink
atdpy: evaluate default field values for each object creation instead…
Browse files Browse the repository at this point in the history
… of (#341)

sharing the same physical value across all objects of the same class.
Fixes #339
  • Loading branch information
mjambon authored May 11, 2023
1 parent 12ea5c5 commit 38fb495
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 24 deletions.
28 changes: 14 additions & 14 deletions atdpy/src/lib/Codegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ methods and functions to convert data from/to JSON.

# Import annotations to allow forward references
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union

import json
Expand Down Expand Up @@ -544,19 +544,16 @@ let rec type_name_of_expr env (e : type_expr) : string =
| Name (loc, (_, name, _::_), _) -> assert false
| Tvar (loc, _) -> not_implemented loc "type variables"

let rec get_default_default
?(mutable_ok = true) (e : type_expr) : string option =
let rec get_default_default (e : type_expr) : string option =
match e with
| Sum _
| Record _
| Tuple _ (* a default tuple could be possible but we're lazy *) -> None
| List _ ->
if mutable_ok then Some "[]"
else None
| List _ -> Some "[]"
| Option _
| Nullable _ -> Some "None"
| Shared (loc, e, an) -> get_default_default ~mutable_ok e
| Wrap (loc, e, an) -> get_default_default ~mutable_ok e
| Shared (loc, e, an) -> get_default_default e
| Wrap (loc, e, an) -> get_default_default e
| Name (loc, (loc2, name, []), an) ->
(match name with
| "unit" -> Some "None"
Expand All @@ -570,12 +567,11 @@ let rec get_default_default
| Name _ -> None
| Tvar _ -> None

let get_python_default
?mutable_ok (e : type_expr) (an : annot) : string option =
let get_python_default (e : type_expr) (an : annot) : string option =
let user_default = Python_annot.get_python_default an in
match user_default with
| Some s -> Some s
| None -> get_default_default ?mutable_ok e
| None -> get_default_default e

(* see explanation where this function is used *)
let has_no_class_inst_prop_default
Expand All @@ -584,7 +580,7 @@ let has_no_class_inst_prop_default
| Required -> true
| Optional -> (* default is None *) false
| With_default ->
match get_python_default ~mutable_ok:false e an with
match get_python_default e an with
| Some _ -> false
| None ->
(* There's either no default at all which is an error,
Expand Down Expand Up @@ -795,9 +791,13 @@ let inst_var_declaration
| Required -> ""
| Optional -> " = None"
| With_default ->
match get_python_default ~mutable_ok:false unwrapped_e an with
match get_python_default unwrapped_e an with
| None -> ""
| Some value -> sprintf " = %s" value
| Some x ->
(* This constructs ensures that a fresh default value is
evaluated for each class instanciation. It's important for
default lists since Python lists are mutable. *)
sprintf " = field(default_factory=lambda: %s)" x
in
[
Line (sprintf "%s: %s%s" var_name type_name default)
Expand Down
4 changes: 4 additions & 0 deletions atdpy/test/atd-input/everything.atd
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ type recursive_class = {
flag: bool;
children: recursive_class list;
}

type default_list = {
~items: int list;
}
40 changes: 34 additions & 6 deletions atdpy/test/python-expected/everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Import annotations to allow forward references
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union

import json
Expand Down Expand Up @@ -455,7 +455,7 @@ class IntFloatParametrizedRecord:
"""Original type: _int_float_parametrized_record = { ... }"""

field_a: int
field_b: List[float]
field_b: List[float] = field(default_factory=lambda: [])

@classmethod
def from_json(cls, x: Any) -> 'IntFloatParametrizedRecord':
Expand Down Expand Up @@ -489,7 +489,6 @@ class Root:
await_: bool
x___init__: float
items: List[List[int]]
extras: List[int]
aliased: Alias
point: Tuple[float, float]
kinds: List[Kind]
Expand All @@ -503,7 +502,8 @@ class Root:
parametrized_record: IntFloatParametrizedRecord
parametrized_tuple: KindParametrizedTuple
maybe: Optional[int] = None
answer: int = 42
extras: List[int] = field(default_factory=lambda: [])
answer: int = field(default_factory=lambda: 42)

@classmethod
def from_json(cls, x: Any) -> 'Root':
Expand All @@ -513,7 +513,6 @@ def from_json(cls, x: Any) -> 'Root':
await_=_atd_read_bool(x['await']) if 'await' in x else _atd_missing_json_field('Root', 'await'),
x___init__=_atd_read_float(x['__init__']) if '__init__' in x else _atd_missing_json_field('Root', '__init__'),
items=_atd_read_list(_atd_read_list(_atd_read_int))(x['items']) if 'items' in x else _atd_missing_json_field('Root', 'items'),
extras=_atd_read_list(_atd_read_int)(x['extras']) if 'extras' in x else [],
aliased=Alias.from_json(x['aliased']) if 'aliased' in x else _atd_missing_json_field('Root', 'aliased'),
point=(lambda x: (_atd_read_float(x[0]), _atd_read_float(x[1])) if isinstance(x, list) and len(x) == 2 else _atd_bad_json('array of length 2', x))(x['point']) if 'point' in x else _atd_missing_json_field('Root', 'point'),
kinds=_atd_read_list(Kind.from_json)(x['kinds']) if 'kinds' in x else _atd_missing_json_field('Root', 'kinds'),
Expand All @@ -527,6 +526,7 @@ def from_json(cls, x: Any) -> 'Root':
parametrized_record=IntFloatParametrizedRecord.from_json(x['parametrized_record']) if 'parametrized_record' in x else _atd_missing_json_field('Root', 'parametrized_record'),
parametrized_tuple=KindParametrizedTuple.from_json(x['parametrized_tuple']) if 'parametrized_tuple' in x else _atd_missing_json_field('Root', 'parametrized_tuple'),
maybe=_atd_read_int(x['maybe']) if 'maybe' in x else None,
extras=_atd_read_list(_atd_read_int)(x['extras']) if 'extras' in x else [],
answer=_atd_read_int(x['answer']) if 'answer' in x else 42,
)
else:
Expand All @@ -538,7 +538,6 @@ def to_json(self) -> Any:
res['await'] = _atd_write_bool(self.await_)
res['__init__'] = _atd_write_float(self.x___init__)
res['items'] = _atd_write_list(_atd_write_list(_atd_write_int))(self.items)
res['extras'] = _atd_write_list(_atd_write_int)(self.extras)
res['aliased'] = (lambda x: x.to_json())(self.aliased)
res['point'] = (lambda x: [_atd_write_float(x[0]), _atd_write_float(x[1])] if isinstance(x, tuple) and len(x) == 2 else _atd_bad_python('tuple of length 2', x))(self.point)
res['kinds'] = _atd_write_list((lambda x: x.to_json()))(self.kinds)
Expand All @@ -553,6 +552,7 @@ def to_json(self) -> Any:
res['parametrized_tuple'] = (lambda x: x.to_json())(self.parametrized_tuple)
if self.maybe is not None:
res['maybe'] = _atd_write_int(self.maybe)
res['extras'] = _atd_write_list(_atd_write_int)(self.extras)
res['answer'] = _atd_write_int(self.answer)
return res

Expand Down Expand Up @@ -683,3 +683,31 @@ def from_json_string(cls, x: str) -> 'Frozen':

def to_json_string(self, **kw: Any) -> str:
return json.dumps(self.to_json(), **kw)


@dataclass
class DefaultList:
"""Original type: default_list = { ... }"""

items: List[int] = field(default_factory=lambda: [])

@classmethod
def from_json(cls, x: Any) -> 'DefaultList':
if isinstance(x, dict):
return cls(
items=_atd_read_list(_atd_read_int)(x['items']) if 'items' in x else [],
)
else:
_atd_bad_json('DefaultList', x)

def to_json(self) -> Any:
res: Dict[str, Any] = {}
res['items'] = _atd_write_list(_atd_write_int)(self.items)
return res

@classmethod
def from_json_string(cls, x: str) -> 'DefaultList':
return cls.from_json(json.loads(x))

def to_json_string(self, **kw: Any) -> str:
return json.dumps(self.to_json(), **kw)
21 changes: 17 additions & 4 deletions atdpy/test/python-tests/test_atdpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def test_everything_to_json() -> None:
2
]
],
"extras": [
17,
53
],
"aliased": [
8,
9,
Expand Down Expand Up @@ -200,6 +196,10 @@ def test_everything_to_json() -> None:
"wow",
100
],
"extras": [
17,
53
],
"answer": 42
}"""
b_obj = e.Root.from_json_string(a_str)
Expand Down Expand Up @@ -256,5 +256,18 @@ def test_recursive_class() -> None:
assert b_str2 == a_str


def test_default_list() -> None:
a = e.DefaultList(items=[])
assert a.items == []
b = e.DefaultList()
assert b.items == []
c = e.DefaultList.from_json_string("{}")
assert c.items == []
# We could emit '{}' instead of '{"items": []}' but it's more complicated
# and not always desired.
j = b.to_json_string()
assert j == '{"items": []}'


# print updated json
test_everything_to_json()

0 comments on commit 38fb495

Please sign in to comment.