Skip to content

Commit

Permalink
Merge pull request #761 from PowerGridModel/refactor/move-dataset-typ…
Browse files Browse the repository at this point in the history
…e-to-internal-utils

move datasetype from dataset to _utils
  • Loading branch information
nitbharambe authored Oct 7, 2024
2 parents 677c078 + def967b commit 937b6f3
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 84 deletions.
46 changes: 46 additions & 0 deletions src/power_grid_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
SparseBatchData,
)
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.errors import PowerGridError
from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict


Expand Down Expand Up @@ -630,3 +631,48 @@ def _extract_row_based_data(data: ComponentData, is_batch: bool | None = None) -

def _extract_data_from_component_data(data: ComponentData, is_batch: bool | None = None):
return _extract_columnar_data(data, is_batch) if is_columnar(data) else _extract_row_based_data(data, is_batch)


def get_dataset_type(data: Dataset) -> DatasetType:
"""
Deduce the dataset type from the provided dataset.
Args:
data: the dataset
Raises:
ValueError
if the dataset type cannot be deduced because multiple dataset types match the format
(probably because the data contained no supported components, e.g. was empty)
PowerGridError
if no dataset type matches the format of the data
(probably because the data contained conflicting data formats)
Returns:
The dataset type.
"""
candidates = set(power_grid_meta_data.keys())

if all(is_columnar(v) for v in data.values()):
raise ValueError("The dataset type could not be deduced. At least one component should have row based data.")

for dataset_type, dataset_metadatas in power_grid_meta_data.items():
for component, dataset_metadata in dataset_metadatas.items():
if component not in data or is_columnar(data[component]):
continue
component_data = data[component]

component_dtype = component_data["data"].dtype if is_sparse(component_data) else component_data.dtype
if component_dtype is not dataset_metadata.dtype:
candidates.discard(dataset_type)
break

if not candidates:
raise PowerGridError(
"The dataset type could not be deduced because no type matches the data. "
"This usually means inconsistent data was provided."
)
if len(candidates) > 1:
raise ValueError("The dataset type could not be deduced because multiple dataset types match the data.")

return next(iter(candidates))
48 changes: 1 addition & 47 deletions src/power_grid_model/core/power_grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Any, Mapping, Optional

from power_grid_model._utils import is_columnar, is_nan_or_equivalent, is_sparse, process_data_filter
from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, process_data_filter
from power_grid_model.core.buffer_handling import (
BufferProperties,
CAttributeBuffer,
Expand All @@ -29,7 +29,6 @@
from power_grid_model.core.power_grid_meta import ComponentMetaData, DatasetMetaData, power_grid_meta_data
from power_grid_model.data_types import AttributeType, ComponentData, Dataset
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.errors import PowerGridError
from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict


Expand Down Expand Up @@ -126,51 +125,6 @@ def total_elements(self) -> Mapping[ComponentType, int]:
}


def get_dataset_type(data: Dataset) -> DatasetType:
"""
Deduce the dataset type from the provided dataset.
Args:
data: the dataset
Raises:
ValueError
if the dataset type cannot be deduced because multiple dataset types match the format
(probably because the data contained no supported components, e.g. was empty)
PowerGridError
if no dataset type matches the format of the data
(probably because the data contained conflicting data formats)
Returns:
The dataset type.
"""
candidates = set(power_grid_meta_data.keys())

if all(is_columnar(v) for v in data.values()):
raise ValueError("The dataset type could not be deduced. At least one component should have row based data.")

for dataset_type, dataset_metadatas in power_grid_meta_data.items():
for component, dataset_metadata in dataset_metadatas.items():
if component not in data or is_columnar(data[component]):
continue
component_data = data[component]

component_dtype = component_data["data"].dtype if is_sparse(component_data) else component_data.dtype
if component_dtype is not dataset_metadata.dtype:
candidates.discard(dataset_type)
break

if not candidates:
raise PowerGridError(
"The dataset type could not be deduced because no type matches the data. "
"This usually means inconsistent data was provided."
)
if len(candidates) > 1:
raise ValueError("The dataset type could not be deduced because multiple dataset types match the data.")

return next(iter(candidates))


class CMutableDataset:
"""
A view of a user-owned dataset.
Expand Down
2 changes: 1 addition & 1 deletion src/power_grid_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
_extract_indptr,
get_and_verify_batch_sizes as _get_and_verify_batch_sizes,
get_batch_size as _get_batch_size,
get_dataset_type,
is_columnar,
is_sparse,
)
from power_grid_model.core.dataset_definitions import DatasetType, _map_to_component_types
from power_grid_model.core.power_grid_dataset import get_dataset_type
from power_grid_model.core.serialization import ( # pylint: disable=unused-import
json_deserialize,
json_serialize,
Expand Down
34 changes: 1 addition & 33 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
#
# SPDX-License-Identifier: MPL-2.0

import itertools

import numpy as np
import pytest

from power_grid_model.core.dataset_definitions import ComponentType, DatasetType
from power_grid_model.core.power_grid_dataset import CConstDataset, get_dataset_type
from power_grid_model.core.power_grid_dataset import CConstDataset
from power_grid_model.core.power_grid_meta import power_grid_meta_data
from power_grid_model.errors import PowerGridError

Expand All @@ -34,36 +32,6 @@ def dataset_type(request):
return request.param


def test_get_dataset_type(dataset_type):
assert (
get_dataset_type(
data={
ComponentType.node: np.zeros(1, dtype=power_grid_meta_data[dataset_type]["node"]),
ComponentType.sym_load: np.zeros(1, dtype=power_grid_meta_data[dataset_type]["sym_load"]),
}
)
== dataset_type
)


def test_get_dataset_type__empty_data():
with pytest.raises(ValueError):
get_dataset_type(data={})


def test_get_dataset_type__conflicting_data():
for first, second in itertools.product(all_dataset_types(), all_dataset_types()):
data = {
"node": np.zeros(1, dtype=power_grid_meta_data[first]["node"]),
"sym_load": np.zeros(1, dtype=power_grid_meta_data[second]["sym_load"]),
}
if first == second:
assert get_dataset_type(data=data) == first
else:
with pytest.raises(PowerGridError):
get_dataset_type(data=data)


def test_const_dataset__empty_dataset(dataset_type):
dataset = CConstDataset(data={}, dataset_type=dataset_type)
info = dataset.get_info()
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MPL-2.0

import itertools
from unittest.mock import MagicMock, patch

import numpy as np
Expand All @@ -14,14 +15,17 @@
convert_dataset_to_python_dataset,
get_and_verify_batch_sizes,
get_batch_size,
get_dataset_type,
is_nan,
process_data_filter,
split_dense_batch_data_in_batches,
split_sparse_batch_data_in_batches,
)
from power_grid_model.core.dataset_definitions import ComponentType as CT, DatasetType as DT
from power_grid_model.core.power_grid_meta import power_grid_meta_data
from power_grid_model.data_types import BatchDataset, BatchList
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.errors import PowerGridError

from .utils import convert_python_to_numpy

Expand Down Expand Up @@ -581,3 +585,37 @@ def test_get_batch_size(data, expected_size):
def test_get_batch_size__single_dataset_is_not_supported(data):
with pytest.raises(ValueError):
get_batch_size(data)


@pytest.mark.parametrize("dataset_type", [DT.input, DT.update, DT.sym_output, DT.asym_output, DT.sc_output])
def test_get_dataset_type(dataset_type):
assert (
get_dataset_type(
data={
CT.node: np.zeros(1, dtype=power_grid_meta_data[dataset_type]["node"]),
CT.sym_load: np.zeros(1, dtype=power_grid_meta_data[dataset_type]["sym_load"]),
}
)
== dataset_type
)


def test_get_dataset_type__empty_data():
with pytest.raises(ValueError):
get_dataset_type(data={})


def test_get_dataset_type__conflicting_data():
for first, second in itertools.product(
[DT.input, DT.update, DT.sym_output, DT.asym_output, DT.sc_output],
[DT.input, DT.update, DT.sym_output, DT.asym_output, DT.sc_output],
):
data = {
"node": np.zeros(1, dtype=power_grid_meta_data[first]["node"]),
"sym_load": np.zeros(1, dtype=power_grid_meta_data[second]["sym_load"]),
}
if first == second:
assert get_dataset_type(data=data) == first
else:
with pytest.raises(PowerGridError):
get_dataset_type(data=data)
4 changes: 1 addition & 3 deletions tests/unit/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import pytest

from power_grid_model import DatasetType
from power_grid_model._utils import is_columnar, is_sparse
from power_grid_model.core.dataset_definitions import ComponentType
from power_grid_model.core.power_grid_dataset import get_dataset_type
from power_grid_model._utils import get_dataset_type, is_columnar, is_sparse
from power_grid_model.data_types import BatchDataset, Dataset, SingleDataset
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize
Expand Down

0 comments on commit 937b6f3

Please sign in to comment.