Skip to content

Commit

Permalink
various updates
Browse files Browse the repository at this point in the history
  • Loading branch information
edisj committed Jul 25, 2024
1 parent f3807d8 commit a68588b
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 199 deletions.
2 changes: 1 addition & 1 deletion mdaadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from importlib.metadata import version

from .query import Query
from .database import Database, Table
from .database import Database, Table, Schema
from .analysis import DBAnalysisManager


Expand Down
111 changes: 63 additions & 48 deletions mdaadb/analysis.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
import inspect
from typing import List, NamedTuple, Optional, Callable
import pathlib
from typing import Optional, Callable

import MDAnalysis as mda
from MDAnalysis.analysis.base import Results
from MDAnalysis.analysis.base import AnalysisBase, Results

from .database import Database, Table


class DBAnalysisManager:
"""Class that connects database IO with analysis running.
"""

def __init__(self, Analysis, dbfile, hooks=None):
"""
Parameters
----------
Analysis :
dbfile : path-like
hooks : dict
Parameters
----------
Analysis : mda.analysis.base.AnalysisBase
dbfile : path-like
hooks : dict
"""

"""
def __init__(
self,
Analysis, # : class that inherits from AnalysisBase
dbfile: pathlib.Path | str,
hooks: Optional[dict] = None,
):

self.Analysis = Analysis
self.db = Database(dbfile)
if isinstance(dbfile, Database):
self.db = dbfile
else:
self.db = Database(dbfile)
self._analysis = None

self.hooks = {
"pre_run": None,
"post_run": None,
"get_universe": None,
"pre_save": None,
"post_save": None,
"get_universe": None,
}
if hooks is not None:
self.hooks.update(hooks)
Expand All @@ -41,55 +48,54 @@ def __init__(self, Analysis, dbfile, hooks=None):
except AttributeError:
self._name = self.Analysis.__name__
try:
self._notes = self.Analysis.notes
self._desc = self.Analysis.description
except AttributeError:
self._notes = None
self._path = inspect.getfile(self.Analysis)
self._desc = "none"
try:
self._creator = inspect.getfile(self.Analysis)
except OSError:
self._creator = "unknown"

try:
self.obsv = self.db.get_table("Observables")
self.observables = self.db.get_table("Observables")
except ValueError:
self.obsv = self.db.create_table(
self.observables = self.db.create_table(
"""
Observables (
obsName TEXT,
notes TEXT,
creator TEXT,
timestamp DATETIME DEFAULT (strftime('%m-%d-%Y %H:%M', 'now', 'localtime'))
name TEXT PRIMARY KEY,
description TEXT,
creator TEXT
)
""",
STRICT=False
STRICT=False,
)

if self._name not in self.obsv.get_column("obsName").data:
self.obsv.insert_row(
(self._name, self._notes, self._path),
columns=["obsName, notes, creator"],
if self._name not in self.observables.get_column("Name").data:
# print(self._name, self._desc, self._path)
self.observables.insert_row(
row=(self._name, self._desc, self._creator),
columns=["name", "description", "creator"],
)

@property
def results(self) -> Results:
"""Analysis results"""

if self._analysis is None:
raise ValueError("Must call run() for results to exist.")
return self._analysis.results

def _get_universe(self, simID: int):
def _get_universe(self, simID: int) -> mda.Universe:

if self.hooks["get_universe"] is not None:
#if self.hooks["get_universe"]:
return self.hooks["get_universe"](self.db, simID)
universe = self.hooks["get_universe"](self.db, simID)

return universe

row = self.db.get_table("Simulations").get_row(simID)
return mda.Universe(row.topology, row.trajectory)
universe = mda.Universe(row.Topology, row.Trajectory)

def run(self, simID: int, **kwargs: dict) -> None:
return universe

def run(self, simID: int, **kwargs: dict) -> None:
"""Run the analysis for a simulation given by `simID`.
Parameters
----------
simID : int
**kwargs : dict
additional keyword arguments to be passed to the Analysis class
Expand All @@ -101,12 +107,12 @@ def run(self, simID: int, **kwargs: dict) -> None:
self._analysis._simID = simID

if self.hooks["pre_run"] is not None:
self.hooks["pre_run"](simID, self.db)
self.hooks["pre_run"](self.db, simID)

self._analysis.run()

if self.hooks["post_run"] is not None:
self.hooks["post_run"](simID, self.db)
self.hooks["post_run"](self.db, simID)

def save(self) -> None:
"""Save the results of the analysis to the database."""
Expand All @@ -120,20 +126,29 @@ def save(self) -> None:
analysis_table = self.db.get_table(self._name)
except ValueError:
analysis_table = self.db.create_table(self.Analysis.schema)
else:
assert analysis_table.schema == self.Analysis.schema
# else:
# assert analysis_table.schema == self.Analysis.schema

simID = self._analysis._simID
if simID in analysis_table._get_rowids():
if simID in [row[0] for row in analysis_table.SELECT("SimulationID").execute().fetchall()]:
raise ValueError(
f"{self._name} table already has data for simID {simID}"
f"'{self._name}' table already has data for simID {simID}"
)

if self.hooks["pre_save"] is not None:
self.hooks["pre_save"](self.db, simID)

rows = self.results[self.Analysis.results_key]
analysis_table.insert_rows(rows)

if self.hooks["post_save"] is not None:
self.hooks["post_save"](simID, self.db)
self.hooks["post_save"](self.db, simID)

@property
def results(self) -> Results:
if self._analysis is None:
raise ValueError("Must call run() for results to exist.")
return self._analysis.results

def __enter__(self):
return self
Expand Down
Loading

0 comments on commit a68588b

Please sign in to comment.