diff --git a/maud/data_model/experiment.py b/maud/data_model/experiment.py index 9fea641e..843ba641 100644 --- a/maud/data_model/experiment.py +++ b/maud/data_model/experiment.py @@ -1,20 +1,13 @@ -"""Provides dataclass Experiment.""" +"""Provides model Experiment.""" from enum import Enum from typing import List, Optional -from pydantic import Field, validator -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, Field, computed_field, field_validator from maud.data_model.hardcoding import ID_SEPARATOR -class MSConfig: - """Config for MeasurementSet, allowing it to contain pandas objects.""" - - arbitrary_types_allowed = True - - class MeasurementType(str, Enum): """Possible types of measurement.""" @@ -23,8 +16,7 @@ class MeasurementType(str, Enum): ENZYME = "enzyme" -@dataclass -class Measurement: +class Measurement(BaseModel): """Maud representation of a measurement.""" experiment: str @@ -35,62 +27,61 @@ class Measurement: compartment: Optional[str] = None reaction: Optional[str] = None enzyme: Optional[str] = None - target_id: str = Field(default=None, init=False, exclude=True) - def __post_init__(self): + @computed_field + def target_id(self) -> str: """Add target_id field.""" if self.target_type == MeasurementType.MIC: - self.target_id = ID_SEPARATOR.join( - [self.metabolite, self.compartment] - ) + assert self.metabolite is not None + assert self.compartment is not None + return ID_SEPARATOR.join([self.metabolite, self.compartment]) elif self.target_type == MeasurementType.FLUX: - self.target_id = self.reaction - elif self.target_type == MeasurementType.ENZYME: - self.target_id = self.enzyme + assert self.reaction is not None + return self.reaction + else: + assert self.enzyme is not None + return self.enzyme -@dataclass -class EnzymeKnockout: +class EnzymeKnockout(BaseModel): """Maud representation of an enzyme being knocked out in an experiment.""" experiment: str enzyme: str - id: str = Field(init=False, exclude=True) - def __post_init__(self): + @computed_field + def id(self) -> str: """Add id field.""" - self.id = ID_SEPARATOR.join(["eko", self.experiment, self.enzyme]) + return ID_SEPARATOR.join(["eko", self.experiment, self.enzyme]) -@dataclass -class PhosphorylationModifyingEnzymeKnockout: +class PhosphorylationModifyingEnzymeKnockout(BaseModel): """Maud representation of a pme being knocked out in an experiment.""" experiment: str + enzyme: str pme: str - id: str = Field(init=False, exclude=True) - def __post_init__(self): + @computed_field + def id(self) -> str: """Add id field.""" - self.id = ID_SEPARATOR.join(["pko", self.experiment, self.enzyme]) + return ID_SEPARATOR.join(["pko", self.experiment, self.enzyme]) -@dataclass -class InitConcentration: +class InitConcentration(BaseModel): """Indication of the initial value of a concentration in the ODE.""" metabolite: str compartment: str value: float - target_id: str = Field(default=None, init=False, exclude=True) - def __post_init__(self): + @computed_field + def target_id(self) -> str: """Add target_id field.""" - self.target_id = ID_SEPARATOR.join([self.metabolite, self.compartment]) + return ID_SEPARATOR.join([self.metabolite, self.compartment]) -@dataclass -class Experiment: +class Experiment(BaseModel): """Maud representation of an experiment. This means a case where the boundary conditions and all measured quantities @@ -109,7 +100,7 @@ class Experiment: default_factory=lambda: [] ) - @validator("temperature") + @field_validator("temperature") def temp_must_be_non_negative(cls, v): """Make sure the temperature isn't negative.""" assert v >= 0 diff --git a/maud/data_model/kinetic_model.py b/maud/data_model/kinetic_model.py index edba893d..f7de22f6 100644 --- a/maud/data_model/kinetic_model.py +++ b/maud/data_model/kinetic_model.py @@ -4,8 +4,13 @@ from typing import Dict, List, Optional, Union import pandas as pd -from pydantic import Field, root_validator, validator -from pydantic.dataclasses import dataclass +from pydantic import ( + BaseModel, + ConfigDict, + computed_field, + field_validator, + model_validator, +) from maud.data_model.hardcoding import ID_SEPARATOR @@ -25,50 +30,41 @@ class ModificationType(int, Enum): inhibition = 2 -class KMConfig: - """Config allowing the KineticModel class to contain pandas objects.""" - - arbitrary_types_allowed = True - - -@dataclass -class Metabolite: +class Metabolite(BaseModel): """Maud representation of a metabolite.""" id: str name: Optional[str] inchi_key: Optional[str] - @validator("id") + @field_validator("id") def id_must_not_contain_seps(cls, v): """Check that the id doesn't contain ID_SEPARATOR.""" assert ID_SEPARATOR not in v return v -@dataclass -class Enzyme: +class Enzyme(BaseModel): """Maud representation of an enzyme.""" id: str name: Optional[str] subunits: int - @validator("id") + @field_validator("id") def id_must_not_contain_seps(cls, v): """Check that the id doesn't contain ID_SEPARATOR.""" assert ID_SEPARATOR not in v return v - @validator("subunits") + @field_validator("subunits") def subunits_must_be_positive(cls, v): """Check that the subunits attribute is biologically possible.""" assert v > 0 return v -@dataclass -class PhosphorylationModifyingEnzyme: +class PhosphorylationModifyingEnzyme(BaseModel): """Maud representation of a phosphorylation modifying enzyme. For example, a phosphatase? @@ -77,15 +73,14 @@ class PhosphorylationModifyingEnzyme: id: str - @validator("id") + @field_validator("id") def id_must_not_contain_seps(cls, v): """Check that the id doesn't contain ID_SEPARATOR.""" assert ID_SEPARATOR not in v return v -@dataclass -class Compartment: +class Compartment(BaseModel): """Maud representation of an intra-cellular compartment. For example, cytosol or mitochondria. @@ -96,15 +91,14 @@ class Compartment: name: Optional[str] volume: float - @validator("id") + @field_validator("id") def id_must_not_contain_seps(cls, v): """Check that the id doesn't contain ID_SEPARATOR.""" assert ID_SEPARATOR not in v return v -@dataclass -class Reaction: +class Reaction(BaseModel): """Maud representation of a chemical reaction.""" id: str @@ -114,21 +108,20 @@ class Reaction: water_stoichiometry: float transported_charge: float - @validator("id") + @field_validator("id") def id_must_not_contain_seps(cls, v): """Check that the id doesn't contain ID_SEPARATOR.""" assert ID_SEPARATOR not in v, "ID must not contain separator" return v - @validator("stoichiometry") + @field_validator("stoichiometry") def stoichiometry_must_be_non_zero(cls, v): """Check that the stoichiometry is not zero.""" assert v != 0, "stoichiometry must be non-zero" return v -@dataclass -class MetaboliteInCompartment: +class MetaboliteInCompartment(BaseModel): """Maud representation of a metabolite/compartment pair. This is needed because metabolites often exist in multiple compartments, and @@ -142,15 +135,14 @@ class MetaboliteInCompartment: metabolite_id: str compartment_id: str balanced: bool - id: str = Field(init=False, exclude=True) - def __post_init__(self): + @computed_field + def id(self) -> str: """Add the id field.""" - self.id = ID_SEPARATOR.join([self.metabolite_id, self.compartment_id]) + return ID_SEPARATOR.join([self.metabolite_id, self.compartment_id]) -@dataclass -class EnzymeReaction: +class EnzymeReaction(BaseModel): """Maud representation of an enzyme/reaction pair. This is needed because some enzymes catalyse multiple reactions. @@ -159,26 +151,25 @@ class EnzymeReaction: enzyme_id: str reaction_id: str - id: str = Field(init=False, exclude=True) - def __post_init__(self): + @computed_field + def id(self) -> str: """Add the id field.""" - self.id = self.enzyme_id + ID_SEPARATOR + self.reaction_id + return self.enzyme_id + ID_SEPARATOR + self.reaction_id -@dataclass -class Allostery: +class Allostery(BaseModel): """Maud representation of an allosteric modification.""" enzyme_id: str metabolite_id: str compartment_id: str modification_type: ModificationType - id: str = Field(init=False, exclude=True) - def __post_init__(self): - """Add the id and mic_id fields.""" - self.id = ID_SEPARATOR.join( + @computed_field + def id(self) -> str: + """Add the id field.""" + return ID_SEPARATOR.join( [ self.enzyme_id, self.metabolite_id, @@ -186,24 +177,25 @@ def __post_init__(self): self.modification_type.name, ] ) - self.mic_id = ID_SEPARATOR.join( - [self.metabolite_id, self.compartment_id] - ) + + @computed_field + def mic_id(self) -> str: + """Add the mic_id field.""" + return ID_SEPARATOR.join([self.metabolite_id, self.compartment_id]) -@dataclass -class CompetitiveInhibition: +class CompetitiveInhibition(BaseModel): """Maud representation of a competitive inhibition.""" enzyme_id: str reaction_id: str metabolite_id: str compartment_id: str - id: str = Field(init=False, exclude=True) - def __post_init__(self): - """Add the id, er_id and mic_id fields.""" - self.id = ID_SEPARATOR.join( + @computed_field + def id(self) -> str: + """Add the id field.""" + return ID_SEPARATOR.join( [ self.enzyme_id, self.reaction_id, @@ -211,25 +203,30 @@ def __post_init__(self): self.compartment_id, ] ) - self.er_id = ID_SEPARATOR.join([self.enzyme_id, self.reaction_id]) - self.mic_id = ID_SEPARATOR.join( - [self.metabolite_id, self.compartment_id] - ) + @computed_field + def er_id(self) -> str: + """Add the er_id field.""" + return ID_SEPARATOR.join([self.enzyme_id, self.reaction_id]) + + @computed_field + def mic_id(self) -> str: + """Add the mic_id field.""" + return ID_SEPARATOR.join([self.metabolite_id, self.compartment_id]) -@dataclass -class Phosphorylation: + +class Phosphorylation(BaseModel): """Maud representation of a phosphorylation modification.""" name: Optional[str] modifying_enzyme_id: str modified_enzyme_id: str modification_type: ModificationType - id: str = Field(init=False, exclude=True) - def __post_init__(self): + @computed_field + def id(self) -> str: """Add the id field.""" - self.id = ID_SEPARATOR.join( + return ID_SEPARATOR.join( [ self.modifying_enzyme_id, self.modified_enzyme_id, @@ -238,8 +235,7 @@ def __post_init__(self): ) -@dataclass(config=KMConfig) -class KineticModel: +class KineticModel(BaseModel): """Representation of a system of metabolic network.""" name: str @@ -253,26 +249,31 @@ class KineticModel: allosteric_enzymes: Optional[List[Enzyme]] competitive_inhibitions: Optional[List[CompetitiveInhibition]] phosphorylations: Optional[List[Phosphorylation]] - drains: List[Reaction] = Field(init=False, exclude=True) - edges: List[Union[Reaction, EnzymeReaction]] = Field( - init=False, exclude=True - ) - stoichiometric_matrix: pd.DataFrame = Field(init=False, exclude=True) - phosphorylation_modifying_enzymes: Optional[ - List[PhosphorylationModifyingEnzyme] - ] = Field(init=False, exclude=True) - - def __post_init__(self): - """Add drains, edges and stoichiometric matrix.""" - self.drains = [ + model_config: ConfigDict = {"arbitrary_types_allowed": True} + + @computed_field + def drains(self) -> List[Reaction]: + """Add the drains field.""" + return [ r for r in self.reactions if r.mechanism == ReactionMechanism.drain ] - self.edges = self.drains + self.ers - self.stoichiometric_matrix = get_stoichiometric_matrix( - self.edges, self.mics, self.reactions - ) - self.phosphorylation_modifying_enzymes = ( + @computed_field + def edges(self) -> List[Union[Reaction, EnzymeReaction]]: + """Add the edges field.""" + return self.drains + self.ers + + @computed_field + def stoichiometric_matrix(self) -> pd.DataFrame: + """Add the stoichiometric_matrix field.""" + return get_stoichiometric_matrix(self.edges, self.mics, self.reactions) + + @computed_field + def phosphorylation_modifying_enzymes( + self, + ) -> Optional[List[PhosphorylationModifyingEnzyme]]: + """Add the phosphorylation_modifying_enzymes field.""" + return ( [ PhosphorylationModifyingEnzyme(pme_id) for pme_id in list( @@ -283,108 +284,108 @@ def __post_init__(self): else None ) - @validator("metabolites") + @field_validator("metabolites") def metabolite_ids_must_be_unique(cls, v): """Make sure there aren't any duplicated metabolite ids.""" met_ids = [m.id for m in v] assert len(met_ids) == len(set(met_ids)) return v - @validator("enzymes") + @field_validator("enzymes") def enzyme_ids_must_be_unique(cls, v): """Make sure there aren't any duplicated enzyme ids.""" met_ids = [m.id for m in v] assert len(met_ids) == len(set(met_ids)) return v - @validator("compartments") + @field_validator("compartments") def compartment_ids_must_be_unique(cls, v): """Make sure there aren't any duplicated compartment ids.""" met_ids = [m.id for m in v] assert len(met_ids) == len(set(met_ids)) return v - @validator("reactions") + @field_validator("reactions") def reaction_ids_must_be_unique(cls, v): """Make sure there aren't any duplicated reaction ids.""" rxn_ids = [r.id for r in v] assert len(rxn_ids) == len(set(rxn_ids)) return v - @root_validator(pre=False, skip_on_failure=True) - def stoic_keys_must_be_mic_ids(cls, values): + @model_validator(mode="after") + def stoic_keys_must_be_mic_ids(self) -> "KineticModel": """Make sure reaction stoichiometries have existent mic ids.""" - mic_ids = [mic.id for mic in values["mics"]] - for r in values["reactions"]: + mic_ids = [mic.id for mic in self.mics] + for r in self.reactions: for stoich_mic_id in r.stoichiometry.keys(): assert ( stoich_mic_id in mic_ids ), f"{r.id} has stoichiometry for bad mic_id {stoich_mic_id}" - return values + return self - @root_validator(pre=False, skip_on_failure=True) - def mic_references_must_exist(cls, values): + @model_validator(mode="after") + def mic_references_must_exist(self) -> "KineticModel": """Make sure mics have existent metabolite and compartment ids.""" - metabolite_ids = [m.id for m in values["metabolites"]] - compartment_ids = [c.id for c in values["compartments"]] - for mic in values["mics"]: + metabolite_ids = [m.id for m in self.metabolites] + compartment_ids = [c.id for c in self.compartments] + for mic in self.mics: assert ( mic.metabolite_id in metabolite_ids ), f"{mic.id} has bad metabolite_id." assert ( mic.compartment_id in compartment_ids ), f"{mic.id} has bad compartment_id." - return values + return self - @root_validator(pre=False, skip_on_failure=True) - def er_references_must_exist(cls, values): + @model_validator(mode="after") + def er_references_must_exist(self) -> "KineticModel": """Make sure ers have existent enzyme and reaction ids.""" - enzyme_ids = [e.id for e in values["enzymes"]] - reaction_ids = [r.id for r in values["reactions"]] - for er in values["ers"]: + enzyme_ids = [e.id for e in self.enzymes] + reaction_ids = [r.id for r in self.reactions] + for er in self.ers: assert er.enzyme_id in enzyme_ids, f"{er.id} has bad enzyme_id" assert ( er.reaction_id in reaction_ids ), f"{er.id} has bad reaction_id" - return values + return self - @root_validator(pre=False, skip_on_failure=True) - def allostery_references_must_exist(cls, values): + @model_validator(mode="after") + def allostery_references_must_exist(self) -> "KineticModel": """Make sure allosteries' external ids exist.""" - if values["allosteries"] is None: - return values - enzyme_ids = [e.id for e in values["enzymes"]] - mic_ids = [mic.id for mic in values["mics"]] - for allostery in values["allosteries"]: + if self.allosteries is None: + return self + enzyme_ids = [e.id for e in self.enzymes] + mic_ids = [mic.id for mic in self.mics] + for allostery in self.allosteries: assert ( allostery.enzyme_id in enzyme_ids ), f"{allostery.id} has bad enzyme_id" assert allostery.mic_id in mic_ids, f"{allostery.id} has bad mic_id" - return values + return self - @root_validator(pre=False, skip_on_failure=True) - def ci_references_must_exist(cls, values): + @model_validator(mode="after") + def ci_references_must_exist(self) -> "KineticModel": """Make sure competitive inhibitions' external ids exist.""" - if values["competitive_inhibitions"] is None: - return values - er_ids = [er.id for er in values["ers"]] - mic_ids = [mic.id for mic in values["mics"]] - for ci in values["competitive_inhibitions"]: + if self.competitive_inhibitions is None: + return self + er_ids = [er.id for er in self.ers] + mic_ids = [mic.id for mic in self.mics] + for ci in self.competitive_inhibitions: assert ci.er_id in er_ids, f"{ci.id} has bad er_id" assert ci.mic_id in mic_ids, f"{ci.mic_id} has bad mic_id" - return values + return self - @root_validator(pre=False, skip_on_failure=True) - def phosphorylation_references_must_exist(cls, values): + @model_validator(mode="after") + def phosphorylation_references_must_exist(self): """Make sure phosphorylations' external ids exist.""" - if values["phosphorylations"] is None: - return values - enzyme_ids = [e.id for e in values["enzymes"]] - for p in values["phosphorylations"]: + if self.phosphorylations is None: + return self + enzyme_ids = [e.id for e in self.enzymes] + for p in self.phosphorylations: assert ( p.modified_enzyme_id in enzyme_ids ), f"{p.id} has bad enzyme_id" - return values + return self def get_stoichiometric_matrix( diff --git a/maud/data_model/maud_config.py b/maud/data_model/maud_config.py index dfe300de..18e283bb 100644 --- a/maud/data_model/maud_config.py +++ b/maud/data_model/maud_config.py @@ -1,31 +1,29 @@ -"""Provides dataclass MaudConfig.""" +"""Provides model MaudConfig.""" + from typing import Optional -from pydantic import Field -from pydantic.class_validators import root_validator -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, ConfigDict, Field, model_validator -@dataclass(frozen=True) -class ODESolverConfig: +class ODESolverConfig(BaseModel): """Config that is specific to an ODE solver.""" rel_tol: float = 1e-9 abs_tol: float = 1e-9 max_num_steps: int = int(1e7) + model_config: ConfigDict = {"frozen": True} -@dataclass(frozen=True) -class AlgebraSolverConfig: +class AlgebraSolverConfig(BaseModel): """Config that is specific to an ODE solver.""" rel_tol: float = 1e-7 abs_tol: float = 1e-7 max_num_steps: int = int(1e6) + model_config: ConfigDict = {"frozen": True} -@dataclass -class MaudConfig: +class MaudConfig(BaseModel): """User's configuration for a Maud input. :param name: name for the input. Used to name the output directory @@ -79,13 +77,11 @@ class MaudConfig: molecule_unit: str = "mmol" volume_unit: str = "L" - @root_validator - def do_not_penalize_if_rejecting(cls, values): + @model_validator(mode="after") + def do_not_penalize_if_rejecting(self): """Check that locations are non-null.""" - assert not ( - values["penalize_non_steady"] and values["reject_non_steady"] - ), ( + assert not self.penalize_non_steady and self.reject_non_steady, ( "Penalizing the non-steady state has no effect if the non-steady" " state is rejected; set one of the two to false." ) - return values + return self diff --git a/maud/data_model/maud_init.py b/maud/data_model/maud_init.py index 300d718b..c7e2556d 100644 --- a/maud/data_model/maud_init.py +++ b/maud/data_model/maud_init.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, computed_field from maud.data_model.experiment import Measurement from maud.data_model.hardcoding import ID_SEPARATOR @@ -12,8 +12,7 @@ from maud.data_model.prior import IndPrior1d, IndPrior2d, PriorMVN -@dataclass -class InitAtomInput: +class InitAtomInput(BaseModel): """Maud representation of an init input for a single quantity.""" init: float @@ -28,8 +27,7 @@ class InitAtomInput: ParamInitInput = Optional[List[InitAtomInput]] -@dataclass -class InitInput: +class InitInput(BaseModel): """Maud representation of a full init input.""" dgf: ParamInitInput = None @@ -57,91 +55,112 @@ def get_init_atom_input_ids( ) -@dataclass(init=False) -class Init1d: +class Init1d(BaseModel): """A 1 dimensional initial values specification.""" - inits_unscaled: List[float] - inits_scaled: Optional[List[float]] = None - - def __init__( - self, - ids: List[List[str]], - id_components: List[List[IdComponent]], - prior: Union[IndPrior1d, PriorMVN], - param_init_input: ParamInitInput, - non_negative: bool, - measurements: Optional[List[Measurement]] = None, - ): - inits_pd = pd.Series(prior.location, index=ids[0], dtype="float64") - if non_negative: # non-negative parameter location is on ln scale + ids: List[List[str]] + id_components: List[List[IdComponent]] + prior: Union[IndPrior1d, PriorMVN] + param_init_input: ParamInitInput + non_negative: bool + measurements: Optional[List[Measurement]] = None + + @computed_field + def inits_unscaled(self) -> List[float]: + """Add the inits_unscaled field.""" + inits_pd = pd.Series( + self.prior.location, index=self.ids[0], dtype="float64" + ) + if self.non_negative: # non-negative parameter location is on ln scale inits_pd = np.exp(inits_pd) - if measurements is not None: - for m in measurements: + if self.measurements is not None: + for m in self.measurements: if m.target_id in inits_pd.index: inits_pd.loc[m.target_id] = m.value - if param_init_input is not None: - for iai in param_init_input: - iai_id = get_init_atom_input_ids(iai, id_components)[0] + if self.param_init_input is not None: + for iai in self.param_init_input: + iai_id = get_init_atom_input_ids(iai, self.id_components)[0] if iai_id in inits_pd.index: inits_pd.loc[iai_id] = iai.init - self.inits_unscaled = inits_pd.tolist() - if isinstance(prior, PriorMVN): # no need to rescale an MVN parameter - self.inits_scaled = None - else: - inits_for_scaling = ( - np.log(inits_pd) if non_negative else inits_pd.copy() - ) - loc_pd = pd.Series(prior.location, index=ids[0], dtype="float64") - scale_pd = pd.Series(prior.scale, index=ids[0], dtype="float64") - inits_pd_scaled = (inits_for_scaling - loc_pd) / scale_pd - self.inits_scaled = inits_pd_scaled.tolist() - - -@dataclass(init=False) -class Init2d: + return inits_pd.tolist() + + @computed_field + def inits_scaled(self) -> Optional[List[float]]: + """Add the inits_scaled field.""" + if isinstance( + self.prior, PriorMVN + ): # no need to rescale an MVN parameter + return None + inits_pd = pd.Series( + self.inits_unscaled, index=self.ids[0], dtype="float64" + ) + inits_for_scaling = ( + np.log(inits_pd) if self.non_negative else inits_pd.copy() + ) + loc_pd = pd.Series( + self.prior.location, index=self.ids[0], dtype="float64" + ) + scale_pd = pd.Series( + self.prior.scale, index=self.ids[0], dtype="float64" + ) + inits_pd_scaled = (inits_for_scaling - loc_pd) / scale_pd + return inits_pd_scaled.tolist() + + +class Init2d(BaseModel): """A 2 dimensional initial values specification.""" - inits_unscaled: List[List[float]] - inits_scaled: Optional[List[List[float]]] = None - - def __init__( - self, - ids: List[List[str]], - id_components: List[List[IdComponent]], - prior: IndPrior2d, - param_init_input: ParamInitInput, - non_negative: bool, - measurements: Optional[List[Measurement]] = None, - ): - inits_pd = pd.DataFrame(prior.location, index=ids[0], columns=ids[1]) - if non_negative: # non-negative parameter location is on ln scale + ids: List[List[str]] + id_components: List[List[IdComponent]] + prior: IndPrior2d + param_init_input: ParamInitInput + non_negative: bool + measurements: Optional[List[Measurement]] = None + + @computed_field + def inits_unscaled(self) -> List[List[float]]: + """Add the inits_unscaled field.""" + inits_pd = pd.DataFrame( + self.prior.location, index=self.ids[0], columns=self.ids[1] + ) + if self.non_negative: # non-negative parameter location is on ln scale inits_pd = np.exp(inits_pd) - if measurements is not None: - for m in measurements: + if self.measurements is not None: + for m in self.measurements: if ( m.target_id in inits_pd.columns and m.experiment in inits_pd.index ): inits_pd.loc[m.experiment, m.target_id] = m.value - if param_init_input is not None: - for iai in param_init_input: + if self.param_init_input is not None: + for iai in self.param_init_input: iai_id_row, iai_id_col = get_init_atom_input_ids( - iai, id_components + iai, self.id_components ) if ( iai_id_row in inits_pd.index and iai_id_col in inits_pd.columns ): inits_pd.loc[iai_id_row, iai_id_col] = iai.init + return inits_pd.values.tolist() + + @computed_field + def inits_scaled(self) -> Optional[List[List[float]]]: + """Add the inits_scaled field.""" + inits_pd = pd.DataFrame( + self.inits_unscaled, index=self.ids[0], columns=self.ids[1] + ) inits_for_scaling = ( - np.log(inits_pd) if non_negative else inits_pd.copy() + np.log(inits_pd) if self.non_negative else inits_pd.copy() + ) + loc_pd = pd.DataFrame( + self.prior.location, index=self.ids[0], columns=self.ids[1] + ) + scale_pd = pd.DataFrame( + self.prior.scale, index=self.ids[0], columns=self.ids[1] ) - loc_pd = pd.DataFrame(prior.location, index=ids[0], columns=ids[1]) - scale_pd = pd.DataFrame(prior.scale, index=ids[0], columns=ids[1]) inits_pd_scaled = (inits_for_scaling - loc_pd) / scale_pd - self.inits_unscaled = inits_pd.values.tolist() - self.inits_scaled = inits_pd_scaled.values.tolist() + return inits_pd_scaled.values.tolist() Init = Union[Init1d, Init2d] diff --git a/maud/data_model/maud_input.py b/maud/data_model/maud_input.py index a1ea30c5..ffc0f305 100644 --- a/maud/data_model/maud_input.py +++ b/maud/data_model/maud_input.py @@ -1,55 +1,74 @@ -"""Provides dataclass MaudInput containing Maud needs to run.""" +"""Provides model MaudInput containing everything Maud needs to run.""" -from dataclasses import fields from typing import Dict, List -from pydantic import Field -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, Field, computed_field from maud.data_model.experiment import Experiment from maud.data_model.kinetic_model import KineticModel from maud.data_model.maud_config import MaudConfig from maud.data_model.maud_init import InitInput -from maud.data_model.maud_parameter import ParameterSet -from maud.data_model.prior_input import PriorInput -from maud.getting_parameters import get_maud_parameters +from maud.data_model.maud_parameter import MaudParameter +from maud.data_model.parameter_input import ParameterSetInput +from maud.data_model.parameter_set import ParameterSet from maud.getting_stan_inputs import get_stan_inputs -@dataclass -class MaudInput: +class MaudInput(BaseModel): """Everything that is needed to run Maud.""" config: MaudConfig kinetic_model: KineticModel experiments: List[Experiment] - prior_input: PriorInput = Field(default_factory=PriorInput) + parameter_set_input: ParameterSetInput = Field( + default_factory=ParameterSetInput + ) init_input: InitInput = Field(default_factory=InitInput) - parameters: ParameterSet = Field(init=False, exclude=True) - stan_input_train: Dict = Field(init=False, exclude=True) - stan_input_test: Dict = Field(init=False, exclude=True) - - def __post_init__(self): - """Add attributes that depend on other ones.""" - self.parameters = get_maud_parameters( - self.kinetic_model, - self.experiments, - self.prior_input, - self.init_input, + + @computed_field + def parameters(self) -> ParameterSet: + """Add the parameters field.""" + return ParameterSet( + kinetic_model=self.kinetic_model, + experiments=self.experiments, + parameter_set_input=self.parameter_set_input, + init_input=self.init_input, ) - self.stan_input_train, self.stan_input_test = get_stan_inputs( - self.parameters, - self.experiments, - self.kinetic_model, - self.config, + + @computed_field + def stan_input_train(self) -> Dict: + """Add the stan_input_train field.""" + train, _ = get_stan_inputs( + parameters=self.parameters, + experiments=self.experiments, + kinetic_model=self.kinetic_model, + config=self.config, ) + return train + + @computed_field + def stan_input_test(self) -> Dict: + """Add the stan_input_test field.""" + _, test = get_stan_inputs( + parameters=self.parameters, + experiments=self.experiments, + kinetic_model=self.kinetic_model, + config=self.config, + ) + return test + + @computed_field + def inits_dict(self) -> Dict: + """Add the inits_dict field.""" inits_dict = {} - for p in map( - lambda f: getattr(self.parameters, f.name), - fields(self.parameters), - ): + params = [ + getattr(self.parameters, p) + for p in self.parameters.dict().keys() + if isinstance(getattr(self.parameters, p), MaudParameter) + ] + for p in params: inits_dict[p.name] = p.inits.inits_unscaled if p.inits.inits_scaled is not None: scaled_pref = "log_" if p.non_negative else "" inits_dict[scaled_pref + p.name + "_z"] = p.inits.inits_scaled - self.inits_dict = inits_dict + return inits_dict diff --git a/maud/data_model/maud_parameter.py b/maud/data_model/maud_parameter.py index 98411115..a1e10b61 100644 --- a/maud/data_model/maud_parameter.py +++ b/maud/data_model/maud_parameter.py @@ -1,637 +1,331 @@ -"""Definitions of Stan variables and StanVariableSet.""" +"""Provides model MaudParameter, and subclasses for all parameters Maud uses.""" +from copy import deepcopy from typing import List, Optional, Union -from pydantic import Field, root_validator, validator -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, computed_field, field_validator, model_validator +from maud.data_model.experiment import Measurement from maud.data_model.hardcoding import ID_SEPARATOR from maud.data_model.id_component import IdComponent from maud.data_model.maud_init import Init, Init1d, Init2d, InitAtomInput -from maud.data_model.prior import Prior -from maud.data_model.prior_input import IndPriorAtomInput, PriorMVNInput -from maud.getting_priors import ( - get_ind_prior_1d, - get_ind_prior_2d, - get_mvn_prior, +from maud.data_model.parameter_input import ( + ParameterInputAtom, + ParameterInputMVN, ) +from maud.data_model.prior import IndPrior1d, IndPrior2d, PriorMVN -@dataclass -class MaudParameter: +class MaudParameter(BaseModel): """A parameter in Maud's statistical model.""" name: str shape_names: List[str] - ids: List[List[str]] id_components: List[List[IdComponent]] - split_ids: Optional[List[List[List[str]]]] non_negative: bool default_scale: float default_loc: float - init_input: Optional[List[InitAtomInput]] - prior_input: Optional[Union[PriorMVNInput, List[IndPriorAtomInput]]] prior_in_test_model: bool prior_in_train_model: bool - prior: Prior = Field(init=False, exclude=True) - inits: Init = Field(init=False, exclude=True) + user_input: Optional[Union[List[ParameterInputAtom], ParameterInputMVN]] + init_input: Optional[List[InitAtomInput]] + ids: List[List[str]] + split_ids: List[List[List[str]]] + measurements: Optional[List[Measurement]] = None + + @computed_field + def prior(self) -> Union[IndPrior1d, IndPrior2d, PriorMVN]: + """Return a prior, calculated from the user input.""" + if self.name == "dgf": + initialiser = PriorMVN + elif len(self.shape_names) == 1: + initialiser = IndPrior1d + else: + initialiser = IndPrior2d + return initialiser( + user_input=self.user_input, + ids=self.ids, + id_components=self.id_components, + non_negative=self.non_negative, + default_loc=self.default_loc, + default_scale=self.default_scale, + ) - @validator("id_components") + @computed_field + def inits(self) -> Init: + """Add the inits field.""" + initialiser = Init1d if len(self.shape_names) == 1 else Init2d + return initialiser( + ids=self.ids, + id_components=self.id_components, + prior=self.prior, + param_init_input=self.init_input, + non_negative=self.non_negative, + measurements=self.measurements, + ) + + @field_validator("id_components") def id_components_have_same_length(cls, v): """Make sure that id_components contains lists with the same length.""" first_length = len(v[0]) assert all([len(x) == first_length for x in v]) return v - @root_validator(pre=False, skip_on_failure=True) - def split_ids_exist_if_needed(cls, values): + @model_validator(mode="after") + def split_ids_exist_if_needed(self): """Check split ids exist when there are non-trivial id components.""" - if any(len(idc) > 1 for idc in values["id_components"]): - assert values["split_ids"] is not None - return values + if any(len(idc) > 1 for idc in self.id_components): + assert self.split_ids is not None + return self + + +class TrainTestParameter(MaudParameter): + """Mark parameter to have different priors between train and test. + + The class must be filled with `shape_names` and `name` for train. + The test version is created by calling the `test` method. + """ + + prior_in_test_model: bool = False + prior_in_train_model: bool = True + + def test(self): + """Generate the test counterpart.""" + test_self = deepcopy(self) + test_self.shape_names = [ + shape_name.replace("train", "test") + for shape_name in self.shape_names + ] + test_self.name = self.name.replace("train", "test") + test_self.prior_in_test_model = True + test_self.prior_in_train_model = False + return test_self -@dataclass(init=False) class Km(MaudParameter): """Parameter representing a model's Michaelis constants.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "km" - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.shape_names = ["N_km"] - self.id_components = [ - [ - IdComponent.ENZYME, - IdComponent.METABOLITE, - IdComponent.COMPARTMENT, - ] + name: str = "km" + shape_names: List[str] = ["N_km"] + id_components: List[List[IdComponent]] = [ + [ + IdComponent.ENZYME, + IdComponent.METABOLITE, + IdComponent.COMPARTMENT, ] - self.non_negative = True - self.default_loc = -0.69 # roughly 0.5 - self.default_scale = 1 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.mic_ids = [ + ] + non_negative: bool = True + default_loc: float = -0.69 # roughly 0.5 + default_scale: float = 1 + prior_in_test_model: bool = False + prior_in_train_model: bool = True + + @computed_field + def mic_ids(self) -> List[str]: + """Add the mic_ids field.""" + return [ ID_SEPARATOR.join([met_id, compartment_id]) for _, met_id, compartment_id in zip(*self.split_ids[0]) ] - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) - @validator("split_ids") + @field_validator("split_ids") def split_ids_must_have_right_shape(cls, v): """Check that there are the right number of split ids.""" assert v is not None, "Km split_ids are None" assert len(v) == 1, "Km split ids have wrong length" - assert len(v[0]) == 4 + assert len(v[0]) == 3, f"Wrong number of split id components: {v[0]}" return v -@dataclass(init=False) class Kcat(MaudParameter): """Stan variable representing a model's turnover numbers.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "kcat" - self.ids = ids - self.shape_names = ["N_enzyme_reaction"] - self.id_components = [[IdComponent.ENZYME, IdComponent.REACTION]] - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.non_negative = True - self.default_loc = -0.69 # roughly 0.5 - self.default_scale = 1 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) + name: str = "kcat" + shape_names: List[str] = ["N_enzyme_reaction"] + id_components: List[List[IdComponent]] = [ + [IdComponent.ENZYME, IdComponent.REACTION] + ] + non_negative: bool = True + default_loc: float = -0.69 # roughly 0.5 + default_scale: float = 1 + prior_in_test_model: bool = False + prior_in_train_model: bool = True -@dataclass(init=False) class Ki(MaudParameter): """Stan variable representing a model's inhibition constants.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "ki" - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.shape_names = ["N_ci"] - self.id_components = [ - [ - IdComponent.ENZYME, - IdComponent.REACTION, - IdComponent.METABOLITE, - IdComponent.COMPARTMENT, - ] + name: str = "ki" + shape_names: List[str] = ["N_ci"] + id_components: List[List[IdComponent]] = [ + [ + IdComponent.ENZYME, + IdComponent.REACTION, + IdComponent.METABOLITE, + IdComponent.COMPARTMENT, ] - self.non_negative = True - self.default_loc = -0.69 # roughly 0.5 - self.default_scale = 1 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.mic_ids = [ + ] + non_negative: bool = True + default_loc: float = -0.69 # roughly 0.5 + default_scale: float = 1 + prior_in_test_model: bool = False + prior_in_train_model: bool = True + + @computed_field + def mic_ids(self) -> List[str]: + """Get the mic ids.""" + return [ ID_SEPARATOR.join([met_id, compartment_id]) for _, _, met_id, compartment_id in zip(*self.split_ids[0]) ] - self.er_ids = [ - enzyme_id + + @computed_field + def er_ids(self) -> List[str]: + """Get the enzyme-reaction ids.""" + return [ + ID_SEPARATOR.join([enzyme_id, reaction_id]) for enzyme_id, reaction_id, _, _ in zip(*self.split_ids[0]) ] - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) -@dataclass(init=False) class Dgf(MaudParameter): """Stan variable representing a model's standard formation energies.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "dgf" - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.shape_names = ["N_metabolite"] - self.id_components = [[IdComponent.METABOLITE]] - self.non_negative = False - self.default_loc = 0 - self.default_scale = 10 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.prior = get_mvn_prior( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) + name: str = "dgf" + shape_names: List[str] = ["N_metabolite"] + id_components: List[List[IdComponent]] = [[IdComponent.METABOLITE]] + non_negative: bool = False + default_loc: float = 0 + default_scale: float = 10 + prior_in_test_model: bool = False + prior_in_train_model: bool = True -@dataclass(init=False) class DissociationConstant(MaudParameter): """Stan variable representing a model's dissociation constants.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "dissociation_constant" - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.shape_names = ["N_aa"] - self.id_components = [ - [ - IdComponent.ENZYME, - IdComponent.METABOLITE, - IdComponent.COMPARTMENT, - IdComponent.MODIFICATION_TYPE, - ] + name: str = "dissociation_constant" + shape_names: List[str] = ["N_aa"] + id_components: List[List[IdComponent]] = [ + [ + IdComponent.ENZYME, + IdComponent.METABOLITE, + IdComponent.COMPARTMENT, + IdComponent.MODIFICATION_TYPE, ] - self.non_negative = True - self.mic_ids = [ + ] + non_negative: bool = True + default_loc: float = -0.69 # roughly 0.5 + default_scale: float = 1 + prior_in_test_model: bool = False + prior_in_train_model: bool = True + + @computed_field + def mic_ids(self) -> List[str]: + """Get the mic ids.""" + return [ ID_SEPARATOR.join([met_id, compartment_id]) - for _, met_id, compartment_id, mt in zip(*self.split_ids[0]) + for _, _, met_id, compartment_id in zip(*self.split_ids[0]) ] - self.enzyme_ids = self.split_ids[0][0] - self.default_loc = -0.69 # roughly 0.5 - self.default_scale = 1 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) + + @computed_field + def enzyme_ids(self) -> List[str]: + """Get the enzyme ids.""" + return self.split_ids[0][0] -@dataclass(init=False) class TransferConstant(MaudParameter): """Stan variable representing a model's transfer constants.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "transfer_constant" - self.ids = ids - self.shape_names = ["N_ae"] - self.id_components = [[IdComponent.ENZYME]] - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.non_negative = True - self.default_loc = -0.69 # roughly 0.5 - self.default_scale = 1 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) + name: str = "transfer_constant" + shape_names: List[str] = ["N_ae"] + id_components: List[List[IdComponent]] = [[IdComponent.ENZYME]] + non_negative: bool = True + default_loc: float = -0.69 # roughly 0.5 + default_scale: float = 1 + prior_in_test_model: bool = False + prior_in_train_model: bool = True -@dataclass(init=False) class KcatPme(MaudParameter): """Stan variable representing Kcats of phosphorylation modifying enzymes.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "kcat_pme" - self.ids = ids - self.shape_names = ["N_pme"] - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.id_components = [[IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME]] - self.non_negative = True - self.default_loc = -0.69 # roughly 0.5 - self.default_scale = 1 - self.prior_in_test_model = False - self.prior_in_train_model = True - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) + name: str = "kcat_pme" + shape_names: List[str] = ["N_pme"] + id_components: List[List[IdComponent]] = [ + [IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME] + ] + non_negative: bool = True + default_loc: float = -0.69 # roughly 0.5 + default_scale: float = 1 + prior_in_test_model: bool = False + prior_in_train_model: bool = True -@dataclass(init=False) -class Drain(MaudParameter): +class Drain(TrainTestParameter): """Stan variable type for drain parameters.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.ids = ids - self.id_components = [[IdComponent.EXPERIMENT], [IdComponent.REACTION]] - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.non_negative = False - self.default_loc = 0 - self.default_scale = 1 - self.prior = get_ind_prior_2d( - self.prior_input, - self.ids, - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init2d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) - - -@dataclass(init=False) -class DrainTrain(Drain): - """Stan variable for drain parameters of training experiments.""" + name: str = "drain_train" + shape_names: List[str] = ["N_experiment_train", "N_drain"] + id_components: List[List[IdComponent]] = [ + [IdComponent.EXPERIMENT], + [IdComponent.REACTION], + ] + non_negative: bool = False + default_loc: float = 0 + default_scale: float = 1 - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "drain_train" - self.shape_names = ["N_experiment_train", "N_drain"] - self.prior_in_test_model = False - self.prior_in_train_model = True - super().__init__(ids, split_ids, prior_input, init_input) - -@dataclass(init=False) -class DrainTest(Drain): - """Stan variable for drain parameters of test experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "drain_test" - self.shape_names = ["N_experiment_test", "N_drain"] - self.prior_in_test_model = True - self.prior_in_train_model = False - super().__init__(ids, split_ids, prior_input, init_input) - - -@dataclass(init=False) -class ConcEnzyme(MaudParameter): +class ConcEnzyme(TrainTestParameter): """Parent class for enzyme concentration parameters.""" - def __init__(self, ids, split_ids, prior_input, init_input, measurements): - self.ids = ids - self.id_components = [[IdComponent.EXPERIMENT], [IdComponent.ENZYME]] - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.non_negative = True - self.default_loc = -2.3 - self.default_scale = 2.0 - self.default_loc = 0.5 - self.default_scale = 1 - self.prior = get_ind_prior_2d( - self.prior_input, - self.ids, - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init2d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - measurements, - ) - - -@dataclass(init=False) -class ConcEnzymeTrain(ConcEnzyme): - """Enzyme concentration parameters in training experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input, measurements): - self.name = "conc_enzyme_train" - self.shape_names = ["N_experiment_train", "N_enzyme"] - self.prior_in_test_model = False - self.prior_in_train_model = True - super().__init__(ids, split_ids, prior_input, init_input, measurements) + name: str = "conc_enzyme_train" + shape_names: List[str] = ["N_experiment_train", "N_enzyme"] + id_components: List[List[IdComponent]] = [ + [IdComponent.EXPERIMENT], + [IdComponent.ENZYME], + ] + non_negative: bool = True + default_loc: float = -2.3 + default_scale: float = 2.0 + default_loc: float = 0.5 + default_scale: float = 1 -@dataclass(init=False) -class ConcEnzymeTest(ConcEnzyme): - """Enzyme concentration parameters in test experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input, measurements): - self.name = "conc_enzyme_test" - self.shape_names = ["N_experiment_test", "N_enzyme"] - self.prior_in_test_model = True - self.prior_in_train_model = False - super().__init__(ids, split_ids, prior_input, init_input, measurements) - - -@dataclass(init=False) -class ConcUnbalanced(MaudParameter): +class ConcUnbalanced(TrainTestParameter): """Parent class for unbalanced mic concentration parameters.""" - def __init__(self, ids, split_ids, prior_input, init_input, measurements): - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.id_components = [ - [IdComponent.EXPERIMENT], - [IdComponent.METABOLITE, IdComponent.COMPARTMENT], - ] - self.non_negative = True - self.default_loc = -2.3 - self.default_scale = 2.0 - self.prior = get_ind_prior_2d( - self.prior_input, - self.ids, - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init2d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - measurements, - ) - - -@dataclass(init=False) -class ConcUnbalancedTrain(ConcUnbalanced): - """Unbalanced mic concentration parameters in training experiments.""" + name: str = "conc_unbalanced_train" + shape_names: List[str] = ["N_experiment_train", "N_unbalanced"] + id_components: List[List[IdComponent]] = [ + [IdComponent.EXPERIMENT], + [IdComponent.METABOLITE, IdComponent.COMPARTMENT], + ] + non_negative: bool = True + default_loc: float = -2.3 + default_scale: float = 2.0 - def __init__(self, ids, split_ids, prior_input, init_input, measurements): - self.name = "conc_unbalanced_train" - self.shape_names = ["N_experiment_train", "N_unbalanced"] - self.prior_in_test_model = False - self.prior_in_train_model = True - super().__init__(ids, split_ids, prior_input, init_input, measurements) - -@dataclass(init=False) -class ConcUnbalancedTest(ConcUnbalanced): - """Unbalanced mic concentration parameters in test experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input, measurements): - self.name = "conc_unbalanced_test" - self.shape_names = ["N_experiment_test", "N_enzyme"] - self.prior_in_test_model = True - self.prior_in_train_model = False - super().__init__(ids, split_ids, prior_input, init_input, measurements) - - -@dataclass(init=False) -class ConcPme(MaudParameter): +class ConcPme(TrainTestParameter): """Parent class for pme concentration parameters.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.id_components = [ - [IdComponent.EXPERIMENT], - [IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME], - ] - self.non_negative = True - self.default_loc = 0.1 - self.default_scale = 2.0 - self.prior = get_ind_prior_2d( - self.prior_input, - self.ids, - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init2d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) - - -@dataclass(init=False) -class ConcPmeTrain(ConcPme): - """Pme concentration parameters in training experiments.""" + name: str = "conc_pme_train" + shape_names: List[str] = ["N_experiment_train", "N_pme"] + id_components: List[List[IdComponent]] = [ + [IdComponent.EXPERIMENT], + [IdComponent.PHOSPHORYLATION_MODIFYING_ENZYME], + ] + non_negative: bool = True + default_loc: float = 0.1 + default_scale: float = 2.0 - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "conc_pme_train" - self.shape_names = ["N_experiment_train", "N_pme"] - self.prior_in_test_model = False - self.prior_in_train_model = True - super().__init__(ids, split_ids, prior_input, init_input) - -@dataclass(init=False) -class ConcPmeTest(ConcPme): - """Pme concentration parameters in test experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "conc_pme_test" - self.shape_names = ["N_experiment_test", "N_pme"] - self.prior_in_test_model = True - self.prior_in_train_model = False - super().__init__(ids, split_ids, prior_input, init_input) - - -@dataclass(init=False) -class Psi(MaudParameter): +class Psi(TrainTestParameter): """Stan variable representing per-experiment membrane potentials.""" - def __init__(self, ids, split_ids, prior_input, init_input): - self.ids = ids - self.split_ids = split_ids - self.prior_input = prior_input - self.init_input = init_input - self.id_components = [[IdComponent.EXPERIMENT]] - self.non_negative = False - self.default_loc = 0 - self.default_scale = 2 - self.prior = get_ind_prior_1d( - self.prior_input, - self.ids[0], - self.id_components, - self.non_negative, - self.default_loc, - self.default_scale, - ) - self.inits = Init1d( - self.ids, - self.id_components, - self.prior, - self.init_input, - self.non_negative, - ) - - -@dataclass(init=False) -class PsiTrain(Psi): - """Pme concentration parameters in training experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "psi_train" - self.shape_names = ["N_experiment_train"] - self.prior_in_test_model = False - self.prior_in_train_model = True - super().__init__(ids, split_ids, prior_input, init_input) - - -@dataclass(init=False) -class PsiTest(Psi): - """Pme concentration parameters in test experiments.""" - - def __init__(self, ids, split_ids, prior_input, init_input): - self.name = "psi_test" - self.shape_names = ["N_experiment_test"] - self.prior_in_test_model = True - self.prior_in_train_model = False - super().__init__(ids, split_ids, prior_input, init_input) - - -@dataclass -class ParameterSet: - """The parameters of a Maud input.""" - - dgf: Dgf - km: Km - ki: Ki - kcat: Kcat - dissociation_constant: DissociationConstant - transfer_constant: TransferConstant - kcat_pme: KcatPme - drain_train: DrainTrain - drain_test: DrainTest - conc_enzyme_train: ConcEnzymeTrain - conc_enzyme_test: ConcEnzymeTest - conc_unbalanced_train: ConcUnbalancedTrain - conc_unbalanced_test: ConcUnbalancedTest - conc_pme_train: ConcPmeTrain - conc_pme_test: ConcPmeTest - psi_train: PsiTrain - psi_test: PsiTest + name: str = "psi_train" + shape_names: List[str] = ["N_experiment_train"] + id_components: List[List[IdComponent]] = [[IdComponent.EXPERIMENT]] + non_negative: bool = False + default_loc: float = 0 + default_scale: float = 2 diff --git a/maud/data_model/parameter_input.py b/maud/data_model/parameter_input.py index 03da4254..7e144c04 100644 --- a/maud/data_model/parameter_input.py +++ b/maud/data_model/parameter_input.py @@ -1,13 +1,17 @@ """Definitions of the user's input for priors. Directly read from toml.""" -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union -from pydantic import root_validator, validator -from pydantic.dataclasses import dataclass +from pydantic import ( + BaseModel, + Field, + PositiveFloat, + field_validator, + model_validator, +) -@dataclass -class ParameterInputAtom: +class ParameterInputAtom(BaseModel): """Parameter input for a single quantity.""" metabolite: Optional[str] = None @@ -19,26 +23,27 @@ class ParameterInputAtom: modification_type: Optional[str] = None location: Optional[float] = None exploc: Optional[float] = None - scale: Optional[float] = None + scale: Optional[PositiveFloat] = None pct1: Optional[float] = None pct99: Optional[float] = None init: Optional[float] = None + fixed_value: Optional[float] = None - @validator("scale") + @field_validator("scale") def scale_is_positive(cls, v): """Check that scale is positive.""" if v is not None and v <= 0: raise ValueError("scale must be a positive number.") return v - @root_validator(pre=False, skip_on_failure=True) - def prior_is_specified_correctly(cls, values): + @model_validator(mode="before") + def prior_is_specified_correctly(cls, data): """Check that location, scale etc are correct.""" - lc = values["location"] - el = values["exploc"] - sc = values["scale"] - p1 = values["pct1"] - p99 = values["pct99"] + lc = data["location"] if "location" in data.keys() else None + el = data["exploc"] if "exploc" in data.keys() else None + sc = data["scale"] if "scale" in data.keys() else None + p1 = data["pct1"] if "pct1" in data.keys() else None + p99 = data["pct99"] if "pct99" in data.keys() else None happy_cases = [ {"not_none": [lc, sc], "none": [el, p1, p99]}, {"not_none": [el, sc], "none": [lc, p1, p99]}, @@ -54,31 +59,40 @@ def prior_is_specified_correctly(cls, values): "Set one out of the following pairs of attributes: " "location and scale, exploc and scale, or pct1 and pct99." ) - return values + return data -@dataclass -class ParameterInputMVN: +class ParameterInputMVN(BaseModel): """User input for a parameter with multivariate normal prior.""" ids: List[str] + fixed_values: Optional[Dict[str, float]] = None mean_vector: List[float] covariance_matrix: List[List[float]] -@dataclass -class ParametersInput: +class ParameterSetInput(BaseModel): """User input for all parameters.""" - dgf: Optional[Union[ParameterInputMVN, List[ParameterInputAtom]]] = None - km: Optional[List[ParameterInputAtom]] = None - kcat: Optional[List[ParameterInputAtom]] = None - kcat_pme: Optional[List[ParameterInputAtom]] = None - ki: Optional[List[ParameterInputAtom]] = None - psi: Optional[List[ParameterInputAtom]] = None - dissociation_constant: Optional[List[ParameterInputAtom]] = None - transfer_constant: Optional[List[ParameterInputAtom]] = None - conc_unbalanced: Optional[List[ParameterInputAtom]] = None - drain: Optional[List[ParameterInputAtom]] = None - conc_enzyme: Optional[List[ParameterInputAtom]] = None - conc_pme: Optional[List[ParameterInputAtom]] = None + dgf: Optional[Union[ParameterInputMVN, List[ParameterInputAtom]]] = Field( + default_factory=list + ) + km: Optional[List[ParameterInputAtom]] = Field(default_factory=list) + kcat: Optional[List[ParameterInputAtom]] = Field(default_factory=list) + kcat_pme: Optional[List[ParameterInputAtom]] = Field(default_factory=list) + ki: Optional[List[ParameterInputAtom]] = Field(default_factory=list) + psi: Optional[List[ParameterInputAtom]] = Field(default_factory=list) + dissociation_constant: Optional[List[ParameterInputAtom]] = Field( + default_factory=list + ) + transfer_constant: Optional[List[ParameterInputAtom]] = Field( + default_factory=list + ) + conc_unbalanced: Optional[List[ParameterInputAtom]] = Field( + default_factory=list + ) + drain: Optional[List[ParameterInputAtom]] = Field(default_factory=list) + conc_enzyme: Optional[List[ParameterInputAtom]] = Field( + default_factory=list + ) + conc_pme: Optional[List[ParameterInputAtom]] = Field(default_factory=list) diff --git a/maud/data_model/parameter_set.py b/maud/data_model/parameter_set.py new file mode 100644 index 00000000..cc207c75 --- /dev/null +++ b/maud/data_model/parameter_set.py @@ -0,0 +1,313 @@ +"""Provides the ParameterSet model. + +This is where logic for constructing MaudParameter objects should live. + +""" + +from pydantic import BaseModel, computed_field + +import maud.data_model.maud_parameter as mp +from maud.data_model.experiment import Experiment, Measurement, MeasurementType +from maud.data_model.hardcoding import ID_SEPARATOR +from maud.data_model.kinetic_model import KineticModel, ReactionMechanism +from maud.data_model.maud_init import InitInput +from maud.data_model.parameter_input import ParameterSetInput + + +class ParameterSet(BaseModel): + """the parameters of a maud input.""" + + kinetic_model: KineticModel + experiments: list[Experiment] + parameter_set_input: ParameterSetInput + init_input: InitInput + + @computed_field + def dgf(self) -> mp.Dgf: + """Add the dgf field.""" + metabolite_ids = [m.id for m in self.kinetic_model.metabolites] + return mp.Dgf( + ids=[metabolite_ids], + split_ids=[[metabolite_ids]], + user_input=self.parameter_set_input.dgf, + init_input=self.init_input.dgf, + ) + + @computed_field + def km(self) -> mp.Km: + """Add the km field.""" + ids = [] + enzs = [] + mets = [] + cpts = [] + for er in self.kinetic_model.ers: + rxn = [ + r + for r in self.kinetic_model.reactions + if r.id == er.reaction_id + ][0] + enz = [ + e for e in self.kinetic_model.enzymes if e.id == er.enzyme_id + ][0] + mic_ids = ( + list(rxn.stoichiometry.keys()) + if rxn.mechanism + != ReactionMechanism.irreversible_michaelis_menten + else [k for k, v in rxn.stoichiometry.items() if v < 0] + ) + for mic_id in mic_ids: + id = ID_SEPARATOR.join([enz.id, mic_id]) + met_id, cpt_id = mic_id.split(ID_SEPARATOR) + if id not in ids: + ids.append(id) + enzs.append(enz.id) + mets.append(met_id) + cpts.append(cpt_id) + return mp.Km( + ids=[ids], + split_ids=[[enzs, mets, cpts]], + user_input=self.parameter_set_input.km, + init_input=self.init_input.km, + ) + + @computed_field + def ki(self) -> mp.Ki: + """Add the ki field.""" + ids = [] + enzs = [] + rxns = [] + mets = [] + cpts = [] + if self.kinetic_model.competitive_inhibitions is not None: + for ci in self.kinetic_model.competitive_inhibitions: + ids.append(ci.id) + enzs.append(ci.enzyme_id) + rxns.append(ci.reaction_id) + mets.append(ci.metabolite_id) + cpts.append(ci.compartment_id) + return mp.Ki( + ids=[ids], + split_ids=[[enzs, rxns, mets, cpts]], + user_input=self.parameter_set_input.ki, + init_input=self.init_input.ki, + ) + + @computed_field + def kcat(self) -> mp.Kcat: + """Add the kcat field.""" + ids = [] + enzs = [] + rxns = [] + for er in self.kinetic_model.ers: + ids.append(er.id) + enzs.append(er.enzyme_id) + rxns.append(er.reaction_id) + return mp.Kcat( + ids=[ids], + split_ids=[[enzs, rxns]], + user_input=self.parameter_set_input.kcat, + init_input=self.init_input.kcat, + ) + + @computed_field + def dissociation_constant(self) -> mp.DissociationConstant: + """Add the dissociation_constant field.""" + ids = [] + enzs = [] + mets = [] + cpts = [] + mts = [] + if self.kinetic_model.allosteries is not None: + for a in self.kinetic_model.allosteries: + ids.append(a.id) + enzs.append(a.enzyme_id) + mets.append(a.metabolite_id) + cpts.append(a.compartment_id) + mts.append(a.modification_type.name) + return mp.DissociationConstant( + ids=[ids], + split_ids=[[enzs, mets, cpts, mts]], + user_input=self.parameter_set_input.dissociation_constant, + init_input=self.init_input.dissociation_constant, + ) + + @computed_field + def transfer_constant(self) -> mp.TransferConstant: + """Add the transfer_constant field.""" + allosteric_enzyme_ids = ( + [e.id for e in self.kinetic_model.allosteric_enzymes] + if self.kinetic_model.allosteric_enzymes is not None + else [] + ) + return mp.TransferConstant( + ids=[allosteric_enzyme_ids], + split_ids=[[allosteric_enzyme_ids]], + user_input=self.parameter_set_input.transfer_constant, + init_input=self.init_input.transfer_constant, + ) + + @computed_field + def kcat_pme(self) -> mp.KcatPme: + """Add the kcat_pme field.""" + phos_modifying_enzymes = ( + [p.modifying_enzyme_id for p in self.kinetic_model.phosphorylations] + if self.kinetic_model.phosphorylations is not None + else [] + ) + return mp.KcatPme( + ids=[phos_modifying_enzymes], + split_ids=[[phos_modifying_enzymes]], + user_input=self.parameter_set_input.kcat_pme, + init_input=self.init_input.kcat_pme, + ) + + def _get_experiments(self, train: bool) -> list[str]: + return [ + e.id + for e in self.experiments + if (e.is_train if train else e.is_test) + ] + + def _get_drain(self, train: bool) -> mp.Drain: + drain_ids = [ + d.id + for d in self.kinetic_model.reactions + if d.mechanism == ReactionMechanism.drain + ] + exp_ids = self._get_experiments(train) + result = mp.Drain( + ids=[exp_ids, drain_ids], + split_ids=[[exp_ids], [drain_ids]], + user_input=self.parameter_set_input.drain, + init_input=self.init_input.drain, + ) + return result if train else result.test() + + @computed_field + def drain_train(self) -> mp.Drain: + """Add the drain_train field.""" + return self._get_drain(train=True) + + @computed_field + def drain_test(self) -> mp.Drain: + """Add the drain_test field.""" + return self._get_drain(train=False) + + def _get_measurements( + self, train: bool, mtype: MeasurementType + ) -> list[Measurement]: + return [ + m + for e in self.experiments + for m in e.measurements + if (e.is_train if train else e.is_test) and m.target_type == mtype + ] + + def _get_conc_enzyme(self, train: bool) -> mp.ConcEnzyme: + enzyme_ids = [e.id for e in self.kinetic_model.enzymes] + exp_ids = self._get_experiments(train) + measurements = self._get_measurements(train, MeasurementType.ENZYME) + result = mp.ConcEnzyme( + ids=[exp_ids, enzyme_ids], + split_ids=[[exp_ids], [enzyme_ids]], + user_input=self.parameter_set_input.conc_enzyme, + init_input=self.init_input.conc_enzyme, + measurements=measurements, + ) + return result if train else result.test() + + @computed_field + def conc_enzyme_train(self) -> mp.ConcEnzyme: + """Add the conc_enzyme_train field.""" + return self._get_conc_enzyme(train=True) + + @computed_field + def conc_enzyme_test(self) -> mp.ConcEnzyme: + """Add the conc_enzyme_test field.""" + return self._get_conc_enzyme(train=False) + + def _get_conc_unbalanced(self, train: bool) -> mp.ConcUnbalanced: + exp_ids = self._get_experiments(train) + measurements = self._get_measurements(train, MeasurementType.MIC) + unbalanced_mic_ids, unbalanced_mic_mets, unbalanced_mic_cpts = map( + list, + zip( + *[ + [m.id, m.metabolite_id, m.compartment_id] + for m in self.kinetic_model.mics + if not m.balanced + ] + ), + ) + result = mp.ConcUnbalanced( + ids=[exp_ids, unbalanced_mic_ids], + split_ids=[ + [exp_ids], + [unbalanced_mic_mets, unbalanced_mic_cpts], + ], + user_input=self.parameter_set_input.conc_unbalanced, + init_input=self.init_input.conc_unbalanced, + measurements=measurements, + ) + return result if train else result.test() + + @computed_field + def conc_unbalanced_train(self) -> mp.ConcUnbalanced: + """Add the conc_unbalanced_train field.""" + return self._get_conc_unbalanced(train=True) + + @computed_field + def conc_unbalanced_test(self) -> mp.ConcUnbalanced: + """Add the conc_unbalanced_test field.""" + return self._get_conc_unbalanced(train=False) + + def _get_conc_pme(self, train: bool) -> mp.ConcPme: + """Add the conc_pme_train field.""" + exp_ids = self._get_experiments(train) + pme_ids = ( + [p.modifying_enzyme_id for p in self.kinetic_model.phosphorylations] + if self.kinetic_model.phosphorylations is not None + else [] + ) + result = mp.ConcPme( + ids=[exp_ids, pme_ids], + split_ids=[[exp_ids], [pme_ids]], + user_input=self.parameter_set_input.conc_pme, + init_input=self.init_input.conc_pme, + ) + return result if train else result.test() + + @computed_field + def conc_pme_train(self) -> mp.ConcPme: + """Add the conc_pme_train field.""" + return self._get_conc_pme(train=True) + + @computed_field + def conc_pme_test(self) -> mp.ConcPme: + """Add the conc_pme_test field.""" + return self._get_conc_pme(train=False) + + def _get_psi(self, train: bool) -> mp.Psi: + """Add the psi_train field.""" + exp_ids = [ + e.id + for e in self.experiments + if (e.is_train if train else e.is_test) + ] + result = mp.Psi( + ids=[exp_ids], + split_ids=[[exp_ids]], + user_input=self.parameter_set_input.psi, + init_input=self.init_input.psi, + ) + return result if train else result.test() + + @computed_field + def psi_train(self) -> mp.Psi: + """Add the psi_train field.""" + return self._get_psi(train=True) + + @computed_field + def psi_test(self) -> mp.Psi: + """Add the psi_test field.""" + return self._get_psi(train=False) diff --git a/maud/data_model/prior.py b/maud/data_model/prior.py index c92c69f1..d6b2c75b 100644 --- a/maud/data_model/prior.py +++ b/maud/data_model/prior.py @@ -1,30 +1,111 @@ """Definitions of priors.""" import math -from typing import List, Union +from typing import List, Optional, Union import numpy as np +import pandas as pd from numpy.linalg import LinAlgError -from pydantic.class_validators import root_validator, validator -from pydantic.dataclasses import dataclass +from pydantic import ( + BaseModel, + NonNegativeFloat, + PositiveFloat, + computed_field, + field_validator, + model_validator, +) +from maud.data_model.hardcoding import ID_SEPARATOR +from maud.data_model.id_component import IdComponent +from maud.data_model.parameter_input import ( + ParameterInputAtom, + ParameterInputMVN, +) +from maud.utility_functions import ( + get_lognormal_parameters_from_quantiles, + get_normal_parameters_from_quantiles, +) -@dataclass -class IndPriorAtom: - """Prior for a single quantity.""" - location: float - scale: float +def get_pia_loc(pia: ParameterInputAtom, non_negative: bool) -> float: + """Get the location from a parameter input atom.""" + if pia.location is not None: + return pia.location + elif pia.exploc is not None: + return np.log(pia.exploc) + elif pia.pct1 is not None: + f = ( + get_lognormal_parameters_from_quantiles + if non_negative + else get_normal_parameters_from_quantiles + ) + loc, _ = f(pia.pct1, 0.01, pia.pct99, 0.99) + return loc + else: + raise ValueError(f"Incorrectly specified Prior input atom: {pia}") -@dataclass -class IndPrior1d: +def get_pia_scale(pia: ParameterInputAtom, non_negative: bool) -> float: + """Get the scale from a parameter input atom.""" + if pia.scale is not None: + return pia.scale + elif pia.pct1 is not None: + f = ( + get_lognormal_parameters_from_quantiles + if non_negative + else get_normal_parameters_from_quantiles + ) + _, scale = f(pia.pct1, 0.01, pia.pct99, 0.99) + return scale + else: + raise ValueError(f"Incorrectly specified Prior input atom: {pia}") + + +class IndPrior1d(BaseModel): """Independent location/scale prior for a 1D parameter.""" - location: List[float] - scale: List[float] + user_input: Optional[List[ParameterInputAtom]] + ids: List[List[str]] + id_components: List[List[IdComponent]] + non_negative: bool + default_loc: float + default_scale: PositiveFloat + + @computed_field + def location(self) -> List[float]: + """Add the location field.""" + if len(self.ids[0]) == 0: + return [] + loc_series = pd.Series(self.default_loc, index=self.ids) + if self.user_input is not None: + for pia in self.user_input: + loc_i = get_pia_loc(pia, non_negative=self.non_negative) + ids_i = [ + ID_SEPARATOR.join([getattr(pia, c) for c in idci]) + for idci in self.id_components + ] + if ids_i[0] in loc_series.index: + loc_series.loc[ids_i[0]] = loc_i + return loc_series.tolist() - @validator("location") + @computed_field + def scale(self) -> List[PositiveFloat]: + """Add the scale field.""" + if len(self.ids[0]) == 0: + return [] + scale_series = pd.Series(self.default_scale, index=self.ids) + if self.user_input is not None: + for pia in self.user_input: + scale_i = get_pia_scale(pia, non_negative=self.non_negative) + ids_i = [ + ID_SEPARATOR.join([getattr(pia, c) for c in idci]) + for idci in self.id_components + ] + if ids_i[0] in scale_series.index: + scale_series.loc[ids_i[0]] = scale_i + return scale_series.tolist() + + @field_validator("location") def no_null_locations(cls, v): """Check that locations are non-null.""" for x in v: @@ -32,7 +113,7 @@ def no_null_locations(cls, v): raise ValueError("Location cannot be null.") return v - @validator("scale") + @field_validator("scale") def no_null_scales(cls, v): """Check that scales are non-null.""" for x in v: @@ -40,33 +121,74 @@ def no_null_scales(cls, v): raise ValueError("Scale cannot be null.") return v - @validator("scale") + @field_validator("scale") def scales_are_positive(cls, v): """Check that scales are positive.""" if len(v) > 0 and any(x <= 0 for x in v): raise ValueError("Scale parameter must be positive.") return v - @root_validator(pre=False, skip_on_failure=True) - def lengths_match(cls, values): + @model_validator(mode="after") + def lengths_match(self): """Check that location and scale have the same index.""" - n_locs = len(values["location"]) - n_scales = len(values["scale"]) + n_locs = len(self.location) + n_scales = len(self.scale) if n_locs != n_scales: raise ValueError( "Location, scale and ids must have the same length." ) - return values + return self -@dataclass -class IndPrior2d: +class IndPrior2d(BaseModel): """Independent location/scale prior for a 2D parameter.""" - location: List[List[float]] - scale: List[List[float]] + user_input: Optional[List[ParameterInputAtom]] + ids: List[List[str]] + id_components: List[List[IdComponent]] + non_negative: bool + default_loc: float + default_scale: PositiveFloat + + @computed_field + def location(self) -> List[List[float]]: + """Add the location field.""" + if any(len(ids_i) == 0 for ids_i in self.ids): + return [[]] + loc_df = pd.DataFrame( + self.default_loc, index=self.ids[0], columns=self.ids[1] + ) + if self.user_input is not None: + for pia in self.user_input: + loc_i = get_pia_loc(pia, non_negative=self.non_negative) + ids_i = [ + ID_SEPARATOR.join([getattr(pia, c) for c in idci]) + for idci in self.id_components + ] + if ids_i[0] in loc_df.index and ids_i[1] in loc_df.columns: + loc_df.loc[ids_i[0], ids_i[1]] = loc_i + return loc_df.values.tolist() - @validator("location") + @computed_field + def scale(self) -> List[List[PositiveFloat]]: + """Add the scale field.""" + if any(len(ids_i) == 0 for ids_i in self.ids): + return [[]] + scale_df = pd.DataFrame( + self.default_scale, index=self.ids[0], columns=self.ids[1] + ) + if self.user_input is not None: + for pia in self.user_input: + scale_i = get_pia_scale(pia, non_negative=self.non_negative) + ids_i = [ + ID_SEPARATOR.join([getattr(pia, c) for c in idci]) + for idci in self.id_components + ] + if ids_i[0] in scale_df.index and ids_i[1] in scale_df.columns: + scale_df.loc[ids_i[0], ids_i[1]] = scale_i + return scale_df.values.tolist() + + @field_validator("location") def no_null_locations(cls, v): """Check that locations are non-null.""" for vi in v: @@ -75,7 +197,7 @@ def no_null_locations(cls, v): raise ValueError("Location cannot be null.") return v - @validator("scale") + @field_validator("scale") def no_null_scales(cls, v): """Check that scales are non-null.""" for vi in v: @@ -84,7 +206,7 @@ def no_null_scales(cls, v): raise ValueError("Scale cannot be null.") return v - @validator("scale") + @field_validator("scale") def scales_are_positive(cls, v): """Check that scales are all positive.""" for s in v: @@ -92,27 +214,78 @@ def scales_are_positive(cls, v): raise ValueError("Scale parameter must be positive.") return v - @root_validator(pre=False, skip_on_failure=True) - def lengths_are_correct(cls, values): + @model_validator(mode="after") + def lengths_are_correct(self): """Check that ids, location and scale have correct length.""" - loc = values["location"] - scale = values["scale"] + loc = self.location + scale = self.scale if not len(loc) == len(scale): raise ValueError("First dimension length incorrect.") for i, (loc_i, scale_i) in enumerate(zip(loc, scale)): if not len(loc_i) == len(scale_i): raise ValueError(f"Length mismatch at index {i}.") - return values + return self -@dataclass -class PriorMVN: +class PriorMVN(BaseModel): """Prior Location vector and covariance matrix for a 1D parameter.""" - location: List[float] - covariance_matrix: List[List[float]] + user_input: Optional[Union[List[ParameterInputAtom], ParameterInputMVN]] + ids: List[List[str]] + id_components: List[List[IdComponent]] + non_negative: bool + default_loc: float + default_scale: PositiveFloat + + @computed_field + def location(self) -> List[float]: + """Add the location field.""" + ids = self.ids[0] + loc_series = pd.Series(self.default_loc, index=self.ids) + if isinstance(self.user_input, ParameterInputMVN): + loc_series = pd.Series( + self.user_input.mean_vector, index=self.user_input.ids + ).reindex(ids) + elif self.user_input is not None: + for pia in self.user_input: + ids_i = [ + ID_SEPARATOR.join([getattr(pia, c) for c in idci]) + for idci in self.id_components + ] + loc_i = get_pia_loc(pia, non_negative=self.non_negative) + loc_series.loc[ids_i[0]] = loc_i + return loc_series.tolist() + + @computed_field + def covariance_matrix(self) -> List[List[NonNegativeFloat]]: + """Add the covariance_matrix field.""" + ids = self.ids[0] + cov_df = pd.DataFrame( + np.diagflat(np.tile(self.default_scale, len(ids))), + index=ids, + columns=ids, + ) + if isinstance(self.user_input, ParameterInputMVN): + cov_df = ( + pd.DataFrame( + self.user_input.covariance_matrix, + index=self.user_input.ids, + columns=self.user_input.ids, + ) + .reindex(ids) + .reindex(columns=ids) + ) + elif self.user_input is not None: + for pia in self.user_input: + ids_i = [ + ID_SEPARATOR.join([getattr(pia, c) for c in idci]) + for idci in self.id_components + ] + cov_ii = get_pia_scale(pia, non_negative=self.non_negative) + cov_df.loc[ids_i[0], ids_i[0]] = cov_ii + return cov_df.values.tolist() - @validator("covariance_matrix") + @field_validator("covariance_matrix") def no_null_covariances(cls, v): """Check that scales are non-null.""" for vi in v: @@ -121,7 +294,7 @@ def no_null_covariances(cls, v): raise ValueError("Covariance cannot be null.") return v - @validator("location") + @field_validator("location") def no_null_locations(cls, v): """Check that no locations are nans.""" for x in v: @@ -129,7 +302,7 @@ def no_null_locations(cls, v): raise ValueError("Location cannot be null.") return v - @validator("covariance_matrix") + @field_validator("covariance_matrix") def cov_matrix_is_pos_def(cls, v): """Check that covariance matrix is positive definite.""" try: @@ -140,18 +313,18 @@ def cov_matrix_is_pos_def(cls, v): ) from e return v - @root_validator(pre=False, skip_on_failure=True) - def lengths_are_correct(cls, values): + @model_validator(mode="after") + def lengths_are_correct(self): """Check that ids, location and cov matrix have correct lengths.""" - loc = values["location"] - cov = values["covariance_matrix"] + loc = self.location + cov = self.covariance_matrix if not all(len(x) == len(cov[0]) for x in cov): raise ValueError( "All elements of covariance matrix must have same length." ) if not len(loc) == len(cov[0]): raise ValueError("First dimension length incorrect.") - return values + return self Prior = Union[IndPrior1d, IndPrior2d, PriorMVN] diff --git a/maud/data_model/prior_input.py b/maud/data_model/prior_input.py deleted file mode 100644 index ce59f448..00000000 --- a/maud/data_model/prior_input.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Definitions of the user's input for priors. Directly read from toml.""" - -from typing import List, Optional, Union - -from pydantic import root_validator, validator -from pydantic.dataclasses import dataclass - - -@dataclass -class IndPriorAtomInput: - """Prior input for a single quantity.""" - - metabolite: Optional[str] = None - compartment: Optional[str] = None - enzyme: Optional[str] = None - reaction: Optional[str] = None - experiment: Optional[str] = None - phosphorylation_modifying_enzyme: Optional[str] = None - modification_type: Optional[str] = None - location: Optional[float] = None - exploc: Optional[float] = None - scale: Optional[float] = None - pct1: Optional[float] = None - pct99: Optional[float] = None - - @validator("scale") - def scale_is_positive(cls, v): - """Check that scale is positive.""" - if v is not None and v <= 0: - raise ValueError("scale must be a positive number.") - return v - - @root_validator(pre=False, skip_on_failure=True) - def prior_is_specified_correctly(cls, values): - """Check that location, scale etc are correct.""" - lc = values["location"] - el = values["exploc"] - sc = values["scale"] - p1 = values["pct1"] - p99 = values["pct99"] - happy_cases = [ - {"not_none": [lc, sc], "none": [el, p1, p99]}, - {"not_none": [el, sc], "none": [lc, p1, p99]}, - {"not_none": [p1, p99], "none": [lc, el, sc]}, - ] - good = [ - all(v is not None for v in case["not_none"]) - and all(v is None for v in case["none"]) - for case in happy_cases - ] - if not any(good): - raise ValueError( - "Set one out of the following pairs of attributes: " - "location and scale, exploc and scale, or pct1 and pct99." - ) - return values - - -@dataclass -class PriorMVNInput: - """Prior input for a parameter with multivariate normal distribution.""" - - ids: List[str] - mean_vector: List[float] - covariance_matrix: List[List[float]] - - -@dataclass -class PriorInput: - """A full prior input.""" - - dgf: Optional[Union[PriorMVNInput, List[IndPriorAtomInput]]] = None - km: Optional[List[IndPriorAtomInput]] = None - kcat: Optional[List[IndPriorAtomInput]] = None - kcat_pme: Optional[List[IndPriorAtomInput]] = None - ki: Optional[List[IndPriorAtomInput]] = None - psi: Optional[List[IndPriorAtomInput]] = None - dissociation_constant: Optional[List[IndPriorAtomInput]] = None - transfer_constant: Optional[List[IndPriorAtomInput]] = None - conc_unbalanced: Optional[List[IndPriorAtomInput]] = None - drain: Optional[List[IndPriorAtomInput]] = None - conc_enzyme: Optional[List[IndPriorAtomInput]] = None - conc_pme: Optional[List[IndPriorAtomInput]] = None diff --git a/maud/getting_parameters.py b/maud/getting_parameters.py deleted file mode 100644 index e637709f..00000000 --- a/maud/getting_parameters.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Provides function get_maud_parameters.""" - -from typing import List, Tuple - -from maud.data_model.experiment import Experiment, MeasurementType -from maud.data_model.hardcoding import ID_SEPARATOR -from maud.data_model.kinetic_model import KineticModel, ReactionMechanism -from maud.data_model.maud_init import InitInput -from maud.data_model.maud_parameter import ( - ConcEnzymeTest, - ConcEnzymeTrain, - ConcPmeTest, - ConcPmeTrain, - ConcUnbalancedTest, - ConcUnbalancedTrain, - Dgf, - DissociationConstant, - DrainTest, - DrainTrain, - Kcat, - KcatPme, - Ki, - Km, - ParameterSet, - PsiTest, - PsiTrain, - TransferConstant, -) -from maud.data_model.prior_input import PriorInput - -AllostericCoords = Tuple[List[str], List[str], List[str], List[str], List[str]] -KmCoords = Tuple[List[str], List[str], List[str], List[str]] -KcatCoords = Tuple[List[str], List[str], List[str]] -KiCoords = Tuple[List[str], List[str], List[str], List[str], List[str]] - - -def get_km_coords(kinetic_model: KineticModel) -> KmCoords: - """Get ids and split ids for a model's Kms.""" - km_ids = [] - km_enzs = [] - km_mets = [] - km_cpts = [] - for er in kinetic_model.ers: - rxn = [r for r in kinetic_model.reactions if r.id == er.reaction_id][0] - enz = [e for e in kinetic_model.enzymes if e.id == er.enzyme_id][0] - mic_ids = ( - list(rxn.stoichiometry.keys()) - if rxn.mechanism != ReactionMechanism.irreversible_michaelis_menten - else [k for k, v in rxn.stoichiometry.items() if v < 0] - ) - for mic_id in mic_ids: - km_id = ID_SEPARATOR.join([enz.id, mic_id]) - met_id, cpt_id = mic_id.split(ID_SEPARATOR) - if km_id not in km_ids: - km_ids.append(km_id) - km_enzs.append(enz.id) - km_mets.append(met_id) - km_cpts.append(cpt_id) - return km_ids, km_enzs, km_mets, km_cpts - - -def get_dc_coords(kinetic_model: KineticModel) -> AllostericCoords: - """Get ids and split ids for a model's dissociation constants.""" - ids = [] - enzs = [] - mets = [] - cpts = [] - mts = [] - if kinetic_model.allosteries is not None: - for a in kinetic_model.allosteries: - ids.append(a.id) - enzs.append(a.enzyme_id) - mets.append(a.metabolite_id) - cpts.append(a.compartment_id) - mts.append(a.modification_type.name) - return ids, enzs, mets, cpts, mts - - -def get_ci_coords(kinetic_model: KineticModel) -> KiCoords: - """Get ids and split ids for a model's competitive inhibition constants.""" - ids = [] - enzs = [] - rxns = [] - mets = [] - cpts = [] - if kinetic_model.competitive_inhibitions is not None: - for ci in kinetic_model.competitive_inhibitions: - ids.append(ci.id) - enzs.append(ci.enzyme_id) - rxns.append(ci.reaction_id) - mets.append(ci.metabolite_id) - cpts.append(ci.compartment_id) - return ids, enzs, rxns, mets, cpts - - -def get_kcat_coords(kinetic_model: KineticModel) -> KcatCoords: - """Get ids and split ids for a model's Kcats.""" - ids = [] - enzs = [] - rxns = [] - for er in kinetic_model.ers: - ids.append(er.id) - enzs.append(er.enzyme_id) - rxns.append(er.reaction_id) - return ids, enzs, rxns - - -def get_maud_parameters( - kmod: KineticModel, - experiments: List[Experiment], - pi: PriorInput, - ii: InitInput, -): - """Get a ParameterSet object from a KineticModel and a MeasurementSet.""" - km_ids, km_enzs, km_mets, km_cpts = get_km_coords(kmod) - dc_ids, dc_enzs, dc_mets, dc_cpts, dc_mts = get_dc_coords(kmod) - ci_ids, ci_enzs, ci_rxns, ci_mets, ci_cpts = get_ci_coords(kmod) - enzyme_ids = [e.id for e in kmod.enzymes] - kcat_ids, kcat_enzs, kcat_rxns = get_kcat_coords(kmod) - allosteric_enzyme_ids = ( - [e.id for e in kmod.allosteric_enzymes] - if kmod.allosteric_enzymes is not None - else [] - ) - metabolite_ids = [m.id for m in kmod.metabolites] - phos_modifying_enzymes = ( - [p.modifying_enzyme_id for p in kmod.phosphorylations] - if kmod.phosphorylations is not None - else [] - ) - drain_ids = [ - d.id for d in kmod.reactions if d.mechanism == ReactionMechanism.drain - ] - unbalanced_mic_ids, unbalanced_mic_mets, unbalanced_mic_cpts = map( - list, - zip( - *[ - [m.id, m.metabolite_id, m.compartment_id] - for m in kmod.mics - if not m.balanced - ] - ), - ) - exp_ids_train = [e.id for e in experiments if e.is_train] - exp_ids_test = [e.id for e in experiments if e.is_test] - return ParameterSet( - dgf=Dgf([metabolite_ids], [[metabolite_ids]], pi.dgf, ii.dgf), - km=Km([km_ids], [[km_enzs, km_mets, km_cpts]], pi.km, ii.km), - kcat=Kcat([kcat_ids], [[kcat_enzs, kcat_rxns]], pi.kcat, ii.kcat), - ki=Ki([ci_ids], [[ci_enzs, ci_rxns, ci_mets, ci_cpts]], pi.ki, ii.ki), - dissociation_constant=DissociationConstant( - [dc_ids], - [[dc_enzs, dc_mets, dc_cpts, dc_mts]], - pi.dissociation_constant, - ii.dissociation_constant, - ), - transfer_constant=TransferConstant( - [allosteric_enzyme_ids], - [[allosteric_enzyme_ids]], - pi.transfer_constant, - ii.transfer_constant, - ), - kcat_pme=KcatPme( - [phos_modifying_enzymes], - [[phos_modifying_enzymes]], - pi.kcat_pme, - ii.kcat_pme, - ), - drain_train=DrainTrain( - [exp_ids_train, drain_ids], - [[exp_ids_train], [drain_ids]], - pi.drain, - ii.drain, - ), - drain_test=DrainTest( - [exp_ids_train, drain_ids], - [[exp_ids_train], [drain_ids]], - pi.drain, - ii.drain, - ), - conc_enzyme_train=ConcEnzymeTrain( - [exp_ids_train, enzyme_ids], - [[exp_ids_train], [enzyme_ids]], - pi.conc_enzyme, - ii.conc_enzyme, - [ - m - for e in experiments - for m in e.measurements - if e.is_train and m.target_type == MeasurementType.ENZYME - ], - ), - conc_enzyme_test=ConcEnzymeTest( - [exp_ids_train, enzyme_ids], - [[exp_ids_train], [enzyme_ids]], - pi.conc_enzyme, - ii.conc_enzyme, - [ - m - for e in experiments - for m in e.measurements - if e.is_test and m.target_type == MeasurementType.ENZYME - ], - ), - conc_unbalanced_train=ConcUnbalancedTrain( - [exp_ids_train, unbalanced_mic_ids], - [ - [exp_ids_train], - [unbalanced_mic_mets, unbalanced_mic_cpts], - ], - pi.conc_unbalanced, - ii.conc_unbalanced, - [ - m - for e in experiments - for m in e.measurements - if e.is_train and m.target_type == MeasurementType.MIC - ], - ), - conc_unbalanced_test=ConcUnbalancedTest( - [exp_ids_train, unbalanced_mic_ids], - [ - [exp_ids_train], - [unbalanced_mic_mets, unbalanced_mic_cpts], - ], - pi.conc_unbalanced, - ii.conc_unbalanced, - [ - m - for e in experiments - for m in e.measurements - if e.is_test and m.target_type == MeasurementType.MIC - ], - ), - conc_pme_train=ConcPmeTrain( - [exp_ids_train, phos_modifying_enzymes], - [[exp_ids_train], [phos_modifying_enzymes]], - pi.conc_pme, - ii.conc_pme, - ), - conc_pme_test=ConcPmeTest( - [exp_ids_train, phos_modifying_enzymes], - [[exp_ids_train], [phos_modifying_enzymes]], - pi.conc_pme, - ii.conc_pme, - ), - psi_train=PsiTrain([exp_ids_train], [[exp_ids_test]], pi.psi, ii.psi), - psi_test=PsiTest([exp_ids_train], [[exp_ids_train]], pi.psi, ii.psi), - ) diff --git a/maud/getting_priors.py b/maud/getting_priors.py deleted file mode 100644 index 9fc4a304..00000000 --- a/maud/getting_priors.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Functions for creating prior objects from PriorInput objects. - -This module handles setting priors to default values and assigning priors -consistently with their ids. - -""" - -from typing import List, Optional, Tuple, Union - -import numpy as np -import pandas as pd - -from maud.data_model.hardcoding import ID_SEPARATOR -from maud.data_model.maud_parameter import IdComponent -from maud.data_model.prior import IndPrior1d, IndPrior2d, PriorMVN -from maud.data_model.prior_input import IndPriorAtomInput, PriorMVNInput -from maud.utility_functions import ( - get_lognormal_parameters_from_quantiles, - get_normal_parameters_from_quantiles, -) - - -def get_loc_and_scale( - ipai: IndPriorAtomInput, non_negative: bool -) -> Tuple[float, float]: - """Get location and scale from a user input prior. - - This is non-trivial because of the "exploc" input and the option to input - priors as percentiles. - - """ - if ipai.location is not None and ipai.scale is not None: - return ipai.location, ipai.scale - elif ipai.exploc is not None and ipai.scale is not None: - return np.log(ipai.exploc), ipai.scale - elif ipai.pct1 is not None and ipai.pct99 is not None: - quantfunc = ( - get_lognormal_parameters_from_quantiles - if non_negative - else get_normal_parameters_from_quantiles - ) - return quantfunc(ipai.pct1, 0.01, ipai.pct99, 0.99) - else: - raise ValueError("Incorrect prior input.") - - -def unpack_ind_prior_atom_input( - ipai: IndPriorAtomInput, - id_components: List[List[IdComponent]], - non_negative: bool, -) -> Tuple[List[str], float, float]: - """Get a 0d prior from a user input.""" - ids = [ - ID_SEPARATOR.join([getattr(ipai, c) for c in idci]) - for idci in id_components - ] - loc, scale = get_loc_and_scale(ipai, non_negative) - return ids, loc, scale - - -def get_ind_prior_1d( - pi: Optional[List[IndPriorAtomInput]], - ids: List[str], - id_components: List[List[IdComponent]], - non_negative: bool, - default_loc: float, - default_scale: float, -) -> IndPrior1d: - """Get an independent 1d prior from a prior input and StanVariable.""" - if len(ids) == 0: - return IndPrior1d(location=[], scale=[]) - loc_series = pd.Series(default_loc, index=ids) - scale_series = pd.Series(default_scale, index=ids) - if pi is not None: - for ipai in pi: - ids_i, loc_i, scale_i = unpack_ind_prior_atom_input( - ipai, id_components, non_negative - ) - if ids_i[0] in loc_series.index: - loc_series.update({ids_i[0]: loc_i}) - scale_series.update({ids_i[0]: scale_i}) - return IndPrior1d(loc_series.tolist(), scale_series.tolist()) - - -def get_ind_prior_2d( - pi: Optional[List[IndPriorAtomInput]], - ids: List[List[str]], - id_components: List[List[IdComponent]], - non_negative: bool, - default_loc: float, - default_scale: float, -) -> IndPrior2d: - """Get an independent 2d prior from a prior input and StanVariable.""" - if any(len(ids_i) == 0 for ids_i in ids): - return IndPrior2d(location=[[]], scale=[[]]) - loc_df = pd.DataFrame(float(default_loc), index=ids[0], columns=ids[1]) - scale_df = pd.DataFrame(float(default_scale), index=ids[0], columns=ids[1]) - if pi is not None: - for ipai in pi: - ids_i, loc_i, scale_i = unpack_ind_prior_atom_input( - ipai, id_components, non_negative - ) - if ids_i[0] in loc_df.index and ids_i[1] in loc_df.columns: - loc_df.loc[ids_i[0], ids_i[1]] = loc_i - scale_df.loc[ids_i[0], ids_i[1]] = scale_i - return IndPrior2d(loc_df.values.tolist(), scale_df.values.tolist()) - - -def get_mvn_prior( - pi: Optional[Union[List[IndPriorAtomInput], PriorMVNInput]], - ids: List[str], - id_components: List[List[IdComponent]], - non_negative: bool, - default_loc: float, - default_scale: float, -) -> PriorMVN: - """Get a multivariate normal prior from a prior input and StanVariable.""" - loc_series = pd.Series(default_loc, index=ids) - cov_df = pd.DataFrame( - np.diagflat(np.tile(default_scale, len(ids))), index=ids, columns=ids - ) - if isinstance(pi, PriorMVNInput): - loc_series = pd.Series(pi.mean_vector, index=pi.ids).reindex(ids) - cov_df = ( - pd.DataFrame(pi.covariance_matrix, index=pi.ids, columns=pi.ids) - .reindex(ids) - .reindex(columns=ids) - ) - elif isinstance(pi, list): - for ipai in pi: - ids_i, loc_i, cov_ii = unpack_ind_prior_atom_input( - ipai, id_components, non_negative - ) - loc_series.loc[ids_i[0]] = loc_i - cov_df.loc[ids_i[0], ids_i[0]] = cov_ii - return PriorMVN(loc_series.tolist(), cov_df.values.tolist()) diff --git a/maud/getting_stan_inputs.py b/maud/getting_stan_inputs.py index 52eddef6..008225b3 100644 --- a/maud/getting_stan_inputs.py +++ b/maud/getting_stan_inputs.py @@ -1,6 +1,5 @@ """Provides function get_stan_inputs for generating Stan input dictionaries.""" -from dataclasses import fields from typing import Dict, Iterable, List, Tuple, Union from scipy.stats import gmean @@ -21,8 +20,9 @@ ReactionMechanism, ) from maud.data_model.maud_config import MaudConfig -from maud.data_model.maud_parameter import ParameterSet -from maud.data_model.prior import IndPrior1d, IndPrior2d +from maud.data_model.maud_parameter import MaudParameter +from maud.data_model.parameter_set import ParameterSet +from maud.data_model.prior import PriorMVN def get_stan_inputs( @@ -66,17 +66,20 @@ def get_stan_inputs( def get_prior_inputs(parameters: ParameterSet) -> Tuple[Dict, Dict]: """Get the priors component of an input to Maud's Stan model.""" + params = [ + getattr(parameters, p) + for p in parameters.model_computed_fields.keys() + if isinstance(getattr(parameters, p), MaudParameter) + ] ind_priors_train = { f"priors_{p.name}": [p.prior.location, p.prior.scale] - for p in map(lambda f: getattr(parameters, f.name), fields(parameters)) - if p.prior_in_train_model - and (isinstance(p.prior, IndPrior1d) or isinstance(p.prior, IndPrior2d)) + for p in params + if p.prior_in_train_model and not isinstance(p.prior, PriorMVN) } ind_priors_test = { f"priors_{p.name}": [p.prior.location, p.prior.scale] - for p in map(lambda f: getattr(parameters, f.name), fields(parameters)) - if p.prior_in_test_model - and (isinstance(p.prior, IndPrior1d) or isinstance(p.prior, IndPrior2d)) + for p in params + if p.prior_in_test_model and not isinstance(p.prior, PriorMVN) } assert hasattr(parameters.dgf.prior, "covariance_matrix") dgf_priors = { diff --git a/maud/loading_maud_inputs.py b/maud/loading_maud_inputs.py index 75252978..d9ec41fa 100644 --- a/maud/loading_maud_inputs.py +++ b/maud/loading_maud_inputs.py @@ -8,7 +8,7 @@ from maud.data_model.maud_config import MaudConfig from maud.data_model.maud_init import InitInput from maud.data_model.maud_input import MaudInput -from maud.data_model.prior_input import PriorInput +from maud.data_model.parameter_input import ParameterSetInput from maud.parsing_kinetic_models import parse_kinetic_model @@ -27,7 +27,7 @@ def load_maud_input(data_path: str) -> MaudInput: kinetic_model_path = os.path.join(data_path, config.kinetic_model_file) experiments_path = os.path.join(data_path, config.experiments_file) raw_kinetic_model = toml.load(kinetic_model_path) - prior_input_path = os.path.join(data_path, config.priors_file) + parameter_input_path = os.path.join(data_path, config.priors_file) if config.user_inits_file is not None: init_input_path = os.path.join(data_path, config.user_inits_file) init_input = InitInput(**toml.load(init_input_path)) @@ -38,11 +38,11 @@ def load_maud_input(data_path: str) -> MaudInput: experiments = [ parse_experiment(e) for e in toml.load(experiments_path)["experiment"] ] - prior_input = PriorInput(**toml.load(prior_input_path)) + parameter_set_input = ParameterSetInput(**toml.load(parameter_input_path)) return MaudInput( config=config, kinetic_model=kinetic_model, - prior_input=prior_input, + parameter_set_input=parameter_set_input, experiments=experiments, init_input=init_input, ) diff --git a/maud/parsing_kinetic_models.py b/maud/parsing_kinetic_models.py index 6e50f3a2..9902c40e 100644 --- a/maud/parsing_kinetic_models.py +++ b/maud/parsing_kinetic_models.py @@ -28,19 +28,12 @@ def parse_kinetic_model(raw: dict) -> KineticModel: """ now = datetime.now().strftime("%Y%m%d%H%M%S") name = read_with_fallback("name", raw, now) - compartments = [ - Compartment(c["id"], c["name"], c["volume"]) for c in raw["compartment"] - ] + compartments = [Compartment(**c) for c in raw["compartment"]] mics = [ - MetaboliteInCompartment( - mic["metabolite_id"], mic["compartment_id"], mic["balanced"] - ) + MetaboliteInCompartment(**mic) for mic in raw["metabolite_in_compartment"] ] - ers = [ - EnzymeReaction(er["enzyme_id"], er["reaction_id"]) - for er in raw["enzyme_reaction"] - ] + ers = [EnzymeReaction(**er) for er in raw["enzyme_reaction"]] reactions = [ Reaction( id=r["id"], diff --git a/pyproject.toml b/pyproject.toml index fd11285b..c3e81bd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "cmdstanpy >= 1.2.0", "click", "depinfo == 1.7.0", - "pydantic == 1.9.0", + "pydantic >= 2.0", ] [project.entry-points.console_scripts]