Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameterize treatment of item dimensions' types (#218) #219

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 57 additions & 38 deletions ixmp/backend/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from jpype import JClass
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

from ixmp import config
from ixmp.core import Scenario
Expand Down Expand Up @@ -70,23 +69,30 @@ def _temp_dbprops(driver=None, path=None, url=None, user=None, password=None):
]

# Handle arguments
if driver == 'oracle':
if 'oracle' in driver.lower():
driver = 'oracle.jdbc.driver.OracleDriver'

if url is None or path is not None:
raise ValueError("use JDBCBackend(driver='oracle', url=…)")

full_url = 'jdbc:oracle:thin:@{}'.format(url)
elif driver == 'hsqldb':
elif 'hsqldb' in driver.lower():
driver = 'org.hsqldb.jdbcDriver'

if path is None or url is not None:
raise ValueError("use JDBCBackend(driver='hsqldb', path=…)")

# Convert Windows paths to use forward slashes per HyperSQL JDBC URL
# spec
url_path = str(PurePosixPath(Path(path).resolve())).replace('\\', '')
full_url = 'jdbc:hsqldb:file:{}'.format(url_path)
if path is None:
if url is None:
raise ValueError("use JDBCBackend(driver='hsqldb', path=…)")
elif re.match('^jdbc:hsqldb:(file|mem):.+$', url) is None:
raise ValueError('invalid JDBC URL provided: '.format(url))

if path is not None:
# Convert Windows paths to use forward slashes
# per HyperSQL JDBC URL spec
url_path = (str(PurePosixPath(Path(path).resolve()))
.replace('\\', ''))
full_url = 'jdbc:hsqldb:file:{}'.format(url_path)
else:
full_url = url
user = user or 'ixmp'
password = password or 'ixmp'
else:
Expand Down Expand Up @@ -166,7 +172,7 @@ def __init__(self, jvmargs=None, **kwargs):

if 'dbprops' in kwargs:
# Use an existing file
dbprops = Path(kwargs.pop('dbprops'))
dbprops = Path(kwargs.pop('dbprops')).resolve()
if dbprops.exists():
# Existing properties file
properties_file, info = _read_dbprops(dbprops)
Expand Down Expand Up @@ -494,7 +500,23 @@ def item_index(self, s, name, sets_or_names):
jitem = self._get_item(s, 'item', name, load=False)
return list(getattr(jitem, f'getIdx{sets_or_names.title()}')())

def item_get_elements(self, s, type, name, filters=None):
def item_get_elements(self, s, type, name, filters=None,
dtypes_map=None, default_dtype='category'):
""" Get item elements (GAMS symbol records) as a dataframe

:param s: scenario
:param type: type of item (parameter, variable, equation or set)
:param name: name of item
:param filters: filters to limit item contant
:param dtypes_map: post-processing of raw symbol dimension/column data
:param default_dtype: default dtype to use for dimensions not defined
in dtypes_map
:return: a dataframe
"""
if dtypes_map is None:
dtypes_map = {
'category': lambda dim_name: 'year' in dim_name
}
try:
# Retrieve the cached value with this exact set of filters
return self.cache_get(s, type, name, filters)
Expand Down Expand Up @@ -529,36 +551,33 @@ def item_get_elements(self, s, type, name, filters=None):
columns = list(item.getIdxNames())
idx_sets = list(item.getIdxSets())

# Prepare dtypes for index columns
dtypes = {}
for idx_name, idx_set in zip(columns, idx_sets):
dtypes[idx_name] = CategoricalDtype(
self.item_get_elements(s, 'set', idx_set))

# Prepare dtypes for additional columns
data = {}
# Prepare arrays with column values column
# NB [:] causes JPype to use a faster code path
for i, (idx_name, idx_set) in enumerate(zip(columns, idx_sets)):
for dtype in dtypes_map:
if dtypes_map[dtype](idx_name):
data[idx_name] = pd.Series(item.getCol(i, jList)[:],
dtype=dtype)
else:
data[idx_name] = pd.Series(item.getCol(i, jList)[:],
dtype=default_dtype)

# Add type-specific columns
if type == 'par':
columns.extend(['value', 'unit'])
dtypes['value'] = float
dtypes['unit'] = CategoricalDtype(self.jobj.getUnitList())
data['value'] = pd.Series(item.getValues(jList)[:],
dtype='float')
data['unit'] = pd.Series(item.getUnits(jList)[:],
dtype='category')
elif type in ('equ', 'var'):
columns.extend(['lvl', 'mrg'])
dtypes.update({'lvl': float, 'mrg': float})
data['lvl'] = pd.Series(item.getLevels(jList)[:],
dtype='float')
data['mrg'] = pd.Series(item.getMarginals(jList)[:],
dtype='float')

# Prepare empty DataFrame
result = pd.DataFrame(index=pd.RangeIndex(len(jList)),
columns=columns) \
.astype(dtypes)
# Construct DataFrame
result = pd.DataFrame.from_dict(data, orient='columns')

# Copy vectors from Java into DataFrame columns
# NB [:] causes JPype to use a faster code path
for i in range(len(idx_sets)):
result.iloc[:, i] = item.getCol(i, jList)[:]
if type == 'par':
result.loc[:, 'value'] = item.getValues(jList)[:]
result.loc[:, 'unit'] = item.getUnits(jList)[:]
elif type in ('equ', 'var'):
result.loc[:, 'lvl'] = item.getLevels(jList)[:]
result.loc[:, 'mrg'] = item.getMarginals(jList)[:]
elif type == 'set':
# Index sets
result = pd.Series(item.getCol(0, jList))
Expand Down
8 changes: 8 additions & 0 deletions ixmp/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def test_mp(request, tmp_env, test_data_path):
yield from create_test_mp(request, test_data_path, 'ixmptest')


@pytest.fixture(scope='function')
def test_mp_mem():
"""An ixmp.Platform connected to an in-memory database."""
return Platform(backend='jdbc', driver='org.hsqldb.jdbcDriver',
url='jdbc:hsqldb:mem:ixmptest',
user='ixmp', password='ixmp')


def create_test_mp(request, path, name):
# Name of the test function, without the preceding 'test_'
dirname = request.node.name.split('test_', 1)[1]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

EXTRAS_REQUIRE = {
'tests': ['codecov', 'jupyter', 'pretenders>=1.4.4', 'pytest-cov',
'pytest>=3.9'],
'pytest>=3.9', 'pytest-benchmark'],
'docs': ['numpydoc', 'sphinx', 'sphinx_rtd_theme', 'sphinxcontrib-bibtex'],
'tutorial': ['jupyter'],
}
Expand Down
27 changes: 27 additions & 0 deletions tests/backend/test_jdbc_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ixmp


def _save_par(test_mp_mem, size=1000):
scen = ixmp.Scenario(test_mp_mem,
model='test',
scenario='scenario',
version='new')
scen.init_set('ii')
ii = range(1, size, 2)
scen.add_set('ii', ii)
scen.init_par('new_par', idx_sets='ii')
scen.add_par('new_par', ii, [1.2] * len(ii))
scen.commit('init')


def test_save_par_1000(benchmark, test_mp_mem):
benchmark(_save_par, test_mp_mem, 1000)


def test_read_par_100000(benchmark, test_mp_mem):
def read_par(test_mp_mem):
scen = ixmp.Scenario(test_mp_mem, model='test', scenario='scenario')
return scen.par('new_par')

_save_par(test_mp_mem, 100000)
benchmark(read_par, test_mp_mem)