diff --git a/mdaadb/analysis.py b/mdaadb/analysis.py index e977ef0..9303a32 100644 --- a/mdaadb/analysis.py +++ b/mdaadb/analysis.py @@ -1,78 +1,125 @@ -from typing import List, NamedTuple -from pathlib import Path +import inspect +from typing import List, NamedTuple, Optional, Callable -from napalib.system.universe import NapAUniverse import MDAnalysis as mda +from MDAnalysis.analysis.base import Results -from . import Database +from .database import Database, Table -def get_NapA_universe_by_simID(db: Database, simID: int) -> NapAUniverse: - row = db.get_table("Simulations").get_row(simID) - topology = row.topology - trajectory = row.trajectory - u = NapAUniverse(topology) - u.load_new(trajectory) +class DBAnalysisManager: + """Class that connects database IO with analysis running. - return u + """ + def __init__(self, Analysis, dbfile, hooks=None): + """ -def get_universe_by_simID(db: Database, simID: int) -> mda.Universe: - row = db.get_table("Simulations").get_row(simID) - topology = row.topology - trajectory = row.trajectory + Parameters + ---------- + Analysis : + dbfile : + hooks : - return mda.Universe(topology, trajectory) + """ - -class DBAnalysisRunner: - - def __init__(self, db: Database, Analysis): - - self.db = db self.Analysis = Analysis - self.name = self.Analysis.name - self._analysis = None + self.db = Database(dbfile) + + try: + self.analysis_name = self.Analysis.name + except AttributeError: + self.analysis_name = self.Analysis.__name__ + try: + self.analysis_notes = self.Analysis.notes + except AttributeError: + self.analysis_notes = None + self.analysis_path = inspect.getfile(self.Analysis) try: - self.observables = self.db.get_table("Observables") + self.obsv = self.db.get_table("Observables") except ValueError: - self.observables = self.db.create_table( - "Observables (name TEXT, progenitor TEXT)" + self.obsv = self.db.create_table( + """ + Observables ( + obsName TEXT, + notes TEXT, + creator TEXT, + timestamp DATETIME DEFAULT (strftime('%m-%d-%Y %H:%M', 'now', 'localtime')) + ) + """, + STRICT=False ) - finally: - self.observables.insert_array([ - (self.name, self.Analysis._path), - ]) - def __enter__(self): - self.db.open() - return self + if self.analysis_name not in self.obsv.get_column("obsName").data: + self.obsv.insert_row( + (self.analysis_name, self.analysis_notes, self.analysis_path), + columns=["obsName, notes, creator"], + ) - def __exit__(self, *args): - self.db.close() + self._analysis = None @property - def results(self): - if self._analysis is not None: - return self._analysis.results - - def run_for_simID(self, simID: int, **kwargs) -> None: - """""" - universe = get_NapA_universe_by_simID(self.db, simID) - self._analysis = self.Analysis(universe, **kwargs) + def results(self) -> Results: + """Analysis results.""" + + if self._analysis is None: + raise ValueError("Must call run() for results to exist.") + return self._analysis.results + + def _get_universe(self, simID: int, get_universe: Optional[Callable]): + + if get_universe is not None: + #if self.hooks["get_universe"]: + return get_universe(self.db, simID) + + row = self.db.get_table("Simulations").get_row(simID) + return mda.Universe(row.topology, row.trajecory) + + def run( + self, + simID: int, + get_universe: Optional[Callable] = None, + **kwargs: dict, + ) -> None: + """ + Parameters + ---------- + simID : int + get_universe : Callable[Database, int] + **kwargs : dict + additional keyword arguments to be passed to the Analysis class + + """ + u = self._get_universe(simID, get_universe) + + self._analysis = self.Analysis(u, **kwargs) self._analysis._simID = simID self._analysis.run() def save(self) -> None: + """Save the results of the analysis to the database.""" + if not self.results: raise ValueError("no results") - if self.name not in self.db._get_table_names(): + analysis_table = Table(self.analysis_name, self.db) + + if analysis_table not in self.db: self.db.create_table(self.Analysis.schema) + else: + simID = self._analysis._simID + if simID in analysis_table._get_rowids(): + raise ValueError( + f"{self.analysis_name} table already has data for simID {simID}" + ) rows = self.results[self.Analysis.results_key] - table = self.db.get_table(self.name) - table.insert_array(rows) + self.db.get_table(self.analysis_name).insert_rows(rows) + def __enter__(self): + return self + + def __exit__(self, *args): + self.db.close()