From 231a875382da112ff8265add27f7470e3ffa2175 Mon Sep 17 00:00:00 2001 From: carrascomj Date: Fri, 13 Oct 2023 13:00:03 +0200 Subject: [PATCH 1/2] refactor: unify some train/test code --- maud/data_model/maud_parameter.py | 134 +++++++-------------- maud/data_model/parameter_set.py | 186 ++++++++++++------------------ 2 files changed, 115 insertions(+), 205 deletions(-) diff --git a/maud/data_model/maud_parameter.py b/maud/data_model/maud_parameter.py index 53fd6050..a1e10b61 100644 --- a/maud/data_model/maud_parameter.py +++ b/maud/data_model/maud_parameter.py @@ -1,5 +1,6 @@ """Provides model MaudParameter, and subclasses for all parameters Maud uses.""" +from copy import deepcopy from typing import List, Optional, Union from pydantic import BaseModel, computed_field, field_validator, model_validator @@ -78,6 +79,29 @@ def split_ids_exist_if_needed(self): return self +class TrainTestParameter(MaudParameter): + """Mark parameter to have different priors between train and test. + + The class must be filled with `shape_names` and `name` for train. + The test version is created by calling the `test` method. + """ + + prior_in_test_model: bool = False + prior_in_train_model: bool = True + + def test(self): + """Generate the test counterpart.""" + test_self = deepcopy(self) + test_self.shape_names = [ + shape_name.replace("train", "test") + for shape_name in self.shape_names + ] + test_self.name = self.name.replace("train", "test") + test_self.prior_in_test_model = True + test_self.prior_in_train_model = False + return test_self + + class Km(MaudParameter): """Parameter representing a model's Michaelis constants.""" @@ -238,9 +262,11 @@ class KcatPme(MaudParameter): prior_in_train_model: bool = True -class Drain(MaudParameter): +class Drain(TrainTestParameter): """Stan variable type for drain parameters.""" + name: str = "drain_train" + shape_names: List[str] = ["N_experiment_train", "N_drain"] id_components: List[List[IdComponent]] = [ [IdComponent.EXPERIMENT], [IdComponent.REACTION], @@ -250,27 +276,11 @@ class Drain(MaudParameter): default_scale: float = 1 -class DrainTrain(Drain): - """Stan variable for drain parameters of training experiments.""" - - name: str = "drain_train" - shape_names: List[str] = ["N_experiment_train", "N_drain"] - prior_in_test_model: bool = False - prior_in_train_model: bool = True - - -class DrainTest(Drain): - """Stan variable for drain parameters of test experiments.""" - - name: str = "drain_test" - shape_names: List[str] = ["N_experiment_test", "N_drain"] - prior_in_test_model: bool = True - prior_in_train_model: bool = False - - -class ConcEnzyme(MaudParameter): +class ConcEnzyme(TrainTestParameter): """Parent class for enzyme concentration parameters.""" + name: str = "conc_enzyme_train" + shape_names: List[str] = ["N_experiment_train", "N_enzyme"] id_components: List[List[IdComponent]] = [ [IdComponent.EXPERIMENT], [IdComponent.ENZYME], @@ -282,27 +292,11 @@ class ConcEnzyme(MaudParameter): default_scale: float = 1 -class ConcEnzymeTrain(ConcEnzyme): - """Enzyme concentration parameters in training experiments.""" - - name: str = "conc_enzyme_train" - shape_names: List[str] = ["N_experiment_train", "N_enzyme"] - prior_in_test_model: bool = False - prior_in_train_model: bool = True - - -class ConcEnzymeTest(ConcEnzyme): - """Enzyme concentration parameters in test experiments.""" - - name: str = "conc_enzyme_test" - shape_names: List[str] = ["N_experiment_test", "N_enzyme"] - prior_in_test_model: bool = True - prior_in_train_model: bool = False - - -class ConcUnbalanced(MaudParameter): +class ConcUnbalanced(TrainTestParameter): """Parent class for unbalanced mic concentration parameters.""" + name: str = "conc_unbalanced_train" + shape_names: List[str] = ["N_experiment_train", "N_unbalanced"] id_components: List[List[IdComponent]] = [ [IdComponent.EXPERIMENT], [IdComponent.METABOLITE, IdComponent.COMPARTMENT], @@ -312,27 +306,11 @@ class ConcUnbalanced(MaudParameter): default_scale: float = 2.0 -class ConcUnbalancedTrain(ConcUnbalanced): - """Unbalanced mic concentration parameters in training experiments.""" - - name: str = "conc_unbalanced_train" - shape_names: List[str] = ["N_experiment_train", "N_unbalanced"] - prior_in_test_model: bool = False - prior_in_train_model: bool = True - - -class ConcUnbalancedTest(ConcUnbalanced): - """Unbalanced mic concentration parameters in test experiments.""" - - name: str = "conc_unbalanced_test" - shape_names: List[str] = ["N_experiment_test", "N_enzyme"] - prior_in_test_model: bool = True - prior_in_train_model: bool = False - - -class ConcPme(MaudParameter): +class ConcPme(TrainTestParameter): """Parent class for pme concentration parameters.""" + name: str = "conc_pme_train" + shape_names: List[str] = ["N_experiment_train", "N_pme"] id_components: List[List[IdComponent]] = [ [IdComponent.EXPERIMENT], [IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME], @@ -342,46 +320,12 @@ class ConcPme(MaudParameter): default_scale: float = 2.0 -class ConcPmeTrain(ConcPme): - """Pme concentration parameters in training experiments.""" - - name: str = "conc_pme_train" - shape_names: List[str] = ["N_experiment_train", "N_pme"] - prior_in_test_model: bool = False - prior_in_train_model: bool = True - - -class ConcPmeTest(ConcPme): - """Pme concentration parameters in test experiments.""" - - name: str = "conc_pme_test" - shape_names: List[str] = ["N_experiment_test", "N_pme"] - prior_in_test_model: bool = True - prior_in_train_model: bool = False - - -class Psi(MaudParameter): +class Psi(TrainTestParameter): """Stan variable representing per-experiment membrane potentials.""" + name: str = "psi_train" + shape_names: List[str] = ["N_experiment_train"] id_components: List[List[IdComponent]] = [[IdComponent.EXPERIMENT]] non_negative: bool = False default_loc: float = 0 default_scale: float = 2 - - -class PsiTrain(Psi): - """Pme concentration parameters in training experiments.""" - - name: str = "psi_train" - shape_names: List[str] = ["N_experiment_train"] - prior_in_test_model: bool = False - prior_in_train_model: bool = True - - -class PsiTest(Psi): - """Pme concentration parameters in test experiments.""" - - name: str = "psi_test" - shape_names: List[str] = ["N_experiment_test"] - prior_in_test_model: bool = True - prior_in_train_model: bool = False diff --git a/maud/data_model/parameter_set.py b/maud/data_model/parameter_set.py index e94260bb..be84cea2 100644 --- a/maud/data_model/parameter_set.py +++ b/maud/data_model/parameter_set.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, computed_field import maud.data_model.maud_parameter as mp -from maud.data_model.experiment import Experiment, MeasurementType +from maud.data_model.experiment import Experiment, Measurement, MeasurementType from maud.data_model.hardcoding import ID_SEPARATOR from maud.data_model.kinetic_model import KineticModel, ReactionMechanism from maud.data_model.maud_init import InitInput @@ -161,117 +161,75 @@ def kcat_pme(self) -> mp.KcatPme: init_input=self.init_input.kcat_pme, ) - @computed_field - def drain_train(self) -> mp.DrainTrain: + def _get_experiments(self, train: bool) -> list[str]: + return [ + e.id + for e in self.experiments + if (e.is_train if train else e.is_test) + ] + + def _get_drain(self, train: bool) -> mp.Drain: """Add the drain_train field.""" drain_ids = [ d.id for d in self.kinetic_model.reactions if d.mechanism == ReactionMechanism.drain ] - exp_ids = [e.id for e in self.experiments if e.is_train] - return mp.DrainTrain( + exp_ids = self._get_experiments(train) + result = mp.Drain( ids=[exp_ids, drain_ids], split_ids=[[exp_ids], [drain_ids]], user_input=self.parameter_set_input.drain, init_input=self.init_input.drain, ) + return result if train else result.test() @computed_field - def drain_test(self) -> mp.DrainTest: - """Add the drain_test field.""" - drain_ids = [ - d.id - for d in self.kinetic_model.reactions - if d.mechanism == ReactionMechanism.drain - ] - exp_ids = [e.id for e in self.experiments if e.is_test] - return mp.DrainTest( - ids=[exp_ids, drain_ids], - split_ids=[[exp_ids], [drain_ids]], - user_input=self.parameter_set_input.drain, - init_input=self.init_input.drain, - ) + def drain_train(self) -> mp.Drain: + """Add the drain_train field.""" + return self._get_drain(train=True) @computed_field - def conc_enzyme_train(self) -> mp.ConcEnzymeTrain: - """Add the conc_enzyme_train field.""" - enzyme_ids = [e.id for e in self.kinetic_model.enzymes] - exp_ids = [e.id for e in self.experiments if e.is_train] - measurements = [ + def drain_test(self) -> mp.Drain: + """Add the drain_test field.""" + return self._get_drain(train=False) + + def _get_measurements( + self, train: bool, mtype: MeasurementType + ) -> list[Measurement]: + return [ m for e in self.experiments for m in e.measurements - if e.is_train and m.target_type == MeasurementType.ENZYME + if (e.is_train if train else e.is_test) and m.target_type == mtype ] - return mp.ConcEnzymeTrain( - ids=[exp_ids, enzyme_ids], - split_ids=[[exp_ids], [enzyme_ids]], - user_input=self.parameter_set_input.conc_enzyme, - init_input=self.init_input.conc_enzyme, - measurements=measurements, - ) - @computed_field - def conc_enzyme_test(self) -> mp.ConcEnzymeTest: - """Add the conc_enzyme_test field.""" + def _get_conc_enzyme(self, train: bool) -> mp.ConcEnzyme: enzyme_ids = [e.id for e in self.kinetic_model.enzymes] - exp_ids = [e.id for e in self.experiments if e.is_test] - measurements = [ - m - for e in self.experiments - for m in e.measurements - if e.is_test and m.target_type == MeasurementType.ENZYME - ] - return mp.ConcEnzymeTest( + exp_ids = self._get_experiments(train) + measurements = self._get_measurements(train, MeasurementType.ENZYME) + result = mp.ConcEnzyme( ids=[exp_ids, enzyme_ids], split_ids=[[exp_ids], [enzyme_ids]], user_input=self.parameter_set_input.conc_enzyme, init_input=self.init_input.conc_enzyme, measurements=measurements, ) + return result if train else result.test() @computed_field - def conc_unbalanced_train(self) -> mp.ConcUnbalancedTrain: - """Add the conc_unbalanced_train field.""" - exp_ids = [e.id for e in self.experiments if e.is_train] - measurements = [ - m - for e in self.experiments - for m in e.measurements - if e.is_train and m.target_type == MeasurementType.MIC - ] - unbalanced_mic_ids, unbalanced_mic_mets, unbalanced_mic_cpts = map( - list, - zip( - *[ - [m.id, m.metabolite_id, m.compartment_id] - for m in self.kinetic_model.mics - if not m.balanced - ] - ), - ) - return mp.ConcUnbalancedTrain( - ids=[exp_ids, unbalanced_mic_ids], - split_ids=[ - [exp_ids], - [unbalanced_mic_mets, unbalanced_mic_cpts], - ], - user_input=self.parameter_set_input.conc_unbalanced, - init_input=self.init_input.conc_unbalanced, - measurements=measurements, - ) + def conc_enzyme_train(self) -> mp.ConcEnzyme: + """Add the conc_enzyme_train field.""" + return self._get_conc_enzyme(train=True) @computed_field - def conc_unbalanced_test(self) -> mp.ConcUnbalancedTest: - """Add the conc_unbalanced_test field.""" - exp_ids = [e.id for e in self.experiments if e.is_test] - measurements = [ - m - for e in self.experiments - for m in e.measurements - if e.is_test and m.target_type == MeasurementType.MIC - ] + def conc_enzyme_test(self) -> mp.ConcEnzyme: + """Add the conc_enzyme_test field.""" + return self._get_conc_enzyme(train=False) + + def _get_conc_unbalanced(self, train: bool) -> mp.ConcUnbalanced: + exp_ids = self._get_experiments(train) + measurements = self._get_measurements(train, MeasurementType.MIC) unbalanced_mic_ids, unbalanced_mic_mets, unbalanced_mic_cpts = map( list, zip( @@ -282,7 +240,7 @@ def conc_unbalanced_test(self) -> mp.ConcUnbalancedTest: ] ), ) - return mp.ConcUnbalancedTest( + result = mp.ConcUnbalanced( ids=[exp_ids, unbalanced_mic_ids], split_ids=[ [exp_ids], @@ -292,57 +250,65 @@ def conc_unbalanced_test(self) -> mp.ConcUnbalancedTest: init_input=self.init_input.conc_unbalanced, measurements=measurements, ) + return result if train else result.test() @computed_field - def conc_pme_train(self) -> mp.ConcPmeTrain: + def conc_unbalanced_train(self) -> mp.ConcUnbalanced: + """Add the conc_unbalanced_train field.""" + return self._get_conc_unbalanced(train=True) + + @computed_field + def conc_unbalanced_test(self) -> mp.ConcUnbalanced: + """Add the conc_unbalanced_test field.""" + return self._get_conc_unbalanced(train=False) + + def _get_conc_pme(self, train: bool) -> mp.ConcPme: """Add the conc_pme_train field.""" - exp_ids = [e.id for e in self.experiments if e.is_train] + exp_ids = self._get_experiments(train) pme_ids = ( [p.modifying_enzyme_id for p in self.kinetic_model.phosphorylations] if self.kinetic_model.phosphorylations is not None else [] ) - return mp.ConcPmeTrain( + result = mp.ConcPme( ids=[exp_ids, pme_ids], split_ids=[[exp_ids], [pme_ids]], user_input=self.parameter_set_input.conc_pme, init_input=self.init_input.conc_pme, ) + return result if train else result.test() @computed_field - def conc_pme_test(self) -> mp.ConcPmeTest: - """Add the conc_pme_test field.""" - exp_ids = [e.id for e in self.experiments if e.is_test] - pme_ids = ( - [p.modifying_enzyme_id for p in self.kinetic_model.phosphorylations] - if self.kinetic_model.phosphorylations is not None - else [] - ) - return mp.ConcPmeTest( - ids=[exp_ids, pme_ids], - split_ids=[[exp_ids], [pme_ids]], - user_input=self.parameter_set_input.conc_pme, - init_input=self.init_input.conc_pme, - ) + def conc_pme_train(self) -> mp.ConcPme: + """Add the conc_pme_train field.""" + return self._get_conc_pme(True) @computed_field - def psi_train(self) -> mp.PsiTrain: + def conc_pme_test(self) -> mp.ConcPme: + """Add the conc_pme_test field.""" + return self._get_conc_pme(False) + + def _get_psi(self, train: bool) -> mp.Psi: """Add the psi_train field.""" - exp_ids = [e.id for e in self.experiments if e.is_train] - return mp.PsiTrain( + exp_ids = [ + e.id + for e in self.experiments + if (e.is_train if train else e.is_test) + ] + result = mp.Psi( ids=[exp_ids], split_ids=[[exp_ids]], user_input=self.parameter_set_input.psi, init_input=self.init_input.psi, ) + return result if train else result.test() + + @computed_field + def psi_train(self) -> mp.Psi: + """Add the psi_train field.""" + return self._get_psi(train=True) @computed_field - def psi_test(self) -> mp.PsiTest: + def psi_test(self) -> mp.Psi: """Add the psi_test field.""" - exp_ids = [e.id for e in self.experiments if e.is_test] - return mp.PsiTest( - ids=[exp_ids], - split_ids=[[exp_ids]], - user_input=self.parameter_set_input.psi, - init_input=self.init_input.psi, - ) + return self._get_psi(train=True).test() From f98aa771f17794515d6969ac498c821ded1cee36 Mon Sep 17 00:00:00 2001 From: carrascomj Date: Fri, 13 Oct 2023 13:19:38 +0200 Subject: [PATCH 2/2] fix: remove typos of test/train param set --- maud/data_model/parameter_set.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/maud/data_model/parameter_set.py b/maud/data_model/parameter_set.py index be84cea2..cc207c75 100644 --- a/maud/data_model/parameter_set.py +++ b/maud/data_model/parameter_set.py @@ -169,7 +169,6 @@ def _get_experiments(self, train: bool) -> list[str]: ] def _get_drain(self, train: bool) -> mp.Drain: - """Add the drain_train field.""" drain_ids = [ d.id for d in self.kinetic_model.reactions @@ -281,12 +280,12 @@ def _get_conc_pme(self, train: bool) -> mp.ConcPme: @computed_field def conc_pme_train(self) -> mp.ConcPme: """Add the conc_pme_train field.""" - return self._get_conc_pme(True) + return self._get_conc_pme(train=True) @computed_field def conc_pme_test(self) -> mp.ConcPme: """Add the conc_pme_test field.""" - return self._get_conc_pme(False) + return self._get_conc_pme(train=False) def _get_psi(self, train: bool) -> mp.Psi: """Add the psi_train field.""" @@ -311,4 +310,4 @@ def psi_train(self) -> mp.Psi: @computed_field def psi_test(self) -> mp.Psi: """Add the psi_test field.""" - return self._get_psi(train=True).test() + return self._get_psi(train=False)