Skip to content

Commit

Permalink
Merge pull request #24 from dtu-qmcm/grad_test
Browse files Browse the repository at this point in the history
Grad test
  • Loading branch information
NicholasCowie authored Sep 24, 2024
2 parents 3bbdf9b + 34f07d0 commit 4d2ca4f
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scripts/mcmc_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def main():
log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1),
dgf=(
ind_prior_from_truth(true_parameters.dgf, 0.1)[0],
jnp.diag(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]),
jnp.diag(
jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1])
),
),
log_km=ind_prior_from_truth(true_parameters.log_km, 0.1),
log_conc_unbalanced=ind_prior_from_truth(
Expand Down
107 changes: 107 additions & 0 deletions tests/data/methionine_pldf_grad.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"log_kcat": {
"MAT1": -8.515730409411995,
"MAT3": -2.5685756593727778,
"METH-Gen": -27.620460850911506,
"GNMT1": -4.548438844490129,
"AHC1": -11.75344338177346,
"MS1": -16.03214664980906,
"BHMT1": 12.038008848152892,
"CBS1": -27.970086306108342,
"MTHFR1": 21.633517021058395,
"PROT1": -11.72862587853564
},
"log_enzyme": {
"MAT1": 6.136798020189005,
"MAT3": 12.083952770228223,
"METH-Gen": -12.96793242131051,
"GNMT1": 10.10408958511087,
"AHC1": 2.8990850478275423,
"MS1": -1.379618220208057,
"BHMT1": 26.690537277753894,
"CBS1": -13.317557876507342,
"MTHFR1": 36.286045450659394,
"PROT1": 2.9239025510653605
},
"log_drain": {
"the_drain": 77.11394277227623
},
"log_km": {
"met-L MAT1": 6.190853781501735,
"atp MAT1": 5.292693046521437,
"met-L MAT3": 2.4814738123803086,
"atp MAT3": 1.705092962802143,
"amet METH-Gen": 6.601750014600864,
"amet GNMT1": 4.590545886354766,
"gly GNMT1": 4.820977250097301,
"ahcys AHC1": 11.683829123279876,
"hcys-L AHC1": -2.9137920929850405,
"adn AHC1": -1.769215504424688,
"hcys-L MS1": 5.240078482919715,
"5mthf MS1": 14.652528429601952,
"hcys-L BHMT1": -10.22016805310029,
"glyb BHMT1": -10.688987009618199,
"hcys-L CBS1": 25.82504526708683,
"ser-L CBS1": 0.049935152823941564,
"mlthf MTHFR1": -20.851184438848044,
"nadph MTHFR1": -2.4216899858227414,
"met-L PROT1": 5.971084568691083
},
"dgf": {
"met-L": 0.0,
"atp": 0.0,
"pi": 0.0,
"ppi": 0.0,
"amet": 0.0,
"ahcys": -1.3226601224426768,
"gly": 0.0,
"sarcs": 0.0,
"hcys-L": 1.3226601224426768,
"adn": 1.3226601224426768,
"thf": 0.0,
"5mthf": 0.0,
"mlthf": 0.0,
"glyb": 0.0,
"dmgly": 0.0,
"ser-L": 0.0,
"nadp": 0.0,
"nadph": 0.0,
"cyst-L": 0.0
},
"log_ki": {
"MAT1": -0.3185780584014895,
"METH-Gen": -0.24799879824208654,
"GNMT1": 0.0018358603202088403
},
"log_conc_unbalanced": {
"atp": 7.654742420277419,
"pi": 14.652528429601,
"ppi": 14.652528429601,
"gly": 9.831551179503698,
"sarcs": 14.652528429601988,
"adn": 19.700379108345924,
"thf": 14.652528429601986,
"mlthf": 35.57574270245104,
"glyb": 25.3415154392192,
"dmgly": 14.652528429601988,
"ser-L": 14.60259327677706,
"nadp": 14.652528429601988,
"nadph": 17.074218415424728,
"cyst-L": 14.652528429601988
},
"log_transfer_constant": {
"METAT": 0.07549016742205726,
"GNMT": 1.3246972471203298,
"CBS": 0.00007813277693021312,
"MTHFR": -0.3918664031389874
},
"log_dissociation_constant": {
"met-L MAT3": 0.009116267028546848,
"amet MAT3": 0.012649140053254744,
"amet GNMT1": 3.233378796396484,
"mlthf GNMT1": -0.0720298340010019,
"amet CBS1": 0.0,
"amet MTHFR1": 0.5327826505992013,
"ahcys MTHFR1": -0.0637729300752556
}
}
110 changes: 110 additions & 0 deletions tests/test_lp_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json
import jax
import pytest
from jax import numpy as jnp

from enzax.examples import methionine
from enzax.mcmc import (
ObservationSet,
AllostericMichaelisMentenPriorSet,
ind_prior_from_truth,
posterior_logdensity_amm,
)
from enzax.steady_state import get_kinetic_model_steady_state

import importlib.resources
from tests import data

import functools

jax.config.update("jax_enable_x64", True)
SEED = 1234

methionine_pldf_grad_file = (
importlib.resources.files(data) / "methionine_pldf_grad.json"
)


def test_lp_grad():
model = methionine
structure = methionine.structure
rate_equations = methionine.rate_equations
true_parameters = methionine.parameters
true_model = methionine.model
default_state_guess = jnp.full((5,), 0.01)
true_states = get_kinetic_model_steady_state(
true_model, default_state_guess
)
prior = AllostericMichaelisMentenPriorSet(
log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1),
log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1),
log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1),
dgf=(
ind_prior_from_truth(true_parameters.dgf, 0.1)[0],
jnp.diag(
jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1])
),
),
log_km=ind_prior_from_truth(true_parameters.log_km, 0.1),
log_conc_unbalanced=ind_prior_from_truth(
true_parameters.log_conc_unbalanced, 0.1
),
temperature=ind_prior_from_truth(true_parameters.temperature, 0.1),
log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1),
log_transfer_constant=ind_prior_from_truth(
true_parameters.log_transfer_constant, 0.1
),
log_dissociation_constant=ind_prior_from_truth(
true_parameters.log_dissociation_constant, 0.1
),
)
# get true concentration
true_conc = jnp.zeros(methionine.structure.S.shape[0])
true_conc = true_conc.at[methionine.structure.balanced_species].set(
true_states
)
true_conc = true_conc.at[methionine.structure.unbalanced_species].set(
jnp.exp(true_parameters.log_conc_unbalanced)
)
# get true flux
true_flux = true_model.flux(true_states)
# simulate observations
error_conc = 0.03
error_flux = 0.05
error_enzyme = 0.03
key = jax.random.key(SEED)
obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc)
obs_enzyme = jnp.exp(
true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme
)
obs_flux = true_flux + jax.random.normal(key) * error_conc
obs = ObservationSet(
conc=obs_conc,
flux=obs_flux,
enzyme=obs_enzyme,
conc_scale=error_conc,
flux_scale=error_flux,
enzyme_scale=error_enzyme,
)
pldf = functools.partial(
posterior_logdensity_amm,
obs=obs,
prior=prior,
structure=structure,
rate_equations=rate_equations,
guess=default_state_guess,
)
pldf_grad = jax.jacrev(pldf)(methionine.parameters)
index_pldf_grad = {
p: {
c: float(getattr(pldf_grad, p)[i])
for i, c in enumerate(model.coords[model.dims[p][0]])
}
for p in model.dims.keys()
}
with open(methionine_pldf_grad_file, "r") as file:
saved_pldf_grad = file.read()

true_gradient = json.loads(saved_pldf_grad)
for p, vals in true_gradient.items():
assert true_gradient[p] == pytest.approx(index_pldf_grad[p])

0 comments on commit 4d2ca4f

Please sign in to comment.