Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify some train/test code #448

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 39 additions & 95 deletions maud/data_model/maud_parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Provides model MaudParameter, and subclasses for all parameters Maud uses."""

from copy import deepcopy
from typing import List, Optional, Union

from pydantic import BaseModel, computed_field, field_validator, model_validator
Expand Down Expand Up @@ -78,6 +79,29 @@ def split_ids_exist_if_needed(self):
return self


class TrainTestParameter(MaudParameter):
"""Mark parameter to have different priors between train and test.

The class must be filled with `shape_names` and `name` for train.
The test version is created by calling the `test` method.
"""

prior_in_test_model: bool = False
prior_in_train_model: bool = True

def test(self):
"""Generate the test counterpart."""
test_self = deepcopy(self)
test_self.shape_names = [
shape_name.replace("train", "test")
for shape_name in self.shape_names
]
test_self.name = self.name.replace("train", "test")
test_self.prior_in_test_model = True
test_self.prior_in_train_model = False
return test_self


class Km(MaudParameter):
"""Parameter representing a model's Michaelis constants."""

Expand Down Expand Up @@ -238,9 +262,11 @@ class KcatPme(MaudParameter):
prior_in_train_model: bool = True


class Drain(MaudParameter):
class Drain(TrainTestParameter):
"""Stan variable type for drain parameters."""

name: str = "drain_train"
shape_names: List[str] = ["N_experiment_train", "N_drain"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.REACTION],
Expand All @@ -250,27 +276,11 @@ class Drain(MaudParameter):
default_scale: float = 1


class DrainTrain(Drain):
"""Stan variable for drain parameters of training experiments."""

name: str = "drain_train"
shape_names: List[str] = ["N_experiment_train", "N_drain"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class DrainTest(Drain):
"""Stan variable for drain parameters of test experiments."""

name: str = "drain_test"
shape_names: List[str] = ["N_experiment_test", "N_drain"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class ConcEnzyme(MaudParameter):
class ConcEnzyme(TrainTestParameter):
"""Parent class for enzyme concentration parameters."""

name: str = "conc_enzyme_train"
shape_names: List[str] = ["N_experiment_train", "N_enzyme"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.ENZYME],
Expand All @@ -282,27 +292,11 @@ class ConcEnzyme(MaudParameter):
default_scale: float = 1


class ConcEnzymeTrain(ConcEnzyme):
"""Enzyme concentration parameters in training experiments."""

name: str = "conc_enzyme_train"
shape_names: List[str] = ["N_experiment_train", "N_enzyme"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class ConcEnzymeTest(ConcEnzyme):
"""Enzyme concentration parameters in test experiments."""

name: str = "conc_enzyme_test"
shape_names: List[str] = ["N_experiment_test", "N_enzyme"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class ConcUnbalanced(MaudParameter):
class ConcUnbalanced(TrainTestParameter):
"""Parent class for unbalanced mic concentration parameters."""

name: str = "conc_unbalanced_train"
shape_names: List[str] = ["N_experiment_train", "N_unbalanced"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.METABOLITE, IdComponent.COMPARTMENT],
Expand All @@ -312,27 +306,11 @@ class ConcUnbalanced(MaudParameter):
default_scale: float = 2.0


class ConcUnbalancedTrain(ConcUnbalanced):
"""Unbalanced mic concentration parameters in training experiments."""

name: str = "conc_unbalanced_train"
shape_names: List[str] = ["N_experiment_train", "N_unbalanced"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class ConcUnbalancedTest(ConcUnbalanced):
"""Unbalanced mic concentration parameters in test experiments."""

name: str = "conc_unbalanced_test"
shape_names: List[str] = ["N_experiment_test", "N_enzyme"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class ConcPme(MaudParameter):
class ConcPme(TrainTestParameter):
"""Parent class for pme concentration parameters."""

name: str = "conc_pme_train"
shape_names: List[str] = ["N_experiment_train", "N_pme"]
id_components: List[List[IdComponent]] = [
[IdComponent.EXPERIMENT],
[IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME],
Expand All @@ -342,46 +320,12 @@ class ConcPme(MaudParameter):
default_scale: float = 2.0


class ConcPmeTrain(ConcPme):
"""Pme concentration parameters in training experiments."""

name: str = "conc_pme_train"
shape_names: List[str] = ["N_experiment_train", "N_pme"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class ConcPmeTest(ConcPme):
"""Pme concentration parameters in test experiments."""

name: str = "conc_pme_test"
shape_names: List[str] = ["N_experiment_test", "N_pme"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False


class Psi(MaudParameter):
class Psi(TrainTestParameter):
"""Stan variable representing per-experiment membrane potentials."""

name: str = "psi_train"
shape_names: List[str] = ["N_experiment_train"]
id_components: List[List[IdComponent]] = [[IdComponent.EXPERIMENT]]
non_negative: bool = False
default_loc: float = 0
default_scale: float = 2


class PsiTrain(Psi):
"""Pme concentration parameters in training experiments."""

name: str = "psi_train"
shape_names: List[str] = ["N_experiment_train"]
prior_in_test_model: bool = False
prior_in_train_model: bool = True


class PsiTest(Psi):
"""Pme concentration parameters in test experiments."""

name: str = "psi_test"
shape_names: List[str] = ["N_experiment_test"]
prior_in_test_model: bool = True
prior_in_train_model: bool = False
Loading
Loading