From dd818c2b00dcb50ca5fc85ccef5dd8e418d900ea Mon Sep 17 00:00:00 2001 From: pesap Date: Wed, 22 May 2024 10:57:20 -0600 Subject: [PATCH] fix(pint_units): Fixed `BaseQuantity` to allow quantities multiplications (#25) List of changes: - Removed `__new__` from BaseQuantity which was causing the problem - Renamed `__compatible_unit__` to `__base_unit__` - Added `ureg.check` to default units. This will check if the unit is compatible, but will not check the input value. - Added testing for BaseQuantity class and units. - Added a serialization property so that we can call `component.model_dump(mode="json")` when using pint.Quantities, Closes #24 --- src/infrasys/base_quantity.py | 70 +++++++++++++++++++++++++----- src/infrasys/quantities.py | 29 +++++++++---- src/infrasys/system.py | 2 +- src/infrasys/time_series_models.py | 3 +- tests/test_base_quantity.py | 48 ++++++++++++++++++++ tests/test_serialization.py | 3 +- 6 files changed, 131 insertions(+), 24 deletions(-) create mode 100644 tests/test_base_quantity.py diff --git a/src/infrasys/base_quantity.py b/src/infrasys/base_quantity.py index 22c4fac..5a3d9b4 100644 --- a/src/infrasys/base_quantity.py +++ b/src/infrasys/base_quantity.py @@ -1,25 +1,71 @@ -""" This module contains base class for handling pint quantity.""" +"""This module contains base class for handling pint quantity.""" -from abc import ABC -from typing import Any +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from __main__ import BaseQuantity import numpy as np import pint +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema +from typing_extensions import Annotated ureg = pint.UnitRegistry() -class BaseQuantity(ureg.Quantity, ABC): # type: ignore +class BaseQuantity(ureg.Quantity): # type: ignore """Interface for base quantity.""" - def __new__(cls, value, units, **kwargs): - instance = super().__new__(cls, value, units, **kwargs) - if not hasattr(cls, "__compatible_unit__"): - raise ValueError("You should define __compatible_unit__ attribute in your class.") - if not instance.is_compatible_with(cls.__compatible_unit__): - message = f"{__class__} must be compatible with {cls.__compatible_unit__}, not {units}" - raise ValueError(message) - return instance + __base_unit__ = None + + def __init_subclass__(cls, **kwargs): + if not cls.__base_unit__: + raise TypeError("__base_unit__ should be defined") + super().__init_subclass__(**kwargs) + + # NOTE: This creates a type hint for the unit. + def __class_getitem__(cls): + return Annotated.__class_getitem__((cls, cls.__base_unit__)) # type: ignore + + # Required for pydantic validation + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + cls.validate, + handler(pint.Quantity), + field_name=handler.field_name, + serialization=core_schema.plain_serializer_function_ser_schema( + cls.serialize, info_arg=False, return_schema=core_schema.str_schema() + ), + ) + + # Required for pydantic validation + @classmethod + def validate(cls, value, *_) -> "BaseQuantity": + if isinstance(value, BaseQuantity): + if cls.__base_unit__: + assert value.check( + cls.__base_unit__ + ), f"Unit must be compatible with {cls.__base_unit__}" + return cls(value) + if isinstance(value, pint.Quantity): + if cls.__base_unit__: + assert value.check( + cls.__base_unit__ + ), f"Unit must be compatible with {cls.__base_unit__}" + return cls(value.magnitude, value.units) + else: + raise ValueError(f"Invalid type for BaseQuantity: {type(value)}") + if isinstance(value, cls): + return cls(value) + return value + + @staticmethod + def serialize(value): + return str(value) def to_dict(self) -> dict[str, Any]: """Convert a quantity to a dictionary for serialization.""" diff --git a/src/infrasys/quantities.py b/src/infrasys/quantities.py index ca75ed8..4be259d 100644 --- a/src/infrasys/quantities.py +++ b/src/infrasys/quantities.py @@ -1,22 +1,33 @@ -"""This module defines basic unit quantities.""" +"""This module defines basic unit quantities. + +To create new Quantities for a given base unit, we just need to specify the +base unit as the second argument of `ureg.check`. +""" from infrasys.base_quantity import BaseQuantity +from infrasys.component import Component # ruff:noqa # fmt: off -class Distance(BaseQuantity): __compatible_unit__ = "meter" +class Distance(BaseQuantity): __base_unit__ = "meter" + +class Voltage(BaseQuantity): __base_unit__ = "volt" + +class Current(BaseQuantity): __base_unit__ = "ampere" + +class Angle(BaseQuantity): __base_unit__ = "degree" -class Voltage(BaseQuantity): __compatible_unit__ = "volt" +class ActivePower(BaseQuantity): __base_unit__ = "watt" -class Current(BaseQuantity): __compatible_unit__ = "ampere" +class Energy(BaseQuantity): __base_unit__ = "watthour" -class Angle(BaseQuantity): __compatible_unit__ = "degree" +class Time(BaseQuantity): __base_unit__ = "minute" -class ActivePower(BaseQuantity): __compatible_unit__ = "watt" +class Resistance(BaseQuantity): __base_unit__ = "ohm" -class Energy(BaseQuantity): __compatible_unit__ = "watthour" -class Time(BaseQuantity): __compatible_unit__ = "minute" +class Test(Component): + voltage: Voltage -class Resistance(BaseQuantity): __compatible_unit__ = "ohm" +Test(name="test", voltage=Voltage(100, "kV")) diff --git a/src/infrasys/system.py b/src/infrasys/system.py index 6202d51..0e1a1fc 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -1064,7 +1064,7 @@ def _deserialize_fields( values[field] = composed_value elif isinstance(metadata.fields, SerializedQuantityType): quantity_type = cached_types.get_type(metadata.fields) - values[field] = quantity_type.from_dict(value) + values[field] = quantity_type(value=value["value"], units=value["units"]) else: msg = f"Bug: unhandled type: {field=} {value=}" raise NotImplementedError(msg) diff --git a/src/infrasys/time_series_models.py b/src/infrasys/time_series_models.py index a374df0..3fccbe4 100644 --- a/src/infrasys/time_series_models.py +++ b/src/infrasys/time_series_models.py @@ -9,6 +9,7 @@ import numpy as np import pyarrow as pa +import pint from pydantic import ( Field, WithJsonSchema, @@ -62,7 +63,7 @@ def get_time_series_metadata_type() -> Type: class SingleTimeSeries(TimeSeriesData): """Defines a time array with a single dimension of floats.""" - data: pa.Array | BaseQuantity + data: pa.Array | pint.Quantity resolution: timedelta initial_time: datetime diff --git a/tests/test_base_quantity.py b/tests/test_base_quantity.py new file mode 100644 index 0000000..8107aee --- /dev/null +++ b/tests/test_base_quantity.py @@ -0,0 +1,48 @@ +from infrasys.base_quantity import ureg, BaseQuantity +from infrasys.quantities import ActivePower, Time +from pint.errors import DimensionalityError +import pytest +import numpy as np + + +def test_base_quantity(): + distance_quantity = ureg.check(None, "meter")(BaseQuantity) + + unit = distance_quantity(100, "meter") + assert isinstance(unit, BaseQuantity) + + # Check that we can not assign units that are not-related. + with pytest.raises(DimensionalityError): + _ = distance_quantity(100, "kWh") + + # Check unit multiplication + active_power_quantity = ActivePower(100, "kW") + hours = Time(2, "h") + + result_quantity = active_power_quantity * hours + assert result_quantity.check("[energy]") + assert result_quantity.magnitude == 200 + + # Check to dict + assert result_quantity.to_dict() == { + "value": result_quantity.magnitude, + "units": str(result_quantity.units), + } + + +def test_base_quantity_numpy(): + array = np.arange(0, 10) + measurements = ActivePower(array, "kW") + assert isinstance(measurements, BaseQuantity) + assert measurements.to_dict()["value"] == array.tolist() + + +def test_unit_deserialization(): + test_units = { + "value": 100, + "units": "kilowatt", # The unit name should be the pint default name + } + active_power = BaseQuantity.from_dict(test_units) + assert isinstance(active_power, BaseQuantity) + assert active_power.magnitude == 100 + assert str(active_power.units) == "kilowatt" diff --git a/tests/test_serialization.py b/tests/test_serialization.py index b7cb98c..5043795 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -8,7 +8,8 @@ from pydantic import WithJsonSchema from typing_extensions import Annotated -from infrasys import Component, Location, SingleTimeSeries +from infrasys import Location, SingleTimeSeries +from infrasys.component import Component from infrasys.quantities import Distance, ActivePower from infrasys.exceptions import ISOperationNotAllowed from infrasys.normalization import NormalizationMax