diff --git a/.gitignore b/.gitignore index 0302c07..77a77f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ .DS_Store +# scratchpads +scratch.md + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index d694f94..fdc612f 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,12 @@ pip install enzax ```python from enzax.examples import methionine -from enzax.steady_state import solve_steady_state +from enzax.steady_state import get_kinetic_model_steady_state from jax import numpy as jnp guess = jnp.full((5,) 0.01) -steady_state = solve_steady_state( - methionine.parameters, methionine.unparameterised_model, guess -) +steady_state = get_kinetic_model_steady_state(methionine.model, guess) ``` ### Find a steady state's Jacobian with respect to all parameters @@ -36,12 +34,20 @@ steady_state = solve_steady_state( ```python import jax from enzax.examples import methionine -from enzax.steady_state import solve_steady_state +from enzax.steady_state import get_kinetic_model_steady_state from jax import numpy as jnp +from jaxtyping import PyTree guess = jnp.full((5,) 0.01) +model = methionine.model + +def get_steady_state_from_params(parameters: PyTree): + """Get the steady state with a one-argument non-pure function.""" + _model = RateEquationModel( + parameters, model.structure, model.rate_equations + ) + return get_kinetic_model_steady_state(_model, guess) + +jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) -jacobian = jax.jacrev(solve_steady_state)( - methionine.parameters, methionine.unparameterised_model, guess -) ``` diff --git a/docs/api/kinetic_model.md b/docs/api/kinetic_model.md index 89bc6f4..43540eb 100644 --- a/docs/api/kinetic_model.md +++ b/docs/api/kinetic_model.md @@ -4,7 +4,6 @@ filters: - "!check" members: - - KineticModel - - UnparameterisedKineticModel - - KineticModelParameters - KineticModelStructure + - KineticModel + - RateEquationModel diff --git a/docs/api/mcmc.md b/docs/api/mcmc.md new file mode 100644 index 0000000..9fc8b8c --- /dev/null +++ b/docs/api/mcmc.md @@ -0,0 +1,13 @@ +# ::: enzax.mcmc + options: + show_root_heading: true + filters: + - "!check" + members: + - ObservationSet + - PriorSet + - run_nuts + - posterior_logdensity_amm + - get_idata + - ind_prior_from_truth + - ind_normal_prior_logdensity diff --git a/docs/api/steady_state.md b/docs/api/steady_state.md new file mode 100644 index 0000000..1d6fb49 --- /dev/null +++ b/docs/api/steady_state.md @@ -0,0 +1,7 @@ +# ::: enzax.steady_state + options: + show_root_heading: true + filters: + - "!check" + members: + - get_kinetic_model_steady_state diff --git a/docs/getting_started.md b/docs/getting_started.md index 73f3f7f..1872d2b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -22,10 +22,8 @@ First we import some enzax classes: ```python from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, KineticModelStructure, - UnparameterisedKineticModel, + RateEquationModel, ) from enzax.rate_equations import ( AllostericReversibleMichaelisMenten, @@ -49,7 +47,9 @@ structure = KineticModelStructure( Next we provide some kinetic parameter values: ```python -parameters = KineticModelParameters( +from enzax.parameters import AllostericMichaelisMentenParameters + +parameters = AllostericMichaelisMentenParameters( log_kcat=jnp.array([-0.1, 0.0, 0.1]), log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), dgf=jnp.array([-3, -1.0]), @@ -65,6 +65,11 @@ parameters = KineticModelParameters( Now we can use enzax's rate laws to specify how each reaction behaves: ```python +from enzax.rate_equations import ( + AllostericReversibleMichaelisMenten, + ReversibleMichaelisMenten, +) + r0 = AllostericReversibleMichaelisMenten( kcat_ix=0, enzyme_ix=0, @@ -130,16 +135,10 @@ r2 = ReversibleMichaelisMenten( ) ``` -Next an unparameterised kinetic model +Now we can declare our model: ```python -unparameterised_model = UnparameterisedKineticModel(structure, [r0, r1, r2]) -``` - -Finally a parameterised model: - -```python -model = KineticModel(parameters, unparameterised_model) +model = RateEquationModel(structure, parameters, [r0, r1, r2]) ``` To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations: @@ -157,28 +156,35 @@ dcdt ## Find a kinetic model's steady state -Enzax provides a few example kinetic models, including [`methionine`](https://github.com/dtu-qmcm/enzax/blob/main/src/enzax/examples/methionine.py), a model of the mammallian methionine cycle. +Enzax provides a few example kinetic models, including [`methionine`](https://github.com/dtu-qmcm/enzax/blob/main/src/enzax/examples/methionine.py), a model of the mammalian methionine cycle. -Here is how to find this model's steady state (and its parameter gradients) using enzax's `solve_steady_state` function: +Here is how to find this model's steady state (and its parameter gradients) using enzax's `get_kinetic_model_steady_state` function: ```python from enzax.examples import methionine -from enzax.steady_state import solve_steady_state +from enzax.steady_state import get_kinetic_model_steady_state from jax import numpy as jnp guess = jnp.full((5,) 0.01) -steady_state = solve_steady_state( - methionine.parameters, methionine.unparameterised_model, guess -) +steady_state = get_kinetic_model_steady_state(methionine.model, guess) ``` -To find the jacobian of this steady state with respect to the model's parameters, we can wrap `solve_steady_state` in JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function: +To access the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in a function that has a set of parameters as its only argument, then use JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function: ```python import jax +from jaxtyping import PyTree -jacobian = jax.jacrev(solve_steady_state)( - methionine.parameters, methionine.unparameterised_model, guess -) +guess = jnp.full((5,) 0.01) +model = methionine.model + +def get_steady_state_from_params(parameters: PyTree): + """Get the steady state with a one-argument non-pure function.""" + _model = RateEquationModel( + parameters, model.structure, model.rate_equations + ) + return get_kinetic_model_steady_state(_model, guess) + +jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) ``` diff --git a/mkdocs.yml b/mkdocs.yml index ed1c91f..65a956a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,3 +44,5 @@ nav: - API: - 'api/kinetic_model.md' - 'api/rate_equations.md' + - 'api/steady_state.md' + - 'api/mcmc.md' diff --git a/pyproject.toml b/pyproject.toml index e2258b3..4ba43d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ line-length = 80 [tool.ruff.lint] ignore = ["F722"] +extend-select = ["E501"] # line length is checked [tool.ruff.lint.isort] known-first-party = ["enzax"] diff --git a/scripts/mcmc_demo.py b/scripts/mcmc_demo.py new file mode 100644 index 0000000..63bac12 --- /dev/null +++ b/scripts/mcmc_demo.py @@ -0,0 +1,126 @@ +"""Demonstration of how to make a Bayesian kinetic model with enzax.""" + +import functools +import logging +import warnings + +import arviz as az +import jax +from jax import numpy as jnp + +from enzax.examples import methionine +from enzax.mcmc import ( + ObservationSet, + AllostericMichaelisMentenPriorSet, + get_idata, + ind_prior_from_truth, + posterior_logdensity_amm, + run_nuts, +) +from enzax.steady_state import get_kinetic_model_steady_state + +SEED = 1234 + +jax.config.update("jax_enable_x64", True) + + +def main(): + """Demonstrate How to make a Bayesian kinetic model with enzax.""" + structure = methionine.structure + rate_equations = methionine.rate_equations + true_parameters = methionine.parameters + true_model = methionine.model + default_state_guess = jnp.full((5,), 0.01) + true_states = get_kinetic_model_steady_state( + true_model, default_state_guess + ) + prior = AllostericMichaelisMentenPriorSet( + log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), + log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), + log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), + dgf=ind_prior_from_truth(true_parameters.dgf, 0.1), + log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), + log_conc_unbalanced=ind_prior_from_truth( + true_parameters.log_conc_unbalanced, 0.1 + ), + temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), + log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), + log_transfer_constant=ind_prior_from_truth( + true_parameters.log_transfer_constant, 0.1 + ), + log_dissociation_constant=ind_prior_from_truth( + true_parameters.log_dissociation_constant, 0.1 + ), + ) + # get true concentration + true_conc = jnp.zeros(methionine.structure.S.shape[0]) + true_conc = true_conc.at[methionine.structure.balanced_species].set( + true_states + ) + true_conc = true_conc.at[methionine.structure.unbalanced_species].set( + jnp.exp(true_parameters.log_conc_unbalanced) + ) + # get true flux + true_flux = true_model.flux(true_states) + # simulate observations + error_conc = 0.03 + error_flux = 0.05 + error_enzyme = 0.03 + key = jax.random.key(SEED) + obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) + obs_enzyme = jnp.exp( + true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme + ) + obs_flux = true_flux + jax.random.normal(key) * error_conc + obs = ObservationSet( + conc=obs_conc, + flux=obs_flux, + enzyme=obs_enzyme, + conc_scale=error_conc, + flux_scale=error_flux, + enzyme_scale=error_enzyme, + ) + pldf = functools.partial( + posterior_logdensity_amm, + obs=obs, + prior=prior, + structure=structure, + rate_equations=rate_equations, + guess=default_state_guess, + ) + samples, info = run_nuts( + pldf, + key, + true_parameters, + num_warmup=200, + num_samples=200, + initial_step_size=0.0001, + max_num_doublings=10, + is_mass_matrix_diagonal=False, + target_acceptance_rate=0.95, + ) + idata = get_idata( + samples, info, coords=methionine.coords, dims=methionine.dims + ) + print(az.summary(idata)) + if jnp.any(info.is_divergent): + n_divergent = info.is_divergent.sum() + msg = f"There were {n_divergent} post-warmup divergent transitions." + warnings.warn(msg) + else: + logging.info("No post-warmup divergent transitions!") + print("True parameter values vs posterior:") + for param in true_parameters.__dataclass_fields__.keys(): + true_val = getattr(true_parameters, param) + model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) + model_high = jnp.quantile( + getattr(samples.position, param), 0.99, axis=0 + ) + print(f" {param}:") + print(f" true value: {true_val}") + print(f" posterior 1%: {model_low}") + print(f" posterior 99%: {model_high}") + + +if __name__ == "__main__": + main() diff --git a/scripts/steady_state_demo.py b/scripts/steady_state_demo.py new file mode 100644 index 0000000..5feeda3 --- /dev/null +++ b/scripts/steady_state_demo.py @@ -0,0 +1,62 @@ +"""Demonstration of how to find a steady state and its gradients with enzax.""" + +import time +from enzax.kinetic_model import RateEquationModel + +import jax +from jax import numpy as jnp + +from enzax.examples import methionine +from enzax.steady_state import get_kinetic_model_steady_state +from jaxtyping import PyTree + +BAD_GUESS = jnp.full((5,), 0.01) +GOOD_GUESS = jnp.array( + [ + 4.233000e-05, # met-L + 3.099670e-05, # amet + 2.170170e-07, # ahcys + 3.521780e-06, # hcys + 6.534400e-06, # 5mthf + ] +) + + +def main(): + """Function for testing the steady state solver.""" + model = methionine.model + # compare good and bad guess + for guess in [BAD_GUESS, GOOD_GUESS]: + + def get_steady_state_from_params(parameters: PyTree): + """Get the steady state from just parameters. + + This lets us get the Jacobian wrt (just) the parameters. + """ + _model = RateEquationModel( + parameters, model.structure, model.rate_equations + ) + return get_kinetic_model_steady_state(_model, guess) + + # solve once for jitting + get_kinetic_model_steady_state(model, GOOD_GUESS) + jax.jacrev(get_steady_state_from_params)(model.parameters) + # timer on + start = time.time() + conc_steady = get_kinetic_model_steady_state(model, guess) + jac = jax.jacrev(get_steady_state_from_params)(model.parameters) + # timer off + runtime = (time.time() - start) * 1e3 + sv = model.dcdt(jnp.array(0.0), conc=conc_steady) + flux = model.flux(conc_steady) + print(f"Results with starting guess {guess}:") + print(f"\tRun time in milliseconds: {round(runtime, 4)}") + print(f"\tSteady state concentration: {conc_steady}") + print(f"\tFlux: {flux}") + print(f"\tSv: {sv}") + print(f"\tLog Km Jacobian: {jac.log_km}") + print(f"\tDgf Jacobian: {jac.dgf}") + + +if __name__ == "__main__": + main() diff --git a/src/enzax/__init__.py b/src/enzax/__init__.py index e69de29..d5f6fc4 100644 --- a/src/enzax/__init__.py +++ b/src/enzax/__init__.py @@ -0,0 +1,3 @@ +from jax import config + +config.update("jax_enable_x64", True) diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index dc911db..134167d 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -1,22 +1,18 @@ """A simple linear kinetic model.""" -from jax import config from jax import numpy as jnp from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, + RateEquationModel, KineticModelStructure, - UnparameterisedKineticModel, ) from enzax.rate_equations import ( AllostericReversibleMichaelisMenten, ReversibleMichaelisMenten, ) +from enzax.parameters import AllostericMichaelisMentenParameterSet -config.update("jax_enable_x64", True) - -parameters = KineticModelParameters( +parameters = AllostericMichaelisMentenParameterSet( log_kcat=jnp.array([-0.1, 0.0, 0.1]), log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), dgf=jnp.array([-3, -1.0]), @@ -35,74 +31,70 @@ balanced_species=jnp.array([1, 2]), unbalanced_species=jnp.array([0, 3]), ) - -unparameterised_model = UnparameterisedKineticModel( - structure, - [ - AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - subunits=1, - ), - AllostericReversibleMichaelisMenten( - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([0]), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([1]), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([1], dtype=jnp.int16), - ix_product=jnp.array([2], dtype=jnp.int16), - ix_reactants=jnp.array([1, 2], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=1, - ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), - ix_dc_activation=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([1], dtype=jnp.int16), - subunits=1, - ), - ReversibleMichaelisMenten( - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4, 5], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - ix_substrate=jnp.array([2], dtype=jnp.int16), - ix_product=jnp.array([3], dtype=jnp.int16), - ix_reactants=jnp.array([2, 3], dtype=jnp.int16), - reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - ), - ], -) -model = KineticModel(parameters, unparameterised_model) +rate_equations = [ + AllostericReversibleMichaelisMenten( + kcat_ix=0, + enzyme_ix=0, + km_ix=jnp.array([0, 1], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ix_substrate=jnp.array([0], dtype=jnp.int16), + ix_product=jnp.array([1], dtype=jnp.int16), + ix_reactants=jnp.array([0, 1], dtype=jnp.int16), + product_reactant_positions=jnp.array([1], dtype=jnp.int16), + product_km_positions=jnp.array([1], dtype=jnp.int16), + water_stoichiometry=jnp.array(0.0), + tc_ix=0, + ix_dc_inhibition=jnp.array([], dtype=jnp.int16), + ix_dc_activation=jnp.array([0], dtype=jnp.int16), + species_activation=jnp.array([2], dtype=jnp.int16), + species_inhibition=jnp.array([], dtype=jnp.int16), + subunits=1, + ), + AllostericReversibleMichaelisMenten( + kcat_ix=1, + enzyme_ix=1, + km_ix=jnp.array([2, 3], dtype=jnp.int16), + ki_ix=jnp.array([0]), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([1]), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ix_substrate=jnp.array([1], dtype=jnp.int16), + ix_product=jnp.array([2], dtype=jnp.int16), + ix_reactants=jnp.array([1, 2], dtype=jnp.int16), + product_reactant_positions=jnp.array([1], dtype=jnp.int16), + product_km_positions=jnp.array([1], dtype=jnp.int16), + water_stoichiometry=jnp.array(0.0), + tc_ix=1, + ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), + ix_dc_activation=jnp.array([], dtype=jnp.int16), + species_activation=jnp.array([], dtype=jnp.int16), + species_inhibition=jnp.array([1], dtype=jnp.int16), + subunits=1, + ), + ReversibleMichaelisMenten( + kcat_ix=2, + enzyme_ix=2, + km_ix=jnp.array([4, 5], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + ix_substrate=jnp.array([2], dtype=jnp.int16), + ix_product=jnp.array([3], dtype=jnp.int16), + ix_reactants=jnp.array([2, 3], dtype=jnp.int16), + reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + product_reactant_positions=jnp.array([1], dtype=jnp.int16), + product_km_positions=jnp.array([1], dtype=jnp.int16), + water_stoichiometry=jnp.array(0.0), + ), +] steady_state = jnp.array([0.43658744, 0.12695706]) +model = RateEquationModel(parameters, structure, rate_equations) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 1334b54..1179b11 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -1,13 +1,15 @@ -"""A simple linear kinetic model.""" +"""A kinetic model of the methionine cycle. + +See here for more about the methionine cycle: +https://doi.org/10.1021/acssynbio.3c00662 + +""" -from jax import config from jax import numpy as jnp from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, + RateEquationModel, KineticModelStructure, - UnparameterisedKineticModel, ) from enzax.rate_equations import ( AllostericIrreversibleMichaelisMenten, @@ -15,11 +17,154 @@ IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) +from enzax.parameters import AllostericMichaelisMentenParameterSet - -config.update("jax_enable_x64", True) - -parameters = KineticModelParameters( +coords = { + "enzyme": [ + "MAT1", + "MAT3", + "METH-Gen", + "GNMT1", + "AHC1", + "MS1", + "BHMT1", + "CBS1", + "MTHFR1", + "PROT1", + ], + "drain": ["the_drain"], + "rate": [ + "the_drain", + "MAT1", + "MAT3", + "METH-Gen", + "GNMT1", + "AHC1", + "MS1", + "BHMT1", + "CBS1", + "MTHFR1", + "PROT1", + ], + "metabolite": [ + "met-L", + "atp", + "pi", + "ppi", + "amet", + "ahcys", + "gly", + "sarcs", + "hcys-L", + "adn", + "thf", + "5mthf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", + ], + "km": [ + "met-L MAT1", + "atp MAT1", + "met-L MAT3", + "atp MAT3", + "amet METH-Gen", + "amet GNMT1", + "gly GNMT1", + "ahcys AHC1", + "hcys-L AHC1", + "adn AHC1", + "hcys-L MS1", + "5mthf MS1", + "hcys-L BHMT1", + "glyb BHMT1", + "hcys-L CBS1", + "ser-L CBS1", + "mlthf MTHFR1", + "nadph MTHFR1", + "met-L PROT1", + ], + "ki": [ + "MAT1", + "METH-Gen", + "GNMT1", + ], + "species": [ + "met-L", + "atp", + "pi", + "ppi", + "amet", + "ahcys", + "gly", + "sarcs", + "hcys-L", + "adn", + "thf", + "5mthf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", + ], + "balanced_species": [ + "met-L", + "amet", + "ahcys", + "hcys-L", + "5mthf", + ], + "unbalanced_species": [ + "atp", + "pi", + "ppi", + "gly", + "sarcs", + "adn", + "thf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", + ], + "transfer_constant": [ + "METAT", + "GNMT", + "CBS", + "MTHFR", + ], + "dissociation_constant": [ + "met-L MAT3", + "amet MAT3", + "amet GNMT1", + "mlthf GNMT1", + "amet CBS1", + "amet MTHFR1", + "ahcys MTHFR1", + ], +} +dims = { + "log_kcat": ["enzyme"], + "log_enzyme": ["enzyme"], + "log_drain": ["drain"], + "log_km": ["km"], + "dgf": ["metabolite"], + "log_ki": ["ki"], + "log_conc_unbalanced": ["unbalanced_species"], + "log_transfer_constant": ["transfer_constant"], + "log_dissociation_constant": ["dissociation_constant"], +} +parameters = AllostericMichaelisMentenParameterSet( log_kcat=jnp.log( jnp.array( [ @@ -209,159 +354,155 @@ ] ), ) -unparameterised_model = UnparameterisedKineticModel( - structure, - [ - Drain(sign=jnp.array(1.0), drain_ix=0), # met-L source - IrreversibleMichaelisMenten( # MAT1 - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([0], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 - ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([4], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MAT3 - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 - ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0, 1], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([0, 4], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # METH - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4], dtype=jnp.int16), - ki_ix=jnp.array([1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([4], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), +rate_equations = [ + Drain(sign=jnp.array(1.0), drain_ix=0), # met-L source + IrreversibleMichaelisMenten( # MAT1 + kcat_ix=0, + enzyme_ix=0, + km_ix=jnp.array([0, 1], dtype=jnp.int16), + ki_ix=jnp.array([0], dtype=jnp.int16), + reactant_stoichiometry=jnp.array( + [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 ), - AllostericIrreversibleMichaelisMenten( # GNMT1 - kcat_ix=3, - enzyme_ix=3, - km_ix=jnp.array([5, 6], dtype=jnp.int16), - ki_ix=jnp.array([2], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, 1.0, -1.0, 1.0], dtype=jnp.int16 - ), - ix_substrate=jnp.array([4, 6], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 2], dtype=jnp.int16), - subunits=4, - tc_ix=1, - ix_dc_activation=jnp.array([2], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([3], dtype=jnp.int16), - species_inhibition=jnp.array([12], dtype=jnp.int16), - species_activation=jnp.array([4], dtype=jnp.int16), - ), - ReversibleMichaelisMenten( # AHC - kcat_ix=4, - enzyme_ix=4, - km_ix=jnp.array([7, 8, 9], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0, 1.0, 1.0], dtype=jnp.int16), - ix_substrate=jnp.array([5], dtype=jnp.int16), - ix_product=jnp.array([8, 9], dtype=jnp.int16), - ix_reactants=jnp.array([5, 8, 9], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - product_km_positions=jnp.array([1, 2], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - water_stoichiometry=jnp.array(-1.0), - reactant_to_dgf=jnp.array([5, 8, 9], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # MS - kcat_ix=5, - enzyme_ix=5, - km_ix=jnp.array([10, 11], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 11], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # BHMT - kcat_ix=6, - enzyme_ix=6, - km_ix=jnp.array([12, 13], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 13], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # CBS1 - kcat_ix=7, - enzyme_ix=7, - km_ix=jnp.array([14, 15], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 15], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=2, - ix_dc_activation=jnp.array([4], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MTHFR - kcat_ix=8, - enzyme_ix=8, - km_ix=jnp.array([16, 17], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([12, 17], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - subunits=2, - tc_ix=3, - ix_dc_activation=jnp.array([6], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([5], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([5], dtype=jnp.int16), + ix_substrate=jnp.array([0, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([4], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # MAT3 + kcat_ix=1, + enzyme_ix=1, + km_ix=jnp.array([2, 3], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array( + [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 ), - IrreversibleMichaelisMenten( # PROT - kcat_ix=9, - enzyme_ix=9, - km_ix=jnp.array([18], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ix_substrate=jnp.array([0, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), + subunits=2, + tc_ix=0, + ix_dc_inhibition=jnp.array([], dtype=jnp.int16), + ix_dc_activation=jnp.array([0, 1], dtype=jnp.int16), + species_inhibition=jnp.array([], dtype=jnp.int16), + species_activation=jnp.array([0, 4], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # METH + kcat_ix=2, + enzyme_ix=2, + km_ix=jnp.array([4], dtype=jnp.int16), + ki_ix=jnp.array([1], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + ix_substrate=jnp.array([4], dtype=jnp.int16), + ix_ki_species=jnp.array([5], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # GNMT1 + kcat_ix=3, + enzyme_ix=3, + km_ix=jnp.array([5, 6], dtype=jnp.int16), + ki_ix=jnp.array([2], dtype=jnp.int16), + reactant_stoichiometry=jnp.array( + [-1.0, 1.0, -1.0, 1.0], dtype=jnp.int16 ), - ], -) -model = KineticModel(parameters, unparameterised_model) + ix_substrate=jnp.array([4, 6], dtype=jnp.int16), + ix_ki_species=jnp.array([5], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 2], dtype=jnp.int16), + subunits=4, + tc_ix=1, + ix_dc_activation=jnp.array([2], dtype=jnp.int16), + ix_dc_inhibition=jnp.array([3], dtype=jnp.int16), + species_inhibition=jnp.array([12], dtype=jnp.int16), + species_activation=jnp.array([4], dtype=jnp.int16), + ), + ReversibleMichaelisMenten( # AHC + kcat_ix=4, + enzyme_ix=4, + km_ix=jnp.array([7, 8, 9], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1.0, 1.0, 1.0], dtype=jnp.int16), + ix_substrate=jnp.array([5], dtype=jnp.int16), + ix_product=jnp.array([8, 9], dtype=jnp.int16), + ix_reactants=jnp.array([5, 8, 9], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + product_km_positions=jnp.array([1, 2], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + product_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), + water_stoichiometry=jnp.array(-1.0), + reactant_to_dgf=jnp.array([5, 8, 9], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # MS + kcat_ix=5, + enzyme_ix=5, + km_ix=jnp.array([10, 11], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), + ix_substrate=jnp.array([8, 11], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # BHMT + kcat_ix=6, + enzyme_ix=6, + km_ix=jnp.array([12, 13], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([1, -1, -1, 1], dtype=jnp.int16), + ix_substrate=jnp.array([8, 13], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # CBS1 + kcat_ix=7, + enzyme_ix=7, + km_ix=jnp.array([14, 15], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, -1, 1], dtype=jnp.int16), + ix_substrate=jnp.array([8, 15], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), + subunits=2, + tc_ix=2, + ix_dc_activation=jnp.array([4], dtype=jnp.int16), + ix_dc_inhibition=jnp.array([], dtype=jnp.int16), + species_inhibition=jnp.array([4], dtype=jnp.int16), + species_activation=jnp.array([], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # MTHFR + kcat_ix=8, + enzyme_ix=8, + km_ix=jnp.array([16, 17], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), + ix_substrate=jnp.array([12, 17], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), + subunits=2, + tc_ix=3, + ix_dc_activation=jnp.array([6], dtype=jnp.int16), + ix_dc_inhibition=jnp.array([5], dtype=jnp.int16), + species_inhibition=jnp.array([4], dtype=jnp.int16), + species_activation=jnp.array([5], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # PROT + kcat_ix=9, + enzyme_ix=9, + km_ix=jnp.array([18], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1.0], dtype=jnp.int16), + ix_substrate=jnp.array([0], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ), +] steady_state = jnp.array( [ 4.233000e-05, # met-L @@ -371,3 +512,4 @@ 6.534400e-06, # 5mthf ] ) +model = RateEquationModel(parameters, structure, rate_equations) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 8f1c541..439b7cd 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -1,59 +1,61 @@ """Module containing enzax's definition of a kinetic model.""" +from abc import ABC, abstractmethod + import equinox as eqx import jax.numpy as jnp -from jaxtyping import Array, Float, Int, PyTree, Scalar, ScalarLike, jaxtyped +from jaxtyping import Array, Float, Int, PyTree, ScalarLike, jaxtyped from typeguard import typechecked -from enzax.rate_equations import RateEquation - - -@jaxtyped(typechecker=typechecked) -class KineticModelParameters(eqx.Module): - """Parameters for a kinetic model.""" - - log_kcat: Float[Array, " n_enzyme"] - log_enzyme: Float[Array, " n_enzyme"] - dgf: Float[Array, " n_metabolite"] - log_km: Float[Array, " n_km"] - log_ki: Float[Array, " n_ki"] - log_conc_unbalanced: Float[Array, " n_unbalanced"] - temperature: Scalar - log_transfer_constant: Float[Array, " n_allosteric_enzyme"] - log_dissociation_constant: Float[Array, " n_allosteric_effector"] - log_drain: Float[Array, " n_drain"] +from enzax.rate_equation import RateEquation @jaxtyped(typechecker=typechecked) class KineticModelStructure(eqx.Module): """Structural information about a kinetic model.""" - S: Float[Array, " s r"] = eqx.field(static=True) - balanced_species: Int[Array, " n_balanced"] = eqx.field(static=True) - unbalanced_species: Int[Array, " n_unbalanced"] = eqx.field(static=True) + S: Float[Array, " s r"] + balanced_species: Int[Array, " n_balanced"] + unbalanced_species: Int[Array, " n_unbalanced"] -class UnparameterisedKineticModel(eqx.Module): - """A kinetic model without parameter values.""" +class KineticModel(eqx.Module, ABC): + """Abstract base class for kinetic models.""" + parameters: PyTree structure: KineticModelStructure - rate_equations: list[RateEquation] + rate_equations: list[RateEquation] = eqx.field(default_factory=list) + @abstractmethod + def flux( + self, + conc_balanced: Float[Array, " n_balanced"], + ) -> Float[Array, " n"]: ... -class KineticModel(eqx.Module): - """A parameterised kinetic model.""" + def dcdt( + self, t: ScalarLike, conc: Float[Array, " n_balanced"], args=None + ) -> Float[Array, " n_balanced"]: + """Get the rate of change of balanced species concentrations. + + Note that the signature is as required for a Diffrax vector field function, hence the redundant variable t and the weird name "args". + + :param t: redundant variable representing time. + + :param conc: a one dimensional array of positive floats representing concentrations of balanced species. Must have same size as self.structure.ix_balanced + + """ # Noqa: E501 + sv = self.structure.S @ self.flux(conc) + return sv[self.structure.balanced_species] - parameters: PyTree - structure: KineticModelStructure - rate_equations: list[RateEquation] - def __init__(self, parameters, unparameterised_model): - self.parameters = parameters - self.structure = unparameterised_model.structure - self.rate_equations = unparameterised_model.rate_equations +class RateEquationModel(KineticModel): + """A kinetic model that specifies its fluxes using RateEquation objects.""" + + rate_equations: list[RateEquation] = eqx.field(default_factory=list) def flux( - self, conc_balanced: Float[Array, " n_balanced"] + self, + conc_balanced: Float[Array, " n_balanced"], ) -> Float[Array, " n"]: """Get fluxes from balanced species concentrations. @@ -61,7 +63,7 @@ def flux( :return: a one dimensional array of (possibly negative) floats representing reaction fluxes. Has same size as number of columns of self.structure.S. - """ + """ # Noqa: E501 conc = jnp.zeros(self.structure.S.shape[0]) conc = conc.at[self.structure.balanced_species].set(conc_balanced) conc = conc.at[self.structure.unbalanced_species].set( @@ -70,20 +72,3 @@ def flux( t = [f(conc, self.parameters) for f in self.rate_equations] out = jnp.array(t) return out - - def dcdt( - self, t: ScalarLike, conc: Float[Array, " n_balanced"], args=None - ) -> Float[Array, " n_balanced"]: - """Get the rate of change of balanced species concentrations. - - Note that the signature is as required for a Diffrax vector field function, hence the redundant variable t and the weird name "args". - - :param t: redundant variable representing time. - - :param conc: a one dimensional array of positive floats representing concentrations of balanced species. Must have same size as self.structure.ix_balanced - - """ - out = (self.structure.S @ self.flux(conc))[ - self.structure.balanced_species - ] - return out diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index bf18618..ef32b09 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -1,35 +1,32 @@ -"""Code for doing mcmc on the parameters of a steady state problem.""" +"""Code for MCMC-based Bayesian inference on kinetic models.""" import functools -import logging -import warnings +from typing import Callable, TypedDict, Unpack import arviz as az import blackjax import chex -import equinox as eqx import jax +from jax._src.random import KeyArray import jax.numpy as jnp from jax.scipy.stats import norm -from jaxtyping import Array, Float, ScalarLike +from jaxtyping import Array, Float, PyTree, ScalarLike -from enzax.examples import methionine from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, - UnparameterisedKineticModel, + RateEquationModel, + KineticModelStructure, ) -from enzax.steady_state_problem import solve - -SEED = 1234 -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_debug_nans", True) +from enzax.parameters import AllostericMichaelisMentenParameterSet +from enzax.rate_equation import RateEquation +from enzax.steady_state import get_kinetic_model_steady_state @chex.dataclass class ObservationSet: + """Measurements from a single experiment.""" + conc: Float[Array, " m"] - flux: ScalarLike + flux: Float[Array, " n"] enzyme: Float[Array, " e"] conc_scale: ScalarLike flux_scale: ScalarLike @@ -37,9 +34,11 @@ class ObservationSet: @chex.dataclass -class PriorSet: - log_kcat: Float[Array, "2 n"] - log_enzyme: Float[Array, "2 n"] +class AllostericMichaelisMentenPriorSet: + """Priors for an allosteric Michaelis-Menten model.""" + + log_kcat: Float[Array, "2 n_enzyme"] + log_enzyme: Float[Array, "2 n_enzyme"] log_drain: Float[Array, "2 n_drain"] dgf: Float[Array, "2 n_metabolite"] log_km: Float[Array, "2 n_km"] @@ -47,23 +46,34 @@ class PriorSet: log_conc_unbalanced: Float[Array, "2 n_unbalanced"] temperature: Float[Array, "2"] log_transfer_constant: Float[Array, "2 n_allosteric_enzyme"] - log_dissociation_constant: Float[Array, "2 n_allosteric_effector"] + log_dissociation_constant: Float[Array, "2 n_allosteric_effect"] + +class AdaptationKwargs(TypedDict): + """Keyword arguments to the blackjax function window_adaptation.""" -def ind_normal_prior_logdensity(param, prior): + initial_step_size: float + max_num_doublings: int + is_mass_matrix_diagonal: bool + target_acceptance_rate: float + + +def ind_normal_prior_logdensity(param, prior: Float[Array, "2 _"]): + """Total log density for an independent normal distribution.""" return norm.logpdf(param, loc=prior[0], scale=prior[1]).sum() -@eqx.filter_jit -def posterior_logdensity_fn( - parameters: KineticModelParameters, - unparameterised_model: UnparameterisedKineticModel, +def posterior_logdensity_amm( + parameters: AllostericMichaelisMentenParameterSet, + structure: KineticModelStructure, + rate_equations: list[RateEquation], obs: ObservationSet, - prior: PriorSet, + prior: AllostericMichaelisMentenPriorSet, guess: Float[Array, " n_balanced"], ): - model = KineticModel(parameters, unparameterised_model) - steady = solve(parameters, unparameterised_model, guess) + """Get the log density for an allosteric Michaelis-Menten model.""" + model = RateEquationModel(parameters, structure, rate_equations) + steady = get_kinetic_model_steady_state(model, guess) flux = model.flux(steady) conc = jnp.zeros(model.structure.S.shape[0]) conc = conc.at[model.structure.balanced_species].set(steady) @@ -100,149 +110,75 @@ def posterior_logdensity_fn( @functools.partial(jax.jit, static_argnames=["kernel", "num_samples"]) -def inference_loop(rng_key, kernel, initial_state, num_samples): +def _inference_loop(rng_key, kernel, initial_state, num_samples): + """Run MCMC with blackjax.""" + def one_step(state, rng_key): state, info = kernel(rng_key, state) return state, (state, info) keys = jax.random.split(rng_key, num_samples) _, (states, info) = jax.lax.scan(one_step, initial_state, keys) - return states, info -def warn_if_divergent(info): - if jnp.any(info.is_divergent): - warnings.warn("I found a divergent transition!") - - -@eqx.filter_jit -def sample(logdensity_fn, rng_key, init_parameters): +def run_nuts( + logdensity_fn: Callable, + rng_key: KeyArray, + init_parameters: PyTree, + num_warmup: int, + num_samples: int, + **adapt_kwargs: Unpack[AdaptationKwargs], +): + """Run the default NUTS algorithm with blackjax.""" warmup = blackjax.window_adaptation( blackjax.nuts, logdensity_fn, progress_bar=True, - initial_step_size=0.0001, - max_num_doublings=10, - is_mass_matrix_diagonal=False, - target_acceptance_rate=0.95, + **adapt_kwargs, ) rng_key, warmup_key = jax.random.split(rng_key) (initial_state, tuned_parameters), (_, info, _) = warmup.run( warmup_key, init_parameters, - num_steps=1000, # type: ignore + num_steps=num_warmup, #  type: ignore ) rng_key, sample_key = jax.random.split(rng_key) nuts_kernel = blackjax.nuts(logdensity_fn, **tuned_parameters).step - states, info = inference_loop( + states, info = _inference_loop( sample_key, kernel=nuts_kernel, initial_state=initial_state, - num_samples=200, + num_samples=num_samples, ) return states, info -def ind_prior_from_truth(truth, sd): +def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike): + """Get a set of independent priors centered at the true parameter values. + + Note that the standard deviation currently has to be the same for + all parameters. + + """ return jnp.vstack((truth, jnp.full(truth.shape, sd))) -def get_idata(samples, info): +def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData: + """Get an arviz InferenceData from a blackjax NUTS output.""" sample_dict = { k: jnp.expand_dims(getattr(samples.position, k), 0) for k in samples.position.__dataclass_fields__.keys() } - posterior = az.convert_to_inference_data(sample_dict, group="posterior") + posterior = az.convert_to_inference_data( + sample_dict, + group="posterior", + coords=coords, + dims=dims, + ) sample_stats = az.convert_to_inference_data( {"diverging": info.is_divergent}, group="sample_stats" ) - return az.concat(posterior, sample_stats) - - -def main(): - """Demonstrate the functionality of the mcmc module.""" - true_parameters = methionine.parameters - unparameterised_model = methionine.unparameterised_model - true_model = methionine.model - default_state_guess = jnp.full((5,), 0.01) - true_states = solve( - true_parameters, unparameterised_model, default_state_guess - ) - prior = PriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=ind_prior_from_truth(true_parameters.dgf, 0.1), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) - # get true concentration - true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( - true_states - ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( - jnp.exp(true_parameters.log_conc_unbalanced) - ) - # get true flux - true_flux = true_model.flux(true_states) - # simulate observations - error_conc = 0.03 - error_flux = 0.05 - error_enzyme = 0.03 - key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) - obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme - ) - obs_flux = true_flux + jax.random.normal(key) * error_conc - obs = ObservationSet( - conc=obs_conc, - flux=obs_flux, - enzyme=obs_enzyme, - conc_scale=error_conc, - flux_scale=error_flux, - enzyme_scale=error_enzyme, - ) - log_M = functools.partial( - posterior_logdensity_fn, - obs=obs, - prior=prior, - unparameterised_model=unparameterised_model, - guess=default_state_guess, - ) - samples, info = sample(log_M, key, true_parameters) - idata = get_idata(samples, info) - print(az.summary(idata)) - if jnp.any(info.is_divergentl): - n_divergent = info.is_divergent.sum() - msg = f"There were {n_divergent} post-warmup divergent transitions." - warnings.warn(msg) - else: - logging.info("No post-warmup divergent transitions!") - print("True parameter values vs posterior:") - for param in true_parameters.__dataclass_fields__.keys(): - true_val = getattr(true_parameters, param) - model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) - model_high = jnp.quantile( - getattr(samples.position, param), 0.99, axis=0 - ) - print(f" {param}:") - print(f" true value: {true_val}") - print(f" posterior 1%: {model_low}") - print(f" posterior 99%: {model_high}") - - -if __name__ == "__main__": - main() + idata = az.concat(posterior, sample_stats) + assert idata is not None + return idata diff --git a/src/enzax/parameters.py b/src/enzax/parameters.py new file mode 100644 index 0000000..6534169 --- /dev/null +++ b/src/enzax/parameters.py @@ -0,0 +1,58 @@ +"""Module with parameters and parameter sets. + +These are not required, but they do provide handy shape checking, for example to ensure that your parameter set has the same number of enzymes and kcat parameters. + +""" # Noqa: E501 + +import equinox as eqx +from jaxtyping import Array, Float, Scalar, jaxtyped +from typeguard import typechecked + +LogKcat = Float[Array, " n_enzyme"] +LogEnzyme = Float[Array, " n_enzyme"] +Dgf = Float[Array, " n_metabolite"] +LogKm = Float[Array, " n_km"] +LogKi = Float[Array, " n_ki"] +LogConcUnbalanced = Float[Array, " n_unbalanced"] +LogDrain = Float[Array, " n_drain"] +LogTransferConstant = Float[Array, " n_allosteric_enzyme"] +LogDissociationConstant = Float[Array, " n_allosteric_effect"] + + +@jaxtyped(typechecker=typechecked) +class MichaelisMentenParameterSet(eqx.Module): + """Parameters for a model with Michaelis Menten kinetics. + + This kind of parameter set supports models with the following rate laws: + + - enzax.rate_equations.drain.Drain + - enzax.rate_equations.michaelis_menten.IrreversibleMichaelisMenten + - enzax.rate_equations.michaelis_menten.ReversibleMichaelisMenten + + """ + + log_kcat: LogKcat + log_enzyme: LogEnzyme + dgf: Dgf + log_km: LogKm + log_ki: LogKi + log_conc_unbalanced: LogConcUnbalanced + temperature: Scalar + log_drain: LogDrain + + +class AllostericMichaelisMentenParameterSet(MichaelisMentenParameterSet): + """Parameters for a model with Michaelis Menten kinetics, with allostery. + + Reactions can be any out of: + + - drain.Drain + - michaelis_menten.IrreversibleMichaelisMenten + - michaelis_menten.ReversibleMichaelisMenten + - generalised_mwc.AllostericIrreversibleMichaelisMenten + - generalised_mwc.AllostericReversibleMichaelisMenten + + """ + + log_transfer_constant: LogTransferConstant + log_dissociation_constant: LogDissociationConstant diff --git a/src/enzax/rate_equation.py b/src/enzax/rate_equation.py new file mode 100644 index 0000000..aa89443 --- /dev/null +++ b/src/enzax/rate_equation.py @@ -0,0 +1,19 @@ +"""Module containing rate equations for enzyme-catalysed reactions.""" + +from abc import ABC, abstractmethod +from equinox import Module + +from jaxtyping import Array, Float, PyTree, Scalar + + +ConcArray = Float[Array, " n"] + + +class RateEquation(Module, ABC): + """Abstract definition of a rate equation. + + A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of parameters, returning a scalar value representing a single flux. + """ # Noqa: E501 + + @abstractmethod + def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: ... diff --git a/src/enzax/rate_equations.py b/src/enzax/rate_equations.py deleted file mode 100644 index af8e346..0000000 --- a/src/enzax/rate_equations.py +++ /dev/null @@ -1,355 +0,0 @@ -"""Module containing rate equations for enzyme-catalysed reactions.""" - -from abc import ABC, abstractmethod -from equinox import Module - -from jax import numpy as jnp -from jaxtyping import Array, Float, Int, PyTree, Scalar - - -ConcArray = Float[Array, " n"] - - -class RateEquation(Module, ABC): - """Abstract definition of a rate equation. - - A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of parameters, returning a scalar value representing a single flux. - """ - - @abstractmethod - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: ... - - -def get_drain_flux(sign: Scalar, log_v: Scalar) -> Scalar: - """Get the flux of a drain reaction. - - :param sign: a scalar value (should be either one or zero) representing the direction of the reaction. - - :param log_v: a scalar representing the magnitude of the reaction, on log scale. - - """ - return sign * jnp.exp(log_v) - - -class Drain(RateEquation): - """A drain reaction.""" - - sign: Scalar - drain_ix: int - - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: - """Get the flux of a drain reaction.""" - log_v = parameters.log_drain[self.drain_ix] - return get_drain_flux(self.sign, log_v) - - -def numerator_mm( - conc: ConcArray, - km: Float[Array, " n"], - ix_substrate: Int[Array, " n_substrate"], - substrate_km_positions: Int[Array, " n_substrate"], -) -> Scalar: - """Get the product of each substrate's concentration over its km. - This quantity is the numerator in a Michaelis Menten reaction's rate equation - """ - return jnp.prod((conc[ix_substrate] / km[substrate_km_positions])) - - -class MichaelisMenten(RateEquation): - """Base class for Michaelis Menten rate equations. - - Subclasses need to implement the __call__ method. - - """ - - kcat_ix: int - enzyme_ix: int - km_ix: Int[Array, " n"] - ki_ix: Int[Array, " n_ki"] - reactant_stoichiometry: Float[Array, " n"] - ix_substrate: Int[Array, " n_substrate"] - ix_ki_species: Int[Array, " n_ki"] - substrate_km_positions: Int[Array, " n_substrate"] - substrate_reactant_positions: Int[Array, " n_substrate"] - - def get_kcat(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_kcat[self.kcat_ix]) - - def get_km(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_km[self.km_ix]) - - def get_ki(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_ki[self.ki_ix]) - - def get_enzyme(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_enzyme[self.enzyme_ix]) - - -def free_enzyme_ratio_imm( - conc: ConcArray, - km: Float[Array, " n"], - ki: Float[Array, " n_ki"], - ix_substrate: Int[Array, " n_substrate"], - substrate_km_positions: Int[Array, " n_substrate"], - substrate_reactant_positions: Int[Array, " n_substrate"], - ix_ki_species: Int[Array, " n_ki"], - reactant_stoichiometry: Float[Array, " n"], -) -> Scalar: - """Free enzyme ratio for irreversible Michaelis Menten reactions.""" - return 1.0 / ( - jnp.prod( - ((conc[ix_substrate] / km[substrate_km_positions]) + 1) - ** jnp.abs(reactant_stoichiometry[substrate_reactant_positions]) - ) - + jnp.sum(conc[ix_ki_species] / ki) - ) - - -class IrreversibleMichaelisMenten(MichaelisMenten): - """A reaction with irreversible Michaelis Menten kinetics.""" - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """Get flux of a reaction with irreversible Michaelis Menten kinetics.""" - kcat = self.get_kcat(parameters) - enzyme = self.get_enzyme(parameters) - km = self.get_km(parameters) - ki = self.get_ki(parameters) - numerator = numerator_mm( - conc=conc, - km=km, - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, - ) - free_enzyme_ratio = free_enzyme_ratio_imm( - conc=conc, - km=km, - ki=ki, - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, - substrate_reactant_positions=self.substrate_reactant_positions, - ix_ki_species=self.ix_ki_species, - reactant_stoichiometry=self.reactant_stoichiometry, - ) - return kcat * enzyme * numerator * free_enzyme_ratio - - -def get_reversibility( - conc: Float[Array, " n"], - water_stoichiometry: Scalar, - dgf: Float[Array, " n_reactant"], - temperature: Scalar, - reactant_stoichiometry: Float[Array, " n_reactant"], - ix_reactants: Int[Array, " n_reactant"], -) -> Scalar: - """Get the reversibility of a reaction. - - Hard coded water dgf is taken from . - - """ - RT = temperature * 0.008314 - dgf_water = -150.9 - dgr = reactant_stoichiometry @ dgf + water_stoichiometry * dgf_water - quotient = reactant_stoichiometry @ jnp.log(conc[ix_reactants]) - out = 1.0 - jnp.exp(((dgr + RT * quotient) / RT)) - return out - - -def free_enzyme_ratio_rmm( - conc: ConcArray, - km: Float[Array, " n_reactant"], - ki: Float[Array, " n_ki"], - reactant_stoichiometry: Float[Array, " n_reactant"], - ix_substrate: Int[Array, " n_substrate"], - ix_product: Int[Array, " n_product"], - substrate_km_positions: Int[Array, " n_substrate"], - product_km_positions: Int[Array, " n_product"], - substrate_reactant_positions: Int[Array, " n_substrate"], - product_reactant_positions: Int[Array, " n_product"], - ix_ki_species: Int[Array, " n_ki"], -) -> Scalar: - """The free enzyme ratio for a reversible Michaelis Menten reaction.""" - return 1.0 / ( - -1.0 - + jnp.prod( - ((conc[ix_substrate] / km[substrate_km_positions]) + 1.0) - ** jnp.abs(reactant_stoichiometry[substrate_reactant_positions]) - ) - + jnp.prod( - ((conc[ix_product] / km[product_km_positions]) + 1.0) - ** jnp.abs(reactant_stoichiometry[product_reactant_positions]) - ) - + jnp.sum(conc[ix_ki_species] / ki) - ) - - -class ReversibleMichaelisMenten(MichaelisMenten): - """A reaction with reversible Michaelis Menten kinetics.""" - - ix_product: Int[Array, " n_product"] - ix_reactants: Int[Array, " n_reactant"] - product_reactant_positions: Int[Array, " n_product"] - product_km_positions: Int[Array, " n_product"] - water_stoichiometry: Scalar - reactant_to_dgf: Int[Array, " n_reactant"] - - def _get_dgf(self, parameters: PyTree): - return parameters.dgf[self.reactant_to_dgf] - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """Get flux of a reaction with reversible Michaelis Menten kinetics. - - :param conc: A 1D array of non-negative numbers representing concentrations of the species that the reaction produces and consumes. - - """ - kcat = self.get_kcat(parameters) - enzyme = self.get_enzyme(parameters) - km = self.get_km(parameters) - ki = self.get_ki(parameters) - reversibility = get_reversibility( - conc=conc, - water_stoichiometry=self.water_stoichiometry, - dgf=self._get_dgf(parameters), - temperature=parameters.temperature, - reactant_stoichiometry=self.reactant_stoichiometry, - ix_reactants=self.ix_reactants, - ) - numerator = numerator_mm( - conc=conc, - km=km, - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, - ) - free_enzyme_ratio = free_enzyme_ratio_rmm( - conc=conc, - km=km, - ki=ki, - reactant_stoichiometry=self.reactant_stoichiometry, - ix_substrate=self.ix_substrate, - ix_product=self.ix_product, - substrate_km_positions=self.substrate_km_positions, - product_km_positions=self.product_km_positions, - substrate_reactant_positions=self.substrate_reactant_positions, - product_reactant_positions=self.product_reactant_positions, - ix_ki_species=self.ix_ki_species, - ) - return reversibility * kcat * enzyme * numerator * free_enzyme_ratio - - -def get_allosteric_effect( - conc: Float[Array, " n_reactant"], - free_enzyme_ratio: Scalar, - tc: Scalar, - dc_inhibition: Float[Array, " n_inhibition"], - dc_activation: Float[Array, " n_activation"], - species_inhibition: Int[Array, " n_inhibition"], - species_activation: Int[Array, " n_activation"], - subunits: int, -) -> Scalar: - """Get the allosteric effect on a rate. - - The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2. - - """ - qnum = 1 + jnp.sum(conc[species_inhibition] / dc_inhibition) - qdenom = 1 + jnp.sum(conc[species_activation] / dc_activation) - out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits) - return out - - -class AllostericRateLaw(MichaelisMenten): - """Mixin class for allosteric rate laws.""" - - subunits: int - tc_ix: int - ix_dc_activation: Int[Array, " n_activation"] - ix_dc_inhibition: Int[Array, " n_inhibition"] - species_activation: Int[Array, " n_activation"] - species_inhibition: Int[Array, " n_inhibition"] - - def get_tc(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_transfer_constant[self.tc_ix]) - - def get_dc_activation(self, parameters: PyTree) -> Scalar: - return jnp.exp( - parameters.log_dissociation_constant[self.ix_dc_activation] - ) - - def get_dc_inhibition(self, parameters: PyTree) -> Scalar: - return jnp.exp( - parameters.log_dissociation_constant[self.ix_dc_inhibition] - ) - - -class AllostericIrreversibleMichaelisMenten( - AllostericRateLaw, IrreversibleMichaelisMenten -): - """A reaction with irreversible Michaelis Menten kinetics and allostery.""" - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """The flux of an allosteric irreversible Michaelis Menten reaction.""" - km = self.get_km(parameters) - ki = self.get_ki(parameters) - tc = self.get_tc(parameters) - dc_activation = self.get_dc_activation(parameters) - dc_inhibition = self.get_dc_inhibition(parameters) - free_enzyme_ratio = free_enzyme_ratio_imm( - conc=conc, - km=km, - ki=ki, - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, - substrate_reactant_positions=self.substrate_reactant_positions, - ix_ki_species=self.ix_ki_species, - reactant_stoichiometry=self.reactant_stoichiometry, - ) - allosteric_effect = get_allosteric_effect( - conc=conc, - free_enzyme_ratio=free_enzyme_ratio, - tc=tc, - dc_inhibition=dc_inhibition, - dc_activation=dc_activation, - species_inhibition=self.species_inhibition, - species_activation=self.species_activation, - subunits=self.subunits, - ) - non_allosteric_rate = super().__call__(conc, parameters) - return non_allosteric_rate * allosteric_effect - - -class AllostericReversibleMichaelisMenten( - AllostericRateLaw, ReversibleMichaelisMenten -): - """A reaction with reversible Michaelis Menten kinetics and allostery.""" - - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: - """The flux of an allosteric reversible Michaelis Menten reaction.""" - km = self.get_km(parameters) - ki = self.get_ki(parameters) - tc = self.get_tc(parameters) - dc_activation = self.get_dc_activation(parameters) - dc_inhibition = self.get_dc_inhibition(parameters) - free_enzyme_ratio = free_enzyme_ratio_rmm( - conc=conc, - km=km, - ki=ki, - reactant_stoichiometry=self.reactant_stoichiometry, - ix_substrate=self.ix_substrate, - ix_product=self.ix_product, - substrate_km_positions=self.substrate_km_positions, - product_km_positions=self.product_km_positions, - substrate_reactant_positions=self.substrate_reactant_positions, - product_reactant_positions=self.product_reactant_positions, - ix_ki_species=self.ix_ki_species, - ) - allosteric_effect = get_allosteric_effect( - conc=conc, - free_enzyme_ratio=free_enzyme_ratio, - tc=tc, - dc_inhibition=dc_inhibition, - dc_activation=dc_activation, - species_inhibition=self.species_inhibition, - species_activation=self.species_activation, - subunits=self.subunits, - ) - non_allosteric_rate = super().__call__(conc, parameters) - return non_allosteric_rate * allosteric_effect diff --git a/src/enzax/rate_equations/__init__.py b/src/enzax/rate_equations/__init__.py new file mode 100644 index 0000000..a0da47d --- /dev/null +++ b/src/enzax/rate_equations/__init__.py @@ -0,0 +1,19 @@ +from enzax.rate_equations.michaelis_menten import ( + ReversibleMichaelisMenten, + IrreversibleMichaelisMenten, + MichaelisMenten, +) +from enzax.rate_equations.generalised_mwc import ( + AllostericReversibleMichaelisMenten, + AllostericIrreversibleMichaelisMenten, +) +from enzax.rate_equations.drain import Drain + +AVAILABLE_RATE_EQUATIONS = [ + ReversibleMichaelisMenten, + IrreversibleMichaelisMenten, + MichaelisMenten, + AllostericReversibleMichaelisMenten, + AllostericIrreversibleMichaelisMenten, + Drain, +] diff --git a/src/enzax/rate_equations/drain.py b/src/enzax/rate_equations/drain.py new file mode 100644 index 0000000..0e1994c --- /dev/null +++ b/src/enzax/rate_equations/drain.py @@ -0,0 +1,27 @@ +from jax import numpy as jnp +from jaxtyping import PyTree, Scalar + +from enzax.rate_equation import ConcArray, RateEquation + + +def get_drain_flux(sign: Scalar, log_v: Scalar) -> Scalar: + """Get the flux of a drain reaction. + + :param sign: a scalar value (should be either one or zero) representing the direction of the reaction. + + :param log_v: a scalar representing the magnitude of the reaction, on log scale. + + """ # Noqa: E501 + return sign * jnp.exp(log_v) + + +class Drain(RateEquation): + """A drain reaction.""" + + sign: Scalar + drain_ix: int + + def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: + """Get the flux of a drain reaction.""" + log_v = parameters.log_drain[self.drain_ix] + return get_drain_flux(self.sign, log_v) diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py new file mode 100644 index 0000000..90f867e --- /dev/null +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -0,0 +1,94 @@ +from jax import numpy as jnp +from jaxtyping import Array, Float, Int, PyTree, Scalar + +from enzax.rate_equation import ConcArray +from enzax.rate_equations.michaelis_menten import ( + IrreversibleMichaelisMenten, + MichaelisMenten, + ReversibleMichaelisMenten, +) + + +def generalised_mwc_effect( + conc_inhibitor: Float[Array, " n_inhibition"], + dc_inhibitor: Float[Array, " n_inhibition"], + conc_activator: Float[Array, " n_activation"], + dc_activator: Float[Array, " n_activation"], + free_enzyme_ratio: Scalar, + tc: Scalar, + subunits: int, +) -> Scalar: + """Get the allosteric effect on a rate. + + The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2. + + """ # noqa: E501 + qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor) + qdenom = 1 + jnp.sum(conc_activator / dc_activator) + out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits) + return out + + +class GeneralisedMWC(MichaelisMenten): + """Mixin class for allosteric rate laws, assuming generalised MWC kinetics. + + See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2. + + Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation. + """ # noqa: E501 + + subunits: int + tc_ix: int + ix_dc_activation: Int[Array, " n_activation"] + ix_dc_inhibition: Int[Array, " n_inhibition"] + species_activation: Int[Array, " n_activation"] + species_inhibition: Int[Array, " n_inhibition"] + + def get_tc(self, parameters: PyTree) -> Scalar: + return jnp.exp(parameters.log_transfer_constant[self.tc_ix]) + + def get_dc_activation(self, parameters: PyTree) -> Scalar: + return jnp.exp( + parameters.log_dissociation_constant[self.ix_dc_activation], + ) + + def get_dc_inhibition(self, parameters: PyTree) -> Scalar: + return jnp.exp( + parameters.log_dissociation_constant[self.ix_dc_inhibition], + ) + + def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar: + return generalised_mwc_effect( + conc_inhibitor=conc[self.species_inhibition], + conc_activator=conc[self.species_activation], + free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters), + tc=self.get_tc(parameters), + dc_inhibitor=self.get_dc_inhibition(parameters), + dc_activator=self.get_dc_activation(parameters), + subunits=self.subunits, + ) + + +class AllostericIrreversibleMichaelisMenten( + GeneralisedMWC, IrreversibleMichaelisMenten +): + """A reaction with irreversible Michaelis Menten kinetics and allostery.""" + + def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: + """The flux of an irreversible allosteric Michaelis Menten reaction.""" + allosteric_effect = self.allosteric_effect(conc, parameters) + non_allosteric_rate = super().__call__(conc, parameters) + return non_allosteric_rate * allosteric_effect + + +class AllostericReversibleMichaelisMenten( + GeneralisedMWC, + ReversibleMichaelisMenten, +): + """A reaction with reversible Michaelis Menten kinetics and allostery.""" + + def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: + """The flux of an allosteric reversible Michaelis Menten reaction.""" + allosteric_effect = self.allosteric_effect(conc, parameters) + non_allosteric_rate = super().__call__(conc, parameters) + return non_allosteric_rate * allosteric_effect diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py new file mode 100644 index 0000000..55cf1d0 --- /dev/null +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -0,0 +1,193 @@ +from abc import abstractmethod +from jax import numpy as jnp +from jaxtyping import Array, Float, Int, PyTree, Scalar + +from enzax.rate_equation import RateEquation, ConcArray + + +def numerator_mm( + conc: ConcArray, + km: Float[Array, " n"], + ix_substrate: Int[Array, " n_substrate"], + substrate_km_positions: Int[Array, " n_substrate"], +) -> Scalar: + """Get the product of each substrate's concentration over its km. + + This quantity is the numerator in a Michaelis Menten reaction's rate equation + """ # Noqa: E501 + return jnp.prod((conc[ix_substrate] / km[substrate_km_positions])) + + +class MichaelisMenten(RateEquation): + """Base class for Michaelis Menten rate equations. + + Subclasses need to implement the __call__ and free_enzyme_ratio methods. + + """ + + kcat_ix: int + enzyme_ix: int + km_ix: Int[Array, " n"] + ki_ix: Int[Array, " n_ki"] + reactant_stoichiometry: Float[Array, " n"] + ix_substrate: Int[Array, " n_substrate"] + ix_ki_species: Int[Array, " n_ki"] + substrate_km_positions: Int[Array, " n_substrate"] + substrate_reactant_positions: Int[Array, " n_substrate"] + + def get_kcat(self, parameters: PyTree) -> Scalar: + return jnp.exp(parameters.log_kcat[self.kcat_ix]) + + def get_km(self, parameters: PyTree) -> Scalar: + return jnp.exp(parameters.log_km[self.km_ix]) + + def get_ki(self, parameters: PyTree) -> Scalar: + return jnp.exp(parameters.log_ki[self.ki_ix]) + + def get_enzyme(self, parameters: PyTree) -> Scalar: + return jnp.exp(parameters.log_enzyme[self.enzyme_ix]) + + @abstractmethod + def free_enzyme_ratio( + self, + conc: ConcArray, + parameters: PyTree, + ) -> Scalar: ... + + +def free_enzyme_ratio_imm( + conc_sub: Float[Array, " n_substrate"], + km_sub: Float[Array, " n_substrate"], + stoich_sub: Float[Array, " n_substrate"], + ki: Float[Array, " n_ki"], + conc_inhibitor: Float[Array, " n_ki"], +) -> Scalar: + """Free enzyme ratio for irreversible Michaelis Menten reactions.""" + return 1.0 / ( + jnp.prod(((conc_sub / km_sub) + 1) ** jnp.abs(stoich_sub)) + + jnp.sum(conc_inhibitor / ki) + ) + + +class IrreversibleMichaelisMenten(MichaelisMenten): + """A reaction with irreversible Michaelis Menten kinetics.""" + + def free_enzyme_ratio(self, conc, parameters): + return free_enzyme_ratio_imm( + conc_sub=conc[self.ix_substrate], + km_sub=self.get_km(parameters)[self.substrate_km_positions], + ki=self.get_ki(parameters), + conc_inhibitor=conc[self.ix_ki_species], + stoich_sub=self.reactant_stoichiometry[ + self.substrate_reactant_positions + ], + ) + + def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: + """Get flux of a reaction with irreversible Michaelis Menten kinetics.""" # noqa: E501 + kcat = self.get_kcat(parameters) + enzyme = self.get_enzyme(parameters) + km = self.get_km(parameters) + numerator = numerator_mm( + conc=conc, + km=km, + ix_substrate=self.ix_substrate, + substrate_km_positions=self.substrate_km_positions, + ) + free_enzyme_ratio = self.free_enzyme_ratio(conc, parameters) + return kcat * enzyme * numerator * free_enzyme_ratio + + +def get_reversibility( + conc: Float[Array, " n"], + water_stoichiometry: Scalar, + dgf: Float[Array, " n_reactant"], + temperature: Scalar, + reactant_stoichiometry: Float[Array, " n_reactant"], + ix_reactants: Int[Array, " n_reactant"], +) -> Scalar: + """Get the reversibility of a reaction. + + Hard coded water dgf is taken from . + + """ + RT = temperature * 0.008314 + dgf_water = -150.9 + dgr = reactant_stoichiometry @ dgf + water_stoichiometry * dgf_water + quotient = reactant_stoichiometry @ jnp.log(conc[ix_reactants]) + out = 1.0 - jnp.exp(((dgr + RT * quotient) / RT)) + return out + + +def free_enzyme_ratio_rmm( + conc_sub: Float[Array, " n_substrate"], + km_sub: Float[Array, " n_substrate"], + stoich_sub: Float[Array, " n_substrate"], + conc_prod: Float[Array, " n_product"], + km_prod: Float[Array, " n_prod"], + stoich_prod: Float[Array, " n_prod"], + conc_inhibitor: Float[Array, " n_ki"], + ki: Float[Array, " n_ki"], +) -> Scalar: + """The free enzyme ratio for a reversible Michaelis Menten reaction.""" + return 1.0 / ( + -1.0 + + jnp.prod(((conc_sub / km_sub) + 1.0) ** jnp.abs(stoich_sub)) + + jnp.prod(((conc_prod / km_prod) + 1.0) ** jnp.abs(stoich_prod)) + + jnp.sum(conc_inhibitor / ki) + ) + + +class ReversibleMichaelisMenten(MichaelisMenten): + """A reaction with reversible Michaelis Menten kinetics.""" + + ix_product: Int[Array, " n_product"] + ix_reactants: Int[Array, " n_reactant"] + product_reactant_positions: Int[Array, " n_product"] + product_km_positions: Int[Array, " n_product"] + water_stoichiometry: Scalar + reactant_to_dgf: Int[Array, " n_reactant"] + + def _get_dgf(self, parameters: PyTree): + return parameters.dgf[self.reactant_to_dgf] + + def free_enzyme_ratio(self, conc: ConcArray, parameters: PyTree) -> Scalar: + return free_enzyme_ratio_rmm( + conc_sub=conc[self.ix_substrate], + km_sub=self.get_km(parameters)[self.substrate_reactant_positions], + stoich_sub=self.reactant_stoichiometry[ + self.substrate_reactant_positions + ], + conc_prod=conc[self.ix_product], + km_prod=self.get_km(parameters)[self.product_reactant_positions], + stoich_prod=self.reactant_stoichiometry[ + self.product_reactant_positions + ], + conc_inhibitor=conc[self.ix_ki_species], + ki=self.get_ki(parameters), + ) + + def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: + """Get flux of a reaction with reversible Michaelis Menten kinetics. + + :param conc: A 1D array of non-negative numbers representing concentrations of the species that the reaction produces and consumes. + + """ # noqa: E501 + kcat = self.get_kcat(parameters) + enzyme = self.get_enzyme(parameters) + reversibility = get_reversibility( + conc=conc, + water_stoichiometry=self.water_stoichiometry, + dgf=self._get_dgf(parameters), + temperature=parameters.temperature, + reactant_stoichiometry=self.reactant_stoichiometry, + ix_reactants=self.ix_reactants, + ) + numerator = numerator_mm( + conc=conc, + km=self.get_km(parameters), + ix_substrate=self.ix_substrate, + substrate_km_positions=self.substrate_km_positions, + ) + free_enzyme_ratio = self.free_enzyme_ratio(conc, parameters) + return reversibility * kcat * enzyme * numerator * free_enzyme_ratio diff --git a/src/enzax/steady_state.py b/src/enzax/steady_state.py new file mode 100644 index 0000000..a89f745 --- /dev/null +++ b/src/enzax/steady_state.py @@ -0,0 +1,63 @@ +"""Module for solving steady state problems. + +Given a structural kinetic model, a set of parameters and an initial guess, the aim is to find the physiological steady state metabolite concentration and its parameter sensitivities. + +""" # noqa: E501 + +import diffrax +import equinox as eqx +import lineax as lx +from jaxtyping import Array, Float, PyTree + +from enzax.kinetic_model import KineticModel + + +@eqx.filter_jit() +def get_kinetic_model_steady_state( + model: KineticModel, + guess: Float[Array, " n_balanced"], +) -> PyTree: + """Get the steady state of a kinetic model, using diffrax. + + The better the guess (generally) the faster and more reliable the solving. + + :param model: a KineticModel object + + :param guess: a JAX array of floats. Must have the same length as the + model's number of balanced species. + + """ + term = diffrax.ODETerm(model.dcdt) + solver = diffrax.Kvaerno5() + t0 = 0 + t1 = 900 + dt0 = 0.000001 + max_steps = None + controller = diffrax.PIDController( + pcoeff=0.1, + icoeff=0.3, + rtol=1e-11, + atol=1e-11, + ) + cond_fn = diffrax.steady_state_event() + event = diffrax.Event(cond_fn) + adjoint = diffrax.ImplicitAdjoint( + linear_solver=lx.AutoLinearSolver(well_posed=False) + ) + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + guess, + args=model, + max_steps=max_steps, + stepsize_controller=controller, + event=event, + adjoint=adjoint, + ) + if sol.ys is not None: + return sol.ys[0] + else: + raise ValueError("No steady state found!") diff --git a/src/enzax/steady_state_problem.py b/src/enzax/steady_state_problem.py deleted file mode 100644 index 340e2df..0000000 --- a/src/enzax/steady_state_problem.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Given a structural kinetic model, a set of parameters and an initial guess, find the physiological steady state metabolite concentration and its parameter sensitivities.""" - -import time - - -import diffrax -import equinox as eqx -import jax -import jax.numpy as jnp -import lineax as lx -from jaxtyping import Array, Float -from enzax.examples import methionine -from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, - UnparameterisedKineticModel, -) - -jax.config.update("jax_debug_nans", True) -jax.config.update("jax_enable_x64", True) - - -@eqx.filter_jit() -def solve( - parameters: KineticModelParameters, - unparameterised_model: UnparameterisedKineticModel, - guess: Float[Array, " n"], -): - model = KineticModel(parameters, unparameterised_model) - term = diffrax.ODETerm(model.dcdt) - solver = diffrax.Kvaerno5() - t0 = 0 - t1 = 900 - dt0 = 0.000001 - max_steps = None - controller = diffrax.PIDController( - pcoeff=0.1, icoeff=0.3, rtol=1e-11, atol=1e-11 - ) - cond_fn = diffrax.steady_state_event() - event = diffrax.Event(cond_fn) - adjoint = diffrax.ImplicitAdjoint( - linear_solver=lx.AutoLinearSolver(well_posed=False) - ) - sol = diffrax.diffeqsolve( - term, - solver, - t0, - t1, - dt0, - guess, - args=model, - max_steps=max_steps, - stepsize_controller=controller, - event=event, - adjoint=adjoint, - # progress_meter=diffrax.TextProgressMeter(minimum_increase=0.001), - ) - return sol.ys[0] - - -def main(): - """Function for testing the steady state solver.""" - # guesses - bad_guess = jnp.full((5,), 0.01) - good_guess = jnp.array( - [ - 4.233000e-05, # met-L - 3.099670e-05, # amet - 2.170170e-07, # ahcys - 3.521780e-06, # hcys - 6.534400e-06, # 5mthf - ] - ) - model = methionine.model - # solve once for jitting - solve(methionine.parameters, methionine.unparameterised_model, good_guess) - jax.jacrev(solve)( - methionine.parameters, methionine.unparameterised_model, good_guess - ) - # compare good and bad guess - for guess in [bad_guess, good_guess]: - start = time.time() - conc_steady = solve( - methionine.parameters, methionine.unparameterised_model, guess - ) - jac = jax.jacrev(solve)( - methionine.parameters, methionine.unparameterised_model, guess - ) - runtime = (time.time() - start) * 1e3 - sv = model.dcdt(jnp.array(0.0), conc=conc_steady) - flux = model.flux(conc_steady) - print(f"Results with starting guess {guess}:") - print(f"\tRun time in milliseconds: {round(runtime, 4)}") - print(f"\tSteady state concentration: {conc_steady}") - print(f"\tFlux: {flux}") - print(f"\tSv: {sv}") - print(f"\tJacobian: {jac}") - print(f"\tLog Km Jacobian: {jac.log_km}") - print(f"\tDgf Jacobian: {jac.dgf}") - - -if __name__ == "__main__": - main() diff --git a/tests/test_kinetic_model.py b/tests/test_examples.py similarity index 100% rename from tests/test_kinetic_model.py rename to tests/test_examples.py diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index 9816a69..a511e77 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -7,10 +7,10 @@ IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.kinetic_model import KineticModelParameters +from enzax.parameters import AllostericMichaelisMentenParameterSet EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1, 0.3]) -EXAMPLE_PARAMETERS = KineticModelParameters( +EXAMPLE_PARAMETERS = AllostericMichaelisMentenParameterSet( log_kcat=jnp.array([-0.1]), log_enzyme=jnp.log(jnp.array([0.3])), dgf=jnp.array([-3, -1.0]),