Skip to content

Commit

Permalink
add more tests to validate MLM fields disallowed under Asset
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Nov 6, 2024
1 parent b3b7d63 commit 596e5d4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Use JSON `$schema` version `2019-09` to allow use of `unevaluatedProperties` for stricter validation of MLM fields.
- Explicitly disallow `mlm:name`, `mlm:input`, `mlm:output` and `mlm:hyperparameters` at the Asset level.
These fields describe the model as a whole and should therefore be defined in Item properties.
- Moved `norm_type` to `value_scaling` object to better reflect the expected operation, which could be another
operation than what is typically known as "normalization" or "standardization" techniques in machine learning.
- Moved `statistics` to `value_scaling` object to better reflect their mutual `type` and additional
Expand Down
29 changes: 26 additions & 3 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,50 @@ def test_mlm_no_undefined_prefixed_field_item_properties(
["item_raster_bands.json"],
indirect=True,
)
@pytest.mark.parametrize(
["test_field", "test_value"],
[
("mlm:unknown", "random"),
("mlm:name", "test-model"),
("mlm:input", []),
("mlm:output", []),
("mlm:hyperparameters", {}),
]
)
def test_mlm_no_undefined_prefixed_field_asset_properties(
mlm_validator: STACValidator,
mlm_example: Dict[str, JSON],
test_field: str,
test_value: Any,
) -> None:
mlm_data = copy.deepcopy(mlm_example)
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid

assert mlm_data["assets"]["weights"] # type: ignore
mlm_data["assets"]["weights"]["mlm:unknown"] = "random" # type: ignore

mlm_data = copy.deepcopy(mlm_example)
mlm_data["assets"]["weights"][test_field] = test_value # type: ignore
with pytest.raises(pystac.errors.STACValidationError) as exc:
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator)
assert len(exc.value.source) == 1 # type: ignore
schema_error = exc.value.source[0] # type: ignore
assert "mlm:unknown" in schema_error.instance
assert test_field in schema_error.instance
assert schema_error.schema["description"] in [
"Fields that apply only within an Asset.",
"Schema to validate the MLM fields permitted only under Assets properties."
]


@pytest.mark.parametrize(
"mlm_example",
["item_raster_bands.json"],
indirect=True,
)
def test_mlm_allowed_field_asset_properties_override(
mlm_validator: STACValidator,
mlm_example: Dict[str, JSON],
) -> None:
# defined property allowed both at the Item at the Asset level
mlm_data = copy.deepcopy(mlm_example)
mlm_data["assets"]["weights"]["mlm:accelerator"] = "cuda" # type: ignore
Expand Down

0 comments on commit 596e5d4

Please sign in to comment.