diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 23a3513..0a33cc1 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -22,21 +22,25 @@ jobs: - name: checkout code uses: actions/checkout@v2 + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + # Install a specific version of uv. + version: "0.5.5" + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install pdm - run: pip install pdm + - name: Install the project + run: uv sync --all-extras --dev - name: pre-commit checks uses: pre-commit/action@v2.0.3 - name: Run tests - run: | - pdm install --dev - pdm run pytest tests --cov=src/enzax + run: uv run pytest tests - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/README.md b/README.md index fdc612f..8bbec17 100644 --- a/README.md +++ b/README.md @@ -43,9 +43,7 @@ model = methionine.model def get_steady_state_from_params(parameters: PyTree): """Get the steady state with a one-argument non-pure function.""" - _model = RateEquationModel( - parameters, model.structure, model.rate_equations - ) + _model = RateEquationModel(parameters, model.structure) return get_kinetic_model_steady_state(_model, guess) jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) diff --git a/docs/getting_started.md b/docs/getting_started.md index 1872d2b..d4149b9 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -18,9 +18,14 @@ Enzax provides building blocks for you to construct a wide range of differentiab Here we write a model describing a simple linear pathway with two state variables, two boundary species and three reactions. -First we import some enzax classes: +First we import some enzax classes, as well as [equinox](https://github.com/patrick-kidger/equinox) and both JAX and standard versions of numpy: ```python +import equinox as eqx + +from jax import numpy as jnp +import numpy as np + from enzax.kinetic_model import ( KineticModelStructure, RateEquationModel, @@ -32,113 +37,82 @@ from enzax.rate_equations import ( ``` -Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which represent boundary or "unbalanced" species: +Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which reactions have which rate equations. ```python -structure = KineticModelStructure( - S=jnp.array( - [[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=jnp.float64 +stoichiometry = { + "r1": {"m1e": -1, "m1c": 1}, + "r2": {"m1c": -1, "m2c": 1}, + "r3": {"m2c": -1, "m2e": 1}, +} +reactions = ["r1", "r2", "r3"] +species = ["m1e", "m1c", "m2c", "m2e"] +balanced_species = ["m1c", "m2c"] +rate_equations = [ + AllostericReversibleMichaelisMenten( + ix_allosteric_activators=np.array([2]), subunits=1 + ), + AllostericReversibleMichaelisMenten( + ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1]) ), - balanced_species=jnp.array([1, 2]), - unbalanced_species=jnp.array([0, 3]), + ReversibleMichaelisMenten(water_stoichiometry=0.0), +] +structure = KineticModelStructure( + stoichiometry=stoichiometry, + species=species, + balanced_species=balanced_species, + rate_equations=rate_equations, ) ``` -Next we provide some kinetic parameter values: +Next we define what a set of kinetic parameters looks like for our problem, and provide a set of parameters matching this definition: ```python -from enzax.parameters import AllostericMichaelisMentenParameters - -parameters = AllostericMichaelisMentenParameters( - log_kcat=jnp.array([-0.1, 0.0, 0.1]), - log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), - dgf=jnp.array([-3, -1.0]), - log_km=jnp.array([0.1, -0.2, 0.5, 0.0, -1.0, 0.5]), - log_ki=jnp.array([1.0]), - log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), +class ParameterDefinition(eqx.Module): + log_substrate_km: dict[int, Array] + log_product_km: dict[int, Array] + log_kcat: dict[int, Scalar] + log_enzyme: dict[int, Array] + log_ki: dict[int, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[int, Array] + log_dc_activator: dict[int, Array] + log_tc: dict[int, Array] + +parameters = ParameterDefinition( + log_substrate_km={ + "r1": jnp.array([0.1]), + "r2": jnp.array([0.5]), + "r3": jnp.array([-1.0]), + }, + log_product_km={ + "r1": jnp.array([-0.2]), + "r2": jnp.array([0.0]), + "r3": jnp.array([0.5]), + }, + log_kcat={"r1": jnp.array(-0.1), "r2": jnp.array(0.0), "r3": jnp.array(0.1)}, + dgf=jnp.array([-3.0, -1.0]), + log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])}, temperature=jnp.array(310.0), - log_transfer_constant=jnp.array([-0.2, 0.3]), - log_dissociation_constant=jnp.array([-0.1, 0.2]), - log_drain=jnp.array([]), -) -``` -Now we can use enzax's rate laws to specify how each reaction behaves: - -```python -from enzax.rate_equations import ( - AllostericReversibleMichaelisMenten, - ReversibleMichaelisMenten, -) - -r0 = AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - subunits=1, -) -r1 = AllostericReversibleMichaelisMenten( - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([0]), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([1]), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([1], dtype=jnp.int16), - ix_product=jnp.array([2], dtype=jnp.int16), - ix_reactants=jnp.array([1, 2], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=1, - ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), - ix_dc_activation=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([1], dtype=jnp.int16), - subunits=1, -) -r2 = ReversibleMichaelisMenten( - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4, 5], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - ix_substrate=jnp.array([2], dtype=jnp.int16), - ix_product=jnp.array([3], dtype=jnp.int16), - ix_reactants=jnp.array([2, 3], dtype=jnp.int16), - reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), + log_enzyme={ + "r1": jnp.log(jnp.array(0.3)), + "r2": jnp.log(jnp.array(0.2)), + "r3": jnp.log(jnp.array(0.1)), + }, + log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), + log_tc={"r1": jnp.array(-0.2), "r2": jnp.array(0.3)}, + log_dc_activator={"r1": jnp.array([-0.1]), "r2": jnp.array([])}, + log_dc_inhibitor={"r1": jnp.array([]), "r2": jnp.array([0.2])}, ) ``` +Note that the parameters use `jnp` whereas the structure uses `np`. This is because we want JAX to trace the parameters, whereas the structure should be static. Read more about this [here](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations). Now we can declare our model: ```python -model = RateEquationModel(structure, parameters, [r0, r1, r2]) +model = RateEquationModel(structure, parameters) ``` To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations: @@ -181,9 +155,7 @@ model = methionine.model def get_steady_state_from_params(parameters: PyTree): """Get the steady state with a one-argument non-pure function.""" - _model = RateEquationModel( - parameters, model.structure, model.rate_equations - ) + _model = RateEquationModel(parameters, model.structure) return get_kinetic_model_steady_state(_model, guess) jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters) diff --git a/pyproject.toml b/pyproject.toml index 6de80d3..80081b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ dependencies = [ "diffrax>=0.6.0", "jaxtyping>=0.2.31", "arviz>=0.19.0", + "jax>=0.4.35", + "equinox>=0.11.9", "python-libsbml>=5.20.4", "sympy2jax>=0.0.5", "sbmlmath>=0.2.0", @@ -33,12 +35,6 @@ build-backend = "hatchling.build" [tool.pdm] distribution = true -[tool.pdm.dev-dependencies] -dev = [ - "pytest>=8.3.2", - "pytest-cov>=5.0.0", - "pre-commit>=3.8.0", -] [tool.ruff] line-length = 80 @@ -48,3 +44,10 @@ extend-select = ["E501"] # line length is checked [tool.ruff.lint.isort] known-first-party = ["enzax"] + +[dependency-groups] +dev = [ + "pytest>=8.3.3", + "pytest-cov>=5.0.0", + "pre-commit>=3.8.0", +] diff --git a/scripts/mcmc_demo.py b/scripts/mcmc_demo.py index 40d428d..cc8491c 100644 --- a/scripts/mcmc_demo.py +++ b/scripts/mcmc_demo.py @@ -7,14 +7,15 @@ import arviz as az import jax from jax import numpy as jnp +from jax.flatten_util import ravel_pytree +from jax.scipy.stats import norm +from jaxtyping import Array from enzax.examples import methionine +from enzax.kinetic_model import RateEquationModel, get_conc from enzax.mcmc import ( ObservationSet, - AllostericMichaelisMentenPriorSet, get_idata, - ind_prior_from_truth, - posterior_logdensity_amm, run_nuts, ) from enzax.steady_state import get_kinetic_model_steady_state @@ -24,59 +25,55 @@ jax.config.update("jax_enable_x64", True) +def joint_log_density(params, prior_mean, prior_sd, obs, guess): + # find the steady state concentration and flux + model = RateEquationModel(params, methionine.structure) + steady = get_kinetic_model_steady_state(model, guess) + conc = get_conc(steady, params.log_conc_unbalanced, methionine.structure) + flux = model.flux(steady) + # prior + flat_params, _ = ravel_pytree(params) + log_prior = norm.logpdf(flat_params, loc=prior_mean, scale=prior_sd).sum() + # likelihood + flat_log_enzyme, _ = ravel_pytree(params.log_enzyme) + log_likelihood = ( + norm.logpdf(jnp.log(obs.conc), jnp.log(conc), obs.conc_scale).sum() + + norm.logpdf( + jnp.log(obs.enzyme), flat_log_enzyme, obs.enzyme_scale + ).sum() + + norm.logpdf(obs.flux, flux, obs.flux_scale).sum() + ) + return log_prior + log_likelihood + + def main(): """Demonstrate How to make a Bayesian kinetic model with enzax.""" - structure = methionine.structure - rate_equations = methionine.rate_equations true_parameters = methionine.parameters true_model = methionine.model - default_state_guess = jnp.full((5,), 0.01) - true_states = get_kinetic_model_steady_state( - true_model, default_state_guess - ) - prior = AllostericMichaelisMentenPriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=( - ind_prior_from_truth(true_parameters.dgf, 0.1)[0], - jnp.diag( - jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) - ), - ), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) + default_guess = jnp.full((5,), 0.01) + true_steady = get_kinetic_model_steady_state(true_model, default_guess) # get true concentration - true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( - true_states - ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( - jnp.exp(true_parameters.log_conc_unbalanced) + true_conc = get_conc( + true_steady, + true_parameters.log_conc_unbalanced, + methionine.structure, ) # get true flux - true_flux = true_model.flux(true_states) + true_flux = true_model.flux(true_steady) # simulate observations error_conc = 0.03 error_flux = 0.05 error_enzyme = 0.03 key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) + true_log_enz_flat, _ = ravel_pytree(true_parameters.log_enzyme) + key_conc, key_enz, key_flux, key_nuts = jax.random.split(key, num=4) + obs_conc = jnp.exp( + jnp.log(true_conc) + jax.random.normal(key_conc) * error_conc + ) obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme + true_log_enz_flat + jax.random.normal(key_enz) * error_enzyme ) - obs_flux = true_flux + jax.random.normal(key) * error_conc + obs_flux = true_flux + jax.random.normal(key_flux) * error_conc obs = ObservationSet( conc=obs_conc, flux=obs_flux, @@ -85,17 +82,19 @@ def main(): flux_scale=error_flux, enzyme_scale=error_enzyme, ) - pldf = functools.partial( - posterior_logdensity_amm, - obs=obs, - prior=prior, - structure=structure, - rate_equations=rate_equations, - guess=default_state_guess, + flat_true_params, _ = ravel_pytree(true_parameters) + posterior_log_density = jax.jit( + functools.partial( + joint_log_density, + obs=obs, + prior_mean=flat_true_params, + prior_sd=0.1, + guess=default_guess, + ) ) samples, info = run_nuts( - pldf, - key, + posterior_log_density, + key_nuts, true_parameters, num_warmup=200, num_samples=200, @@ -104,9 +103,7 @@ def main(): is_mass_matrix_diagonal=False, target_acceptance_rate=0.95, ) - idata = get_idata( - samples, info, coords=methionine.coords, dims=methionine.dims - ) + idata = get_idata(samples, info) print(az.summary(idata)) if jnp.any(info.is_divergent): n_divergent = info.is_divergent.sum() @@ -117,10 +114,15 @@ def main(): print("True parameter values vs posterior:") for param in true_parameters.__dataclass_fields__.keys(): true_val = getattr(true_parameters, param) - model_low = jnp.quantile(getattr(samples.position, param), 0.01, axis=0) - model_high = jnp.quantile( - getattr(samples.position, param), 0.99, axis=0 - ) + model_p = getattr(samples.position, param) + if isinstance(true_val, Array): + model_low = jnp.quantile(model_p, 0.01, axis=0) + model_high = jnp.quantile(model_p, 0.99, axis=0) + elif isinstance(true_val, dict): + model_low, model_high = ( + {k: jnp.quantile(v, q, axis=0) for k, v in model_p.items()} + for q in (0.01, 0.99) + ) print(f" {param}:") print(f" true value: {true_val}") print(f" posterior 1%: {model_low}") diff --git a/scripts/sbml_demo.py b/scripts/sbml_demo.py index 8307423..570fa83 100644 --- a/scripts/sbml_demo.py +++ b/scripts/sbml_demo.py @@ -11,54 +11,60 @@ reactions_sympy = sbml.sbml_to_sympy(model_sbml) sym_module = sbml.sympy_to_enzax(reactions_sympy) -parameters_all = [ - ({p.getId(): p.getValue() for p in r.getKineticLaw().getListOfParameters()}) - for r in model_sbml.getListOfReactions() +species = [s.getId() for s in model_sbml.getListOfSpecies()] + +balanced_species = [ + b.getId() for b in model_sbml.getListOfSpecies() if not b.boundary_condition ] -parameters = {} -for i in parameters_all: - parameters.update(i) -compartments = {c.getId(): c.volume for c in model_sbml.getListOfCompartments()} +reactions = [reaction.getId() for reaction in model_sbml.getListOfReactions()] -species = [s.getId() for s in model_sbml.getListOfSpecies()] +stoichiometry = { + reaction.getId(): { + r.getSpecies(): -r.getStoichiometry(), + p.getSpecies(): p.getStoichiometry(), + } + for reaction in model_sbml.getListOfReactions() + for r in reaction.getListOfReactants() + for p in reaction.getListOfProducts() +} -balanced_species_dict = {} -unbalanced_species_dict = {} -for i in model_sbml.getListOfSpecies(): - if not i.boundary_condition: - balanced_species_dict.update({i.getId(): i.getInitialConcentration()}) - else: - unbalanced_species_dict.update({i.getId(): i.getInitialConcentration()}) +structure = KineticModelStructure( + stoichiometry=stoichiometry, + species=species, + reactions=reactions, + balanced_species=balanced_species, +) -balanced_ix = jnp.array([species.index(b) for b in balanced_species_dict]) -unbalanced_ix = jnp.array([species.index(u) for u in unbalanced_species_dict]) +parameters_local = { + p.getId(): p.getValue() + for r in model_sbml.getListOfReactions() + for p in r.getKineticLaw().getListOfParameters() +} -para = {**parameters, **compartments, **unbalanced_species_dict} +parameters_global = { + p.getId(): p.getValue() + for p in model_sbml.getListOfParameters() + if p.constant +} -stoichmatrix = jnp.zeros( - (model_sbml.getNumSpecies(), model_sbml.getNumReactions()), - dtype=jnp.float64, -) -i = 0 -for reaction in model_sbml.getListOfReactions(): - for r in reaction.getListOfReactants(): - stoichmatrix = stoichmatrix.at[species.index(r.getSpecies()), i].set( - -int(r.getStoichiometry()) - ) - for p in reaction.getListOfProducts(): - stoichmatrix = stoichmatrix.at[species.index(p.getSpecies()), i].set( - int(p.getStoichiometry()) - ) - i += 1 +compartments = {c.getId(): c.volume for c in model_sbml.getListOfCompartments()} -structure = KineticModelStructure( - stoichmatrix, jnp.array(balanced_ix), jnp.array(unbalanced_ix) -) +unbalanced_species = { + u.getId(): u.getInitialConcentration() + for u in model_sbml.getListOfSpecies() + if u.boundary_condition +} + +para = { + **parameters_local, + **parameters_global, + **compartments, + **unbalanced_species, +} kinmodel_sbml = KineticModelSbml( parameters=para, - balanced_ids=balanced_species_dict, structure=structure, sym_module=sym_module, ) diff --git a/scripts/steady_state_demo.py b/scripts/steady_state_demo.py index 5feeda3..f030099 100644 --- a/scripts/steady_state_demo.py +++ b/scripts/steady_state_demo.py @@ -33,9 +33,7 @@ def get_steady_state_from_params(parameters: PyTree): This lets us get the Jacobian wrt (just) the parameters. """ - _model = RateEquationModel( - parameters, model.structure, model.rate_equations - ) + _model = RateEquationModel(parameters, model.structure) return get_kinetic_model_steady_state(_model, guess) # solve once for jitting @@ -54,7 +52,7 @@ def get_steady_state_from_params(parameters: PyTree): print(f"\tSteady state concentration: {conc_steady}") print(f"\tFlux: {flux}") print(f"\tSv: {sv}") - print(f"\tLog Km Jacobian: {jac.log_km}") + print(f"\tLog substrate Km Jacobian: {jac.log_substrate_km}") print(f"\tDgf Jacobian: {jac.dgf}") diff --git a/src/enzax/examples/linear.py b/src/enzax/examples/linear.py index 134167d..45cb847 100644 --- a/src/enzax/examples/linear.py +++ b/src/enzax/examples/linear.py @@ -1,100 +1,88 @@ """A simple linear kinetic model.""" +import equinox as eqx +import numpy as np from jax import numpy as jnp +from jaxtyping import Array, Scalar from enzax.kinetic_model import ( + RateEquationKineticModelStructure, RateEquationModel, - KineticModelStructure, ) from enzax.rate_equations import ( AllostericReversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameterSet -parameters = AllostericMichaelisMentenParameterSet( - log_kcat=jnp.array([-0.1, 0.0, 0.1]), - log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])), - dgf=jnp.array([-3, -1.0]), - log_km=jnp.array([0.1, -0.2, 0.5, 0.0, -1.0, 0.5]), - log_ki=jnp.array([1.0]), - log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), - temperature=jnp.array(310.0), - log_transfer_constant=jnp.array([-0.2, 0.3]), - log_dissociation_constant=jnp.array([-0.1, 0.2]), - log_drain=jnp.array([]), -) -structure = KineticModelStructure( - S=jnp.array( - [[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=jnp.float64 - ), - balanced_species=jnp.array([1, 2]), - unbalanced_species=jnp.array([0, 3]), -) + +class ParameterDefinition(eqx.Module): + log_substrate_km: dict[str, Array] + log_product_km: dict[str, Array] + log_kcat: dict[str, Scalar] + log_enzyme: dict[str, Array] + log_ki: dict[str, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[str, Array] + log_dc_activator: dict[str, Array] + log_tc: dict[str, Array] + + +stoichiometry = { + "r1": {"m1e": -1, "m1c": 1}, + "r2": {"m1c": -1, "m2c": 1}, + "r3": {"m2c": -1, "m2e": 1}, +} +reactions = ["r1", "r2", "r3"] +species = ["m1e", "m1c", "m2c", "m2e"] +balanced_species = ["m1c", "m2c"] rate_equations = [ AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - subunits=1, + ix_allosteric_activators=np.array([2]), subunits=1 ), AllostericReversibleMichaelisMenten( - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([0]), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([1]), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([1], dtype=jnp.int16), - ix_product=jnp.array([2], dtype=jnp.int16), - ix_reactants=jnp.array([1, 2], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=1, - ix_dc_inhibition=jnp.array([1], dtype=jnp.int16), - ix_dc_activation=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([1], dtype=jnp.int16), - subunits=1, - ), - ReversibleMichaelisMenten( - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4, 5], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - ix_substrate=jnp.array([2], dtype=jnp.int16), - ix_product=jnp.array([3], dtype=jnp.int16), - ix_reactants=jnp.array([2, 3], dtype=jnp.int16), - reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), + ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1]) ), + ReversibleMichaelisMenten(water_stoichiometry=0.0), ] +structure = RateEquationKineticModelStructure( + stoichiometry=stoichiometry, + species=species, + reactions=reactions, + balanced_species=balanced_species, + species_to_dgf_ix=np.array([0, 0, 1, 1]), + rate_equations=rate_equations, +) +parameters = ParameterDefinition( + log_substrate_km={ + "r1": jnp.array([0.1]), + "r2": jnp.array([0.5]), + "r3": jnp.array([-1.0]), + }, + log_product_km={ + "r1": jnp.array([-0.2]), + "r2": jnp.array([0.0]), + "r3": jnp.array([0.5]), + }, + log_kcat={ + "r1": jnp.array(-0.1), + "r2": jnp.array(0.0), + "r3": jnp.array(0.1), + }, + dgf=jnp.array([-3.0, -1.0]), + log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])}, + temperature=jnp.array(310.0), + log_enzyme={ + "r1": jnp.log(jnp.array(0.3)), + "r2": jnp.log(jnp.array(0.2)), + "r3": jnp.log(jnp.array(0.1)), + }, + log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])), + log_tc={"r1": jnp.array(-0.2), "r2": jnp.array(0.3)}, + log_dc_activator={"r1": jnp.array([-0.1]), "r2": jnp.array([])}, + log_dc_inhibitor={"r1": jnp.array([]), "r2": jnp.array([0.2])}, +) +true_model = RateEquationModel(structure=structure, parameters=parameters) steady_state = jnp.array([0.43658744, 0.12695706]) -model = RateEquationModel(parameters, structure, rate_equations) +model = RateEquationModel(parameters, structure) diff --git a/src/enzax/examples/methionine.py b/src/enzax/examples/methionine.py index 1179b11..7191b02 100644 --- a/src/enzax/examples/methionine.py +++ b/src/enzax/examples/methionine.py @@ -5,11 +5,14 @@ """ +import equinox as eqx +import numpy as np from jax import numpy as jnp +from jaxtyping import Array, Scalar from enzax.kinetic_model import ( + RateEquationKineticModelStructure, RateEquationModel, - KineticModelStructure, ) from enzax.rate_equations import ( AllostericIrreversibleMichaelisMenten, @@ -17,187 +20,103 @@ IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameterSet -coords = { - "enzyme": [ - "MAT1", - "MAT3", - "METH-Gen", - "GNMT1", - "AHC1", - "MS1", - "BHMT1", - "CBS1", - "MTHFR1", - "PROT1", - ], - "drain": ["the_drain"], - "rate": [ - "the_drain", - "MAT1", - "MAT3", - "METH-Gen", - "GNMT1", - "AHC1", - "MS1", - "BHMT1", - "CBS1", - "MTHFR1", - "PROT1", - ], - "metabolite": [ - "met-L", - "atp", - "pi", - "ppi", - "amet", - "ahcys", - "gly", - "sarcs", - "hcys-L", - "adn", - "thf", - "5mthf", - "mlthf", - "glyb", - "dmgly", - "ser-L", - "nadp", - "nadph", - "cyst-L", - ], - "km": [ - "met-L MAT1", - "atp MAT1", - "met-L MAT3", - "atp MAT3", - "amet METH-Gen", - "amet GNMT1", - "gly GNMT1", - "ahcys AHC1", - "hcys-L AHC1", - "adn AHC1", - "hcys-L MS1", - "5mthf MS1", - "hcys-L BHMT1", - "glyb BHMT1", - "hcys-L CBS1", - "ser-L CBS1", - "mlthf MTHFR1", - "nadph MTHFR1", - "met-L PROT1", - ], - "ki": [ - "MAT1", - "METH-Gen", - "GNMT1", - ], - "species": [ - "met-L", - "atp", - "pi", - "ppi", - "amet", - "ahcys", - "gly", - "sarcs", - "hcys-L", - "adn", - "thf", - "5mthf", - "mlthf", - "glyb", - "dmgly", - "ser-L", - "nadp", - "nadph", - "cyst-L", - ], - "balanced_species": [ - "met-L", - "amet", - "ahcys", - "hcys-L", - "5mthf", - ], - "unbalanced_species": [ - "atp", - "pi", - "ppi", - "gly", - "sarcs", - "adn", - "thf", - "mlthf", - "glyb", - "dmgly", - "ser-L", - "nadp", - "nadph", - "cyst-L", - ], - "transfer_constant": [ - "METAT", - "GNMT", - "CBS", - "MTHFR", - ], - "dissociation_constant": [ - "met-L MAT3", - "amet MAT3", - "amet GNMT1", - "mlthf GNMT1", - "amet CBS1", - "amet MTHFR1", - "ahcys MTHFR1", - ], -} -dims = { - "log_kcat": ["enzyme"], - "log_enzyme": ["enzyme"], - "log_drain": ["drain"], - "log_km": ["km"], - "dgf": ["metabolite"], - "log_ki": ["ki"], - "log_conc_unbalanced": ["unbalanced_species"], - "log_transfer_constant": ["transfer_constant"], - "log_dissociation_constant": ["dissociation_constant"], + +class ParameterDefinition(eqx.Module): + log_substrate_km: dict[str, Array] + log_product_km: dict[str, Array] + log_kcat: dict[str, Scalar] + log_enzyme: dict[str, Array] + log_ki: dict[str, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[str, Array] + log_dc_activator: dict[str, Array] + log_tc: dict[str, Array] + log_drain: dict[str, Scalar] + + +stoichiometry = { + "the_drain": {"met-L": 1}, + "MAT1": {"met-L": -1, "atp": -1, "pi": 1, "ppi": 1, "amet": 1}, + "MAT3": {"met-L": -1, "atp": -1, "pi": 1, "ppi": 1, "amet": 1}, + "METH-Gen": {"amet": -1, "ahcys": 1}, + "GNMT1": {"amet": -1, "ahcys": 1, "gly": -1, "sarcs": 1}, + "AHC1": {"ahcys": -1, "hcys-L": 1, "adn": 1}, + "MS1": {"hcys-L": -1, "thf": 1, "met-L": 1, "5mthf": -1}, + "BHMT1": {"hcys-L": -1, "glyb": -1, "met-L": 1, "dmgly": 1}, + "CBS1": {"hcys-L": -1, "ser-L": -1, "cyst-L": 1}, + "MTHFR1": {"5mthf": 1, "mlthf": -1, "nadp": 1, "nadph": -1}, + "PROT1": {"met-L": -1}, } -parameters = AllostericMichaelisMentenParameterSet( - log_kcat=jnp.log( - jnp.array( - [ - 7.89577, # MAT1 - 19.9215, # MAT3 - 1.15777, # METH-Gen - 10.5307, # GNMT1 - 234.284, # AHC1 - 1.77471, # MS1 - 13.7676, # BHMT1 - 7.02307, # CBS1 - 3.1654, # MTHFR1 - 0.264744, # PROT1 - ] - ) - ), - log_enzyme=jnp.log( - jnp.array( - [ - 0.000961712, # MAT1 - 0.00098812, # MAT3 - 0.00103396, # METH-Gen - 0.000983692, # GNMT1 - 0.000977878, # AHC1 - 0.00105094, # MS1 - 0.000996603, # BHMT1 - 0.00134056, # CBS1 - 0.0010054, # MTHFR1 - 0.000995525, # PROT1 - ] - ) - ), - log_drain=jnp.log(jnp.array([0.000850127])), +species = [ + "met-L", + "atp", + "pi", + "ppi", + "amet", + "ahcys", + "gly", + "sarcs", + "hcys-L", + "adn", + "thf", + "5mthf", + "mlthf", + "glyb", + "dmgly", + "ser-L", + "nadp", + "nadph", + "cyst-L", +] +balanced_species = [ + "met-L", + "amet", + "ahcys", + "hcys-L", + "5mthf", +] +reactions = [ + "the_drain", + "MAT1", + "MAT3", + "METH-Gen", + "GNMT1", + "AHC1", + "MS1", + "BHMT1", + "CBS1", + "MTHFR1", + "PROT1", +] +parameters = ParameterDefinition( + log_kcat={ + "MAT1": jnp.log(jnp.array(7.89577)), # MAT1 + "MAT3": jnp.log(jnp.array(19.9215)), # MAT3 + "METH-Gen": jnp.log(jnp.array(1.15777)), # METH-Gen + "GNMT1": jnp.log(jnp.array(10.5307)), # GNMT1 + "AHC1": jnp.log(jnp.array(234.284)), # AHC1 + "MS1": jnp.log(jnp.array(1.77471)), # MS1 + "BHMT1": jnp.log(jnp.array(13.7676)), # BHMT1 + "CBS1": jnp.log(jnp.array(7.02307)), # CBS1 + "MTHFR1": jnp.log(jnp.array(3.1654)), # MTHFR1 + "PROT1": jnp.log(jnp.array(0.264744)), # PROT1 + }, + log_enzyme={ + "MAT1": jnp.log(jnp.array(0.000961712)), # MAT1 + "MAT3": jnp.log(jnp.array(0.00098812)), # MAT3 + "METH-Gen": jnp.log(jnp.array(0.00103396)), # METH-Gen + "GNMT1": jnp.log(jnp.array(0.000983692)), # GNMT1 + "AHC1": jnp.log(jnp.array(0.000977878)), # AHC1 + "MS1": jnp.log(jnp.array(0.00105094)), # MS1 + "BHMT1": jnp.log(jnp.array(0.000996603)), # BHMT1 + "CBS1": jnp.log(jnp.array(0.00134056)), # CBS1 + "MTHFR1": jnp.log(jnp.array(0.0010054)), # MTHFR1 + "PROT1": jnp.log(jnp.array(0.000995525)), # PROT1 + }, + log_drain={"the_drain": jnp.log(jnp.array(0.000850127))}, dgf=jnp.array( [ 160.953, # met-L @@ -221,40 +140,44 @@ -46.4737, # cyst-L ] ), - log_km=jnp.log( - jnp.array( - [ - 0.000106919, # met-L MAT1 - 0.00203015, # atp MAT1 - 0.00113258, # met-L MAT3 - 0.00236759, # atp MAT3 - 9.37e-06, # amet METH-Gen - 0.000520015, # amet GNMT1 - 0.00253545, # gly GNMT1 - 2.32e-05, # ahcys AHC1 - 1.06e-05, # hcys-L AHC1 - 5.66e-06, # adn AHC1 - 1.71e-06, # hcys-L MS1 - 6.94e-05, # 5mthf MS1 - 1.98e-05, # hcys-L BHMT1 - 0.00845898, # glyb BHMT1 - 4.24e-05, # hcys-L CBS1 - 2.83e-06, # ser-L CBS1 - 8.08e-05, # mlthf MTHFR1 - 2.09e-05, # nadph MTHFR1 - 4.39e-05, # met-L PROT1 - ] - ) - ), - log_ki=jnp.log( - jnp.array( - [ - 0.000346704, # MAT1 - 5.56e-06, # METH-Gen - 5.31e-05, # GNMT1 - ] - ) - ), + log_product_km={ + "AHC1": jnp.log( + jnp.array([1.06e-05, 5.66e-06]) + ), # hcys-L AHC1, adn AHC1 + }, + log_substrate_km={ + "MAT1": jnp.log( + jnp.array([0.000106919, 0.00203015]) + ), # MAT1 met-L, atp + "MAT3": jnp.log(jnp.array([0.00113258, 0.00236759])), # MAT3 met-L atp + "METH-Gen": jnp.log(jnp.array([9.37e-06])), # METH-Gen amet + "GNMT1": jnp.log( + jnp.array([0.000520015, 0.00253545]) + ), # GNMT1, amet, gly + "AHC1": jnp.log(jnp.array([2.32e-05])), # ahcys AHC1 + "MS1": jnp.log(jnp.array([1.71e-06, 6.94e-05])), # MS1 hcys-L, 5mthf + "BHMT1": jnp.log( + jnp.array([1.98e-05, 0.00845898]) + ), # BHMT1 hcys-L, glyb + "CBS1": jnp.log(jnp.array([4.24e-05, 2.83e-06])), # CBS1 hcys-L, ser-L + "MTHFR1": jnp.log( + jnp.array([8.08e-05, 2.09e-05]) + ), # MTHFR1 mlthf, nadph + "PROT1": jnp.log(jnp.array([4.39e-05])), # PROT1 met-L + }, + temperature=jnp.array(298.15), + log_ki={ + "MAT1": jnp.array([jnp.log(0.000346704)]), # MAT1 + "MAT3": jnp.array([]), + "METH-Gen": jnp.array([jnp.log(5.56e-06)]), # METH-Gen + "GNMT1": jnp.array([jnp.log(5.31e-05)]), # GNMT1 + "AHC1": jnp.array([]), + "MS1": jnp.array([]), + "BHMT1": jnp.array([]), + "CBS1": jnp.array([]), + "MTHFR1": jnp.array([]), + "PROT1": jnp.array([]), + }, log_conc_unbalanced=jnp.log( jnp.array( [ @@ -276,233 +199,68 @@ ] ) ), - temperature=jnp.array(298.15), - log_transfer_constant=jnp.log( - jnp.array( - [ - 0.107657, # METAT - 131.207, # GNMT - 1.03452, # CBS - 0.392035, # MTHFR - ] - ) - ), - log_dissociation_constant=jnp.log( - jnp.array( - [ - 0.00059999, # met-L MAT3 - 0.000316641, # amet MAT3 - 1.98e-05, # amet GNMT1 - 0.000228576, # mlthf GNMT1 - 9.30e-05, # amet CBS1 - 1.46e-05, # amet MTHFR1 - 2.45e-06, # ahcys MTHFR1 - ] - ) - ), -) -structure = KineticModelStructure( - S=jnp.array( - [ - [1, -1, -1, 0, 0, 0, 1, 1, 0, 0, -1], # met-L b - [0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0], # atp - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # pi - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # ppi - [0, 1, 1, -1, -1, 0, 0, 0, 0, 0, 0], # amet b - [0, 0, 0, 1, 1, -1, 0, 0, 0, 0, 0], # ahcys b - [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], # gly - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # sarcs - [0, 0, 0, 0, 0, 1, -1, -1, -1, 0, 0], # hcys-L b - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # adn - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # thf - [0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0], # 5mthf b - [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # mlthf - [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0], # glyb - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # dmgly - [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], # ser-L - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # nadp - [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], # nadph - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # cyst-L - ], - dtype=jnp.float64, - ), - balanced_species=jnp.array( - [ - 0, # met-L - 4, # amet - 5, # ahcys - 8, # hcys-L - 11, # 5mthf - ] - ), - unbalanced_species=jnp.array( - [ - 1, # atp - 2, # pi - 3, # ppi - 6, # gly - 7, # sarcs - 9, # adn - 10, # thf - 12, # mlthf - 13, # glyb - 14, # dmgly - 15, # ser-L - 16, # nadp - 17, # nadph - 18, # cyst-L - ] - ), + log_tc={ + "MAT3": jnp.array(jnp.log(0.107657)), # MAT3 + "GNMT1": jnp.array(jnp.log(131.207)), # GNMT + "CBS1": jnp.array(jnp.log(1.03452)), # CBS + "MTHFR1": jnp.array(jnp.log(0.392035)), # MTHFR + }, + log_dc_activator={ + "MAT3": jnp.log( + jnp.array([0.00059999, 0.000316641]) + ), # met-L MAT3, # amet MAT3 + "GNMT1": jnp.log(jnp.array([1.98e-05])), # amet GNMT1 + "CBS1": jnp.array([]), # CBS1 + "MTHFR1": jnp.log(jnp.array([2.45e-06])), # ahcys MTHFR1, + }, + log_dc_inhibitor={ + "MAT3": jnp.array([]), # MAT3 + "GNMT1": jnp.log(jnp.array([0.000228576])), # mlthf GNMT1 + "CBS1": jnp.log(jnp.array([9.30e-05])), # amet CBS1 + "MTHFR1": jnp.log(jnp.array([1.46e-05])), # amet MTHFR1 + }, ) -rate_equations = [ - Drain(sign=jnp.array(1.0), drain_ix=0), # met-L source - IrreversibleMichaelisMenten( # MAT1 - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([0], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 + +structure = RateEquationKineticModelStructure( + stoichiometry=stoichiometry, + species=species, + reactions=reactions, + balanced_species=balanced_species, + rate_equations=[ + Drain(sign=1.0), # met-L source + IrreversibleMichaelisMenten( # MAT1 + ix_ki_species=np.array([4], dtype=np.int16), ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([4], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MAT3 - kcat_ix=1, - enzyme_ix=1, - km_ix=jnp.array([2, 3], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, -1.0, 1.0, 1.0, 1.0], dtype=jnp.int16 + AllostericIrreversibleMichaelisMenten( # MAT3 + subunits=2, + ix_allosteric_activators=np.array([0, 4], dtype=np.int16), ), - ix_substrate=jnp.array([0, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0, 1], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), - species_activation=jnp.array([0, 4], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # METH - kcat_ix=2, - enzyme_ix=2, - km_ix=jnp.array([4], dtype=jnp.int16), - ki_ix=jnp.array([1], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([4], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # GNMT1 - kcat_ix=3, - enzyme_ix=3, - km_ix=jnp.array([5, 6], dtype=jnp.int16), - ki_ix=jnp.array([2], dtype=jnp.int16), - reactant_stoichiometry=jnp.array( - [-1.0, 1.0, -1.0, 1.0], dtype=jnp.int16 + IrreversibleMichaelisMenten( # METH + ix_ki_species=np.array([5], dtype=np.int16) ), - ix_substrate=jnp.array([4, 6], dtype=jnp.int16), - ix_ki_species=jnp.array([5], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 2], dtype=jnp.int16), - subunits=4, - tc_ix=1, - ix_dc_activation=jnp.array([2], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([3], dtype=jnp.int16), - species_inhibition=jnp.array([12], dtype=jnp.int16), - species_activation=jnp.array([4], dtype=jnp.int16), - ), - ReversibleMichaelisMenten( # AHC - kcat_ix=4, - enzyme_ix=4, - km_ix=jnp.array([7, 8, 9], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0, 1.0, 1.0], dtype=jnp.int16), - ix_substrate=jnp.array([5], dtype=jnp.int16), - ix_product=jnp.array([8, 9], dtype=jnp.int16), - ix_reactants=jnp.array([5, 8, 9], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - product_km_positions=jnp.array([1, 2], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - product_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - water_stoichiometry=jnp.array(-1.0), - reactant_to_dgf=jnp.array([5, 8, 9], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # MS - kcat_ix=5, - enzyme_ix=5, - km_ix=jnp.array([10, 11], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 11], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # BHMT - kcat_ix=6, - enzyme_ix=6, - km_ix=jnp.array([12, 13], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 13], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 2], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # CBS1 - kcat_ix=7, - enzyme_ix=7, - km_ix=jnp.array([14, 15], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, -1, 1], dtype=jnp.int16), - ix_substrate=jnp.array([8, 15], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0, 1], dtype=jnp.int16), - subunits=2, - tc_ix=2, - ix_dc_activation=jnp.array([4], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([], dtype=jnp.int16), - ), - AllostericIrreversibleMichaelisMenten( # MTHFR - kcat_ix=8, - enzyme_ix=8, - km_ix=jnp.array([16, 17], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([1, -1, 1, -1], dtype=jnp.int16), - ix_substrate=jnp.array([12, 17], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0, 1], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([1, 3], dtype=jnp.int16), - subunits=2, - tc_ix=3, - ix_dc_activation=jnp.array([6], dtype=jnp.int16), - ix_dc_inhibition=jnp.array([5], dtype=jnp.int16), - species_inhibition=jnp.array([4], dtype=jnp.int16), - species_activation=jnp.array([5], dtype=jnp.int16), - ), - IrreversibleMichaelisMenten( # PROT - kcat_ix=9, - enzyme_ix=9, - km_ix=jnp.array([18], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1.0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ), -] + AllostericIrreversibleMichaelisMenten( # GNMT1 + subunits=4, + ix_allosteric_inhibitors=np.array([12], dtype=np.int16), + ix_allosteric_activators=np.array([4], dtype=np.int16), + ix_ki_species=np.array([5], dtype=np.int16), + ), + ReversibleMichaelisMenten( # AHC + water_stoichiometry=-1.0, + ), + IrreversibleMichaelisMenten(), # MS + IrreversibleMichaelisMenten(), # BHMT + AllostericIrreversibleMichaelisMenten( # CBS1 + subunits=2, + ix_allosteric_inhibitors=np.array([4], dtype=np.int16), + ), + AllostericIrreversibleMichaelisMenten( # MTHFR + subunits=2, + ix_allosteric_inhibitors=np.array([4], dtype=np.int16), + ix_allosteric_activators=np.array([5], dtype=np.int16), + ), + IrreversibleMichaelisMenten(), # PROT + ], +) steady_state = jnp.array( [ 4.233000e-05, # met-L @@ -512,4 +270,4 @@ 6.534400e-06, # 5mthf ] ) -model = RateEquationModel(parameters, structure, rate_equations) +model = RateEquationModel(parameters, structure) diff --git a/src/enzax/kinetic_model.py b/src/enzax/kinetic_model.py index ec6df15..3cf26b2 100644 --- a/src/enzax/kinetic_model.py +++ b/src/enzax/kinetic_model.py @@ -1,29 +1,134 @@ """Module containing enzax's definition of a kinetic model.""" from abc import ABC, abstractmethod +from typing import Any import equinox as eqx import jax.numpy as jnp -from jaxtyping import Array, Float, Int, PyTree, ScalarLike, jaxtyped +import numpy as np +from jax.tree_util import register_pytree_node_class +from jaxtyping import Array, Float, PyTree, ScalarLike, jaxtyped +from numpy.typing import NDArray from typeguard import typechecked from enzax.rate_equation import RateEquation +def get_ix_from_list(s: str, list_of_strings: list[str]): + return next(i for i, si in enumerate(list_of_strings) if si == s) + + +def get_conc(balanced, log_unbalanced, structure): + conc = jnp.zeros(structure.S.shape[0]) + conc = conc.at[structure.balanced_species_ix].set(balanced) + conc = conc.at[structure.unbalanced_species_ix].set(jnp.exp(log_unbalanced)) + return conc + + @jaxtyped(typechecker=typechecked) -class KineticModelStructure(eqx.Module): +@register_pytree_node_class +class KineticModelStructure: """Structural information about a kinetic model.""" - S: Float[Array, " s r"] - balanced_species: Int[Array, " n_balanced"] - unbalanced_species: Int[Array, " n_unbalanced"] + stoichiometry: dict[str, dict[str, float]] + species: list[str] + reactions: list[str] + balanced_species: list[str] + unbalanced_species: list[str] + species_to_dgf_ix: NDArray[np.int16] + balanced_species_ix: NDArray[np.int16] + unbalanced_species_ix: NDArray[np.int16] + S: NDArray[np.float64] + + def __init__( + self, + stoichiometry, + species, + reactions, + balanced_species, + species_to_dgf_ix=None, + ): + self.stoichiometry = stoichiometry + self.species = species + self.reactions = reactions + self.balanced_species = balanced_species + self.unbalanced_species = [ + s for s in species if s not in balanced_species + ] + if species_to_dgf_ix is None: + self.species_to_dgf_ix = np.arange(len(species), dtype=np.int16) + else: + self.species_to_dgf_ix = species_to_dgf_ix + self.balanced_species_ix = np.array( + [get_ix_from_list(s, species) for s in self.balanced_species], + dtype=np.int16, + ) + self.unbalanced_species_ix = np.array( + [get_ix_from_list(s, species) for s in self.unbalanced_species], + dtype=np.int16, + ) + S = np.zeros(shape=(len(species), len(reactions))) + for ix_reaction, reaction in enumerate(reactions): + for species_i, coeff in stoichiometry[reaction].items(): + ix_species = get_ix_from_list(species_i, species) + S[ix_species, ix_reaction] = coeff + self.S = S.astype(np.float64) + + def tree_flatten(self): + children = ( + self.stoichiometry, + self.species, + self.reactions, + self.balanced_species, + self.species_to_dgf_ix, + ) + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) + + +class RateEquationKineticModelStructure(KineticModelStructure): + rate_equations: list[RateEquation] + + def __init__( + self, + stoichiometry, + species, + reactions, + balanced_species, + rate_equations, + species_to_dgf_ix=None, + ): + super().__init__( + stoichiometry, + species, + reactions, + balanced_species, + species_to_dgf_ix, + ) + self.rate_equations = rate_equations + + def tree_flatten(self): + children = ( + self.stoichiometry, + self.species, + self.reactions, + self.balanced_species, + self.species_to_dgf_ix, + self.rate_equations, + ) + aux_data = None + return children, aux_data class KineticModel(eqx.Module, ABC): """Abstract base class for kinetic models.""" parameters: PyTree - structure: KineticModelStructure + structure: KineticModelStructure = eqx.field(static=True) @abstractmethod def flux( @@ -31,6 +136,7 @@ def flux( conc_balanced: Float[Array, " n_balanced"], ) -> Float[Array, " n"]: ... + @eqx.filter_jit def dcdt( self, t: ScalarLike, conc: Float[Array, " n_balanced"], args=None ) -> Float[Array, " n_balanced"]: @@ -43,14 +149,15 @@ def dcdt( :param conc: a one dimensional array of positive floats representing concentrations of balanced species. Must have same size as self.structure.ix_balanced """ # Noqa: E501 - sv = self.structure.S @ self.flux(conc) - return sv[self.structure.balanced_species] + v = self.flux(conc) + sv = self.structure.S @ v + return jnp.array(sv[self.structure.balanced_species_ix]) class RateEquationModel(KineticModel): """A kinetic model that specifies its fluxes using RateEquation objects.""" - rate_equations: list[RateEquation] = eqx.field(default_factory=list) + structure: RateEquationKineticModelStructure def flux( self, @@ -63,19 +170,27 @@ def flux( :return: a one dimensional array of (possibly negative) floats representing reaction fluxes. Has same size as number of columns of self.structure.S. """ # Noqa: E501 - conc = jnp.zeros(self.structure.S.shape[0]) - conc = conc.at[self.structure.balanced_species].set(conc_balanced) - conc = conc.at[self.structure.unbalanced_species].set( - jnp.exp(self.parameters.log_conc_unbalanced) + conc = get_conc( + conc_balanced, + self.parameters.log_conc_unbalanced, + self.structure, ) - t = [f(conc, self.parameters) for f in self.rate_equations] - out = jnp.array(t) - return out + flux_list = [] + for reaction_ix, (reaction_id, rate_equation) in enumerate( + zip(self.structure.reactions, self.structure.rate_equations) + ): + ipt = rate_equation.get_input( + parameters=self.parameters, + reaction_id=reaction_id, + reaction_stoichiometry=self.structure.S[:, reaction_ix], + species_to_dgf_ix=self.structure.species_to_dgf_ix, + ) + flux_list.append(rate_equation(conc, ipt)) + return jnp.array(flux_list) class KineticModelSbml(KineticModel): - balanced_ids: PyTree - sym_module: any + sym_module: Any def flux( self, @@ -83,7 +198,8 @@ def flux( ) -> Float[Array, " n"]: flux = jnp.array( self.sym_module( - **self.parameters, **dict(zip(self.balanced_ids, conc_balanced)) + **self.parameters, + **dict(zip(self.structure.balanced_species, conc_balanced)), ) ) return flux diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index 10d16dd..66936b1 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -7,19 +7,10 @@ import blackjax import chex import jax -from jax._src.random import KeyArray import jax.numpy as jnp from jax.scipy.stats import norm, multivariate_normal from jaxtyping import Array, Float, PyTree, ScalarLike -from enzax.kinetic_model import ( - RateEquationModel, - KineticModelStructure, -) -from enzax.parameters import AllostericMichaelisMentenParameterSet -from enzax.rate_equation import RateEquation -from enzax.steady_state import get_kinetic_model_steady_state - @chex.dataclass class ObservationSet: @@ -33,25 +24,6 @@ class ObservationSet: enzyme_scale: ScalarLike -@chex.dataclass -class AllostericMichaelisMentenPriorSet: - """Priors for an allosteric Michaelis-Menten model.""" - - log_kcat: Float[Array, "2 n_enzyme"] - log_enzyme: Float[Array, "2 n_enzyme"] - log_drain: Float[Array, "2 n_drain"] - dgf: tuple[ - Float[Array, " n_metabolite"], - Float[Array, " n_metabolite n_metabolite"], - ] - log_km: Float[Array, "2 n_km"] - log_ki: Float[Array, "2 n_ki"] - log_conc_unbalanced: Float[Array, "2 n_unbalanced"] - temperature: Float[Array, "2"] - log_transfer_constant: Float[Array, "2 n_allosteric_enzyme"] - log_dissociation_constant: Float[Array, "2 n_allosteric_effect"] - - class AdaptationKwargs(TypedDict): """Keyword arguments to the blackjax function window_adaptation.""" @@ -76,52 +48,6 @@ def mv_normal_prior_logdensity( ) -def posterior_logdensity_amm( - parameters: AllostericMichaelisMentenParameterSet, - structure: KineticModelStructure, - rate_equations: list[RateEquation], - obs: ObservationSet, - prior: AllostericMichaelisMentenPriorSet, - guess: Float[Array, " n_balanced"], -): - """Get the log density for an allosteric Michaelis-Menten model.""" - model = RateEquationModel(parameters, structure, rate_equations) - steady = get_kinetic_model_steady_state(model, guess) - flux = model.flux(steady) - conc = jnp.zeros(model.structure.S.shape[0]) - conc = conc.at[model.structure.balanced_species].set(steady) - conc = conc.at[model.structure.unbalanced_species].set( - jnp.exp(parameters.log_conc_unbalanced) - ) - likelihood_logdensity = ( - norm.logpdf(jnp.log(obs.conc), jnp.log(conc), obs.conc_scale).sum() - + norm.logpdf(obs.flux, flux[0], obs.flux_scale).sum() - + norm.logpdf( - jnp.log(obs.enzyme), parameters.log_enzyme, obs.enzyme_scale - ).sum() - ) - prior_logdensity = ( - ind_normal_prior_logdensity(parameters.log_kcat, prior.log_kcat) - + ind_normal_prior_logdensity(parameters.log_enzyme, prior.log_enzyme) - + ind_normal_prior_logdensity(parameters.log_drain, prior.log_drain) - + mv_normal_prior_logdensity(parameters.dgf, prior.dgf) - + ind_normal_prior_logdensity(parameters.log_km, prior.log_km) - + ind_normal_prior_logdensity( - parameters.log_conc_unbalanced, prior.log_conc_unbalanced - ) - + ind_normal_prior_logdensity(parameters.temperature, prior.temperature) - + ind_normal_prior_logdensity(parameters.log_ki, prior.log_ki) - + ind_normal_prior_logdensity( - parameters.log_transfer_constant, prior.log_transfer_constant - ) - + ind_normal_prior_logdensity( - parameters.log_dissociation_constant, - prior.log_dissociation_constant, - ) - ) - return prior_logdensity + likelihood_logdensity - - @functools.partial(jax.jit, static_argnames=["kernel", "num_samples"]) def _inference_loop(rng_key, kernel, initial_state, num_samples): """Run MCMC with blackjax.""" @@ -137,7 +63,7 @@ def one_step(state, rng_key): def run_nuts( logdensity_fn: Callable, - rng_key: KeyArray, + rng_key: Array, init_parameters: PyTree, num_warmup: int, num_samples: int, @@ -179,10 +105,17 @@ def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike): def get_idata(samples, info, coords=None, dims=None) -> az.InferenceData: """Get an arviz InferenceData from a blackjax NUTS output.""" - sample_dict = { - k: jnp.expand_dims(getattr(samples.position, k), 0) - for k in samples.position.__dataclass_fields__.keys() - } + if coords is None: + coords = dict() + sample_dict = dict() + for k in samples.position.__dataclass_fields__.keys(): + samples_k = getattr(samples.position, k) + if isinstance(samples_k, Array): + sample_dict[k] = jnp.expand_dims(samples_k, 0) + elif isinstance(samples_k, dict): + sample_dict[k] = jnp.expand_dims( + jnp.concat([v.T for v in samples_k.values()]).T, 0 + ) posterior = az.convert_to_inference_data( sample_dict, group="posterior", diff --git a/src/enzax/parameters.py b/src/enzax/parameters.py deleted file mode 100644 index 6534169..0000000 --- a/src/enzax/parameters.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Module with parameters and parameter sets. - -These are not required, but they do provide handy shape checking, for example to ensure that your parameter set has the same number of enzymes and kcat parameters. - -""" # Noqa: E501 - -import equinox as eqx -from jaxtyping import Array, Float, Scalar, jaxtyped -from typeguard import typechecked - -LogKcat = Float[Array, " n_enzyme"] -LogEnzyme = Float[Array, " n_enzyme"] -Dgf = Float[Array, " n_metabolite"] -LogKm = Float[Array, " n_km"] -LogKi = Float[Array, " n_ki"] -LogConcUnbalanced = Float[Array, " n_unbalanced"] -LogDrain = Float[Array, " n_drain"] -LogTransferConstant = Float[Array, " n_allosteric_enzyme"] -LogDissociationConstant = Float[Array, " n_allosteric_effect"] - - -@jaxtyped(typechecker=typechecked) -class MichaelisMentenParameterSet(eqx.Module): - """Parameters for a model with Michaelis Menten kinetics. - - This kind of parameter set supports models with the following rate laws: - - - enzax.rate_equations.drain.Drain - - enzax.rate_equations.michaelis_menten.IrreversibleMichaelisMenten - - enzax.rate_equations.michaelis_menten.ReversibleMichaelisMenten - - """ - - log_kcat: LogKcat - log_enzyme: LogEnzyme - dgf: Dgf - log_km: LogKm - log_ki: LogKi - log_conc_unbalanced: LogConcUnbalanced - temperature: Scalar - log_drain: LogDrain - - -class AllostericMichaelisMentenParameterSet(MichaelisMentenParameterSet): - """Parameters for a model with Michaelis Menten kinetics, with allostery. - - Reactions can be any out of: - - - drain.Drain - - michaelis_menten.IrreversibleMichaelisMenten - - michaelis_menten.ReversibleMichaelisMenten - - generalised_mwc.AllostericIrreversibleMichaelisMenten - - generalised_mwc.AllostericReversibleMichaelisMenten - - """ - - log_transfer_constant: LogTransferConstant - log_dissociation_constant: LogDissociationConstant diff --git a/src/enzax/rate_equation.py b/src/enzax/rate_equation.py index aa89443..883e9a9 100644 --- a/src/enzax/rate_equation.py +++ b/src/enzax/rate_equation.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from equinox import Module +import numpy as np +from numpy.typing import NDArray from jaxtyping import Array, Float, PyTree, Scalar @@ -12,8 +14,19 @@ class RateEquation(Module, ABC): """Abstract definition of a rate equation. - A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of parameters, returning a scalar value representing a single flux. + A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of other inputs, returning a scalar value representing a single flux. """ # Noqa: E501 @abstractmethod - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: ... + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ) -> PyTree: ... + + @abstractmethod + def __call__( + self, conc: ConcArray, rate_equation_input: PyTree + ) -> Scalar: ... diff --git a/src/enzax/rate_equations/__init__.py b/src/enzax/rate_equations/__init__.py index a0da47d..32da2d4 100644 --- a/src/enzax/rate_equations/__init__.py +++ b/src/enzax/rate_equations/__init__.py @@ -1,8 +1,8 @@ from enzax.rate_equations.michaelis_menten import ( ReversibleMichaelisMenten, IrreversibleMichaelisMenten, - MichaelisMenten, ) + from enzax.rate_equations.generalised_mwc import ( AllostericReversibleMichaelisMenten, AllostericIrreversibleMichaelisMenten, @@ -12,7 +12,6 @@ AVAILABLE_RATE_EQUATIONS = [ ReversibleMichaelisMenten, IrreversibleMichaelisMenten, - MichaelisMenten, AllostericReversibleMichaelisMenten, AllostericIrreversibleMichaelisMenten, Drain, diff --git a/src/enzax/rate_equations/drain.py b/src/enzax/rate_equations/drain.py index 0e1994c..17833c6 100644 --- a/src/enzax/rate_equations/drain.py +++ b/src/enzax/rate_equations/drain.py @@ -1,27 +1,30 @@ +import equinox as eqx from jax import numpy as jnp +import numpy as np +from numpy.typing import NDArray from jaxtyping import PyTree, Scalar from enzax.rate_equation import ConcArray, RateEquation -def get_drain_flux(sign: Scalar, log_v: Scalar) -> Scalar: - """Get the flux of a drain reaction. - - :param sign: a scalar value (should be either one or zero) representing the direction of the reaction. - - :param log_v: a scalar representing the magnitude of the reaction, on log scale. - - """ # Noqa: E501 - return sign * jnp.exp(log_v) +class DrainInput(eqx.Module): + abs_v: Scalar class Drain(RateEquation): """A drain reaction.""" - sign: Scalar - drain_ix: int + sign: float + + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return DrainInput(abs_v=jnp.exp(parameters.log_drain[reaction_id])) - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: + def __call__(self, conc: ConcArray, drain_input: PyTree) -> Scalar: """Get the flux of a drain reaction.""" - log_v = parameters.log_drain[self.drain_ix] - return get_drain_flux(self.sign, log_v) + return self.sign * drain_input.abs_v diff --git a/src/enzax/rate_equations/generalised_mwc.py b/src/enzax/rate_equations/generalised_mwc.py index 90f867e..17b6edb 100644 --- a/src/enzax/rate_equations/generalised_mwc.py +++ b/src/enzax/rate_equations/generalised_mwc.py @@ -1,14 +1,88 @@ +import equinox as eqx from jax import numpy as jnp -from jaxtyping import Array, Float, Int, PyTree, Scalar +from jaxtyping import Array, Float, PyTree, Scalar +import numpy as np +from numpy.typing import NDArray -from enzax.rate_equation import ConcArray from enzax.rate_equations.michaelis_menten import ( + free_enzyme_ratio_imm, + free_enzyme_ratio_rmm, IrreversibleMichaelisMenten, - MichaelisMenten, + IrreversibleMichaelisMentenInput, ReversibleMichaelisMenten, + ReversibleMichaelisMentenInput, ) +class AllostericIrreversibleMichaelisMentenInput( + IrreversibleMichaelisMentenInput +): + dc_inhibitor: Float[Array, " n_inhibitor"] + dc_activator: Float[Array, " n_activator"] + tc: Scalar + + +class AllostericReversibleMichaelisMentenInput(ReversibleMichaelisMentenInput): + dc_inhibitor: Float[Array, " n_inhibitor"] + dc_activator: Float[Array, " n_activator"] + tc: Scalar + + +def get_allosteric_irreversible_michaelis_menten_input( + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], +) -> AllostericIrreversibleMichaelisMentenInput: + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + return AllostericIrreversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), + ix_substrate=ix_substrate, + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], + ix_ki_species=ci_ix, + ki=jnp.exp(parameters.log_ki[reaction_id]), + dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[reaction_id]), + dc_activator=jnp.exp(parameters.log_dc_activator[reaction_id]), + tc=jnp.exp(parameters.log_tc[reaction_id]), + ) + + +def get_allosteric_reversible_michaelis_menten_input( + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], + water_stoichiometry: float, +) -> AllostericReversibleMichaelisMentenInput: + ix_reactant = np.argwhere(reaction_stoichiometry != 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + ix_product = np.argwhere(reaction_stoichiometry > 0.0).flatten() + return AllostericReversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + product_kms=jnp.exp(parameters.log_product_km[reaction_id]), + ki=jnp.exp(parameters.log_ki[reaction_id]), + dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], + temperature=parameters.temperature, + ix_ki_species=ci_ix, + ix_reactant=ix_reactant, + ix_substrate=ix_substrate, + ix_product=ix_product, + reactant_stoichiometry=reaction_stoichiometry[ix_reactant], + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], + product_stoichiometry=reaction_stoichiometry[ix_product], + water_stoichiometry=water_stoichiometry, + dc_inhibitor=jnp.exp(parameters.log_dc_inhibitor[reaction_id]), + dc_activator=jnp.exp(parameters.log_dc_activator[reaction_id]), + tc=jnp.exp(parameters.log_tc[reaction_id]), + ) + + def generalised_mwc_effect( conc_inhibitor: Float[Array, " n_inhibition"], dc_inhibitor: Float[Array, " n_inhibition"], @@ -29,66 +103,109 @@ def generalised_mwc_effect( return out -class GeneralisedMWC(MichaelisMenten): - """Mixin class for allosteric rate laws, assuming generalised MWC kinetics. - - See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2. - - Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation. - """ # noqa: E501 - - subunits: int - tc_ix: int - ix_dc_activation: Int[Array, " n_activation"] - ix_dc_inhibition: Int[Array, " n_inhibition"] - species_activation: Int[Array, " n_activation"] - species_inhibition: Int[Array, " n_inhibition"] - - def get_tc(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_transfer_constant[self.tc_ix]) +class AllostericIrreversibleMichaelisMenten(IrreversibleMichaelisMenten): + """A reaction with irreversible Michaelis Menten kinetics and allostery.""" - def get_dc_activation(self, parameters: PyTree) -> Scalar: - return jnp.exp( - parameters.log_dissociation_constant[self.ix_dc_activation], + ix_allosteric_inhibitors: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + ix_allosteric_activators: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + subunits: int = 1 + + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_allosteric_irreversible_michaelis_menten_input( + parameters=parameters, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.ix_ki_species, ) - def get_dc_inhibition(self, parameters: PyTree) -> Scalar: - return jnp.exp( - parameters.log_dissociation_constant[self.ix_dc_inhibition], + def __call__( + self, + conc: Float[Array, " n"], + aimm_input: AllostericIrreversibleMichaelisMentenInput, + ) -> Scalar: + """The flux of an irreversible allosteric Michaelis Menten reaction.""" + fer = free_enzyme_ratio_imm( + substrate_conc=conc[aimm_input.ix_substrate], + substrate_km=aimm_input.substrate_kms, + ki=aimm_input.ki, + inhibitor_conc=conc[self.ix_ki_species], + substrate_stoichiometry=aimm_input.substrate_stoichiometry, ) - - def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar: - return generalised_mwc_effect( - conc_inhibitor=conc[self.species_inhibition], - conc_activator=conc[self.species_activation], - free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters), - tc=self.get_tc(parameters), - dc_inhibitor=self.get_dc_inhibition(parameters), - dc_activator=self.get_dc_activation(parameters), + allosteric_effect = generalised_mwc_effect( + conc_inhibitor=conc[self.ix_allosteric_inhibitors], + dc_inhibitor=aimm_input.dc_inhibitor, + dc_activator=aimm_input.dc_activator, + conc_activator=conc[self.ix_allosteric_activators], + free_enzyme_ratio=fer, + tc=aimm_input.tc, subunits=self.subunits, ) - - -class AllostericIrreversibleMichaelisMenten( - GeneralisedMWC, IrreversibleMichaelisMenten -): - """A reaction with irreversible Michaelis Menten kinetics and allostery.""" - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """The flux of an irreversible allosteric Michaelis Menten reaction.""" - allosteric_effect = self.allosteric_effect(conc, parameters) - non_allosteric_rate = super().__call__(conc, parameters) + non_allosteric_rate = super().__call__(conc, aimm_input) return non_allosteric_rate * allosteric_effect -class AllostericReversibleMichaelisMenten( - GeneralisedMWC, - ReversibleMichaelisMenten, -): +class AllostericReversibleMichaelisMenten(ReversibleMichaelisMenten): """A reaction with reversible Michaelis Menten kinetics and allostery.""" - def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: - """The flux of an allosteric reversible Michaelis Menten reaction.""" - allosteric_effect = self.allosteric_effect(conc, parameters) - non_allosteric_rate = super().__call__(conc, parameters) + ix_allosteric_inhibitors: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + ix_allosteric_activators: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + subunits: int = 1 + + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_allosteric_reversible_michaelis_menten_input( + parameters=parameters, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.ix_ki_species, + water_stoichiometry=self.water_stoichiometry, + ) + + def __call__( + self, + conc: Float[Array, " n"], + armm_input: AllostericReversibleMichaelisMentenInput, + ) -> Scalar: + """The flux of an irreversible allosteric Michaelis Menten reaction.""" + fer = free_enzyme_ratio_rmm( + substrate_conc=conc[armm_input.ix_substrate], + product_conc=conc[armm_input.ix_product], + substrate_kms=armm_input.substrate_kms, + product_kms=armm_input.product_kms, + ix_ki_species=conc[self.ix_ki_species], + ki=armm_input.ki, + substrate_stoichiometry=armm_input.substrate_stoichiometry, + product_stoichiometry=armm_input.product_stoichiometry, + ) + allosteric_effect = generalised_mwc_effect( + conc_inhibitor=conc[self.ix_allosteric_inhibitors], + dc_inhibitor=armm_input.dc_inhibitor, + dc_activator=armm_input.dc_activator, + conc_activator=conc[self.ix_allosteric_activators], + free_enzyme_ratio=fer, + tc=armm_input.tc, + subunits=self.subunits, + ) + non_allosteric_rate = super().__call__(conc, armm_input) return non_allosteric_rate * allosteric_effect diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index 55cf1d0..39d96d7 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -1,110 +1,106 @@ -from abc import abstractmethod +import equinox as eqx +import numpy as np from jax import numpy as jnp from jaxtyping import Array, Float, Int, PyTree, Scalar +from numpy.typing import NDArray + +from enzax.rate_equation import RateEquation + + +class IrreversibleMichaelisMentenInput(eqx.Module): + kcat: Scalar + enzyme: Scalar + ix_ki_species: NDArray[np.int16] + ki: Float[Array, " n_ki"] + ix_substrate: NDArray[np.int16] + substrate_kms: Float[Array, " n_substrate"] + substrate_stoichiometry: NDArray[np.float64] + + +def get_irreversible_michaelis_menten_input( + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], +) -> IrreversibleMichaelisMentenInput: + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + return IrreversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), + ix_substrate=ix_substrate, + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], + ix_ki_species=ci_ix, + ki=jnp.exp(parameters.log_ki[reaction_id]), + ) + -from enzax.rate_equation import RateEquation, ConcArray +class ReversibleMichaelisMentenInput(eqx.Module): + kcat: Scalar + enzyme: Scalar + ki: Float[Array, " n_ki"] + substrate_kms: Float[Array, " n_substrate"] + product_kms: Float[Array, " n_product"] + dgf: Float[Array, " n_reactant"] + temperature: Scalar + ix_ki_species: NDArray[np.int16] + ix_reactant: NDArray[np.int16] + ix_substrate: NDArray[np.int16] + ix_product: NDArray[np.int16] + reactant_stoichiometry: NDArray[np.float64] + substrate_stoichiometry: NDArray[np.float64] + product_stoichiometry: NDArray[np.float64] + water_stoichiometry: float + + +def get_reversible_michaelis_menten_input( + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ci_ix: NDArray[np.int16], + water_stoichiometry: float, +) -> ReversibleMichaelisMentenInput: + ix_reactant = np.argwhere(reaction_stoichiometry != 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + ix_product = np.argwhere(reaction_stoichiometry > 0.0).flatten() + return ReversibleMichaelisMentenInput( + kcat=jnp.exp(parameters.log_kcat[reaction_id]), + enzyme=jnp.exp(parameters.log_enzyme[reaction_id]), + substrate_kms=jnp.exp(parameters.log_substrate_km[reaction_id]), + product_kms=jnp.exp(parameters.log_product_km[reaction_id]), + ki=jnp.exp(parameters.log_ki[reaction_id]), + dgf=parameters.dgf[species_to_dgf_ix][ix_reactant], + temperature=parameters.temperature, + ix_ki_species=ci_ix, + ix_reactant=ix_reactant, + ix_substrate=ix_substrate, + ix_product=ix_product, + reactant_stoichiometry=reaction_stoichiometry[ix_reactant], + substrate_stoichiometry=reaction_stoichiometry[ix_substrate], + product_stoichiometry=reaction_stoichiometry[ix_product], + water_stoichiometry=water_stoichiometry, + ) def numerator_mm( - conc: ConcArray, - km: Float[Array, " n"], - ix_substrate: Int[Array, " n_substrate"], - substrate_km_positions: Int[Array, " n_substrate"], + substrate_conc: Float[Array, " n_substrate"], + substrate_kms: Int[Array, " n_substrate"], ) -> Scalar: """Get the product of each substrate's concentration over its km. This quantity is the numerator in a Michaelis Menten reaction's rate equation """ # Noqa: E501 - return jnp.prod((conc[ix_substrate] / km[substrate_km_positions])) - - -class MichaelisMenten(RateEquation): - """Base class for Michaelis Menten rate equations. - - Subclasses need to implement the __call__ and free_enzyme_ratio methods. - - """ - - kcat_ix: int - enzyme_ix: int - km_ix: Int[Array, " n"] - ki_ix: Int[Array, " n_ki"] - reactant_stoichiometry: Float[Array, " n"] - ix_substrate: Int[Array, " n_substrate"] - ix_ki_species: Int[Array, " n_ki"] - substrate_km_positions: Int[Array, " n_substrate"] - substrate_reactant_positions: Int[Array, " n_substrate"] - - def get_kcat(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_kcat[self.kcat_ix]) - - def get_km(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_km[self.km_ix]) - - def get_ki(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_ki[self.ki_ix]) - - def get_enzyme(self, parameters: PyTree) -> Scalar: - return jnp.exp(parameters.log_enzyme[self.enzyme_ix]) - - @abstractmethod - def free_enzyme_ratio( - self, - conc: ConcArray, - parameters: PyTree, - ) -> Scalar: ... - - -def free_enzyme_ratio_imm( - conc_sub: Float[Array, " n_substrate"], - km_sub: Float[Array, " n_substrate"], - stoich_sub: Float[Array, " n_substrate"], - ki: Float[Array, " n_ki"], - conc_inhibitor: Float[Array, " n_ki"], -) -> Scalar: - """Free enzyme ratio for irreversible Michaelis Menten reactions.""" - return 1.0 / ( - jnp.prod(((conc_sub / km_sub) + 1) ** jnp.abs(stoich_sub)) - + jnp.sum(conc_inhibitor / ki) - ) - - -class IrreversibleMichaelisMenten(MichaelisMenten): - """A reaction with irreversible Michaelis Menten kinetics.""" - - def free_enzyme_ratio(self, conc, parameters): - return free_enzyme_ratio_imm( - conc_sub=conc[self.ix_substrate], - km_sub=self.get_km(parameters)[self.substrate_km_positions], - ki=self.get_ki(parameters), - conc_inhibitor=conc[self.ix_ki_species], - stoich_sub=self.reactant_stoichiometry[ - self.substrate_reactant_positions - ], - ) - - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: - """Get flux of a reaction with irreversible Michaelis Menten kinetics.""" # noqa: E501 - kcat = self.get_kcat(parameters) - enzyme = self.get_enzyme(parameters) - km = self.get_km(parameters) - numerator = numerator_mm( - conc=conc, - km=km, - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, - ) - free_enzyme_ratio = self.free_enzyme_ratio(conc, parameters) - return kcat * enzyme * numerator * free_enzyme_ratio + return jnp.prod((substrate_conc / substrate_kms)) def get_reversibility( - conc: Float[Array, " n"], - water_stoichiometry: Scalar, + reactant_conc: Float[Array, " n_reactant"], dgf: Float[Array, " n_reactant"], temperature: Scalar, - reactant_stoichiometry: Float[Array, " n_reactant"], - ix_reactants: Int[Array, " n_reactant"], + reactant_stoichiometry: NDArray[np.float64], + water_stoichiometry: float, ) -> Scalar: """Get the reversibility of a reaction. @@ -113,81 +109,151 @@ def get_reversibility( """ RT = temperature * 0.008314 dgf_water = -150.9 - dgr = reactant_stoichiometry @ dgf + water_stoichiometry * dgf_water - quotient = reactant_stoichiometry @ jnp.log(conc[ix_reactants]) - out = 1.0 - jnp.exp(((dgr + RT * quotient) / RT)) + dgr = ( + reactant_stoichiometry.T @ dgf + water_stoichiometry * dgf_water + ).flatten() + quotient = (reactant_stoichiometry.T @ jnp.log(reactant_conc)).flatten() + out = 1.0 - jnp.exp((dgr + RT * quotient) / RT)[0] return out +def free_enzyme_ratio_imm( + substrate_conc: Float[Array, " n_substrate"], + substrate_km: Float[Array, " n_substrate"], + ki: Float[Array, " n_ki"], + inhibitor_conc: Float[Array, " n_ki"], + substrate_stoichiometry: NDArray[np.float64], +) -> Scalar: + """Free enzyme ratio for irreversible Michaelis Menten reactions.""" + return 1.0 / ( + jnp.prod( + ((substrate_conc / substrate_km) + 1) + ** jnp.abs(substrate_stoichiometry) + ) + + jnp.sum(inhibitor_conc / ki) + ) + + def free_enzyme_ratio_rmm( - conc_sub: Float[Array, " n_substrate"], - km_sub: Float[Array, " n_substrate"], - stoich_sub: Float[Array, " n_substrate"], - conc_prod: Float[Array, " n_product"], - km_prod: Float[Array, " n_prod"], - stoich_prod: Float[Array, " n_prod"], - conc_inhibitor: Float[Array, " n_ki"], + substrate_conc: Float[Array, " n_substrate"], + product_conc: Float[Array, " n_product"], + substrate_kms: Float[Array, " n_substrate"], + product_kms: Float[Array, " n_product"], + ix_ki_species: Float[Array, " n_ki"], ki: Float[Array, " n_ki"], + substrate_stoichiometry: NDArray[np.float64], + product_stoichiometry: NDArray[np.float64], ) -> Scalar: """The free enzyme ratio for a reversible Michaelis Menten reaction.""" return 1.0 / ( -1.0 - + jnp.prod(((conc_sub / km_sub) + 1.0) ** jnp.abs(stoich_sub)) - + jnp.prod(((conc_prod / km_prod) + 1.0) ** jnp.abs(stoich_prod)) - + jnp.sum(conc_inhibitor / ki) + + jnp.prod( + ((substrate_conc / substrate_kms) + 1.0) + ** jnp.abs(substrate_stoichiometry) + ) + + jnp.prod( + ((product_conc / product_kms) + 1.0) + ** jnp.abs(product_stoichiometry) + ) + + jnp.sum(ix_ki_species / ki) ) -class ReversibleMichaelisMenten(MichaelisMenten): +class IrreversibleMichaelisMenten(RateEquation): + """A reaction with irreversible Michaelis Menten kinetics.""" + + ix_ki_species: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int64) + ) + + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_irreversible_michaelis_menten_input( + parameters=parameters, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.ix_ki_species, + ) + + def __call__( + self, + conc: Float[Array, " n"], + imm_input: IrreversibleMichaelisMentenInput, + ) -> Scalar: + """Get flux of a reaction with irreversible Michaelis Menten kinetics.""" # noqa: E501 + numerator = numerator_mm( + substrate_conc=conc[imm_input.ix_substrate], + substrate_kms=imm_input.substrate_kms, + ) + fer = free_enzyme_ratio_imm( + substrate_conc=conc[imm_input.ix_substrate], + substrate_km=imm_input.substrate_kms, + ki=imm_input.ki, + inhibitor_conc=conc[imm_input.ix_ki_species], + substrate_stoichiometry=imm_input.substrate_stoichiometry, + ) + return imm_input.kcat * imm_input.enzyme * numerator * fer + + +class ReversibleMichaelisMenten(RateEquation): """A reaction with reversible Michaelis Menten kinetics.""" - ix_product: Int[Array, " n_product"] - ix_reactants: Int[Array, " n_reactant"] - product_reactant_positions: Int[Array, " n_product"] - product_km_positions: Int[Array, " n_product"] - water_stoichiometry: Scalar - reactant_to_dgf: Int[Array, " n_reactant"] - - def _get_dgf(self, parameters: PyTree): - return parameters.dgf[self.reactant_to_dgf] - - def free_enzyme_ratio(self, conc: ConcArray, parameters: PyTree) -> Scalar: - return free_enzyme_ratio_rmm( - conc_sub=conc[self.ix_substrate], - km_sub=self.get_km(parameters)[self.substrate_reactant_positions], - stoich_sub=self.reactant_stoichiometry[ - self.substrate_reactant_positions - ], - conc_prod=conc[self.ix_product], - km_prod=self.get_km(parameters)[self.product_reactant_positions], - stoich_prod=self.reactant_stoichiometry[ - self.product_reactant_positions - ], - conc_inhibitor=conc[self.ix_ki_species], - ki=self.get_ki(parameters), + ix_ki_species: NDArray[np.int16] = eqx.field( + default_factory=lambda: np.array([], dtype=np.int16) + ) + water_stoichiometry: float = eqx.field(default_factory=lambda: 0.0) + + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + return get_reversible_michaelis_menten_input( + parameters=parameters, + reaction_id=reaction_id, + reaction_stoichiometry=reaction_stoichiometry, + species_to_dgf_ix=species_to_dgf_ix, + ci_ix=self.ix_ki_species, + water_stoichiometry=self.water_stoichiometry, ) - def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar: + def __call__( + self, + conc: Float[Array, " n"], + rmm_input: ReversibleMichaelisMentenInput, + ) -> Scalar: """Get flux of a reaction with reversible Michaelis Menten kinetics. :param conc: A 1D array of non-negative numbers representing concentrations of the species that the reaction produces and consumes. """ # noqa: E501 - kcat = self.get_kcat(parameters) - enzyme = self.get_enzyme(parameters) - reversibility = get_reversibility( - conc=conc, - water_stoichiometry=self.water_stoichiometry, - dgf=self._get_dgf(parameters), - temperature=parameters.temperature, - reactant_stoichiometry=self.reactant_stoichiometry, - ix_reactants=self.ix_reactants, + rev = get_reversibility( + reactant_conc=conc[rmm_input.ix_reactant], + reactant_stoichiometry=rmm_input.reactant_stoichiometry, + dgf=rmm_input.dgf, + temperature=rmm_input.temperature, + water_stoichiometry=rmm_input.water_stoichiometry, ) numerator = numerator_mm( - conc=conc, - km=self.get_km(parameters), - ix_substrate=self.ix_substrate, - substrate_km_positions=self.substrate_km_positions, + substrate_conc=conc[rmm_input.ix_substrate], + substrate_kms=rmm_input.substrate_kms, + ) + fer = free_enzyme_ratio_rmm( + substrate_conc=conc[rmm_input.ix_substrate], + product_conc=conc[rmm_input.ix_product], + ix_ki_species=conc[rmm_input.ix_ki_species], + substrate_kms=rmm_input.substrate_kms, + product_kms=rmm_input.product_kms, + substrate_stoichiometry=rmm_input.substrate_stoichiometry, + product_stoichiometry=rmm_input.product_stoichiometry, + ki=rmm_input.ki, ) - free_enzyme_ratio = self.free_enzyme_ratio(conc, parameters) - return reversibility * kcat * enzyme * numerator * free_enzyme_ratio + return rev * rmm_input.kcat * rmm_input.enzyme * numerator * fer diff --git a/src/enzax/steady_state.py b/src/enzax/steady_state.py index a89f745..7437ef4 100644 --- a/src/enzax/steady_state.py +++ b/src/enzax/steady_state.py @@ -51,7 +51,6 @@ def get_kinetic_model_steady_state( t1, dt0, guess, - args=model, max_steps=max_steps, stepsize_controller=controller, event=event, diff --git a/tests/data/expected_methionine_gradient.json b/tests/data/expected_methionine_gradient.json new file mode 100644 index 0000000..83abf48 --- /dev/null +++ b/tests/data/expected_methionine_gradient.json @@ -0,0 +1 @@ +[-51.008806318357124, 44.62125359725783, 46.668133361083704, -112.7484826050269, -0.21853936474188224, -20.04100693766347, -21.046992055692126, -27.029305319714087, -23.107933540053693, -10.831961241573868, -7.442956276058797, -28.82230903289454, -22.87688514741915, -63.96941462805278, 91.03560765315373, 10.573024526064392, -26.068355084866106, 12.720942038333074, 7.723972171145526, 51.31272574643639, -52.55796473716283, 122.11340482806465, 19.857181701644674, 37.17973577495537, 11.212172526568366, 120.58711257251609, 69.99249085550201, -94.45124667621965, 51.204443757792795, 52.405258216632795, -45.05001949589764, -149.04233199339694, 27.62149868641839, 36.91166292624504, 109.14923727725667, 133.32327445006953, 378.63173605788427, -20.25453860618623, 17.98290589428491, -0.008014753935453442, 1.390914925131441, 1.0827290721714595, 0.0, 0.0, 0.0, 0.0, 0.0, 5.774425772898505, 0.0, 0.0, -5.774425772898505, -5.774425772898505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.455892028402676, -33.418521957445726, -63.96941505648969, -63.96941589328896, -42.922426803111094, -63.96941505649067, -86.00714492908733, -63.9694153507284, -155.31948281663819, -110.63755110207094, -63.96941623344747, -63.75087668807969, -63.969417159576764, -74.54244120146943, -63.969415350730365, 0.00030317394151436246, 0.3144610971331224, -2.3261150730556386, -14.115986977104384, -0.03979371864763597, -0.055215199560989005, 0.2784311817245577, -0.0006063958903968876, -5.783239063037292, -0.3295244858681964, 1.710877943445662, -336.4591824327092] diff --git a/tests/data/methionine_pldf_grad.json b/tests/data/methionine_pldf_grad.json deleted file mode 100644 index 1d4bf6d..0000000 --- a/tests/data/methionine_pldf_grad.json +++ /dev/null @@ -1,107 +0,0 @@ -{ - "log_kcat": { - "MAT1": -8.515730409411995, - "MAT3": -2.5685756593727778, - "METH-Gen": -27.620460850911506, - "GNMT1": -4.548438844490129, - "AHC1": -11.75344338177346, - "MS1": -16.03214664980906, - "BHMT1": 12.038008848152892, - "CBS1": -27.970086306108342, - "MTHFR1": 21.633517021058395, - "PROT1": -11.72862587853564 - }, - "log_enzyme": { - "MAT1": 6.136798020189005, - "MAT3": 12.083952770228223, - "METH-Gen": -12.96793242131051, - "GNMT1": 10.10408958511087, - "AHC1": 2.8990850478275423, - "MS1": -1.379618220208057, - "BHMT1": 26.690537277753894, - "CBS1": -13.317557876507342, - "MTHFR1": 36.286045450659394, - "PROT1": 2.9239025510653605 - }, - "log_drain": { - "the_drain": 77.11394277227623 - }, - "log_km": { - "met-L MAT1": 6.190853781501735, - "atp MAT1": 5.292693046521437, - "met-L MAT3": 2.4814738123803086, - "atp MAT3": 1.705092962802143, - "amet METH-Gen": 6.601750014600864, - "amet GNMT1": 4.590545886354766, - "gly GNMT1": 4.820977250097301, - "ahcys AHC1": 11.683829123279876, - "hcys-L AHC1": -2.9137920929850405, - "adn AHC1": -1.769215504424688, - "hcys-L MS1": 5.240078482919715, - "5mthf MS1": 14.652528429601952, - "hcys-L BHMT1": -10.22016805310029, - "glyb BHMT1": -10.688987009618199, - "hcys-L CBS1": 25.82504526708683, - "ser-L CBS1": 0.049935152823941564, - "mlthf MTHFR1": -20.851184438848044, - "nadph MTHFR1": -2.4216899858227414, - "met-L PROT1": 5.971084568691083 - }, - "dgf": { - "met-L": 0.0, - "atp": 0.0, - "pi": 0.0, - "ppi": 0.0, - "amet": 0.0, - "ahcys": -1.3226601224426768, - "gly": 0.0, - "sarcs": 0.0, - "hcys-L": 1.3226601224426768, - "adn": 1.3226601224426768, - "thf": 0.0, - "5mthf": 0.0, - "mlthf": 0.0, - "glyb": 0.0, - "dmgly": 0.0, - "ser-L": 0.0, - "nadp": 0.0, - "nadph": 0.0, - "cyst-L": 0.0 - }, - "log_ki": { - "MAT1": -0.3185780584014895, - "METH-Gen": -0.24799879824208654, - "GNMT1": 0.0018358603202088403 - }, - "log_conc_unbalanced": { - "atp": 7.654742420277419, - "pi": 14.652528429601, - "ppi": 14.652528429601, - "gly": 9.831551179503698, - "sarcs": 14.652528429601988, - "adn": 19.700379108345924, - "thf": 14.652528429601986, - "mlthf": 35.57574270245104, - "glyb": 25.3415154392192, - "dmgly": 14.652528429601988, - "ser-L": 14.60259327677706, - "nadp": 14.652528429601988, - "nadph": 17.074218415424728, - "cyst-L": 14.652528429601988 - }, - "log_transfer_constant": { - "METAT": 0.07549016742205726, - "GNMT": 1.3246972471203298, - "CBS": 0.00007813277693021312, - "MTHFR": -0.3918664031389874 - }, - "log_dissociation_constant": { - "met-L MAT3": 0.009116267028546848, - "amet MAT3": 0.012649140053254744, - "amet GNMT1": 3.233378796396484, - "mlthf GNMT1": -0.0720298340010019, - "amet CBS1": 0.0, - "amet MTHFR1": 0.5327826505992013, - "ahcys MTHFR1": -0.0637729300752556 - } -} diff --git a/tests/test_lp_grad.py b/tests/test_lp_grad.py index 71b9284..a7aa861 100644 --- a/tests/test_lp_grad.py +++ b/tests/test_lp_grad.py @@ -1,15 +1,13 @@ import json import jax -import pytest from jax import numpy as jnp +from jax.scipy.stats import norm +from jax.flatten_util import ravel_pytree +from jaxtyping import Array, Scalar from enzax.examples import methionine -from enzax.mcmc import ( - ObservationSet, - AllostericMichaelisMentenPriorSet, - ind_prior_from_truth, - posterior_logdensity_amm, -) +from enzax.kinetic_model import RateEquationModel +from enzax.mcmc import ObservationSet from enzax.steady_state import get_kinetic_model_steady_state import importlib.resources @@ -21,63 +19,86 @@ SEED = 1234 methionine_pldf_grad_file = ( - importlib.resources.files(data) / "methionine_pldf_grad.json" + importlib.resources.files(data) / "expected_methionine_gradient.json" +) + +obs_conc = jnp.array( + [ + 3.99618131e-05, + 1.24186458e-03, + 9.44053469e-04, + 4.72041839e-04, + 2.92625684e-05, + 2.04876101e-07, + 1.37054850e-03, + 9.44053469e-08, + 3.32476221e-06, + 9.53494003e-07, + 2.11467977e-05, + 6.16881926e-06, + 2.97376843e-06, + 1.00785260e-03, + 4.72026734e-05, + 1.49849607e-03, + 1.15174523e-06, + 2.31424323e-04, + 2.11467977e-06, + ], + dtype=jnp.float64, +) +obs_flux = jnp.array( + [ + -0.00425181, + 0.03739644, + 0.01397071, + -0.04154405, + -0.05396867, + 0.01236334, + -0.07089178, + -0.02136595, + 0.00152784, + -0.02482788, + -0.01588131, + ], + dtype=jnp.float64, +) +obs_enzyme = jnp.array( + [ + 0.00097884, + 0.00100336, + 0.00105027, + 0.00099059, + 0.00096148, + 0.00107917, + 0.00104588, + 0.00138744, + 0.00107483, + 0.0009662, + ], + dtype=jnp.float64, ) def test_lp_grad(): - model = methionine structure = methionine.structure - rate_equations = methionine.rate_equations true_parameters = methionine.parameters true_model = methionine.model default_state_guess = jnp.full((5,), 0.01) true_states = get_kinetic_model_steady_state( true_model, default_state_guess ) - prior = AllostericMichaelisMentenPriorSet( - log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), - log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), - log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), - dgf=( - ind_prior_from_truth(true_parameters.dgf, 0.1)[0], - jnp.diag( - jnp.square(ind_prior_from_truth(true_parameters.dgf, 0.1)[1]) - ), - ), - log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), - log_conc_unbalanced=ind_prior_from_truth( - true_parameters.log_conc_unbalanced, 0.1 - ), - temperature=ind_prior_from_truth(true_parameters.temperature, 0.1), - log_ki=ind_prior_from_truth(true_parameters.log_ki, 0.1), - log_transfer_constant=ind_prior_from_truth( - true_parameters.log_transfer_constant, 0.1 - ), - log_dissociation_constant=ind_prior_from_truth( - true_parameters.log_dissociation_constant, 0.1 - ), - ) + flat_true_params, _ = ravel_pytree(methionine.parameters) # get true concentration true_conc = jnp.zeros(methionine.structure.S.shape[0]) - true_conc = true_conc.at[methionine.structure.balanced_species].set( + true_conc = true_conc.at[methionine.structure.balanced_species_ix].set( true_states ) - true_conc = true_conc.at[methionine.structure.unbalanced_species].set( + true_conc = true_conc.at[methionine.structure.unbalanced_species_ix].set( jnp.exp(true_parameters.log_conc_unbalanced) ) - # get true flux - true_flux = true_model.flux(true_states) - # simulate observations error_conc = 0.03 error_flux = 0.05 error_enzyme = 0.03 - key = jax.random.key(SEED) - obs_conc = jnp.exp(jnp.log(true_conc) + jax.random.normal(key) * error_conc) - obs_enzyme = jnp.exp( - true_parameters.log_enzyme + jax.random.normal(key) * error_enzyme - ) - obs_flux = true_flux + jax.random.normal(key) * error_conc obs = ObservationSet( conc=obs_conc, flux=obs_flux, @@ -86,25 +107,56 @@ def test_lp_grad(): flux_scale=error_flux, enzyme_scale=error_enzyme, ) - pldf = functools.partial( - posterior_logdensity_amm, + + def joint_log_density(params, prior_mean, prior_sd, obs): + flat_params, _ = ravel_pytree(params) + model = RateEquationModel(params, methionine.structure) + steady = get_kinetic_model_steady_state(model, default_state_guess) + unbalanced = jnp.exp(params.log_conc_unbalanced) + conc = jnp.zeros(structure.S.shape[0]) + conc = conc.at[structure.balanced_species_ix].set(steady) + conc = conc.at[structure.unbalanced_species_ix].set(unbalanced) + flux = model.flux(steady) + log_prior = norm.pdf(flat_params, prior_mean, prior_sd).sum() + flat_log_enzyme, _ = ravel_pytree(params.log_enzyme) + log_liklihood = jnp.sum( + jnp.array( + [ + norm.logpdf( + jnp.log(obs.conc), jnp.log(conc), obs.conc_scale + ).sum(), + norm.logpdf( + jnp.log(obs.enzyme), flat_log_enzyme, obs.enzyme_scale + ).sum(), + norm.logpdf(obs.flux, flux, obs.flux_scale).sum(), + ] + ) + ) + return log_prior + log_liklihood + + posterior_log_density = functools.partial( + joint_log_density, + prior_mean=flat_true_params, + prior_sd=0.1, obs=obs, - prior=prior, - structure=structure, - rate_equations=rate_equations, - guess=default_state_guess, ) - pldf_grad = jax.jacrev(pldf)(methionine.parameters) - index_pldf_grad = { - p: { - c: float(getattr(pldf_grad, p)[i]) - for i, c in enumerate(model.coords[model.dims[p][0]]) - } - for p in model.dims.keys() - } + gradient = jax.jacrev(posterior_log_density)(methionine.parameters) + _, grad_pytree_def = ravel_pytree(gradient) with open(methionine_pldf_grad_file, "r") as file: - saved_pldf_grad = file.read() - - true_gradient = json.loads(saved_pldf_grad) - for p, vals in true_gradient.items(): - assert true_gradient[p] == pytest.approx(index_pldf_grad[p]) + expected_gradient = grad_pytree_def(jnp.array(json.load(file))) + for k in gradient.__dataclass_fields__.keys(): + obs = getattr(gradient, k) + exp = getattr(expected_gradient, k) + if isinstance(obs, Scalar): + assert jnp.isclose(obs, exp) + elif isinstance(obs, Array): + assert jnp.isclose(obs, exp).all() + elif isinstance(obs, dict): + for kk in obs.keys(): + if isinstance(obs[kk], list): + for o, e in zip(obs[kk], exp[kk]): + assert jnp.isclose(o, e).all() + elif isinstance(obs[kk], Scalar): + assert jnp.isclose(obs[kk], exp[kk]) + elif len(obs[kk]) > 0: + assert jnp.isclose(obs[kk], exp[kk]).all() diff --git a/tests/test_rate_equations.py b/tests/test_rate_equations.py index a511e77..0f414c4 100644 --- a/tests/test_rate_equations.py +++ b/tests/test_rate_equations.py @@ -1,116 +1,107 @@ """Unit tests for rate equations.""" +import equinox as eqx +import numpy as np from jax import numpy as jnp +from jaxtyping import Array, Scalar from enzax.rate_equations import ( AllostericIrreversibleMichaelisMenten, AllostericReversibleMichaelisMenten, IrreversibleMichaelisMenten, ReversibleMichaelisMenten, ) -from enzax.parameters import AllostericMichaelisMentenParameterSet -EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1, 0.3]) -EXAMPLE_PARAMETERS = AllostericMichaelisMentenParameterSet( - log_kcat=jnp.array([-0.1]), - log_enzyme=jnp.log(jnp.array([0.3])), - dgf=jnp.array([-3, -1.0]), - log_km=jnp.array([0.1, -0.2]), - log_ki=jnp.array([1.0]), - log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.3])), + +class ExampleParameterSet(eqx.Module): + log_substrate_km: dict[str, Array] + log_product_km: dict[str, Array] + log_kcat: dict[str, Scalar] + log_enzyme: dict[str, Array] + log_ki: dict[str, Array] + dgf: Array + temperature: Scalar + log_conc_unbalanced: Array + log_dc_inhibitor: dict[str, Array] + log_dc_activator: dict[str, Array] + log_tc: dict[str, Array] + + +EXAMPLE_S = np.array([[-1], [1], [0]], dtype=np.float64) +EXAMPLE_CONC = jnp.array([0.5, 0.2, 0.1]) +EXAMPLE_PARAMETERS = ExampleParameterSet( + log_substrate_km={"r1": jnp.array([0.1])}, + log_product_km={"r1": jnp.array([-0.2])}, + log_kcat={"r1": jnp.array(-0.1)}, + dgf=jnp.array([-3.0, 1.0]), + log_ki={"r1": jnp.array([1.0])}, temperature=jnp.array(310.0), - log_transfer_constant=jnp.array([-0.2, 0.3]), - log_dissociation_constant=jnp.array([-0.1, 0.2]), - log_drain=jnp.array([]), + log_enzyme={"r1": jnp.log(jnp.array(0.3))}, + log_conc_unbalanced=jnp.array([]), + log_tc={"r1": jnp.array(-0.2)}, + log_dc_activator={"r1": jnp.array([-0.1])}, + log_dc_inhibitor={"r1": jnp.array([0.2])}, ) +EXAMPLE_SPECIES_TO_DGF_IX = np.array([0, 0, 1, 1]) def test_irreversible_michaelis_menten(): expected_rate = 0.08455524 - f = IrreversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), + f = IrreversibleMichaelisMenten() + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) def test_reversible_michaelis_menten(): expected_rate = 0.04342889 f = ReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), + ix_ki_species=np.array([], dtype=np.int16), + water_stoichiometry=0.0, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, + ) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) def test_allosteric_irreversible_michaelis_menten(): expected_rate = 0.05608589 f = AllostericIrreversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), + ix_ki_species=np.array([], dtype=np.int16), + ix_allosteric_activators=np.array([2], dtype=np.int16), subunits=1, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, + ) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate) def test_allosteric_reversible_michaelis_menten(): expected_rate = 0.03027414 f = AllostericReversibleMichaelisMenten( - kcat_ix=0, - enzyme_ix=0, - km_ix=jnp.array([0, 1], dtype=jnp.int16), - ki_ix=jnp.array([], dtype=jnp.int16), - reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16), - reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16), - ix_ki_species=jnp.array([], dtype=jnp.int16), - substrate_km_positions=jnp.array([0], dtype=jnp.int16), - substrate_reactant_positions=jnp.array([0], dtype=jnp.int16), - ix_substrate=jnp.array([0], dtype=jnp.int16), - ix_product=jnp.array([1], dtype=jnp.int16), - ix_reactants=jnp.array([0, 1], dtype=jnp.int16), - product_reactant_positions=jnp.array([1], dtype=jnp.int16), - product_km_positions=jnp.array([1], dtype=jnp.int16), - water_stoichiometry=jnp.array(0.0), - tc_ix=0, - ix_dc_inhibition=jnp.array([], dtype=jnp.int16), - ix_dc_activation=jnp.array([0], dtype=jnp.int16), - species_activation=jnp.array([2], dtype=jnp.int16), - species_inhibition=jnp.array([], dtype=jnp.int16), + ix_ki_species=np.array([], dtype=np.int16), + ix_allosteric_activators=np.array([2], dtype=np.int16), subunits=1, ) - rate = f(EXAMPLE_CONC, EXAMPLE_PARAMETERS) + f_input = f.get_input( + parameters=EXAMPLE_PARAMETERS, + reaction_id="r1", + reaction_stoichiometry=EXAMPLE_S[:, 0], + species_to_dgf_ix=EXAMPLE_SPECIES_TO_DGF_IX, + ) + rate = f(EXAMPLE_CONC, f_input) assert jnp.isclose(rate, expected_rate)