Skip to content

Commit

Permalink
Merge pull request #23 from khaeru/issue/21
Browse files Browse the repository at this point in the history
Tests for #21
  • Loading branch information
khaeru authored Sep 26, 2016
2 parents 9c5fb83 + a283653 commit 2cd39e7
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 62 deletions.
13 changes: 7 additions & 6 deletions data/tests.gms
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ sets
;

parameters
p1(s) 'Example parameter with animal data' / /
p2(t) 'Example parameter with color data' / set.t 0 /
p3(s,t) 'Two-dimensional parameter' / set.s.y 1 /
p4(s1) 'Parameter defined over a subset' / set.s1 1 /
p5(*) 'Empty parameter defined over the universal set'
p1(s) 'Example parameter with animal data' / /
p2(t) 'Example parameter with color data' / set.t 0.1 /
p3(s,t) 'Two-dimensional parameter' / set.s.y 1 /
p4(s1) 'Parameter defined over a subset' / set.s1 1 /
p5(*) 'Empty parameter defined over the universal set'
p6(s,s1,t) 'Parameter defined over a set and its subset' / /
;

parameter p6(*,*) 'Parameter defined over the universal set' /
parameter p7(*,*) 'Parameter defined over the universal set' /
a.o 1
r.US 2
CA.b 3
Expand Down
154 changes: 98 additions & 56 deletions gdx/test/test_gdx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import OrderedDict

import numpy as np
import pytest
import xarray as xr

import gdx
from gdx.pycompat import FileNotFoundError
Expand Down Expand Up @@ -45,35 +44,73 @@ def gdxfile_explicit(rawgdx):
return gdx.File(rawgdx, implicit=False)


actual = OrderedDict([
('*', None),
('pi', 3.14),
('s', ['a', 'b', 'c', 'd', 'e', 'f', 'g']),
('t', ['r', 'o', 'y', 'g', 'b', 'i', 'v']),
('u', ['CA', 'US', 'CN', 'JP']),
('s1', ['a', 'b', 'c', 'd']),
('s2', ['e', 'f', 'g']),
('s3', None),
('s4', None),
('s5', ['b', 'd', 'f']),
('s6', ['b', 'd', 'f']),
('s7', None),
('p1', None),
('p2', None),
('p3', None),
('p4', None),
('p5', None),
('p6', None),
])
actual_info = {
'N sets': 12,
'N parameters': 7,
}
actual_info['N symbols'] = sum(actual_info.values()) + 1


def list_cmp(l1, l2):
return all([i1 == i2 for i1, i2 in zip(l1, l2)])
@pytest.fixture(scope='session')
def actual():
"""Return an xarray.Dataset with actual data.
The returned Dataset has the contents expected when tests.gms is compiled
to tests.gdx and loaded using gdx.File.
"""

# Sets, in the order expected in the GDX file
s = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
t = ['b', 'g', 'r', 'o', 'y', 'i', 'v']
u = ['CA', 'US', 'CN', 'JP']
star = s + t[2:] + u

# Create the dataset
ds = xr.Dataset({
'p1': ('s', np.full(len(s), np.nan)),
'p2': ('t', np.full(len(t), 0.1)),
'p3': (['s', 't'], np.full([len(s), len(t)], np.nan)),
'p5': ('*', np.full(len(star), np.nan)),
},
coords={
's': s,
't': t,
'u': u,
'*': star,
's1': ['a', 'b', 'c', 'd'],
's2': ['e', 'f', 'g'],
's3': [],
's4': [],
's5': ['b', 'd', 'f'],
's6': ['b', 'd', 'f'],
's7': [],
's_': s,
})

# Contents of parameters
ds['pi'] = 3.14
ds['p1'].loc['a'] = 1
ds['p3'].loc[:, 'y'] = 1
ds['p4'] = (['s1'], np.ones(ds['s1'].size))
ds['p6'] = (['s', 's1', 't'],
np.full([len(s), ds['s1'].size, len(t)], np.nan))

ds['p7'] = (['*', '*'],
np.full([ds['*'].size] * 2, np.nan))
ds['p7'].loc['a', 'o'] = 1
ds['p7'].loc['r', 'US'] = 2
ds['p7'].loc['CA', 'b'] = 3

# Set the _gdx_index attribute on each variable
order = ['*', 'pi', 's', 't', 'u', 's1', 's2', 's3', 's4', 's5', 's6',
's7', 'p1', 'p2', 'p3', 'p4', 'p5', 'p6', 'p7', 'e1', 'v1', 'v2',
's_', ]
for num, name in enumerate(order):
try:
ds[name].attrs['_gdx_index'] = num
except KeyError:
# These names do not appear in the loaded gdx.File object
assert name in ['e1', 'v1', 'v2']

return ds


def test_implicit(gdxfile_explicit):
N = len(gdxfile_explicit['*'])
assert gdxfile_explicit['p7'].shape == (N, N)


class TestAPI:
Expand All @@ -95,33 +132,32 @@ def test_init(self, rawgdx):
with pytest.raises(FileNotFoundError):
gdx.File('nonexistent.gdx')

def test_parameters(self, gdxfile):
params = gdxfile.parameters()
assert len(params) == actual_info['N parameters']
def test_num_parameters(self, gdxfile, actual):
print(gdxfile.parameters())
assert len(gdxfile.parameters()) == len(actual.data_vars)

def test_sets(self, gdxfile):
sets = gdxfile.sets()
assert len(sets) == actual_info['N sets']
def test_num_sets(self, gdxfile, actual):
assert len(gdxfile.sets()) == len(actual.coords)

def test_get_symbol(self, gdxfile):
gdxfile['s']

def test_get_symbol_by_index(self, gdxfile):
for i, name in enumerate(actual.keys()):
sym = gdxfile.get_symbol_by_index(i)
def test_get_symbol_by_index(self, gdxfile, actual):
for name in actual:
sym = gdxfile.get_symbol_by_index(actual[name].attrs['_gdx_index'])
assert sym.name == name
# Giving too high an index results in IndexError
with pytest.raises(IndexError):
gdxfile.get_symbol_by_index(gdxfile.attrs['symbol_count'] + 1)

def test_getattr(self, gdxfile):
for name in actual.keys():
def test_getattr(self, gdxfile, actual):
for name in actual:
getattr(gdxfile, name)
with pytest.raises(AttributeError):
gdxfile.notasymbolname

def test_getitem(self, gdxfile):
for name in actual.keys():
def test_getitem(self, gdxfile, actual):
for name in actual:
gdxfile[name]
with pytest.raises(KeyError):
gdxfile['notasymbolname']
Expand All @@ -139,24 +175,25 @@ def test_info2(self, rawgdx):
def test_dealias(self, gdxfile):
assert gdxfile.dealias('s_').equals(gdxfile['s'])

def test_extract(self, gdxfile, gdxfile_explicit):
for name in ['p1', 'p2', 'p3', 'p4']:
gdxfile.extract(name)
def test_domain(self, gdxfile, actual):
assert gdxfile['p6'].dims == actual['p6'].dims

def test_extract(self, gdxfile, gdxfile_explicit, actual):
# TODO add p5, p7
for name in ['p1', 'p2', 'p3', 'p4', 'p6']:
assert gdxfile.extract(name).equals(actual[name])

gdxfile_explicit.extract('p5')

with pytest.raises(KeyError):
gdxfile.extract('notasymbolname')

def test_implicit(self, gdxfile):
assert gdxfile['p6'].shape == (3, 3)


def test_implicit(gdxfile_explicit):
N = len(gdxfile_explicit['*'])
assert gdxfile_explicit['p6'].shape == (N, N)
assert gdxfile['p7'].shape == (3, 3)


class TestSet:
def test_len(self, gdxfile):
def test_len(self, gdxfile, actual):
assert len(gdxfile.s) == len(actual['s'])
assert len(gdxfile.set('s1')) == len(actual['s1'])
assert len(gdxfile.set('s2')) == len(actual['s2'])
Expand All @@ -170,12 +207,17 @@ def test_getitem(self, gdxfile):
def test_index(self, gdxfile):
assert np.argwhere(gdxfile.s.values == 'd') == 3

def test_iter(self, gdxfile):
def test_iter(self, gdxfile, actual):
for i, elem in enumerate(gdxfile.s):
assert actual['s'][i] == elem

def test_domain(self, gdxfile):
def domain(name): return gdxfile[name].attrs['_gdx_domain']
def domain(name):
return gdxfile[name].attrs['_gdx_domain']

def list_cmp(l1, l2):
return all([i1 == i2 for i1, i2 in zip(l1, l2)])

assert list_cmp(domain('s'), ['*'])
assert list_cmp(domain('t'), ['*'])
assert list_cmp(domain('u'), ['*'])
Expand Down

0 comments on commit 2cd39e7

Please sign in to comment.