Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into pp/multi_time_from_list
Browse files Browse the repository at this point in the history
  • Loading branch information
ppinchuk committed Aug 14, 2024
2 parents 7b5dcc7 + dbcf643 commit 20e64c8
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 98 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pull_request_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
pip install --upgrade pip
pip install pytest
pip install pytest-cov
pip install pytest-timeout
pip install -e .
- name: Run pytest and Generate coverage report
run: |
Expand Down
20 changes: 2 additions & 18 deletions rex/multi_file_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from rex.renewable_resource import (NSRDB, SolarResource, GeothermalResource,
WindResource, WaveResource,
AbstractInterpolatedResource)
from rex.resource import Resource
from rex.resource import Resource, BaseDatasetIterable
from rex.utilities.exceptions import FileInputError, ResourceRuntimeError
from rex.utilities.utilities import unstupify_path


class MultiH5:
class MultiH5(BaseDatasetIterable):
"""
Class to handle multiple h5 file Resources
"""
Expand All @@ -32,8 +32,6 @@ def __init__(self, h5_files, check_files=False):
self._dset_map = self._map_file_dsets(h5_files)
self._h5_map = self._map_file_instances(set(self._dset_map.values()))

self._i = 0

if check_files:
self._preflight_check()

Expand Down Expand Up @@ -66,19 +64,6 @@ def __getitem__(self, dset):

return ds

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

def __iter__(self):
return self

def __contains__(self, dset):
return dset in self.datasets

Expand Down Expand Up @@ -405,7 +390,6 @@ def __init__(self, h5_source, unscale=True, str_decode=True,
self._shapes = None
self._chunks = None
self._dtypes = None
self._i = 0

self._interp_var = None
self._use_lapse = use_lapse_rate
Expand Down
13 changes: 1 addition & 12 deletions rex/multi_res_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(self, h5_hr, h5_lr, handler_class=Resource,
self._lr_res = handler_class(h5_lr, **handle_kwargs)
self._nn_map = nn_map
self._nn_d = nn_d
self._i = 0

if self._nn_map is None:
self._nn_d, self._nn_map = self.make_nn_map(self._hr_res,
Expand Down Expand Up @@ -237,17 +236,7 @@ def __getitem__(self, keys):
return out

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset
return iter(self.datasets)

def __contains__(self, dset):
return dset in self.datasets
Expand Down
19 changes: 2 additions & 17 deletions rex/multi_time_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
WaveResource,
WindResource,
)
from rex.resource import Resource
from rex.resource import Resource, BaseDatasetIterable
from rex.utilities.exceptions import FileInputError
from rex.utilities.parse_keys import parse_keys, parse_slice

Expand Down Expand Up @@ -60,7 +60,6 @@ def __init__(self, h5_path, res_cls=Resource, hsds=False, hsds_kwargs=None,
self._shape = None
self._time_index = None
self._time_slice_map = []
self._i = 0

def __repr__(self):
msg = ("{} for {}:\n Contains data from {} files"
Expand Down Expand Up @@ -421,7 +420,7 @@ def close(self):
f.close()


class MultiTimeResource:
class MultiTimeResource(BaseDatasetIterable):
"""
Class to handle resource data stored temporally accross multiple
.h5 files
Expand Down Expand Up @@ -522,7 +521,6 @@ def __init__(self, h5_path, unscale=True, str_decode=True,
self._h5 = MultiTimeH5(self.h5_path, res_cls=res_cls, **cls_kwargs)
self.h5_files = self._h5.h5_files
self.h5_file = self.h5_files[0]
self._i = 0

def __repr__(self):
msg = "{} for {}".format(self.__class__.__name__, self.h5_path)
Expand All @@ -540,19 +538,6 @@ def __exit__(self, type, value, traceback):
def __len__(self):
return len(self.h5.time_index)

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

def __getitem__(self, keys):
ds, ds_slice = parse_keys(keys)

Expand Down
14 changes: 1 addition & 13 deletions rex/multi_year_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, h5_path, years=None, res_cls=Resource, hsds=False,
self._datasets = None
self._shape = None
self._time_index = None
self._i = 0

def __repr__(self):
msg = ("{} for {}:\n Contains data for {} years"
Expand All @@ -82,17 +81,7 @@ def __getitem__(self, year):
return h5

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.years):
self._i = 0
raise StopIteration

year = self.years[self._i]
self._i += 1

return year
return iter(self.years)

def __contains__(self, year):
return year in self.years
Expand Down Expand Up @@ -451,7 +440,6 @@ def __init__(self, h5_path, years=None, unscale=True, str_decode=True,
**cls_kwargs)
self.h5_files = self._h5.h5_files
self.h5_file = self.h5_files[0]
self._i = 0

@property
def years(self):
Expand Down
30 changes: 14 additions & 16 deletions rex/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Classes to handle resource data
"""
import os
from abc import ABC
from abc import ABC, abstractmethod
from warnings import warn

import dateutil
Expand All @@ -17,6 +17,18 @@
from rex.utilities.utilities import check_tz, get_lat_lon_cols


class BaseDatasetIterable(ABC):
"""Base class for file that is iterable over datasets. """

@property
@abstractmethod
def datasets(self):
"""iterable: Datasets available in file. """

def __iter__(self):
return iter(self.datasets)


class ResourceDataset:
"""
h5py.Dataset wrapper for Resource .h5 files
Expand Down Expand Up @@ -583,7 +595,7 @@ def extract(cls, ds, ds_slice, scale_attr='scale_factor',
return dset[ds_slice]


class BaseResource(ABC):
class BaseResource(BaseDatasetIterable):
"""
Abstract Base class to handle resource .h5 files
"""
Expand Down Expand Up @@ -646,7 +658,6 @@ def __init__(self, h5_file, mode='r', unscale=True, str_decode=True,
self._shapes = None
self._chunks = None
self._dtypes = None
self._i = 0

def __repr__(self):
msg = "{} for {}".format(self.__class__.__name__, self.h5_file)
Expand Down Expand Up @@ -691,19 +702,6 @@ def __getitem__(self, keys):

return out

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

def __contains__(self, dset):
return dset in self.datasets

Expand Down
21 changes: 2 additions & 19 deletions rex/resource_extraction/resource_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
WaveResource,
WindResource,
)
from rex.resource import Resource, ResourceDataset
from rex.resource import Resource, ResourceDataset, BaseDatasetIterable
from rex.temporal_stats.temporal_stats import TemporalStats
from rex.utilities.exceptions import ResourceValueError, ResourceWarning
from rex.utilities.execution import SpawnProcessPool
Expand All @@ -39,7 +39,7 @@
logger = logging.getLogger(__name__)


class ResourceX:
class ResourceX(BaseDatasetIterable):
"""
Resource data extraction tool
"""
Expand Down Expand Up @@ -88,7 +88,6 @@ def __init__(self, res_h5, res_cls=None, tree=None, unscale=True,
group=group, hsds=hsds, hsds_kwargs=hsds_kwargs)
self._dist_thresh = None
self._tree = tree
self._i = 0

def __repr__(self):
msg = "{} extractor for {}".format(self._res.__class__.__name__,
Expand All @@ -114,19 +113,6 @@ def __getitem__(self, keys):
def __contains__(self, dset):
return dset in self.datasets

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

@property
def resource(self):
"""
Expand Down Expand Up @@ -1543,7 +1529,6 @@ def __init__(self, resource_path, res_cls=None, tree=None,
str_decode=str_decode, check_files=check_files)
self._dist_thresh = None
self._tree = tree
self._i = 0


class MultiYearResourceX(ResourceX):
Expand Down Expand Up @@ -1590,7 +1575,6 @@ def __init__(self, resource_path, years=None, tree=None, unscale=True,
hsds_kwargs=hsds_kwargs)
self._dist_thresh = None
self._tree = tree
self._i = 0

def get_means_map(self, ds_name, year=None, region=None,
region_col='state', max_workers=None,
Expand Down Expand Up @@ -1676,7 +1660,6 @@ def __init__(self, resource_path, tree=None, unscale=True,
hsds=hsds, hsds_kwargs=hsds_kwargs)
self._dist_thresh = None
self._tree = tree
self._i = 0


class SolarX(ResourceX):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(self):
with open("requirements.txt") as f:
install_requires = f.readlines()

test_requires = ["pytest>=5.2", ]
test_requires = ["pytest>=5.2", "pytest-timeout>=2.3.1"]
dev_requires = ["flake8", "pre-commit", "pylint", "hsds>=0.8.4"]
description = ("National Renewable Energy Laboratory's (NREL's) REsource "
"eXtraction tool: rex")
Expand Down
17 changes: 17 additions & 0 deletions tests/test_multi_res_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,20 @@ def test_preload_sam():
assert np.allclose(true, test)

mrr.close()


@pytest.mark.timeout(10)
def test_multi_res_resource_iterator():
"""
test MultiResolutionResource iterator. Incorrect implementation can
cause an infinite loop
"""
with tempfile.TemporaryDirectory() as td:
fp_hr, fp_lr = make_multi_res_files(td)
mrr = MultiResolutionResource(fp_hr, fp_lr, handler_class=WindResource)
dsets_permutation = {(a, b) for a in mrr for b in mrr}
num_dsets = len(mrr.datasets)

mrr.close()

assert len(dsets_permutation) == num_dsets ** 2
19 changes: 17 additions & 2 deletions tests/test_multi_time_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import pytest

from rex import TESTDATADIR
from rex.multi_time_resource import (MultiTimeH5, MultiTimeNSRDB,
MultiTimeWindResource)
from rex.multi_time_resource import (MultiTimeH5, MultiTimeResource,
MultiTimeNSRDB, MultiTimeWindResource)
from rex.resource import Resource


Expand Down Expand Up @@ -391,6 +391,21 @@ def test_multi_time_resource_acts_like_resource_single_file():
assert np.allclose(res[ds], mt_res[ds])


@pytest.mark.timeout(10)
def test_mt_iterator():
"""
test MultiTimeResource iterator. Incorrect implementation can
cause an infinite loop
"""
path = os.path.join(TESTDATADIR, 'wtk/ri_100_wtk_*.h5')

with MultiTimeResource(path) as res:
dsets_permutation = {(a, b) for a in res for b in res}
num_dsets = len(res.datasets)

assert len(dsets_permutation) == num_dsets ** 2


def execute_pytest(capture='all', flags='-rapP'):
"""Execute module as pytest with detailed summary report.
Expand Down
Loading

0 comments on commit 20e64c8

Please sign in to comment.