From 0da0dd338da4593debcf359d086f5f71297083ed Mon Sep 17 00:00:00 2001 From: teddygroves Date: Fri, 29 Nov 2024 13:00:54 +0100 Subject: [PATCH] get mcmc demo to work --- scripts/mcmc_demo.py | 118 ++++++++++--------- scripts/steady_state_demo.py | 2 +- src/enzax/examples/linear.py | 16 ++- src/enzax/examples/methionine.py | 43 +++---- src/enzax/kinetic_model.py | 15 ++- src/enzax/mcmc.py | 15 ++- src/enzax/rate_equations/generalised_mwc.py | 6 +- src/enzax/rate_equations/michaelis_menten.py | 6 +- tests/data/expected_methionine_gradient.json | 2 +- tests/test_rate_equations.py | 6 +- 10 files changed, 121 insertions(+), 108 deletions(-) diff --git a/scripts/mcmc_demo.py b/scripts/mcmc_demo.py index 40d428d..cc8491c 100644 --- a/scripts/mcmc_demo.py +++ b/scripts/mcmc_demo.py @@ -7,14 +7,15 @@ import arviz as az import jax from jax import numpy as jnp +from jax.flatten_util import ravel_pytree +from jax.scipy.stats import norm +from jaxtyping import Array from enzax.examples import methionine +from enzax.kinetic_model import RateEquationModel, get_conc from enzax.mcmc import ( ObservationSet, - AllostericMichaelisMentenPriorSet, get_idata, - ind_prior_from_truth, - posterior_logdensity_amm, run_nuts, ) from enzax.steady_state import get_kinetic_model_steady_state @@ -24,59 +25,55 @@ jax.config.update("jax_enable_x64", True) +def joint_log_density(params, prior_mean, prior_sd, obs, guess): + # find the steady state concentration and flux + model = RateEquationModel(params, methionine.structure) + steady = get_kinetic_model_steady_state(model, guess) + conc = get_conc(steady, params.log_conc_unbalanced, methionine.structure) + flux = model.flux(steady) + # prior + flat_params, _ = ravel_pytree(params) + log_prior = norm.logpdf(flat_params, loc=prior_mean, scale=prior_sd).sum() + # likelihood + flat_log_enzyme, _ = ravel_pytree(params.log_enzyme) + log_likelihood = ( + norm.logpdf(jnp.log(obs.conc), jnp.log(conc), obs.conc_scale).sum() + + norm.logpdf( + jnp.log(obs.enzyme), flat_log_enzyme, obs.enzyme_scale + ).sum() + + norm.logpdf(obs.flux, flux, obs.flux_scale).sum() + ) + return log_prior + log_likelihood + + def main(): """Demonstrate How to make a Bayesian kinetic model with enzax.""" - structure = methionine.structure - rate_equations = methionine.rate_equations true_parameters = methionine.parameters true_model = methionine.model - default_state_guess = jnp.full((5,), 0.01) - true_states = get_kinetic_model_steady_state( - true_model, default_state_guess - ) - prior = AllostericMichaelisMentenPriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=( - ind_prior_from_truth(true_parameters.dgf, 0.1)[0], - jnp.diag( - jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) - ), - ), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) + default_guess = jnp.full((5,), 0.01) + true_steady = get_kinetic_model_steady_state(true_model, default_guess) # get true concentration - true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( - true_states - ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( - jnp.exp(true_parameters.log_conc_unbalanced) + true_conc = get_conc( + true_steady, + true_parameters.log_conc_unbalanced, + methionine.structure, ) # get true flux - true_flux = true_model.flux(true_states) + true_flux = true_model.flux(true_steady) # simulate observations error_conc = 0.03 error_flux = 0.05 error_enzyme = 0.03 key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) + true_log_enz_flat, _ = ravel_pytree(true_parameters.log_enzyme) + key_conc, key_enz, key_flux, key_nuts = jax.random.split(key, num=4) + obs_conc = jnp.exp( + jnp.log(true_conc) + jax.random.normal(key_conc) * error_conc + ) obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme + true_log_enz_flat + jax.random.normal(key_enz) * error_enzyme ) - obs_flux = true_flux + jax.random.normal(key) * error_conc + obs_flux = true_flux + jax.random.normal(key_flux) * error_conc obs = ObservationSet( conc=obs_conc, flux=obs_flux, @@ -85,17 +82,19 @@ def main(): flux_scale=error_flux, enzyme_scale=error_enzyme, ) - pldf = functools.partial( - posterior_logdensity_amm, - obs=obs, - prior=prior, - structure=structure, - rate_equations=rate_equations, - guess=default_state_guess, + flat_true_params, _ = ravel_pytree(true_parameters) + posterior_log_density = jax.jit( + functools.partial( + joint_log_density, + obs=obs, + prior_mean=flat_true_params, + prior_sd=0.1, + guess=default_guess, + ) ) samples, info = run_nuts( - pldf, - key, + posterior_log_density, + key_nuts, true_parameters, num_warmup=200, num_samples=200, @@ -104,9 +103,7 @@ def main(): is_mass_matrix_diagonal=False, target_acceptance_rate=0.95, ) - idata = get_idata( - samples, info, coords=methionine.coords, dims=methionine.dims - ) + idata = get_idata(samples, info) print(az.summary(idata)) if jnp.any(info.is_divergent): n_divergent = info.is_divergent.sum() @@ -117,10 +114,15 @@ def main(): print("True parameter values vs posterior:") for param in true_parameters.__dataclass_fields__.keys(): true_val = getattr(true_parameters, param) - model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) - model_high = jnp.quantile( - getattr(samples.position, param), 0.99, axis=0 - ) + model_p = getattr(samples.position, param) + if isinstance(true_val, Array): + model_low = jnp.quantile(model_p, 0.01, axis=0) + model_high = jnp.quantile(model_p, 0.99, axis=0) + elif isinstance(true_val, dict): + model_low, model_high = ( + {k: jnp.quantile(v, q, axis=0) for k, v in model_p.items()} + for q in (0.01, 0.99) + ) print(f" {param}:") print(f" true value: {true_val}") print(f" posterior 1%: {model_low}") diff --git a/scripts/steady_state_demo.py b/scripts/steady_state_demo.py index 5feeda3..45cc7c7 100644 --- a/scripts/steady_state_demo.py +++ b/scripts/steady_state_demo.py @@ -54,7 +54,7 @@ def get_steady_state_from_params(parameters: PyTree): print(f"\tSteady state concentration: {conc_steady}") print(f"\tFlux: {flux}") print(f"\tSv: {sv}") - print(f"\tLog Km Jacobian: {jac.log_km}") + print(f"\tLog substrate Km Jacobian: {jac.log_substrate_km}") print(f"\tDgf Jacobian: {jac.dgf}") diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 9f03df2..6a1b3b0 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -16,7 +16,8 @@ class ParameterDefinition(eqx.Module): - log_km: dict[int, Array] + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] log_kcat: dict[int, Scalar] log_enzyme: dict[int, Array] log_ki: dict[int, Array] @@ -50,10 +51,15 @@ class ParameterDefinition(eqx.Module): rate_equations=rate_equations, ) parameters = ParameterDefinition( - log_km={ - 0: jnp.array([[0.1], [-0.2]]), - 1: jnp.array([[0.5], [0.0]]), - 2: jnp.array([[-1.0], [0.5]]), + log_substrate_km={ + 0: jnp.array([0.1]), + 1: jnp.array([0.5]), + 2: jnp.array([-1.0]), + }, + log_product_km={ + 0: jnp.array([-0.2]), + 1: jnp.array([0.0]), + 2: jnp.array([0.5]), }, log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)}, dgf=jnp.array([-3.0, -1.0]), diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 408fe4f..9ca0674 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -23,7 +23,8 @@ class ParameterDefinition(eqx.Module): - log_km: dict[int, list[Array]] + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] log_kcat: dict[int, Scalar] log_enzyme: dict[int, Array] log_ki: dict[int, Array] @@ -151,32 +152,20 @@ class ParameterDefinition(eqx.Module): -46.4737, # cyst-L ] ), - log_km={ - 1: [jnp.log(jnp.array([0.000106919, 0.00203015]))], # MAT1 met-L, atp - 2: [jnp.log(jnp.array([0.00113258, 0.00236759]))], # MAT3 met-L atp - 3: [jnp.log(jnp.array([9.37e-06]))], # amet METH-Gen - 4: [ - jnp.log(jnp.array([0.000520015, 0.00253545])) - ], # amet GNMT1, # gly GNMT1 - 5: [ - jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 - jnp.log( - jnp.array([1.06e-05, 5.66e-06]) - ), # hcys-L AHC1, # adn AHC1 - ], - 6: [ - jnp.log(jnp.array([1.71e-06, 6.94e-05])) - ], # hcys-L MS1, # 5mthf MS1 - 7: [ - jnp.log(jnp.array([1.98e-05, 0.00845898])) - ], # hcys-L BHMT1, # glyb BHMT1 - 8: [ - jnp.log(jnp.array([4.24e-05, 2.83e-06])) - ], # hcys-L CBS1, # ser-L CBS1 - 9: [ - jnp.log(jnp.array([8.08e-05, 2.09e-05])) - ], # mlthf MTHFR1, # nadph MTHFR1 - 10: [jnp.log(jnp.array([4.39e-05]))], # met-L PROT1 + log_product_km={ + 5: jnp.log(jnp.array([1.06e-05, 5.66e-06])), # hcys-L AHC1, adn AHC1 + }, + log_substrate_km={ + 1: jnp.log(jnp.array([0.000106919, 0.00203015])), # MAT1 met-L, atp + 2: jnp.log(jnp.array([0.00113258, 0.00236759])), # MAT3 met-L atp + 3: jnp.log(jnp.array([9.37e-06])), # METH-Gen amet + 4: jnp.log(jnp.array([0.000520015, 0.00253545])), # GNMT1, amet, gly + 5: jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 + 6: jnp.log(jnp.array([1.71e-06, 6.94e-05])), # MS1 hcys-L, 5mthf + 7: jnp.log(jnp.array([1.98e-05, 0.00845898])), # BHMT1 hcys-L, glyb + 8: jnp.log(jnp.array([4.24e-05, 2.83e-06])), # CBS1 hcys-L, ser-L + 9: jnp.log(jnp.array([8.08e-05, 2.09e-05])), # MTHFR1 mlthf, nadph + 10: jnp.log(jnp.array([4.39e-05])), # PROT1 met-L }, temperature=jnp.array(298.15), log_ki={ diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index 8478f1f..6430f87 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -13,6 +13,13 @@ from enzax.rate_equation import RateEquation +def get_conc(balanced, log_unbalanced, structure): + conc = jnp.zeros(structure.S.shape[0]) + conc = conc.at[structure.balanced_species_ix].set(balanced) + conc = conc.at[structure.unbalanced_species_ix].set(jnp.exp(log_unbalanced)) + return conc + + @jaxtyped(typechecker=typechecked) @register_pytree_node_class class KineticModelStructure: @@ -143,10 +150,10 @@ def flux( :return: a one dimensional array of (possibly negative) floats representing reaction fluxes. Has same size as number of columns of self.structure.S. """ # Noqa: E501 - conc = jnp.zeros(self.structure.S.shape[0]) - conc = conc.at[self.structure.balanced_species_ix].set(conc_balanced) - conc = conc.at[self.structure.unbalanced_species_ix].set( - jnp.exp(self.parameters.log_conc_unbalanced) + conc = get_conc( + conc_balanced, + self.parameters.log_conc_unbalanced, + self.structure, ) flux_list = [] for i, rate_equation in enumerate(self.structure.rate_equations): diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index 20e0c81..eb17ea0 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -106,10 +106,17 @@ def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike): def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData: """Get an arviz InferenceData from a blackjax NUTS output.""" - sample_dict = { - k: jnp.expand_dims(getattr(samples.position, k), 0) - for k in samples.position.__dataclass_fields__.keys() - } + if coords is None: + coords = dict() + sample_dict = dict() + for k in samples.position.__dataclass_fields__.keys(): + samples_k = getattr(samples.position, k) + if isinstance(samples_k, Array): + sample_dict[k] = jnp.expand_dims(samples_k, 0) + elif isinstance(samples_k, dict): + sample_dict[k] = jnp.expand_dims( + jnp.concat([v.T for v in samples_k.values()]).T, 0 + ) posterior = az.convert_to_inference_data( sample_dict, group="posterior", diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index dbab431..ea9ee60 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -41,7 +41,7 @@ def get_allosteric_irreversible_michaelis_menten_input( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, ki=jnp.exp(parameters.log_ki[rxn_ix]), @@ -66,8 +66,8 @@ def get_allosteric_reversible_michaelis_menten_input( return AllostericReversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), - product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), + product_kms=jnp.exp(parameters.log_product_km[rxn_ix]), ki=jnp.exp(parameters.log_ki[rxn_ix]), dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 51fbb44..df44865 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -30,7 +30,7 @@ def get_irreversible_michaelis_menten_input( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), ix_substrate=ix_substrate, - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), substrate_stoichiometry=Sj[ix_substrate], ix_ki_species=ci_ix, ki=jnp.exp(parameters.log_ki[rxn_ix]), @@ -70,8 +70,8 @@ def get_reversible_michaelis_menten_input( return ReversibleMichaelisMentenInput( kcat=jnp.exp(parameters.log_kcat[rxn_ix]), enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]), - substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]), - product_kms=jnp.exp(parameters.log_km[rxn_ix][1]), + substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]), + product_kms=jnp.exp(parameters.log_product_km[rxn_ix]), ki=jnp.exp(parameters.log_ki[rxn_ix]), dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], temperature=parameters.temperature, diff --git a/tests/data/expected_methionine_gradient.json b/tests/data/expected_methionine_gradient.json index e8491a3..0271271 100644 --- a/tests/data/expected_methionine_gradient.json +++ b/tests/data/expected_methionine_gradient.json @@ -1 +1 @@ -[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, 12.720942038333074, 7.723972171145526, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092] +[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 12.720942038333074, 7.723972171145526, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092] diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index e14100e..3b2961f 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -13,7 +13,8 @@ class ExampleParameterSet(eqx.Module): - log_km: dict[int, Array] + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] log_kcat: dict[int, Scalar] log_enzyme: dict[int, Array] log_ki: dict[int, Array] @@ -28,7 +29,8 @@ class ExampleParameterSet(eqx.Module): EXAMPLE_S = np.array([[-1], [1], [0]], dtype=np.float64) EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1]) EXAMPLE_PARAMETERS = ExampleParameterSet( - log_km={0: jnp.array([[0.1], [-0.2]])}, + log_substrate_km={0: jnp.array([0.1])}, + log_product_km={0: jnp.array([-0.2])}, log_kcat={0: jnp.array(-0.1)}, dgf=jnp.array([-3.0, 1.0]), log_ki={0: jnp.array([1.0])},