Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grad test #24

Merged
merged 13 commits into from
Sep 24, 2024
7 changes: 6 additions & 1 deletion scripts/mcmc_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def main():
log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1),
log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1),
log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1),
dgf=ind_prior_from_truth(true_parameters.dgf, 0.1),
dgf=(
ind_prior_from_truth(true_parameters.dgf, 0.1)[0],
jnp.diag(
jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1])
),
),
log_km=ind_prior_from_truth(true_parameters.log_km, 0.1),
log_conc_unbalanced=ind_prior_from_truth(
true_parameters.log_conc_unbalanced, 0.1
Expand Down
19 changes: 16 additions & 3 deletions src/enzax/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
from jax._src.random import KeyArray
import jax.numpy as jnp
from jax.scipy.stats import norm
from jax.scipy.stats import norm, multivariate_normal
from jaxtyping import Array, Float, PyTree, ScalarLike

from enzax.kinetic_model import (
Expand Down Expand Up @@ -40,7 +40,10 @@ class AllostericMichaelisMentenPriorSet:
log_kcat: Float[Array, "2 n_enzyme"]
log_enzyme: Float[Array, "2 n_enzyme"]
log_drain: Float[Array, "2 n_drain"]
dgf: Float[Array, "2 n_metabolite"]
dgf: tuple[
Float[Array, " n_metabolite"],
Float[Array, " n_metabolite n_metabolite"],
]
log_km: Float[Array, "2 n_km"]
log_ki: Float[Array, "2 n_ki"]
log_conc_unbalanced: Float[Array, "2 n_unbalanced"]
Expand All @@ -63,6 +66,16 @@ def ind_normal_prior_logdensity(param, prior: Float[Array, "2 _"]):
return norm.logpdf(param, loc=prior[0], scale=prior[1]).sum()


def mv_normal_prior_logdensity(
param: Float[Array, " _"],
prior: tuple[Float[Array, " _"], Float[Array, " _ _"]],
):
"""Total log density for an multivariate normal distribution."""
return jnp.sum(
multivariate_normal.logpdf(param, mean=prior[0], cov=prior[1])
)


def posterior_logdensity_amm(
parameters: AllostericMichaelisMentenParameterSet,
structure: KineticModelStructure,
Expand Down Expand Up @@ -91,7 +104,7 @@ def posterior_logdensity_amm(
ind_normal_prior_logdensity(parameters.log_kcat, prior.log_kcat)
+ ind_normal_prior_logdensity(parameters.log_enzyme, prior.log_enzyme)
+ ind_normal_prior_logdensity(parameters.log_drain, prior.log_drain)
+ ind_normal_prior_logdensity(parameters.dgf, prior.dgf)
+ mv_normal_prior_logdensity(parameters.dgf, prior.dgf)
+ ind_normal_prior_logdensity(parameters.log_km, prior.log_km)
+ ind_normal_prior_logdensity(
parameters.log_conc_unbalanced, prior.log_conc_unbalanced
Expand Down
107 changes: 107 additions & 0 deletions tests/data/methionine_pldf_grad.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"log_kcat": {
"MAT1": -8.515730409411995,
"MAT3": -2.5685756593727778,
"METH-Gen": -27.620460850911506,
"GNMT1": -4.548438844490129,
"AHC1": -11.75344338177346,
"MS1": -16.03214664980906,
"BHMT1": 12.038008848152892,
"CBS1": -27.970086306108342,
"MTHFR1": 21.633517021058395,
"PROT1": -11.72862587853564
},
"log_enzyme": {
"MAT1": 6.136798020189005,
"MAT3": 12.083952770228223,
"METH-Gen": -12.96793242131051,
"GNMT1": 10.10408958511087,
"AHC1": 2.8990850478275423,
"MS1": -1.379618220208057,
"BHMT1": 26.690537277753894,
"CBS1": -13.317557876507342,
"MTHFR1": 36.286045450659394,
"PROT1": 2.9239025510653605
},
"log_drain": {
"the_drain": 77.11394277227623
},
"log_km": {
"met-L MAT1": 6.190853781501735,
"atp MAT1": 5.292693046521437,
"met-L MAT3": 2.4814738123803086,
"atp MAT3": 1.705092962802143,
"amet METH-Gen": 6.601750014600864,
"amet GNMT1": 4.590545886354766,
"gly GNMT1": 4.820977250097301,
"ahcys AHC1": 11.683829123279876,
"hcys-L AHC1": -2.9137920929850405,
"adn AHC1": -1.769215504424688,
"hcys-L MS1": 5.240078482919715,
"5mthf MS1": 14.652528429601952,
"hcys-L BHMT1": -10.22016805310029,
"glyb BHMT1": -10.688987009618199,
"hcys-L CBS1": 25.82504526708683,
"ser-L CBS1": 0.049935152823941564,
"mlthf MTHFR1": -20.851184438848044,
"nadph MTHFR1": -2.4216899858227414,
"met-L PROT1": 5.971084568691083
},
"dgf": {
"met-L": 0.0,
"atp": 0.0,
"pi": 0.0,
"ppi": 0.0,
"amet": 0.0,
"ahcys": -1.3226601224426768,
"gly": 0.0,
"sarcs": 0.0,
"hcys-L": 1.3226601224426768,
"adn": 1.3226601224426768,
"thf": 0.0,
"5mthf": 0.0,
"mlthf": 0.0,
"glyb": 0.0,
"dmgly": 0.0,
"ser-L": 0.0,
"nadp": 0.0,
"nadph": 0.0,
"cyst-L": 0.0
},
"log_ki": {
"MAT1": -0.3185780584014895,
"METH-Gen": -0.24799879824208654,
"GNMT1": 0.0018358603202088403
},
"log_conc_unbalanced": {
"atp": 7.654742420277419,
"pi": 14.652528429601,
"ppi": 14.652528429601,
"gly": 9.831551179503698,
"sarcs": 14.652528429601988,
"adn": 19.700379108345924,
"thf": 14.652528429601986,
"mlthf": 35.57574270245104,
"glyb": 25.3415154392192,
"dmgly": 14.652528429601988,
"ser-L": 14.60259327677706,
"nadp": 14.652528429601988,
"nadph": 17.074218415424728,
"cyst-L": 14.652528429601988
},
"log_transfer_constant": {
"METAT": 0.07549016742205726,
"GNMT": 1.3246972471203298,
"CBS": 0.00007813277693021312,
"MTHFR": -0.3918664031389874
},
"log_dissociation_constant": {
"met-L MAT3": 0.009116267028546848,
"amet MAT3": 0.012649140053254744,
"amet GNMT1": 3.233378796396484,
"mlthf GNMT1": -0.0720298340010019,
"amet CBS1": 0.0,
"amet MTHFR1": 0.5327826505992013,
"ahcys MTHFR1": -0.0637729300752556
}
}
110 changes: 110 additions & 0 deletions tests/test_lp_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json
import jax
import pytest
from jax import numpy as jnp

from enzax.examples import methionine
from enzax.mcmc import (
ObservationSet,
AllostericMichaelisMentenPriorSet,
ind_prior_from_truth,
posterior_logdensity_amm,
)
from enzax.steady_state import get_kinetic_model_steady_state

import importlib.resources
from tests import data

import functools

jax.config.update("jax_enable_x64", True)
SEED = 1234

methionine_pldf_grad_file = (
importlib.resources.files(data) / "methionine_pldf_grad.json"
)


def test_lp_grad():
model = methionine
structure = methionine.structure
rate_equations = methionine.rate_equations
true_parameters = methionine.parameters
true_model = methionine.model
default_state_guess = jnp.full((5,), 0.01)
true_states = get_kinetic_model_steady_state(
true_model, default_state_guess
)
prior = AllostericMichaelisMentenPriorSet(
log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1),
log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1),
log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1),
dgf=(
ind_prior_from_truth(true_parameters.dgf, 0.1)[0],
jnp.diag(
jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1])
),
),
log_km=ind_prior_from_truth(true_parameters.log_km, 0.1),
log_conc_unbalanced=ind_prior_from_truth(
true_parameters.log_conc_unbalanced, 0.1
),
temperature=ind_prior_from_truth(true_parameters.temperature, 0.1),
log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1),
log_transfer_constant=ind_prior_from_truth(
true_parameters.log_transfer_constant, 0.1
),
log_dissociation_constant=ind_prior_from_truth(
true_parameters.log_dissociation_constant, 0.1
),
)
# get true concentration
true_conc = jnp.zeros(methionine.structure.S.shape[0])
true_conc = true_conc.at[methionine.structure.balanced_species].set(
true_states
)
true_conc = true_conc.at[methionine.structure.unbalanced_species].set(
jnp.exp(true_parameters.log_conc_unbalanced)
)
# get true flux
true_flux = true_model.flux(true_states)
# simulate observations
error_conc = 0.03
error_flux = 0.05
error_enzyme = 0.03
key = jax.random.key(SEED)
obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc)
obs_enzyme = jnp.exp(
true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme
)
obs_flux = true_flux + jax.random.normal(key) * error_conc
obs = ObservationSet(
conc=obs_conc,
flux=obs_flux,
enzyme=obs_enzyme,
conc_scale=error_conc,
flux_scale=error_flux,
enzyme_scale=error_enzyme,
)
pldf = functools.partial(
posterior_logdensity_amm,
obs=obs,
prior=prior,
structure=structure,
rate_equations=rate_equations,
guess=default_state_guess,
)
pldf_grad = jax.jacrev(pldf)(methionine.parameters)
index_pldf_grad = {
p: {
c: float(getattr(pldf_grad, p)[i])
for i, c in enumerate(model.coords[model.dims[p][0]])
}
for p in model.dims.keys()
}
with open(methionine_pldf_grad_file, "r") as file:
saved_pldf_grad = file.read()

true_gradient = json.loads(saved_pldf_grad)
for p, vals in true_gradient.items():
assert true_gradient[p] == pytest.approx(index_pldf_grad[p])
Loading