Skip to content

Commit

Permalink
Merge pull request #55 from NREL/fix/time-series-as-numpy
Browse files Browse the repository at this point in the history
Use numpy arrays in time series
  • Loading branch information
daniel-thom authored Nov 18, 2024
2 parents 169accf + 00c9dfe commit f3703dd
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ repos:
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.13.0
hooks:
- id: mypy
14 changes: 12 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"numpy~=1.26.4",
# This is what we want, but are currently be limited by pyarrow.
# Leave unbounded until pyarrow upgrades numpy.
#"numpy >= 2, < 3",
"numpy",
"pyarrow~=15.0.2",
"pint~=0.23",
"loguru~=0.7.2",
Expand All @@ -37,7 +40,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"furo",
"mypy",
"mypy >=1.13, < 2",
"myst_parser",
"pre-commit",
"pyarrow-stubs",
Expand All @@ -56,6 +59,13 @@ Documentation = "https://github.com/NREL/infrasys#readme"
Issues = "https://github.com/NREL/infrasys/issues"
Source = "https://github.com/NREL/infrasys"

[tool.mypy]
check_untyped_defs = true
files = [
"src",
"tests",
]

[tool.pytest.ini_options]
pythonpath = "src"
minversion = "6.0"
Expand Down
17 changes: 11 additions & 6 deletions src/infrasys/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Optional
from uuid import UUID

import numpy as np
import pyarrow as pa
import pint
from loguru import logger
Expand Down Expand Up @@ -112,24 +113,28 @@ def _get_single_time_series(
index, length = metadata.get_range(start_time=start_time, length=length)
data = base_ts[metadata.variable_name][index : index + length]
if metadata.quantity_metadata is not None:
data = metadata.quantity_metadata.quantity_type(data, metadata.quantity_metadata.units)
np_array = metadata.quantity_metadata.quantity_type(
data, metadata.quantity_metadata.units
)
else:
np_array = np.array(data)
return SingleTimeSeries(
uuid=metadata.time_series_uuid,
variable_name=metadata.variable_name,
resolution=metadata.resolution,
initial_time=start_time or metadata.initial_time,
data=data,
data=np_array,
normalization=metadata.normalization,
)

def _convert_to_record_batch(
self, time_series: SingleTimeSeries, variable_name: str
) -> pa.RecordBatch:
"""Create record batch to save array to disk."""
pa_array = time_series.data
if not isinstance(pa_array, pa.Array) and isinstance(pa_array, pint.Quantity):
pa_array = pa.array(pa_array.magnitude)
assert isinstance(pa_array, pa.Array)
if isinstance(time_series.data, pint.Quantity):
pa_array = pa.array(time_series.data.magnitude)
else:
pa_array = pa.array(time_series.data)
schema = pa.schema([pa.field(variable_name, pa_array.type)])
return pa.record_batch([pa_array], schema=schema)

Expand Down
3 changes: 2 additions & 1 deletion src/infrasys/function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pint
from numpy.typing import NDArray
from pydantic import Field, model_validator
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated
Expand Down Expand Up @@ -186,7 +187,7 @@ def validate_piecewise_xy(self):
return self


def get_x_lengths(x_coords: List[float]) -> List[float]:
def get_x_lengths(x_coords: List[float]) -> NDArray[np.float64]:
return np.subtract(x_coords[1:], x_coords[:-1])


Expand Down
5 changes: 3 additions & 2 deletions src/infrasys/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal, Optional, Annotated, Union

import numpy as np
from numpy.typing import NDArray
from pydantic import Field

from infrasys.models import InfraSysBaseModel
Expand All @@ -29,7 +30,7 @@ class NormalizationMax(NormalizationBase):
max_value: Optional[float] = None
normalization_type: Literal[NormalizationType.MAX] = NormalizationType.MAX

def normalize_array(self, data: np.ndarray) -> np.ndarray:
def normalize_array(self, data: NDArray) -> NDArray:
self.max_value = np.max(data)
return data / self.max_value

Expand All @@ -40,7 +41,7 @@ class NormalizationByValue(NormalizationBase):
value: float
normalization_type: Literal[NormalizationType.BY_VALUE] = NormalizationType.BY_VALUE

def normalize_array(self, data: np.ndarray) -> np.ndarray:
def normalize_array(self, data: NDArray) -> NDArray:
return data / self.value


Expand Down
45 changes: 28 additions & 17 deletions src/infrasys/time_series_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from uuid import UUID

import numpy as np
import pyarrow as pa
import pint
from numpy.typing import NDArray
from pydantic import (
Field,
WithJsonSchema,
Expand All @@ -29,7 +29,6 @@
)
from typing_extensions import Annotated

from infrasys.base_quantity import BaseQuantity
from infrasys.exceptions import (
ISConflictingArguments,
InconsistentTimeseriesAggregation,
Expand All @@ -42,7 +41,7 @@
VALUE_COLUMN = "value"


ISArray: TypeAlias = Sequence | pa.Array | np.ndarray | pint.Quantity
ISArray: TypeAlias = Sequence | NDArray | pint.Quantity


class TimeSeriesStorageType(str, Enum):
Expand Down Expand Up @@ -74,7 +73,7 @@ def get_time_series_metadata_type() -> Type:
class SingleTimeSeries(TimeSeriesData):
"""Defines a time array with a single dimension of floats."""

data: pa.Array | pint.Quantity
data: NDArray | pint.Quantity
resolution: timedelta
initial_time: datetime

Expand All @@ -83,21 +82,36 @@ def length(self) -> int:
"""Return the length of the data."""
return len(self.data)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, SingleTimeSeries):
raise NotImplementedError
is_equal = True
for field in self.model_fields_set:
if field == "data":
if not (self.data == other.data).all():
is_equal = False
break
else:
if not getattr(self, field) == getattr(other, field):
is_equal = False
break
return is_equal

@field_validator("data", mode="before")
@classmethod
def check_data(
cls, data
) -> pa.Array | pa.ChunkedArray | pint.Quantity: # Standarize what object we receive.
def check_data(cls, data) -> NDArray | pint.Quantity: # Standarize what object we receive.
"""Check time series data."""
if len(data) < 2:
msg = f"SingleTimeSeries length must be at least 2: {len(data)}"
raise ValueError(msg)

if isinstance(data, pint.Quantity):
if not isinstance(data.magnitude, np.ndarray):
return type(data)(np.array(data.magnitude), units=data.units)
return data

if not isinstance(data, pa.Array):
return pa.array(data)
if not isinstance(data, np.ndarray):
return np.array(data)

return data

Expand Down Expand Up @@ -136,20 +150,17 @@ def aggregate(cls, ts_data: list[Self]) -> Self:
raise InconsistentTimeseriesAggregation(msg)

# Aggregate data
is_quantity = issubclass(next(iter(unique_props["data_type"])), BaseQuantity)
is_quantity = issubclass(next(iter(unique_props["data_type"])), pint.Quantity)
magnitude_type = (
type(ts_data[0].data.magnitude)
if is_quantity
else next(iter(unique_props["data_type"]))
)

# Aggregate data based on magnitude type
if issubclass(magnitude_type, pa.Array):
if issubclass(magnitude_type, np.ndarray):
new_data = sum(
[
data.data.to_numpy() * (data.data.units if is_quantity else 1)
for data in ts_data
]
[data.data * (data.data.units if is_quantity else 1) for data in ts_data]
)
elif issubclass(magnitude_type, np.ndarray):
new_data = sum([data.data for data in ts_data])
Expand Down Expand Up @@ -274,7 +285,7 @@ class SingleTimeSeriesScalingFactor(SingleTimeSeries):


class QuantityMetadata(InfraSysBaseModel):
"""Contains the metadata needed to de-serialize time series stored within a BaseQuantity."""
"""Contains the metadata needed to de-serialize time series stored within a pint.Quantity."""

module: str
quantity_type: Annotated[Type, WithJsonSchema({"type": "string"})]
Expand Down Expand Up @@ -341,7 +352,7 @@ def from_data(cls, time_series: SingleTimeSeries, **user_attributes) -> Any:
quantity_type=type(time_series.data),
units=str(time_series.data.units),
)
if isinstance(time_series.data, BaseQuantity)
if isinstance(time_series.data, pint.Quantity)
else None
)
return cls(
Expand Down
14 changes: 11 additions & 3 deletions src/infrasys/value_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
QuadraticFunctionData,
PiecewiseLinearData,
PiecewiseStepData,
XYCoords,
running_sum,
)
from pydantic import Field
Expand Down Expand Up @@ -117,7 +118,9 @@ def to_input_output(self) -> InputOutputCurve:
points = running_sum(self.function_data)

return InputOutputCurve(
function_data=PiecewiseLinearData(points=[(p.x, p.y + c) for p in points]),
function_data=PiecewiseLinearData(
points=[XYCoords(p.x, p.y + c) for p in points]
),
input_at_zero=self.input_at_zero,
)

Expand Down Expand Up @@ -188,7 +191,10 @@ def to_input_output(self) -> InputOutputCurve:
else:
return InputOutputCurve(
function_data=QuadraticFunctionData(
quadratic_term=p, proportional_term=m, constant_term=c
# issue 53
quadratic_term=p,
proportional_term=m,
constant_term=c, # type: ignore
),
input_at_zero=self.input_at_zero,
)
Expand All @@ -203,7 +209,9 @@ def to_input_output(self) -> InputOutputCurve:
ys.insert(0, c)

return InputOutputCurve(
function_data=PiecewiseLinearData(points=list(zip(xs, ys))),
function_data=PiecewiseLinearData(
points=[XYCoords(x, y) for x, y in zip(xs, ys)]
),
input_at_zero=self.input_at_zero,
)
case _:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime, timedelta
from pathlib import Path

import pyarrow as pa
import numpy as np
from loguru import logger

from infrasys.arrow_storage import ArrowTimeSeriesStorage
Expand Down Expand Up @@ -88,5 +88,7 @@ def test_read_deserialize_time_series(tmp_path):
assert isinstance(deserialize_ts, SingleTimeSeries)
assert deserialize_ts.resolution == ts.resolution
assert deserialize_ts.initial_time == ts.initial_time
assert isinstance(deserialize_ts.data, pa.Array)
assert deserialize_ts.data[-1].as_py() == ts.length - 1
assert isinstance(deserialize_ts.data, np.ndarray)
length = ts.length
assert isinstance(length, int)
assert np.array_equal(deserialize_ts.data, np.array(range(length)))
4 changes: 2 additions & 2 deletions tests/test_base_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class _(BaseQuantity):
with pytest.raises(ValidationError):
BaseQuantityComponent(name="test", voltage=Voltage(test_magnitude, "meter"))

test_component = BaseQuantityComponent(name="test", voltage=[0, 1])
test_component = BaseQuantityComponent(name="test", voltage=Voltage([0, 1], units="volt"))
assert type(test_component.voltage) is Voltage
assert test_component.voltage.magnitude.tolist() == [0, 1]
assert test_component.voltage.units == test_unit
Expand All @@ -91,7 +91,7 @@ def test_different_validate(input_unit):


def test_custom_serialization():
component = BaseQuantityComponent(name="test", voltage=10.0)
component = BaseQuantityComponent(name="test", voltage=Voltage(10.0, units="volt"))

model_dump = component.model_dump(mode="json")

Expand Down
9 changes: 4 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from datetime import datetime, timedelta

import numpy as np
import pyarrow as pa
import pytest
from pydantic import WithJsonSchema
from typing_extensions import Annotated
Expand Down Expand Up @@ -98,7 +97,7 @@ def test_serialize_time_series(tmp_path, time_series_in_memory):
system2.remove_time_series(gen1b, variable_name=variable_name)

ts2 = system.get_time_series(gen1b, variable_name=variable_name)
assert ts2.data == ts.data
assert np.array_equal(ts2.data, ts.data)

system3 = SimpleSystem.from_json(filename, time_series_read_only=False)
assert system3.get_time_series_directory() != SimpleSystem._make_time_series_directory(
Expand All @@ -121,6 +120,7 @@ def test_serialize_quantity(tmp_path, distance):
system = SimpleSystem()
gen = SimpleGenerator.example()
component = ComponentWithPintQuantity(name="test", distance=distance)
assert gen.bus.coordinates is not None
system.add_components(gen.bus.coordinates, gen.bus, gen, component)
sys_file = tmp_path / "system.json"
system.to_json(sys_file)
Expand Down Expand Up @@ -156,9 +156,8 @@ def test_with_time_series_quantity(tmp_path):
assert ts.length == length
assert ts.resolution == resolution
assert ts.initial_time == initial_time
assert isinstance(ts2.data.magnitude, pa.Array)
assert ts2.data[-1].as_py() == length - 1
assert ts2.data.magnitude == pa.array(range(length))
assert isinstance(ts2.data.magnitude, np.ndarray)
assert np.array_equal(ts2.data.magnitude, np.array(range(length)))


@pytest.mark.parametrize("in_memory", [True, False])
Expand Down
Loading

0 comments on commit f3703dd

Please sign in to comment.