diff --git a/data/tests.gms b/data/tests.gms index 06fa2cf..49fc46e 100644 --- a/data/tests.gms +++ b/data/tests.gms @@ -40,14 +40,15 @@ sets ; parameters - p1(s) 'Example parameter with animal data' / / - p2(t) 'Example parameter with color data' / set.t 0 / - p3(s,t) 'Two-dimensional parameter' / set.s.y 1 / - p4(s1) 'Parameter defined over a subset' / set.s1 1 / - p5(*) 'Empty parameter defined over the universal set' + p1(s) 'Example parameter with animal data' / / + p2(t) 'Example parameter with color data' / set.t 0.1 / + p3(s,t) 'Two-dimensional parameter' / set.s.y 1 / + p4(s1) 'Parameter defined over a subset' / set.s1 1 / + p5(*) 'Empty parameter defined over the universal set' + p6(s,s1,t) 'Parameter defined over a set and its subset' / / ; -parameter p6(*,*) 'Parameter defined over the universal set' / +parameter p7(*,*) 'Parameter defined over the universal set' / a.o 1 r.US 2 CA.b 3 diff --git a/gdx/test/test_gdx.py b/gdx/test/test_gdx.py index 807aa84..c14e25c 100644 --- a/gdx/test/test_gdx.py +++ b/gdx/test/test_gdx.py @@ -1,7 +1,6 @@ -from collections import OrderedDict - import numpy as np import pytest +import xarray as xr import gdx from gdx.pycompat import FileNotFoundError @@ -45,35 +44,73 @@ def gdxfile_explicit(rawgdx): return gdx.File(rawgdx, implicit=False) -actual = OrderedDict([ - ('*', None), - ('pi', 3.14), - ('s', ['a', 'b', 'c', 'd', 'e', 'f', 'g']), - ('t', ['r', 'o', 'y', 'g', 'b', 'i', 'v']), - ('u', ['CA', 'US', 'CN', 'JP']), - ('s1', ['a', 'b', 'c', 'd']), - ('s2', ['e', 'f', 'g']), - ('s3', None), - ('s4', None), - ('s5', ['b', 'd', 'f']), - ('s6', ['b', 'd', 'f']), - ('s7', None), - ('p1', None), - ('p2', None), - ('p3', None), - ('p4', None), - ('p5', None), - ('p6', None), - ]) -actual_info = { - 'N sets': 12, - 'N parameters': 7, - } -actual_info['N symbols'] = sum(actual_info.values()) + 1 - - -def list_cmp(l1, l2): - return all([i1 == i2 for i1, i2 in zip(l1, l2)]) +@pytest.fixture(scope='session') +def actual(): + """Return an xarray.Dataset with actual data. + + The returned Dataset has the contents expected when tests.gms is compiled + to tests.gdx and loaded using gdx.File. + """ + + # Sets, in the order expected in the GDX file + s = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + t = ['b', 'g', 'r', 'o', 'y', 'i', 'v'] + u = ['CA', 'US', 'CN', 'JP'] + star = s + t[2:] + u + + # Create the dataset + ds = xr.Dataset({ + 'p1': ('s', np.full(len(s), np.nan)), + 'p2': ('t', np.full(len(t), 0.1)), + 'p3': (['s', 't'], np.full([len(s), len(t)], np.nan)), + 'p5': ('*', np.full(len(star), np.nan)), + }, + coords={ + 's': s, + 't': t, + 'u': u, + '*': star, + 's1': ['a', 'b', 'c', 'd'], + 's2': ['e', 'f', 'g'], + 's3': [], + 's4': [], + 's5': ['b', 'd', 'f'], + 's6': ['b', 'd', 'f'], + 's7': [], + 's_': s, + }) + + # Contents of parameters + ds['pi'] = 3.14 + ds['p1'].loc['a'] = 1 + ds['p3'].loc[:, 'y'] = 1 + ds['p4'] = (['s1'], np.ones(ds['s1'].size)) + ds['p6'] = (['s', 's1', 't'], + np.full([len(s), ds['s1'].size, len(t)], np.nan)) + + ds['p7'] = (['*', '*'], + np.full([ds['*'].size] * 2, np.nan)) + ds['p7'].loc['a', 'o'] = 1 + ds['p7'].loc['r', 'US'] = 2 + ds['p7'].loc['CA', 'b'] = 3 + + # Set the _gdx_index attribute on each variable + order = ['*', 'pi', 's', 't', 'u', 's1', 's2', 's3', 's4', 's5', 's6', + 's7', 'p1', 'p2', 'p3', 'p4', 'p5', 'p6', 'p7', 'e1', 'v1', 'v2', + 's_', ] + for num, name in enumerate(order): + try: + ds[name].attrs['_gdx_index'] = num + except KeyError: + # These names do not appear in the loaded gdx.File object + assert name in ['e1', 'v1', 'v2'] + + return ds + + +def test_implicit(gdxfile_explicit): + N = len(gdxfile_explicit['*']) + assert gdxfile_explicit['p7'].shape == (N, N) class TestAPI: @@ -95,33 +132,32 @@ def test_init(self, rawgdx): with pytest.raises(FileNotFoundError): gdx.File('nonexistent.gdx') - def test_parameters(self, gdxfile): - params = gdxfile.parameters() - assert len(params) == actual_info['N parameters'] + def test_num_parameters(self, gdxfile, actual): + print(gdxfile.parameters()) + assert len(gdxfile.parameters()) == len(actual.data_vars) - def test_sets(self, gdxfile): - sets = gdxfile.sets() - assert len(sets) == actual_info['N sets'] + def test_num_sets(self, gdxfile, actual): + assert len(gdxfile.sets()) == len(actual.coords) def test_get_symbol(self, gdxfile): gdxfile['s'] - def test_get_symbol_by_index(self, gdxfile): - for i, name in enumerate(actual.keys()): - sym = gdxfile.get_symbol_by_index(i) + def test_get_symbol_by_index(self, gdxfile, actual): + for name in actual: + sym = gdxfile.get_symbol_by_index(actual[name].attrs['_gdx_index']) assert sym.name == name # Giving too high an index results in IndexError with pytest.raises(IndexError): gdxfile.get_symbol_by_index(gdxfile.attrs['symbol_count'] + 1) - def test_getattr(self, gdxfile): - for name in actual.keys(): + def test_getattr(self, gdxfile, actual): + for name in actual: getattr(gdxfile, name) with pytest.raises(AttributeError): gdxfile.notasymbolname - def test_getitem(self, gdxfile): - for name in actual.keys(): + def test_getitem(self, gdxfile, actual): + for name in actual: gdxfile[name] with pytest.raises(KeyError): gdxfile['notasymbolname'] @@ -139,24 +175,25 @@ def test_info2(self, rawgdx): def test_dealias(self, gdxfile): assert gdxfile.dealias('s_').equals(gdxfile['s']) - def test_extract(self, gdxfile, gdxfile_explicit): - for name in ['p1', 'p2', 'p3', 'p4']: - gdxfile.extract(name) + def test_domain(self, gdxfile, actual): + assert gdxfile['p6'].dims == actual['p6'].dims + + def test_extract(self, gdxfile, gdxfile_explicit, actual): + # TODO add p5, p7 + for name in ['p1', 'p2', 'p3', 'p4', 'p6']: + assert gdxfile.extract(name).equals(actual[name]) + gdxfile_explicit.extract('p5') + with pytest.raises(KeyError): gdxfile.extract('notasymbolname') def test_implicit(self, gdxfile): - assert gdxfile['p6'].shape == (3, 3) - - -def test_implicit(gdxfile_explicit): - N = len(gdxfile_explicit['*']) - assert gdxfile_explicit['p6'].shape == (N, N) + assert gdxfile['p7'].shape == (3, 3) class TestSet: - def test_len(self, gdxfile): + def test_len(self, gdxfile, actual): assert len(gdxfile.s) == len(actual['s']) assert len(gdxfile.set('s1')) == len(actual['s1']) assert len(gdxfile.set('s2')) == len(actual['s2']) @@ -170,12 +207,17 @@ def test_getitem(self, gdxfile): def test_index(self, gdxfile): assert np.argwhere(gdxfile.s.values == 'd') == 3 - def test_iter(self, gdxfile): + def test_iter(self, gdxfile, actual): for i, elem in enumerate(gdxfile.s): assert actual['s'][i] == elem def test_domain(self, gdxfile): - def domain(name): return gdxfile[name].attrs['_gdx_domain'] + def domain(name): + return gdxfile[name].attrs['_gdx_domain'] + + def list_cmp(l1, l2): + return all([i1 == i2 for i1, i2 in zip(l1, l2)]) + assert list_cmp(domain('s'), ['*']) assert list_cmp(domain('t'), ['*']) assert list_cmp(domain('u'), ['*'])