Skip to content

Commit

Permalink
Merge pull request #29 from dtu-qmcm/usable_ui
Browse files Browse the repository at this point in the history
Improve model definition ui
  • Loading branch information
teddygroves authored Dec 13, 2024
2 parents 823b4f0 + 0214412 commit 1669e64
Show file tree
Hide file tree
Showing 22 changed files with 1,202 additions and 1,348 deletions.
16 changes: 10 additions & 6 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ jobs:
- name: checkout code
uses: actions/checkout@v2

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
# Install a specific version of uv.
version: "0.5.5"

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install pdm
run: pip install pdm
- name: Install the project
run: uv sync --all-extras --dev

- name: pre-commit checks
uses: pre-commit/[email protected]

- name: Run tests
run: |
pdm install --dev
pdm run pytest tests --cov=src/enzax
run: uv run pytest tests

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
_model = RateEquationModel(parameters, model.structure)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)
Expand Down
164 changes: 68 additions & 96 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ Enzax provides building blocks for you to construct a wide range of differentiab

Here we write a model describing a simple linear pathway with two state variables, two boundary species and three reactions.

First we import some enzax classes:
First we import some enzax classes, as well as [equinox](https://github.com/patrick-kidger/equinox) and both JAX and standard versions of numpy:

```python
import equinox as eqx

from jax import numpy as jnp
import numpy as np

from enzax.kinetic_model import (
KineticModelStructure,
RateEquationModel,
Expand All @@ -32,113 +37,82 @@ from enzax.rate_equations import (

```

Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which represent boundary or "unbalanced" species:
Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which reactions have which rate equations.

```python
structure = KineticModelStructure(
S=jnp.array(
[[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=jnp.float64
stoichiometry = {
"r1": {"m1e": -1, "m1c": 1},
"r2": {"m1c": -1, "m2c": 1},
"r3": {"m2c": -1, "m2e": 1},
}
reactions = ["r1", "r2", "r3"]
species = ["m1e", "m1c", "m2c", "m2e"]
balanced_species = ["m1c", "m2c"]
rate_equations = [
AllostericReversibleMichaelisMenten(
ix_allosteric_activators=np.array([2]), subunits=1
),
AllostericReversibleMichaelisMenten(
ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1])
),
balanced_species=jnp.array([1, 2]),
unbalanced_species=jnp.array([0, 3]),
ReversibleMichaelisMenten(water_stoichiometry=0.0),
]
structure = KineticModelStructure(
stoichiometry=stoichiometry,
species=species,
balanced_species=balanced_species,
rate_equations=rate_equations,
)
```

Next we provide some kinetic parameter values:
Next we define what a set of kinetic parameters looks like for our problem, and provide a set of parameters matching this definition:

```python
from enzax.parameters import AllostericMichaelisMentenParameters

parameters = AllostericMichaelisMentenParameters(
log_kcat=jnp.array([-0.1, 0.0, 0.1]),
log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])),
dgf=jnp.array([-3, -1.0]),
log_km=jnp.array([0.1, -0.2, 0.5, 0.0, -1.0, 0.5]),
log_ki=jnp.array([1.0]),
log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])),
class ParameterDefinition(eqx.Module):
log_substrate_km: dict[int, Array]
log_product_km: dict[int, Array]
log_kcat: dict[int, Scalar]
log_enzyme: dict[int, Array]
log_ki: dict[int, Array]
dgf: Array
temperature: Scalar
log_conc_unbalanced: Array
log_dc_inhibitor: dict[int, Array]
log_dc_activator: dict[int, Array]
log_tc: dict[int, Array]

parameters = ParameterDefinition(
log_substrate_km={
"r1": jnp.array([0.1]),
"r2": jnp.array([0.5]),
"r3": jnp.array([-1.0]),
},
log_product_km={
"r1": jnp.array([-0.2]),
"r2": jnp.array([0.0]),
"r3": jnp.array([0.5]),
},
log_kcat={"r1": jnp.array(-0.1), "r2": jnp.array(0.0), "r3": jnp.array(0.1)},
dgf=jnp.array([-3.0, -1.0]),
log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])},
temperature=jnp.array(310.0),
log_transfer_constant=jnp.array([-0.2, 0.3]),
log_dissociation_constant=jnp.array([-0.1, 0.2]),
log_drain=jnp.array([]),
)
```
Now we can use enzax's rate laws to specify how each reaction behaves:

```python
from enzax.rate_equations import (
AllostericReversibleMichaelisMenten,
ReversibleMichaelisMenten,
)

r0 = AllostericReversibleMichaelisMenten(
kcat_ix=0,
enzyme_ix=0,
km_ix=jnp.array([0, 1], dtype=jnp.int16),
ki_ix=jnp.array([], dtype=jnp.int16),
reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16),
reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16),
ix_ki_species=jnp.array([], dtype=jnp.int16),
substrate_km_positions=jnp.array([0], dtype=jnp.int16),
substrate_reactant_positions=jnp.array([0], dtype=jnp.int16),
ix_substrate=jnp.array([0], dtype=jnp.int16),
ix_product=jnp.array([1], dtype=jnp.int16),
ix_reactants=jnp.array([0, 1], dtype=jnp.int16),
product_reactant_positions=jnp.array([1], dtype=jnp.int16),
product_km_positions=jnp.array([1], dtype=jnp.int16),
water_stoichiometry=jnp.array(0.0),
tc_ix=0,
ix_dc_inhibition=jnp.array([], dtype=jnp.int16),
ix_dc_activation=jnp.array([0], dtype=jnp.int16),
species_activation=jnp.array([2], dtype=jnp.int16),
species_inhibition=jnp.array([], dtype=jnp.int16),
subunits=1,
)
r1 = AllostericReversibleMichaelisMenten(
kcat_ix=1,
enzyme_ix=1,
km_ix=jnp.array([2, 3], dtype=jnp.int16),
ki_ix=jnp.array([0]),
reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16),
reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16),
ix_ki_species=jnp.array([1]),
substrate_km_positions=jnp.array([0], dtype=jnp.int16),
substrate_reactant_positions=jnp.array([0], dtype=jnp.int16),
ix_substrate=jnp.array([1], dtype=jnp.int16),
ix_product=jnp.array([2], dtype=jnp.int16),
ix_reactants=jnp.array([1, 2], dtype=jnp.int16),
product_reactant_positions=jnp.array([1], dtype=jnp.int16),
product_km_positions=jnp.array([1], dtype=jnp.int16),
water_stoichiometry=jnp.array(0.0),
tc_ix=1,
ix_dc_inhibition=jnp.array([1], dtype=jnp.int16),
ix_dc_activation=jnp.array([], dtype=jnp.int16),
species_activation=jnp.array([], dtype=jnp.int16),
species_inhibition=jnp.array([1], dtype=jnp.int16),
subunits=1,
)
r2 = ReversibleMichaelisMenten(
kcat_ix=2,
enzyme_ix=2,
km_ix=jnp.array([4, 5], dtype=jnp.int16),
ki_ix=jnp.array([], dtype=jnp.int16),
ix_substrate=jnp.array([2], dtype=jnp.int16),
ix_product=jnp.array([3], dtype=jnp.int16),
ix_reactants=jnp.array([2, 3], dtype=jnp.int16),
reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16),
reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16),
ix_ki_species=jnp.array([], dtype=jnp.int16),
substrate_km_positions=jnp.array([0], dtype=jnp.int16),
substrate_reactant_positions=jnp.array([0], dtype=jnp.int16),
product_reactant_positions=jnp.array([1], dtype=jnp.int16),
product_km_positions=jnp.array([1], dtype=jnp.int16),
water_stoichiometry=jnp.array(0.0),
log_enzyme={
"r1": jnp.log(jnp.array(0.3)),
"r2": jnp.log(jnp.array(0.2)),
"r3": jnp.log(jnp.array(0.1)),
},
log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])),
log_tc={"r1": jnp.array(-0.2), "r2": jnp.array(0.3)},
log_dc_activator={"r1": jnp.array([-0.1]), "r2": jnp.array([])},
log_dc_inhibitor={"r1": jnp.array([]), "r2": jnp.array([0.2])},
)
```
Note that the parameters use `jnp` whereas the structure uses `np`. This is because we want JAX to trace the parameters, whereas the structure should be static. Read more about this [here](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations).

Now we can declare our model:

```python
model = RateEquationModel(structure, parameters, [r0, r1, r2])
model = RateEquationModel(structure, parameters)
```

To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations:
Expand Down Expand Up @@ -181,9 +155,7 @@ model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
_model = RateEquationModel(parameters, model.structure)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)
Expand Down
15 changes: 9 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies = [
"diffrax>=0.6.0",
"jaxtyping>=0.2.31",
"arviz>=0.19.0",
"jax>=0.4.35",
"equinox>=0.11.9",
"python-libsbml>=5.20.4",
"sympy2jax>=0.0.5",
"sbmlmath>=0.2.0",
Expand All @@ -33,12 +35,6 @@ build-backend = "hatchling.build"
[tool.pdm]
distribution = true

[tool.pdm.dev-dependencies]
dev = [
"pytest>=8.3.2",
"pytest-cov>=5.0.0",
"pre-commit>=3.8.0",
]
[tool.ruff]
line-length = 80

Expand All @@ -48,3 +44,10 @@ extend-select = ["E501"] # line length is checked

[tool.ruff.lint.isort]
known-first-party = ["enzax"]

[dependency-groups]
dev = [
"pytest>=8.3.3",
"pytest-cov>=5.0.0",
"pre-commit>=3.8.0",
]
Loading

0 comments on commit 1669e64

Please sign in to comment.