diff --git a/docs/api/kinetic_model.md b/docs/api/kinetic_model.md index 89bc6f4..43540eb 100644 --- a/docs/api/kinetic_model.md +++ b/docs/api/kinetic_model.md @@ -4,7 +4,6 @@ filters: - "!check" members: - - KineticModel - - UnparameterisedKineticModel - - KineticModelParameters - KineticModelStructure + - KineticModel + - RateEquationModel diff --git a/docs/getting_started.md b/docs/getting_started.md index 73f3f7f..f64bdd6 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -22,10 +22,8 @@ First we import some enzax classes: ```python from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, KineticModelStructure, - UnparameterisedKineticModel, + RateEquationModel, ) from enzax.rate_equations import ( AllostericReversibleMichaelisMenten, @@ -49,7 +47,9 @@ structure = KineticModelStructure( Next we provide some kinetic parameter values: ```python -parameters = KineticModelParameters( +from enzax.parameters import AllostericMichaelisMentenParameters + +parameters = AllostericMichaelisMentenParameters( log_kcat=jnp.array([-0.1, 0.0, 0.1]), log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), dgf=jnp.array([-3, -1.0]), @@ -65,6 +65,11 @@ parameters = KineticModelParameters( Now we can use enzax's rate laws to specify how each reaction behaves: ```python +from enzax.rate_equations import ( + AllostericReversibleMichaelisMenten, + ReversibleMichaelisMenten, +) + r0 = AllostericReversibleMichaelisMenten( kcat_ix=0, enzyme_ix=0, @@ -130,16 +135,10 @@ r2 = ReversibleMichaelisMenten( ) ``` -Next an unparameterised kinetic model +Now we can declare our model: ```python -unparameterised_model = UnparameterisedKineticModel(structure, [r0, r1, r2]) -``` - -Finally a parameterised model: - -```python -model = KineticModel(parameters, unparameterised_model) +model = RateEquationModel(structure, parameters, [r0, r1, r2]) ``` To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations: @@ -157,28 +156,24 @@ dcdt ## Find a kinetic model's steady state -Enzax provides a few example kinetic models, including [`methionine`](https://github.com/dtu-qmcm/enzax/blob/main/src/enzax/examples/methionine.py), a model of the mammallian methionine cycle. +Enzax provides a few example kinetic models, including [`methionine`](https://github.com/dtu-qmcm/enzax/blob/main/src/enzax/examples/methionine.py), a model of the mammalian methionine cycle. -Here is how to find this model's steady state (and its parameter gradients) using enzax's `solve_steady_state` function: +Here is how to find this model's steady state (and its parameter gradients) using enzax's `get_kinetic_model_steady_state` function: ```python from enzax.examples import methionine -from enzax.steady_state import solve_steady_state +from enzax.steady_state import get_kinetic_model_steady_state from jax import numpy as jnp guess = jnp.full((5,) 0.01) -steady_state = solve_steady_state( - methionine.parameters, methionine.unparameterised_model, guess -) +steady_state = get_kinetic_model_steady_state(methionine.model, guess) ``` -To find the jacobian of this steady state with respect to the model's parameters, we can wrap `solve_steady_state` in JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function: +To find the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function: ```python import jax -jacobian = jax.jacrev(solve_steady_state)( - methionine.parameters, methionine.unparameterised_model, guess -) +jacobian = jax.jacrev(solve_steady_state)(methionine.model, guess) ``` diff --git a/scripts/mcmc_demo.py b/scripts/mcmc_demo.py new file mode 100644 index 0000000..63bac12 --- /dev/null +++ b/scripts/mcmc_demo.py @@ -0,0 +1,126 @@ +"""Demonstration of how to make a Bayesian kinetic model with enzax.""" + +import functools +import logging +import warnings + +import arviz as az +import jax +from jax import numpy as jnp + +from enzax.examples import methionine +from enzax.mcmc import ( + ObservationSet, + AllostericMichaelisMentenPriorSet, + get_idata, + ind_prior_from_truth, + posterior_logdensity_amm, + run_nuts, +) +from enzax.steady_state import get_kinetic_model_steady_state + +SEED = 1234 + +jax.config.update("jax_enable_x64", True) + + +def main(): + """Demonstrate How to make a Bayesian kinetic model with enzax.""" + structure = methionine.structure + rate_equations = methionine.rate_equations + true_parameters = methionine.parameters + true_model = methionine.model + default_state_guess = jnp.full((5,), 0.01) + true_states = get_kinetic_model_steady_state( + true_model, default_state_guess + ) + prior = AllostericMichaelisMentenPriorSet( + log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), + log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), + log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), + dgf=ind_prior_from_truth(true_parameters.dgf, 0.1), + log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), + log_conc_unbalanced=ind_prior_from_truth( + true_parameters.log_conc_unbalanced, 0.1 + ), + temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), + log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), + log_transfer_constant=ind_prior_from_truth( + true_parameters.log_transfer_constant, 0.1 + ), + log_dissociation_constant=ind_prior_from_truth( + true_parameters.log_dissociation_constant, 0.1 + ), + ) + # get true concentration + true_conc = jnp.zeros(methionine.structure.S.shape[0]) + true_conc = true_conc.at[methionine.structure.balanced_species].set( + true_states + ) + true_conc = true_conc.at[methionine.structure.unbalanced_species].set( + jnp.exp(true_parameters.log_conc_unbalanced) + ) + # get true flux + true_flux = true_model.flux(true_states) + # simulate observations + error_conc = 0.03 + error_flux = 0.05 + error_enzyme = 0.03 + key = jax.random.key(SEED) + obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) + obs_enzyme = jnp.exp( + true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme + ) + obs_flux = true_flux + jax.random.normal(key) * error_conc + obs = ObservationSet( + conc=obs_conc, + flux=obs_flux, + enzyme=obs_enzyme, + conc_scale=error_conc, + flux_scale=error_flux, + enzyme_scale=error_enzyme, + ) + pldf = functools.partial( + posterior_logdensity_amm, + obs=obs, + prior=prior, + structure=structure, + rate_equations=rate_equations, + guess=default_state_guess, + ) + samples, info = run_nuts( + pldf, + key, + true_parameters, + num_warmup=200, + num_samples=200, + initial_step_size=0.0001, + max_num_doublings=10, + is_mass_matrix_diagonal=False, + target_acceptance_rate=0.95, + ) + idata = get_idata( + samples, info, coords=methionine.coords, dims=methionine.dims + ) + print(az.summary(idata)) + if jnp.any(info.is_divergent): + n_divergent = info.is_divergent.sum() + msg = f"There were {n_divergent} post-warmup divergent transitions." + warnings.warn(msg) + else: + logging.info("No post-warmup divergent transitions!") + print("True parameter values vs posterior:") + for param in true_parameters.__dataclass_fields__.keys(): + true_val = getattr(true_parameters, param) + model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) + model_high = jnp.quantile( + getattr(samples.position, param), 0.99, axis=0 + ) + print(f" {param}:") + print(f" true value: {true_val}") + print(f" posterior 1%: {model_low}") + print(f" posterior 99%: {model_high}") + + +if __name__ == "__main__": + main() diff --git a/scripts/steady_state_demo.py b/scripts/steady_state_demo.py new file mode 100644 index 0000000..5feeda3 --- /dev/null +++ b/scripts/steady_state_demo.py @@ -0,0 +1,62 @@ +"""Demonstration of how to find a steady state and its gradients with enzax.""" + +import time +from enzax.kinetic_model import RateEquationModel + +import jax +from jax import numpy as jnp + +from enzax.examples import methionine +from enzax.steady_state import get_kinetic_model_steady_state +from jaxtyping import PyTree + +BAD_GUESS = jnp.full((5,), 0.01) +GOOD_GUESS = jnp.array( + [ + 4.233000e-05, # met-L + 3.099670e-05, # amet + 2.170170e-07, # ahcys + 3.521780e-06, # hcys + 6.534400e-06, # 5mthf + ] +) + + +def main(): + """Function for testing the steady state solver.""" + model = methionine.model + # compare good and bad guess + for guess in [BAD_GUESS, GOOD_GUESS]: + + def get_steady_state_from_params(parameters: PyTree): + """Get the steady state from just parameters. + + This lets us get the Jacobian wrt (just) the parameters. + """ + _model = RateEquationModel( + parameters, model.structure, model.rate_equations + ) + return get_kinetic_model_steady_state(_model, guess) + + # solve once for jitting + get_kinetic_model_steady_state(model, GOOD_GUESS) + jax.jacrev(get_steady_state_from_params)(model.parameters) + # timer on + start = time.time() + conc_steady = get_kinetic_model_steady_state(model, guess) + jac = jax.jacrev(get_steady_state_from_params)(model.parameters) + # timer off + runtime = (time.time() - start) * 1e3 + sv = model.dcdt(jnp.array(0.0), conc=conc_steady) + flux = model.flux(conc_steady) + print(f"Results with starting guess {guess}:") + print(f"\tRun time in milliseconds: {round(runtime, 4)}") + print(f"\tSteady state concentration: {conc_steady}") + print(f"\tFlux: {flux}") + print(f"\tSv: {sv}") + print(f"\tLog Km Jacobian: {jac.log_km}") + print(f"\tDgf Jacobian: {jac.dgf}") + + +if __name__ == "__main__": + main() diff --git a/src/enzax/__init__.py b/src/enzax/__init__.py index e69de29..d5f6fc4 100644 --- a/src/enzax/__init__.py +++ b/src/enzax/__init__.py @@ -0,0 +1,3 @@ +from jax import config + +config.update("jax_enable_x64", True) diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 33aa744..134167d 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -1,22 +1,18 @@ """A simple linear kinetic model.""" -from jax import config from jax import numpy as jnp from enzax.kinetic_model import ( - KineticModel, + RateEquationModel, KineticModelStructure, - UnparameterisedKineticModel, ) from enzax.rate_equations import ( AllostericReversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameters +from enzax.parameters import AllostericMichaelisMentenParameterSet -config.update("jax_enable_x64", True) - -parameters = AllostericMichaelisMentenParameters( +parameters = AllostericMichaelisMentenParameterSet( log_kcat=jnp.array([-0.1, 0.0, 0.1]), log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), dgf=jnp.array([-3, -1.0]), @@ -35,74 +31,70 @@ balanced_species=jnp.array([1, 2]), unbalanced_species=jnp.array([0, 3]), ) - -unparameterised_model = UnparameterisedKineticModel( - structure, - [ - AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - subunits=1, - ), - AllostericReversibleMichaelisMenten( - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([0]), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([1]), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([1], dtype=jnp.int16), - ix_product=jnp.array([2], dtype=jnp.int16), - ix_reactants=jnp.array([1, 2], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=1, - ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), - ix_dc_activation=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([1], dtype=jnp.int16), - subunits=1, - ), - ReversibleMichaelisMenten( - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4, 5], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - ix_substrate=jnp.array([2], dtype=jnp.int16), - ix_product=jnp.array([3], dtype=jnp.int16), - ix_reactants=jnp.array([2, 3], dtype=jnp.int16), - reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - ), - ], -) -model = KineticModel(parameters, unparameterised_model) +rate_equations = [ + AllostericReversibleMichaelisMenten( + kcat_ix=0, + enzyme_ix=0, + km_ix=jnp.array([0, 1], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ix_substrate=jnp.array([0], dtype=jnp.int16), + ix_product=jnp.array([1], dtype=jnp.int16), + ix_reactants=jnp.array([0, 1], dtype=jnp.int16), + product_reactant_positions=jnp.array([1], dtype=jnp.int16), + product_km_positions=jnp.array([1], dtype=jnp.int16), + water_stoichiometry=jnp.array(0.0), + tc_ix=0, + ix_dc_inhibition=jnp.array([], dtype=jnp.int16), + ix_dc_activation=jnp.array([0], dtype=jnp.int16), + species_activation=jnp.array([2], dtype=jnp.int16), + species_inhibition=jnp.array([], dtype=jnp.int16), + subunits=1, + ), + AllostericReversibleMichaelisMenten( + kcat_ix=1, + enzyme_ix=1, + km_ix=jnp.array([2, 3], dtype=jnp.int16), + ki_ix=jnp.array([0]), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([1]), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ix_substrate=jnp.array([1], dtype=jnp.int16), + ix_product=jnp.array([2], dtype=jnp.int16), + ix_reactants=jnp.array([1, 2], dtype=jnp.int16), + product_reactant_positions=jnp.array([1], dtype=jnp.int16), + product_km_positions=jnp.array([1], dtype=jnp.int16), + water_stoichiometry=jnp.array(0.0), + tc_ix=1, + ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), + ix_dc_activation=jnp.array([], dtype=jnp.int16), + species_activation=jnp.array([], dtype=jnp.int16), + species_inhibition=jnp.array([1], dtype=jnp.int16), + subunits=1, + ), + ReversibleMichaelisMenten( + kcat_ix=2, + enzyme_ix=2, + km_ix=jnp.array([4, 5], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + ix_substrate=jnp.array([2], dtype=jnp.int16), + ix_product=jnp.array([3], dtype=jnp.int16), + ix_reactants=jnp.array([2, 3], dtype=jnp.int16), + reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + product_reactant_positions=jnp.array([1], dtype=jnp.int16), + product_km_positions=jnp.array([1], dtype=jnp.int16), + water_stoichiometry=jnp.array(0.0), + ), +] steady_state = jnp.array([0.43658744, 0.12695706]) +model = RateEquationModel(parameters, structure, rate_equations) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 6540e0f..1179b11 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -1,12 +1,15 @@ -"""A simple linear kinetic model.""" +"""A kinetic model of the methionine cycle. + +See here for more about the methionine cycle: +https://doi.org/10.1021/acssynbio.3c00662 + +""" -from jax import config from jax import numpy as jnp from enzax.kinetic_model import ( - KineticModel, + RateEquationModel, KineticModelStructure, - UnparameterisedKineticModel, ) from enzax.rate_equations import ( AllostericIrreversibleMichaelisMenten, @@ -14,12 +17,154 @@ IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameters - +from enzax.parameters import AllostericMichaelisMentenParameterSet -config.update("jax_enable_x64", True) - -parameters = AllostericMichaelisMentenParameters( +coords = { + "enzyme": [ + "MAT1", + "MAT3", + "METH-Gen", + "GNMT1", + "AHC1", + "MS1", + "BHMT1", + "CBS1", + "MTHFR1", + "PROT1", + ], + "drain": ["the_drain"], + "rate": [ + "the_drain", + "MAT1", + "MAT3", + "METH-Gen", + "GNMT1", + "AHC1", + "MS1", + "BHMT1", + "CBS1", + "MTHFR1", + "PROT1", + ], + "metabolite": [ + "met-L", + "atp", + "pi", + "ppi", + "amet", + "ahcys", + "gly", + "sarcs", + "hcys-L", + "adn", + "thf", + "5mthf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", + ], + "km": [ + "met-L MAT1", + "atp MAT1", + "met-L MAT3", + "atp MAT3", + "amet METH-Gen", + "amet GNMT1", + "gly GNMT1", + "ahcys AHC1", + "hcys-L AHC1", + "adn AHC1", + "hcys-L MS1", + "5mthf MS1", + "hcys-L BHMT1", + "glyb BHMT1", + "hcys-L CBS1", + "ser-L CBS1", + "mlthf MTHFR1", + "nadph MTHFR1", + "met-L PROT1", + ], + "ki": [ + "MAT1", + "METH-Gen", + "GNMT1", + ], + "species": [ + "met-L", + "atp", + "pi", + "ppi", + "amet", + "ahcys", + "gly", + "sarcs", + "hcys-L", + "adn", + "thf", + "5mthf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", + ], + "balanced_species": [ + "met-L", + "amet", + "ahcys", + "hcys-L", + "5mthf", + ], + "unbalanced_species": [ + "atp", + "pi", + "ppi", + "gly", + "sarcs", + "adn", + "thf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", + ], + "transfer_constant": [ + "METAT", + "GNMT", + "CBS", + "MTHFR", + ], + "dissociation_constant": [ + "met-L MAT3", + "amet MAT3", + "amet GNMT1", + "mlthf GNMT1", + "amet CBS1", + "amet MTHFR1", + "ahcys MTHFR1", + ], +} +dims = { + "log_kcat": ["enzyme"], + "log_enzyme": ["enzyme"], + "log_drain": ["drain"], + "log_km": ["km"], + "dgf": ["metabolite"], + "log_ki": ["ki"], + "log_conc_unbalanced": ["unbalanced_species"], + "log_transfer_constant": ["transfer_constant"], + "log_dissociation_constant": ["dissociation_constant"], +} +parameters = AllostericMichaelisMentenParameterSet( log_kcat=jnp.log( jnp.array( [ @@ -209,159 +354,155 @@ ] ), ) -unparameterised_model = UnparameterisedKineticModel( - structure, - [ - Drain(sign=jnp.array(1.0), drain_ix=0), # met-L source - IrreversibleMichaelisMenten( # MAT1 - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([0], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 - ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([4], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MAT3 - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 - ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0, 1], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([0, 4], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # METH - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4], dtype=jnp.int16), - ki_ix=jnp.array([1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([4], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), +rate_equations = [ + Drain(sign=jnp.array(1.0), drain_ix=0), # met-L source + IrreversibleMichaelisMenten( # MAT1 + kcat_ix=0, + enzyme_ix=0, + km_ix=jnp.array([0, 1], dtype=jnp.int16), + ki_ix=jnp.array([0], dtype=jnp.int16), + reactant_stoichiometry=jnp.array( + [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 ), - AllostericIrreversibleMichaelisMenten( # GNMT1 - kcat_ix=3, - enzyme_ix=3, - km_ix=jnp.array([5, 6], dtype=jnp.int16), - ki_ix=jnp.array([2], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, 1.0, -1.0, 1.0], dtype=jnp.int16 - ), - ix_substrate=jnp.array([4, 6], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 2], dtype=jnp.int16), - subunits=4, - tc_ix=1, - ix_dc_activation=jnp.array([2], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([3], dtype=jnp.int16), - species_inhibition=jnp.array([12], dtype=jnp.int16), - species_activation=jnp.array([4], dtype=jnp.int16), - ), - ReversibleMichaelisMenten( # AHC - kcat_ix=4, - enzyme_ix=4, - km_ix=jnp.array([7, 8, 9], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0, 1.0, 1.0], dtype=jnp.int16), - ix_substrate=jnp.array([5], dtype=jnp.int16), - ix_product=jnp.array([8, 9], dtype=jnp.int16), - ix_reactants=jnp.array([5, 8, 9], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - product_km_positions=jnp.array([1, 2], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - water_stoichiometry=jnp.array(-1.0), - reactant_to_dgf=jnp.array([5, 8, 9], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # MS - kcat_ix=5, - enzyme_ix=5, - km_ix=jnp.array([10, 11], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 11], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # BHMT - kcat_ix=6, - enzyme_ix=6, - km_ix=jnp.array([12, 13], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 13], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # CBS1 - kcat_ix=7, - enzyme_ix=7, - km_ix=jnp.array([14, 15], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 15], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=2, - ix_dc_activation=jnp.array([4], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MTHFR - kcat_ix=8, - enzyme_ix=8, - km_ix=jnp.array([16, 17], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([12, 17], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - subunits=2, - tc_ix=3, - ix_dc_activation=jnp.array([6], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([5], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([5], dtype=jnp.int16), + ix_substrate=jnp.array([0, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([4], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # MAT3 + kcat_ix=1, + enzyme_ix=1, + km_ix=jnp.array([2, 3], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array( + [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 ), - IrreversibleMichaelisMenten( # PROT - kcat_ix=9, - enzyme_ix=9, - km_ix=jnp.array([18], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ix_substrate=jnp.array([0, 1], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), + subunits=2, + tc_ix=0, + ix_dc_inhibition=jnp.array([], dtype=jnp.int16), + ix_dc_activation=jnp.array([0, 1], dtype=jnp.int16), + species_inhibition=jnp.array([], dtype=jnp.int16), + species_activation=jnp.array([0, 4], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # METH + kcat_ix=2, + enzyme_ix=2, + km_ix=jnp.array([4], dtype=jnp.int16), + ki_ix=jnp.array([1], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), + ix_substrate=jnp.array([4], dtype=jnp.int16), + ix_ki_species=jnp.array([5], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # GNMT1 + kcat_ix=3, + enzyme_ix=3, + km_ix=jnp.array([5, 6], dtype=jnp.int16), + ki_ix=jnp.array([2], dtype=jnp.int16), + reactant_stoichiometry=jnp.array( + [-1.0, 1.0, -1.0, 1.0], dtype=jnp.int16 ), - ], -) -model = KineticModel(parameters, unparameterised_model) + ix_substrate=jnp.array([4, 6], dtype=jnp.int16), + ix_ki_species=jnp.array([5], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 2], dtype=jnp.int16), + subunits=4, + tc_ix=1, + ix_dc_activation=jnp.array([2], dtype=jnp.int16), + ix_dc_inhibition=jnp.array([3], dtype=jnp.int16), + species_inhibition=jnp.array([12], dtype=jnp.int16), + species_activation=jnp.array([4], dtype=jnp.int16), + ), + ReversibleMichaelisMenten( # AHC + kcat_ix=4, + enzyme_ix=4, + km_ix=jnp.array([7, 8, 9], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1.0, 1.0, 1.0], dtype=jnp.int16), + ix_substrate=jnp.array([5], dtype=jnp.int16), + ix_product=jnp.array([8, 9], dtype=jnp.int16), + ix_reactants=jnp.array([5, 8, 9], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + product_km_positions=jnp.array([1, 2], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + product_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), + water_stoichiometry=jnp.array(-1.0), + reactant_to_dgf=jnp.array([5, 8, 9], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # MS + kcat_ix=5, + enzyme_ix=5, + km_ix=jnp.array([10, 11], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), + ix_substrate=jnp.array([8, 11], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # BHMT + kcat_ix=6, + enzyme_ix=6, + km_ix=jnp.array([12, 13], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([1, -1, -1, 1], dtype=jnp.int16), + ix_substrate=jnp.array([8, 13], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # CBS1 + kcat_ix=7, + enzyme_ix=7, + km_ix=jnp.array([14, 15], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1, -1, 1], dtype=jnp.int16), + ix_substrate=jnp.array([8, 15], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), + subunits=2, + tc_ix=2, + ix_dc_activation=jnp.array([4], dtype=jnp.int16), + ix_dc_inhibition=jnp.array([], dtype=jnp.int16), + species_inhibition=jnp.array([4], dtype=jnp.int16), + species_activation=jnp.array([], dtype=jnp.int16), + ), + AllostericIrreversibleMichaelisMenten( # MTHFR + kcat_ix=8, + enzyme_ix=8, + km_ix=jnp.array([16, 17], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), + ix_substrate=jnp.array([12, 17], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), + subunits=2, + tc_ix=3, + ix_dc_activation=jnp.array([6], dtype=jnp.int16), + ix_dc_inhibition=jnp.array([5], dtype=jnp.int16), + species_inhibition=jnp.array([4], dtype=jnp.int16), + species_activation=jnp.array([5], dtype=jnp.int16), + ), + IrreversibleMichaelisMenten( # PROT + kcat_ix=9, + enzyme_ix=9, + km_ix=jnp.array([18], dtype=jnp.int16), + ki_ix=jnp.array([], dtype=jnp.int16), + reactant_stoichiometry=jnp.array([-1.0], dtype=jnp.int16), + ix_substrate=jnp.array([0], dtype=jnp.int16), + ix_ki_species=jnp.array([], dtype=jnp.int16), + substrate_km_positions=jnp.array([0], dtype=jnp.int16), + substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), + ), +] steady_state = jnp.array( [ 4.233000e-05, # met-L @@ -371,3 +512,4 @@ 6.534400e-06, # 5mthf ] ) +model = RateEquationModel(parameters, structure, rate_equations) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 0ccf343..439b7cd 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -1,5 +1,7 @@ """Module containing enzax's definition of a kinetic model.""" +from abc import ABC, abstractmethod + import equinox as eqx import jax.numpy as jnp from jaxtyping import Array, Float, Int, PyTree, ScalarLike, jaxtyped @@ -12,29 +14,44 @@ class KineticModelStructure(eqx.Module): """Structural information about a kinetic model.""" - S: Float[Array, " s r"] = eqx.field(static=True) - balanced_species: Int[Array, " n_balanced"] = eqx.field(static=True) - unbalanced_species: Int[Array, " n_unbalanced"] = eqx.field(static=True) + S: Float[Array, " s r"] + balanced_species: Int[Array, " n_balanced"] + unbalanced_species: Int[Array, " n_unbalanced"] -class UnparameterisedKineticModel(eqx.Module): - """A kinetic model without parameter values.""" +class KineticModel(eqx.Module, ABC): + """Abstract base class for kinetic models.""" + parameters: PyTree structure: KineticModelStructure - rate_equations: list[RateEquation] | None = None + rate_equations: list[RateEquation] = eqx.field(default_factory=list) + @abstractmethod + def flux( + self, + conc_balanced: Float[Array, " n_balanced"], + ) -> Float[Array, " n"]: ... -class KineticModel(eqx.Module): - """A parameterised kinetic model.""" + def dcdt( + self, t: ScalarLike, conc: Float[Array, " n_balanced"], args=None + ) -> Float[Array, " n_balanced"]: + """Get the rate of change of balanced species concentrations. - parameters: PyTree - structure: KineticModelStructure - rate_equations: list[RateEquation] | None = None + Note that the signature is as required for a Diffrax vector field function, hence the redundant variable t and the weird name "args". - def __init__(self, parameters, unparameterised_model): - self.parameters = parameters - self.structure = unparameterised_model.structure - self.rate_equations = unparameterised_model.rate_equations + :param t: redundant variable representing time. + + :param conc: a one dimensional array of positive floats representing concentrations of balanced species. Must have same size as self.structure.ix_balanced + + """ # Noqa: E501 + sv = self.structure.S @ self.flux(conc) + return sv[self.structure.balanced_species] + + +class RateEquationModel(KineticModel): + """A kinetic model that specifies its fluxes using RateEquation objects.""" + + rate_equations: list[RateEquation] = eqx.field(default_factory=list) def flux( self, @@ -55,18 +72,3 @@ def flux( t = [f(conc, self.parameters) for f in self.rate_equations] out = jnp.array(t) return out - - def dcdt( - self, t: ScalarLike, conc: Float[Array, " n_balanced"], args=None - ) -> Float[Array, " n_balanced"]: - """Get the rate of change of balanced species concentrations. - - Note that the signature is as required for a Diffrax vector field function, hence the redundant variable t and the weird name "args". - - :param t: redundant variable representing time. - - :param conc: a one dimensional array of positive floats representing concentrations of balanced species. Must have same size as self.structure.ix_balanced - - """ # Noqa: E501 - sv = self.structure.S @ self.flux(conc) - return sv[self.structure.balanced_species] diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index bf18618..e26111b 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -1,29 +1,24 @@ -"""Code for doing mcmc on the parameters of a steady state problem.""" +"""Code for MCMC-based Bayesian inference on kinetic models.""" import functools -import logging -import warnings +from typing import Callable, TypedDict, Unpack import arviz as az import blackjax import chex -import equinox as eqx import jax +from jax._src.random import KeyArray import jax.numpy as jnp from jax.scipy.stats import norm -from jaxtyping import Array, Float, ScalarLike +from jaxtyping import Array, Float, PyTree, ScalarLike -from enzax.examples import methionine from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, - UnparameterisedKineticModel, + RateEquationModel, + KineticModelStructure, ) -from enzax.steady_state_problem import solve - -SEED = 1234 -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_debug_nans", True) +from enzax.parameters import AllostericMichaelisMentenParameterSet +from enzax.rate_equation import RateEquation +from enzax.steady_state import get_kinetic_model_steady_state @chex.dataclass @@ -37,9 +32,9 @@ class ObservationSet: @chex.dataclass -class PriorSet: - log_kcat: Float[Array, "2 n"] - log_enzyme: Float[Array, "2 n"] +class AllostericMichaelisMentenPriorSet: + log_kcat: Float[Array, "2 n_enzyme"] + log_enzyme: Float[Array, "2 n_enzyme"] log_drain: Float[Array, "2 n_drain"] dgf: Float[Array, "2 n_metabolite"] log_km: Float[Array, "2 n_km"] @@ -47,23 +42,34 @@ class PriorSet: log_conc_unbalanced: Float[Array, "2 n_unbalanced"] temperature: Float[Array, "2"] log_transfer_constant: Float[Array, "2 n_allosteric_enzyme"] - log_dissociation_constant: Float[Array, "2 n_allosteric_effector"] + log_dissociation_constant: Float[Array, "2 n_allosteric_effect"] + + +class AdaptationKwargs(TypedDict): + """Keyword arguments to the blackjax function window_adaptation.""" + initial_step_size: float + max_num_doublings: int + is_mass_matrix_diagonal: bool + target_acceptance_rate: float -def ind_normal_prior_logdensity(param, prior): + +def ind_normal_prior_logdensity(param, prior: Float[Array, "2 _"]): + """Total log density for an independent normal distribution.""" return norm.logpdf(param, loc=prior[0], scale=prior[1]).sum() -@eqx.filter_jit -def posterior_logdensity_fn( - parameters: KineticModelParameters, - unparameterised_model: UnparameterisedKineticModel, +def posterior_logdensity_amm( + parameters: AllostericMichaelisMentenParameterSet, + structure: KineticModelStructure, + rate_equations: list[RateEquation], obs: ObservationSet, - prior: PriorSet, + prior: AllostericMichaelisMentenPriorSet, guess: Float[Array, " n_balanced"], ): - model = KineticModel(parameters, unparameterised_model) - steady = solve(parameters, unparameterised_model, guess) + """Get the log density for an allosteric Michaelis-Menten model.""" + model = RateEquationModel(parameters, structure, rate_equations) + steady = get_kinetic_model_steady_state(model, guess) flux = model.flux(steady) conc = jnp.zeros(model.structure.S.shape[0]) conc = conc.at[model.structure.balanced_species].set(steady) @@ -101,37 +107,37 @@ def posterior_logdensity_fn( @functools.partial(jax.jit, static_argnames=["kernel", "num_samples"]) def inference_loop(rng_key, kernel, initial_state, num_samples): + """Run MCMC with blackjax.""" + def one_step(state, rng_key): state, info = kernel(rng_key, state) return state, (state, info) keys = jax.random.split(rng_key, num_samples) _, (states, info) = jax.lax.scan(one_step, initial_state, keys) - return states, info -def warn_if_divergent(info): - if jnp.any(info.is_divergent): - warnings.warn("I found a divergent transition!") - - -@eqx.filter_jit -def sample(logdensity_fn, rng_key, init_parameters): +def run_nuts( + logdensity_fn: Callable, + rng_key: KeyArray, + init_parameters: PyTree, + num_warmup: int, + num_samples: int, + **adapt_kwargs: Unpack[AdaptationKwargs], +): + """Run the default NUTS algorithm with blackjax.""" warmup = blackjax.window_adaptation( blackjax.nuts, logdensity_fn, progress_bar=True, - initial_step_size=0.0001, - max_num_doublings=10, - is_mass_matrix_diagonal=False, - target_acceptance_rate=0.95, + **adapt_kwargs, ) rng_key, warmup_key = jax.random.split(rng_key) (initial_state, tuned_parameters), (_, info, _) = warmup.run( warmup_key, init_parameters, - num_steps=1000, # type: ignore + num_steps=num_warmup, #  type: ignore ) rng_key, sample_key = jax.random.split(rng_key) nuts_kernel = blackjax.nuts(logdensity_fn, **tuned_parameters).step @@ -139,110 +145,36 @@ def sample(logdensity_fn, rng_key, init_parameters): sample_key, kernel=nuts_kernel, initial_state=initial_state, - num_samples=200, + num_samples=num_samples, ) return states, info -def ind_prior_from_truth(truth, sd): +def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike): + """Get a set of independent priors centered at the true parameter values. + + Note that the standard deviation currently has to be the same for + all parameters. + + """ return jnp.vstack((truth, jnp.full(truth.shape, sd))) -def get_idata(samples, info): +def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData: + """Get an arviz InferenceData from a blackjax NUTS output.""" sample_dict = { k: jnp.expand_dims(getattr(samples.position, k), 0) for k in samples.position.__dataclass_fields__.keys() } - posterior = az.convert_to_inference_data(sample_dict, group="posterior") + posterior = az.convert_to_inference_data( + sample_dict, + group="posterior", + coords=coords, + dims=dims, + ) sample_stats = az.convert_to_inference_data( {"diverging": info.is_divergent}, group="sample_stats" ) - return az.concat(posterior, sample_stats) - - -def main(): - """Demonstrate the functionality of the mcmc module.""" - true_parameters = methionine.parameters - unparameterised_model = methionine.unparameterised_model - true_model = methionine.model - default_state_guess = jnp.full((5,), 0.01) - true_states = solve( - true_parameters, unparameterised_model, default_state_guess - ) - prior = PriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=ind_prior_from_truth(true_parameters.dgf, 0.1), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) - # get true concentration - true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( - true_states - ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( - jnp.exp(true_parameters.log_conc_unbalanced) - ) - # get true flux - true_flux = true_model.flux(true_states) - # simulate observations - error_conc = 0.03 - error_flux = 0.05 - error_enzyme = 0.03 - key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) - obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme - ) - obs_flux = true_flux + jax.random.normal(key) * error_conc - obs = ObservationSet( - conc=obs_conc, - flux=obs_flux, - enzyme=obs_enzyme, - conc_scale=error_conc, - flux_scale=error_flux, - enzyme_scale=error_enzyme, - ) - log_M = functools.partial( - posterior_logdensity_fn, - obs=obs, - prior=prior, - unparameterised_model=unparameterised_model, - guess=default_state_guess, - ) - samples, info = sample(log_M, key, true_parameters) - idata = get_idata(samples, info) - print(az.summary(idata)) - if jnp.any(info.is_divergentl): - n_divergent = info.is_divergent.sum() - msg = f"There were {n_divergent} post-warmup divergent transitions." - warnings.warn(msg) - else: - logging.info("No post-warmup divergent transitions!") - print("True parameter values vs posterior:") - for param in true_parameters.__dataclass_fields__.keys(): - true_val = getattr(true_parameters, param) - model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) - model_high = jnp.quantile( - getattr(samples.position, param), 0.99, axis=0 - ) - print(f" {param}:") - print(f" true value: {true_val}") - print(f" posterior 1%: {model_low}") - print(f" posterior 99%: {model_high}") - - -if __name__ == "__main__": - main() + idata = az.concat(posterior, sample_stats) + assert idata is not None + return idata diff --git a/src/enzax/parameters.py b/src/enzax/parameters.py index b10793f..6534169 100644 --- a/src/enzax/parameters.py +++ b/src/enzax/parameters.py @@ -20,10 +20,10 @@ @jaxtyped(typechecker=typechecked) -class MichaelisMentenParameters(eqx.Module): +class MichaelisMentenParameterSet(eqx.Module): """Parameters for a model with Michaelis Menten kinetics. - Reactions can have any of these rate laws: + This kind of parameter set supports models with the following rate laws: - enzax.rate_equations.drain.Drain - enzax.rate_equations.michaelis_menten.IrreversibleMichaelisMenten @@ -41,7 +41,7 @@ class MichaelisMentenParameters(eqx.Module): log_drain: LogDrain -class AllostericMichaelisMentenParameters(MichaelisMentenParameters): +class AllostericMichaelisMentenParameterSet(MichaelisMentenParameterSet): """Parameters for a model with Michaelis Menten kinetics, with allostery. Reactions can be any out of: diff --git a/src/enzax/steady_state.py b/src/enzax/steady_state.py new file mode 100644 index 0000000..b773da3 --- /dev/null +++ b/src/enzax/steady_state.py @@ -0,0 +1,53 @@ +"""Module for solving steady state problems. + +Given a structural kinetic model, a set of parameters and an initial guess, the aim is to find the physiological steady state metabolite concentration and its parameter sensitivities. + +""" # noqa: E501 + +import diffrax +import equinox as eqx +import lineax as lx +from jaxtyping import Array, Float, PyTree + +from enzax.kinetic_model import KineticModel + + +@eqx.filter_jit() +def get_kinetic_model_steady_state( + model: KineticModel, + guess: Float[Array, " n"], +) -> PyTree: + term = diffrax.ODETerm(model.dcdt) + solver = diffrax.Kvaerno5() + t0 = 0 + t1 = 900 + dt0 = 0.000001 + max_steps = None + controller = diffrax.PIDController( + pcoeff=0.1, + icoeff=0.3, + rtol=1e-11, + atol=1e-11, + ) + cond_fn = diffrax.steady_state_event() + event = diffrax.Event(cond_fn) + adjoint = diffrax.ImplicitAdjoint( + linear_solver=lx.AutoLinearSolver(well_posed=False) + ) + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + guess, + args=model, + max_steps=max_steps, + stepsize_controller=controller, + event=event, + adjoint=adjoint, + ) + if sol.ys is not None: + return sol.ys[0] + else: + raise ValueError("No steady state found!") diff --git a/src/enzax/steady_state_problem.py b/src/enzax/steady_state_problem.py deleted file mode 100644 index ecbbb95..0000000 --- a/src/enzax/steady_state_problem.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Module for solving steady state problems. - -Given a structural kinetic model, a set of parameters and an initial guess, the aim is to find the physiological steady state metabolite concentration and its parameter sensitivities. - -""" # noqa: E501 - -import time - - -import diffrax -import equinox as eqx -import jax -import jax.numpy as jnp -import lineax as lx -from jaxtyping import Array, Float -from enzax.examples import methionine -from enzax.kinetic_model import ( - KineticModel, - KineticModelParameters, - UnparameterisedKineticModel, -) - -jax.config.update("jax_debug_nans", True) -jax.config.update("jax_enable_x64", True) - - -@eqx.filter_jit() -def solve( - parameters: KineticModelParameters, - unparameterised_model: UnparameterisedKineticModel, - guess: Float[Array, " n"], -): - model = KineticModel(parameters, unparameterised_model) - term = diffrax.ODETerm(model.dcdt) - solver = diffrax.Kvaerno5() - t0 = 0 - t1 = 900 - dt0 = 0.000001 - max_steps = None - controller = diffrax.PIDController( - pcoeff=0.1, - icoeff=0.3, - rtol=1e-11, - atol=1e-11, - ) - cond_fn = diffrax.steady_state_event() - event = diffrax.Event(cond_fn) - adjoint = diffrax.ImplicitAdjoint( - linear_solver=lx.AutoLinearSolver(well_posed=False) - ) - sol = diffrax.diffeqsolve( - term, - solver, - t0, - t1, - dt0, - guess, - args=model, - max_steps=max_steps, - stepsize_controller=controller, - event=event, - adjoint=adjoint, - # progress_meter=diffrax.TextProgressMeter(minimum_increase=0.001), - ) - return sol.ys[0] - - -def main(): - """Function for testing the steady state solver.""" - # guesses - bad_guess = jnp.full((5,), 0.01) - good_guess = jnp.array( - [ - 4.233000e-05, # met-L - 3.099670e-05, # amet - 2.170170e-07, # ahcys - 3.521780e-06, # hcys - 6.534400e-06, # 5mthf - ] - ) - model = methionine.model - # solve once for jitting - solve(methionine.parameters, methionine.unparameterised_model, good_guess) - jax.jacrev(solve)( - methionine.parameters, methionine.unparameterised_model, good_guess - ) - # compare good and bad guess - for guess in [bad_guess, good_guess]: - start = time.time() - conc_steady = solve( - methionine.parameters, methionine.unparameterised_model, guess - ) - jac = jax.jacrev(solve)( - methionine.parameters, methionine.unparameterised_model, guess - ) - runtime = (time.time() - start) * 1e3 - sv = model.dcdt(jnp.array(0.0), conc=conc_steady) - flux = model.flux(conc_steady) - print(f"Results with starting guess {guess}:") - print(f"\tRun time in milliseconds: {round(runtime, 4)}") - print(f"\tSteady state concentration: {conc_steady}") - print(f"\tFlux: {flux}") - print(f"\tSv: {sv}") - print(f"\tJacobian: {jac}") - print(f"\tLog Km Jacobian: {jac.log_km}") - print(f"\tDgf Jacobian: {jac.dgf}") - - -if __name__ == "__main__": - main() diff --git a/tests/test_kinetic_model.py b/tests/test_examples.py similarity index 100% rename from tests/test_kinetic_model.py rename to tests/test_examples.py diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index 8b77f10..a511e77 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -7,10 +7,10 @@ IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameters +from enzax.parameters import AllostericMichaelisMentenParameterSet EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1, 0.3]) -EXAMPLE_PARAMETERS = AllostericMichaelisMentenParameters( +EXAMPLE_PARAMETERS = AllostericMichaelisMentenParameterSet( log_kcat=jnp.array([-0.1]), log_enzyme=jnp.log(jnp.array([0.3])), dgf=jnp.array([-3, -1.0]),