From 549757e0fb17a575bef6ec34afecdecb2ae57ac7 Mon Sep 17 00:00:00 2001 From: "Martin K. Scherer" Date: Mon, 4 Jul 2016 14:00:56 +0200 Subject: [PATCH] Sqlite db backend for TrajInfo (len, dim, offsets) (#798) * [coordinates/TrajectoryInfoCache] implemented sqlite backend Sqlite provides database locking for parallel filesystems, so we can use it on clusters to cache trajectory information. A LRU pattern has been implemented to clean least recently used entries. The entries are spread over several databases to avoid locking timeouts. The hash function is now computed via python stdlib hashlib and uses MD5. --- pyemma/_base/progress/reporter.py | 2 +- pyemma/coordinates/data/_base/datasource.py | 2 +- pyemma/coordinates/data/numpy_filereader.py | 2 +- .../data/util/traj_info_backends.py | 354 ++++++++++++++++++ .../coordinates/data/util/traj_info_cache.py | 174 ++++----- .../coordinates/tests/test_traj_info_cache.py | 113 ++++-- pyemma/pyemma.cfg | 5 +- pyemma/thermo/estimators/_callback.py | 1 + pyemma/util/config.py | 18 + pyemma/util/debug.py | 2 +- pyemma/util/tests/test_config.py | 6 + 11 files changed, 549 insertions(+), 130 deletions(-) create mode 100644 pyemma/coordinates/data/util/traj_info_backends.py diff --git a/pyemma/_base/progress/reporter.py b/pyemma/_base/progress/reporter.py index 8de68aaca..a8d230ad3 100644 --- a/pyemma/_base/progress/reporter.py +++ b/pyemma/_base/progress/reporter.py @@ -42,7 +42,7 @@ def show_progress(self): if not hasattr(self, "_show_progress"): from pyemma import config val = config.show_progress_bars - self._show_progress = val + self._show_progress = val return self._show_progress @show_progress.setter diff --git a/pyemma/coordinates/data/_base/datasource.py b/pyemma/coordinates/data/_base/datasource.py index 91feb1948..560e103cb 100644 --- a/pyemma/coordinates/data/_base/datasource.py +++ b/pyemma/coordinates/data/_base/datasource.py @@ -142,7 +142,7 @@ def filenames(self, filename_list): self._offsets = offsets else: - # propate this until we finally have a a reader? + # propagate this until we finally have a a reader self.data_producer.filenames = filename_list @property diff --git a/pyemma/coordinates/data/numpy_filereader.py b/pyemma/coordinates/data/numpy_filereader.py index abcb34662..1934cb1a6 100644 --- a/pyemma/coordinates/data/numpy_filereader.py +++ b/pyemma/coordinates/data/numpy_filereader.py @@ -87,7 +87,7 @@ def _reshape(self, array): def _load_file(self, itraj): filename = self._filenames[itraj] - self._logger.debug("opening file %s" % filename) + #self._logger.debug("opening file %s" % filename) if filename.endswith('.npy'): x = np.load(filename, mmap_mode=self.mmap_mode) diff --git a/pyemma/coordinates/data/util/traj_info_backends.py b/pyemma/coordinates/data/util/traj_info_backends.py new file mode 100644 index 000000000..438dbc10a --- /dev/null +++ b/pyemma/coordinates/data/util/traj_info_backends.py @@ -0,0 +1,354 @@ + +# This file is part of PyEMMA. +# +# Copyright (c) 2016 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) +# +# PyEMMA is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +''' +Created on 25.05.2016 + +@author: marscher +''' + +import itertools +import os +import time +import warnings +from io import BytesIO +from operator import itemgetter + +import numpy as np +from six import StringIO + +from pyemma.coordinates.data.util.traj_info_cache import (UnknownDBFormatException, + TrajInfo, + TrajectoryInfoCache, + logger) +from pyemma.util import config + + +class AbstractDB(object): + def set(self, value): + # value: TrajInfo + pass + + def update(self, value): + pass + + def close(self): + pass + + def sync(self): + pass + + def get(self, key): + # should raise KeyError in case of non existent key + pass + + @property + def db_version(self): + pass + + @db_version.setter + def db_version(self, val): + pass + + +class DictDB(AbstractDB): + def __init__(self): + self._db = {} + self.db_version = TrajectoryInfoCache.DB_VERSION + + def set(self, value): + self._db[value.hash_value] = value + + def update(self, value): + self._db[value.hash_value] = value + + @property + def db_version(self): + return self._db['version'] + + @db_version.setter + def db_version(self, version): + self._db['version'] = version + + @property + def num_entries(self): + return len(self._db) - 1 # substract field for db_version + + +class SqliteDB(AbstractDB): + def __init__(self, filename=None, clean_n_entries=30): + """ + :param filename: path to database file + :param clean_n_entries: during cleaning delete n % entries. + """ + self.clean_n_entries = clean_n_entries + import sqlite3 + + # register numpy array conversion functions + # uses "np_array" type in sql tables + def adapt_array(arr): + out = BytesIO() + np.savez_compressed(out, offsets=arr) + out.seek(0) + return out.read() + + def convert_array(text): + out = BytesIO(text) + out.seek(0) + npz = np.load(out) + arr = npz['offsets'] + npz.close() + return arr + # Converts np.array to TEXT when inserting + sqlite3.register_adapter(np.ndarray, adapt_array) + + # Converts TEXT to np.array when selecting + sqlite3.register_converter("NPARRAY", convert_array) + self._database = sqlite3.connect(filename if filename is not None else ":memory:", + detect_types=sqlite3.PARSE_DECLTYPES, timeout=1000*1000, + isolation_level=None) + self.filename = filename + + try: + cursor = self._database.execute("select num from version") + row = cursor.fetchone() + if not row: + self.db_version = TrajectoryInfoCache.DB_VERSION + version = self.db_version + else: + version = row[0] + if version != TrajectoryInfoCache.DB_VERSION: + # drop old db? or try to convert? + self._create_new_db() + except sqlite3.OperationalError as e: + if "no such table" in str(e): + self._create_new_db() + self.db_version = TrajectoryInfoCache.DB_VERSION + except sqlite3.DatabaseError: + bak = filename + ".bak" + warnings.warn("TrajInfo database corrupted. Backing up file to %s and start with new one." % bak) + self._database.close() + import shutil + shutil.move(filename, bak) + SqliteDB.__init__(self, filename) + + def _create_new_db(self): + # assumes self.database is a sqlite3.Connection + create_version_table = "CREATE TABLE version (num INTEGER PRIMARY KEY);" + create_info_table = """CREATE TABLE traj_info( + hash VARCHAR(64) PRIMARY KEY, + length INTEGER, + ndim INTEGER, + offsets NPARRAY, + abs_path VARCHAR(4096) null, + version INTEGER, + lru_db INTEGER + ); + """ + self._database.execute(create_version_table) + self._database.execute(create_info_table) + self._database.commit() + + def close(self): + self._database.close() + + @property + def db_version(self): + cursor = self._database.execute("select num from version") + row = cursor.fetchone() + if not row: + raise RuntimeError("unknown db version") + return row[0] + + @db_version.setter + def db_version(self, val): + self._database.execute("insert into version VALUES (?)", [val]) + self._database.commit() + + @property + def num_entries(self): + # cursor = self._database.execute("SELECT hash FROM traj_info;") + # return len(cursor.fetchall()) + c = self._database.execute("SELECT COUNT(hash) from traj_info;").fetchone() + return int(c[0]) + + def set(self, traj_info): + import sqlite3 + values = ( + traj_info.hash_value, traj_info.length, traj_info.ndim, + np.array(traj_info.offsets), traj_info.abs_path, TrajectoryInfoCache.DB_VERSION, + # lru db + self._database_from_key(traj_info.hash_value) + ) + statement = ("INSERT INTO traj_info (hash, length, ndim, offsets, abs_path, version, lru_db)" + "VALUES (?, ?, ?, ?, ?, ?, ?)", values) + try: + self._database.execute(*statement) + except sqlite3.IntegrityError: + logger.exception() + return + self._database.commit() + + self._update_time_stamp(hash_value=traj_info.hash_value) + + current_size = os.stat(self.filename).st_size + if (self.num_entries >= config.traj_info_max_entries or + # current_size is in bytes, while traj_info_max_size is in MB + 1.*current_size / 1024**2 >= config.traj_info_max_size): + logger.info("Cleaning database because it has too much entries or is too large.\n" + "Entries: %s. Size: %.2fMB. Configured max_entires: %s. Max_size: %sMB" + % (self.num_entries, (current_size*1.0 / 1024**2), + config.traj_info_max_entries, config.traj_info_max_size)) + self._clean(n=self.clean_n_entries) + + def get(self, key): + cursor = self._database.execute("SELECT * FROM traj_info WHERE hash=?", (key,)) + row = cursor.fetchone() + if not row: + raise KeyError() + info = self._create_traj_info(row) + self._update_time_stamp(key) + return info + + def _database_from_key(self, key): + """ + gets the database name for the given key. Should ensure a uniform spread + of keys over the databases in order to minimize waiting times. Since the + database has to be locked for updates and multiple processes want to write, + each process has to wait until the lock has been released. + + By default the LRU databases will be stored in a sub directory "tra_info_usage" + lying next to the main database. + + :param key: hash of the TrajInfo instance + :return: str, database path + """ + from pyemma.util.files import mkdir_p + hash_value_long = int(key, 16) + # bin hash to one of either 10 different databases + # TODO: make a configuration parameter out of this number + db_name = str(hash_value_long)[-1] + '.db' + directory = os.path.dirname(self.filename) + os.path.sep + 'traj_info_usage' + mkdir_p(directory) + return os.path.join(directory, db_name) + + def _update_time_stamp(self, hash_value): + """ timestamps are being stored distributed over several lru databases. + The timestamp is a time.time() snapshot (float), which are seconds since epoch.""" + db_name = self._database_from_key(hash_value) + import sqlite3 + + with sqlite3.connect(db_name) as conn: + """ last_read is a result of time.time()""" + conn.execute('CREATE TABLE IF NOT EXISTS usage ' + '(hash VARCHAR(32), last_read FLOAT)') + conn.commit() + cur = conn.execute('select * from usage where hash=?', (hash_value,)) + row = cur.fetchone() + if not row: + conn.execute("insert into usage(hash, last_read) values(?, ?)", (hash_value, time.time())) + else: + conn.execute("update usage set last_read=? where hash=?", (time.time(), hash_value)) + conn.commit() + + @staticmethod + def _create_traj_info(row): + # convert a database row to a TrajInfo object + try: + hash = row[0] + length = row[1] + ndim = row[2] + offsets = row[3] + assert isinstance(offsets, np.ndarray) + abs_path = row[4] + version = row[5] + + info = TrajInfo() + info._version = version + if version == 2: + info._hash = hash + info._ndim = ndim + info._length = length + info._offsets = offsets + info._abs_path = abs_path + else: + raise ValueError("unknown version %s" % version) + return info + except Exception as ex: + logger.exception(ex) + raise UnknownDBFormatException(ex) + + @staticmethod + def _format_tuple_for_sql(value): + value = tuple(str(v) for v in value) + return repr(value)[1:-2 if len(value) == 1 else -1] + + def _clean(self, n): + """ + obtain n% oldest entries by looking into the usage databases. Then these entries + are deleted first from the traj_info db and afterwards from the associated LRU dbs. + + :param n: delete n% entries in traj_info db [and associated LRU (usage) dbs]. + """ + # delete the n % oldest entries in the database + import sqlite3 + num_delete = int(self.num_entries / 100.0 * n) + logger.debug("removing %i entries from db" % num_delete) + lru_dbs = self._database.execute("select hash, lru_db from traj_info").fetchall() + lru_dbs.sort(key=itemgetter(1)) + hashs_by_db = {} + age_by_hash = [] + for k, v in itertools.groupby(lru_dbs, key=itemgetter(1)): + hashs_by_db[k] = list(x[0] for x in v) + + # debug: distribution + len_by_db = {os.path.basename(db): len(hashs_by_db[db]) for db in hashs_by_db.keys()} + logger.debug("distribution of lru: %s" % str(len_by_db)) + ### end dbg + + self.lru_timeout = 1000 #1 sec + + # collect timestamps from databases + for db in hashs_by_db.keys(): + with sqlite3.connect(db, timeout=self.lru_timeout) as conn: + rows = conn.execute("select hash, last_read from usage").fetchall() + for r in rows: + age_by_hash.append((r[0], float(r[1]), db)) + + # sort by age + age_by_hash.sort(key=itemgetter(1)) + if len(age_by_hash)>=2: + assert[age_by_hash[-1] > age_by_hash[-2]] + ids = map(itemgetter(0), age_by_hash[:num_delete]) + ids = tuple(map(str, ids)) + + sql_compatible_ids = SqliteDB._format_tuple_for_sql(ids) + + stmnt = "DELETE FROM traj_info WHERE hash in (%s)" % sql_compatible_ids + cur = self._database.execute(stmnt) + self._database.commit() + assert cur.rowcount == len(ids), "deleted not as many rows(%s) as desired(%s)" %(cur.rowcount, len(ids)) + + # iterate over all LRU databases and delete those ids, we've just deleted from the main db. + age_by_hash.sort(key=itemgetter(2)) + for db, values in itertools.groupby(age_by_hash, key=itemgetter(2)): + values = tuple(v[0] for v in values) + with sqlite3.connect(db, timeout=self.lru_timeout) as conn: + stmnt = "DELETE FROM usage WHERE hash IN (%s)" \ + % SqliteDB._format_tuple_for_sql(values) + curr = conn.execute(stmnt) + assert curr.rowcount == len(values), curr.rowcount diff --git a/pyemma/coordinates/data/util/traj_info_cache.py b/pyemma/coordinates/data/util/traj_info_cache.py index c9c2acf3d..921b1e998 100644 --- a/pyemma/coordinates/data/util/traj_info_cache.py +++ b/pyemma/coordinates/data/util/traj_info_cache.py @@ -22,18 +22,16 @@ from __future__ import absolute_import +import hashlib +import os +import sys +import warnings from io import BytesIO from logging import getLogger -import os -from threading import Semaphore -from pyemma.util import config -import six import numpy as np -if six.PY2: - import dumbdbm -else: - from dbm import dumb as dumbdbm + +from pyemma.util import config logger = getLogger(__name__) @@ -57,6 +55,7 @@ def __init__(self, ndim=0, length=0, offsets=None): self._version = 1 self._hash = -1 + self._abs_path = None @property def version(self): @@ -86,6 +85,21 @@ def hash_value(self): def hash_value(self, val): self._hash = val + @property + def abs_path(self): + return self._abs_path + + @abs_path.setter + def abs_path(self, val): + self._abs_path = val + + def offsets_to_bytes(self): + assert self.hash_value != -1 + fh = BytesIO() + np.savez_compressed(fh, offsets=self.offsets) + fh.seek(0) + return fh.read() + def __eq__(self, other): return (isinstance(other, self.__class__) and self.version == other.version @@ -95,30 +109,9 @@ def __eq__(self, other): and np.all(self.offsets == other.offsets) ) - -def create_traj_info(db_val): - assert isinstance(db_val, (six.string_types, bytes)) - if six.PY3 and isinstance(db_val, six.string_types): - db_val = bytes(db_val.encode('utf-8', errors='ignore')) - fh = BytesIO(db_val) - - try: - arr = np.load(fh)['data'] - info = TrajInfo() - header = arr[0] - - version = header['data_format_version'] - info._version = version - if version == 1: - info._hash = header['filehash'] - info._ndim = arr[1] - info._length = arr[2] - info._offsets = arr[3] - else: - raise ValueError("unknown version %s" % version) - return info - except Exception as ex: - raise UnknownDBFormatException(ex) + def __str__(self): + return "[TrajInfo hash={hash}, len={len}, dim={dim}, path={path}". \ + format(hash=self.hash_value, len=self.length, dim=self.ndim, path=self.abs_path) class TrajectoryInfoCache(object): @@ -138,56 +131,41 @@ class TrajectoryInfoCache(object): """ _instance = None - DB_VERSION = '1' + DB_VERSION = 2 @staticmethod def instance(): + """ :returns the TrajectoryInfoCache singleton instance""" if TrajectoryInfoCache._instance is None: # singleton pattern - filename = os.path.join(config.cfg_dir, "trajlen_cache") + filename = os.path.join(config.cfg_dir, "traj_info.sqlite3") TrajectoryInfoCache._instance = TrajectoryInfoCache(filename) - # sync db to hard drive at exit. - if hasattr(TrajectoryInfoCache._instance._database, 'sync'): - import atexit - @atexit.register - def write_at_exit(): - TrajectoryInfoCache._instance._database.sync() - return TrajectoryInfoCache._instance def __init__(self, database_filename=None): - # for now we disable traj info cache persistence! - database_filename = None self.database_filename = database_filename - if database_filename is not None: - try: - self._database = dumbdbm.open(database_filename, flag="c") - except dumbdbm.error as e: - try: - os.unlink(database_filename) - self._database = dumbdbm.open(database_filename, flag="n") - # persist file right now, since it was broken - self._set_curr_db_version(TrajectoryInfoCache.DB_VERSION) - # close and re-open to ensure file exists - self._database.close() - self._database = dumbdbm.open(database_filename, flag="w") - except OSError: - raise RuntimeError('corrupted database in "%s" could not be deleted' - % os.path.abspath(database_filename)) - else: - self._database = {} - self._set_curr_db_version(TrajectoryInfoCache.DB_VERSION) - self._write_protector = Semaphore() + # have no filename, use in memory sqlite db + # have no sqlite module, use dict + # have sqlite and file, create db with given filename + + try: + import sqlite3 + from pyemma.coordinates.data.util.traj_info_backends import SqliteDB + self._database = SqliteDB(self.database_filename) + except ImportError: + warnings.warn("sqlite3 package not available, persistant storage of trajectory info not possible!") + from pyemma.coordinates.data.util.traj_info_backends import DictDB + self._database = DictDB() @property def current_db_version(self): - return self._current_db_version + return self._database.db_version - def _set_curr_db_version(self, val): - self._database['db_version'] = val - self._current_db_version = val + @property + def num_entries(self): + return self._database.num_entries def _handle_csv(self, reader, filename, length): # this is maybe a bit ugly, but so far we do not store the dialect of csv files in @@ -200,48 +178,32 @@ def _handle_csv(self, reader, filename, length): def __getitem__(self, filename_reader_tuple): filename, reader = filename_reader_tuple - key = self._get_file_hash(filename) - result = None + abs_path = os.path.abspath(filename) + key = self._get_file_hash_v2(filename) try: - result = str(self._database[key]) - info = create_traj_info(result) - + info = self._database.get(key) + if not isinstance(info, TrajInfo): + raise KeyError() self._handle_csv(reader, filename, info.length) - + # if path has changed, update it + if not info.abs_path == abs_path: + info.abs_path = abs_path + self._database.update(info) # handle cache misses and not interpretable results by re-computation. # Note: this also handles UnknownDBFormatExceptions! except KeyError: info = reader._get_traj_info(filename) info.hash_value = key + info.abs_path = abs_path # store info in db - result = self.__setitem__(filename, info) + self.__setitem__(info) # save forcefully now if hasattr(self._database, 'sync'): - logger.debug("sync db after adding new entry") self._database.sync() return info - def __format_value(self, traj_info): - assert traj_info.hash_value != -1 - fh = BytesIO() - - header = {'data_format_version': 1, - 'filehash': traj_info.hash_value, # back reference to file by hash - } - - array = np.empty(4, dtype=object) - - array[0] = header - array[1] = traj_info.ndim - array[2] = traj_info.length - array[3] = traj_info.offsets - - np.savez_compressed(fh, data=array) - fh.seek(0) - return fh.read() - def _get_file_hash(self, filename): statinfo = os.stat(filename) @@ -258,14 +220,28 @@ def _get_file_hash(self, filename): hash_value ^= hash(data) return str(hash_value) - def __setitem__(self, filename, traj_info): - dbval = self.__format_value(traj_info) + def _get_file_hash_v2(self, filename): + statinfo = os.stat(filename) + # now read the first megabyte and hash it + with open(filename, mode='rb') as fh: + data = fh.read(1024) - self._write_protector.acquire() - self._database[str(traj_info.hash_value)] = dbval - self._write_protector.release() + if sys.version_info > (3,): + long = int - return dbval + hasher = hashlib.md5() + hasher.update(os.path.basename(filename).encode('utf-8')) + hasher.update(str(statinfo.st_mtime).encode('ascii')) + hasher.update(str(statinfo.st_size).encode('ascii')) + hasher.update(data) + return hasher.hexdigest() + + def __setitem__(self, traj_info): + self._database.set(traj_info) def clear(self): self._database.clear() + + def close(self): + """ you most likely never want to call this! """ + self._database.close() diff --git a/pyemma/coordinates/tests/test_traj_info_cache.py b/pyemma/coordinates/tests/test_traj_info_cache.py index c475b47bc..9b7a3423d 100644 --- a/pyemma/coordinates/tests/test_traj_info_cache.py +++ b/pyemma/coordinates/tests/test_traj_info_cache.py @@ -22,19 +22,16 @@ from __future__ import absolute_import +from contextlib import contextmanager from tempfile import NamedTemporaryFile -try: - import bsddb - have_bsddb = True -except ImportError: - have_bsddb = False import os -import six import tempfile import unittest +import mock + from pyemma.coordinates import api from pyemma.coordinates.data.feature_reader import FeatureReader from pyemma.coordinates.data.numpy_filereader import NumPyFileReader @@ -49,12 +46,6 @@ import pyemma import numpy as np -if six.PY2: - import dumbdbm - import mock -else: - from dbm import dumb as dumbdbm - from unittest import mock xtcfiles = get_bpti_test_data()['trajs'] pdbfile = get_bpti_test_data()['top'] @@ -64,30 +55,36 @@ class TestTrajectoryInfoCache(unittest.TestCase): @classmethod def setUpClass(cls): - cls.work_dir = tempfile.mkdtemp("traj_cache_test") + cls.old_instance = TrajectoryInfoCache.instance() + cls.old_show_pg = config.show_progress_bars + config.show_progress_bars = False def setUp(self): + self.work_dir = tempfile.mkdtemp("traj_cache_test") self.tmpfile = tempfile.mktemp(dir=self.work_dir) self.db = TrajectoryInfoCache(self.tmpfile) - assert len(self.db._database) == 1, len(self.db._database) - assert 'db_version' in self.db._database - assert int(self.db._database['db_version']) >= 1 + # overwrite TrajectoryInfoCache._instance with self.db... + TrajectoryInfoCache._instance = self.db def tearDown(self): - del self.db + self.db.close() + os.unlink(self.tmpfile) + + import shutil + shutil.rmtree(self.work_dir, ignore_errors=True) @classmethod def tearDownClass(cls): - import shutil - shutil.rmtree(cls.work_dir, ignore_errors=True) + TrajectoryInfoCache._instance = cls.old_instance + config.show_progress_bars = cls.old_show_pg def test_get_instance(self): # test for exceptions in singleton creation inst = TrajectoryInfoCache.instance() inst.current_db_version + self.assertIs(inst, self.db) - @unittest.skip("persistence currently disabled.") def test_store_load_traj_info(self): x = np.random.random((10, 3)) my_conf = config() @@ -97,8 +94,8 @@ def test_store_load_traj_info(self): np.savetxt(fh.name, x) reader = api.source(fh.name) info = self.db[fh.name, reader] - self.db._database.close() - self.db._database = dumbdbm.open(self.db.database_filename, 'r') + self.db.close() + self.db.__init__(self.db._database.filename) info2 = self.db[fh.name, reader] self.assertEqual(info2, info) @@ -106,7 +103,7 @@ def test_exceptions(self): # in accessible files not_existant = ''.join( chr(i) for i in np.random.random_integers(65, 90, size=10)) + '.npy' - bad = [not_existant] # should be unaccessible or non existant + bad = [not_existant] # should be unaccessible or non existent with self.assertRaises(ValueError) as cm: api.source(bad) assert bad[0] in cm.exception.message @@ -207,7 +204,7 @@ def test_data_in_mem(self): # make sure cache is not used for data in memory! data = [np.empty((3, 3))] * 3 api.source(data) - assert len(self.db._database) == 1 + self.assertEqual(self.db.num_entries, 0) def test_old_db_conversion(self): # prior 2.1, database only contained lengths (int as string) entries @@ -219,7 +216,9 @@ def test_old_db_conversion(self): f.close() # windows sucks reader = api.source(fn) hash = db._get_file_hash(fn) - db._database = {hash: str(3)} + from pyemma.coordinates.data.util.traj_info_backends import DictDB + db._database = DictDB() + db._database.db_version = 0 info = db[fn, reader] assert info.length == 3 @@ -231,11 +230,73 @@ def test_corrupted_db(self): f.write("makes no sense!!!!") f.close() name = f.name - db = TrajectoryInfoCache(name) + import warnings + with warnings.catch_warnings(record=True) as cm: + warnings.simplefilter('always') + db = TrajectoryInfoCache(name) + assert len(cm) == 1 + assert "corrupted" in str(cm[-1].message) # ensure we can perform lookups on the broken db without exception. r = api.source(xtcfiles[0], top=pdbfile) db[xtcfiles[0], r] + def test_n_entries(self): + self.assertEqual(self.db.num_entries, 0) + assert TrajectoryInfoCache._instance is self.db + pyemma.coordinates.source(xtcfiles, top=pdbfile) + self.assertEqual(self.db.num_entries, len(xtcfiles)) + + def test_max_n_entries(self): + data = [np.random.random((10, 3)) for _ in range(20)] + max_entries = 10 + config.traj_info_max_entries = max_entries + files = [] + with TemporaryDirectory() as td: + for i, arr in enumerate(data): + f = os.path.join(td, "%s.npy" % i) + np.save(f, arr) + files.append(f) + pyemma.coordinates.source(files) + self.assertLessEqual(self.db.num_entries, max_entries) + self.assertGreater(self.db.num_entries, 0) + + def test_max_size(self): + data = [np.random.random((150, 10)) for _ in range(150)] + max_size = 1 + + @contextmanager + def size_ctx(new_size): + old_size = config.traj_info_max_size + config.traj_info_max_size = new_size + yield + config.traj_info_max_size = old_size + + files = [] + config.show_progress_bars=False + with TemporaryDirectory() as td, size_ctx(max_size): + for i, arr in enumerate(data): + f = os.path.join(td, "%s.txt" % i) + # save as txt to enforce creation of offsets + np.savetxt(f, arr) + files.append(f) + pyemma.coordinates.source(files) + + self.assertLessEqual(os.stat(self.db.database_filename).st_size / 1024, config.traj_info_max_size) + self.assertGreater(self.db.num_entries, 0) + + @unittest.skip("not yet functional") + def test_no_sqlite(self): + def import_mock(name, *args): + if name == 'sqlite3': + raise ImportError("we pretend not to have this") + return __import__(name, *args) + + from pyemma.coordinates.data.util import traj_info_cache + with mock.patch('pyemma.coordinates.data.util.traj_info_cache', '__import__', + side_effect=import_mock, create=True): + TrajectoryInfoCache._instance = None + TrajectoryInfoCache(self.tempfile) + if __name__ == "__main__": unittest.main() diff --git a/pyemma/pyemma.cfg b/pyemma/pyemma.cfg index 8e7339b46..78370d411 100644 --- a/pyemma/pyemma.cfg +++ b/pyemma/pyemma.cfg @@ -21,4 +21,7 @@ show_progress_bars = True # useful for trajectory formats, for which one has to read the whole file to get len # eg. XTC format. use_trajectory_lengths_cache = True - +# maximum entries in database +traj_info_max_entries = 50000 +# max size in MB +traj_info_max_size = 500 \ No newline at end of file diff --git a/pyemma/thermo/estimators/_callback.py b/pyemma/thermo/estimators/_callback.py index cd3f841a9..83a264bc4 100644 --- a/pyemma/thermo/estimators/_callback.py +++ b/pyemma/thermo/estimators/_callback.py @@ -22,6 +22,7 @@ class _ProgressIndicatorCallBack(object): def __init__(self): self.time = 0.0 + # TODO: unify this concept in ProgressReporter (but make it adaptive) def waiting(self): now = time.time() if now - self.time < .2: diff --git a/pyemma/util/config.py b/pyemma/util/config.py index f4fbd3589..e331693b4 100644 --- a/pyemma/util/config.py +++ b/pyemma/util/config.py @@ -43,6 +43,8 @@ 'show_progress_bars', 'used_filenames', 'use_trajectory_lengths_cache', + 'traj_info_max_entries', + 'traj_info_max_size', ) if six.PY2: @@ -322,6 +324,22 @@ def logging_config(self): # #config['incremental'] = True # setup_logging(self, config) + @property + def traj_info_max_entries(self): + return self._conf_values.getint('pyemma', 'traj_info_max_entries') + + @traj_info_max_entries.setter + def traj_info_max_entries(self, val): + self._conf_values.set('pyemma', 'traj_info_max_entries', str(val)) + + @property + def traj_info_max_size(self): + return self._conf_values.getint('pyemma', 'traj_info_max_size') + + @traj_info_max_size.setter + def traj_info_max_size(self, val): + val = str(int(val)) + self._conf_values.set('pyemma', 'traj_info_max_size', val) @property def show_progress_bars(self): return self._conf_values.getboolean('pyemma', 'show_progress_bars') diff --git a/pyemma/util/debug.py b/pyemma/util/debug.py index 8076520e5..db91173e1 100644 --- a/pyemma/util/debug.py +++ b/pyemma/util/debug.py @@ -33,7 +33,7 @@ _logger = None -SIGNAL_STACKTRACE = 42 +SIGNAL_STACKTRACE = 23 SIGNAL_PDB = 43 diff --git a/pyemma/util/tests/test_config.py b/pyemma/util/tests/test_config.py index 580b7ce77..98a981a5c 100644 --- a/pyemma/util/tests/test_config.py +++ b/pyemma/util/tests/test_config.py @@ -147,5 +147,11 @@ def test_interpolation_from_multiple_files(self): # TODO: impl pass + def test_traj_info_max_entries(self): + assert isinstance(self.config_inst.traj_info_max_entries, int) + self.config_inst.traj_info_max_entries = 1 + self.assertEqual(self.config_inst.traj_info_max_entries, 1) + + if __name__ == "__main__": unittest.main()