Skip to content

Commit

Permalink
fix: Add correct representation of Value curves and remove name from …
Browse files Browse the repository at this point in the history
…cost and function data components (#45)

List of changes
- Changed data handling when we create an instance of
  `SingleTimeSeries`.
- Updated test to check that data inside a `SingleTimeSeries` is consistent since it could return
  two different instances when serialize and deserialize.
- Updated arrow storage to convert to `pa.Array` when serializing only
  and better type hint
  • Loading branch information
pesap authored Oct 22, 2024
1 parent 9dc0e3e commit ae517df
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "infrasys"
version = "0.1.0"
version = "0.1.1"
description = ''
readme = "README.md"
requires-python = ">=3.10, <3.13"
Expand Down
10 changes: 7 additions & 3 deletions src/infrasys/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from uuid import UUID

import pyarrow as pa
import pint
from loguru import logger

from infrasys.exceptions import ISNotStored
from infrasys.base_quantity import BaseQuantity
from infrasys.time_series_models import (
SingleTimeSeries,
SingleTimeSeriesMetadata,
Expand Down Expand Up @@ -122,9 +122,13 @@ def _get_single_time_series(
normalization=metadata.normalization,
)

def _convert_to_record_batch(self, array: SingleTimeSeries, variable_name: str):
def _convert_to_record_batch(
self, time_series: SingleTimeSeries, variable_name: str
) -> pa.RecordBatch:
"""Create record batch to save array to disk."""
pa_array = array.data.magnitude if isinstance(array.data, BaseQuantity) else array.data
pa_array = time_series.data
if not isinstance(pa_array, pa.Array) and isinstance(pa_array, pint.Quantity):
pa_array = pa.array(pa_array.magnitude)
assert isinstance(pa_array, pa.Array)
schema = pa.schema([pa.field(variable_name, pa_array.type)])
return pa.record_batch([pa_array], schema=schema)
Expand Down
26 changes: 11 additions & 15 deletions src/infrasys/cost_curves.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing_extensions import Annotated
from infrasys.component import Component
from pydantic import Field
from infrasys.value_curves import InputOutputCurve, IncrementalCurve, AverageRateCurve
from infrasys.function_data import LinearFunctionData
from infrasys.models import InfraSysBaseModelWithIdentifers
from infrasys.value_curves import InputOutputCurve, IncrementalCurve, AverageRateCurve, LinearCurve
import pint


class ProductionVariableCostCurve(Component):
name: Annotated[str, Field(frozen=True)] = ""
class ProductionVariableCostCurve(InfraSysBaseModelWithIdentifers):
...


class CostCurve(ProductionVariableCostCurve):
Expand All @@ -23,12 +23,10 @@ class CostCurve(ProductionVariableCostCurve):
description="The underlying `ValueCurve` representation of this `ProductionVariableCostCurve`"
),
]
vom_units: Annotated[
vom_cost: Annotated[
InputOutputCurve,
Field(description="(default: natural units (MW)) The units for the x-axis of the curve"),
] = InputOutputCurve(
function_data=LinearFunctionData(proportional_term=0.0, constant_term=0.0)
)
] = LinearCurve(0.0)


class FuelCurve(ProductionVariableCostCurve):
Expand All @@ -45,15 +43,13 @@ class FuelCurve(ProductionVariableCostCurve):
description="The underlying `ValueCurve` representation of this `ProductionVariableCostCurve`"
),
]
vom_units: Annotated[
vom_cost: Annotated[
InputOutputCurve,
Field(description="(default: natural units (MW)) The units for the x-axis of the curve"),
] = InputOutputCurve(
function_data=LinearFunctionData(proportional_term=0.0, constant_term=0.0)
)
] = LinearCurve(0.0)
fuel_cost: Annotated[
float,
pint.Quantity | float,
Field(
description="Either a fixed value for fuel cost or the key to a fuel cost time series"
),
]
] = 0.0
18 changes: 10 additions & 8 deletions src/infrasys/function_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Defines models for cost functions"""

from infrasys import Component
from typing_extensions import Annotated
from typing import List, NamedTuple

import numpy as np
import pint
from pydantic import Field, model_validator
from pydantic.functional_validators import AfterValidator
from typing import NamedTuple, List
import numpy as np
from typing_extensions import Annotated

from infrasys.models import InfraSysBaseModelWithIdentifers


class XYCoords(NamedTuple):
Expand All @@ -15,11 +18,9 @@ class XYCoords(NamedTuple):
y: float


class FunctionData(Component):
class FunctionData(InfraSysBaseModelWithIdentifers):
"""BaseClass of FunctionData"""

name: Annotated[str, Field(frozen=True)] = ""


class LinearFunctionData(FunctionData):
"""Data representation for linear cost function.
Expand All @@ -32,7 +33,8 @@ class LinearFunctionData(FunctionData):
"""

proportional_term: Annotated[
float, Field(description="the proportional term in the represented function.")
pint.Quantity | float,
Field(description="the proportional term in the represented function."),
]
constant_term: Annotated[
float, Field(description="the constant term in the represented function.")
Expand Down
12 changes: 5 additions & 7 deletions src/infrasys/time_series_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
VALUE_COLUMN = "value"


ISArray: TypeAlias = Sequence | pa.Array | np.ndarray | BaseQuantity
ISArray: TypeAlias = Sequence | pa.Array | np.ndarray | pint.Quantity


class TimeSeriesStorageType(str, Enum):
Expand All @@ -45,7 +45,6 @@ class TimeSeriesStorageType(str, Enum):
class TimeSeriesData(InfraSysBaseModelWithIdentifers, abc.ABC):
"""Base class for all time series models"""

units: Optional[str] = None
variable_name: str
normalization: NormalizationModel = None

Expand Down Expand Up @@ -74,16 +73,15 @@ def length(self) -> int:

@field_validator("data", mode="before")
@classmethod
def check_data(cls, data) -> pa.Array | BaseQuantity: # Standarize what object we receive.
def check_data(
cls, data
) -> pa.Array | pa.ChunkedArray | pint.Quantity: # Standarize what object we receive.
"""Check time series data."""
if len(data) < 2:
msg = f"SingleTimeSeries length must be at least 2: {len(data)}"
raise ValueError(msg)

if isinstance(data, BaseQuantity):
if not isinstance(data.magnitude, pa.Array):
cls = type(data)
return cls(pa.array(data.magnitude), data.units)
if isinstance(data, pint.Quantity):
return data

if not isinstance(data, pa.Array):
Expand Down
46 changes: 43 additions & 3 deletions src/infrasys/value_curves.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Defines classes for value curves using cost functions"""

from typing_extensions import Annotated
from infrasys.component import Component
from infrasys.exceptions import ISOperationNotAllowed
from infrasys.function_data import (
LinearFunctionData,
Expand All @@ -13,9 +12,10 @@
from pydantic import Field
import numpy as np

from infrasys.models import InfraSysBaseModelWithIdentifers

class ValueCurve(Component):
name: Annotated[str, Field(frozen=True)] = ""

class ValueCurve(InfraSysBaseModelWithIdentifers):
input_at_zero: Annotated[
float | None,
Field(
Expand Down Expand Up @@ -209,3 +209,43 @@ def to_input_output(self) -> InputOutputCurve:
case _:
msg = "Function is not valid for the type of data provided."
raise ISOperationNotAllowed(msg)


def LinearCurve(proportional_term: float = 0.0, constant_term: float = 0.0) -> InputOutputCurve:
"""Creates a linear curve using the given proportional and constant terms.
Returns an instance of `InputOutputCurve` with the specified linear function parameters.
If no arguments are provided, both the `proportional_term` and `constant_term` default to 0.
Parameters
----------
proportional_term : float, optional
The slope of the linear curve. Defaults to 0.0.
constant_term : float, optional
The y-intercept of the linear curve. Defaults to 0.0.
Returns
-------
InputOutputCurve
An instance of `InputOutputCurve` with a `LinearFunctionData` object based on the given parameters.
Examples
--------
>>> LinearCurve()
InputOutputCurve(function_data=LinearFunctionData(proportional_term=0.0, constant_term=0.0))
>>> LinearCurve(10)
InputOutputCurve(function_data=LinearFunctionData(proportional_term=10.0, constant_term=0.0))
>>> LinearCurve(10, 20)
InputOutputCurve(function_data=LinearFunctionData(proportional_term=10.0, constant_term=20.0))
>>> LinearCurve(proportional_term=5.0, constant_term=15.0)
InputOutputCurve(function_data=LinearFunctionData(proportional_term=5.0, constant_term=15.0))
"""
return InputOutputCurve(
function_data=LinearFunctionData(
proportional_term=proportional_term, constant_term=constant_term
)
)
3 changes: 2 additions & 1 deletion tests/test_base_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from infrasys.base_quantity import ureg, BaseQuantity
from infrasys.component import Component
from infrasys.quantities import ActivePower, Time, Voltage
from pint import Quantity
from pint.errors import DimensionalityError
import pytest
import numpy as np
Expand All @@ -16,7 +17,7 @@ def test_base_quantity():

unit = distance_quantity(100, "meter")
assert isinstance(unit, BaseQuantity)

assert isinstance(unit, Quantity)
# Check that we can not assign units that are not-related.
with pytest.raises(DimensionalityError):
_ = distance_quantity(100, "kWh")
Expand Down
16 changes: 10 additions & 6 deletions tests/test_cost_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ def test_cost_curve():
value_curve=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=1.0, constant_term=2.0)
),
vom_units=InputOutputCurve(
vom_cost=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=2.0, constant_term=1.0)
),
)

assert isinstance(cost_curve.value_curve.function_data, LinearFunctionData)
assert cost_curve.value_curve.function_data.proportional_term == 1.0
assert cost_curve.vom_units.function_data.proportional_term == 2.0
assert isinstance(cost_curve.vom_cost.function_data, LinearFunctionData)
assert cost_curve.vom_cost.function_data.proportional_term == 2.0


def test_fuel_curve():
Expand All @@ -30,14 +32,16 @@ def test_fuel_curve():
value_curve=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=1.0, constant_term=2.0)
),
vom_units=InputOutputCurve(
vom_cost=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=2.0, constant_term=1.0)
),
fuel_cost=2.5,
)

assert isinstance(fuel_curve.value_curve.function_data, LinearFunctionData)
assert fuel_curve.value_curve.function_data.proportional_term == 1.0
assert fuel_curve.vom_units.function_data.proportional_term == 2.0
assert isinstance(fuel_curve.vom_cost.function_data, LinearFunctionData)
assert fuel_curve.vom_cost.function_data.proportional_term == 2.0
assert fuel_curve.fuel_cost == 2.5


Expand All @@ -48,7 +52,7 @@ def test_value_curve_custom_serialization():
value_curve=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=1.0, constant_term=2.0)
),
vom_units=InputOutputCurve(
vom_cost=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=2.0, constant_term=1.0)
),
),
Expand All @@ -73,7 +77,7 @@ def test_value_curve_serialization(tmp_path):
value_curve=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=1.0, constant_term=2.0)
),
vom_units=InputOutputCurve(
vom_cost=InputOutputCurve(
function_data=LinearFunctionData(proportional_term=2.0, constant_term=1.0)
),
),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_single_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_with_quantity():
assert ts.length == length
assert ts.resolution == resolution
assert ts.initial_time == initial_time
assert isinstance(ts.data.magnitude, pa.Array)
assert ts.data[-1].as_py() == length - 1
assert isinstance(ts.data, ActivePower)
assert ts.data[-1].magnitude == length - 1


def test_normalization():
Expand Down
48 changes: 23 additions & 25 deletions tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ def test_time_series():


@pytest.mark.parametrize(
"params", list(itertools.product([True, False], [True, False], [True, False]))
"in_memory,use_quantity,sql_json",
list(itertools.product([True, False], [True, False], [True, False])),
)
def test_time_series_retrieval(params):
in_memory, use_quantity, sql_json = params
def test_time_series_retrieval(in_memory, use_quantity, sql_json):
try:
if not sql_json:
os.environ["__INFRASYS_NON_JSON_SQLITE__"] = "1"
Expand Down Expand Up @@ -290,29 +290,33 @@ def test_time_series_retrieval(params):
for metadata in system.list_time_series_metadata(gen, scenario="high"):
assert metadata.user_attributes["scenario"] == "high"

assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="high", model_year="2030"
assert all(
np.equal(
system.get_time_series(
gen, variable_name, scenario="high", model_year="2030"
).data,
ts1.data,
)
== ts1
)
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="high", model_year="2035"
assert all(
np.equal(
system.get_time_series(
gen, variable_name, scenario="high", model_year="2035"
).data,
ts2.data,
)
== ts2
)
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="low", model_year="2030"
assert all(
np.equal(
system.get_time_series(gen, variable_name, scenario="low", model_year="2030").data,
ts3.data,
)
== ts3
)
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="low", model_year="2035"
assert all(
np.equal(
system.get_time_series(gen, variable_name, scenario="low", model_year="2035").data,
ts4.data,
)
== ts4
)

with pytest.raises(ISAlreadyAttached):
Expand All @@ -324,12 +328,6 @@ def test_time_series_retrieval(params):
gen, variable_name=variable_name, scenario="high", model_year="2030"
)
assert not system.has_time_series(gen, variable_name=variable_name, model_year="2036")
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="high", model_year="2030"
)
== ts1
)
with pytest.raises(ISOperationNotAllowed):
system.get_time_series(gen, variable_name=variable_name, scenario="high")
with pytest.raises(ISNotStored):
Expand Down
Loading

0 comments on commit ae517df

Please sign in to comment.