Skip to content

Commit

Permalink
fix: Explicitly begin(), not connect(), to sql engine
Browse files Browse the repository at this point in the history
As of sqlalchemy 2.0, default behavior for a connection is to auto-rollback
unless explicitly committed or part of a transaction (begun with begin()).

closes #56
Add a test to ensure variants have unique definition

remove Experiment type
  • Loading branch information
Jacob-Stevens-Haas authored Oct 24, 2024
1 parent 3fba4a6 commit b1e6eed
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 39 deletions.
12 changes: 6 additions & 6 deletions mitosis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(
self.log_table = Table(table_name, md, *cols)
url = "sqlite:///" + str(self.db)
self.eng = create_engine(url)
with self.eng.connect() as conn:
with self.eng.begin() as conn:
if not inspection.inspect(conn).has_table(table_name):
md.create_all(conn)

Expand All @@ -182,7 +182,7 @@ def emit(self, record: logging.LogRecord):
stmt = stmt.values({col: vals[i + 1]})
else:
raise ValueError("Cannot parse db message")
with self.eng.connect() as conn:
with self.eng.begin() as conn:
conn.execute(stmt)

def parse_record(self, msg: str) -> List[str]:
Expand Down Expand Up @@ -220,21 +220,21 @@ def _verify_variant_name(trial_db: Path, step: str, param: Parameter) -> None:
eng = create_engine("sqlite:///" + str(trial_db))
md = MetaData()
tb = Table(f"{step}_variant_{param.arg_name}", md, *variant_types())
vals: Collection[Any]
vals: Collection
if isinstance(param.vals, Mapping):
vals = StrictlyReproduceableDict({k: v for k, v in sorted(param.vals.items())})
elif isinstance(param.vals, Collection) and not isinstance(param.vals, str):
try:
vals = StrictlyReproduceableList(sorted(param.vals))
except (ValueError, TypeError):
vals = param.vals
vals = str(param.vals)
else:
vals = param.vals
vals = str(param.vals)
df = pd.read_sql(select(tb), eng)
ind_equal = df.loc[:, "name"] == param.var_name
if ind_equal.sum() == 0:
stmt = tb.insert().values({"name": param.var_name, "params": str(vals)})
with eng.connect() as conn:
with eng.begin() as conn:
conn.execute(stmt)
elif df.loc[ind_equal, "params"].iloc[0] != str(vals):
raise RuntimeError(
Expand Down
21 changes: 8 additions & 13 deletions mitosis/_typing.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
from abc import ABCMeta
from collections.abc import Mapping
from dataclasses import dataclass
from dataclasses import field
from types import ModuleType
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import ParamSpec
from typing import TypedDict


P = ParamSpec("P")
ExpRun = Callable[P, dict]
class ExpResults(TypedDict):
main: object


class Experiment(ModuleType, metaclass=ABCMeta):
__name__: str
__file__: str
name: str
lookup_dict: dict[str, dict[str, Any]]
run: ExpRun
P = ParamSpec("P")
ExpRun = Callable[P, ExpResults]


@dataclass
Expand All @@ -33,7 +28,7 @@ class Parameter:

var_name: str
arg_name: str
vals: Any
vals: object
# > 3.10 only: https://stackoverflow.com/a/49911616/534674
evaluate: bool = field(default=False, kw_only=True)

Expand All @@ -42,7 +37,7 @@ class ExpStep(NamedTuple):
name: str
action: ExpRun
action_ref: str
lookup: dict[str, Any]
lookup: Mapping[str, Mapping[str, object]]
lookup_ref: str
group: str | None
args: list[Parameter]
Expand Down
7 changes: 7 additions & 0 deletions mitosis/tests/mock_paper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from collections import defaultdict

data_config = {"length": {"test": 5}}

meth_config = {"metric": {"test": "len"}}

# lookup any parameter, any variant: always none
lookup_default: dict[str, dict[str, None]] = defaultdict(
lambda: defaultdict(lambda: None)
)
19 changes: 13 additions & 6 deletions mitosis/tests/mock_part1.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from logging import getLogger

import numpy as np
from numpy.typing import NBitBase
from numpy.typing import NDArray

from mitosis._typing import ExpResults


class Klass:
@staticmethod
def gen_data(
length: int, extra: bool = False
) -> dict[str, NDArray[np.floating[NBitBase]] | bool]:
def gen_data(length: int, extra: bool = False) -> ExpResults:
getLogger(__name__).info("This is run every time")
getLogger(__name__).debug("This is run in debug mode only")

return {"data": np.ones(length, dtype=np.float_), "extra": extra}
return {
"data": np.ones(length, dtype=np.float_),
"extra": extra,
"main": None,
} # type: ignore


def do_nothing(*args, **kwargs) -> ExpResults:
"""An experiment step that accepts anything and produces nothing"""
return {"main": None}
6 changes: 4 additions & 2 deletions mitosis/tests/mock_part2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from numpy.typing import NBitBase
from numpy.typing import NDArray

from mitosis._typing import ExpResults


def fit_and_score(
data: NDArray[np.floating[NBitBase]], metric: Literal["len"] | Literal["zero"]
) -> dict[str, float]:
) -> ExpResults:
if metric == "len":
return {"main": len(data)}
elif metric == "zero":
return {"main": 0}


def bad_runnable(*args: Any, **kwargs: Any):
def bad_runnable(*args: Any, **kwargs: Any) -> int:
return 1 # not a dict with key "main"
43 changes: 32 additions & 11 deletions mitosis/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,33 @@ def test_mock_experiment(mock_steps, tmp_path):
assert (metadata / "experiment").resolve().exists()


@pytest.fixture
def nothing_step():
# fmt: off
return ExpStep(
"nothing",
mock_part1.do_nothing, "mitosis.tests.mock_part1:do_nothing",
mock_paper.lookup_default, "mitosis.tests.mock_paper:lookup_default",
None,
[],
[]
)
# fmt: on


@pytest.mark.clean
def test_variant_redefinition_disallowed(nothing_step, tmp_path):
# GH 56
chg_param1 = Parameter("foo_a", "foo", "a", evaluate=False)
chg_param2 = Parameter("foo_a", "foo", "b", evaluate=False)
nothing_step.args.append(chg_param1)
mitosis.run([nothing_step], trials_folder=tmp_path)
nothing_step.args.pop(0)
nothing_step.args.append(chg_param2)
with pytest.raises(RuntimeError, match="stored with different values"):
mitosis.run([nothing_step], trials_folder=tmp_path)


def test_load_results_order(tmp_path):
exp_key = "test_results"
(tmp_path / exp_key).mkdir()
Expand Down Expand Up @@ -224,17 +251,11 @@ def test_malfored_return_experiment(mock_steps, tmp_path):
def test_load_toml():
parent = Path(__file__).resolve().parent
tomlfile = parent / "test_pyproject.toml"
result = _disk.load_mitosis_steps(tomlfile)
expected = {
"data": (
"mitosis.tests.mock_part1:Klass.gen_data",
"mitosis.tests.mock_paper:data_config",
),
"fit_eval": (
"mitosis.tests.mock_part2:fit_and_score",
"mitosis.tests.mock_paper:meth_config",
),
}
result = _disk.load_mitosis_steps(tomlfile)["nothing"]
expected = (
"mitosis.tests.mock_part1:do_nothing",
"mitosis.tests.mock_paper:lookup_default",
)
assert result == expected


Expand Down
3 changes: 3 additions & 0 deletions mitosis/tests/test_pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[tool.mitosis.steps]
nothing = [
"mitosis.tests.mock_part1:do_nothing",
"mitosis.tests.mock_paper:lookup_default"]
data = [
"mitosis.tests.mock_part1:Klass.gen_data",
"mitosis.tests.mock_paper:data_config"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ dependencies = [
"nbclient",
"nbformat",
"pandas<2.2",
"sqlalchemy",
"sqlalchemy>=1.4",
"toml",
"types-toml",
]
Expand Down

0 comments on commit b1e6eed

Please sign in to comment.