Skip to content

Commit

Permalink
refactor: unify some train/test code
Browse files Browse the repository at this point in the history
  • Loading branch information
carrascomj committed Oct 13, 2023
1 parent 44cd50a commit 231a875
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 205 deletions.
134 changes: 39 additions & 95 deletions maud/data_model/maud_parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Provides model MaudParameter, and subclasses for all parameters Maud uses."""

from copy import deepcopy
from typing import List, Optional, Union

from pydantic import BaseModel, computed_field, field_validator, model_validator
Expand Down Expand Up @@ -78,6 +79,29 @@ def split_ids_exist_if_needed(self):
return self


class TrainTestParameter(MaudParameter):
"""Mark parameter to have different priors between train and test.
The class must be filled with `shape_names` and `name` for train.
The test version is created by calling the `test` method.
"""

prior_in_test_model: bool = False
prior_in_train_model: bool = True

def test(self):
"""Generate the test counterpart."""
test_self = deepcopy(self)
test_self.shape_names = [
shape_name.replace("train", "test")
for shape_name in self.shape_names
]
test_self.name = self.name.replace("train", "test")
test_self.prior_in_test_model = True
test_self.prior_in_train_model = False
return test_self


class Km(MaudParameter):
"""Parameter representing a model's Michaelis constants."""

Expand Down Expand Up @@ -238,9 +262,11 @@ class KcatPme(MaudParameter):
prior_in_train_model: bool = True


class Drain(MaudParameter):
class Drain(TrainTestParameter):
"""Stan variable type for drain parameters."""

name: str = "drain_train"
shape_names: List[str] = ["N_experiment_train", "N_drain"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.REACTION],
Expand All @@ -250,27 +276,11 @@ class Drain(MaudParameter):
default_scale: float = 1


class DrainTrain(Drain):
"""Stan variable for drain parameters of training experiments."""

name: str = "drain_train"
shape_names: List[str] = ["N_experiment_train", "N_drain"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class DrainTest(Drain):
"""Stan variable for drain parameters of test experiments."""

name: str = "drain_test"
shape_names: List[str] = ["N_experiment_test", "N_drain"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class ConcEnzyme(MaudParameter):
class ConcEnzyme(TrainTestParameter):
"""Parent class for enzyme concentration parameters."""

name: str = "conc_enzyme_train"
shape_names: List[str] = ["N_experiment_train", "N_enzyme"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.ENZYME],
Expand All @@ -282,27 +292,11 @@ class ConcEnzyme(MaudParameter):
default_scale: float = 1


class ConcEnzymeTrain(ConcEnzyme):
"""Enzyme concentration parameters in training experiments."""

name: str = "conc_enzyme_train"
shape_names: List[str] = ["N_experiment_train", "N_enzyme"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class ConcEnzymeTest(ConcEnzyme):
"""Enzyme concentration parameters in test experiments."""

name: str = "conc_enzyme_test"
shape_names: List[str] = ["N_experiment_test", "N_enzyme"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class ConcUnbalanced(MaudParameter):
class ConcUnbalanced(TrainTestParameter):
"""Parent class for unbalanced mic concentration parameters."""

name: str = "conc_unbalanced_train"
shape_names: List[str] = ["N_experiment_train", "N_unbalanced"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.METABOLITE, IdComponent.COMPARTMENT],
Expand All @@ -312,27 +306,11 @@ class ConcUnbalanced(MaudParameter):
default_scale: float = 2.0


class ConcUnbalancedTrain(ConcUnbalanced):
"""Unbalanced mic concentration parameters in training experiments."""

name: str = "conc_unbalanced_train"
shape_names: List[str] = ["N_experiment_train", "N_unbalanced"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class ConcUnbalancedTest(ConcUnbalanced):
"""Unbalanced mic concentration parameters in test experiments."""

name: str = "conc_unbalanced_test"
shape_names: List[str] = ["N_experiment_test", "N_enzyme"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class ConcPme(MaudParameter):
class ConcPme(TrainTestParameter):
"""Parent class for pme concentration parameters."""

name: str = "conc_pme_train"
shape_names: List[str] = ["N_experiment_train", "N_pme"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME],
Expand All @@ -342,46 +320,12 @@ class ConcPme(MaudParameter):
default_scale: float = 2.0


class ConcPmeTrain(ConcPme):
"""Pme concentration parameters in training experiments."""

name: str = "conc_pme_train"
shape_names: List[str] = ["N_experiment_train", "N_pme"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class ConcPmeTest(ConcPme):
"""Pme concentration parameters in test experiments."""

name: str = "conc_pme_test"
shape_names: List[str] = ["N_experiment_test", "N_pme"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class Psi(MaudParameter):
class Psi(TrainTestParameter):
"""Stan variable representing per-experiment membrane potentials."""

name: str = "psi_train"
shape_names: List[str] = ["N_experiment_train"]
id_components: List[List[IdComponent]] = [[IdComponent.EXPERIMENT]]
non_negative: bool = False
default_loc: float = 0
default_scale: float = 2


class PsiTrain(Psi):
"""Pme concentration parameters in training experiments."""

name: str = "psi_train"
shape_names: List[str] = ["N_experiment_train"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class PsiTest(Psi):
"""Pme concentration parameters in test experiments."""

name: str = "psi_test"
shape_names: List[str] = ["N_experiment_test"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False
Loading

0 comments on commit 231a875

Please sign in to comment.