From c8ddebdd7aff6b39596c30ac18d61b9c57a10013 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 27 Nov 2024 12:21:30 +0100 Subject: [PATCH 01/27] WIP better ui --- src/enzax/kinetic_model.py | 69 +++- src/enzax/rate_equation.py | 20 +- src/enzax/rate_equations/__init__.py | 15 +- src/enzax/rate_equations/generalised_mwc.py | 187 +++++----- src/enzax/rate_equations/michaelis_menten.py | 358 +++++++++++-------- 5 files changed, 389 insertions(+), 260 deletions(-) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index bb8e7fc..9f60d28 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -15,15 +15,59 @@ class KineticModelStructure(eqx.Module): """Structural information about a kinetic model.""" S: Float[Array, " s r"] - balanced_species: Int[Array, " n_balanced"] - unbalanced_species: Int[Array, " n_unbalanced"] + species: list[str] + reactions: list[str] + balanced_species: list[str] + species_to_dgf_ix: Int[Array, " s"] + balanced_species_ix: Int[Array, " b"] + unbalanced_species_ix: Int[Array, " u"] + + def __init__( + self, + S, + species, + reactions, + balanced_species, + species_to_dgf_ix, + ): + self.S = S + self.species = species + self.reactions = reactions + self.balanced_species = balanced_species + self.species_to_dgf_ix = species_to_dgf_ix + self.balanced_species_ix = jnp.array( + [i for i, s in enumerate(species) if s in balanced_species], + dtype=jnp.int16, + ) + self.unbalanced_species_ix = jnp.array( + [i for i, s in enumerate(species) if s not in balanced_species], + dtype=jnp.int16, + ) + + +class RateEquationKineticModelStructure(KineticModelStructure): + rate_equations: list[RateEquation] + + def __init__( + self, + S, + species, + reactions, + balanced_species, + species_to_dgf_ix, + rate_equations, + ): + super().__init__( + S, species, reactions, balanced_species, species_to_dgf_ix + ) + self.rate_equations = rate_equations class KineticModel(eqx.Module, ABC): """Abstract base class for kinetic models.""" parameters: PyTree - structure: KineticModelStructure + structure: KineticModelStructure = eqx.field(static=True) @abstractmethod def flux( @@ -44,7 +88,7 @@ def dcdt( """ # Noqa: E501 sv = self.structure.S @ self.flux(conc) - return sv[self.structure.balanced_species] + return sv[self.structure.balanced_species_ix] class RateEquationModel(KineticModel): @@ -64,10 +108,17 @@ def flux( """ # Noqa: E501 conc = jnp.zeros(self.structure.S.shape[0]) - conc = conc.at[self.structure.balanced_species].set(conc_balanced) - conc = conc.at[self.structure.unbalanced_species].set( + conc = conc.at[self.structure.balanced_species_ix].set(conc_balanced) + conc = conc.at[self.structure.unbalanced_species_ix].set( jnp.exp(self.parameters.log_conc_unbalanced) ) - t = [f(conc, self.parameters) for f in self.rate_equations] - out = jnp.array(t) - return out + flux_list = [] + for i, rate_equation in enumerate(self.structure.rate_equations): + ipt = rate_equation.get_input( + self.parameters, + i, + self.structure.S, + self.structure.species_to_dgf_ix, + ) + flux_list.append(rate_equation(conc, ipt)) + return jnp.array(flux_list) diff --git a/src/enzax/rate_equation.py b/src/enzax/rate_equation.py index aa89443..e3dbe24 100644 --- a/src/enzax/rate_equation.py +++ b/src/enzax/rate_equation.py @@ -1,10 +1,11 @@ """Module containing rate equations for enzyme-catalysed reactions.""" from abc import ABC, abstractmethod -from equinox import Module +import numpy as np +from equinox import Module from jaxtyping import Array, Float, PyTree, Scalar - +from numpy.typing import NDArray ConcArray = Float[Array, " n"] @@ -12,8 +13,19 @@ class RateEquation(Module, ABC): """Abstract definition of a rate equation. - A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of parameters, returning a scalar value representing a single flux. + A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of other inputs, returning a scalar value representing a single flux. """ # Noqa: E501 @abstractmethod - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: ... + def get_input( + self, + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ) -> PyTree: ... + + @abstractmethod + def __call__( + self, conc: ConcArray, rate_equation_input: PyTree + ) -> Scalar: ... diff --git a/src/enzax/rate_equations/__init__.py b/src/enzax/rate_equations/__init__.py index a0da47d..95f03d1 100644 --- a/src/enzax/rate_equations/__init__.py +++ b/src/enzax/rate_equations/__init__.py @@ -1,19 +1,18 @@ from enzax.rate_equations.michaelis_menten import ( ReversibleMichaelisMenten, IrreversibleMichaelisMenten, - MichaelisMenten, -) -from enzax.rate_equations.generalised_mwc import ( - AllostericReversibleMichaelisMenten, - AllostericIrreversibleMichaelisMenten, ) + +# from enzax.rate_equations.generalised_mwc import ( +# AllostericReversibleMichaelisMenten, +# AllostericIrreversibleMichaelisMenten, +# ) from enzax.rate_equations.drain import Drain AVAILABLE_RATE_EQUATIONS = [ ReversibleMichaelisMenten, IrreversibleMichaelisMenten, - MichaelisMenten, - AllostericReversibleMichaelisMenten, - AllostericIrreversibleMichaelisMenten, + # AllostericReversibleMichaelisMenten, + # AllostericIrreversibleMichaelisMenten, Drain, ] diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 90f867e..224f588 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -1,94 +1,93 @@ -from jax import numpy as jnp -from jaxtyping import Array, Float, Int, PyTree, Scalar - -from enzax.rate_equation import ConcArray -from enzax.rate_equations.michaelis_menten import ( - IrreversibleMichaelisMenten, - MichaelisMenten, - ReversibleMichaelisMenten, -) - - -def generalised_mwc_effect( - conc_inhibitor: Float[Array, " n_inhibition"], - dc_inhibitor: Float[Array, " n_inhibition"], - conc_activator: Float[Array, " n_activation"], - dc_activator: Float[Array, " n_activation"], - free_enzyme_ratio: Scalar, - tc: Scalar, - subunits: int, -) -> Scalar: - """Get the allosteric effect on a rate. - - The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2. - - """ # noqa: E501 - qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor) - qdenom = 1 + jnp.sum(conc_activator / dc_activator) - out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits) - return out - - -class GeneralisedMWC(MichaelisMenten): - """Mixin class for allosteric rate laws, assuming generalised MWC kinetics. - - See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2. - - Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation. - """ # noqa: E501 - - subunits: int - tc_ix: int - ix_dc_activation: Int[Array, " n_activation"] - ix_dc_inhibition: Int[Array, " n_inhibition"] - species_activation: Int[Array, " n_activation"] - species_inhibition: Int[Array, " n_inhibition"] - - def get_tc(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_transfer_constant[self.tc_ix]) - - def get_dc_activation(self, parameters: PyTree) -> Scalar: - return jnp.exp( - parameters.log_dissociation_constant[self.ix_dc_activation], - ) - - def get_dc_inhibition(self, parameters: PyTree) -> Scalar: - return jnp.exp( - parameters.log_dissociation_constant[self.ix_dc_inhibition], - ) - - def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar: - return generalised_mwc_effect( - conc_inhibitor=conc[self.species_inhibition], - conc_activator=conc[self.species_activation], - free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters), - tc=self.get_tc(parameters), - dc_inhibitor=self.get_dc_inhibition(parameters), - dc_activator=self.get_dc_activation(parameters), - subunits=self.subunits, - ) - - -class AllostericIrreversibleMichaelisMenten( - GeneralisedMWC, IrreversibleMichaelisMenten -): - """A reaction with irreversible Michaelis Menten kinetics and allostery.""" - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """The flux of an irreversible allosteric Michaelis Menten reaction.""" - allosteric_effect = self.allosteric_effect(conc, parameters) - non_allosteric_rate = super().__call__(conc, parameters) - return non_allosteric_rate * allosteric_effect - - -class AllostericReversibleMichaelisMenten( - GeneralisedMWC, - ReversibleMichaelisMenten, -): - """A reaction with reversible Michaelis Menten kinetics and allostery.""" - - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: - """The flux of an allosteric reversible Michaelis Menten reaction.""" - allosteric_effect = self.allosteric_effect(conc, parameters) - non_allosteric_rate = super().__call__(conc, parameters) - return non_allosteric_rate * allosteric_effect +# from jax import numpy as jnp +# from jaxtyping import Array, Float, Int, PyTree, Scalar + +# from enzax.rate_equation import ConcArray +# from enzax.rate_equations.michaelis_menten import ( +# IrreversibleMichaelisMenten, +# ReversibleMichaelisMenten, +# ) + + +# def generalised_mwc_effect( +# conc_inhibitor: Float[Array, " n_inhibition"], +# dc_inhibitor: Float[Array, " n_inhibition"], +# conc_activator: Float[Array, " n_activation"], +# dc_activator: Float[Array, " n_activation"], +# free_enzyme_ratio: Scalar, +# tc: Scalar, +# subunits: int, +# ) -> Scalar: +# """Get the allosteric effect on a rate. + +# The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2. + +# """ # noqa: E501 +# qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor) +# qdenom = 1 + jnp.sum(conc_activator / dc_activator) +# out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits) +# return out + + +# class GeneralisedMWC: +# """Mixin class for allosteric rate laws, assuming generalised MWC kinetics. + +# See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2. + +# Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation. +# """ # noqa: E501 + +# subunits: int +# tc_ix: int +# ix_dc_activation: Int[Array, " n_activation"] +# ix_dc_inhibition: Int[Array, " n_inhibition"] +# species_activation: Int[Array, " n_activation"] +# species_inhibition: Int[Array, " n_inhibition"] + +# def get_tc(self, parameters: PyTree) -> Scalar: +# return jnp.exp(parameters.log_transfer_constant[self.tc_ix]) + +# def get_dc_activation(self, parameters: PyTree) -> Scalar: +# return jnp.exp( +# parameters.log_dissociation_constant[self.ix_dc_activation], +# ) + +# def get_dc_inhibition(self, parameters: PyTree) -> Scalar: +# return jnp.exp( +# parameters.log_dissociation_constant[self.ix_dc_inhibition], +# ) + +# def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar: +# return generalised_mwc_effect( +# conc_inhibitor=conc[self.species_inhibition], +# conc_activator=conc[self.species_activation], +# free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters), +# tc=self.get_tc(parameters), +# dc_inhibitor=self.get_dc_inhibition(parameters), +# dc_activator=self.get_dc_activation(parameters), +# subunits=self.subunits, +# ) + + +# class AllostericIrreversibleMichaelisMenten( +# GeneralisedMWC, IrreversibleMichaelisMenten +# ): +# """A reaction with irreversible Michaelis Menten kinetics and allostery.""" + +# def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: +# """The flux of an irreversible allosteric Michaelis Menten reaction.""" +# allosteric_effect = self.allosteric_effect(conc, parameters) +# non_allosteric_rate = super().__call__(conc, parameters) +# return non_allosteric_rate * allosteric_effect + + +# class AllostericReversibleMichaelisMenten( +# GeneralisedMWC, +# ReversibleMichaelisMenten, +# ): +# """A reaction with reversible Michaelis Menten kinetics and allostery.""" + +# def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: +# """The flux of an allosteric reversible Michaelis Menten reaction.""" +# allosteric_effect = self.allosteric_effect(conc, parameters) +# non_allosteric_rate = super().__call__(conc, parameters) +# return non_allosteric_rate * allosteric_effect diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 55cf1d0..91011f8 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -1,110 +1,108 @@ -from abc import abstractmethod +import equinox as eqx +import numpy as np from jax import numpy as jnp from jaxtyping import Array, Float, Int, PyTree, Scalar +from numpy.typing import NDArray + +from enzax.rate_equation import RateEquation + + +class IrreversibleMichaelisMentenInput(eqx.Module): + kcat: Scalar + enzyme: Scalar + ix_ki_species: NDArray[np.int16] + ki: Float[Array, " n_ki"] + ix_substrate: NDArray[np.int16] + substrate_kms: Float[Array, " n_substrate"] + substrate_stoichiometry: NDArray[np.float64] + + +def get_irreversible_michaelis_menten_input( + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], +) -> IrreversibleMichaelisMentenInput: + Sj = S[:, rxn_ix] + ix_substrate = np.argwhere(Sj < 0.0) + return IrreversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[rxn_ix]), + enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), + ix_substrate=ix_substrate, + substrate_kms=jnp.exp(parameters.log_km[rxn_ix]), + substrate_stoichiometry=Sj[ix_substrate], + ix_ki_species=ci_ix, + ki=parameters.jnp.exp(parameters.log_ki[rxn_ix]), + ) + -from enzax.rate_equation import RateEquation, ConcArray +class ReversibleMichaelisMentenInput(eqx.Module): + kcat: Scalar + enzyme: Scalar + ki: Float[Array, " n_ki"] + substrate_kms: Float[Array, " n_substrate"] + product_kms: Float[Array, " n_product"] + dgf: Float[Array, " n_reactant"] + temperature: Scalar + ix_ki_species: NDArray[np.int16] + ix_reactant: NDArray[np.int16] + ix_substrate: NDArray[np.int16] + ix_product: NDArray[np.int16] + reactant_stoichiometry: NDArray[np.float64] + substrate_stoichiometry: NDArray[np.float64] + product_stoichiometry: NDArray[np.float64] + water_stoichiometry: float + + +def get_reversible_michaelis_menten_input( + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], + water_stoichiometry: float, +) -> ReversibleMichaelisMentenInput: + Sj = S[:, rxn_ix] + ix_reactant = np.argwhere(Sj != 0.0) + ix_substrate = np.argwhere(Sj < 0.0) + ix_product = np.argwhere(Sj > 0.0) + return ReversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[rxn_ix]), + enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), + substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), + product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), + ki=jnp.exp(parameters.log_ki[rxn_ix]), + dgf=jnp.exp(parameters.dgf[species_to_dgf_ix][ix_reactant]), + temperature=parameters.temperature, + ix_ki_species=ci_ix, + ix_reactant=ix_reactant, + ix_substrate=ix_substrate, + ix_product=ix_product, + reactant_stoichiometry=Sj[ix_reactant], + substrate_stoichiometry=Sj[ix_substrate], + product_stoichiometry=Sj[ix_product], + water_stoichiometry=water_stoichiometry, + ) def numerator_mm( - conc: ConcArray, - km: Float[Array, " n"], - ix_substrate: Int[Array, " n_substrate"], - substrate_km_positions: Int[Array, " n_substrate"], + substrate_conc: Float[Array, " n_substrate"], + substrate_kms: Int[Array, " n_substrate"], ) -> Scalar: """Get the product of each substrate's concentration over its km. This quantity is the numerator in a Michaelis Menten reaction's rate equation """ # Noqa: E501 - return jnp.prod((conc[ix_substrate] / km[substrate_km_positions])) - - -class MichaelisMenten(RateEquation): - """Base class for Michaelis Menten rate equations. - - Subclasses need to implement the __call__ and free_enzyme_ratio methods. - - """ - - kcat_ix: int - enzyme_ix: int - km_ix: Int[Array, " n"] - ki_ix: Int[Array, " n_ki"] - reactant_stoichiometry: Float[Array, " n"] - ix_substrate: Int[Array, " n_substrate"] - ix_ki_species: Int[Array, " n_ki"] - substrate_km_positions: Int[Array, " n_substrate"] - substrate_reactant_positions: Int[Array, " n_substrate"] - - def get_kcat(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_kcat[self.kcat_ix]) - - def get_km(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_km[self.km_ix]) - - def get_ki(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_ki[self.ki_ix]) - - def get_enzyme(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_enzyme[self.enzyme_ix]) - - @abstractmethod - def free_enzyme_ratio( - self, - conc: ConcArray, - parameters: PyTree, - ) -> Scalar: ... - - -def free_enzyme_ratio_imm( - conc_sub: Float[Array, " n_substrate"], - km_sub: Float[Array, " n_substrate"], - stoich_sub: Float[Array, " n_substrate"], - ki: Float[Array, " n_ki"], - conc_inhibitor: Float[Array, " n_ki"], -) -> Scalar: - """Free enzyme ratio for irreversible Michaelis Menten reactions.""" - return 1.0 / ( - jnp.prod(((conc_sub / km_sub) + 1) ** jnp.abs(stoich_sub)) - + jnp.sum(conc_inhibitor / ki) - ) - - -class IrreversibleMichaelisMenten(MichaelisMenten): - """A reaction with irreversible Michaelis Menten kinetics.""" - - def free_enzyme_ratio(self, conc, parameters): - return free_enzyme_ratio_imm( - conc_sub=conc[self.ix_substrate], - km_sub=self.get_km(parameters)[self.substrate_km_positions], - ki=self.get_ki(parameters), - conc_inhibitor=conc[self.ix_ki_species], - stoich_sub=self.reactant_stoichiometry[ - self.substrate_reactant_positions - ], - ) - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """Get flux of a reaction with irreversible Michaelis Menten kinetics.""" # noqa: E501 - kcat = self.get_kcat(parameters) - enzyme = self.get_enzyme(parameters) - km = self.get_km(parameters) - numerator = numerator_mm( - conc=conc, - km=km, - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, - ) - free_enzyme_ratio = self.free_enzyme_ratio(conc, parameters) - return kcat * enzyme * numerator * free_enzyme_ratio + return jnp.prod((substrate_conc / substrate_kms)) def get_reversibility( - conc: Float[Array, " n"], - water_stoichiometry: Scalar, + reactant_conc: Float[Array, " n_reactant"], dgf: Float[Array, " n_reactant"], temperature: Scalar, - reactant_stoichiometry: Float[Array, " n_reactant"], - ix_reactants: Int[Array, " n_reactant"], + reactant_stoichiometry: NDArray[np.float64], + water_stoichiometry: float, ) -> Scalar: """Get the reversibility of a reaction. @@ -113,81 +111,151 @@ def get_reversibility( """ RT = temperature * 0.008314 dgf_water = -150.9 - dgr = reactant_stoichiometry @ dgf + water_stoichiometry * dgf_water - quotient = reactant_stoichiometry @ jnp.log(conc[ix_reactants]) - out = 1.0 - jnp.exp(((dgr + RT * quotient) / RT)) + dgr = ( + reactant_stoichiometry.T @ dgf + water_stoichiometry * dgf_water + ).flatten() + quotient = (reactant_stoichiometry.T @ jnp.log(reactant_conc)).flatten() + out = 1.0 - jnp.exp((dgr + RT * quotient) / RT)[0] return out +def free_enzyme_ratio_imm( + substrate_conc: Float[Array, " n_substrate"], + substrate_km: Float[Array, " n_substrate"], + ki: Float[Array, " n_ki"], + inhibitor_conc: Float[Array, " n_ki"], + substrate_stoichiometry: NDArray[np.float64], +) -> Scalar: + """Free enzyme ratio for irreversible Michaelis Menten reactions.""" + return 1.0 / ( + jnp.prod( + ((substrate_conc / substrate_km) + 1) + ** jnp.abs(substrate_stoichiometry) + ) + + jnp.sum(inhibitor_conc / ki) + ) + + def free_enzyme_ratio_rmm( - conc_sub: Float[Array, " n_substrate"], - km_sub: Float[Array, " n_substrate"], - stoich_sub: Float[Array, " n_substrate"], - conc_prod: Float[Array, " n_product"], - km_prod: Float[Array, " n_prod"], - stoich_prod: Float[Array, " n_prod"], - conc_inhibitor: Float[Array, " n_ki"], + substrate_conc: Float[Array, " n_substrate"], + product_conc: Float[Array, " n_product"], + substrate_kms: Float[Array, " n_substrate"], + product_kms: Float[Array, " n_product"], + inhibitor_conc: Float[Array, " n_ki"], ki: Float[Array, " n_ki"], + substrate_stoichiometry: NDArray[np.float64], + product_stoichiometry: NDArray[np.float64], ) -> Scalar: """The free enzyme ratio for a reversible Michaelis Menten reaction.""" return 1.0 / ( -1.0 - + jnp.prod(((conc_sub / km_sub) + 1.0) ** jnp.abs(stoich_sub)) - + jnp.prod(((conc_prod / km_prod) + 1.0) ** jnp.abs(stoich_prod)) - + jnp.sum(conc_inhibitor / ki) + + jnp.prod( + ((substrate_conc / substrate_kms) + 1.0) + ** jnp.abs(substrate_stoichiometry) + ) + + jnp.prod( + ((product_conc / product_kms) + 1.0) + ** jnp.abs(product_stoichiometry) + ) + + jnp.sum(inhibitor_conc / ki) ) -class ReversibleMichaelisMenten(MichaelisMenten): +class IrreversibleMichaelisMenten(RateEquation): + """A reaction with irreversible Michaelis Menten kinetics.""" + + competitive_inhibitor_ix: NDArray[np.int16] + + def get_input( + self, + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_irreversible_michaelis_menten_input( + parameters=parameters, + rxn_ix=rxn_ix, + S=S, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.competitive_inhibitor_ix, + ) + + def __call__( + self, + conc: Float[Array, " n"], + imm_input: IrreversibleMichaelisMentenInput, + ) -> Scalar: + """Get flux of a reaction with irreversible Michaelis Menten kinetics.""" # noqa: E501 + numerator = numerator_mm( + substrate_conc=conc[imm_input.ix_substrate], + substrate_kms=imm_input.substrate_kms, + ) + fer = free_enzyme_ratio_imm( + substrate_conc=conc[imm_input.ix_substrate], + substrate_km=imm_input.substrate_kms, + ki=imm_input.ki, + inhibitor_conc=conc[imm_input.ix_ki_species], + substrate_stoichiometry=imm_input.substrate_stoichiometry, + ) + return imm_input.kcat * imm_input.enzyme * numerator * fer + + +class ReversibleMichaelisMenten(RateEquation): """A reaction with reversible Michaelis Menten kinetics.""" - ix_product: Int[Array, " n_product"] - ix_reactants: Int[Array, " n_reactant"] - product_reactant_positions: Int[Array, " n_product"] - product_km_positions: Int[Array, " n_product"] - water_stoichiometry: Scalar - reactant_to_dgf: Int[Array, " n_reactant"] - - def _get_dgf(self, parameters: PyTree): - return parameters.dgf[self.reactant_to_dgf] - - def free_enzyme_ratio(self, conc: ConcArray, parameters: PyTree) -> Scalar: - return free_enzyme_ratio_rmm( - conc_sub=conc[self.ix_substrate], - km_sub=self.get_km(parameters)[self.substrate_reactant_positions], - stoich_sub=self.reactant_stoichiometry[ - self.substrate_reactant_positions - ], - conc_prod=conc[self.ix_product], - km_prod=self.get_km(parameters)[self.product_reactant_positions], - stoich_prod=self.reactant_stoichiometry[ - self.product_reactant_positions - ], - conc_inhibitor=conc[self.ix_ki_species], - ki=self.get_ki(parameters), + competitive_inhibitor_ix: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16), static=True + ) + water_stoichiometry: float = eqx.field( + default_factory=lambda: 0.0, static=True + ) + + def get_input( + self, + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_reversible_michaelis_menten_input( + parameters=parameters, + rxn_ix=rxn_ix, + S=S, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.competitive_inhibitor_ix, + water_stoichiometry=self.water_stoichiometry, ) - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: + def __call__( + self, + conc: Float[Array, " n"], + rmm_input: ReversibleMichaelisMentenInput, + ) -> Scalar: """Get flux of a reaction with reversible Michaelis Menten kinetics. :param conc: A 1D array of non-negative numbers representing concentrations of the species that the reaction produces and consumes. """ # noqa: E501 - kcat = self.get_kcat(parameters) - enzyme = self.get_enzyme(parameters) - reversibility = get_reversibility( - conc=conc, - water_stoichiometry=self.water_stoichiometry, - dgf=self._get_dgf(parameters), - temperature=parameters.temperature, - reactant_stoichiometry=self.reactant_stoichiometry, - ix_reactants=self.ix_reactants, + rev = get_reversibility( + reactant_conc=conc[rmm_input.ix_reactant], + reactant_stoichiometry=rmm_input.reactant_stoichiometry, + dgf=rmm_input.dgf, + temperature=rmm_input.temperature, + water_stoichiometry=rmm_input.water_stoichiometry, ) numerator = numerator_mm( - conc=conc, - km=self.get_km(parameters), - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, + substrate_conc=conc[rmm_input.ix_substrate], + substrate_kms=rmm_input.substrate_kms, + ) + fer = free_enzyme_ratio_rmm( + substrate_conc=conc[rmm_input.ix_substrate], + product_conc=conc[rmm_input.ix_product], + inhibitor_conc=conc[rmm_input.ix_ki_species], + substrate_kms=rmm_input.substrate_kms, + product_kms=rmm_input.product_kms, + substrate_stoichiometry=rmm_input.substrate_stoichiometry, + product_stoichiometry=rmm_input.product_stoichiometry, + ki=rmm_input.ki, ) - free_enzyme_ratio = self.free_enzyme_ratio(conc, parameters) - return reversibility * kcat * enzyme * numerator * free_enzyme_ratio + return rev * rmm_input.kcat * rmm_input.enzyme * numerator * fer From c14a469b7053f5f42c1964b078f4642bd10c9977 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 27 Nov 2024 16:45:11 +0100 Subject: [PATCH 02/27] allostery, tests (not passing!) --- pyproject.toml | 15 +- src/enzax/kinetic_model.py | 59 +++- src/enzax/rate_equation.py | 7 +- src/enzax/rate_equations/__init__.py | 12 +- src/enzax/rate_equations/generalised_mwc.py | 306 +++++++++++++------ src/enzax/rate_equations/michaelis_menten.py | 28 +- src/enzax/steady_state.py | 1 - tests/test_rate_equations.py | 143 +++++---- 8 files changed, 361 insertions(+), 210 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd7e050..a2873bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ dependencies = [ "diffrax>=0.6.0", "jaxtyping>=0.2.31", "arviz>=0.19.0", + "jax>=0.4.35", + "equinox>=0.11.9", ] requires-python = ">=3.12" readme = "README.md" @@ -30,12 +32,6 @@ build-backend = "hatchling.build" [tool.pdm] distribution = true -[tool.pdm.dev-dependencies] -dev = [ - "pytest>=8.3.2", - "pytest-cov>=5.0.0", - "pre-commit>=3.8.0", -] [tool.ruff] line-length = 80 @@ -45,3 +41,10 @@ extend-select = ["E501"] # line length is checked [tool.ruff.lint.isort] known-first-party = ["enzax"] + +[dependency-groups] +dev = [ + "pytest>=8.3.3", + "pytest-cov>=5.0.0", + "pre-commit>=3.8.0", +] diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 9f60d28..1f9c39d 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -4,23 +4,27 @@ import equinox as eqx import jax.numpy as jnp -from jaxtyping import Array, Float, Int, PyTree, ScalarLike, jaxtyped +import numpy as np +from jax.tree_util import register_pytree_node_class +from jaxtyping import Array, Float, PyTree, ScalarLike, jaxtyped +from numpy.typing import NDArray from typeguard import typechecked from enzax.rate_equation import RateEquation @jaxtyped(typechecker=typechecked) -class KineticModelStructure(eqx.Module): +@register_pytree_node_class +class KineticModelStructure: """Structural information about a kinetic model.""" - S: Float[Array, " s r"] + S: NDArray[np.float64] species: list[str] reactions: list[str] balanced_species: list[str] - species_to_dgf_ix: Int[Array, " s"] - balanced_species_ix: Int[Array, " b"] - unbalanced_species_ix: Int[Array, " u"] + species_to_dgf_ix: NDArray[np.int16] + balanced_species_ix: NDArray[np.int16] + unbalanced_species_ix: NDArray[np.int16] def __init__( self, @@ -35,15 +39,30 @@ def __init__( self.reactions = reactions self.balanced_species = balanced_species self.species_to_dgf_ix = species_to_dgf_ix - self.balanced_species_ix = jnp.array( + self.balanced_species_ix = np.array( [i for i, s in enumerate(species) if s in balanced_species], - dtype=jnp.int16, + dtype=np.int16, ) - self.unbalanced_species_ix = jnp.array( + self.unbalanced_species_ix = np.array( [i for i, s in enumerate(species) if s not in balanced_species], - dtype=jnp.int16, + dtype=np.int16, ) + def tree_flatten(self): + children = ( + self.S, + self.species, + self.reactions, + self.balanced_species, + self.species_to_dgf_ix, + ) + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) + class RateEquationKineticModelStructure(KineticModelStructure): rate_equations: list[RateEquation] @@ -62,6 +81,18 @@ def __init__( ) self.rate_equations = rate_equations + def tree_flatten(self): + children = ( + self.S, + self.species, + self.reactions, + self.balanced_species, + self.species_to_dgf_ix, + self.rate_equations, + ) + aux_data = None + return children, aux_data + class KineticModel(eqx.Module, ABC): """Abstract base class for kinetic models.""" @@ -75,6 +106,7 @@ def flux( conc_balanced: Float[Array, " n_balanced"], ) -> Float[Array, " n"]: ... + @eqx.filter_jit def dcdt( self, t: ScalarLike, conc: Float[Array, " n_balanced"], args=None ) -> Float[Array, " n_balanced"]: @@ -87,14 +119,15 @@ def dcdt( :param conc: a one dimensional array of positive floats representing concentrations of balanced species. Must have same size as self.structure.ix_balanced """ # Noqa: E501 - sv = self.structure.S @ self.flux(conc) - return sv[self.structure.balanced_species_ix] + v = self.flux(conc) + sv = self.structure.S @ v + return jnp.array(sv[self.structure.balanced_species_ix]) class RateEquationModel(KineticModel): """A kinetic model that specifies its fluxes using RateEquation objects.""" - rate_equations: list[RateEquation] = eqx.field(default_factory=list) + structure: RateEquationKineticModelStructure def flux( self, diff --git a/src/enzax/rate_equation.py b/src/enzax/rate_equation.py index e3dbe24..4c11631 100644 --- a/src/enzax/rate_equation.py +++ b/src/enzax/rate_equation.py @@ -1,12 +1,13 @@ """Module containing rate equations for enzyme-catalysed reactions.""" from abc import ABC, abstractmethod - -import numpy as np from equinox import Module -from jaxtyping import Array, Float, PyTree, Scalar +import numpy as np from numpy.typing import NDArray +from jaxtyping import Array, Float, PyTree, Scalar + + ConcArray = Float[Array, " n"] diff --git a/src/enzax/rate_equations/__init__.py b/src/enzax/rate_equations/__init__.py index 95f03d1..32da2d4 100644 --- a/src/enzax/rate_equations/__init__.py +++ b/src/enzax/rate_equations/__init__.py @@ -3,16 +3,16 @@ IrreversibleMichaelisMenten, ) -# from enzax.rate_equations.generalised_mwc import ( -# AllostericReversibleMichaelisMenten, -# AllostericIrreversibleMichaelisMenten, -# ) +from enzax.rate_equations.generalised_mwc import ( + AllostericReversibleMichaelisMenten, + AllostericIrreversibleMichaelisMenten, +) from enzax.rate_equations.drain import Drain AVAILABLE_RATE_EQUATIONS = [ ReversibleMichaelisMenten, IrreversibleMichaelisMenten, - # AllostericReversibleMichaelisMenten, - # AllostericIrreversibleMichaelisMenten, + AllostericReversibleMichaelisMenten, + AllostericIrreversibleMichaelisMenten, Drain, ] diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 224f588..0d56866 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -1,93 +1,213 @@ -# from jax import numpy as jnp -# from jaxtyping import Array, Float, Int, PyTree, Scalar - -# from enzax.rate_equation import ConcArray -# from enzax.rate_equations.michaelis_menten import ( -# IrreversibleMichaelisMenten, -# ReversibleMichaelisMenten, -# ) - - -# def generalised_mwc_effect( -# conc_inhibitor: Float[Array, " n_inhibition"], -# dc_inhibitor: Float[Array, " n_inhibition"], -# conc_activator: Float[Array, " n_activation"], -# dc_activator: Float[Array, " n_activation"], -# free_enzyme_ratio: Scalar, -# tc: Scalar, -# subunits: int, -# ) -> Scalar: -# """Get the allosteric effect on a rate. - -# The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2. - -# """ # noqa: E501 -# qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor) -# qdenom = 1 + jnp.sum(conc_activator / dc_activator) -# out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits) -# return out - - -# class GeneralisedMWC: -# """Mixin class for allosteric rate laws, assuming generalised MWC kinetics. - -# See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2. - -# Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation. -# """ # noqa: E501 - -# subunits: int -# tc_ix: int -# ix_dc_activation: Int[Array, " n_activation"] -# ix_dc_inhibition: Int[Array, " n_inhibition"] -# species_activation: Int[Array, " n_activation"] -# species_inhibition: Int[Array, " n_inhibition"] - -# def get_tc(self, parameters: PyTree) -> Scalar: -# return jnp.exp(parameters.log_transfer_constant[self.tc_ix]) - -# def get_dc_activation(self, parameters: PyTree) -> Scalar: -# return jnp.exp( -# parameters.log_dissociation_constant[self.ix_dc_activation], -# ) - -# def get_dc_inhibition(self, parameters: PyTree) -> Scalar: -# return jnp.exp( -# parameters.log_dissociation_constant[self.ix_dc_inhibition], -# ) - -# def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar: -# return generalised_mwc_effect( -# conc_inhibitor=conc[self.species_inhibition], -# conc_activator=conc[self.species_activation], -# free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters), -# tc=self.get_tc(parameters), -# dc_inhibitor=self.get_dc_inhibition(parameters), -# dc_activator=self.get_dc_activation(parameters), -# subunits=self.subunits, -# ) - - -# class AllostericIrreversibleMichaelisMenten( -# GeneralisedMWC, IrreversibleMichaelisMenten -# ): -# """A reaction with irreversible Michaelis Menten kinetics and allostery.""" - -# def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: -# """The flux of an irreversible allosteric Michaelis Menten reaction.""" -# allosteric_effect = self.allosteric_effect(conc, parameters) -# non_allosteric_rate = super().__call__(conc, parameters) -# return non_allosteric_rate * allosteric_effect - - -# class AllostericReversibleMichaelisMenten( -# GeneralisedMWC, -# ReversibleMichaelisMenten, -# ): -# """A reaction with reversible Michaelis Menten kinetics and allostery.""" - -# def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: -# """The flux of an allosteric reversible Michaelis Menten reaction.""" -# allosteric_effect = self.allosteric_effect(conc, parameters) -# non_allosteric_rate = super().__call__(conc, parameters) -# return non_allosteric_rate * allosteric_effect +import equinox as eqx +from jax import numpy as jnp +from jaxtyping import Array, Float, PyTree, Scalar +import numpy as np +from numpy.typing import NDArray + +from enzax.rate_equations.michaelis_menten import ( + free_enzyme_ratio_imm, + free_enzyme_ratio_rmm, + IrreversibleMichaelisMenten, + IrreversibleMichaelisMentenInput, + ReversibleMichaelisMenten, + ReversibleMichaelisMentenInput, +) + + +class AllostericIrreversibleMichaelisMentenInput( + IrreversibleMichaelisMentenInput +): + dc_inhibitor: Float[Array, " n_inhibitor"] + dc_activator: Float[Array, " n_activator"] + tc: Scalar + + +class AllostericReversibleMichaelisMentenInput(ReversibleMichaelisMentenInput): + dc_inhibitor: Float[Array, " n_inhibitor"] + dc_activator: Float[Array, " n_activator"] + tc: Scalar + + +def get_allosteric_irreversible_michaelis_menten_input( + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], +) -> AllostericIrreversibleMichaelisMentenInput: + Sj = S[:, rxn_ix] + ix_substrate = np.argwhere(Sj < 0.0) + return AllostericIrreversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[rxn_ix]), + enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), + ix_substrate=ix_substrate, + substrate_kms=jnp.exp(parameters.log_km[rxn_ix]), + substrate_stoichiometry=Sj[ix_substrate], + ix_ki_species=ci_ix, + ki=jnp.exp(parameters.log_ki[rxn_ix]), + dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[rxn_ix]), + dc_activator=jnp.exp(parameters.log_dc_activator[rxn_ix]), + tc=jnp.exp(parameters.log_tc[rxn_ix]), + ) + + +def get_allosteric_reversible_michaelis_menten_input( + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], + water_stoichiometry: float, +) -> AllostericReversibleMichaelisMentenInput: + Sj = S[:, rxn_ix] + ix_reactant = np.argwhere(Sj != 0.0) + ix_substrate = np.argwhere(Sj < 0.0) + ix_product = np.argwhere(Sj > 0.0) + return AllostericReversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[rxn_ix]), + enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), + substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), + product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), + ki=jnp.exp(parameters.log_ki[rxn_ix]), + dgf=jnp.exp(parameters.dgf[species_to_dgf_ix][ix_reactant]), + temperature=parameters.temperature, + ix_ki_species=ci_ix, + ix_reactant=ix_reactant, + ix_substrate=ix_substrate, + ix_product=ix_product, + reactant_stoichiometry=Sj[ix_reactant], + substrate_stoichiometry=Sj[ix_substrate], + product_stoichiometry=Sj[ix_product], + water_stoichiometry=water_stoichiometry, + dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[rxn_ix]), + dc_activator=jnp.exp(parameters.log_dc_activator[rxn_ix]), + tc=jnp.exp(parameters.log_tc[rxn_ix]), + ) + + +def generalised_mwc_effect( + conc_inhibitor: Float[Array, " n_inhibition"], + dc_inhibitor: Float[Array, " n_inhibition"], + conc_activator: Float[Array, " n_activation"], + dc_activator: Float[Array, " n_activation"], + free_enzyme_ratio: Scalar, + tc: Scalar, + subunits: int, +) -> Scalar: + """Get the allosteric effect on a rate. + + The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2. + + """ # noqa: E501 + qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor) + qdenom = 1 + jnp.sum(conc_activator / dc_activator) + out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits) + return out + + +class AllostericIrreversibleMichaelisMenten(IrreversibleMichaelisMenten): + """A reaction with irreversible Michaelis Menten kinetics and allostery.""" + + ix_inhibitors: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + ix_activators: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + subunits: int = 1 + + def get_input( + self, + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_allosteric_irreversible_michaelis_menten_input( + parameters=parameters, + rxn_ix=rxn_ix, + S=S, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.ix_ki_species, + ) + + def __call__( + self, + conc: Float[Array, " n"], + aimm_input: AllostericIrreversibleMichaelisMentenInput, + ) -> Scalar: + """The flux of an irreversible allosteric Michaelis Menten reaction.""" + fer = free_enzyme_ratio_imm( + substrate_conc=conc[aimm_input.ix_substrate], + substrate_km=aimm_input.substrate_kms, + ki=aimm_input.ki, + inhibitor_conc=conc[self.ix_inhibitors], + substrate_stoichiometry=aimm_input.substrate_stoichiometry, + ) + allosteric_effect = generalised_mwc_effect( + conc_inhibitor=conc[self.ix_inhibitors], + dc_inhibitor=aimm_input.dc_inhibitor, + dc_activator=aimm_input.dc_activator, + conc_activator=conc[self.ix_activators], + free_enzyme_ratio=fer, + tc=aimm_input.tc, + subunits=self.subunits, + ) + non_allosteric_rate = super().__call__(conc, aimm_input) + return non_allosteric_rate * allosteric_effect + + +class AllostericReversibleMichaelisMenten(ReversibleMichaelisMenten): + """A reaction with reversible Michaelis Menten kinetics and allostery.""" + + ix_allosteric_inhibitors: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + ix_allosteric_activators: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + subunits: int = 1 + + def get_input( + self, + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_allosteric_reversible_michaelis_menten_input( + parameters=parameters, + rxn_ix=rxn_ix, + S=S, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.ix_ki_species, + water_stoichiometry=self.water_stoichiometry, + ) + + def __call__( + self, + conc: Float[Array, " n"], + armm_input: AllostericReversibleMichaelisMentenInput, + ) -> Scalar: + """The flux of an irreversible allosteric Michaelis Menten reaction.""" + fer = free_enzyme_ratio_rmm( + substrate_conc=conc[armm_input.ix_substrate], + product_conc=conc[armm_input.ix_product], + substrate_kms=armm_input.substrate_kms, + product_kms=armm_input.product_kms, + ix_ki_species=conc[self.ix_ki_species], + ki=armm_input.ki, + substrate_stoichiometry=armm_input.substrate_stoichiometry, + product_stoichiometry=armm_input.product_stoichiometry, + ) + allosteric_effect = generalised_mwc_effect( + conc_inhibitor=conc[self.ix_allosteric_inhibitors], + dc_inhibitor=armm_input.dc_inhibitor, + dc_activator=armm_input.dc_activator, + conc_activator=conc[self.ix_allosteric_activators], + free_enzyme_ratio=fer, + tc=armm_input.tc, + subunits=self.subunits, + ) + non_allosteric_rate = super().__call__(conc, armm_input) + return non_allosteric_rate * allosteric_effect diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 91011f8..433ce32 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -25,15 +25,15 @@ def get_irreversible_michaelis_menten_input( ci_ix: NDArray[np.int16], ) -> IrreversibleMichaelisMentenInput: Sj = S[:, rxn_ix] - ix_substrate = np.argwhere(Sj < 0.0) + ix_substrate = np.argwhere(Sj < 0.0).flatten() return IrreversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix]), + substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, - ki=parameters.jnp.exp(parameters.log_ki[rxn_ix]), + ki=jnp.exp(parameters.log_ki[rxn_ix]), ) @@ -141,7 +141,7 @@ def free_enzyme_ratio_rmm( product_conc: Float[Array, " n_product"], substrate_kms: Float[Array, " n_substrate"], product_kms: Float[Array, " n_product"], - inhibitor_conc: Float[Array, " n_ki"], + ix_ki_species: Float[Array, " n_ki"], ki: Float[Array, " n_ki"], substrate_stoichiometry: NDArray[np.float64], product_stoichiometry: NDArray[np.float64], @@ -157,14 +157,16 @@ def free_enzyme_ratio_rmm( ((product_conc / product_kms) + 1.0) ** jnp.abs(product_stoichiometry) ) - + jnp.sum(inhibitor_conc / ki) + + jnp.sum(ix_ki_species / ki) ) class IrreversibleMichaelisMenten(RateEquation): """A reaction with irreversible Michaelis Menten kinetics.""" - competitive_inhibitor_ix: NDArray[np.int16] + ix_ki_species: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int64) + ) def get_input( self, @@ -178,7 +180,7 @@ def get_input( rxn_ix=rxn_ix, S=S, species_to_dgf_ix=species_to_dgf_ix, - ci_ix=self.competitive_inhibitor_ix, + ci_ix=self.ix_ki_species, ) def __call__( @@ -204,12 +206,10 @@ def __call__( class ReversibleMichaelisMenten(RateEquation): """A reaction with reversible Michaelis Menten kinetics.""" - competitive_inhibitor_ix: NDArray[np.int16] = eqx.field( - default_factory=lambda: np.array([], dtype=np.int16), static=True - ) - water_stoichiometry: float = eqx.field( - default_factory=lambda: 0.0, static=True + ix_ki_species: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) ) + water_stoichiometry: float = eqx.field(default_factory=lambda: 0.0) def get_input( self, @@ -223,7 +223,7 @@ def get_input( rxn_ix=rxn_ix, S=S, species_to_dgf_ix=species_to_dgf_ix, - ci_ix=self.competitive_inhibitor_ix, + ci_ix=self.ix_ki_species, water_stoichiometry=self.water_stoichiometry, ) @@ -251,7 +251,7 @@ def __call__( fer = free_enzyme_ratio_rmm( substrate_conc=conc[rmm_input.ix_substrate], product_conc=conc[rmm_input.ix_product], - inhibitor_conc=conc[rmm_input.ix_ki_species], + ix_ki_species=conc[rmm_input.ix_ki_species], substrate_kms=rmm_input.substrate_kms, product_kms=rmm_input.product_kms, substrate_stoichiometry=rmm_input.substrate_stoichiometry, diff --git a/src/enzax/steady_state.py b/src/enzax/steady_state.py index a89f745..7437ef4 100644 --- a/src/enzax/steady_state.py +++ b/src/enzax/steady_state.py @@ -51,7 +51,6 @@ def get_kinetic_model_steady_state( t1, dt0, guess, - args=model, max_steps=max_steps, stepsize_controller=controller, event=event, diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index a511e77..84a5a35 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -1,116 +1,111 @@ """Unit tests for rate equations.""" +import equinox as eqx +import numpy as np from jax import numpy as jnp +from jaxtyping import Array, Scalar from enzax.rate_equations import ( AllostericIrreversibleMichaelisMenten, AllostericReversibleMichaelisMenten, IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameterSet + +class ExampleParameterSet(eqx.Module): + log_km: dict[int, Array] + log_kcat: dict[int, Scalar] + log_enzyme: dict[int, Array] + log_ki: dict[int, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[int, Array] + log_dc_activator: dict[int, Array] + log_tc: dict[int, Array] + + +EXAMPLE_S = np.array( + [[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=np.float64 +) EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1, 0.3]) -EXAMPLE_PARAMETERS = AllostericMichaelisMentenParameterSet( - log_kcat=jnp.array([-0.1]), - log_enzyme=jnp.log(jnp.array([0.3])), - dgf=jnp.array([-3, -1.0]), - log_km=jnp.array([0.1, -0.2]), - log_ki=jnp.array([1.0]), - log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.3])), +EXAMPLE_PARAMETERS = ExampleParameterSet( + log_km={ + 0: jnp.array([[0.1], [-0.2]]), + 1: jnp.array([[0.5], [0.0]]), + 2: jnp.array([[-1.0], [0.5]]), + }, + log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, + dgf=jnp.array([-3.0, 1.0]), + log_ki={0: jnp.array([1.0]), 1: jnp.array([1.0]), 2: jnp.array([])}, temperature=jnp.array(310.0), - log_transfer_constant=jnp.array([-0.2, 0.3]), - log_dissociation_constant=jnp.array([-0.1, 0.2]), - log_drain=jnp.array([]), + log_enzyme={0: jnp.array(0.3), 1: jnp.array(0.2), 2: jnp.array(0.1)}, + log_conc_unbalanced=jnp.array([0.5, 0.3]), + log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)}, + log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])}, + log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])}, ) +EXAMPLE_SPECIES_TO_DGF_IX = np.array([0, 0, 1, 1]) def test_irreversible_michaelis_menten(): expected_rate = 0.08455524 - f = IrreversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), + f = IrreversibleMichaelisMenten() + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + rxn_ix=0, + S=EXAMPLE_S, + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) def test_reversible_michaelis_menten(): expected_rate = 0.04342889 f = ReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), + ix_ki_species=np.array([], dtype=np.int16), + water_stoichiometry=0.0, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + rxn_ix=0, + S=EXAMPLE_S, + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, + ) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) def test_allosteric_irreversible_michaelis_menten(): expected_rate = 0.05608589 f = AllostericIrreversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), + ix_ki_species=np.array([], dtype=np.int16), + ix_activators=np.array([2], dtype=np.int16), subunits=1, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + rxn_ix=0, + S=EXAMPLE_S, + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, + ) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) def test_allosteric_reversible_michaelis_menten(): expected_rate = 0.03027414 f = AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), + ix_ki_species=np.array([], dtype=np.int16), + ix_allosteric_activators=np.array([2], dtype=np.int16), subunits=1, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + rxn_ix=0, + S=EXAMPLE_S, + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, + ) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) From 90e79abdb254fa17380f6255ade9363c912bff8b Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 27 Nov 2024 16:52:36 +0100 Subject: [PATCH 03/27] tidying --- src/enzax/rate_equations/generalised_mwc.py | 10 ++++---- tests/test_rate_equations.py | 28 ++++++++------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 0d56866..89bc3ad 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -108,10 +108,10 @@ def generalised_mwc_effect( class AllostericIrreversibleMichaelisMenten(IrreversibleMichaelisMenten): """A reaction with irreversible Michaelis Menten kinetics and allostery.""" - ix_inhibitors: NDArray[np.int16] = eqx.field( + ix_allosteric_inhibitors: NDArray[np.int16] = eqx.field( default_factory=lambda: np.array([], dtype=np.int16) ) - ix_activators: NDArray[np.int16] = eqx.field( + ix_allosteric_activators: NDArray[np.int16] = eqx.field( default_factory=lambda: np.array([], dtype=np.int16) ) subunits: int = 1 @@ -141,14 +141,14 @@ def __call__( substrate_conc=conc[aimm_input.ix_substrate], substrate_km=aimm_input.substrate_kms, ki=aimm_input.ki, - inhibitor_conc=conc[self.ix_inhibitors], + inhibitor_conc=conc[self.ix_ki_species], substrate_stoichiometry=aimm_input.substrate_stoichiometry, ) allosteric_effect = generalised_mwc_effect( - conc_inhibitor=conc[self.ix_inhibitors], + conc_inhibitor=conc[self.ix_allosteric_inhibitors], dc_inhibitor=aimm_input.dc_inhibitor, dc_activator=aimm_input.dc_activator, - conc_activator=conc[self.ix_activators], + conc_activator=conc[self.ix_allosteric_activators], free_enzyme_ratio=fer, tc=aimm_input.tc, subunits=self.subunits, diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index 84a5a35..a982cd3 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -25,25 +25,19 @@ class ExampleParameterSet(eqx.Module): log_tc: dict[int, Array] -EXAMPLE_S = np.array( - [[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=np.float64 -) -EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1, 0.3]) +EXAMPLE_S = np.array([[-1], [1]], dtype=np.float64) +EXAMPLE_CONC = jnp.array([0.5, 0.2]) EXAMPLE_PARAMETERS = ExampleParameterSet( - log_km={ - 0: jnp.array([[0.1], [-0.2]]), - 1: jnp.array([[0.5], [0.0]]), - 2: jnp.array([[-1.0], [0.5]]), - }, - log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, + log_km={0: jnp.array([[0.1], [-0.2]])}, + log_kcat={0: jnp.array(-0.1)}, dgf=jnp.array([-3.0, 1.0]), - log_ki={0: jnp.array([1.0]), 1: jnp.array([1.0]), 2: jnp.array([])}, + log_ki={0: jnp.array([1.0])}, temperature=jnp.array(310.0), - log_enzyme={0: jnp.array(0.3), 1: jnp.array(0.2), 2: jnp.array(0.1)}, - log_conc_unbalanced=jnp.array([0.5, 0.3]), - log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)}, - log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])}, - log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])}, + log_enzyme={0: jnp.array(0.3)}, + log_conc_unbalanced=jnp.array([]), + log_tc={0: jnp.array(-0.2)}, + log_dc_activator={0: jnp.array([-0.1])}, + log_dc_inhibitor={0: jnp.array([0.2])}, ) EXAMPLE_SPECIES_TO_DGF_IX = np.array([0, 0, 1, 1]) @@ -81,7 +75,7 @@ def test_allosteric_irreversible_michaelis_menten(): expected_rate = 0.05608589 f = AllostericIrreversibleMichaelisMenten( ix_ki_species=np.array([], dtype=np.int16), - ix_activators=np.array([2], dtype=np.int16), + ix_allosteric_activators=np.array([2], dtype=np.int16), subunits=1, ) f_input = f.get_input( From 259961ea3fbad2db38baaa8fdddc85ad6492f0c8 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 27 Nov 2024 17:03:33 +0100 Subject: [PATCH 04/27] Fix missing jnp.log in test --- tests/test_rate_equations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index a982cd3..bbd0d6b 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -33,7 +33,7 @@ class ExampleParameterSet(eqx.Module): dgf=jnp.array([-3.0, 1.0]), log_ki={0: jnp.array([1.0])}, temperature=jnp.array(310.0), - log_enzyme={0: jnp.array(0.3)}, + log_enzyme={0: jnp.log(jnp.array(0.3))}, log_conc_unbalanced=jnp.array([]), log_tc={0: jnp.array(-0.2)}, log_dc_activator={0: jnp.array([-0.1])}, From 33dbc3df1506666affb1d54b0b1d1cc3018a37f0 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 27 Nov 2024 17:14:51 +0100 Subject: [PATCH 05/27] More fixes - now the rate equations test passes! --- src/enzax/rate_equations/generalised_mwc.py | 2 +- tests/test_rate_equations.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 89bc3ad..6995103 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -41,7 +41,7 @@ def get_allosteric_irreversible_michaelis_menten_input( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix]), + substrate_kms=jnp.exp(parameters.log_km[rxn_ix][ix_substrate]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, ki=jnp.exp(parameters.log_ki[rxn_ix]), diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index bbd0d6b..e14100e 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -25,8 +25,8 @@ class ExampleParameterSet(eqx.Module): log_tc: dict[int, Array] -EXAMPLE_S = np.array([[-1], [1]], dtype=np.float64) -EXAMPLE_CONC = jnp.array([0.5, 0.2]) +EXAMPLE_S = np.array([[-1], [1], [0]], dtype=np.float64) +EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1]) EXAMPLE_PARAMETERS = ExampleParameterSet( log_km={0: jnp.array([[0.1], [-0.2]])}, log_kcat={0: jnp.array(-0.1)}, From 952a9099dbb4402bf7078e6471be3750ec5f0041 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 08:38:34 +0100 Subject: [PATCH 06/27] Updating linear example --- src/enzax/examples/linear.py | 136 ++++++++++++++--------------------- tests/test_examples.py | 4 +- 2 files changed, 57 insertions(+), 83 deletions(-) diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 134167d..11e9667 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -1,100 +1,74 @@ """A simple linear kinetic model.""" +import equinox as eqx +import numpy as np from jax import numpy as jnp +from jaxtyping import Array, Scalar from enzax.kinetic_model import ( + RateEquationKineticModelStructure, RateEquationModel, - KineticModelStructure, ) from enzax.rate_equations import ( AllostericReversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameterSet -parameters = AllostericMichaelisMentenParameterSet( - log_kcat=jnp.array([-0.1, 0.0, 0.1]), - log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), - dgf=jnp.array([-3, -1.0]), - log_km=jnp.array([0.1, -0.2, 0.5, 0.0, -1.0, 0.5]), - log_ki=jnp.array([1.0]), - log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), - temperature=jnp.array(310.0), - log_transfer_constant=jnp.array([-0.2, 0.3]), - log_dissociation_constant=jnp.array([-0.1, 0.2]), - log_drain=jnp.array([]), -) -structure = KineticModelStructure( - S=jnp.array( - [[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=jnp.float64 - ), - balanced_species=jnp.array([1, 2]), - unbalanced_species=jnp.array([0, 3]), -) + +class ParameterDefinition(eqx.Module): + log_km: dict[int, Array] + log_kcat: dict[int, Scalar] + log_enzyme: dict[int, Array] + log_ki: dict[int, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[int, Array] + log_dc_activator: dict[int, Array] + log_tc: dict[int, Array] + + +S = np.array([[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=np.float64) +reactions = ["r1", "r2", "r3"] +species = ["m1e", "m1c", "m2c", "m2e"] +balanced_species = ["m1c", "m2c"] rate_equations = [ AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - subunits=1, + ix_allosteric_activators=np.array([2]), subunits=1 ), AllostericReversibleMichaelisMenten( - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([0]), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([1]), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([1], dtype=jnp.int16), - ix_product=jnp.array([2], dtype=jnp.int16), - ix_reactants=jnp.array([1, 2], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=1, - ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), - ix_dc_activation=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([1], dtype=jnp.int16), - subunits=1, - ), - ReversibleMichaelisMenten( - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4, 5], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - ix_substrate=jnp.array([2], dtype=jnp.int16), - ix_product=jnp.array([3], dtype=jnp.int16), - ix_reactants=jnp.array([2, 3], dtype=jnp.int16), - reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), + ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1]) ), + ReversibleMichaelisMenten(water_stoichiometry=0.0), ] +structure = RateEquationKineticModelStructure( + S=S, + species=species, + reactions=reactions, + balanced_species=balanced_species, + species_to_dgf_ix=np.array([0, 0, 1, 1]), + rate_equations=rate_equations, +) +parameters = ParameterDefinition( + log_km={ + 0: jnp.array([[0.1], [-0.2]]), + 1: jnp.array([[0.5], [0.0]]), + 2: jnp.array([[-1.0], [0.5]]), + }, + log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, + dgf=jnp.array([-3.0, 1.0]), + log_ki={0: jnp.array([]), 1: jnp.array([1.0]), 2: jnp.array([])}, + temperature=jnp.array(310.0), + log_enzyme={ + 0: jnp.log(jnp.array(0.3)), + 1: jnp.log(jnp.array(0.2)), + 2: jnp.log(jnp.array(0.1)), + }, + log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), + log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)}, + log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])}, + log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])}, +) +true_model = RateEquationModel(structure=structure, parameters=parameters) steady_state = jnp.array([0.43658744, 0.12695706]) -model = RateEquationModel(parameters, structure, rate_equations) +model = RateEquationModel(parameters, structure) diff --git a/tests/test_examples.py b/tests/test_examples.py index d11cc2a..0c183e5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,13 +1,13 @@ from jax import numpy as jnp import pytest -from enzax.examples import linear, methionine +from enzax.examples import linear @pytest.mark.parametrize( ["model", "steady_state"], [ - (methionine.model, methionine.steady_state), + # (methionine.model, methionine.steady_state), (linear.model, linear.steady_state), ], ) From 3f3937e673d8ae998586ad425d7aaa90657a659e Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 09:08:21 +0100 Subject: [PATCH 07/27] Fix incorrect dgf in linear model --- src/enzax/examples/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 11e9667..9f03df2 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -56,7 +56,7 @@ class ParameterDefinition(eqx.Module): 2: jnp.array([[-1.0], [0.5]]), }, log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, - dgf=jnp.array([-3.0, 1.0]), + dgf=jnp.array([-3.0, -1.0]), log_ki={0: jnp.array([]), 1: jnp.array([1.0]), 2: jnp.array([])}, temperature=jnp.array(310.0), log_enzyme={ From 2334e1c52ecfa21c6e6956a5874a58fd26e8f6ad Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 13:55:09 +0100 Subject: [PATCH 08/27] trying to get methionine to work --- src/enzax/examples/methionine.py | 630 ++++++------------- src/enzax/kinetic_model.py | 9 +- src/enzax/rate_equations/drain.py | 31 +- src/enzax/rate_equations/generalised_mwc.py | 4 +- src/enzax/rate_equations/michaelis_menten.py | 2 +- tests/test_examples.py | 4 +- 6 files changed, 220 insertions(+), 460 deletions(-) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 1179b11..1e16011 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -5,11 +5,14 @@ """ +import equinox as eqx +import numpy as np from jax import numpy as jnp +from jaxtyping import Array, Scalar from enzax.kinetic_model import ( + RateEquationKineticModelStructure, RateEquationModel, - KineticModelStructure, ) from enzax.rate_equations import ( AllostericIrreversibleMichaelisMenten, @@ -17,187 +20,114 @@ IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameterSet -coords = { - "enzyme": [ - "MAT1", - "MAT3", - "METH-Gen", - "GNMT1", - "AHC1", - "MS1", - "BHMT1", - "CBS1", - "MTHFR1", - "PROT1", - ], - "drain": ["the_drain"], - "rate": [ - "the_drain", - "MAT1", - "MAT3", - "METH-Gen", - "GNMT1", - "AHC1", - "MS1", - "BHMT1", - "CBS1", - "MTHFR1", - "PROT1", - ], - "metabolite": [ - "met-L", - "atp", - "pi", - "ppi", - "amet", - "ahcys", - "gly", - "sarcs", - "hcys-L", - "adn", - "thf", - "5mthf", - "mlthf", - "glyb", - "dmgly", - "ser-L", - "nadp", - "nadph", - "cyst-L", - ], - "km": [ - "met-L MAT1", - "atp MAT1", - "met-L MAT3", - "atp MAT3", - "amet METH-Gen", - "amet GNMT1", - "gly GNMT1", - "ahcys AHC1", - "hcys-L AHC1", - "adn AHC1", - "hcys-L MS1", - "5mthf MS1", - "hcys-L BHMT1", - "glyb BHMT1", - "hcys-L CBS1", - "ser-L CBS1", - "mlthf MTHFR1", - "nadph MTHFR1", - "met-L PROT1", - ], - "ki": [ - "MAT1", - "METH-Gen", - "GNMT1", - ], - "species": [ - "met-L", - "atp", - "pi", - "ppi", - "amet", - "ahcys", - "gly", - "sarcs", - "hcys-L", - "adn", - "thf", - "5mthf", - "mlthf", - "glyb", - "dmgly", - "ser-L", - "nadp", - "nadph", - "cyst-L", - ], - "balanced_species": [ - "met-L", - "amet", - "ahcys", - "hcys-L", - "5mthf", - ], - "unbalanced_species": [ - "atp", - "pi", - "ppi", - "gly", - "sarcs", - "adn", - "thf", - "mlthf", - "glyb", - "dmgly", - "ser-L", - "nadp", - "nadph", - "cyst-L", - ], - "transfer_constant": [ - "METAT", - "GNMT", - "CBS", - "MTHFR", - ], - "dissociation_constant": [ - "met-L MAT3", - "amet MAT3", - "amet GNMT1", - "mlthf GNMT1", - "amet CBS1", - "amet MTHFR1", - "ahcys MTHFR1", + +class ParameterDefinition(eqx.Module): + log_km: dict[int, list[Array]] + log_kcat: dict[int, Scalar] + log_enzyme: dict[int, Array] + log_ki: dict[int, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[int, Array] + log_dc_activator: dict[int, Array] + log_tc: dict[int, Array] + log_drain: dict[int, Scalar] + + +S = np.array( + [ + [1, -1, -1, 0, 0, 0, 1, 1, 0, 0, -1], # met-L b + [0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0], # atp + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # pi + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # ppi + [0, 1, 1, -1, -1, 0, 0, 0, 0, 0, 0], # amet b + [0, 0, 0, 1, 1, -1, 0, 0, 0, 0, 0], # ahcys b + [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], # gly + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # sarcs + [0, 0, 0, 0, 0, 1, -1, -1, -1, 0, 0], # hcys-L b + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # adn + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # thf + [0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0], # 5mthf b + [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # mlthf + [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0], # glyb + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # dmgly + [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], # ser-L + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # nadp + [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # nadph + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # cyst-L ], -} -dims = { - "log_kcat": ["enzyme"], - "log_enzyme": ["enzyme"], - "log_drain": ["drain"], - "log_km": ["km"], - "dgf": ["metabolite"], - "log_ki": ["ki"], - "log_conc_unbalanced": ["unbalanced_species"], - "log_transfer_constant": ["transfer_constant"], - "log_dissociation_constant": ["dissociation_constant"], -} -parameters = AllostericMichaelisMentenParameterSet( - log_kcat=jnp.log( - jnp.array( - [ - 7.89577, # MAT1 - 19.9215, # MAT3 - 1.15777, # METH-Gen - 10.5307, # GNMT1 - 234.284, # AHC1 - 1.77471, # MS1 - 13.7676, # BHMT1 - 7.02307, # CBS1 - 3.1654, # MTHFR1 - 0.264744, # PROT1 - ] - ) - ), - log_enzyme=jnp.log( - jnp.array( - [ - 0.000961712, # MAT1 - 0.00098812, # MAT3 - 0.00103396, # METH-Gen - 0.000983692, # GNMT1 - 0.000977878, # AHC1 - 0.00105094, # MS1 - 0.000996603, # BHMT1 - 0.00134056, # CBS1 - 0.0010054, # MTHFR1 - 0.000995525, # PROT1 - ] - ) - ), - log_drain=jnp.log(jnp.array([0.000850127])), + dtype=np.float64, +) +reactions = [] +species = [ + "met-L", + "atp", + "pi", + "ppi", + "amet", + "ahcys", + "gly", + "sarcs", + "hcys-L", + "adn", + "thf", + "5mthf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", +] +balanced_species = [ + "met-L", + "amet", + "ahcys", + "hcys-L", + "5mthf", +] +reactions = [ + "the_drain", + "MAT1", + "MAT3", + "METH-Gen", + "GNMT1", + "AHC1", + "MS1", + "BHMT1", + "CBS1", + "MTHFR1", + "PROT1", +] +parameters = ParameterDefinition( + log_kcat={ + 1: jnp.log(jnp.array(7.89577)), # MAT1 + 2: jnp.log(jnp.array(19.9215)), # MAT3 + 3: jnp.log(jnp.array(1.15777)), # METH-Gen + 4: jnp.log(jnp.array(10.5307)), # GNMT1 + 5: jnp.log(jnp.array(234.284)), # AHC1 + 6: jnp.log(jnp.array(1.77471)), # MS1 + 7: jnp.log(jnp.array(13.7676)), # BHMT1 + 8: jnp.log(jnp.array(7.02307)), # CBS1 + 9: jnp.log(jnp.array(3.1654)), # MTHFR1 + 10: jnp.log(jnp.array(0.264744)), # PROT1 + }, + log_enzyme={ + 1: jnp.log(jnp.array(0.000961712)), # MAT1 + 2: jnp.log(jnp.array(0.00098812)), # MAT3 + 3: jnp.log(jnp.array(0.00103396)), # METH-Gen + 4: jnp.log(jnp.array(0.000983692)), # GNMT1 + 5: jnp.log(jnp.array(0.000977878)), # AHC1 + 6: jnp.log(jnp.array(0.00105094)), # MS1 + 7: jnp.log(jnp.array(0.000996603)), # BHMT1 + 8: jnp.log(jnp.array(0.00134056)), # CBS1 + 9: jnp.log(jnp.array(0.0010054)), # MTHFR1 + 10: jnp.log(jnp.array(0.000995525)), # PROT1 + }, + log_drain={0: jnp.log(jnp.array(0.000850127))}, dgf=jnp.array( [ 160.953, # met-L @@ -221,40 +151,34 @@ -46.4737, # cyst-L ] ), - log_km=jnp.log( - jnp.array( - [ - 0.000106919, # met-L MAT1 - 0.00203015, # atp MAT1 - 0.00113258, # met-L MAT3 - 0.00236759, # atp MAT3 - 9.37e-06, # amet METH-Gen - 0.000520015, # amet GNMT1 - 0.00253545, # gly GNMT1 - 2.32e-05, # ahcys AHC1 - 1.06e-05, # hcys-L AHC1 - 5.66e-06, # adn AHC1 - 1.71e-06, # hcys-L MS1 - 6.94e-05, # 5mthf MS1 - 1.98e-05, # hcys-L BHMT1 - 0.00845898, # glyb BHMT1 - 4.24e-05, # hcys-L CBS1 - 2.83e-06, # ser-L CBS1 - 8.08e-05, # mlthf MTHFR1 - 2.09e-05, # nadph MTHFR1 - 4.39e-05, # met-L PROT1 - ] - ) - ), - log_ki=jnp.log( - jnp.array( - [ - 0.000346704, # MAT1 - 5.56e-06, # METH-Gen - 5.31e-05, # GNMT1 - ] - ) - ), + log_km={ + 1: [jnp.array([0.000106919, 0.00203015])], # MAT1 met-L, atp + 2: [jnp.array([0.00113258, 0.00236759])], # MAT3 met-L atp + 3: [jnp.array([9.37e-06])], # amet METH-Gen + 4: [jnp.array([0.000520015, 0.00253545])], # amet GNMT1, # gly GNMT1 + 5: [ + jnp.array([2.32e-05]), # ahcys AHC1 + jnp.array([1.06e-05, 5.66e-06]), # hcys-L AHC1, # adn AHC1 + ], + 6: [jnp.array([1.71e-06, 6.94e-05])], # hcys-L MS1, # 5mthf MS1 + 7: [jnp.array([1.98e-05, 0.00845898])], # hcys-L BHMT1, # glyb BHMT1 + 8: [jnp.array([4.24e-05, 2.83e-06])], # hcys-L CBS1, # ser-L CBS1 + 9: [jnp.array([8.08e-05, 2.09e-05])], # mlthf MTHFR1, # nadph MTHFR1 + 10: [jnp.array([4.39e-05])], # met-L PROT1 + }, + temperature=jnp.array(298.15), + log_ki={ + 1: jnp.array([jnp.log(0.000346704)]), # MAT1 + 2: jnp.array([]), + 3: jnp.array([jnp.log(5.56e-06)]), # METH-Gen + 4: jnp.array([jnp.log(5.31e-05)]), # GNMT1 + 5: jnp.array([]), + 6: jnp.array([]), + 7: jnp.array([]), + 8: jnp.array([]), + 9: jnp.array([]), + 10: jnp.array([]), + }, log_conc_unbalanced=jnp.log( jnp.array( [ @@ -276,233 +200,63 @@ ] ) ), - temperature=jnp.array(298.15), - log_transfer_constant=jnp.log( - jnp.array( - [ - 0.107657, # METAT - 131.207, # GNMT - 1.03452, # CBS - 0.392035, # MTHFR - ] - ) - ), - log_dissociation_constant=jnp.log( - jnp.array( - [ - 0.00059999, # met-L MAT3 - 0.000316641, # amet MAT3 - 1.98e-05, # amet GNMT1 - 0.000228576, # mlthf GNMT1 - 9.30e-05, # amet CBS1 - 1.46e-05, # amet MTHFR1 - 2.45e-06, # ahcys MTHFR1 - ] - ) - ), + log_tc={ + 2: jnp.array(jnp.log(0.107657)), # MAT3 + 4: jnp.array(jnp.log(131.207)), # GNMT + 8: jnp.array(jnp.log(1.03452)), # CBS + 9: jnp.array(jnp.log(0.392035)), # MTHFR + }, + log_dc_activator={ + 2: jnp.log( + jnp.array([0.00059999, 0.000316641]) + ), # met-L MAT3, # amet MAT3 + 4: jnp.log(jnp.array([1.98e-05])), # amet GNMT1 + 8: jnp.array([]), # CBS1 + 9: jnp.log(jnp.array([2.45e-06])), # ahcys MTHFR1, + }, + log_dc_inhibitor={ + 2: jnp.array([]), # MAT3 + 4: jnp.log(jnp.array([0.000228576])), # mlthf GNMT1 + 8: jnp.log(jnp.array([9.30e-05])), # amet CBS1 + 9: jnp.log(jnp.array([1.46e-05])), # amet MTHFR1 + }, ) -structure = KineticModelStructure( - S=jnp.array( - [ - [1, -1, -1, 0, 0, 0, 1, 1, 0, 0, -1], # met-L b - [0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0], # atp - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # pi - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # ppi - [0, 1, 1, -1, -1, 0, 0, 0, 0, 0, 0], # amet b - [0, 0, 0, 1, 1, -1, 0, 0, 0, 0, 0], # ahcys b - [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], # gly - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # sarcs - [0, 0, 0, 0, 0, 1, -1, -1, -1, 0, 0], # hcys-L b - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # adn - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # thf - [0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0], # 5mthf b - [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # mlthf - [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0], # glyb - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # dmgly - [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], # ser-L - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # nadp - [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # nadph - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # cyst-L - ], - dtype=jnp.float64, - ), - balanced_species=jnp.array( - [ - 0, # met-L - 4, # amet - 5, # ahcys - 8, # hcys-L - 11, # 5mthf - ] - ), - unbalanced_species=jnp.array( - [ - 1, # atp - 2, # pi - 3, # ppi - 6, # gly - 7, # sarcs - 9, # adn - 10, # thf - 12, # mlthf - 13, # glyb - 14, # dmgly - 15, # ser-L - 16, # nadp - 17, # nadph - 18, # cyst-L - ] - ), -) -rate_equations = [ - Drain(sign=jnp.array(1.0), drain_ix=0), # met-L source - IrreversibleMichaelisMenten( # MAT1 - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([0], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 + +structure = RateEquationKineticModelStructure( + S=S, + species=species, + reactions=reactions, + balanced_species=balanced_species, + rate_equations=[ + Drain(sign=1.0), # met-L source + IrreversibleMichaelisMenten(), # MAT1 + AllostericIrreversibleMichaelisMenten( # MAT3 + subunits=2, + ix_allosteric_activators=np.array([0, 4], dtype=np.int16), ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([4], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MAT3 - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 + IrreversibleMichaelisMenten(), # METH + AllostericIrreversibleMichaelisMenten( # GNMT1 + subunits=4, + ix_allosteric_inhibitors=np.array([12], dtype=np.int16), + ix_allosteric_activators=np.array([4], dtype=np.int16), ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0, 1], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([0, 4], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # METH - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4], dtype=jnp.int16), - ki_ix=jnp.array([1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([4], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # GNMT1 - kcat_ix=3, - enzyme_ix=3, - km_ix=jnp.array([5, 6], dtype=jnp.int16), - ki_ix=jnp.array([2], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, 1.0, -1.0, 1.0], dtype=jnp.int16 + ReversibleMichaelisMenten( + water_stoichiometry=-1.0, + ), # AHC + IrreversibleMichaelisMenten(), # MS + IrreversibleMichaelisMenten(), # BHMT + AllostericIrreversibleMichaelisMenten( # CBS1 + subunits=2, + ix_allosteric_inhibitors=np.array([4], dtype=np.int16), ), - ix_substrate=jnp.array([4, 6], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 2], dtype=jnp.int16), - subunits=4, - tc_ix=1, - ix_dc_activation=jnp.array([2], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([3], dtype=jnp.int16), - species_inhibition=jnp.array([12], dtype=jnp.int16), - species_activation=jnp.array([4], dtype=jnp.int16), - ), - ReversibleMichaelisMenten( # AHC - kcat_ix=4, - enzyme_ix=4, - km_ix=jnp.array([7, 8, 9], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0, 1.0, 1.0], dtype=jnp.int16), - ix_substrate=jnp.array([5], dtype=jnp.int16), - ix_product=jnp.array([8, 9], dtype=jnp.int16), - ix_reactants=jnp.array([5, 8, 9], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - product_km_positions=jnp.array([1, 2], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - water_stoichiometry=jnp.array(-1.0), - reactant_to_dgf=jnp.array([5, 8, 9], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # MS - kcat_ix=5, - enzyme_ix=5, - km_ix=jnp.array([10, 11], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 11], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # BHMT - kcat_ix=6, - enzyme_ix=6, - km_ix=jnp.array([12, 13], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 13], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # CBS1 - kcat_ix=7, - enzyme_ix=7, - km_ix=jnp.array([14, 15], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 15], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=2, - ix_dc_activation=jnp.array([4], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MTHFR - kcat_ix=8, - enzyme_ix=8, - km_ix=jnp.array([16, 17], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([12, 17], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - subunits=2, - tc_ix=3, - ix_dc_activation=jnp.array([6], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([5], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([5], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # PROT - kcat_ix=9, - enzyme_ix=9, - km_ix=jnp.array([18], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ), -] + AllostericIrreversibleMichaelisMenten( # MTHFR + subunits=2, + ix_allosteric_inhibitors=np.array([4], dtype=np.int16), + ix_allosteric_activators=np.array([5], dtype=np.int16), + ), + IrreversibleMichaelisMenten(), # PROT + ], +) steady_state = jnp.array( [ 4.233000e-05, # met-L @@ -512,4 +266,4 @@ 6.534400e-06, # 5mthf ] ) -model = RateEquationModel(parameters, structure, rate_equations) +model = RateEquationModel(parameters, structure) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 1f9c39d..8478f1f 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -32,13 +32,16 @@ def __init__( species, reactions, balanced_species, - species_to_dgf_ix, + species_to_dgf_ix=None, ): self.S = S self.species = species self.reactions = reactions self.balanced_species = balanced_species - self.species_to_dgf_ix = species_to_dgf_ix + if species_to_dgf_ix is None: + self.species_to_dgf_ix = np.arange(len(species), dtype=np.int16) + else: + self.species_to_dgf_ix = species_to_dgf_ix self.balanced_species_ix = np.array( [i for i, s in enumerate(species) if s in balanced_species], dtype=np.int16, @@ -73,8 +76,8 @@ def __init__( species, reactions, balanced_species, - species_to_dgf_ix, rate_equations, + species_to_dgf_ix=None, ): super().__init__( S, species, reactions, balanced_species, species_to_dgf_ix diff --git a/src/enzax/rate_equations/drain.py b/src/enzax/rate_equations/drain.py index 0e1994c..680a108 100644 --- a/src/enzax/rate_equations/drain.py +++ b/src/enzax/rate_equations/drain.py @@ -1,27 +1,30 @@ +import equinox as eqx from jax import numpy as jnp +import numpy as np +from numpy.typing import NDArray from jaxtyping import PyTree, Scalar from enzax.rate_equation import ConcArray, RateEquation -def get_drain_flux(sign: Scalar, log_v: Scalar) -> Scalar: - """Get the flux of a drain reaction. - - :param sign: a scalar value (should be either one or zero) representing the direction of the reaction. - - :param log_v: a scalar representing the magnitude of the reaction, on log scale. - - """ # Noqa: E501 - return sign * jnp.exp(log_v) +class DrainInput(eqx.Module): + abs_v: Scalar class Drain(RateEquation): """A drain reaction.""" - sign: Scalar - drain_ix: int + sign: float + + def get_input( + self, + parameters: PyTree, + rxn_ix: int, + S: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return DrainInput(abs_v=jnp.exp(parameters.log_drain[rxn_ix])) - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: + def __call__(self, conc: ConcArray, drain_input: PyTree) -> Scalar: """Get the flux of a drain reaction.""" - log_v = parameters.log_drain[self.drain_ix] - return get_drain_flux(self.sign, log_v) + return self.sign * drain_input.abs_v diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 6995103..408856e 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -41,7 +41,7 @@ def get_allosteric_irreversible_michaelis_menten_input( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][ix_substrate]), + substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, ki=jnp.exp(parameters.log_ki[rxn_ix]), @@ -69,7 +69,7 @@ def get_allosteric_reversible_michaelis_menten_input( substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), ki=jnp.exp(parameters.log_ki[rxn_ix]), - dgf=jnp.exp(parameters.dgf[species_to_dgf_ix][ix_reactant]), + dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, ix_ki_species=ci_ix, ix_reactant=ix_reactant, diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 433ce32..3177f34 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -73,7 +73,7 @@ def get_reversible_michaelis_menten_input( substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), ki=jnp.exp(parameters.log_ki[rxn_ix]), - dgf=jnp.exp(parameters.dgf[species_to_dgf_ix][ix_reactant]), + dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, ix_ki_species=ci_ix, ix_reactant=ix_reactant, diff --git a/tests/test_examples.py b/tests/test_examples.py index 0c183e5..d11cc2a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,13 +1,13 @@ from jax import numpy as jnp import pytest -from enzax.examples import linear +from enzax.examples import linear, methionine @pytest.mark.parametrize( ["model", "steady_state"], [ - # (methionine.model, methionine.steady_state), + (methionine.model, methionine.steady_state), (linear.model, linear.steady_state), ], ) From 34cc521fb371fab46cfad2dcb15dfbc5323abd61 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 14:13:57 +0100 Subject: [PATCH 09/27] Fix incorrect kms in methionine model --- src/enzax/examples/methionine.py | 34 +++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 1e16011..a2202fe 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -152,19 +152,31 @@ class ParameterDefinition(eqx.Module): ] ), log_km={ - 1: [jnp.array([0.000106919, 0.00203015])], # MAT1 met-L, atp - 2: [jnp.array([0.00113258, 0.00236759])], # MAT3 met-L atp - 3: [jnp.array([9.37e-06])], # amet METH-Gen - 4: [jnp.array([0.000520015, 0.00253545])], # amet GNMT1, # gly GNMT1 + 1: [jnp.log(jnp.array([0.000106919, 0.00203015]))], # MAT1 met-L, atp + 2: [jnp.log(jnp.array([0.00113258, 0.00236759]))], # MAT3 met-L atp + 3: [jnp.log(jnp.array([9.37e-06]))], # amet METH-Gen + 4: [ + jnp.log(jnp.array([0.000520015, 0.00253545])) + ], # amet GNMT1, # gly GNMT1 5: [ - jnp.array([2.32e-05]), # ahcys AHC1 - jnp.array([1.06e-05, 5.66e-06]), # hcys-L AHC1, # adn AHC1 + jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 + jnp.log( + jnp.array([1.06e-05, 5.66e-06]) + ), # hcys-L AHC1, # adn AHC1 ], - 6: [jnp.array([1.71e-06, 6.94e-05])], # hcys-L MS1, # 5mthf MS1 - 7: [jnp.array([1.98e-05, 0.00845898])], # hcys-L BHMT1, # glyb BHMT1 - 8: [jnp.array([4.24e-05, 2.83e-06])], # hcys-L CBS1, # ser-L CBS1 - 9: [jnp.array([8.08e-05, 2.09e-05])], # mlthf MTHFR1, # nadph MTHFR1 - 10: [jnp.array([4.39e-05])], # met-L PROT1 + 6: [ + jnp.log(jnp.array([1.71e-06, 6.94e-05])) + ], # hcys-L MS1, # 5mthf MS1 + 7: [ + jnp.log(jnp.array([1.98e-05, 0.00845898])) + ], # hcys-L BHMT1, # glyb BHMT1 + 8: [ + jnp.log(jnp.array([4.24e-05, 2.83e-06])) + ], # hcys-L CBS1, # ser-L CBS1 + 9: [ + jnp.log(jnp.array([8.08e-05, 2.09e-05])) + ], # mlthf MTHFR1, # nadph MTHFR1 + 10: [jnp.log(jnp.array([4.39e-05]))], # met-L PROT1 }, temperature=jnp.array(298.15), log_ki={ From 70395e714406fccc70e31171c8668382bf419dcd Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 14:30:39 +0100 Subject: [PATCH 10/27] flatten species indexes --- src/enzax/rate_equations/generalised_mwc.py | 8 ++++---- src/enzax/rate_equations/michaelis_menten.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 408856e..dbab431 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -36,7 +36,7 @@ def get_allosteric_irreversible_michaelis_menten_input( ci_ix: NDArray[np.int16], ) -> AllostericIrreversibleMichaelisMentenInput: Sj = S[:, rxn_ix] - ix_substrate = np.argwhere(Sj < 0.0) + ix_substrate = np.argwhere(Sj < 0.0).flatten() return AllostericIrreversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), @@ -60,9 +60,9 @@ def get_allosteric_reversible_michaelis_menten_input( water_stoichiometry: float, ) -> AllostericReversibleMichaelisMentenInput: Sj = S[:, rxn_ix] - ix_reactant = np.argwhere(Sj != 0.0) - ix_substrate = np.argwhere(Sj < 0.0) - ix_product = np.argwhere(Sj > 0.0) + ix_reactant = np.argwhere(Sj != 0.0).flatten() + ix_substrate = np.argwhere(Sj < 0.0).flatten() + ix_product = np.argwhere(Sj > 0.0).flatten() return AllostericReversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 3177f34..51fbb44 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -64,9 +64,9 @@ def get_reversible_michaelis_menten_input( water_stoichiometry: float, ) -> ReversibleMichaelisMentenInput: Sj = S[:, rxn_ix] - ix_reactant = np.argwhere(Sj != 0.0) - ix_substrate = np.argwhere(Sj < 0.0) - ix_product = np.argwhere(Sj > 0.0) + ix_reactant = np.argwhere(Sj != 0.0).flatten() + ix_substrate = np.argwhere(Sj < 0.0).flatten() + ix_product = np.argwhere(Sj > 0.0).flatten() return ReversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), From 31a0cdcce56ad48688b97e6822e257ad6e571c2c Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 14:51:31 +0100 Subject: [PATCH 11/27] Fix methionine kis --- src/enzax/examples/methionine.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index a2202fe..408fe4f 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -241,20 +241,25 @@ class ParameterDefinition(eqx.Module): balanced_species=balanced_species, rate_equations=[ Drain(sign=1.0), # met-L source - IrreversibleMichaelisMenten(), # MAT1 + IrreversibleMichaelisMenten( # MAT1 + ix_ki_species=np.array([4], dtype=np.int16), + ), AllostericIrreversibleMichaelisMenten( # MAT3 subunits=2, ix_allosteric_activators=np.array([0, 4], dtype=np.int16), ), - IrreversibleMichaelisMenten(), # METH + IrreversibleMichaelisMenten( # METH + ix_ki_species=np.array([5], dtype=np.int16) + ), AllostericIrreversibleMichaelisMenten( # GNMT1 subunits=4, ix_allosteric_inhibitors=np.array([12], dtype=np.int16), ix_allosteric_activators=np.array([4], dtype=np.int16), + ix_ki_species=np.array([5], dtype=np.int16), ), - ReversibleMichaelisMenten( + ReversibleMichaelisMenten( # AHC water_stoichiometry=-1.0, - ), # AHC + ), IrreversibleMichaelisMenten(), # MS IrreversibleMichaelisMenten(), # BHMT AllostericIrreversibleMichaelisMenten( # CBS1 From a2019229406053456502d8367dd7bfa436429253 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 16:27:42 +0100 Subject: [PATCH 12/27] Update gradient test --- tests/data/methionine_pldf_grad.json | 197 ++++++++++++--------------- tests/test_lp_grad.py | 162 +++++++++++++--------- 2 files changed, 188 insertions(+), 171 deletions(-) diff --git a/tests/data/methionine_pldf_grad.json b/tests/data/methionine_pldf_grad.json index 1d4bf6d..c1c66cd 100644 --- a/tests/data/methionine_pldf_grad.json +++ b/tests/data/methionine_pldf_grad.json @@ -1,107 +1,90 @@ -{ - "log_kcat": { - "MAT1": -8.515730409411995, - "MAT3": -2.5685756593727778, - "METH-Gen": -27.620460850911506, - "GNMT1": -4.548438844490129, - "AHC1": -11.75344338177346, - "MS1": -16.03214664980906, - "BHMT1": 12.038008848152892, - "CBS1": -27.970086306108342, - "MTHFR1": 21.633517021058395, - "PROT1": -11.72862587853564 - }, - "log_enzyme": { - "MAT1": 6.136798020189005, - "MAT3": 12.083952770228223, - "METH-Gen": -12.96793242131051, - "GNMT1": 10.10408958511087, - "AHC1": 2.8990850478275423, - "MS1": -1.379618220208057, - "BHMT1": 26.690537277753894, - "CBS1": -13.317557876507342, - "MTHFR1": 36.286045450659394, - "PROT1": 2.9239025510653605 - }, - "log_drain": { - "the_drain": 77.11394277227623 - }, - "log_km": { - "met-L MAT1": 6.190853781501735, - "atp MAT1": 5.292693046521437, - "met-L MAT3": 2.4814738123803086, - "atp MAT3": 1.705092962802143, - "amet METH-Gen": 6.601750014600864, - "amet GNMT1": 4.590545886354766, - "gly GNMT1": 4.820977250097301, - "ahcys AHC1": 11.683829123279876, - "hcys-L AHC1": -2.9137920929850405, - "adn AHC1": -1.769215504424688, - "hcys-L MS1": 5.240078482919715, - "5mthf MS1": 14.652528429601952, - "hcys-L BHMT1": -10.22016805310029, - "glyb BHMT1": -10.688987009618199, - "hcys-L CBS1": 25.82504526708683, - "ser-L CBS1": 0.049935152823941564, - "mlthf MTHFR1": -20.851184438848044, - "nadph MTHFR1": -2.4216899858227414, - "met-L PROT1": 5.971084568691083 - }, - "dgf": { - "met-L": 0.0, - "atp": 0.0, - "pi": 0.0, - "ppi": 0.0, - "amet": 0.0, - "ahcys": -1.3226601224426768, - "gly": 0.0, - "sarcs": 0.0, - "hcys-L": 1.3226601224426768, - "adn": 1.3226601224426768, - "thf": 0.0, - "5mthf": 0.0, - "mlthf": 0.0, - "glyb": 0.0, - "dmgly": 0.0, - "ser-L": 0.0, - "nadp": 0.0, - "nadph": 0.0, - "cyst-L": 0.0 - }, - "log_ki": { - "MAT1": -0.3185780584014895, - "METH-Gen": -0.24799879824208654, - "GNMT1": 0.0018358603202088403 - }, - "log_conc_unbalanced": { - "atp": 7.654742420277419, - "pi": 14.652528429601, - "ppi": 14.652528429601, - "gly": 9.831551179503698, - "sarcs": 14.652528429601988, - "adn": 19.700379108345924, - "thf": 14.652528429601986, - "mlthf": 35.57574270245104, - "glyb": 25.3415154392192, - "dmgly": 14.652528429601988, - "ser-L": 14.60259327677706, - "nadp": 14.652528429601988, - "nadph": 17.074218415424728, - "cyst-L": 14.652528429601988 - }, - "log_transfer_constant": { - "METAT": 0.07549016742205726, - "GNMT": 1.3246972471203298, - "CBS": 0.00007813277693021312, - "MTHFR": -0.3918664031389874 - }, - "log_dissociation_constant": { - "met-L MAT3": 0.009116267028546848, - "amet MAT3": 0.012649140053254744, - "amet GNMT1": 3.233378796396484, - "mlthf GNMT1": -0.0720298340010019, - "amet CBS1": 0.0, - "amet MTHFR1": 0.5327826505992013, - "ahcys MTHFR1": -0.0637729300752556 - } -} +[ + -2.70293053e+01, + -2.31079335e+01, + -1.08319612e+01, + -7.44295628e+00, + -2.88223090e+01, + -2.00410069e+01, + -2.10469921e+01, + -5.10088063e+01, + 1.27209420e+01, + 7.72397217e+00, + -2.28768851e+01, + -6.39694146e+01, + 4.46212536e+01, + 4.66681334e+01, + -1.12748483e+02, + -2.18539365e-01, + 9.10356077e+01, + 1.05730245e+01, + -2.60683551e+01, + 3.71797358e+01, + 1.12121725e+01, + 1.20587113e+02, + 1.98571817e+01, + 5.13127257e+01, + 6.99924909e+01, + -5.25579647e+01, + 1.22113405e+02, + -9.44512467e+01, + 5.12044438e+01, + 5.67943664e+01, + 2.82182825e+01, + 1.37977316e+02, + 2.76214987e+01, + 3.25225547e+01, + 9.94449635e+01, + 1.06583439e+00, + 1.60305500e+02, + -2.02545386e+01, + 1.79829059e+01, + 1.39091493e+00, + 1.08272907e+00, + -8.01475394e-03, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 5.77442577e+00, + 0.00000000e+00, + 0.00000000e+00, + -5.77442577e+00, + -5.77442577e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 0.00000000e+00, + 4.55892028e-01, + -3.34185220e+01, + -6.39694151e+01, + -6.39694159e+01, + -4.29224268e+01, + -6.39694151e+01, + -8.60071449e+01, + -6.39694154e+01, + -1.55319483e+02, + -1.10637551e+02, + -6.39694162e+01, + -6.37508767e+01, + -6.39694172e+01, + -7.45424412e+01, + -6.39694154e+01, + 3.14461097e-01, + 3.03173942e-04, + -2.32611507e+00, + -3.97937186e-02, + -5.52151996e-02, + -1.41159870e+01, + 2.78431182e-01, + -3.29524486e-01, + -5.78323906e+00, + -6.06395890e-04, + 1.71087794e+00, + -3.36459182e+02 +] diff --git a/tests/test_lp_grad.py b/tests/test_lp_grad.py index 71b9284..6fd6267 100644 --- a/tests/test_lp_grad.py +++ b/tests/test_lp_grad.py @@ -1,15 +1,12 @@ import json import jax -import pytest from jax import numpy as jnp +from jax.scipy.stats import norm +from jax.flatten_util import ravel_pytree from enzax.examples import methionine -from enzax.mcmc import ( - ObservationSet, - AllostericMichaelisMentenPriorSet, - ind_prior_from_truth, - posterior_logdensity_amm, -) +from enzax.kinetic_model import RateEquationModel +from enzax.mcmc import ObservationSet from enzax.steady_state import get_kinetic_model_steady_state import importlib.resources @@ -20,64 +17,85 @@ jax.config.update("jax_enable_x64", True) SEED = 1234 -methionine_pldf_grad_file = ( - importlib.resources.files(data) / "methionine_pldf_grad.json" +methionine_pldf_grad_file = importlib.resources.files(data) / "t.eqx" + +obs_conc = jnp.array( + [ + 3.99618131e-05, + 1.24186458e-03, + 9.44053469e-04, + 4.72041839e-04, + 2.92625684e-05, + 2.04876101e-07, + 1.37054850e-03, + 9.44053469e-08, + 3.32476221e-06, + 9.53494003e-07, + 2.11467977e-05, + 6.16881926e-06, + 2.97376843e-06, + 1.00785260e-03, + 4.72026734e-05, + 1.49849607e-03, + 1.15174523e-06, + 2.31424323e-04, + 2.11467977e-06, + ], + dtype=jnp.float64, +) +obs_flux = jnp.array( + [ + -0.00425181, + 0.03739644, + 0.01397071, + -0.04154405, + -0.05396867, + 0.01236334, + -0.07089178, + -0.02136595, + 0.00152784, + -0.02482788, + -0.01588131, + ], + dtype=jnp.float64, +) +obs_enzyme = jnp.array( + [ + 0.00097884, + 0.00100336, + 0.00105027, + 0.00099059, + 0.00096148, + 0.00107917, + 0.00104588, + 0.00138744, + 0.00107483, + 0.0009662, + ], + dtype=jnp.float64, ) def test_lp_grad(): - model = methionine structure = methionine.structure - rate_equations = methionine.rate_equations true_parameters = methionine.parameters true_model = methionine.model default_state_guess = jnp.full((5,), 0.01) true_states = get_kinetic_model_steady_state( true_model, default_state_guess ) - prior = AllostericMichaelisMentenPriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=( - ind_prior_from_truth(true_parameters.dgf, 0.1)[0], - jnp.diag( - jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) - ), - ), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) + flat_true_params, _ = ravel_pytree(methionine.parameters) # get true concentration true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( + true_conc = true_conc.at[methionine.structure.balanced_species_ix].set( true_states ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( + true_conc = true_conc.at[methionine.structure.unbalanced_species_ix].set( jnp.exp(true_parameters.log_conc_unbalanced) ) - # get true flux - true_flux = true_model.flux(true_states) - # simulate observations error_conc = 0.03 error_flux = 0.05 error_enzyme = 0.03 - key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) - obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme - ) - obs_flux = true_flux + jax.random.normal(key) * error_conc obs = ObservationSet( conc=obs_conc, flux=obs_flux, @@ -86,25 +104,41 @@ def test_lp_grad(): flux_scale=error_flux, enzyme_scale=error_enzyme, ) - pldf = functools.partial( - posterior_logdensity_amm, + + def joint_log_density(params, prior_mean, prior_sd, obs): + flat_params, _ = ravel_pytree(params) + model = RateEquationModel(params, methionine.structure) + steady = get_kinetic_model_steady_state(model, default_state_guess) + unbalanced = jnp.exp(params.log_conc_unbalanced) + conc = jnp.zeros(structure.S.shape[0]) + conc = conc.at[structure.balanced_species_ix].set(steady) + conc = conc.at[structure.unbalanced_species_ix].set(unbalanced) + flux = model.flux(steady) + log_prior = norm.pdf(flat_params, prior_mean, prior_sd).sum() + flat_log_enzyme, _ = ravel_pytree(params.log_enzyme) + log_liklihood = jnp.sum( + jnp.array( + [ + norm.logpdf( + jnp.log(obs.conc), jnp.log(conc), obs.conc_scale + ).sum(), + norm.logpdf( + jnp.log(obs.enzyme), flat_log_enzyme, obs.enzyme_scale + ).sum(), + norm.logpdf(obs.flux, flux, obs.flux_scale).sum(), + ] + ) + ) + return log_prior + log_liklihood + + posterior_log_density = functools.partial( + joint_log_density, + prior_mean=flat_true_params, + prior_sd=0.1, obs=obs, - prior=prior, - structure=structure, - rate_equations=rate_equations, - guess=default_state_guess, ) - pldf_grad = jax.jacrev(pldf)(methionine.parameters) - index_pldf_grad = { - p: { - c: float(getattr(pldf_grad, p)[i]) - for i, c in enumerate(model.coords[model.dims[p][0]]) - } - for p in model.dims.keys() - } + gradient = jax.jacrev(posterior_log_density)(methionine.parameters) + _, grad_pytree_def = ravel_pytree(gradient) with open(methionine_pldf_grad_file, "r") as file: - saved_pldf_grad = file.read() - - true_gradient = json.loads(saved_pldf_grad) - for p, vals in true_gradient.items(): - assert true_gradient[p] == pytest.approx(index_pldf_grad[p]) + expected_gradient = grad_pytree_def(jnp.array(json.load(file))) + assert gradient == expected_gradient From cccb879c10de63704f5eb510670ba0c08c9ffd59 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 16:35:10 +0100 Subject: [PATCH 13/27] fix gradient test --- tests/data/expected_methionine_gradient.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/data/expected_methionine_gradient.json diff --git a/tests/data/expected_methionine_gradient.json b/tests/data/expected_methionine_gradient.json new file mode 100644 index 0000000..e8491a3 --- /dev/null +++ b/tests/data/expected_methionine_gradient.json @@ -0,0 +1 @@ +[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, 12.720942038333074, 7.723972171145526, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092] From b9736e65f60e021382ff9df9c401848a4f51951a Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 16:36:55 +0100 Subject: [PATCH 14/27] delete unused files, tidy grad test --- src/enzax/mcmc.py | 73 ---------------------- src/enzax/parameters.py | 58 ------------------ tests/data/methionine_pldf_grad.json | 90 ---------------------------- tests/test_lp_grad.py | 4 +- 4 files changed, 3 insertions(+), 222 deletions(-) delete mode 100644 src/enzax/parameters.py delete mode 100644 tests/data/methionine_pldf_grad.json diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index 10d16dd..20e0c81 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -12,14 +12,6 @@ from jax.scipy.stats import norm, multivariate_normal from jaxtyping import Array, Float, PyTree, ScalarLike -from enzax.kinetic_model import ( - RateEquationModel, - KineticModelStructure, -) -from enzax.parameters import AllostericMichaelisMentenParameterSet -from enzax.rate_equation import RateEquation -from enzax.steady_state import get_kinetic_model_steady_state - @chex.dataclass class ObservationSet: @@ -33,25 +25,6 @@ class ObservationSet: enzyme_scale: ScalarLike -@chex.dataclass -class AllostericMichaelisMentenPriorSet: - """Priors for an allosteric Michaelis-Menten model.""" - - log_kcat: Float[Array, "2 n_enzyme"] - log_enzyme: Float[Array, "2 n_enzyme"] - log_drain: Float[Array, "2 n_drain"] - dgf: tuple[ - Float[Array, " n_metabolite"], - Float[Array, " n_metabolite n_metabolite"], - ] - log_km: Float[Array, "2 n_km"] - log_ki: Float[Array, "2 n_ki"] - log_conc_unbalanced: Float[Array, "2 n_unbalanced"] - temperature: Float[Array, "2"] - log_transfer_constant: Float[Array, "2 n_allosteric_enzyme"] - log_dissociation_constant: Float[Array, "2 n_allosteric_effect"] - - class AdaptationKwargs(TypedDict): """Keyword arguments to the blackjax function window_adaptation.""" @@ -76,52 +49,6 @@ def mv_normal_prior_logdensity( ) -def posterior_logdensity_amm( - parameters: AllostericMichaelisMentenParameterSet, - structure: KineticModelStructure, - rate_equations: list[RateEquation], - obs: ObservationSet, - prior: AllostericMichaelisMentenPriorSet, - guess: Float[Array, " n_balanced"], -): - """Get the log density for an allosteric Michaelis-Menten model.""" - model = RateEquationModel(parameters, structure, rate_equations) - steady = get_kinetic_model_steady_state(model, guess) - flux = model.flux(steady) - conc = jnp.zeros(model.structure.S.shape[0]) - conc = conc.at[model.structure.balanced_species].set(steady) - conc = conc.at[model.structure.unbalanced_species].set( - jnp.exp(parameters.log_conc_unbalanced) - ) - likelihood_logdensity = ( - norm.logpdf(jnp.log(obs.conc), jnp.log(conc), obs.conc_scale).sum() - + norm.logpdf(obs.flux, flux[0], obs.flux_scale).sum() - + norm.logpdf( - jnp.log(obs.enzyme), parameters.log_enzyme, obs.enzyme_scale - ).sum() - ) - prior_logdensity = ( - ind_normal_prior_logdensity(parameters.log_kcat, prior.log_kcat) - + ind_normal_prior_logdensity(parameters.log_enzyme, prior.log_enzyme) - + ind_normal_prior_logdensity(parameters.log_drain, prior.log_drain) - + mv_normal_prior_logdensity(parameters.dgf, prior.dgf) - + ind_normal_prior_logdensity(parameters.log_km, prior.log_km) - + ind_normal_prior_logdensity( - parameters.log_conc_unbalanced, prior.log_conc_unbalanced - ) - + ind_normal_prior_logdensity(parameters.temperature, prior.temperature) - + ind_normal_prior_logdensity(parameters.log_ki, prior.log_ki) - + ind_normal_prior_logdensity( - parameters.log_transfer_constant, prior.log_transfer_constant - ) - + ind_normal_prior_logdensity( - parameters.log_dissociation_constant, - prior.log_dissociation_constant, - ) - ) - return prior_logdensity + likelihood_logdensity - - @functools.partial(jax.jit, static_argnames=["kernel", "num_samples"]) def _inference_loop(rng_key, kernel, initial_state, num_samples): """Run MCMC with blackjax.""" diff --git a/src/enzax/parameters.py b/src/enzax/parameters.py deleted file mode 100644 index 6534169..0000000 --- a/src/enzax/parameters.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Module with parameters and parameter sets. - -These are not required, but they do provide handy shape checking, for example to ensure that your parameter set has the same number of enzymes and kcat parameters. - -""" # Noqa: E501 - -import equinox as eqx -from jaxtyping import Array, Float, Scalar, jaxtyped -from typeguard import typechecked - -LogKcat = Float[Array, " n_enzyme"] -LogEnzyme = Float[Array, " n_enzyme"] -Dgf = Float[Array, " n_metabolite"] -LogKm = Float[Array, " n_km"] -LogKi = Float[Array, " n_ki"] -LogConcUnbalanced = Float[Array, " n_unbalanced"] -LogDrain = Float[Array, " n_drain"] -LogTransferConstant = Float[Array, " n_allosteric_enzyme"] -LogDissociationConstant = Float[Array, " n_allosteric_effect"] - - -@jaxtyped(typechecker=typechecked) -class MichaelisMentenParameterSet(eqx.Module): - """Parameters for a model with Michaelis Menten kinetics. - - This kind of parameter set supports models with the following rate laws: - - - enzax.rate_equations.drain.Drain - - enzax.rate_equations.michaelis_menten.IrreversibleMichaelisMenten - - enzax.rate_equations.michaelis_menten.ReversibleMichaelisMenten - - """ - - log_kcat: LogKcat - log_enzyme: LogEnzyme - dgf: Dgf - log_km: LogKm - log_ki: LogKi - log_conc_unbalanced: LogConcUnbalanced - temperature: Scalar - log_drain: LogDrain - - -class AllostericMichaelisMentenParameterSet(MichaelisMentenParameterSet): - """Parameters for a model with Michaelis Menten kinetics, with allostery. - - Reactions can be any out of: - - - drain.Drain - - michaelis_menten.IrreversibleMichaelisMenten - - michaelis_menten.ReversibleMichaelisMenten - - generalised_mwc.AllostericIrreversibleMichaelisMenten - - generalised_mwc.AllostericReversibleMichaelisMenten - - """ - - log_transfer_constant: LogTransferConstant - log_dissociation_constant: LogDissociationConstant diff --git a/tests/data/methionine_pldf_grad.json b/tests/data/methionine_pldf_grad.json deleted file mode 100644 index c1c66cd..0000000 --- a/tests/data/methionine_pldf_grad.json +++ /dev/null @@ -1,90 +0,0 @@ -[ - -2.70293053e+01, - -2.31079335e+01, - -1.08319612e+01, - -7.44295628e+00, - -2.88223090e+01, - -2.00410069e+01, - -2.10469921e+01, - -5.10088063e+01, - 1.27209420e+01, - 7.72397217e+00, - -2.28768851e+01, - -6.39694146e+01, - 4.46212536e+01, - 4.66681334e+01, - -1.12748483e+02, - -2.18539365e-01, - 9.10356077e+01, - 1.05730245e+01, - -2.60683551e+01, - 3.71797358e+01, - 1.12121725e+01, - 1.20587113e+02, - 1.98571817e+01, - 5.13127257e+01, - 6.99924909e+01, - -5.25579647e+01, - 1.22113405e+02, - -9.44512467e+01, - 5.12044438e+01, - 5.67943664e+01, - 2.82182825e+01, - 1.37977316e+02, - 2.76214987e+01, - 3.25225547e+01, - 9.94449635e+01, - 1.06583439e+00, - 1.60305500e+02, - -2.02545386e+01, - 1.79829059e+01, - 1.39091493e+00, - 1.08272907e+00, - -8.01475394e-03, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 5.77442577e+00, - 0.00000000e+00, - 0.00000000e+00, - -5.77442577e+00, - -5.77442577e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 0.00000000e+00, - 4.55892028e-01, - -3.34185220e+01, - -6.39694151e+01, - -6.39694159e+01, - -4.29224268e+01, - -6.39694151e+01, - -8.60071449e+01, - -6.39694154e+01, - -1.55319483e+02, - -1.10637551e+02, - -6.39694162e+01, - -6.37508767e+01, - -6.39694172e+01, - -7.45424412e+01, - -6.39694154e+01, - 3.14461097e-01, - 3.03173942e-04, - -2.32611507e+00, - -3.97937186e-02, - -5.52151996e-02, - -1.41159870e+01, - 2.78431182e-01, - -3.29524486e-01, - -5.78323906e+00, - -6.06395890e-04, - 1.71087794e+00, - -3.36459182e+02 -] diff --git a/tests/test_lp_grad.py b/tests/test_lp_grad.py index 6fd6267..9189519 100644 --- a/tests/test_lp_grad.py +++ b/tests/test_lp_grad.py @@ -17,7 +17,9 @@ jax.config.update("jax_enable_x64", True) SEED = 1234 -methionine_pldf_grad_file = importlib.resources.files(data) / "t.eqx" +methionine_pldf_grad_file = ( + importlib.resources.files(data) / "expected_methionine_gradient.json" +) obs_conc = jnp.array( [ From 97e5e7b613649122c18221324a81cf4021914215 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 16:49:29 +0100 Subject: [PATCH 15/27] Update gh action to use uv --- .github/workflows/run_tests.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 23a3513..0a33cc1 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -22,21 +22,25 @@ jobs: - name: checkout code uses: actions/checkout@v2 + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + # Install a specific version of uv. + version: "0.5.5" + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install pdm - run: pip install pdm + - name: Install the project + run: uv sync --all-extras --dev - name: pre-commit checks uses: pre-commit/action@v2.0.3 - name: Run tests - run: | - pdm install --dev - pdm run pytest tests --cov=src/enzax + run: uv run pytest tests - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 From 6b8b627fde26eb72fb1a57caa8f6936f2fed02bf Mon Sep 17 00:00:00 2001 From: teddygroves Date: Thu, 28 Nov 2024 17:12:20 +0100 Subject: [PATCH 16/27] Update grad test to be approximate --- tests/test_lp_grad.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_lp_grad.py b/tests/test_lp_grad.py index 9189519..a7aa861 100644 --- a/tests/test_lp_grad.py +++ b/tests/test_lp_grad.py @@ -3,6 +3,7 @@ from jax import numpy as jnp from jax.scipy.stats import norm from jax.flatten_util import ravel_pytree +from jaxtyping import Array, Scalar from enzax.examples import methionine from enzax.kinetic_model import RateEquationModel @@ -143,4 +144,19 @@ def joint_log_density(params, prior_mean, prior_sd, obs): _, grad_pytree_def = ravel_pytree(gradient) with open(methionine_pldf_grad_file, "r") as file: expected_gradient = grad_pytree_def(jnp.array(json.load(file))) - assert gradient == expected_gradient + for k in gradient.__dataclass_fields__.keys(): + obs = getattr(gradient, k) + exp = getattr(expected_gradient, k) + if isinstance(obs, Scalar): + assert jnp.isclose(obs, exp) + elif isinstance(obs, Array): + assert jnp.isclose(obs, exp).all() + elif isinstance(obs, dict): + for kk in obs.keys(): + if isinstance(obs[kk], list): + for o, e in zip(obs[kk], exp[kk]): + assert jnp.isclose(o, e).all() + elif isinstance(obs[kk], Scalar): + assert jnp.isclose(obs[kk], exp[kk]) + elif len(obs[kk]) > 0: + assert jnp.isclose(obs[kk], exp[kk]).all() From 0da0dd338da4593debcf359d086f5f71297083ed Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 29 Nov 2024 13:00:54 +0100 Subject: [PATCH 17/27] get mcmc demo to work --- scripts/mcmc_demo.py | 118 ++++++++++--------- scripts/steady_state_demo.py | 2 +- src/enzax/examples/linear.py | 16 ++- src/enzax/examples/methionine.py | 43 +++---- src/enzax/kinetic_model.py | 15 ++- src/enzax/mcmc.py | 15 ++- src/enzax/rate_equations/generalised_mwc.py | 6 +- src/enzax/rate_equations/michaelis_menten.py | 6 +- tests/data/expected_methionine_gradient.json | 2 +- tests/test_rate_equations.py | 6 +- 10 files changed, 121 insertions(+), 108 deletions(-) diff --git a/scripts/mcmc_demo.py b/scripts/mcmc_demo.py index 40d428d..cc8491c 100644 --- a/scripts/mcmc_demo.py +++ b/scripts/mcmc_demo.py @@ -7,14 +7,15 @@ import arviz as az import jax from jax import numpy as jnp +from jax.flatten_util import ravel_pytree +from jax.scipy.stats import norm +from jaxtyping import Array from enzax.examples import methionine +from enzax.kinetic_model import RateEquationModel, get_conc from enzax.mcmc import ( ObservationSet, - AllostericMichaelisMentenPriorSet, get_idata, - ind_prior_from_truth, - posterior_logdensity_amm, run_nuts, ) from enzax.steady_state import get_kinetic_model_steady_state @@ -24,59 +25,55 @@ jax.config.update("jax_enable_x64", True) +def joint_log_density(params, prior_mean, prior_sd, obs, guess): + # find the steady state concentration and flux + model = RateEquationModel(params, methionine.structure) + steady = get_kinetic_model_steady_state(model, guess) + conc = get_conc(steady, params.log_conc_unbalanced, methionine.structure) + flux = model.flux(steady) + # prior + flat_params, _ = ravel_pytree(params) + log_prior = norm.logpdf(flat_params, loc=prior_mean, scale=prior_sd).sum() + # likelihood + flat_log_enzyme, _ = ravel_pytree(params.log_enzyme) + log_likelihood = ( + norm.logpdf(jnp.log(obs.conc), jnp.log(conc), obs.conc_scale).sum() + + norm.logpdf( + jnp.log(obs.enzyme), flat_log_enzyme, obs.enzyme_scale + ).sum() + + norm.logpdf(obs.flux, flux, obs.flux_scale).sum() + ) + return log_prior + log_likelihood + + def main(): """Demonstrate How to make a Bayesian kinetic model with enzax.""" - structure = methionine.structure - rate_equations = methionine.rate_equations true_parameters = methionine.parameters true_model = methionine.model - default_state_guess = jnp.full((5,), 0.01) - true_states = get_kinetic_model_steady_state( - true_model, default_state_guess - ) - prior = AllostericMichaelisMentenPriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=( - ind_prior_from_truth(true_parameters.dgf, 0.1)[0], - jnp.diag( - jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) - ), - ), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) + default_guess = jnp.full((5,), 0.01) + true_steady = get_kinetic_model_steady_state(true_model, default_guess) # get true concentration - true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( - true_states - ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( - jnp.exp(true_parameters.log_conc_unbalanced) + true_conc = get_conc( + true_steady, + true_parameters.log_conc_unbalanced, + methionine.structure, ) # get true flux - true_flux = true_model.flux(true_states) + true_flux = true_model.flux(true_steady) # simulate observations error_conc = 0.03 error_flux = 0.05 error_enzyme = 0.03 key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) + true_log_enz_flat, _ = ravel_pytree(true_parameters.log_enzyme) + key_conc, key_enz, key_flux, key_nuts = jax.random.split(key, num=4) + obs_conc = jnp.exp( + jnp.log(true_conc) + jax.random.normal(key_conc) * error_conc + ) obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme + true_log_enz_flat + jax.random.normal(key_enz) * error_enzyme ) - obs_flux = true_flux + jax.random.normal(key) * error_conc + obs_flux = true_flux + jax.random.normal(key_flux) * error_conc obs = ObservationSet( conc=obs_conc, flux=obs_flux, @@ -85,17 +82,19 @@ def main(): flux_scale=error_flux, enzyme_scale=error_enzyme, ) - pldf = functools.partial( - posterior_logdensity_amm, - obs=obs, - prior=prior, - structure=structure, - rate_equations=rate_equations, - guess=default_state_guess, + flat_true_params, _ = ravel_pytree(true_parameters) + posterior_log_density = jax.jit( + functools.partial( + joint_log_density, + obs=obs, + prior_mean=flat_true_params, + prior_sd=0.1, + guess=default_guess, + ) ) samples, info = run_nuts( - pldf, - key, + posterior_log_density, + key_nuts, true_parameters, num_warmup=200, num_samples=200, @@ -104,9 +103,7 @@ def main(): is_mass_matrix_diagonal=False, target_acceptance_rate=0.95, ) - idata = get_idata( - samples, info, coords=methionine.coords, dims=methionine.dims - ) + idata = get_idata(samples, info) print(az.summary(idata)) if jnp.any(info.is_divergent): n_divergent = info.is_divergent.sum() @@ -117,10 +114,15 @@ def main(): print("True parameter values vs posterior:") for param in true_parameters.__dataclass_fields__.keys(): true_val = getattr(true_parameters, param) - model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) - model_high = jnp.quantile( - getattr(samples.position, param), 0.99, axis=0 - ) + model_p = getattr(samples.position, param) + if isinstance(true_val, Array): + model_low = jnp.quantile(model_p, 0.01, axis=0) + model_high = jnp.quantile(model_p, 0.99, axis=0) + elif isinstance(true_val, dict): + model_low, model_high = ( + {k: jnp.quantile(v, q, axis=0) for k, v in model_p.items()} + for q in (0.01, 0.99) + ) print(f" {param}:") print(f" true value: {true_val}") print(f" posterior 1%: {model_low}") diff --git a/scripts/steady_state_demo.py b/scripts/steady_state_demo.py index 5feeda3..45cc7c7 100644 --- a/scripts/steady_state_demo.py +++ b/scripts/steady_state_demo.py @@ -54,7 +54,7 @@ def get_steady_state_from_params(parameters: PyTree): print(f"\tSteady state concentration: {conc_steady}") print(f"\tFlux: {flux}") print(f"\tSv: {sv}") - print(f"\tLog Km Jacobian: {jac.log_km}") + print(f"\tLog substrate Km Jacobian: {jac.log_substrate_km}") print(f"\tDgf Jacobian: {jac.dgf}") diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 9f03df2..6a1b3b0 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -16,7 +16,8 @@ class ParameterDefinition(eqx.Module): - log_km: dict[int, Array] + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] log_kcat: dict[int, Scalar] log_enzyme: dict[int, Array] log_ki: dict[int, Array] @@ -50,10 +51,15 @@ class ParameterDefinition(eqx.Module): rate_equations=rate_equations, ) parameters = ParameterDefinition( - log_km={ - 0: jnp.array([[0.1], [-0.2]]), - 1: jnp.array([[0.5], [0.0]]), - 2: jnp.array([[-1.0], [0.5]]), + log_substrate_km={ + 0: jnp.array([0.1]), + 1: jnp.array([0.5]), + 2: jnp.array([-1.0]), + }, + log_product_km={ + 0: jnp.array([-0.2]), + 1: jnp.array([0.0]), + 2: jnp.array([0.5]), }, log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, dgf=jnp.array([-3.0, -1.0]), diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 408fe4f..9ca0674 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -23,7 +23,8 @@ class ParameterDefinition(eqx.Module): - log_km: dict[int, list[Array]] + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] log_kcat: dict[int, Scalar] log_enzyme: dict[int, Array] log_ki: dict[int, Array] @@ -151,32 +152,20 @@ class ParameterDefinition(eqx.Module): -46.4737, # cyst-L ] ), - log_km={ - 1: [jnp.log(jnp.array([0.000106919, 0.00203015]))], # MAT1 met-L, atp - 2: [jnp.log(jnp.array([0.00113258, 0.00236759]))], # MAT3 met-L atp - 3: [jnp.log(jnp.array([9.37e-06]))], # amet METH-Gen - 4: [ - jnp.log(jnp.array([0.000520015, 0.00253545])) - ], # amet GNMT1, # gly GNMT1 - 5: [ - jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 - jnp.log( - jnp.array([1.06e-05, 5.66e-06]) - ), # hcys-L AHC1, # adn AHC1 - ], - 6: [ - jnp.log(jnp.array([1.71e-06, 6.94e-05])) - ], # hcys-L MS1, # 5mthf MS1 - 7: [ - jnp.log(jnp.array([1.98e-05, 0.00845898])) - ], # hcys-L BHMT1, # glyb BHMT1 - 8: [ - jnp.log(jnp.array([4.24e-05, 2.83e-06])) - ], # hcys-L CBS1, # ser-L CBS1 - 9: [ - jnp.log(jnp.array([8.08e-05, 2.09e-05])) - ], # mlthf MTHFR1, # nadph MTHFR1 - 10: [jnp.log(jnp.array([4.39e-05]))], # met-L PROT1 + log_product_km={ + 5: jnp.log(jnp.array([1.06e-05, 5.66e-06])), # hcys-L AHC1, adn AHC1 + }, + log_substrate_km={ + 1: jnp.log(jnp.array([0.000106919, 0.00203015])), # MAT1 met-L, atp + 2: jnp.log(jnp.array([0.00113258, 0.00236759])), # MAT3 met-L atp + 3: jnp.log(jnp.array([9.37e-06])), # METH-Gen amet + 4: jnp.log(jnp.array([0.000520015, 0.00253545])), # GNMT1, amet, gly + 5: jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 + 6: jnp.log(jnp.array([1.71e-06, 6.94e-05])), # MS1 hcys-L, 5mthf + 7: jnp.log(jnp.array([1.98e-05, 0.00845898])), # BHMT1 hcys-L, glyb + 8: jnp.log(jnp.array([4.24e-05, 2.83e-06])), # CBS1 hcys-L, ser-L + 9: jnp.log(jnp.array([8.08e-05, 2.09e-05])), # MTHFR1 mlthf, nadph + 10: jnp.log(jnp.array([4.39e-05])), # PROT1 met-L }, temperature=jnp.array(298.15), log_ki={ diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 8478f1f..6430f87 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -13,6 +13,13 @@ from enzax.rate_equation import RateEquation +def get_conc(balanced, log_unbalanced, structure): + conc = jnp.zeros(structure.S.shape[0]) + conc = conc.at[structure.balanced_species_ix].set(balanced) + conc = conc.at[structure.unbalanced_species_ix].set(jnp.exp(log_unbalanced)) + return conc + + @jaxtyped(typechecker=typechecked) @register_pytree_node_class class KineticModelStructure: @@ -143,10 +150,10 @@ def flux( :return: a one dimensional array of (possibly negative) floats representing reaction fluxes. Has same size as number of columns of self.structure.S. """ # Noqa: E501 - conc = jnp.zeros(self.structure.S.shape[0]) - conc = conc.at[self.structure.balanced_species_ix].set(conc_balanced) - conc = conc.at[self.structure.unbalanced_species_ix].set( - jnp.exp(self.parameters.log_conc_unbalanced) + conc = get_conc( + conc_balanced, + self.parameters.log_conc_unbalanced, + self.structure, ) flux_list = [] for i, rate_equation in enumerate(self.structure.rate_equations): diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index 20e0c81..eb17ea0 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -106,10 +106,17 @@ def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike): def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData: """Get an arviz InferenceData from a blackjax NUTS output.""" - sample_dict = { - k: jnp.expand_dims(getattr(samples.position, k), 0) - for k in samples.position.__dataclass_fields__.keys() - } + if coords is None: + coords = dict() + sample_dict = dict() + for k in samples.position.__dataclass_fields__.keys(): + samples_k = getattr(samples.position, k) + if isinstance(samples_k, Array): + sample_dict[k] = jnp.expand_dims(samples_k, 0) + elif isinstance(samples_k, dict): + sample_dict[k] = jnp.expand_dims( + jnp.concat([v.T for v in samples_k.values()]).T, 0 + ) posterior = az.convert_to_inference_data( sample_dict, group="posterior", diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index dbab431..ea9ee60 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -41,7 +41,7 @@ def get_allosteric_irreversible_michaelis_menten_input( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, ki=jnp.exp(parameters.log_ki[rxn_ix]), @@ -66,8 +66,8 @@ def get_allosteric_reversible_michaelis_menten_input( return AllostericReversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), - product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), + product_kms=jnp.exp(parameters.log_product_km[rxn_ix]), ki=jnp.exp(parameters.log_ki[rxn_ix]), dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 51fbb44..df44865 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -30,7 +30,7 @@ def get_irreversible_michaelis_menten_input( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, ki=jnp.exp(parameters.log_ki[rxn_ix]), @@ -70,8 +70,8 @@ def get_reversible_michaelis_menten_input( return ReversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), - product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), + product_kms=jnp.exp(parameters.log_product_km[rxn_ix]), ki=jnp.exp(parameters.log_ki[rxn_ix]), dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, diff --git a/tests/data/expected_methionine_gradient.json b/tests/data/expected_methionine_gradient.json index e8491a3..0271271 100644 --- a/tests/data/expected_methionine_gradient.json +++ b/tests/data/expected_methionine_gradient.json @@ -1 +1 @@ -[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, 12.720942038333074, 7.723972171145526, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092] +[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 12.720942038333074, 7.723972171145526, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092] diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index e14100e..3b2961f 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -13,7 +13,8 @@ class ExampleParameterSet(eqx.Module): - log_km: dict[int, Array] + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] log_kcat: dict[int, Scalar] log_enzyme: dict[int, Array] log_ki: dict[int, Array] @@ -28,7 +29,8 @@ class ExampleParameterSet(eqx.Module): EXAMPLE_S = np.array([[-1], [1], [0]], dtype=np.float64) EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1]) EXAMPLE_PARAMETERS = ExampleParameterSet( - log_km={0: jnp.array([[0.1], [-0.2]])}, + log_substrate_km={0: jnp.array([0.1])}, + log_product_km={0: jnp.array([-0.2])}, log_kcat={0: jnp.array(-0.1)}, dgf=jnp.array([-3.0, 1.0]), log_ki={0: jnp.array([1.0])}, From a0e8c46657ce9169e3a2782f0ced8b02962d5c8f Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 29 Nov 2024 13:02:23 +0100 Subject: [PATCH 18/27] get steady state demo to work --- scripts/steady_state_demo.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scripts/steady_state_demo.py b/scripts/steady_state_demo.py index 45cc7c7..f030099 100644 --- a/scripts/steady_state_demo.py +++ b/scripts/steady_state_demo.py @@ -33,9 +33,7 @@ def get_steady_state_from_params(parameters: PyTree): This lets us get the Jacobian wrt (just) the parameters. """ - _model = RateEquationModel( - parameters, model.structure, model.rate_equations - ) + _model = RateEquationModel(parameters, model.structure) return get_kinetic_model_steady_state(_model, guess) # solve once for jitting From c3dffffce596e7da2e2894684a4f3b8171ede34e Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 29 Nov 2024 13:08:09 +0100 Subject: [PATCH 19/27] update readme --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index fdc612f..8bbec17 100644 --- a/README.md +++ b/README.md @@ -43,9 +43,7 @@ model = methionine.model def get_steady_state_from_params(parameters: PyTree): """Get the steady state with a one-argument non-pure function.""" - _model = RateEquationModel( - parameters, model.structure, model.rate_equations - ) + _model = RateEquationModel(parameters, model.structure) return get_kinetic_model_steady_state(_model, guess) jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) From 841ae08238c6091d1241478439dd1720ed94fc5d Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 29 Nov 2024 13:24:02 +0100 Subject: [PATCH 20/27] update getting started docs --- docs/getting_started.md | 160 ++++++++++++++++------------------------ 1 file changed, 64 insertions(+), 96 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index 1872d2b..db15a51 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -18,9 +18,14 @@ Enzax provides building blocks for you to construct a wide range of differentiab Here we write a model describing a simple linear pathway with two state variables, two boundary species and three reactions. -First we import some enzax classes: +First we import some enzax classes, as well as [equinox](https://github.com/patrick-kidger/equinox) and both JAX and standard versions of numpy: ```python +import equinox + +from jax import numpy as jnp +import numpy as np + from enzax.kinetic_model import ( KineticModelStructure, RateEquationModel, @@ -32,113 +37,78 @@ from enzax.rate_equations import ( ``` -Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which represent boundary or "unbalanced" species: +Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which reactions have which rate equations. ```python -structure = KineticModelStructure( - S=jnp.array( - [[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=jnp.float64 +S = np.array([[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]]) +reactions = ["r1", "r2", "r3"] +species = ["m1e", "m1c", "m2c", "m2e"] +balanced_species = ["m1c", "m2c"] +rate_equations = [ + AllostericReversibleMichaelisMenten( + ix_allosteric_activators=np.array([2]), subunits=1 + ), + AllostericReversibleMichaelisMenten( + ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1]) ), - balanced_species=jnp.array([1, 2]), - unbalanced_species=jnp.array([0, 3]), + ReversibleMichaelisMenten(water_stoichiometry=0.0), +] +structure = KineticModelStructure( + S=S, + species=species, + balanced_species=balanced_species, + rate_equations=rate_equations, ) ``` -Next we provide some kinetic parameter values: +Next we define what a set of kinetic parameters looks like for our problem, and provide a set of parameters matching this definition: ```python -from enzax.parameters import AllostericMichaelisMentenParameters - -parameters = AllostericMichaelisMentenParameters( - log_kcat=jnp.array([-0.1, 0.0, 0.1]), - log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), - dgf=jnp.array([-3, -1.0]), - log_km=jnp.array([0.1, -0.2, 0.5, 0.0, -1.0, 0.5]), - log_ki=jnp.array([1.0]), - log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), +class ParameterDefinition(eqx.Module): + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] + log_kcat: dict[int, Scalar] + log_enzyme: dict[int, Array] + log_ki: dict[int, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[int, Array] + log_dc_activator: dict[int, Array] + log_tc: dict[int, Array] + +parameters = ParameterDefinition( + log_substrate_km={ + 0: jnp.array([0.1]), + 1: jnp.array([0.5]), + 2: jnp.array([-1.0]), + }, + log_product_km={ + 0: jnp.array([-0.2]), + 1: jnp.array([0.0]), + 2: jnp.array([0.5]), + }, + log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, + dgf=jnp.array([-3.0, -1.0]), + log_ki={0: jnp.array([]), 1: jnp.array([1.0]), 2: jnp.array([])}, temperature=jnp.array(310.0), - log_transfer_constant=jnp.array([-0.2, 0.3]), - log_dissociation_constant=jnp.array([-0.1, 0.2]), - log_drain=jnp.array([]), -) -``` -Now we can use enzax's rate laws to specify how each reaction behaves: - -```python -from enzax.rate_equations import ( - AllostericReversibleMichaelisMenten, - ReversibleMichaelisMenten, -) - -r0 = AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - subunits=1, -) -r1 = AllostericReversibleMichaelisMenten( - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([0]), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([1]), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([1], dtype=jnp.int16), - ix_product=jnp.array([2], dtype=jnp.int16), - ix_reactants=jnp.array([1, 2], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=1, - ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), - ix_dc_activation=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([1], dtype=jnp.int16), - subunits=1, -) -r2 = ReversibleMichaelisMenten( - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4, 5], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - ix_substrate=jnp.array([2], dtype=jnp.int16), - ix_product=jnp.array([3], dtype=jnp.int16), - ix_reactants=jnp.array([2, 3], dtype=jnp.int16), - reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), + log_enzyme={ + 0: jnp.log(jnp.array(0.3)), + 1: jnp.log(jnp.array(0.2)), + 2: jnp.log(jnp.array(0.1)), + }, + log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), + log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)}, + log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])}, + log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])}, ) ``` +Note that the parameters use `jnp` whereas the structure uses `np`. This is because we want JAX to trace the parameters, whereas the structure should be static. Read more about this [here](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations). Now we can declare our model: ```python -model = RateEquationModel(structure, parameters, [r0, r1, r2]) +model = RateEquationModel(structure, parameters) ``` To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations: @@ -181,9 +151,7 @@ model = methionine.model def get_steady_state_from_params(parameters: PyTree): """Get the steady state with a one-argument non-pure function.""" - _model = RateEquationModel( - parameters, model.structure, model.rate_equations - ) + _model = RateEquationModel(parameters, model.structure) return get_kinetic_model_steady_state(_model, guess) jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) From 296befa71eb09ef27fd0b6cb0755400d851f9f56 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 29 Nov 2024 14:12:21 +0100 Subject: [PATCH 21/27] Fix typo in python example --- docs/getting_started.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index db15a51..65f0708 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -21,7 +21,7 @@ Here we write a model describing a simple linear pathway with two state variable First we import some enzax classes, as well as [equinox](https://github.com/patrick-kidger/equinox) and both JAX and standard versions of numpy: ```python -import equinox +import equinox as eqx from jax import numpy as jnp import numpy as np From 62884a1293f09398fc12932dc3eba3134f7c88dc Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 11 Dec 2024 15:22:21 +0100 Subject: [PATCH 22/27] define rate equation models using reaction ids --- docs/getting_started.md | 36 ++-- src/enzax/examples/linear.py | 56 +++--- src/enzax/examples/methionine.py | 176 +++++++++---------- src/enzax/kinetic_model.py | 52 ++++-- src/enzax/rate_equation.py | 4 +- src/enzax/rate_equations/drain.py | 6 +- src/enzax/rate_equations/generalised_mwc.py | 72 ++++---- src/enzax/rate_equations/michaelis_menten.py | 60 +++---- tests/data/expected_methionine_gradient.json | 2 +- tests/test_rate_equations.py | 48 ++--- 10 files changed, 268 insertions(+), 244 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index 65f0708..d4149b9 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -40,7 +40,11 @@ from enzax.rate_equations import ( Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which reactions have which rate equations. ```python -S = np.array([[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]]) +stoichiometry = { + "r1": {"m1e": -1, "m1c": 1}, + "r2": {"m1c": -1, "m2c": 1}, + "r3": {"m2c": -1, "m2e": 1}, +} reactions = ["r1", "r2", "r3"] species = ["m1e", "m1c", "m2c", "m2e"] balanced_species = ["m1c", "m2c"] @@ -54,7 +58,7 @@ rate_equations = [ ReversibleMichaelisMenten(water_stoichiometry=0.0), ] structure = KineticModelStructure( - S=S, + stoichiometry=stoichiometry, species=species, balanced_species=balanced_species, rate_equations=rate_equations, @@ -79,28 +83,28 @@ class ParameterDefinition(eqx.Module): parameters = ParameterDefinition( log_substrate_km={ - 0: jnp.array([0.1]), - 1: jnp.array([0.5]), - 2: jnp.array([-1.0]), + "r1": jnp.array([0.1]), + "r2": jnp.array([0.5]), + "r3": jnp.array([-1.0]), }, log_product_km={ - 0: jnp.array([-0.2]), - 1: jnp.array([0.0]), - 2: jnp.array([0.5]), + "r1": jnp.array([-0.2]), + "r2": jnp.array([0.0]), + "r3": jnp.array([0.5]), }, - log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, + log_kcat={"r1": jnp.array(-0.1), "r2": jnp.array(0.0), "r3": jnp.array(0.1)}, dgf=jnp.array([-3.0, -1.0]), - log_ki={0: jnp.array([]), 1: jnp.array([1.0]), 2: jnp.array([])}, + log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])}, temperature=jnp.array(310.0), log_enzyme={ - 0: jnp.log(jnp.array(0.3)), - 1: jnp.log(jnp.array(0.2)), - 2: jnp.log(jnp.array(0.1)), + "r1": jnp.log(jnp.array(0.3)), + "r2": jnp.log(jnp.array(0.2)), + "r3": jnp.log(jnp.array(0.1)), }, log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), - log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)}, - log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])}, - log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])}, + log_tc={"r1": jnp.array(-0.2), "r2": jnp.array(0.3)}, + log_dc_activator={"r1": jnp.array([-0.1]), "r2": jnp.array([])}, + log_dc_inhibitor={"r1": jnp.array([]), "r2": jnp.array([0.2])}, ) ``` Note that the parameters use `jnp` whereas the structure uses `np`. This is because we want JAX to trace the parameters, whereas the structure should be static. Read more about this [here](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations). diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 6a1b3b0..45cb847 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -16,20 +16,24 @@ class ParameterDefinition(eqx.Module): - log_substrate_km: dict[int, Array] - log_product_km: dict[int, Array] - log_kcat: dict[int, Scalar] - log_enzyme: dict[int, Array] - log_ki: dict[int, Array] + log_substrate_km: dict[str, Array] + log_product_km: dict[str, Array] + log_kcat: dict[str, Scalar] + log_enzyme: dict[str, Array] + log_ki: dict[str, Array] dgf: Array temperature: Scalar log_conc_unbalanced: Array - log_dc_inhibitor: dict[int, Array] - log_dc_activator: dict[int, Array] - log_tc: dict[int, Array] + log_dc_inhibitor: dict[str, Array] + log_dc_activator: dict[str, Array] + log_tc: dict[str, Array] -S = np.array([[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=np.float64) +stoichiometry = { + "r1": {"m1e": -1, "m1c": 1}, + "r2": {"m1c": -1, "m2c": 1}, + "r3": {"m2c": -1, "m2e": 1}, +} reactions = ["r1", "r2", "r3"] species = ["m1e", "m1c", "m2c", "m2e"] balanced_species = ["m1c", "m2c"] @@ -43,7 +47,7 @@ class ParameterDefinition(eqx.Module): ReversibleMichaelisMenten(water_stoichiometry=0.0), ] structure = RateEquationKineticModelStructure( - S=S, + stoichiometry=stoichiometry, species=species, reactions=reactions, balanced_species=balanced_species, @@ -52,28 +56,32 @@ class ParameterDefinition(eqx.Module): ) parameters = ParameterDefinition( log_substrate_km={ - 0: jnp.array([0.1]), - 1: jnp.array([0.5]), - 2: jnp.array([-1.0]), + "r1": jnp.array([0.1]), + "r2": jnp.array([0.5]), + "r3": jnp.array([-1.0]), }, log_product_km={ - 0: jnp.array([-0.2]), - 1: jnp.array([0.0]), - 2: jnp.array([0.5]), + "r1": jnp.array([-0.2]), + "r2": jnp.array([0.0]), + "r3": jnp.array([0.5]), + }, + log_kcat={ + "r1": jnp.array(-0.1), + "r2": jnp.array(0.0), + "r3": jnp.array(0.1), }, - log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, dgf=jnp.array([-3.0, -1.0]), - log_ki={0: jnp.array([]), 1: jnp.array([1.0]), 2: jnp.array([])}, + log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])}, temperature=jnp.array(310.0), log_enzyme={ - 0: jnp.log(jnp.array(0.3)), - 1: jnp.log(jnp.array(0.2)), - 2: jnp.log(jnp.array(0.1)), + "r1": jnp.log(jnp.array(0.3)), + "r2": jnp.log(jnp.array(0.2)), + "r3": jnp.log(jnp.array(0.1)), }, log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), - log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)}, - log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])}, - log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])}, + log_tc={"r1": jnp.array(-0.2), "r2": jnp.array(0.3)}, + log_dc_activator={"r1": jnp.array([-0.1]), "r2": jnp.array([])}, + log_dc_inhibitor={"r1": jnp.array([]), "r2": jnp.array([0.2])}, ) true_model = RateEquationModel(structure=structure, parameters=parameters) steady_state = jnp.array([0.43658744, 0.12695706]) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 9ca0674..7191b02 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -23,45 +23,33 @@ class ParameterDefinition(eqx.Module): - log_substrate_km: dict[int, Array] - log_product_km: dict[int, Array] - log_kcat: dict[int, Scalar] - log_enzyme: dict[int, Array] - log_ki: dict[int, Array] + log_substrate_km: dict[str, Array] + log_product_km: dict[str, Array] + log_kcat: dict[str, Scalar] + log_enzyme: dict[str, Array] + log_ki: dict[str, Array] dgf: Array temperature: Scalar log_conc_unbalanced: Array - log_dc_inhibitor: dict[int, Array] - log_dc_activator: dict[int, Array] - log_tc: dict[int, Array] - log_drain: dict[int, Scalar] + log_dc_inhibitor: dict[str, Array] + log_dc_activator: dict[str, Array] + log_tc: dict[str, Array] + log_drain: dict[str, Scalar] -S = np.array( - [ - [1, -1, -1, 0, 0, 0, 1, 1, 0, 0, -1], # met-L b - [0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0], # atp - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # pi - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # ppi - [0, 1, 1, -1, -1, 0, 0, 0, 0, 0, 0], # amet b - [0, 0, 0, 1, 1, -1, 0, 0, 0, 0, 0], # ahcys b - [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], # gly - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # sarcs - [0, 0, 0, 0, 0, 1, -1, -1, -1, 0, 0], # hcys-L b - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # adn - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # thf - [0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0], # 5mthf b - [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # mlthf - [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0], # glyb - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # dmgly - [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], # ser-L - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # nadp - [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # nadph - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # cyst-L - ], - dtype=np.float64, -) -reactions = [] +stoichiometry = { + "the_drain": {"met-L": 1}, + "MAT1": {"met-L": -1, "atp": -1, "pi": 1, "ppi": 1, "amet": 1}, + "MAT3": {"met-L": -1, "atp": -1, "pi": 1, "ppi": 1, "amet": 1}, + "METH-Gen": {"amet": -1, "ahcys": 1}, + "GNMT1": {"amet": -1, "ahcys": 1, "gly": -1, "sarcs": 1}, + "AHC1": {"ahcys": -1, "hcys-L": 1, "adn": 1}, + "MS1": {"hcys-L": -1, "thf": 1, "met-L": 1, "5mthf": -1}, + "BHMT1": {"hcys-L": -1, "glyb": -1, "met-L": 1, "dmgly": 1}, + "CBS1": {"hcys-L": -1, "ser-L": -1, "cyst-L": 1}, + "MTHFR1": {"5mthf": 1, "mlthf": -1, "nadp": 1, "nadph": -1}, + "PROT1": {"met-L": -1}, +} species = [ "met-L", "atp", @@ -105,30 +93,30 @@ class ParameterDefinition(eqx.Module): ] parameters = ParameterDefinition( log_kcat={ - 1: jnp.log(jnp.array(7.89577)), # MAT1 - 2: jnp.log(jnp.array(19.9215)), # MAT3 - 3: jnp.log(jnp.array(1.15777)), # METH-Gen - 4: jnp.log(jnp.array(10.5307)), # GNMT1 - 5: jnp.log(jnp.array(234.284)), # AHC1 - 6: jnp.log(jnp.array(1.77471)), # MS1 - 7: jnp.log(jnp.array(13.7676)), # BHMT1 - 8: jnp.log(jnp.array(7.02307)), # CBS1 - 9: jnp.log(jnp.array(3.1654)), # MTHFR1 - 10: jnp.log(jnp.array(0.264744)), # PROT1 + "MAT1": jnp.log(jnp.array(7.89577)), # MAT1 + "MAT3": jnp.log(jnp.array(19.9215)), # MAT3 + "METH-Gen": jnp.log(jnp.array(1.15777)), # METH-Gen + "GNMT1": jnp.log(jnp.array(10.5307)), # GNMT1 + "AHC1": jnp.log(jnp.array(234.284)), # AHC1 + "MS1": jnp.log(jnp.array(1.77471)), # MS1 + "BHMT1": jnp.log(jnp.array(13.7676)), # BHMT1 + "CBS1": jnp.log(jnp.array(7.02307)), # CBS1 + "MTHFR1": jnp.log(jnp.array(3.1654)), # MTHFR1 + "PROT1": jnp.log(jnp.array(0.264744)), # PROT1 }, log_enzyme={ - 1: jnp.log(jnp.array(0.000961712)), # MAT1 - 2: jnp.log(jnp.array(0.00098812)), # MAT3 - 3: jnp.log(jnp.array(0.00103396)), # METH-Gen - 4: jnp.log(jnp.array(0.000983692)), # GNMT1 - 5: jnp.log(jnp.array(0.000977878)), # AHC1 - 6: jnp.log(jnp.array(0.00105094)), # MS1 - 7: jnp.log(jnp.array(0.000996603)), # BHMT1 - 8: jnp.log(jnp.array(0.00134056)), # CBS1 - 9: jnp.log(jnp.array(0.0010054)), # MTHFR1 - 10: jnp.log(jnp.array(0.000995525)), # PROT1 + "MAT1": jnp.log(jnp.array(0.000961712)), # MAT1 + "MAT3": jnp.log(jnp.array(0.00098812)), # MAT3 + "METH-Gen": jnp.log(jnp.array(0.00103396)), # METH-Gen + "GNMT1": jnp.log(jnp.array(0.000983692)), # GNMT1 + "AHC1": jnp.log(jnp.array(0.000977878)), # AHC1 + "MS1": jnp.log(jnp.array(0.00105094)), # MS1 + "BHMT1": jnp.log(jnp.array(0.000996603)), # BHMT1 + "CBS1": jnp.log(jnp.array(0.00134056)), # CBS1 + "MTHFR1": jnp.log(jnp.array(0.0010054)), # MTHFR1 + "PROT1": jnp.log(jnp.array(0.000995525)), # PROT1 }, - log_drain={0: jnp.log(jnp.array(0.000850127))}, + log_drain={"the_drain": jnp.log(jnp.array(0.000850127))}, dgf=jnp.array( [ 160.953, # met-L @@ -153,32 +141,42 @@ class ParameterDefinition(eqx.Module): ] ), log_product_km={ - 5: jnp.log(jnp.array([1.06e-05, 5.66e-06])), # hcys-L AHC1, adn AHC1 + "AHC1": jnp.log( + jnp.array([1.06e-05, 5.66e-06]) + ), # hcys-L AHC1, adn AHC1 }, log_substrate_km={ - 1: jnp.log(jnp.array([0.000106919, 0.00203015])), # MAT1 met-L, atp - 2: jnp.log(jnp.array([0.00113258, 0.00236759])), # MAT3 met-L atp - 3: jnp.log(jnp.array([9.37e-06])), # METH-Gen amet - 4: jnp.log(jnp.array([0.000520015, 0.00253545])), # GNMT1, amet, gly - 5: jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 - 6: jnp.log(jnp.array([1.71e-06, 6.94e-05])), # MS1 hcys-L, 5mthf - 7: jnp.log(jnp.array([1.98e-05, 0.00845898])), # BHMT1 hcys-L, glyb - 8: jnp.log(jnp.array([4.24e-05, 2.83e-06])), # CBS1 hcys-L, ser-L - 9: jnp.log(jnp.array([8.08e-05, 2.09e-05])), # MTHFR1 mlthf, nadph - 10: jnp.log(jnp.array([4.39e-05])), # PROT1 met-L + "MAT1": jnp.log( + jnp.array([0.000106919, 0.00203015]) + ), # MAT1 met-L, atp + "MAT3": jnp.log(jnp.array([0.00113258, 0.00236759])), # MAT3 met-L atp + "METH-Gen": jnp.log(jnp.array([9.37e-06])), # METH-Gen amet + "GNMT1": jnp.log( + jnp.array([0.000520015, 0.00253545]) + ), # GNMT1, amet, gly + "AHC1": jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 + "MS1": jnp.log(jnp.array([1.71e-06, 6.94e-05])), # MS1 hcys-L, 5mthf + "BHMT1": jnp.log( + jnp.array([1.98e-05, 0.00845898]) + ), # BHMT1 hcys-L, glyb + "CBS1": jnp.log(jnp.array([4.24e-05, 2.83e-06])), # CBS1 hcys-L, ser-L + "MTHFR1": jnp.log( + jnp.array([8.08e-05, 2.09e-05]) + ), # MTHFR1 mlthf, nadph + "PROT1": jnp.log(jnp.array([4.39e-05])), # PROT1 met-L }, temperature=jnp.array(298.15), log_ki={ - 1: jnp.array([jnp.log(0.000346704)]), # MAT1 - 2: jnp.array([]), - 3: jnp.array([jnp.log(5.56e-06)]), # METH-Gen - 4: jnp.array([jnp.log(5.31e-05)]), # GNMT1 - 5: jnp.array([]), - 6: jnp.array([]), - 7: jnp.array([]), - 8: jnp.array([]), - 9: jnp.array([]), - 10: jnp.array([]), + "MAT1": jnp.array([jnp.log(0.000346704)]), # MAT1 + "MAT3": jnp.array([]), + "METH-Gen": jnp.array([jnp.log(5.56e-06)]), # METH-Gen + "GNMT1": jnp.array([jnp.log(5.31e-05)]), # GNMT1 + "AHC1": jnp.array([]), + "MS1": jnp.array([]), + "BHMT1": jnp.array([]), + "CBS1": jnp.array([]), + "MTHFR1": jnp.array([]), + "PROT1": jnp.array([]), }, log_conc_unbalanced=jnp.log( jnp.array( @@ -202,29 +200,29 @@ class ParameterDefinition(eqx.Module): ) ), log_tc={ - 2: jnp.array(jnp.log(0.107657)), # MAT3 - 4: jnp.array(jnp.log(131.207)), # GNMT - 8: jnp.array(jnp.log(1.03452)), # CBS - 9: jnp.array(jnp.log(0.392035)), # MTHFR + "MAT3": jnp.array(jnp.log(0.107657)), # MAT3 + "GNMT1": jnp.array(jnp.log(131.207)), # GNMT + "CBS1": jnp.array(jnp.log(1.03452)), # CBS + "MTHFR1": jnp.array(jnp.log(0.392035)), # MTHFR }, log_dc_activator={ - 2: jnp.log( + "MAT3": jnp.log( jnp.array([0.00059999, 0.000316641]) ), # met-L MAT3, # amet MAT3 - 4: jnp.log(jnp.array([1.98e-05])), # amet GNMT1 - 8: jnp.array([]), # CBS1 - 9: jnp.log(jnp.array([2.45e-06])), # ahcys MTHFR1, + "GNMT1": jnp.log(jnp.array([1.98e-05])), # amet GNMT1 + "CBS1": jnp.array([]), # CBS1 + "MTHFR1": jnp.log(jnp.array([2.45e-06])), # ahcys MTHFR1, }, log_dc_inhibitor={ - 2: jnp.array([]), # MAT3 - 4: jnp.log(jnp.array([0.000228576])), # mlthf GNMT1 - 8: jnp.log(jnp.array([9.30e-05])), # amet CBS1 - 9: jnp.log(jnp.array([1.46e-05])), # amet MTHFR1 + "MAT3": jnp.array([]), # MAT3 + "GNMT1": jnp.log(jnp.array([0.000228576])), # mlthf GNMT1 + "CBS1": jnp.log(jnp.array([9.30e-05])), # amet CBS1 + "MTHFR1": jnp.log(jnp.array([1.46e-05])), # amet MTHFR1 }, ) structure = RateEquationKineticModelStructure( - S=S, + stoichiometry=stoichiometry, species=species, reactions=reactions, balanced_species=balanced_species, diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 0bc642a..0ac431c 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -14,6 +14,10 @@ from enzax.rate_equation import RateEquation +def get_ix_from_list(s: str, list_of_strings: list[str]): + return next(i for i, si in enumerate(list_of_strings) if si == s) + + def get_conc(balanced, log_unbalanced, structure): conc = jnp.zeros(structure.S.shape[0]) conc = conc.at[structure.balanced_species_ix].set(balanced) @@ -26,42 +30,53 @@ def get_conc(balanced, log_unbalanced, structure): class KineticModelStructure: """Structural information about a kinetic model.""" - S: NDArray[np.float64] + stoichiometry: dict[str, dict[str, float]] species: list[str] reactions: list[str] balanced_species: list[str] + unbalanced_species: list[str] species_to_dgf_ix: NDArray[np.int16] balanced_species_ix: NDArray[np.int16] unbalanced_species_ix: NDArray[np.int16] + S: NDArray[np.float64] def __init__( self, - S, + stoichiometry, species, reactions, balanced_species, species_to_dgf_ix=None, ): - self.S = S + self.stoichiometry = stoichiometry self.species = species self.reactions = reactions self.balanced_species = balanced_species + self.unbalanced_species = [ + s for s in species if s not in balanced_species + ] if species_to_dgf_ix is None: self.species_to_dgf_ix = np.arange(len(species), dtype=np.int16) else: self.species_to_dgf_ix = species_to_dgf_ix self.balanced_species_ix = np.array( - [i for i, s in enumerate(species) if s in balanced_species], + [get_ix_from_list(s, species) for s in self.balanced_species], dtype=np.int16, ) self.unbalanced_species_ix = np.array( - [i for i, s in enumerate(species) if s not in balanced_species], + [get_ix_from_list(s, species) for s in self.unbalanced_species], dtype=np.int16, ) + S = np.zeros(shape=(len(species), len(reactions))) + for ix_reaction, reaction in enumerate(reactions): + for species_i, coeff in stoichiometry[reaction].items(): + ix_species = get_ix_from_list(species_i, species) + S[ix_species, ix_reaction] = coeff + self.S = S.astype(np.float64) def tree_flatten(self): children = ( - self.S, + self.stoichiometry, self.species, self.reactions, self.balanced_species, @@ -80,7 +95,7 @@ class RateEquationKineticModelStructure(KineticModelStructure): def __init__( self, - S, + stoichiometry, species, reactions, balanced_species, @@ -88,13 +103,17 @@ def __init__( species_to_dgf_ix=None, ): super().__init__( - S, species, reactions, balanced_species, species_to_dgf_ix + stoichiometry, + species, + reactions, + balanced_species, + species_to_dgf_ix, ) self.rate_equations = rate_equations def tree_flatten(self): children = ( - self.S, + self.stoichiometry, self.species, self.reactions, self.balanced_species, @@ -157,18 +176,17 @@ def flux( self.structure, ) flux_list = [] - for i, rate_equation in enumerate(self.structure.rate_equations): + for reaction_ix, (reaction_id, rate_equation) in enumerate( + zip(self.structure.reactions, self.structure.rate_equations) + ): ipt = rate_equation.get_input( - self.parameters, - i, - self.structure.S, - self.structure.species_to_dgf_ix, + parameters=self.parameters, + reaction_id=reaction_id, + reaction_stoichiometry=self.structure.S[:, reaction_ix], + species_to_dgf_ix=self.structure.species_to_dgf_ix, ) flux_list.append(rate_equation(conc, ipt)) return jnp.array(flux_list) - t = [f(conc, self.parameters) for f in self.rate_equations] - out = jnp.array(t) - return out class KineticModelSbml(KineticModel): diff --git a/src/enzax/rate_equation.py b/src/enzax/rate_equation.py index 4c11631..883e9a9 100644 --- a/src/enzax/rate_equation.py +++ b/src/enzax/rate_equation.py @@ -21,8 +21,8 @@ class RateEquation(Module, ABC): def get_input( self, parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ) -> PyTree: ... diff --git a/src/enzax/rate_equations/drain.py b/src/enzax/rate_equations/drain.py index 680a108..17833c6 100644 --- a/src/enzax/rate_equations/drain.py +++ b/src/enzax/rate_equations/drain.py @@ -19,11 +19,11 @@ class Drain(RateEquation): def get_input( self, parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ): - return DrainInput(abs_v=jnp.exp(parameters.log_drain[rxn_ix])) + return DrainInput(abs_v=jnp.exp(parameters.log_drain[reaction_id])) def __call__(self, conc: ConcArray, drain_input: PyTree) -> Scalar: """Get the flux of a drain reaction.""" diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index ea9ee60..17b6edb 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -30,58 +30,56 @@ class AllostericReversibleMichaelisMentenInput(ReversibleMichaelisMentenInput): def get_allosteric_irreversible_michaelis_menten_input( parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ci_ix: NDArray[np.int16], ) -> AllostericIrreversibleMichaelisMentenInput: - Sj = S[:, rxn_ix] - ix_substrate = np.argwhere(Sj < 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() return AllostericIrreversibleMichaelisMentenInput( - kcat=jnp.exp(parameters.log_kcat[rxn_ix]), - enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), - substrate_stoichiometry=Sj[ix_substrate], + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], ix_ki_species=ci_ix, - ki=jnp.exp(parameters.log_ki[rxn_ix]), - dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[rxn_ix]), - dc_activator=jnp.exp(parameters.log_dc_activator[rxn_ix]), - tc=jnp.exp(parameters.log_tc[rxn_ix]), + ki=jnp.exp(parameters.log_ki[reaction_id]), + dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[reaction_id]), + dc_activator=jnp.exp(parameters.log_dc_activator[reaction_id]), + tc=jnp.exp(parameters.log_tc[reaction_id]), ) def get_allosteric_reversible_michaelis_menten_input( parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ci_ix: NDArray[np.int16], water_stoichiometry: float, ) -> AllostericReversibleMichaelisMentenInput: - Sj = S[:, rxn_ix] - ix_reactant = np.argwhere(Sj != 0.0).flatten() - ix_substrate = np.argwhere(Sj < 0.0).flatten() - ix_product = np.argwhere(Sj > 0.0).flatten() + ix_reactant = np.argwhere(reaction_stoichiometry != 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + ix_product = np.argwhere(reaction_stoichiometry > 0.0).flatten() return AllostericReversibleMichaelisMentenInput( - kcat=jnp.exp(parameters.log_kcat[rxn_ix]), - enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), - substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), - product_kms=jnp.exp(parameters.log_product_km[rxn_ix]), - ki=jnp.exp(parameters.log_ki[rxn_ix]), + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + product_kms=jnp.exp(parameters.log_product_km[reaction_id]), + ki=jnp.exp(parameters.log_ki[reaction_id]), dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, ix_ki_species=ci_ix, ix_reactant=ix_reactant, ix_substrate=ix_substrate, ix_product=ix_product, - reactant_stoichiometry=Sj[ix_reactant], - substrate_stoichiometry=Sj[ix_substrate], - product_stoichiometry=Sj[ix_product], + reactant_stoichiometry=reaction_stoichiometry[ix_reactant], + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], + product_stoichiometry=reaction_stoichiometry[ix_product], water_stoichiometry=water_stoichiometry, - dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[rxn_ix]), - dc_activator=jnp.exp(parameters.log_dc_activator[rxn_ix]), - tc=jnp.exp(parameters.log_tc[rxn_ix]), + dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[reaction_id]), + dc_activator=jnp.exp(parameters.log_dc_activator[reaction_id]), + tc=jnp.exp(parameters.log_tc[reaction_id]), ) @@ -119,14 +117,14 @@ class AllostericIrreversibleMichaelisMenten(IrreversibleMichaelisMenten): def get_input( self, parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ): return get_allosteric_irreversible_michaelis_menten_input( parameters=parameters, - rxn_ix=rxn_ix, - S=S, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, species_to_dgf_ix=species_to_dgf_ix, ci_ix=self.ix_ki_species, ) @@ -171,14 +169,14 @@ class AllostericReversibleMichaelisMenten(ReversibleMichaelisMenten): def get_input( self, parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ): return get_allosteric_reversible_michaelis_menten_input( parameters=parameters, - rxn_ix=rxn_ix, - S=S, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, species_to_dgf_ix=species_to_dgf_ix, ci_ix=self.ix_ki_species, water_stoichiometry=self.water_stoichiometry, diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index df44865..39d96d7 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -19,21 +19,20 @@ class IrreversibleMichaelisMentenInput(eqx.Module): def get_irreversible_michaelis_menten_input( parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ci_ix: NDArray[np.int16], ) -> IrreversibleMichaelisMentenInput: - Sj = S[:, rxn_ix] - ix_substrate = np.argwhere(Sj < 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() return IrreversibleMichaelisMentenInput( - kcat=jnp.exp(parameters.log_kcat[rxn_ix]), - enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), - substrate_stoichiometry=Sj[ix_substrate], + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], ix_ki_species=ci_ix, - ki=jnp.exp(parameters.log_ki[rxn_ix]), + ki=jnp.exp(parameters.log_ki[reaction_id]), ) @@ -57,31 +56,30 @@ class ReversibleMichaelisMentenInput(eqx.Module): def get_reversible_michaelis_menten_input( parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ci_ix: NDArray[np.int16], water_stoichiometry: float, ) -> ReversibleMichaelisMentenInput: - Sj = S[:, rxn_ix] - ix_reactant = np.argwhere(Sj != 0.0).flatten() - ix_substrate = np.argwhere(Sj < 0.0).flatten() - ix_product = np.argwhere(Sj > 0.0).flatten() + ix_reactant = np.argwhere(reaction_stoichiometry != 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + ix_product = np.argwhere(reaction_stoichiometry > 0.0).flatten() return ReversibleMichaelisMentenInput( - kcat=jnp.exp(parameters.log_kcat[rxn_ix]), - enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), - substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), - product_kms=jnp.exp(parameters.log_product_km[rxn_ix]), - ki=jnp.exp(parameters.log_ki[rxn_ix]), + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + product_kms=jnp.exp(parameters.log_product_km[reaction_id]), + ki=jnp.exp(parameters.log_ki[reaction_id]), dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, ix_ki_species=ci_ix, ix_reactant=ix_reactant, ix_substrate=ix_substrate, ix_product=ix_product, - reactant_stoichiometry=Sj[ix_reactant], - substrate_stoichiometry=Sj[ix_substrate], - product_stoichiometry=Sj[ix_product], + reactant_stoichiometry=reaction_stoichiometry[ix_reactant], + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], + product_stoichiometry=reaction_stoichiometry[ix_product], water_stoichiometry=water_stoichiometry, ) @@ -171,14 +169,14 @@ class IrreversibleMichaelisMenten(RateEquation): def get_input( self, parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ): return get_irreversible_michaelis_menten_input( parameters=parameters, - rxn_ix=rxn_ix, - S=S, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, species_to_dgf_ix=species_to_dgf_ix, ci_ix=self.ix_ki_species, ) @@ -214,14 +212,14 @@ class ReversibleMichaelisMenten(RateEquation): def get_input( self, parameters: PyTree, - rxn_ix: int, - S: NDArray[np.float64], + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], species_to_dgf_ix: NDArray[np.int16], ): return get_reversible_michaelis_menten_input( parameters=parameters, - rxn_ix=rxn_ix, - S=S, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, species_to_dgf_ix=species_to_dgf_ix, ci_ix=self.ix_ki_species, water_stoichiometry=self.water_stoichiometry, diff --git a/tests/data/expected_methionine_gradient.json b/tests/data/expected_methionine_gradient.json index 0271271..83abf48 100644 --- a/tests/data/expected_methionine_gradient.json +++ b/tests/data/expected_methionine_gradient.json @@ -1 +1 @@ -[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 12.720942038333074, 7.723972171145526, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092] +[-51.008806318357124, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, -20.04100693766347, -21.046992055692126, -27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -22.87688514741915, -63.96941462805278, 91.03560765315373, 10.573024526064392, -26.068355084866106, 12.720942038333074, 7.723972171145526, 51.31272574643639, -52.55796473716283, 122.11340482806465, 19.857181701644674, 37.17973577495537, 11.212172526568366, 120.58711257251609, 69.99249085550201, -94.45124667621965, 51.204443757792795, 52.405258216632795, -45.05001949589764, -149.04233199339694, 27.62149868641839, 36.91166292624504, 109.14923727725667, 133.32327445006953, 378.63173605788427, -20.25453860618623, 17.98290589428491, -0.008014753935453442, 1.390914925131441, 1.0827290721714595, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.00030317394151436246, 0.3144610971331224, -2.3261150730556386, -14.115986977104384, -0.03979371864763597, -0.055215199560989005, 0.2784311817245577, -0.0006063958903968876, -5.783239063037292, -0.3295244858681964, 1.710877943445662, -336.4591824327092] diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index 3b2961f..0f414c4 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -13,33 +13,33 @@ class ExampleParameterSet(eqx.Module): - log_substrate_km: dict[int, Array] - log_product_km: dict[int, Array] - log_kcat: dict[int, Scalar] - log_enzyme: dict[int, Array] - log_ki: dict[int, Array] + log_substrate_km: dict[str, Array] + log_product_km: dict[str, Array] + log_kcat: dict[str, Scalar] + log_enzyme: dict[str, Array] + log_ki: dict[str, Array] dgf: Array temperature: Scalar log_conc_unbalanced: Array - log_dc_inhibitor: dict[int, Array] - log_dc_activator: dict[int, Array] - log_tc: dict[int, Array] + log_dc_inhibitor: dict[str, Array] + log_dc_activator: dict[str, Array] + log_tc: dict[str, Array] EXAMPLE_S = np.array([[-1], [1], [0]], dtype=np.float64) EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1]) EXAMPLE_PARAMETERS = ExampleParameterSet( - log_substrate_km={0: jnp.array([0.1])}, - log_product_km={0: jnp.array([-0.2])}, - log_kcat={0: jnp.array(-0.1)}, + log_substrate_km={"r1": jnp.array([0.1])}, + log_product_km={"r1": jnp.array([-0.2])}, + log_kcat={"r1": jnp.array(-0.1)}, dgf=jnp.array([-3.0, 1.0]), - log_ki={0: jnp.array([1.0])}, + log_ki={"r1": jnp.array([1.0])}, temperature=jnp.array(310.0), - log_enzyme={0: jnp.log(jnp.array(0.3))}, + log_enzyme={"r1": jnp.log(jnp.array(0.3))}, log_conc_unbalanced=jnp.array([]), - log_tc={0: jnp.array(-0.2)}, - log_dc_activator={0: jnp.array([-0.1])}, - log_dc_inhibitor={0: jnp.array([0.2])}, + log_tc={"r1": jnp.array(-0.2)}, + log_dc_activator={"r1": jnp.array([-0.1])}, + log_dc_inhibitor={"r1": jnp.array([0.2])}, ) EXAMPLE_SPECIES_TO_DGF_IX = np.array([0, 0, 1, 1]) @@ -49,8 +49,8 @@ def test_irreversible_michaelis_menten(): f = IrreversibleMichaelisMenten() f_input = f.get_input( parameters=EXAMPLE_PARAMETERS, - rxn_ix=0, - S=EXAMPLE_S, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, ) rate = f(EXAMPLE_CONC, f_input) @@ -65,8 +65,8 @@ def test_reversible_michaelis_menten(): ) f_input = f.get_input( parameters=EXAMPLE_PARAMETERS, - rxn_ix=0, - S=EXAMPLE_S, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, ) rate = f(EXAMPLE_CONC, f_input) @@ -82,8 +82,8 @@ def test_allosteric_irreversible_michaelis_menten(): ) f_input = f.get_input( parameters=EXAMPLE_PARAMETERS, - rxn_ix=0, - S=EXAMPLE_S, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, ) rate = f(EXAMPLE_CONC, f_input) @@ -99,8 +99,8 @@ def test_allosteric_reversible_michaelis_menten(): ) f_input = f.get_input( parameters=EXAMPLE_PARAMETERS, - rxn_ix=0, - S=EXAMPLE_S, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, ) rate = f(EXAMPLE_CONC, f_input) From fc15ee65fd5a2f3dceb8ed9e847c74813239493a Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 11 Dec 2024 15:23:44 +0100 Subject: [PATCH 23/27] Fix sbml model errors --- src/enzax/kinetic_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 0ac431c..45de6ca 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -190,7 +190,6 @@ def flux( class KineticModelSbml(KineticModel): - balanced_ids: PyTree sym_module: Any def flux( @@ -199,7 +198,8 @@ def flux( ) -> Float[Array, " n"]: flux = jnp.array( self.sym_module( - **self.parameters, **dict(zip(self.balanced_ids, conc_balanced)) + **self.parameters, + **dict(zip(self.structure.balanced_ids, conc_balanced)), ) ) return flux From 01cd8cea5bb1d1ee44b34a5690c27e79f08048c4 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Wed, 11 Dec 2024 15:32:12 +0100 Subject: [PATCH 24/27] Remove use of deprecated KeyArray --- src/enzax/mcmc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index eb17ea0..66936b1 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -7,7 +7,6 @@ import blackjax import chex import jax -from jax._src.random import KeyArray import jax.numpy as jnp from jax.scipy.stats import norm, multivariate_normal from jaxtyping import Array, Float, PyTree, ScalarLike @@ -64,7 +63,7 @@ def one_step(state, rng_key): def run_nuts( logdensity_fn: Callable, - rng_key: KeyArray, + rng_key: Array, init_parameters: PyTree, num_warmup: int, num_samples: int, From e406db326159e739e34fd2edde93112ddadae1dd Mon Sep 17 00:00:00 2001 From: AlberteSloth <158040192+AlberteSloth@users.noreply.github.com> Date: Wed, 11 Dec 2024 15:32:39 +0100 Subject: [PATCH 25/27] Update KineticModelSbml Changed balanced_ids to balanced_species in KineticModelSbml --- src/enzax/kinetic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 45de6ca..3cf26b2 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -199,7 +199,7 @@ def flux( flux = jnp.array( self.sym_module( **self.parameters, - **dict(zip(self.structure.balanced_ids, conc_balanced)), + **dict(zip(self.structure.balanced_species, conc_balanced)), ) ) return flux From 5495772e6566c5fe60e4764dd12ccd6e9f6ca53e Mon Sep 17 00:00:00 2001 From: AlberteSloth <158040192+AlberteSloth@users.noreply.github.com> Date: Fri, 13 Dec 2024 14:21:13 +0100 Subject: [PATCH 26/27] Updated sbml_demo --- scripts/sbml_demo.py | 77 +++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/scripts/sbml_demo.py b/scripts/sbml_demo.py index 8307423..1bbe8af 100644 --- a/scripts/sbml_demo.py +++ b/scripts/sbml_demo.py @@ -11,54 +11,57 @@ reactions_sympy = sbml.sbml_to_sympy(model_sbml) sym_module = sbml.sympy_to_enzax(reactions_sympy) -parameters_all = [ - ({p.getId(): p.getValue() for p in r.getKineticLaw().getListOfParameters()}) - for r in model_sbml.getListOfReactions() +species = [s.getId() for s in model_sbml.getListOfSpecies()] + +balanced_species = [ + b.getId() + for b in model_sbml.getListOfSpecies() + if not b.boundary_condition ] -parameters = {} -for i in parameters_all: - parameters.update(i) -compartments = {c.getId(): c.volume for c in model_sbml.getListOfCompartments()} +reactions = [reaction.getId() + for reaction in model_sbml.getListOfReactions() +] -species = [s.getId() for s in model_sbml.getListOfSpecies()] +stoichiometry = {reaction.getId(): {r.getSpecies(): -r.getStoichiometry(), p.getSpecies(): p.getStoichiometry()} + for reaction in model_sbml.getListOfReactions() + for r in reaction.getListOfReactants() + for p in reaction.getListOfProducts() +} + +structure = KineticModelStructure( + stoichiometry=stoichiometry, + species=species, + reactions=reactions, + balanced_species=balanced_species +) -balanced_species_dict = {} -unbalanced_species_dict = {} -for i in model_sbml.getListOfSpecies(): - if not i.boundary_condition: - balanced_species_dict.update({i.getId(): i.getInitialConcentration()}) - else: - unbalanced_species_dict.update({i.getId(): i.getInitialConcentration()}) +parameters_local = { + p.getId(): p.getValue() + for r in model_sbml.getListOfReactions() + for p in r.getKineticLaw().getListOfParameters() +} -balanced_ix = jnp.array([species.index(b) for b in balanced_species_dict]) -unbalanced_ix = jnp.array([species.index(u) for u in unbalanced_species_dict]) +parameters_global = { + p.getId(): p.getValue() + for p in model_sbml.getListOfParameters() + if p.constant +} -para = {**parameters, **compartments, **unbalanced_species_dict} +compartments = {c.getId(): c.volume + for c in model_sbml.getListOfCompartments() +} -stoichmatrix = jnp.zeros( - (model_sbml.getNumSpecies(), model_sbml.getNumReactions()), - dtype=jnp.float64, -) -i = 0 -for reaction in model_sbml.getListOfReactions(): - for r in reaction.getListOfReactants(): - stoichmatrix = stoichmatrix.at[species.index(r.getSpecies()), i].set( - -int(r.getStoichiometry()) - ) - for p in reaction.getListOfProducts(): - stoichmatrix = stoichmatrix.at[species.index(p.getSpecies()), i].set( - int(p.getStoichiometry()) - ) - i += 1 +unbalanced_species = { + u.getId(): u.getInitialConcentration() + for u in model_sbml.getListOfSpecies() + if u.boundary_condition +} -structure = KineticModelStructure( - stoichmatrix, jnp.array(balanced_ix), jnp.array(unbalanced_ix) -) +para = {**parameters_local, **parameters_global, **compartments, **unbalanced_species} kinmodel_sbml = KineticModelSbml( parameters=para, - balanced_ids=balanced_species_dict, structure=structure, sym_module=sym_module, ) From 0214412d605882b3662799ddcdf14fd570acbd0b Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 13 Dec 2024 14:27:39 +0100 Subject: [PATCH 27/27] fix pre-commit --- scripts/sbml_demo.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/scripts/sbml_demo.py b/scripts/sbml_demo.py index 1bbe8af..570fa83 100644 --- a/scripts/sbml_demo.py +++ b/scripts/sbml_demo.py @@ -14,17 +14,17 @@ species = [s.getId() for s in model_sbml.getListOfSpecies()] balanced_species = [ - b.getId() - for b in model_sbml.getListOfSpecies() - if not b.boundary_condition + b.getId() for b in model_sbml.getListOfSpecies() if not b.boundary_condition ] -reactions = [reaction.getId() - for reaction in model_sbml.getListOfReactions() -] +reactions = [reaction.getId() for reaction in model_sbml.getListOfReactions()] -stoichiometry = {reaction.getId(): {r.getSpecies(): -r.getStoichiometry(), p.getSpecies(): p.getStoichiometry()} - for reaction in model_sbml.getListOfReactions() +stoichiometry = { + reaction.getId(): { + r.getSpecies(): -r.getStoichiometry(), + p.getSpecies(): p.getStoichiometry(), + } + for reaction in model_sbml.getListOfReactions() for r in reaction.getListOfReactants() for p in reaction.getListOfProducts() } @@ -33,7 +33,7 @@ stoichiometry=stoichiometry, species=species, reactions=reactions, - balanced_species=balanced_species + balanced_species=balanced_species, ) parameters_local = { @@ -48,17 +48,20 @@ if p.constant } -compartments = {c.getId(): c.volume - for c in model_sbml.getListOfCompartments() -} +compartments = {c.getId(): c.volume for c in model_sbml.getListOfCompartments()} unbalanced_species = { - u.getId(): u.getInitialConcentration() - for u in model_sbml.getListOfSpecies() + u.getId(): u.getInitialConcentration() + for u in model_sbml.getListOfSpecies() if u.boundary_condition } -para = {**parameters_local, **parameters_global, **compartments, **unbalanced_species} +para = { + **parameters_local, + **parameters_global, + **compartments, + **unbalanced_species, +} kinmodel_sbml = KineticModelSbml( parameters=para,