diff --git a/CHANGELOG.md b/CHANGELOG.md index ddd646e..2726743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Use JSON `$schema` version `2019-09` to allow use of `unevaluatedProperties` for stricter validation of MLM fields. +- Explicitly disallow `mlm:name`, `mlm:input`, `mlm:output` and `mlm:hyperparameters` at the Asset level. + These fields describe the model as a whole and should therefore be defined in Item properties. - Moved `norm_type` to `value_scaling` object to better reflect the expected operation, which could be another operation than what is typically known as "normalization" or "standardization" techniques in machine learning. - Moved `statistics` to `value_scaling` object to better reflect their mutual `type` and additional diff --git a/tests/test_schema.py b/tests/test_schema.py index e5a1ded..d0232af 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -68,27 +68,50 @@ def test_mlm_no_undefined_prefixed_field_item_properties( ["item_raster_bands.json"], indirect=True, ) +@pytest.mark.parametrize( + ["test_field", "test_value"], + [ + ("mlm:unknown", "random"), + ("mlm:name", "test-model"), + ("mlm:input", []), + ("mlm:output", []), + ("mlm:hyperparameters", {}), + ] +) def test_mlm_no_undefined_prefixed_field_asset_properties( mlm_validator: STACValidator, mlm_example: Dict[str, JSON], + test_field: str, + test_value: Any, ) -> None: mlm_data = copy.deepcopy(mlm_example) mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid - assert mlm_data["assets"]["weights"] # type: ignore - mlm_data["assets"]["weights"]["mlm:unknown"] = "random" # type: ignore + + mlm_data = copy.deepcopy(mlm_example) + mlm_data["assets"]["weights"][test_field] = test_value # type: ignore with pytest.raises(pystac.errors.STACValidationError) as exc: mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) assert len(exc.value.source) == 1 # type: ignore schema_error = exc.value.source[0] # type: ignore - assert "mlm:unknown" in schema_error.instance + assert test_field in schema_error.instance assert schema_error.schema["description"] in [ "Fields that apply only within an Asset.", "Schema to validate the MLM fields permitted only under Assets properties." ] + +@pytest.mark.parametrize( + "mlm_example", + ["item_raster_bands.json"], + indirect=True, +) +def test_mlm_allowed_field_asset_properties_override( + mlm_validator: STACValidator, + mlm_example: Dict[str, JSON], +) -> None: # defined property allowed both at the Item at the Asset level mlm_data = copy.deepcopy(mlm_example) mlm_data["assets"]["weights"]["mlm:accelerator"] = "cuda" # type: ignore