diff --git a/ixmp/backend/jdbc.py b/ixmp/backend/jdbc.py index 88d1897fe..66adfb06d 100644 --- a/ixmp/backend/jdbc.py +++ b/ixmp/backend/jdbc.py @@ -14,7 +14,6 @@ from jpype import JClass import numpy as np import pandas as pd -from pandas.api.types import CategoricalDtype from ixmp import config from ixmp.core import Scenario @@ -70,23 +69,30 @@ def _temp_dbprops(driver=None, path=None, url=None, user=None, password=None): ] # Handle arguments - if driver == 'oracle': + if 'oracle' in driver.lower(): driver = 'oracle.jdbc.driver.OracleDriver' if url is None or path is not None: raise ValueError("use JDBCBackend(driver='oracle', url=…)") full_url = 'jdbc:oracle:thin:@{}'.format(url) - elif driver == 'hsqldb': + elif 'hsqldb' in driver.lower(): driver = 'org.hsqldb.jdbcDriver' - if path is None or url is not None: - raise ValueError("use JDBCBackend(driver='hsqldb', path=…)") - - # Convert Windows paths to use forward slashes per HyperSQL JDBC URL - # spec - url_path = str(PurePosixPath(Path(path).resolve())).replace('\\', '') - full_url = 'jdbc:hsqldb:file:{}'.format(url_path) + if path is None: + if url is None: + raise ValueError("use JDBCBackend(driver='hsqldb', path=…)") + elif re.match('^jdbc:hsqldb:(file|mem):.+$', url) is None: + raise ValueError('invalid JDBC URL provided: '.format(url)) + + if path is not None: + # Convert Windows paths to use forward slashes + # per HyperSQL JDBC URL spec + url_path = (str(PurePosixPath(Path(path).resolve())) + .replace('\\', '')) + full_url = 'jdbc:hsqldb:file:{}'.format(url_path) + else: + full_url = url user = user or 'ixmp' password = password or 'ixmp' else: @@ -166,7 +172,7 @@ def __init__(self, jvmargs=None, **kwargs): if 'dbprops' in kwargs: # Use an existing file - dbprops = Path(kwargs.pop('dbprops')) + dbprops = Path(kwargs.pop('dbprops')).resolve() if dbprops.exists(): # Existing properties file properties_file, info = _read_dbprops(dbprops) @@ -494,7 +500,23 @@ def item_index(self, s, name, sets_or_names): jitem = self._get_item(s, 'item', name, load=False) return list(getattr(jitem, f'getIdx{sets_or_names.title()}')()) - def item_get_elements(self, s, type, name, filters=None): + def item_get_elements(self, s, type, name, filters=None, + dtypes_map=None, default_dtype='category'): + """ Get item elements (GAMS symbol records) as a dataframe + + :param s: scenario + :param type: type of item (parameter, variable, equation or set) + :param name: name of item + :param filters: filters to limit item contant + :param dtypes_map: post-processing of raw symbol dimension/column data + :param default_dtype: default dtype to use for dimensions not defined + in dtypes_map + :return: a dataframe + """ + if dtypes_map is None: + dtypes_map = { + 'category': lambda dim_name: 'year' in dim_name + } try: # Retrieve the cached value with this exact set of filters return self.cache_get(s, type, name, filters) @@ -529,36 +551,33 @@ def item_get_elements(self, s, type, name, filters=None): columns = list(item.getIdxNames()) idx_sets = list(item.getIdxSets()) - # Prepare dtypes for index columns - dtypes = {} - for idx_name, idx_set in zip(columns, idx_sets): - dtypes[idx_name] = CategoricalDtype( - self.item_get_elements(s, 'set', idx_set)) - - # Prepare dtypes for additional columns + data = {} + # Prepare arrays with column values column + # NB [:] causes JPype to use a faster code path + for i, (idx_name, idx_set) in enumerate(zip(columns, idx_sets)): + for dtype in dtypes_map: + if dtypes_map[dtype](idx_name): + data[idx_name] = pd.Series(item.getCol(i, jList)[:], + dtype=dtype) + else: + data[idx_name] = pd.Series(item.getCol(i, jList)[:], + dtype=default_dtype) + + # Add type-specific columns if type == 'par': - columns.extend(['value', 'unit']) - dtypes['value'] = float - dtypes['unit'] = CategoricalDtype(self.jobj.getUnitList()) + data['value'] = pd.Series(item.getValues(jList)[:], + dtype='float') + data['unit'] = pd.Series(item.getUnits(jList)[:], + dtype='category') elif type in ('equ', 'var'): - columns.extend(['lvl', 'mrg']) - dtypes.update({'lvl': float, 'mrg': float}) + data['lvl'] = pd.Series(item.getLevels(jList)[:], + dtype='float') + data['mrg'] = pd.Series(item.getMarginals(jList)[:], + dtype='float') - # Prepare empty DataFrame - result = pd.DataFrame(index=pd.RangeIndex(len(jList)), - columns=columns) \ - .astype(dtypes) + # Construct DataFrame + result = pd.DataFrame.from_dict(data, orient='columns') - # Copy vectors from Java into DataFrame columns - # NB [:] causes JPype to use a faster code path - for i in range(len(idx_sets)): - result.iloc[:, i] = item.getCol(i, jList)[:] - if type == 'par': - result.loc[:, 'value'] = item.getValues(jList)[:] - result.loc[:, 'unit'] = item.getUnits(jList)[:] - elif type in ('equ', 'var'): - result.loc[:, 'lvl'] = item.getLevels(jList)[:] - result.loc[:, 'mrg'] = item.getMarginals(jList)[:] elif type == 'set': # Index sets result = pd.Series(item.getCol(0, jList)) diff --git a/ixmp/testing.py b/ixmp/testing.py index 3f345a288..f08c0303a 100644 --- a/ixmp/testing.py +++ b/ixmp/testing.py @@ -87,6 +87,14 @@ def test_mp(request, tmp_env, test_data_path): yield from create_test_mp(request, test_data_path, 'ixmptest') +@pytest.fixture(scope='function') +def test_mp_mem(): + """An ixmp.Platform connected to an in-memory database.""" + return Platform(backend='jdbc', driver='org.hsqldb.jdbcDriver', + url='jdbc:hsqldb:mem:ixmptest', + user='ixmp', password='ixmp') + + def create_test_mp(request, path, name): # Name of the test function, without the preceding 'test_' dirname = request.node.name.split('test_', 1)[1] diff --git a/setup.py b/setup.py index 57e6c357a..6b6e1adc0 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ EXTRAS_REQUIRE = { 'tests': ['codecov', 'jupyter', 'pretenders>=1.4.4', 'pytest-cov', - 'pytest>=3.9'], + 'pytest>=3.9', 'pytest-benchmark'], 'docs': ['numpydoc', 'sphinx', 'sphinx_rtd_theme', 'sphinxcontrib-bibtex'], 'tutorial': ['jupyter'], } diff --git a/tests/backend/test_jdbc_perf.py b/tests/backend/test_jdbc_perf.py new file mode 100644 index 000000000..87b88745f --- /dev/null +++ b/tests/backend/test_jdbc_perf.py @@ -0,0 +1,27 @@ +import ixmp + + +def _save_par(test_mp_mem, size=1000): + scen = ixmp.Scenario(test_mp_mem, + model='test', + scenario='scenario', + version='new') + scen.init_set('ii') + ii = range(1, size, 2) + scen.add_set('ii', ii) + scen.init_par('new_par', idx_sets='ii') + scen.add_par('new_par', ii, [1.2] * len(ii)) + scen.commit('init') + + +def test_save_par_1000(benchmark, test_mp_mem): + benchmark(_save_par, test_mp_mem, 1000) + + +def test_read_par_100000(benchmark, test_mp_mem): + def read_par(test_mp_mem): + scen = ixmp.Scenario(test_mp_mem, model='test', scenario='scenario') + return scen.par('new_par') + + _save_par(test_mp_mem, 100000) + benchmark(read_par, test_mp_mem)