Skip to content

Commit

Permalink
fix(pint_units): Fixed BaseQuantity to allow quantities multiplicat…
Browse files Browse the repository at this point in the history
…ions (#25)

List of changes:
- Removed `__new__` from BaseQuantity which was causing the problem
- Renamed `__compatible_unit__` to `__base_unit__`
- Added `ureg.check` to default units. This will check if the unit is compatible, but will not check the input value.
- Added testing for BaseQuantity class and units.
- Added a serialization property so that we can call `component.model_dump(mode="json")` when using pint.Quantities,

Closes #24
  • Loading branch information
pesap authored May 22, 2024
1 parent fe2295d commit dd818c2
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 24 deletions.
70 changes: 58 additions & 12 deletions src/infrasys/base_quantity.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,71 @@
""" This module contains base class for handling pint quantity."""
"""This module contains base class for handling pint quantity."""

from abc import ABC
from typing import Any
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
from __main__ import BaseQuantity

import numpy as np
import pint
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from typing_extensions import Annotated

ureg = pint.UnitRegistry()


class BaseQuantity(ureg.Quantity, ABC): # type: ignore
class BaseQuantity(ureg.Quantity): # type: ignore
"""Interface for base quantity."""

def __new__(cls, value, units, **kwargs):
instance = super().__new__(cls, value, units, **kwargs)
if not hasattr(cls, "__compatible_unit__"):
raise ValueError("You should define __compatible_unit__ attribute in your class.")
if not instance.is_compatible_with(cls.__compatible_unit__):
message = f"{__class__} must be compatible with {cls.__compatible_unit__}, not {units}"
raise ValueError(message)
return instance
__base_unit__ = None

def __init_subclass__(cls, **kwargs):
if not cls.__base_unit__:
raise TypeError("__base_unit__ should be defined")
super().__init_subclass__(**kwargs)

# NOTE: This creates a type hint for the unit.
def __class_getitem__(cls):
return Annotated.__class_getitem__((cls, cls.__base_unit__)) # type: ignore

# Required for pydantic validation
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.with_info_after_validator_function(
cls.validate,
handler(pint.Quantity),
field_name=handler.field_name,
serialization=core_schema.plain_serializer_function_ser_schema(
cls.serialize, info_arg=False, return_schema=core_schema.str_schema()
),
)

# Required for pydantic validation
@classmethod
def validate(cls, value, *_) -> "BaseQuantity":
if isinstance(value, BaseQuantity):
if cls.__base_unit__:
assert value.check(
cls.__base_unit__
), f"Unit must be compatible with {cls.__base_unit__}"
return cls(value)
if isinstance(value, pint.Quantity):
if cls.__base_unit__:
assert value.check(
cls.__base_unit__
), f"Unit must be compatible with {cls.__base_unit__}"
return cls(value.magnitude, value.units)
else:
raise ValueError(f"Invalid type for BaseQuantity: {type(value)}")
if isinstance(value, cls):
return cls(value)
return value

@staticmethod
def serialize(value):
return str(value)

def to_dict(self) -> dict[str, Any]:
"""Convert a quantity to a dictionary for serialization."""
Expand Down
29 changes: 20 additions & 9 deletions src/infrasys/quantities.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
"""This module defines basic unit quantities."""
"""This module defines basic unit quantities.
To create new Quantities for a given base unit, we just need to specify the
base unit as the second argument of `ureg.check`.
"""

from infrasys.base_quantity import BaseQuantity
from infrasys.component import Component

# ruff:noqa
# fmt: off

class Distance(BaseQuantity): __compatible_unit__ = "meter"
class Distance(BaseQuantity): __base_unit__ = "meter"

class Voltage(BaseQuantity): __base_unit__ = "volt"

class Current(BaseQuantity): __base_unit__ = "ampere"

class Angle(BaseQuantity): __base_unit__ = "degree"

class Voltage(BaseQuantity): __compatible_unit__ = "volt"
class ActivePower(BaseQuantity): __base_unit__ = "watt"

class Current(BaseQuantity): __compatible_unit__ = "ampere"
class Energy(BaseQuantity): __base_unit__ = "watthour"

class Angle(BaseQuantity): __compatible_unit__ = "degree"
class Time(BaseQuantity): __base_unit__ = "minute"

class ActivePower(BaseQuantity): __compatible_unit__ = "watt"
class Resistance(BaseQuantity): __base_unit__ = "ohm"

class Energy(BaseQuantity): __compatible_unit__ = "watthour"

class Time(BaseQuantity): __compatible_unit__ = "minute"
class Test(Component):
voltage: Voltage

class Resistance(BaseQuantity): __compatible_unit__ = "ohm"
Test(name="test", voltage=Voltage(100, "kV"))
2 changes: 1 addition & 1 deletion src/infrasys/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def _deserialize_fields(
values[field] = composed_value
elif isinstance(metadata.fields, SerializedQuantityType):
quantity_type = cached_types.get_type(metadata.fields)
values[field] = quantity_type.from_dict(value)
values[field] = quantity_type(value=value["value"], units=value["units"])
else:
msg = f"Bug: unhandled type: {field=} {value=}"
raise NotImplementedError(msg)
Expand Down
3 changes: 2 additions & 1 deletion src/infrasys/time_series_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pyarrow as pa
import pint
from pydantic import (
Field,
WithJsonSchema,
Expand Down Expand Up @@ -62,7 +63,7 @@ def get_time_series_metadata_type() -> Type:
class SingleTimeSeries(TimeSeriesData):
"""Defines a time array with a single dimension of floats."""

data: pa.Array | BaseQuantity
data: pa.Array | pint.Quantity
resolution: timedelta
initial_time: datetime

Expand Down
48 changes: 48 additions & 0 deletions tests/test_base_quantity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from infrasys.base_quantity import ureg, BaseQuantity
from infrasys.quantities import ActivePower, Time
from pint.errors import DimensionalityError
import pytest
import numpy as np


def test_base_quantity():
distance_quantity = ureg.check(None, "meter")(BaseQuantity)

unit = distance_quantity(100, "meter")
assert isinstance(unit, BaseQuantity)

# Check that we can not assign units that are not-related.
with pytest.raises(DimensionalityError):
_ = distance_quantity(100, "kWh")

# Check unit multiplication
active_power_quantity = ActivePower(100, "kW")
hours = Time(2, "h")

result_quantity = active_power_quantity * hours
assert result_quantity.check("[energy]")
assert result_quantity.magnitude == 200

# Check to dict
assert result_quantity.to_dict() == {
"value": result_quantity.magnitude,
"units": str(result_quantity.units),
}


def test_base_quantity_numpy():
array = np.arange(0, 10)
measurements = ActivePower(array, "kW")
assert isinstance(measurements, BaseQuantity)
assert measurements.to_dict()["value"] == array.tolist()


def test_unit_deserialization():
test_units = {
"value": 100,
"units": "kilowatt", # The unit name should be the pint default name
}
active_power = BaseQuantity.from_dict(test_units)
assert isinstance(active_power, BaseQuantity)
assert active_power.magnitude == 100
assert str(active_power.units) == "kilowatt"
3 changes: 2 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pydantic import WithJsonSchema
from typing_extensions import Annotated

from infrasys import Component, Location, SingleTimeSeries
from infrasys import Location, SingleTimeSeries
from infrasys.component import Component
from infrasys.quantities import Distance, ActivePower
from infrasys.exceptions import ISOperationNotAllowed
from infrasys.normalization import NormalizationMax
Expand Down

0 comments on commit dd818c2

Please sign in to comment.