Skip to content

Commit

Permalink
get mcmc demo to work
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Nov 29, 2024
1 parent 6b8b627 commit 0da0dd3
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 108 deletions.
118 changes: 60 additions & 58 deletions scripts/mcmc_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import arviz as az
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.scipy.stats import norm
from jaxtyping import Array

from enzax.examples import methionine
from enzax.kinetic_model import RateEquationModel, get_conc
from enzax.mcmc import (
ObservationSet,
AllostericMichaelisMentenPriorSet,
get_idata,
ind_prior_from_truth,
posterior_logdensity_amm,
run_nuts,
)
from enzax.steady_state import get_kinetic_model_steady_state
Expand All @@ -24,59 +25,55 @@
jax.config.update("jax_enable_x64", True)


def joint_log_density(params, prior_mean, prior_sd, obs, guess):
# find the steady state concentration and flux
model = RateEquationModel(params, methionine.structure)
steady = get_kinetic_model_steady_state(model, guess)
conc = get_conc(steady, params.log_conc_unbalanced, methionine.structure)
flux = model.flux(steady)
# prior
flat_params, _ = ravel_pytree(params)
log_prior = norm.logpdf(flat_params, loc=prior_mean, scale=prior_sd).sum()
# likelihood
flat_log_enzyme, _ = ravel_pytree(params.log_enzyme)
log_likelihood = (
norm.logpdf(jnp.log(obs.conc), jnp.log(conc), obs.conc_scale).sum()
+ norm.logpdf(
jnp.log(obs.enzyme), flat_log_enzyme, obs.enzyme_scale
).sum()
+ norm.logpdf(obs.flux, flux, obs.flux_scale).sum()
)
return log_prior + log_likelihood


def main():
"""Demonstrate How to make a Bayesian kinetic model with enzax."""
structure = methionine.structure
rate_equations = methionine.rate_equations
true_parameters = methionine.parameters
true_model = methionine.model
default_state_guess = jnp.full((5,), 0.01)
true_states = get_kinetic_model_steady_state(
true_model, default_state_guess
)
prior = AllostericMichaelisMentenPriorSet(
log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1),
log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1),
log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1),
dgf=(
ind_prior_from_truth(true_parameters.dgf, 0.1)[0],
jnp.diag(
jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1])
),
),
log_km=ind_prior_from_truth(true_parameters.log_km, 0.1),
log_conc_unbalanced=ind_prior_from_truth(
true_parameters.log_conc_unbalanced, 0.1
),
temperature=ind_prior_from_truth(true_parameters.temperature, 0.1),
log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1),
log_transfer_constant=ind_prior_from_truth(
true_parameters.log_transfer_constant, 0.1
),
log_dissociation_constant=ind_prior_from_truth(
true_parameters.log_dissociation_constant, 0.1
),
)
default_guess = jnp.full((5,), 0.01)
true_steady = get_kinetic_model_steady_state(true_model, default_guess)
# get true concentration
true_conc = jnp.zeros(methionine.structure.S.shape[0])
true_conc = true_conc.at[methionine.structure.balanced_species].set(
true_states
)
true_conc = true_conc.at[methionine.structure.unbalanced_species].set(
jnp.exp(true_parameters.log_conc_unbalanced)
true_conc = get_conc(
true_steady,
true_parameters.log_conc_unbalanced,
methionine.structure,
)
# get true flux
true_flux = true_model.flux(true_states)
true_flux = true_model.flux(true_steady)
# simulate observations
error_conc = 0.03
error_flux = 0.05
error_enzyme = 0.03
key = jax.random.key(SEED)
obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc)
true_log_enz_flat, _ = ravel_pytree(true_parameters.log_enzyme)
key_conc, key_enz, key_flux, key_nuts = jax.random.split(key, num=4)
obs_conc = jnp.exp(
jnp.log(true_conc) + jax.random.normal(key_conc) * error_conc
)
obs_enzyme = jnp.exp(
true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme
true_log_enz_flat + jax.random.normal(key_enz) * error_enzyme
)
obs_flux = true_flux + jax.random.normal(key) * error_conc
obs_flux = true_flux + jax.random.normal(key_flux) * error_conc
obs = ObservationSet(
conc=obs_conc,
flux=obs_flux,
Expand All @@ -85,17 +82,19 @@ def main():
flux_scale=error_flux,
enzyme_scale=error_enzyme,
)
pldf = functools.partial(
posterior_logdensity_amm,
obs=obs,
prior=prior,
structure=structure,
rate_equations=rate_equations,
guess=default_state_guess,
flat_true_params, _ = ravel_pytree(true_parameters)
posterior_log_density = jax.jit(
functools.partial(
joint_log_density,
obs=obs,
prior_mean=flat_true_params,
prior_sd=0.1,
guess=default_guess,
)
)
samples, info = run_nuts(
pldf,
key,
posterior_log_density,
key_nuts,
true_parameters,
num_warmup=200,
num_samples=200,
Expand All @@ -104,9 +103,7 @@ def main():
is_mass_matrix_diagonal=False,
target_acceptance_rate=0.95,
)
idata = get_idata(
samples, info, coords=methionine.coords, dims=methionine.dims
)
idata = get_idata(samples, info)
print(az.summary(idata))
if jnp.any(info.is_divergent):
n_divergent = info.is_divergent.sum()
Expand All @@ -117,10 +114,15 @@ def main():
print("True parameter values vs posterior:")
for param in true_parameters.__dataclass_fields__.keys():
true_val = getattr(true_parameters, param)
model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0)
model_high = jnp.quantile(
getattr(samples.position, param), 0.99, axis=0
)
model_p = getattr(samples.position, param)
if isinstance(true_val, Array):
model_low = jnp.quantile(model_p, 0.01, axis=0)
model_high = jnp.quantile(model_p, 0.99, axis=0)
elif isinstance(true_val, dict):
model_low, model_high = (
{k: jnp.quantile(v, q, axis=0) for k, v in model_p.items()}
for q in (0.01, 0.99)
)
print(f" {param}:")
print(f" true value: {true_val}")
print(f" posterior 1%: {model_low}")
Expand Down
2 changes: 1 addition & 1 deletion scripts/steady_state_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_steady_state_from_params(parameters: PyTree):
print(f"\tSteady state concentration: {conc_steady}")
print(f"\tFlux: {flux}")
print(f"\tSv: {sv}")
print(f"\tLog Km Jacobian: {jac.log_km}")
print(f"\tLog substrate Km Jacobian: {jac.log_substrate_km}")
print(f"\tDgf Jacobian: {jac.dgf}")


Expand Down
16 changes: 11 additions & 5 deletions src/enzax/examples/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


class ParameterDefinition(eqx.Module):
log_km: dict[int, Array]
log_substrate_km: dict[int, Array]
log_product_km: dict[int, Array]
log_kcat: dict[int, Scalar]
log_enzyme: dict[int, Array]
log_ki: dict[int, Array]
Expand Down Expand Up @@ -50,10 +51,15 @@ class ParameterDefinition(eqx.Module):
rate_equations=rate_equations,
)
parameters = ParameterDefinition(
log_km={
0: jnp.array([[0.1], [-0.2]]),
1: jnp.array([[0.5], [0.0]]),
2: jnp.array([[-1.0], [0.5]]),
log_substrate_km={
0: jnp.array([0.1]),
1: jnp.array([0.5]),
2: jnp.array([-1.0]),
},
log_product_km={
0: jnp.array([-0.2]),
1: jnp.array([0.0]),
2: jnp.array([0.5]),
},
log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)},
dgf=jnp.array([-3.0, -1.0]),
Expand Down
43 changes: 16 additions & 27 deletions src/enzax/examples/methionine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@


class ParameterDefinition(eqx.Module):
log_km: dict[int, list[Array]]
log_substrate_km: dict[int, Array]
log_product_km: dict[int, Array]
log_kcat: dict[int, Scalar]
log_enzyme: dict[int, Array]
log_ki: dict[int, Array]
Expand Down Expand Up @@ -151,32 +152,20 @@ class ParameterDefinition(eqx.Module):
-46.4737, # cyst-L
]
),
log_km={
1: [jnp.log(jnp.array([0.000106919, 0.00203015]))], # MAT1 met-L, atp
2: [jnp.log(jnp.array([0.00113258, 0.00236759]))], # MAT3 met-L atp
3: [jnp.log(jnp.array([9.37e-06]))], # amet METH-Gen
4: [
jnp.log(jnp.array([0.000520015, 0.00253545]))
], # amet GNMT1, # gly GNMT1
5: [
jnp.log(jnp.array([2.32e-05])), # ahcys AHC1
jnp.log(
jnp.array([1.06e-05, 5.66e-06])
), # hcys-L AHC1, # adn AHC1
],
6: [
jnp.log(jnp.array([1.71e-06, 6.94e-05]))
], # hcys-L MS1, # 5mthf MS1
7: [
jnp.log(jnp.array([1.98e-05, 0.00845898]))
], # hcys-L BHMT1, # glyb BHMT1
8: [
jnp.log(jnp.array([4.24e-05, 2.83e-06]))
], # hcys-L CBS1, # ser-L CBS1
9: [
jnp.log(jnp.array([8.08e-05, 2.09e-05]))
], # mlthf MTHFR1, # nadph MTHFR1
10: [jnp.log(jnp.array([4.39e-05]))], # met-L PROT1
log_product_km={
5: jnp.log(jnp.array([1.06e-05, 5.66e-06])), # hcys-L AHC1, adn AHC1
},
log_substrate_km={
1: jnp.log(jnp.array([0.000106919, 0.00203015])), # MAT1 met-L, atp
2: jnp.log(jnp.array([0.00113258, 0.00236759])), # MAT3 met-L atp
3: jnp.log(jnp.array([9.37e-06])), # METH-Gen amet
4: jnp.log(jnp.array([0.000520015, 0.00253545])), # GNMT1, amet, gly
5: jnp.log(jnp.array([2.32e-05])), # ahcys AHC1
6: jnp.log(jnp.array([1.71e-06, 6.94e-05])), # MS1 hcys-L, 5mthf
7: jnp.log(jnp.array([1.98e-05, 0.00845898])), # BHMT1 hcys-L, glyb
8: jnp.log(jnp.array([4.24e-05, 2.83e-06])), # CBS1 hcys-L, ser-L
9: jnp.log(jnp.array([8.08e-05, 2.09e-05])), # MTHFR1 mlthf, nadph
10: jnp.log(jnp.array([4.39e-05])), # PROT1 met-L
},
temperature=jnp.array(298.15),
log_ki={
Expand Down
15 changes: 11 additions & 4 deletions src/enzax/kinetic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
from enzax.rate_equation import RateEquation


def get_conc(balanced, log_unbalanced, structure):
conc = jnp.zeros(structure.S.shape[0])
conc = conc.at[structure.balanced_species_ix].set(balanced)
conc = conc.at[structure.unbalanced_species_ix].set(jnp.exp(log_unbalanced))
return conc


@jaxtyped(typechecker=typechecked)
@register_pytree_node_class
class KineticModelStructure:
Expand Down Expand Up @@ -143,10 +150,10 @@ def flux(
:return: a one dimensional array of (possibly negative) floats representing reaction fluxes. Has same size as number of columns of self.structure.S.
""" # Noqa: E501
conc = jnp.zeros(self.structure.S.shape[0])
conc = conc.at[self.structure.balanced_species_ix].set(conc_balanced)
conc = conc.at[self.structure.unbalanced_species_ix].set(
jnp.exp(self.parameters.log_conc_unbalanced)
conc = get_conc(
conc_balanced,
self.parameters.log_conc_unbalanced,
self.structure,
)
flux_list = []
for i, rate_equation in enumerate(self.structure.rate_equations):
Expand Down
15 changes: 11 additions & 4 deletions src/enzax/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,17 @@ def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike):

def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData:
"""Get an arviz InferenceData from a blackjax NUTS output."""
sample_dict = {
k: jnp.expand_dims(getattr(samples.position, k), 0)
for k in samples.position.__dataclass_fields__.keys()
}
if coords is None:
coords = dict()
sample_dict = dict()
for k in samples.position.__dataclass_fields__.keys():
samples_k = getattr(samples.position, k)
if isinstance(samples_k, Array):
sample_dict[k] = jnp.expand_dims(samples_k, 0)
elif isinstance(samples_k, dict):
sample_dict[k] = jnp.expand_dims(
jnp.concat([v.T for v in samples_k.values()]).T, 0
)
posterior = az.convert_to_inference_data(
sample_dict,
group="posterior",
Expand Down
6 changes: 3 additions & 3 deletions src/enzax/rate_equations/generalised_mwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_allosteric_irreversible_michaelis_menten_input(
kcat=jnp.exp(parameters.log_kcat[rxn_ix]),
enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]),
ix_substrate=ix_substrate,
substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]),
substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]),
substrate_stoichiometry=Sj[ix_substrate],
ix_ki_species=ci_ix,
ki=jnp.exp(parameters.log_ki[rxn_ix]),
Expand All @@ -66,8 +66,8 @@ def get_allosteric_reversible_michaelis_menten_input(
return AllostericReversibleMichaelisMentenInput(
kcat=jnp.exp(parameters.log_kcat[rxn_ix]),
enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]),
substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]),
product_kms=jnp.exp(parameters.log_km[rxn_ix][1]),
substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]),
product_kms=jnp.exp(parameters.log_product_km[rxn_ix]),
ki=jnp.exp(parameters.log_ki[rxn_ix]),
dgf=parameters.dgf[species_to_dgf_ix][ix_reactant],
temperature=parameters.temperature,
Expand Down
6 changes: 3 additions & 3 deletions src/enzax/rate_equations/michaelis_menten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_irreversible_michaelis_menten_input(
kcat=jnp.exp(parameters.log_kcat[rxn_ix]),
enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]),
ix_substrate=ix_substrate,
substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]),
substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]),
substrate_stoichiometry=Sj[ix_substrate],
ix_ki_species=ci_ix,
ki=jnp.exp(parameters.log_ki[rxn_ix]),
Expand Down Expand Up @@ -70,8 +70,8 @@ def get_reversible_michaelis_menten_input(
return ReversibleMichaelisMentenInput(
kcat=jnp.exp(parameters.log_kcat[rxn_ix]),
enzyme=jnp.exp(parameters.log_enzyme[rxn_ix]),
substrate_kms=jnp.exp(parameters.log_km[rxn_ix][0]),
product_kms=jnp.exp(parameters.log_km[rxn_ix][1]),
substrate_kms=jnp.exp(parameters.log_substrate_km[rxn_ix]),
product_kms=jnp.exp(parameters.log_product_km[rxn_ix]),
ki=jnp.exp(parameters.log_ki[rxn_ix]),
dgf=parameters.dgf[species_to_dgf_ix][ix_reactant],
temperature=parameters.temperature,
Expand Down
2 changes: 1 addition & 1 deletion tests/data/expected_methionine_gradient.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, 12.720942038333074, 7.723972171145526, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092]
[-27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -20.04100693766347, -21.046992055692126, -51.008806318357124, -22.87688514741915, -63.96941462805278, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, 91.03560765315373, 10.573024526064392, -26.068355084866106, 12.720942038333074, 7.723972171145526, 37.17973577495537, 11.212172526568366, 120.58711257251609, 19.857181701644674, 51.31272574643639, 69.99249085550201, -52.55796473716283, 122.11340482806465, -94.45124667621965, 51.204443757792795, 56.794366392943715, 28.218282451377277, 137.9773162976142, 27.62149868641839, 32.52255474993412, 99.44496349751932, 1.065834388161285, 160.30549966124383, -20.25453860618623, 17.98290589428491, 1.390914925131441, 1.0827290721714595, -0.008014753935453442, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.3144610971331224, 0.00030317394151436246, -2.3261150730556386, -0.03979371864763597, -0.055215199560989005, -14.115986977104384, 0.2784311817245577, -0.3295244858681964, -5.783239063037292, -0.0006063958903968876, 1.710877943445662, -336.4591824327092]
6 changes: 4 additions & 2 deletions tests/test_rate_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


class ExampleParameterSet(eqx.Module):
log_km: dict[int, Array]
log_substrate_km: dict[int, Array]
log_product_km: dict[int, Array]
log_kcat: dict[int, Scalar]
log_enzyme: dict[int, Array]
log_ki: dict[int, Array]
Expand All @@ -28,7 +29,8 @@ class ExampleParameterSet(eqx.Module):
EXAMPLE_S = np.array([[-1], [1], [0]], dtype=np.float64)
EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1])
EXAMPLE_PARAMETERS = ExampleParameterSet(
log_km={0: jnp.array([[0.1], [-0.2]])},
log_substrate_km={0: jnp.array([0.1])},
log_product_km={0: jnp.array([-0.2])},
log_kcat={0: jnp.array(-0.1)},
dgf=jnp.array([-3.0, 1.0]),
log_ki={0: jnp.array([1.0])},
Expand Down

0 comments on commit 0da0dd3

Please sign in to comment.