From ac8b3d2937770f99d88c769004d052530884f9a6 Mon Sep 17 00:00:00 2001 From: Paul Natsuo Kishimoto Date: Tue, 27 Sep 2016 12:21:56 -0400 Subject: [PATCH 1/4] Convert to xarray accessor - add new open_dataset() method - override xr.Dataset.__getitem__ - minimum changes to pass tests --- gdx/__init__.py | 184 ++++++++++++++++++++++++------------------- gdx/pycompat.py | 3 +- gdx/test/test_gdx.py | 47 ++++++----- 3 files changed, 127 insertions(+), 107 deletions(-) diff --git a/gdx/__init__.py b/gdx/__init__.py index 1ee50ed..3083870 100644 --- a/gdx/__init__.py +++ b/gdx/__init__.py @@ -3,29 +3,29 @@ unicode_literals) from itertools import cycle from logging import debug, info -# commented: for debugging +# # commented: for debugging # import logging # logging.basicConfig(level=logging.DEBUG) -import numpy +import numpy as np import pandas import xarray as xr - -from .pycompat import install_aliases, filter, raise_from, range, super, zip -install_aliases() +from xarray.core.utils import is_dict_like, hashable from .api import GDX, gdxcc, type_str, vartype_str +from .pycompat import install_aliases, filter, range, zip +install_aliases() -__version__ = '2' +__version__ = '4-dev' __all__ = [ - 'File', + 'open_dataset', ] -class File(xr.Dataset): +def open_dataset(filename, lazy=True, implicit=True, skip=set()): """Load the file at *filename* into memory. If *lazy* is ``True`` (default), then the data for GDX Parameters is not @@ -47,28 +47,61 @@ class File(xr.Dataset): otherwise, loading ``foo`` as declared raises :py:class:`MemoryError`. """ - # For the benefit of xr.Dataset.__getattr__ - _api = None - _index = [] - _state = {} - _alias = {} - _implicit = False - - def __init__(self, filename='', lazy=True, implicit=True, skip=set()): - """Constructor.""" - super(File, self).__init__() # Invoke Dataset constructor - - # load the GDX API + ds = xr.Dataset() + ds.gdx._initialize(filename, lazy, implicit, skip) + + return ds + + +# Override xarray.Dataset.__getitem__ to add GDX lazy-loading +def _dataset_getitem(self, key): + """DERP Access variables or coordinates this dataset as a + :py:class:`~xarray.DataArray`. + + Indexing with a list of names will return a new ``Dataset`` object. + """ + if is_dict_like(key): + return self.isel(**key) + + # GDX lazy-loading + self.gdx._lazy_load(key) + + if hashable(key): + return self._construct_dataarray(key) + else: + return self._copy_listed(np.asarray(key)) + + +xr.Dataset.__getitem__ = _dataset_getitem + + +@xr.register_dataset_accessor('gdx') +class GDXAccessor(object): + def __init__(self, xarray_obj): + self._obj = xarray_obj + self._initialized = False + + def _lazy_load(self, key): + if not self._initialized: + return + keys = [key] if hashable(key) else key + + for k in keys: + if k in self._state and isinstance(self._state[k], dict): + debug('Lazy-loading {}'.format(k)) + self._load_symbol_data(k) + + def _initialize(self, filename, lazy, implicit, skip): self._api = GDX() self._api.open_read(str(filename)) # Basic information about the GDX file v, p = self._api.file_version() sc, ec = self._api.system_info() - self.attrs['version'] = v.strip() - self.attrs['producer'] = p.strip() - self.attrs['symbol_count'] = sc - self.attrs['element_count'] = ec + self._obj.attrs['version'] = v.strip() + self._obj.attrs['producer'] = p.strip() + self._obj.attrs['symbol_count'] = sc + self._obj.attrs['element_count'] = ec # Initialize private variables self._index = [None for _ in range(sc + 1)] @@ -87,6 +120,8 @@ def __init__(self, filename='', lazy=True, implicit=True, skip=set()): if name not in skip: self._load_symbol_data(name) + self._initialized = True + def _load_symbol(self, index): """Load the *index*-th Symbol in the GDX file.""" # Load basic information @@ -129,11 +164,12 @@ def _load_symbol(self, index): elif type_code == gdxcc.GMS_DT_ALIAS: parent = desc.replace('Aliased with ', '') self._alias[name] = parent - assert self[parent].attrs['_gdx_type_code'] == gdxcc.GMS_DT_SET + assert (self._obj[parent].attrs['_gdx_type_code'] == + gdxcc.GMS_DT_SET) # Duplicate the variable - self._variables[name] = self._variables[parent] + self._obj._variables[name] = self._obj._variables[parent] self._state[name] = True - super(File, self).set_coords(name, inplace=True) + self._obj.set_coords(name, inplace=True) return name, type_code # The Symbol is either a Set, Parameter or Variable @@ -233,7 +269,7 @@ def _infer_domain(self, name, domain, elements): debug('guessing a better domain for {}: {}'.format(name, domain)) # Domain as a list of references to Variables in the File/xr.Dataset - domain_ = [self[d] for d in domain] + domain_ = [self._obj[d] for d in domain] for i, d in enumerate(domain_): # Iterate over dimensions e = set(elements[i]) @@ -246,13 +282,13 @@ def _infer_domain(self, name, domain, elements): d = '_{}_{}'.format(name, i) debug(('Constructing implicit set {} for dimension {} of {}\n' ' {} instead of {} elements') - .format(d, name, i, len(e), len(self['*']))) - self.coords[d] = elements[i] - d = self[d] + .format(d, name, i, len(e), len(self._obj['*']))) + self._obj.coords[d] = elements[i] + d = self._obj[d] else: # try to find a smaller domain for this dimension # Iterate over every Set/Coordinate - for s in self.coords.values(): + for s in self._obj.coords.values(): if s.ndim == 1 and set(s.values).issuperset(e) and \ len(s) < len(d): d = s # Found a smaller Set; use this instead @@ -272,7 +308,7 @@ def _infer_domain(self, name, domain, elements): def _root_dim(self, dim): """Return the ultimate ancestor of the 1-D Set *dim*.""" - parent = self[dim].dims[0] + parent = self._obj[dim].dims[0] return dim if parent == dim else self._root_dim(parent) def _empty(self, *dims, **kwargs): @@ -280,11 +316,11 @@ def _empty(self, *dims, **kwargs): size = [] dtypes = [] for d in dims: - size.append(len(self[d])) - dtypes.append(self[d].dtype) - dtype = kwargs.pop('dtype', numpy.result_type(*dtypes)) + size.append(len(self._obj[d])) + dtypes.append(self._obj[d].dtype) + dtype = kwargs.pop('dtype', np.result_type(*dtypes)) fv = kwargs.pop('fill_value') - return numpy.full(size, fill_value=fv, dtype=dtype) + return np.full(size, fill_value=fv, dtype=dtype) def _add_symbol(self, name, dim, domain, attrs): """Add a xray.DataArray with the data from Symbol *name*.""" @@ -300,14 +336,13 @@ def _add_symbol(self, name, dim, domain, attrs): kwargs = {} # Arguments to xr.Dataset.__setitem__() if dim == 0: # 0-D Variable or scalar Parameter - super(File, self).__setitem__(name, ([], data.popitem()[1], - gdx_attrs)) + self._obj.__setitem__(name, ([], data.popitem()[1], gdx_attrs)) return elif attrs['type_code'] == gdxcc.GMS_DT_SET: # GAMS Set if dim == 1: # One-dimensional Set - self.coords[name] = elements[0] - self.coords[name].attrs = gdx_attrs + self._obj.coords[name] = elements[0] + self._obj.coords[name].attrs = gdx_attrs else: # Multi-dimensional Sets are mappings indexed by other Sets; # elements are either 'on'/True or 'off'/False @@ -319,47 +354,48 @@ def _add_symbol(self, name, dim, domain, attrs): dims = [self._root_dim(d) for d in domain] # Update coords - self.coords.__setitem__(name, (dims, self._empty(*domain, - **kwargs), - gdx_attrs)) + self._obj.coords.__setitem__(name, (dims, self._empty(*domain, + **kwargs), gdx_attrs)) # Store the elements for k in data.keys(): - self[name].loc[k] = k if dim == 1 else True + self._obj[name].loc[k] = k if dim == 1 else True else: # 1+-dimensional GAMS Parameters kwargs['dtype'] = float - kwargs['fill_value'] = numpy.nan + kwargs['fill_value'] = np.nan dims = [self._root_dim(d) for d in domain] # Same as above # Create an empty xr.DataArray; this ensures that the data # read in below has the proper form and indices - super(File, self).__setitem__(name, (dims, self._empty(*domain, - **kwargs), - gdx_attrs)) + self._obj.__setitem__(name, (dims, self._empty(*domain, **kwargs), + gdx_attrs)) # Fill in extra keys - longest = numpy.argmax(self[name].values.shape) + longest = np.argmax(self._obj[name].values.shape) iters = [] for i, d in enumerate(dims): if i == longest: - iters.append(self[d].to_index()) + iters.append(self._obj[d].to_index()) else: - iters.append(cycle(self[d].to_index())) - data.update({k: numpy.nan for k in set(zip(*iters)) - + iters.append(cycle(self._obj[d].to_index())) + data.update({k: np.nan for k in set(zip(*iters)) - set(data.keys())}) # Use pandas and xarray IO methods to convert data, a dict, to a # xr.DataArray of the correct shape, then extract its values tmp = pandas.Series(data) tmp.index.names = dims - tmp = xr.DataArray.from_series(tmp).reindex_like(self[name]) - self[name].values = tmp.values + tmp = xr.DataArray.from_series(tmp).reindex_like(self._obj[name]) + self._obj[name].values = tmp.values def dealias(self, name): """Identify the GDX Symbol that *name* refers to, and return the corresponding :py:class:`xarray.DataArray`.""" - return self[self._alias[name]] if name in self._alias else self[name] + if name in self._alias: + return self._obj[self._alias[name]] + else: + return self._obj[name] def extract(self, name): """Extract the GAMS Symbol *name* from the dataset. @@ -374,7 +410,7 @@ def extract(self, name): dimensions), which does not make reference to the :class:`File`. """ # Copy the Symbol, triggering lazy-loading if needed - result = self[name].copy() + result = self._obj[name].copy() # Declared dimensions of the Symbol, and their parents try: @@ -397,7 +433,7 @@ def extract(self, name): # Dimension is indexed by 'p', but declared 'c'. First drop # the elements which do not appear in the sub-Set c;, then # rename 'p' to 'c' - drop = set(self[p].values) - set(self[c].values) - set('') + drop = set(self._obj[p].values) - set(self._obj[c].values) result = result.drop(drop, dim=p).swap_dims({p: c}) # Add the old coord to the set of coords to drop drop_coords.add(p) @@ -412,14 +448,14 @@ def info(self, name): attrs['type_str'], name, ','.join(attrs['domain']), attrs['records'], attrs['description']) else: - return repr(self[name]) + return repr(self._obj[name]) def _loaded_and_cached(self, type_code): """Return a list of loaded and not-loaded Symbols of *type_code*.""" names = set() for name, state in self._state.items(): if state is True: - tc = self._variables[name].attrs['_gdx_type_code'] + tc = self._obj._variables[name].attrs['_gdx_type_code'] elif isinstance(state, dict): tc = state['attrs']['type_code'] else: # pragma: no cover @@ -437,19 +473,19 @@ def set(self, name, as_dict=False): :func:`set()` returns the elements without these placeholders. """ - assert self[name].attrs['_gdx_type_code'] == gdxcc.GMS_DT_SET, \ + assert self._obj[name].attrs['_gdx_type_code'] == gdxcc.GMS_DT_SET, \ 'Variable {} is not a GAMS Set'.format(name) - if len(self[name].dims) > 1: - return self[name] + if len(self._obj[name].dims) > 1: + return self._obj[name] elif as_dict: from collections import OrderedDict result = OrderedDict() - parent = self[name].attrs['_gdx_domain'][0] - for label in self[parent].values: - result[label] = label in self[name].values + parent = self._obj[name].attrs['_gdx_domain'][0] + for label in self._obj[parent].values: + result._obj[label] = label in self[name].values return result else: - return list(self[name].values) + return list(self._obj[name].values) def sets(self): """Return a list of all GDX Sets.""" @@ -462,16 +498,4 @@ def parameters(self): def get_symbol_by_index(self, index): """Retrieve the GAMS Symbol from the *index*-th position of the :class:`File`.""" - return self[self._index[index]] - - def __getitem__(self, key): - """Set element access.""" - try: - return super(File, self).__getitem__(key) - except KeyError as e: - if isinstance(self._state[key], dict): - debug('Lazy-loading {}'.format(key)) - self._load_symbol_data(key) - return super(File, self).__getitem__(key) - else: - raise raise_from(KeyError(key), e) + return self._obj[self._index[index]] diff --git a/gdx/pycompat.py b/gdx/pycompat.py index 601acfa..7fad6e8 100644 --- a/gdx/pycompat.py +++ b/gdx/pycompat.py @@ -1,8 +1,7 @@ import sys -from builtins import filter, range, object, super, zip +from builtins import filter, range, object, zip from future.standard_library import install_aliases -from future.utils import raise_from PY3 = sys.version_info[0] >= 3 diff --git a/gdx/test/test_gdx.py b/gdx/test/test_gdx.py index c14e25c..4617d2b 100644 --- a/gdx/test/test_gdx.py +++ b/gdx/test/test_gdx.py @@ -35,13 +35,13 @@ def finalize(): @pytest.fixture(scope='class') def gdxfile(rawgdx): """A gdx.File fixture.""" - return gdx.File(rawgdx) + return gdx.open_dataset(rawgdx) @pytest.fixture(scope='class') def gdxfile_explicit(rawgdx): """A gdx.File fixture, instantiated with implicit=False.""" - return gdx.File(rawgdx, implicit=False) + return gdx.open_dataset(rawgdx, implicit=False) @pytest.fixture(scope='session') @@ -127,34 +127,29 @@ def test_bad_method(self): class TestFile: def test_init(self, rawgdx): - gdx.File(rawgdx) - gdx.File(rawgdx, lazy=False) + gdx.open_dataset(rawgdx) + gdx.open_dataset(rawgdx, lazy=False) with pytest.raises(FileNotFoundError): - gdx.File('nonexistent.gdx') + gdx.open_dataset('nonexistent.gdx') def test_num_parameters(self, gdxfile, actual): - print(gdxfile.parameters()) - assert len(gdxfile.parameters()) == len(actual.data_vars) + assert len(gdxfile.gdx.parameters()) == len(actual.data_vars) def test_num_sets(self, gdxfile, actual): - assert len(gdxfile.sets()) == len(actual.coords) + assert len(gdxfile.gdx.sets()) == len(actual.coords) def test_get_symbol(self, gdxfile): gdxfile['s'] + gdxfile['p1'] def test_get_symbol_by_index(self, gdxfile, actual): for name in actual: - sym = gdxfile.get_symbol_by_index(actual[name].attrs['_gdx_index']) + sym = gdxfile.gdx.get_symbol_by_index( + actual[name].attrs['_gdx_index']) assert sym.name == name # Giving too high an index results in IndexError with pytest.raises(IndexError): - gdxfile.get_symbol_by_index(gdxfile.attrs['symbol_count'] + 1) - - def test_getattr(self, gdxfile, actual): - for name in actual: - getattr(gdxfile, name) - with pytest.raises(AttributeError): - gdxfile.notasymbolname + gdxfile.gdx.get_symbol_by_index(gdxfile.attrs['symbol_count'] + 1) def test_getitem(self, gdxfile, actual): for name in actual: @@ -165,15 +160,17 @@ def test_getitem(self, gdxfile, actual): gdxfile['e1'] def test_info1(self, gdxfile): - assert gdxfile.info('s1').startswith("") + assert gdxfile.gdx.info('s1') \ + .startswith("") def test_info2(self, rawgdx): # Use a File where p1 is guaranteed to not have been loaded: - assert (gdx.File(rawgdx).info('p1') == 'unknown parameter p1(s), 1 ' - 'records: Example parameter with animal data') + assert (gdx.open_dataset(rawgdx).gdx.info('p1') == + 'unknown parameter p1(s), 1 records: Example parameter with ' + 'animal data') def test_dealias(self, gdxfile): - assert gdxfile.dealias('s_').equals(gdxfile['s']) + assert gdxfile.gdx.dealias('s_').equals(gdxfile['s']) def test_domain(self, gdxfile, actual): assert gdxfile['p6'].dims == actual['p6'].dims @@ -181,12 +178,12 @@ def test_domain(self, gdxfile, actual): def test_extract(self, gdxfile, gdxfile_explicit, actual): # TODO add p5, p7 for name in ['p1', 'p2', 'p3', 'p4', 'p6']: - assert gdxfile.extract(name).equals(actual[name]) + assert gdxfile.gdx.extract(name).equals(actual[name]) - gdxfile_explicit.extract('p5') + gdxfile_explicit.gdx.extract('p5') with pytest.raises(KeyError): - gdxfile.extract('notasymbolname') + gdxfile.gdx.extract('notasymbolname') def test_implicit(self, gdxfile): assert gdxfile['p7'].shape == (3, 3) @@ -195,8 +192,8 @@ def test_implicit(self, gdxfile): class TestSet: def test_len(self, gdxfile, actual): assert len(gdxfile.s) == len(actual['s']) - assert len(gdxfile.set('s1')) == len(actual['s1']) - assert len(gdxfile.set('s2')) == len(actual['s2']) + assert len(gdxfile.gdx.set('s1')) == len(actual['s1']) + assert len(gdxfile.gdx.set('s2')) == len(actual['s2']) def test_getitem(self, gdxfile): for i in range(len(gdxfile.s)): From 5109ffe2df09cbfbba928d55f95afe23a6a1d318 Mon Sep 17 00:00:00 2001 From: Paul Natsuo Kishimoto Date: Tue, 27 Sep 2016 12:29:33 -0400 Subject: [PATCH 2/4] reorder methods of GDXAccessor in alpha order --- gdx/__init__.py | 390 ++++++++++++++++++++++++------------------------ 1 file changed, 195 insertions(+), 195 deletions(-) diff --git a/gdx/__init__.py b/gdx/__init__.py index 3083870..1c63677 100644 --- a/gdx/__init__.py +++ b/gdx/__init__.py @@ -55,7 +55,7 @@ def open_dataset(filename, lazy=True, implicit=True, skip=set()): # Override xarray.Dataset.__getitem__ to add GDX lazy-loading def _dataset_getitem(self, key): - """DERP Access variables or coordinates this dataset as a + """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. Indexing with a list of names will return a new ``Dataset`` object. @@ -81,15 +81,176 @@ def __init__(self, xarray_obj): self._obj = xarray_obj self._initialized = False - def _lazy_load(self, key): - if not self._initialized: + def _add_symbol(self, name, dim, domain, attrs): + """Add a xray.DataArray with the data from Symbol *name*.""" + # Transform the attrs for storage, unpack data + gdx_attrs = {'_gdx_{}'.format(k): v for k, v in attrs.items()} + data = self._state[name]['data'] + elements = self._state[name]['elements'] + + # Erase the cache; this also prevents __getitem__ from triggering lazy- + # loading, which is still in progress + self._state[name] = True + + kwargs = {} # Arguments to xr.Dataset.__setitem__() + if dim == 0: + # 0-D Variable or scalar Parameter + self._obj.__setitem__(name, ([], data.popitem()[1], gdx_attrs)) return - keys = [key] if hashable(key) else key + elif attrs['type_code'] == gdxcc.GMS_DT_SET: # GAMS Set + if dim == 1: + # One-dimensional Set + self._obj.coords[name] = elements[0] + self._obj.coords[name].attrs = gdx_attrs + else: + # Multi-dimensional Sets are mappings indexed by other Sets; + # elements are either 'on'/True or 'off'/False + kwargs['dtype'] = bool + kwargs['fill_value'] = False - for k in keys: - if k in self._state and isinstance(self._state[k], dict): - debug('Lazy-loading {}'.format(k)) - self._load_symbol_data(k) + # Don't define over the actual domain dimensions, but over the + # parent Set/xr.Coordinates for each dimension + dims = [self._root_dim(d) for d in domain] + + # Update coords + self._obj.coords.__setitem__(name, (dims, self._empty(*domain, + **kwargs), gdx_attrs)) + + # Store the elements + for k in data.keys(): + self._obj[name].loc[k] = k if dim == 1 else True + else: # 1+-dimensional GAMS Parameters + kwargs['dtype'] = float + kwargs['fill_value'] = np.nan + + dims = [self._root_dim(d) for d in domain] # Same as above + + # Create an empty xr.DataArray; this ensures that the data + # read in below has the proper form and indices + self._obj.__setitem__(name, (dims, self._empty(*domain, **kwargs), + gdx_attrs)) + + # Fill in extra keys + longest = np.argmax(self._obj[name].values.shape) + iters = [] + for i, d in enumerate(dims): + if i == longest: + iters.append(self._obj[d].to_index()) + else: + iters.append(cycle(self._obj[d].to_index())) + data.update({k: np.nan for k in set(zip(*iters)) - + set(data.keys())}) + + # Use pandas and xarray IO methods to convert data, a dict, to a + # xr.DataArray of the correct shape, then extract its values + tmp = pandas.Series(data) + tmp.index.names = dims + tmp = xr.DataArray.from_series(tmp).reindex_like(self._obj[name]) + self._obj[name].values = tmp.values + + def _cache_data(self, name, index, dim, records): + """Read data for the Symbol *name* from the GDX file.""" + # Initiate the data read. The API method returns a number of records, + # which should match that given by gdxSymbolInfoX in _load_symbol() + records2 = self._api.data_read_str_start(index) + assert records == records2, \ + ('{}: gdxSymbolInfoX ({}) and gdxDataReadStrStart ({}) disagree on' + ' number of records.').format(name, records, records2) + + # Indices of data records, one list per dimension + elements = [list() for _ in range(dim)] + # Data points. Keys are index tuples, values are data. For a 1-D Set, + # the data is the GDX 'string number' of the text associated with the + # element + data = {} + try: + while True: # Loop over all records + labels, value, _ = self._api.data_read_str() # Next record + # Update elements with the indices + for j, label in enumerate(labels): + if label not in elements[j]: + elements[j].append(label) + # Convert a 1-D index from a tuple to a bare string + key = labels[0] if dim == 1 else tuple(labels) + # The value is a sequence, containing the level, marginal, + # lower & upper bounds, etc. Store only the value (first + # element). + data[key] = value[gdxcc.GMS_VAL_LEVEL] + except Exception: + if len(data) == records: + pass # All data has been read + else: # pragma: no cover + raise # Some other read error + + # Cache the read data + self._state[name].update({ + 'data': data, + 'elements': elements, + }) + + def _empty(self, *dims, **kwargs): + """Return an empty numpy.ndarray for a GAMS Set or Parameter.""" + size = [] + dtypes = [] + for d in dims: + size.append(len(self._obj[d])) + dtypes.append(self._obj[d].dtype) + dtype = kwargs.pop('dtype', np.result_type(*dtypes)) + fv = kwargs.pop('fill_value') + return np.full(size, fill_value=fv, dtype=dtype) + + def _infer_domain(self, name, domain, elements): + """Infer the domain of the Symbol *name*. + + Lazy GAMS modellers may create variables like myvar(*,*,*,*). If the + size of the universal set * is large, then attempting to instantiate a + xr.DataArray with this many elements may cause a MemoryError. For every + dimension of *name* defined on the domain '*' this method tries to find + a Set from the file which contains all the labels appearing in *name*'s + data. + + """ + if '*' not in domain: + return domain + debug('guessing a better domain for {}: {}'.format(name, domain)) + + # Domain as a list of references to Variables in the File/xr.Dataset + domain_ = [self._obj[d] for d in domain] + + for i, d in enumerate(domain_): # Iterate over dimensions + e = set(elements[i]) + if d.name != '*' or len(e) == 0: # pragma: no cover + assert set(d.values).issuperset(e) + continue # The stated domain matches the data; or no data + # '*' is given + if (self._state[name]['attrs']['type_code'] == gdxcc.GMS_DT_PAR and + self._implicit): + d = '_{}_{}'.format(name, i) + debug(('Constructing implicit set {} for dimension {} of {}\n' + ' {} instead of {} elements') + .format(d, name, i, len(e), len(self._obj['*']))) + self._obj.coords[d] = elements[i] + d = self._obj[d] + else: + # try to find a smaller domain for this dimension + # Iterate over every Set/Coordinate + for s in self._obj.coords.values(): + if s.ndim == 1 and set(s.values).issuperset(e) and \ + len(s) < len(d): + d = s # Found a smaller Set; use this instead + domain_[i] = d + + # Convert the references to names + inferred = [d.name for d in domain_] + + if domain != inferred: + # Store the result + self._state[name]['attrs']['domain_inferred'] = inferred + debug('…inferred {}.'.format(inferred)) + else: + debug('…failed.') + + return inferred def _initialize(self, filename, lazy, implicit, skip): self._api = GDX() @@ -122,6 +283,16 @@ def _initialize(self, filename, lazy, implicit, skip): self._initialized = True + def _lazy_load(self, key): + if not self._initialized: + return + keys = [key] if hashable(key) else key + + for k in keys: + if k in self._state and isinstance(self._state[k], dict): + debug('Lazy-loading {}'.format(k)) + self._load_symbol_data(k) + def _load_symbol(self, index): """Load the *index*-th Symbol in the GDX file.""" # Load basic information @@ -213,182 +384,25 @@ def _load_symbol_data(self, name): # Create an xr.DataArray with the Symbol's data self._add_symbol(name, dim, domain, attrs) - def _cache_data(self, name, index, dim, records): - """Read data for the Symbol *name* from the GDX file.""" - # Initiate the data read. The API method returns a number of records, - # which should match that given by gdxSymbolInfoX in _load_symbol() - records2 = self._api.data_read_str_start(index) - assert records == records2, \ - ('{}: gdxSymbolInfoX ({}) and gdxDataReadStrStart ({}) disagree on' - ' number of records.').format(name, records, records2) - - # Indices of data records, one list per dimension - elements = [list() for _ in range(dim)] - # Data points. Keys are index tuples, values are data. For a 1-D Set, - # the data is the GDX 'string number' of the text associated with the - # element - data = {} - try: - while True: # Loop over all records - labels, value, _ = self._api.data_read_str() # Next record - # Update elements with the indices - for j, label in enumerate(labels): - if label not in elements[j]: - elements[j].append(label) - # Convert a 1-D index from a tuple to a bare string - key = labels[0] if dim == 1 else tuple(labels) - # The value is a sequence, containing the level, marginal, - # lower & upper bounds, etc. Store only the value (first - # element). - data[key] = value[gdxcc.GMS_VAL_LEVEL] - except Exception: - if len(data) == records: - pass # All data has been read + def _loaded_and_cached(self, type_code): + """Return a list of loaded and not-loaded Symbols of *type_code*.""" + names = set() + for name, state in self._state.items(): + if state is True: + tc = self._obj._variables[name].attrs['_gdx_type_code'] + elif isinstance(state, dict): + tc = state['attrs']['type_code'] else: # pragma: no cover - raise # Some other read error - - # Cache the read data - self._state[name].update({ - 'data': data, - 'elements': elements, - }) - - def _infer_domain(self, name, domain, elements): - """Infer the domain of the Symbol *name*. - - Lazy GAMS modellers may create variables like myvar(*,*,*,*). If the - size of the universal set * is large, then attempting to instantiate a - xr.DataArray with this many elements may cause a MemoryError. For every - dimension of *name* defined on the domain '*' this method tries to find - a Set from the file which contains all the labels appearing in *name*'s - data. - - """ - if '*' not in domain: - return domain - debug('guessing a better domain for {}: {}'.format(name, domain)) - - # Domain as a list of references to Variables in the File/xr.Dataset - domain_ = [self._obj[d] for d in domain] - - for i, d in enumerate(domain_): # Iterate over dimensions - e = set(elements[i]) - if d.name != '*' or len(e) == 0: # pragma: no cover - assert set(d.values).issuperset(e) - continue # The stated domain matches the data; or no data - # '*' is given - if (self._state[name]['attrs']['type_code'] == gdxcc.GMS_DT_PAR and - self._implicit): - d = '_{}_{}'.format(name, i) - debug(('Constructing implicit set {} for dimension {} of {}\n' - ' {} instead of {} elements') - .format(d, name, i, len(e), len(self._obj['*']))) - self._obj.coords[d] = elements[i] - d = self._obj[d] - else: - # try to find a smaller domain for this dimension - # Iterate over every Set/Coordinate - for s in self._obj.coords.values(): - if s.ndim == 1 and set(s.values).issuperset(e) and \ - len(s) < len(d): - d = s # Found a smaller Set; use this instead - domain_[i] = d - - # Convert the references to names - inferred = [d.name for d in domain_] - - if domain != inferred: - # Store the result - self._state[name]['attrs']['domain_inferred'] = inferred - debug('…inferred {}.'.format(inferred)) - else: - debug('…failed.') - - return inferred + continue + if tc == type_code: + names.add(name) + return names def _root_dim(self, dim): """Return the ultimate ancestor of the 1-D Set *dim*.""" parent = self._obj[dim].dims[0] return dim if parent == dim else self._root_dim(parent) - def _empty(self, *dims, **kwargs): - """Return an empty numpy.ndarray for a GAMS Set or Parameter.""" - size = [] - dtypes = [] - for d in dims: - size.append(len(self._obj[d])) - dtypes.append(self._obj[d].dtype) - dtype = kwargs.pop('dtype', np.result_type(*dtypes)) - fv = kwargs.pop('fill_value') - return np.full(size, fill_value=fv, dtype=dtype) - - def _add_symbol(self, name, dim, domain, attrs): - """Add a xray.DataArray with the data from Symbol *name*.""" - # Transform the attrs for storage, unpack data - gdx_attrs = {'_gdx_{}'.format(k): v for k, v in attrs.items()} - data = self._state[name]['data'] - elements = self._state[name]['elements'] - - # Erase the cache; this also prevents __getitem__ from triggering lazy- - # loading, which is still in progress - self._state[name] = True - - kwargs = {} # Arguments to xr.Dataset.__setitem__() - if dim == 0: - # 0-D Variable or scalar Parameter - self._obj.__setitem__(name, ([], data.popitem()[1], gdx_attrs)) - return - elif attrs['type_code'] == gdxcc.GMS_DT_SET: # GAMS Set - if dim == 1: - # One-dimensional Set - self._obj.coords[name] = elements[0] - self._obj.coords[name].attrs = gdx_attrs - else: - # Multi-dimensional Sets are mappings indexed by other Sets; - # elements are either 'on'/True or 'off'/False - kwargs['dtype'] = bool - kwargs['fill_value'] = False - - # Don't define over the actual domain dimensions, but over the - # parent Set/xr.Coordinates for each dimension - dims = [self._root_dim(d) for d in domain] - - # Update coords - self._obj.coords.__setitem__(name, (dims, self._empty(*domain, - **kwargs), gdx_attrs)) - - # Store the elements - for k in data.keys(): - self._obj[name].loc[k] = k if dim == 1 else True - else: # 1+-dimensional GAMS Parameters - kwargs['dtype'] = float - kwargs['fill_value'] = np.nan - - dims = [self._root_dim(d) for d in domain] # Same as above - - # Create an empty xr.DataArray; this ensures that the data - # read in below has the proper form and indices - self._obj.__setitem__(name, (dims, self._empty(*domain, **kwargs), - gdx_attrs)) - - # Fill in extra keys - longest = np.argmax(self._obj[name].values.shape) - iters = [] - for i, d in enumerate(dims): - if i == longest: - iters.append(self._obj[d].to_index()) - else: - iters.append(cycle(self._obj[d].to_index())) - data.update({k: np.nan for k in set(zip(*iters)) - - set(data.keys())}) - - # Use pandas and xarray IO methods to convert data, a dict, to a - # xr.DataArray of the correct shape, then extract its values - tmp = pandas.Series(data) - tmp.index.names = dims - tmp = xr.DataArray.from_series(tmp).reindex_like(self._obj[name]) - self._obj[name].values = tmp.values - def dealias(self, name): """Identify the GDX Symbol that *name* refers to, and return the corresponding :py:class:`xarray.DataArray`.""" @@ -450,19 +464,10 @@ def info(self, name): else: return repr(self._obj[name]) - def _loaded_and_cached(self, type_code): - """Return a list of loaded and not-loaded Symbols of *type_code*.""" - names = set() - for name, state in self._state.items(): - if state is True: - tc = self._obj._variables[name].attrs['_gdx_type_code'] - elif isinstance(state, dict): - tc = state['attrs']['type_code'] - else: # pragma: no cover - continue - if tc == type_code: - names.add(name) - return names + def get_symbol_by_index(self, index): + """Retrieve the GAMS Symbol from the *index*-th position of the + :class:`File`.""" + return self._obj[self._index[index]] def set(self, name, as_dict=False): """Return the elements of GAMS Set *name*. @@ -494,8 +499,3 @@ def sets(self): def parameters(self): """Return a list of all GDX Parameters.""" return self._loaded_and_cached(gdxcc.GMS_DT_PAR) - - def get_symbol_by_index(self, index): - """Retrieve the GAMS Symbol from the *index*-th position of the - :class:`File`.""" - return self._obj[self._index[index]] From 5245281d52b7a2c7e71799fd905d6e456624e304 Mon Sep 17 00:00:00 2001 From: Paul Natsuo Kishimoto Date: Tue, 27 Sep 2016 12:36:14 -0400 Subject: [PATCH 3/4] simplify override of xarray.Dataset.__getitem__ This approach uses the existing code, instead of repeating it with modifications (which would make maintenance harder). --- gdx/__init__.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/gdx/__init__.py b/gdx/__init__.py index 1c63677..fbc2891 100644 --- a/gdx/__init__.py +++ b/gdx/__init__.py @@ -55,24 +55,12 @@ def open_dataset(filename, lazy=True, implicit=True, skip=set()): # Override xarray.Dataset.__getitem__ to add GDX lazy-loading def _dataset_getitem(self, key): - """Access variables or coordinates this dataset as a - :py:class:`~xarray.DataArray`. - - Indexing with a list of names will return a new ``Dataset`` object. - """ - if is_dict_like(key): - return self.isel(**key) - - # GDX lazy-loading + """Override :py:class:`~xarray.Dataset` __getitem__ method.""" self.gdx._lazy_load(key) + return self._base_getitem(key) - if hashable(key): - return self._construct_dataarray(key) - else: - return self._copy_listed(np.asarray(key)) - - -xr.Dataset.__getitem__ = _dataset_getitem +setattr(xr.Dataset, '_base_getitem', xr.Dataset.__getitem__) +setattr(xr.Dataset, '__getitem__', _dataset_getitem) @xr.register_dataset_accessor('gdx') @@ -284,7 +272,7 @@ def _initialize(self, filename, lazy, implicit, skip): self._initialized = True def _lazy_load(self, key): - if not self._initialized: + if not self._initialized or is_dict_like(key): return keys = [key] if hashable(key) else key From cd3ce0906cc8b91463b10f227233cd2071ca00ac Mon Sep 17 00:00:00 2001 From: Paul Natsuo Kishimoto Date: Tue, 27 Sep 2016 12:42:27 -0400 Subject: [PATCH 4/4] override xarray.Dataset.__contains__ + test --- gdx/__init__.py | 9 ++++++++- gdx/test/test_gdx.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/gdx/__init__.py b/gdx/__init__.py index fbc2891..65dcea6 100644 --- a/gdx/__init__.py +++ b/gdx/__init__.py @@ -63,11 +63,19 @@ def _dataset_getitem(self, key): setattr(xr.Dataset, '__getitem__', _dataset_getitem) +def _dataset_contains(self, key): + return self._base_contains(key) or key in self.gdx._state + +setattr(xr.Dataset, '_base_contains', xr.Dataset.__contains__) +setattr(xr.Dataset, '__contains__', _dataset_contains) + + @xr.register_dataset_accessor('gdx') class GDXAccessor(object): def __init__(self, xarray_obj): self._obj = xarray_obj self._initialized = False + self._state = {} def _add_symbol(self, name, dim, domain, attrs): """Add a xray.DataArray with the data from Symbol *name*.""" @@ -254,7 +262,6 @@ def _initialize(self, filename, lazy, implicit, skip): # Initialize private variables self._index = [None for _ in range(sc + 1)] - self._state = {} self._alias = {} self._implicit = implicit diff --git a/gdx/test/test_gdx.py b/gdx/test/test_gdx.py index 4617d2b..e4e54b0 100644 --- a/gdx/test/test_gdx.py +++ b/gdx/test/test_gdx.py @@ -132,6 +132,9 @@ def test_init(self, rawgdx): with pytest.raises(FileNotFoundError): gdx.open_dataset('nonexistent.gdx') + def test_contains(self, rawgdx): + assert 'p5' in gdx.open_dataset(rawgdx) + def test_num_parameters(self, gdxfile, actual): assert len(gdxfile.gdx.parameters()) == len(actual.data_vars)