Skip to content

Commit

Permalink
Merge pull request #21 from dtu-qmcm/rate_equation_refactor
Browse files Browse the repository at this point in the history
Make way for non-rate-equation based kinetic models
  • Loading branch information
teddygroves authored Sep 8, 2024
2 parents ab9c89f + 8d2f5f9 commit db56074
Show file tree
Hide file tree
Showing 26 changed files with 1,214 additions and 916 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
.DS_Store

# scratchpads
scratch.md

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,33 @@ pip install enzax

```python
from enzax.examples import methionine
from enzax.steady_state import solve_steady_state
from enzax.steady_state import get_kinetic_model_steady_state
from jax import numpy as jnp

guess = jnp.full((5,) 0.01)

steady_state = solve_steady_state(
methionine.parameters, methionine.unparameterised_model, guess
)
steady_state = get_kinetic_model_steady_state(methionine.model, guess)
```

### Find a steady state's Jacobian with respect to all parameters

```python
import jax
from enzax.examples import methionine
from enzax.steady_state import solve_steady_state
from enzax.steady_state import get_kinetic_model_steady_state
from jax import numpy as jnp
from jaxtyping import PyTree

guess = jnp.full((5,) 0.01)
model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)

jacobian = jax.jacrev(solve_steady_state)(
methionine.parameters, methionine.unparameterised_model, guess
)
```
5 changes: 2 additions & 3 deletions docs/api/kinetic_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
filters:
- "!check"
members:
- KineticModel
- UnparameterisedKineticModel
- KineticModelParameters
- KineticModelStructure
- KineticModel
- RateEquationModel
13 changes: 13 additions & 0 deletions docs/api/mcmc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ::: enzax.mcmc
options:
show_root_heading: true
filters:
- "!check"
members:
- ObservationSet
- PriorSet
- run_nuts
- posterior_logdensity_amm
- get_idata
- ind_prior_from_truth
- ind_normal_prior_logdensity
7 changes: 7 additions & 0 deletions docs/api/steady_state.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# ::: enzax.steady_state
options:
show_root_heading: true
filters:
- "!check"
members:
- get_kinetic_model_steady_state
50 changes: 28 additions & 22 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ First we import some enzax classes:

```python
from enzax.kinetic_model import (
KineticModel,
KineticModelParameters,
KineticModelStructure,
UnparameterisedKineticModel,
RateEquationModel,
)
from enzax.rate_equations import (
AllostericReversibleMichaelisMenten,
Expand All @@ -49,7 +47,9 @@ structure = KineticModelStructure(
Next we provide some kinetic parameter values:

```python
parameters = KineticModelParameters(
from enzax.parameters import AllostericMichaelisMentenParameters

parameters = AllostericMichaelisMentenParameters(
log_kcat=jnp.array([-0.1, 0.0, 0.1]),
log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])),
dgf=jnp.array([-3, -1.0]),
Expand All @@ -65,6 +65,11 @@ parameters = KineticModelParameters(
Now we can use enzax's rate laws to specify how each reaction behaves:

```python
from enzax.rate_equations import (
AllostericReversibleMichaelisMenten,
ReversibleMichaelisMenten,
)

r0 = AllostericReversibleMichaelisMenten(
kcat_ix=0,
enzyme_ix=0,
Expand Down Expand Up @@ -130,16 +135,10 @@ r2 = ReversibleMichaelisMenten(
)
```

Next an unparameterised kinetic model
Now we can declare our model:

```python
unparameterised_model = UnparameterisedKineticModel(structure, [r0, r1, r2])
```

Finally a parameterised model:

```python
model = KineticModel(parameters, unparameterised_model)
model = RateEquationModel(structure, parameters, [r0, r1, r2])
```

To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations:
Expand All @@ -157,28 +156,35 @@ dcdt

## Find a kinetic model's steady state

Enzax provides a few example kinetic models, including [`methionine`](https://github.com/dtu-qmcm/enzax/blob/main/src/enzax/examples/methionine.py), a model of the mammallian methionine cycle.
Enzax provides a few example kinetic models, including [`methionine`](https://github.com/dtu-qmcm/enzax/blob/main/src/enzax/examples/methionine.py), a model of the mammalian methionine cycle.

Here is how to find this model's steady state (and its parameter gradients) using enzax's `solve_steady_state` function:
Here is how to find this model's steady state (and its parameter gradients) using enzax's `get_kinetic_model_steady_state` function:

```python
from enzax.examples import methionine
from enzax.steady_state import solve_steady_state
from enzax.steady_state import get_kinetic_model_steady_state
from jax import numpy as jnp

guess = jnp.full((5,) 0.01)

steady_state = solve_steady_state(
methionine.parameters, methionine.unparameterised_model, guess
)
steady_state = get_kinetic_model_steady_state(methionine.model, guess)
```

To find the jacobian of this steady state with respect to the model's parameters, we can wrap `solve_steady_state` in JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function:
To access the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in a function that has a set of parameters as its only argument, then use JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function:

```python
import jax
from jaxtyping import PyTree

jacobian = jax.jacrev(solve_steady_state)(
methionine.parameters, methionine.unparameterised_model, guess
)
guess = jnp.full((5,) 0.01)
model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)
```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ nav:
- API:
- 'api/kinetic_model.md'
- 'api/rate_equations.md'
- 'api/steady_state.md'
- 'api/mcmc.md'
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ line-length = 80

[tool.ruff.lint]
ignore = ["F722"]
extend-select = ["E501"] # line length is checked

[tool.ruff.lint.isort]
known-first-party = ["enzax"]
126 changes: 126 additions & 0 deletions scripts/mcmc_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Demonstration of how to make a Bayesian kinetic model with enzax."""

import functools
import logging
import warnings

import arviz as az
import jax
from jax import numpy as jnp

from enzax.examples import methionine
from enzax.mcmc import (
ObservationSet,
AllostericMichaelisMentenPriorSet,
get_idata,
ind_prior_from_truth,
posterior_logdensity_amm,
run_nuts,
)
from enzax.steady_state import get_kinetic_model_steady_state

SEED = 1234

jax.config.update("jax_enable_x64", True)


def main():
"""Demonstrate How to make a Bayesian kinetic model with enzax."""
structure = methionine.structure
rate_equations = methionine.rate_equations
true_parameters = methionine.parameters
true_model = methionine.model
default_state_guess = jnp.full((5,), 0.01)
true_states = get_kinetic_model_steady_state(
true_model, default_state_guess
)
prior = AllostericMichaelisMentenPriorSet(
log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1),
log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1),
log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1),
dgf=ind_prior_from_truth(true_parameters.dgf, 0.1),
log_km=ind_prior_from_truth(true_parameters.log_km, 0.1),
log_conc_unbalanced=ind_prior_from_truth(
true_parameters.log_conc_unbalanced, 0.1
),
temperature=ind_prior_from_truth(true_parameters.temperature, 0.1),
log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1),
log_transfer_constant=ind_prior_from_truth(
true_parameters.log_transfer_constant, 0.1
),
log_dissociation_constant=ind_prior_from_truth(
true_parameters.log_dissociation_constant, 0.1
),
)
# get true concentration
true_conc = jnp.zeros(methionine.structure.S.shape[0])
true_conc = true_conc.at[methionine.structure.balanced_species].set(
true_states
)
true_conc = true_conc.at[methionine.structure.unbalanced_species].set(
jnp.exp(true_parameters.log_conc_unbalanced)
)
# get true flux
true_flux = true_model.flux(true_states)
# simulate observations
error_conc = 0.03
error_flux = 0.05
error_enzyme = 0.03
key = jax.random.key(SEED)
obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc)
obs_enzyme = jnp.exp(
true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme
)
obs_flux = true_flux + jax.random.normal(key) * error_conc
obs = ObservationSet(
conc=obs_conc,
flux=obs_flux,
enzyme=obs_enzyme,
conc_scale=error_conc,
flux_scale=error_flux,
enzyme_scale=error_enzyme,
)
pldf = functools.partial(
posterior_logdensity_amm,
obs=obs,
prior=prior,
structure=structure,
rate_equations=rate_equations,
guess=default_state_guess,
)
samples, info = run_nuts(
pldf,
key,
true_parameters,
num_warmup=200,
num_samples=200,
initial_step_size=0.0001,
max_num_doublings=10,
is_mass_matrix_diagonal=False,
target_acceptance_rate=0.95,
)
idata = get_idata(
samples, info, coords=methionine.coords, dims=methionine.dims
)
print(az.summary(idata))
if jnp.any(info.is_divergent):
n_divergent = info.is_divergent.sum()
msg = f"There were {n_divergent} post-warmup divergent transitions."
warnings.warn(msg)
else:
logging.info("No post-warmup divergent transitions!")
print("True parameter values vs posterior:")
for param in true_parameters.__dataclass_fields__.keys():
true_val = getattr(true_parameters, param)
model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0)
model_high = jnp.quantile(
getattr(samples.position, param), 0.99, axis=0
)
print(f" {param}:")
print(f" true value: {true_val}")
print(f" posterior 1%: {model_low}")
print(f" posterior 99%: {model_high}")


if __name__ == "__main__":
main()
62 changes: 62 additions & 0 deletions scripts/steady_state_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Demonstration of how to find a steady state and its gradients with enzax."""

import time
from enzax.kinetic_model import RateEquationModel

import jax
from jax import numpy as jnp

from enzax.examples import methionine
from enzax.steady_state import get_kinetic_model_steady_state
from jaxtyping import PyTree

BAD_GUESS = jnp.full((5,), 0.01)
GOOD_GUESS = jnp.array(
[
4.233000e-05, # met-L
3.099670e-05, # amet
2.170170e-07, # ahcys
3.521780e-06, # hcys
6.534400e-06, # 5mthf
]
)


def main():
"""Function for testing the steady state solver."""
model = methionine.model
# compare good and bad guess
for guess in [BAD_GUESS, GOOD_GUESS]:

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state from just parameters.
This lets us get the Jacobian wrt (just) the parameters.
"""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
return get_kinetic_model_steady_state(_model, guess)

# solve once for jitting
get_kinetic_model_steady_state(model, GOOD_GUESS)
jax.jacrev(get_steady_state_from_params)(model.parameters)
# timer on
start = time.time()
conc_steady = get_kinetic_model_steady_state(model, guess)
jac = jax.jacrev(get_steady_state_from_params)(model.parameters)
# timer off
runtime = (time.time() - start) * 1e3
sv = model.dcdt(jnp.array(0.0), conc=conc_steady)
flux = model.flux(conc_steady)
print(f"Results with starting guess {guess}:")
print(f"\tRun time in milliseconds: {round(runtime, 4)}")
print(f"\tSteady state concentration: {conc_steady}")
print(f"\tFlux: {flux}")
print(f"\tSv: {sv}")
print(f"\tLog Km Jacobian: {jac.log_km}")
print(f"\tDgf Jacobian: {jac.dgf}")


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions src/enzax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from jax import config

config.update("jax_enable_x64", True)
Loading

0 comments on commit db56074

Please sign in to comment.