diff --git a/neps/search_spaces/hyperparameters/categorical.py b/neps/search_spaces/hyperparameters/categorical.py index db5566f6..9fa75253 100644 --- a/neps/search_spaces/hyperparameters/categorical.py +++ b/neps/search_spaces/hyperparameters/categorical.py @@ -39,6 +39,19 @@ def __init__( ) self.value: None | float | int | str = None + # Check if 'default' is in 'choices' + if default is not None and default not in self.choices: + raise ValueError( + f"Default value {default} is not in the provided choices {self.choices}" + ) + + # Check if 'is_fidelity' is a boolean + if not isinstance(is_fidelity, bool): + raise TypeError( + f"Expected 'is_fidelity' to be a boolean, but got type: " + f"{type(is_fidelity).__name__}" + ) + @property def id(self): return self.value @@ -46,12 +59,13 @@ def id(self): def __eq__(self, other): if not isinstance(other, self.__class__): return False - return (self.choices == other.choices - and self.value == other.value - and self.is_fidelity == other.is_fidelity - and self.default == other.default - and self.default_confidence_score == other.default_confidence_score - ) + return ( + self.choices == other.choices + and self.value == other.value + and self.is_fidelity == other.is_fidelity + and self.default == other.default + and self.default_confidence_score == other.default_confidence_score + ) def __repr__(self): return f"" diff --git a/neps/search_spaces/hyperparameters/float.py b/neps/search_spaces/hyperparameters/float.py index 0501ecf7..7dc37042 100644 --- a/neps/search_spaces/hyperparameters/float.py +++ b/neps/search_spaces/hyperparameters/float.py @@ -49,17 +49,15 @@ def __init__( f" upper={self.upper}" ) - if not isinstance(log, bool): - raise TypeError( - f"Expected 'log' to be a boolean, but got type: {type(log).__name__}" - ) - - if not isinstance(log, bool): - raise TypeError( - "Expected 'self.log' to be a boolean, but got type: {}".format( - type(log).__name__ + # Validate 'log' and 'is_fidelity' types to prevent configuration errors + # from the YAML input + for param, value in {"log": log, "is_fidelity": is_fidelity}.items(): + if not isinstance(value, bool): + raise TypeError( + f"Expected '{param}' to be a boolean, but got type: " + f"{type(value).__name__}" ) - ) + self.log = log if self.log: diff --git a/tests/test_yaml_search_space/default_value_not_in_choices_config.yaml b/tests/test_yaml_search_space/default_value_not_in_choices_config.yaml new file mode 100644 index 00000000..3a1aa4c1 --- /dev/null +++ b/tests/test_yaml_search_space/default_value_not_in_choices_config.yaml @@ -0,0 +1,4 @@ +search_space: + cat1: + choices: ["a", "b", "c"] + default: "d" diff --git a/tests/test_yaml_search_space/not_boolean_type_is_fidelity_cat_config.yaml b/tests/test_yaml_search_space/not_boolean_type_is_fidelity_cat_config.yaml new file mode 100644 index 00000000..434bd8a1 --- /dev/null +++ b/tests/test_yaml_search_space/not_boolean_type_is_fidelity_cat_config.yaml @@ -0,0 +1,5 @@ +search_space: + cat1: + choices: ["a", "b", "c"] + is_fidelity: fals + default: "c" diff --git a/tests/test_yaml_search_space/not_boolean_type_is_fidelity_float_config.yaml b/tests/test_yaml_search_space/not_boolean_type_is_fidelity_float_config.yaml new file mode 100644 index 00000000..ac527c09 --- /dev/null +++ b/tests/test_yaml_search_space/not_boolean_type_is_fidelity_float_config.yaml @@ -0,0 +1,7 @@ +search_space: + param_float1: + lower: 0.00001 + upper: 0.1 + default: 0.001 + log: false + is_fidelity: truee diff --git a/tests/test_yaml_search_space/not_boolean_type_log_config.yaml b/tests/test_yaml_search_space/not_boolean_type_log_config.yaml new file mode 100644 index 00000000..5f860194 --- /dev/null +++ b/tests/test_yaml_search_space/not_boolean_type_log_config.yaml @@ -0,0 +1,7 @@ +search_space: + param_float1: + lower: 0.00001 + upper: 0.1 + default: 0.001 + log: falsee + is_fidelity: true diff --git a/tests/test_yaml_search_space/test_search_space.py b/tests/test_yaml_search_space/test_search_space.py index 23cd0573..961e0713 100644 --- a/tests/test_yaml_search_space/test_search_space.py +++ b/tests/test_yaml_search_space/test_search_space.py @@ -109,9 +109,48 @@ def test_yaml_file_including_not_allowed_parameter_keys(): @pytest.mark.neps_api -def test_yaml_file_default_parameter_in_range(): +def test_yaml_file_default_parameter_not_in_range(): """Test if the default value outside the specified range is correctly identified and handled.""" with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: pipeline_space_from_yaml(BASE_PATH + "default_not_in_range_config.yaml") assert excinfo.value.exception_type == "ValueError" + + +@pytest.mark.neps_api +def test_float_log_not_boolean(): + """Test if an exception is raised when the 'log' attribute is not a boolean.""" + with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: + pipeline_space_from_yaml(BASE_PATH + "not_boolean_type_log_config.yaml") + assert excinfo.value.exception_type == "TypeError" + + +@pytest.mark.neps_api +def test_float_is_fidelity_not_boolean(): + """Test if an exception is raised when for FloatParameter the 'is_fidelity' + attribute is not a boolean.""" + with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: + pipeline_space_from_yaml( + BASE_PATH + "not_boolean_type_is_fidelity_float_config.yaml" + ) + assert excinfo.value.exception_type == "TypeError" + + +@pytest.mark.neps_api +def test_cat_is_fidelity_not_boolean(): + """Test if an exception is raised when for CategoricalParameter the 'is_fidelity' + attribute is not boolean.""" + with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: + pipeline_space_from_yaml( + BASE_PATH + "not_boolean_type_is_fidelity_cat_config.yaml" + ) + assert excinfo.value.exception_type == "TypeError" + + +@pytest.mark.neps_api +def test_categorical_default_value_not_in_choices(): + """Test if a ValueError is raised when the default value is not in the choices + for a CategoricalParameter.""" + with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: + pipeline_space_from_yaml(BASE_PATH + "default_value_not_in_choices_config.yaml") + assert excinfo.value.exception_type == "ValueError"