diff --git a/scripts/mcmc_demo.py b/scripts/mcmc_demo.py index 63bac12..40d428d 100644 --- a/scripts/mcmc_demo.py +++ b/scripts/mcmc_demo.py @@ -38,7 +38,12 @@ def main(): log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=ind_prior_from_truth(true_parameters.dgf, 0.1), + dgf=( + ind_prior_from_truth(true_parameters.dgf, 0.1)[0], + jnp.diag( + jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) + ), + ), log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), log_conc_unbalanced=ind_prior_from_truth( true_parameters.log_conc_unbalanced, 0.1 diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index ef32b09..10d16dd 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -9,7 +9,7 @@ import jax from jax._src.random import KeyArray import jax.numpy as jnp -from jax.scipy.stats import norm +from jax.scipy.stats import norm, multivariate_normal from jaxtyping import Array, Float, PyTree, ScalarLike from enzax.kinetic_model import ( @@ -40,7 +40,10 @@ class AllostericMichaelisMentenPriorSet: log_kcat: Float[Array, "2 n_enzyme"] log_enzyme: Float[Array, "2 n_enzyme"] log_drain: Float[Array, "2 n_drain"] - dgf: Float[Array, "2 n_metabolite"] + dgf: tuple[ + Float[Array, " n_metabolite"], + Float[Array, " n_metabolite n_metabolite"], + ] log_km: Float[Array, "2 n_km"] log_ki: Float[Array, "2 n_ki"] log_conc_unbalanced: Float[Array, "2 n_unbalanced"] @@ -63,6 +66,16 @@ def ind_normal_prior_logdensity(param, prior: Float[Array, "2 _"]): return norm.logpdf(param, loc=prior[0], scale=prior[1]).sum() +def mv_normal_prior_logdensity( + param: Float[Array, " _"], + prior: tuple[Float[Array, " _"], Float[Array, " _ _"]], +): + """Total log density for an multivariate normal distribution.""" + return jnp.sum( + multivariate_normal.logpdf(param, mean=prior[0], cov=prior[1]) + ) + + def posterior_logdensity_amm( parameters: AllostericMichaelisMentenParameterSet, structure: KineticModelStructure, @@ -91,7 +104,7 @@ def posterior_logdensity_amm( ind_normal_prior_logdensity(parameters.log_kcat, prior.log_kcat) + ind_normal_prior_logdensity(parameters.log_enzyme, prior.log_enzyme) + ind_normal_prior_logdensity(parameters.log_drain, prior.log_drain) - + ind_normal_prior_logdensity(parameters.dgf, prior.dgf) + + mv_normal_prior_logdensity(parameters.dgf, prior.dgf) + ind_normal_prior_logdensity(parameters.log_km, prior.log_km) + ind_normal_prior_logdensity( parameters.log_conc_unbalanced, prior.log_conc_unbalanced diff --git a/tests/data/methionine_pldf_grad.json b/tests/data/methionine_pldf_grad.json new file mode 100644 index 0000000..1d4bf6d --- /dev/null +++ b/tests/data/methionine_pldf_grad.json @@ -0,0 +1,107 @@ +{ + "log_kcat": { + "MAT1": -8.515730409411995, + "MAT3": -2.5685756593727778, + "METH-Gen": -27.620460850911506, + "GNMT1": -4.548438844490129, + "AHC1": -11.75344338177346, + "MS1": -16.03214664980906, + "BHMT1": 12.038008848152892, + "CBS1": -27.970086306108342, + "MTHFR1": 21.633517021058395, + "PROT1": -11.72862587853564 + }, + "log_enzyme": { + "MAT1": 6.136798020189005, + "MAT3": 12.083952770228223, + "METH-Gen": -12.96793242131051, + "GNMT1": 10.10408958511087, + "AHC1": 2.8990850478275423, + "MS1": -1.379618220208057, + "BHMT1": 26.690537277753894, + "CBS1": -13.317557876507342, + "MTHFR1": 36.286045450659394, + "PROT1": 2.9239025510653605 + }, + "log_drain": { + "the_drain": 77.11394277227623 + }, + "log_km": { + "met-L MAT1": 6.190853781501735, + "atp MAT1": 5.292693046521437, + "met-L MAT3": 2.4814738123803086, + "atp MAT3": 1.705092962802143, + "amet METH-Gen": 6.601750014600864, + "amet GNMT1": 4.590545886354766, + "gly GNMT1": 4.820977250097301, + "ahcys AHC1": 11.683829123279876, + "hcys-L AHC1": -2.9137920929850405, + "adn AHC1": -1.769215504424688, + "hcys-L MS1": 5.240078482919715, + "5mthf MS1": 14.652528429601952, + "hcys-L BHMT1": -10.22016805310029, + "glyb BHMT1": -10.688987009618199, + "hcys-L CBS1": 25.82504526708683, + "ser-L CBS1": 0.049935152823941564, + "mlthf MTHFR1": -20.851184438848044, + "nadph MTHFR1": -2.4216899858227414, + "met-L PROT1": 5.971084568691083 + }, + "dgf": { + "met-L": 0.0, + "atp": 0.0, + "pi": 0.0, + "ppi": 0.0, + "amet": 0.0, + "ahcys": -1.3226601224426768, + "gly": 0.0, + "sarcs": 0.0, + "hcys-L": 1.3226601224426768, + "adn": 1.3226601224426768, + "thf": 0.0, + "5mthf": 0.0, + "mlthf": 0.0, + "glyb": 0.0, + "dmgly": 0.0, + "ser-L": 0.0, + "nadp": 0.0, + "nadph": 0.0, + "cyst-L": 0.0 + }, + "log_ki": { + "MAT1": -0.3185780584014895, + "METH-Gen": -0.24799879824208654, + "GNMT1": 0.0018358603202088403 + }, + "log_conc_unbalanced": { + "atp": 7.654742420277419, + "pi": 14.652528429601, + "ppi": 14.652528429601, + "gly": 9.831551179503698, + "sarcs": 14.652528429601988, + "adn": 19.700379108345924, + "thf": 14.652528429601986, + "mlthf": 35.57574270245104, + "glyb": 25.3415154392192, + "dmgly": 14.652528429601988, + "ser-L": 14.60259327677706, + "nadp": 14.652528429601988, + "nadph": 17.074218415424728, + "cyst-L": 14.652528429601988 + }, + "log_transfer_constant": { + "METAT": 0.07549016742205726, + "GNMT": 1.3246972471203298, + "CBS": 0.00007813277693021312, + "MTHFR": -0.3918664031389874 + }, + "log_dissociation_constant": { + "met-L MAT3": 0.009116267028546848, + "amet MAT3": 0.012649140053254744, + "amet GNMT1": 3.233378796396484, + "mlthf GNMT1": -0.0720298340010019, + "amet CBS1": 0.0, + "amet MTHFR1": 0.5327826505992013, + "ahcys MTHFR1": -0.0637729300752556 + } +} diff --git a/tests/test_lp_grad.py b/tests/test_lp_grad.py new file mode 100644 index 0000000..71b9284 --- /dev/null +++ b/tests/test_lp_grad.py @@ -0,0 +1,110 @@ +import json +import jax +import pytest +from jax import numpy as jnp + +from enzax.examples import methionine +from enzax.mcmc import ( + ObservationSet, + AllostericMichaelisMentenPriorSet, + ind_prior_from_truth, + posterior_logdensity_amm, +) +from enzax.steady_state import get_kinetic_model_steady_state + +import importlib.resources +from tests import data + +import functools + +jax.config.update("jax_enable_x64", True) +SEED = 1234 + +methionine_pldf_grad_file = ( + importlib.resources.files(data) / "methionine_pldf_grad.json" +) + + +def test_lp_grad(): + model = methionine + structure = methionine.structure + rate_equations = methionine.rate_equations + true_parameters = methionine.parameters + true_model = methionine.model + default_state_guess = jnp.full((5,), 0.01) + true_states = get_kinetic_model_steady_state( + true_model, default_state_guess + ) + prior = AllostericMichaelisMentenPriorSet( + log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), + log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), + log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), + dgf=( + ind_prior_from_truth(true_parameters.dgf, 0.1)[0], + jnp.diag( + jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) + ), + ), + log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), + log_conc_unbalanced=ind_prior_from_truth( + true_parameters.log_conc_unbalanced, 0.1 + ), + temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), + log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), + log_transfer_constant=ind_prior_from_truth( + true_parameters.log_transfer_constant, 0.1 + ), + log_dissociation_constant=ind_prior_from_truth( + true_parameters.log_dissociation_constant, 0.1 + ), + ) + # get true concentration + true_conc = jnp.zeros(methionine.structure.S.shape[0]) + true_conc = true_conc.at[methionine.structure.balanced_species].set( + true_states + ) + true_conc = true_conc.at[methionine.structure.unbalanced_species].set( + jnp.exp(true_parameters.log_conc_unbalanced) + ) + # get true flux + true_flux = true_model.flux(true_states) + # simulate observations + error_conc = 0.03 + error_flux = 0.05 + error_enzyme = 0.03 + key = jax.random.key(SEED) + obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) + obs_enzyme = jnp.exp( + true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme + ) + obs_flux = true_flux + jax.random.normal(key) * error_conc + obs = ObservationSet( + conc=obs_conc, + flux=obs_flux, + enzyme=obs_enzyme, + conc_scale=error_conc, + flux_scale=error_flux, + enzyme_scale=error_enzyme, + ) + pldf = functools.partial( + posterior_logdensity_amm, + obs=obs, + prior=prior, + structure=structure, + rate_equations=rate_equations, + guess=default_state_guess, + ) + pldf_grad = jax.jacrev(pldf)(methionine.parameters) + index_pldf_grad = { + p: { + c: float(getattr(pldf_grad, p)[i]) + for i, c in enumerate(model.coords[model.dims[p][0]]) + } + for p in model.dims.keys() + } + with open(methionine_pldf_grad_file, "r") as file: + saved_pldf_grad = file.read() + + true_gradient = json.loads(saved_pldf_grad) + for p, vals in true_gradient.items(): + assert true_gradient[p] == pytest.approx(index_pldf_grad[p])