Skip to content

Commit

Permalink
fix: cluster configuration validation for bool type
Browse files Browse the repository at this point in the history
As bool is a subtype of int, True/False was considered as 1/0
  • Loading branch information
Ygnas authored and openshift-merge-bot[bot] committed Dec 5, 2024
1 parent e666e0a commit be9763a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/codeflare_sdk/common/utils/unit_test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def createClusterWrongType():
config = ClusterConfiguration(
name="unit-test-cluster",
namespace="ns",
num_workers=2,
num_workers=True,
worker_cpu_requests=[],
worker_cpu_limits=4,
worker_memory_requests=5,
Expand Down
12 changes: 9 additions & 3 deletions src/codeflare_sdk/ray/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,15 @@ def _memory_to_resource(self):

def _validate_types(self):
"""Validate the types of all fields in the ClusterConfiguration dataclass."""
errors = []
for field_info in fields(self):
value = getattr(self, field_info.name)
expected_type = field_info.type
if not self._is_type(value, expected_type):
raise TypeError(
f"'{field_info.name}' should be of type {expected_type}"
)
errors.append(f"'{field_info.name}' should be of type {expected_type}.")

if errors:
raise TypeError("Type validation failed:\n" + "\n".join(errors))

@staticmethod
def _is_type(value, expected_type):
Expand All @@ -268,6 +270,10 @@ def check_type(value, expected_type):
)
if origin_type is tuple:
return all(check_type(elem, etype) for elem, etype in zip(value, args))
if expected_type is int:
return isinstance(value, int) and not isinstance(value, bool)
if expected_type is bool:
return isinstance(value, bool)
return isinstance(value, expected_type)

return check_type(value, expected_type)
4 changes: 3 additions & 1 deletion src/codeflare_sdk/ray/cluster/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ def test_all_config_params_aw(mocker):


def test_config_creation_wrong_type():
with pytest.raises(TypeError):
with pytest.raises(TypeError) as error_info:
createClusterWrongType()

assert len(str(error_info.value).splitlines()) == 4


def test_cluster_config_deprecation_conversion(mocker):
config = ClusterConfiguration(
Expand Down

0 comments on commit be9763a

Please sign in to comment.