Skip to content

Commit

Permalink
update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Sep 8, 2024
1 parent a5d0660 commit 8d2f5f9
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 14 deletions.
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,33 @@ pip install enzax

```python
from enzax.examples import methionine
from enzax.steady_state import solve_steady_state
from enzax.steady_state import get_kinetic_model_steady_state
from jax import numpy as jnp

guess = jnp.full((5,) 0.01)

steady_state = solve_steady_state(
methionine.parameters, methionine.unparameterised_model, guess
)
steady_state = get_kinetic_model_steady_state(methionine.model, guess)
```

### Find a steady state's Jacobian with respect to all parameters

```python
import jax
from enzax.examples import methionine
from enzax.steady_state import solve_steady_state
from enzax.steady_state import get_kinetic_model_steady_state
from jax import numpy as jnp
from jaxtyping import PyTree

guess = jnp.full((5,) 0.01)
model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)

jacobian = jax.jacrev(solve_steady_state)(
methionine.parameters, methionine.unparameterised_model, guess
)
```
13 changes: 13 additions & 0 deletions docs/api/mcmc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ::: enzax.mcmc
options:
show_root_heading: true
filters:
- "!check"
members:
- ObservationSet
- PriorSet
- run_nuts
- posterior_logdensity_amm
- get_idata
- ind_prior_from_truth
- ind_normal_prior_logdensity
7 changes: 7 additions & 0 deletions docs/api/steady_state.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# ::: enzax.steady_state
options:
show_root_heading: true
filters:
- "!check"
members:
- get_kinetic_model_steady_state
15 changes: 13 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,21 @@ guess = jnp.full((5,) 0.01)
steady_state = get_kinetic_model_steady_state(methionine.model, guess)
```

To find the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function:
To access the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in a function that has a set of parameters as its only argument, then use JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function:

```python
import jax
from jaxtyping import PyTree

jacobian = jax.jacrev(solve_steady_state)(methionine.model, guess)
guess = jnp.full((5,) 0.01)
model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)
```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ nav:
- API:
- 'api/kinetic_model.md'
- 'api/rate_equations.md'
- 'api/steady_state.md'
- 'api/mcmc.md'
10 changes: 7 additions & 3 deletions src/enzax/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

@chex.dataclass
class ObservationSet:
"""Measurements from a single experiment."""

conc: Float[Array, " m"]
flux: ScalarLike
flux: Float[Array, " n"]
enzyme: Float[Array, " e"]
conc_scale: ScalarLike
flux_scale: ScalarLike
Expand All @@ -33,6 +35,8 @@ class ObservationSet:

@chex.dataclass
class AllostericMichaelisMentenPriorSet:
"""Priors for an allosteric Michaelis-Menten model."""

log_kcat: Float[Array, "2 n_enzyme"]
log_enzyme: Float[Array, "2 n_enzyme"]
log_drain: Float[Array, "2 n_drain"]
Expand Down Expand Up @@ -106,7 +110,7 @@ def posterior_logdensity_amm(


@functools.partial(jax.jit, static_argnames=["kernel", "num_samples"])
def inference_loop(rng_key, kernel, initial_state, num_samples):
def _inference_loop(rng_key, kernel, initial_state, num_samples):
"""Run MCMC with blackjax."""

def one_step(state, rng_key):
Expand Down Expand Up @@ -141,7 +145,7 @@ def run_nuts(
)
rng_key, sample_key = jax.random.split(rng_key)
nuts_kernel = blackjax.nuts(logdensity_fn, **tuned_parameters).step
states, info = inference_loop(
states, info = _inference_loop(
sample_key,
kernel=nuts_kernel,
initial_state=initial_state,
Expand Down
12 changes: 11 additions & 1 deletion src/enzax/steady_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,18 @@
@eqx.filter_jit()
def get_kinetic_model_steady_state(
model: KineticModel,
guess: Float[Array, " n"],
guess: Float[Array, " n_balanced"],
) -> PyTree:
"""Get the steady state of a kinetic model, using diffrax.
The better the guess (generally) the faster and more reliable the solving.
:param model: a KineticModel object
:param guess: a JAX array of floats. Must have the same length as the
model's number of balanced species.
"""
term = diffrax.ODETerm(model.dcdt)
solver = diffrax.Kvaerno5()
t0 = 0
Expand Down

0 comments on commit 8d2f5f9

Please sign in to comment.