diff --git a/README.md b/README.md index d694f94..fdc612f 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,12 @@ pip install enzax ```python from enzax.examples import methionine -from enzax.steady_state import solve_steady_state +from enzax.steady_state import get_kinetic_model_steady_state from jax import numpy as jnp guess = jnp.full((5,) 0.01) -steady_state = solve_steady_state( - methionine.parameters, methionine.unparameterised_model, guess -) +steady_state = get_kinetic_model_steady_state(methionine.model, guess) ``` ### Find a steady state's Jacobian with respect to all parameters @@ -36,12 +34,20 @@ steady_state = solve_steady_state( ```python import jax from enzax.examples import methionine -from enzax.steady_state import solve_steady_state +from enzax.steady_state import get_kinetic_model_steady_state from jax import numpy as jnp +from jaxtyping import PyTree guess = jnp.full((5,) 0.01) +model = methionine.model + +def get_steady_state_from_params(parameters: PyTree): + """Get the steady state with a one-argument non-pure function.""" + _model = RateEquationModel( + parameters, model.structure, model.rate_equations + ) + return get_kinetic_model_steady_state(_model, guess) + +jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) -jacobian = jax.jacrev(solve_steady_state)( - methionine.parameters, methionine.unparameterised_model, guess -) ``` diff --git a/docs/api/mcmc.md b/docs/api/mcmc.md new file mode 100644 index 0000000..9fc8b8c --- /dev/null +++ b/docs/api/mcmc.md @@ -0,0 +1,13 @@ +# ::: enzax.mcmc + options: + show_root_heading: true + filters: + - "!check" + members: + - ObservationSet + - PriorSet + - run_nuts + - posterior_logdensity_amm + - get_idata + - ind_prior_from_truth + - ind_normal_prior_logdensity diff --git a/docs/api/steady_state.md b/docs/api/steady_state.md new file mode 100644 index 0000000..1d6fb49 --- /dev/null +++ b/docs/api/steady_state.md @@ -0,0 +1,7 @@ +# ::: enzax.steady_state + options: + show_root_heading: true + filters: + - "!check" + members: + - get_kinetic_model_steady_state diff --git a/docs/getting_started.md b/docs/getting_started.md index f64bdd6..1872d2b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -170,10 +170,21 @@ guess = jnp.full((5,) 0.01) steady_state = get_kinetic_model_steady_state(methionine.model, guess) ``` -To find the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function: +To access the Jacobian of this steady state with respect to the model's parameters, we can wrap `get_kinetic_model_steady_state` in a function that has a set of parameters as its only argument, then use JAX's [`jacrev`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html) function: ```python import jax +from jaxtyping import PyTree -jacobian = jax.jacrev(solve_steady_state)(methionine.model, guess) +guess = jnp.full((5,) 0.01) +model = methionine.model + +def get_steady_state_from_params(parameters: PyTree): + """Get the steady state with a one-argument non-pure function.""" + _model = RateEquationModel( + parameters, model.structure, model.rate_equations + ) + return get_kinetic_model_steady_state(_model, guess) + +jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) ``` diff --git a/mkdocs.yml b/mkdocs.yml index ed1c91f..65a956a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,3 +44,5 @@ nav: - API: - 'api/kinetic_model.md' - 'api/rate_equations.md' + - 'api/steady_state.md' + - 'api/mcmc.md' diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index e26111b..ef32b09 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -23,8 +23,10 @@ @chex.dataclass class ObservationSet: + """Measurements from a single experiment.""" + conc: Float[Array, " m"] - flux: ScalarLike + flux: Float[Array, " n"] enzyme: Float[Array, " e"] conc_scale: ScalarLike flux_scale: ScalarLike @@ -33,6 +35,8 @@ class ObservationSet: @chex.dataclass class AllostericMichaelisMentenPriorSet: + """Priors for an allosteric Michaelis-Menten model.""" + log_kcat: Float[Array, "2 n_enzyme"] log_enzyme: Float[Array, "2 n_enzyme"] log_drain: Float[Array, "2 n_drain"] @@ -106,7 +110,7 @@ def posterior_logdensity_amm( @functools.partial(jax.jit, static_argnames=["kernel", "num_samples"]) -def inference_loop(rng_key, kernel, initial_state, num_samples): +def _inference_loop(rng_key, kernel, initial_state, num_samples): """Run MCMC with blackjax.""" def one_step(state, rng_key): @@ -141,7 +145,7 @@ def run_nuts( ) rng_key, sample_key = jax.random.split(rng_key) nuts_kernel = blackjax.nuts(logdensity_fn, **tuned_parameters).step - states, info = inference_loop( + states, info = _inference_loop( sample_key, kernel=nuts_kernel, initial_state=initial_state, diff --git a/src/enzax/steady_state.py b/src/enzax/steady_state.py index b773da3..a89f745 100644 --- a/src/enzax/steady_state.py +++ b/src/enzax/steady_state.py @@ -15,8 +15,18 @@ @eqx.filter_jit() def get_kinetic_model_steady_state( model: KineticModel, - guess: Float[Array, " n"], + guess: Float[Array, " n_balanced"], ) -> PyTree: + """Get the steady state of a kinetic model, using diffrax. + + The better the guess (generally) the faster and more reliable the solving. + + :param model: a KineticModel object + + :param guess: a JAX array of floats. Must have the same length as the + model's number of balanced species. + + """ term = diffrax.ODETerm(model.dcdt) solver = diffrax.Kvaerno5() t0 = 0