Skip to content

Commit

Permalink
fix for non-types
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 7, 2024
1 parent c76a95f commit 400ef6b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
1 change: 0 additions & 1 deletion cobaya/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def get_external_function(string_or_function, name=None):
if isinstance(string_or_function, str):
try:
scope = globals()
import scipy.stats as stats # provide default scope for eval
scope['stats'] = stats
scope['np'] = np
string_or_function = replace_optimizations(string_or_function)
Expand Down
27 changes: 14 additions & 13 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
"\n".join(f"- {e}" for e in path_errors)
)

if origin is typing.ClassVar:
if not isinstance(origin, type):
return validate_type(args[0], value, path)

if isinstance(value, Mapping) != issubclass(origin, Mapping):
Expand Down Expand Up @@ -245,18 +245,19 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
validate_type(t, v, f"{path}[{i}]" if path else f"[{i}]")
return

if not (isinstance(value, expected_type) or
expected_type is Sequence and isinstance(value, np.ndarray)):
if not isinstance(expected_type, type) or isinstance(value, expected_type) \
or expected_type is Sequence and isinstance(value, np.ndarray):
return

type_name = getattr(expected_type, "__name__", repr(expected_type))
type_name = getattr(expected_type, "__name__", repr(expected_type))

# special case for Cobaya's NumberWithUnits, if not instance yet
if type_name == 'NumberWithUnits':
if not isinstance(value, (numbers.Real, str)):
raise TypeError(
f"{curr_path} must be a number or string for NumberWithUnits,"
f" got {type(value).__name__}")
return
# special case for Cobaya's NumberWithUnits, if not instance yet
if type_name == 'NumberWithUnits':
if not isinstance(value, (numbers.Real, str)):
raise TypeError(
f"{curr_path} must be a number or string for NumberWithUnits,"
f" got {type(value).__name__}")
return

raise TypeError(f"{curr_path} must be of type {type_name}, "
f"got {type(value).__name__}")
raise TypeError(f"{curr_path} must be of type {type_name}, "
f"got {type(value).__name__}")
2 changes: 2 additions & 0 deletions tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GenericComponent(CobayaComponent):
map: Mapping[float, str]
deferred: 'ParamDict'
unset = 1
install_options: ClassVar

_enforce_types = True

Expand All @@ -47,6 +48,7 @@ def test_component_types():
"array2": [1, 2],
"map": {1.0: "a", 2.0: "b"},
"deferred": {'value': lambda x: x},
"install_options": {}
}
GenericComponent(correct_kwargs)

Expand Down

0 comments on commit 400ef6b

Please sign in to comment.