Skip to content

Commit

Permalink
typ: remove Experiment type
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Oct 24, 2024
1 parent ba6152a commit c2e5d84
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 26 deletions.
6 changes: 3 additions & 3 deletions mitosis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,16 @@ def _verify_variant_name(trial_db: Path, step: str, param: Parameter) -> None:
eng = create_engine("sqlite:///" + str(trial_db))
md = MetaData()
tb = Table(f"{step}_variant_{param.arg_name}", md, *variant_types())
vals: Collection[Any]
vals: Collection
if isinstance(param.vals, Mapping):
vals = StrictlyReproduceableDict({k: v for k, v in sorted(param.vals.items())})
elif isinstance(param.vals, Collection) and not isinstance(param.vals, str):
try:
vals = StrictlyReproduceableList(sorted(param.vals))
except (ValueError, TypeError):
vals = param.vals
vals = str(param.vals)
else:
vals = param.vals
vals = str(param.vals)
df = pd.read_sql(select(tb), eng)
ind_equal = df.loc[:, "name"] == param.var_name
if ind_equal.sum() == 0:
Expand Down
21 changes: 8 additions & 13 deletions mitosis/_typing.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
from abc import ABCMeta
from collections.abc import Mapping
from dataclasses import dataclass
from dataclasses import field
from types import ModuleType
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import ParamSpec
from typing import TypedDict


P = ParamSpec("P")
ExpRun = Callable[P, dict]
class ExpResults(TypedDict):
main: object


class Experiment(ModuleType, metaclass=ABCMeta):
__name__: str
__file__: str
name: str
lookup_dict: dict[str, dict[str, Any]]
run: ExpRun
P = ParamSpec("P")
ExpRun = Callable[P, ExpResults]


@dataclass
Expand All @@ -33,7 +28,7 @@ class Parameter:

var_name: str
arg_name: str
vals: Any
vals: object
# > 3.10 only: https://stackoverflow.com/a/49911616/534674
evaluate: bool = field(default=False, kw_only=True)

Expand All @@ -42,7 +37,7 @@ class ExpStep(NamedTuple):
name: str
action: ExpRun
action_ref: str
lookup: dict[str, Any]
lookup: Mapping[str, Mapping[str, object]]
lookup_ref: str
group: str | None
args: list[Parameter]
Expand Down
4 changes: 3 additions & 1 deletion mitosis/tests/mock_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
meth_config = {"metric": {"test": "len"}}

# lookup any parameter, any variant: always none
lookup_default = defaultdict(lambda: defaultdict(lambda: None))
lookup_default: dict[str, dict[str, None]] = defaultdict(
lambda: defaultdict(lambda: None)
)
16 changes: 9 additions & 7 deletions mitosis/tests/mock_part1.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from logging import getLogger

import numpy as np
from numpy.typing import NBitBase
from numpy.typing import NDArray

from mitosis._typing import ExpResults


class Klass:
@staticmethod
def gen_data(
length: int, extra: bool = False
) -> dict[str, NDArray[np.floating[NBitBase]] | bool]:
def gen_data(length: int, extra: bool = False) -> ExpResults:
getLogger(__name__).info("This is run every time")
getLogger(__name__).debug("This is run in debug mode only")

return {"data": np.ones(length, dtype=np.float_), "extra": extra}
return {
"data": np.ones(length, dtype=np.float_),
"extra": extra,
"main": None,
} # type: ignore


def do_nothing(*args, **kwargs) -> dict[str, None]:
def do_nothing(*args, **kwargs) -> ExpResults:
"""An experiment step that accepts anything and produces nothing"""
return {"main": None}
6 changes: 4 additions & 2 deletions mitosis/tests/mock_part2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from numpy.typing import NBitBase
from numpy.typing import NDArray

from mitosis._typing import ExpResults


def fit_and_score(
data: NDArray[np.floating[NBitBase]], metric: Literal["len"] | Literal["zero"]
) -> dict[str, float]:
) -> ExpResults:
if metric == "len":
return {"main": len(data)}
elif metric == "zero":
return {"main": 0}


def bad_runnable(*args: Any, **kwargs: Any):
def bad_runnable(*args: Any, **kwargs: Any) -> int:
return 1 # not a dict with key "main"

0 comments on commit c2e5d84

Please sign in to comment.