Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model definition ui #29

Merged
merged 28 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c8ddebd
WIP better ui
teddygroves Nov 27, 2024
c14a469
allostery, tests (not passing!)
teddygroves Nov 27, 2024
90e79ab
tidying
teddygroves Nov 27, 2024
259961e
Fix missing jnp.log in test
teddygroves Nov 27, 2024
33dbc3d
More fixes - now the rate equations test passes!
teddygroves Nov 27, 2024
952a909
Updating linear example
teddygroves Nov 28, 2024
3f3937e
Fix incorrect dgf in linear model
teddygroves Nov 28, 2024
2334e1c
trying to get methionine to work
teddygroves Nov 28, 2024
34cc521
Fix incorrect kms in methionine model
teddygroves Nov 28, 2024
70395e7
flatten species indexes
teddygroves Nov 28, 2024
31a0cdc
Fix methionine kis
teddygroves Nov 28, 2024
a201922
Update gradient test
teddygroves Nov 28, 2024
cccb879
fix gradient test
teddygroves Nov 28, 2024
b9736e6
delete unused files, tidy grad test
teddygroves Nov 28, 2024
97e5e7b
Update gh action to use uv
teddygroves Nov 28, 2024
6b8b627
Update grad test to be approximate
teddygroves Nov 28, 2024
0da0dd3
get mcmc demo to work
teddygroves Nov 29, 2024
a0e8c46
get steady state demo to work
teddygroves Nov 29, 2024
c3dffff
update readme
teddygroves Nov 29, 2024
12b4dac
Merge branch 'main' into usable_ui
teddygroves Nov 29, 2024
841ae08
update getting started docs
teddygroves Nov 29, 2024
296befa
Fix typo in python example
teddygroves Nov 29, 2024
62884a1
define rate equation models using reaction ids
teddygroves Dec 11, 2024
fc15ee6
Fix sbml model errors
teddygroves Dec 11, 2024
01cd8ce
Remove use of deprecated KeyArray
teddygroves Dec 11, 2024
e406db3
Update KineticModelSbml
AlberteSloth Dec 11, 2024
5495772
Updated sbml_demo
AlberteSloth Dec 13, 2024
0214412
fix pre-commit
teddygroves Dec 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ jobs:
- name: checkout code
uses: actions/checkout@v2

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
# Install a specific version of uv.
version: "0.5.5"

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install pdm
run: pip install pdm
- name: Install the project
run: uv sync --all-extras --dev

- name: pre-commit checks
uses: pre-commit/[email protected]

- name: Run tests
run: |
pdm install --dev
pdm run pytest tests --cov=src/enzax
run: uv run pytest tests

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
_model = RateEquationModel(parameters, model.structure)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)
Expand Down
164 changes: 68 additions & 96 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ Enzax provides building blocks for you to construct a wide range of differentiab

Here we write a model describing a simple linear pathway with two state variables, two boundary species and three reactions.

First we import some enzax classes:
First we import some enzax classes, as well as [equinox](https://github.com/patrick-kidger/equinox) and both JAX and standard versions of numpy:

```python
import equinox as eqx

from jax import numpy as jnp
import numpy as np

from enzax.kinetic_model import (
KineticModelStructure,
RateEquationModel,
Expand All @@ -32,113 +37,82 @@ from enzax.rate_equations import (

```

Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which represent boundary or "unbalanced" species:
Next we specify our model's structure by providing a stoichiometric matrix and saying which of its rows represent state variables (aka "balanced species") and which reactions have which rate equations.

```python
structure = KineticModelStructure(
S=jnp.array(
[[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]], dtype=jnp.float64
stoichiometry = {
"r1": {"m1e": -1, "m1c": 1},
"r2": {"m1c": -1, "m2c": 1},
"r3": {"m2c": -1, "m2e": 1},
}
reactions = ["r1", "r2", "r3"]
species = ["m1e", "m1c", "m2c", "m2e"]
balanced_species = ["m1c", "m2c"]
rate_equations = [
AllostericReversibleMichaelisMenten(
ix_allosteric_activators=np.array([2]), subunits=1
),
AllostericReversibleMichaelisMenten(
ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1])
),
balanced_species=jnp.array([1, 2]),
unbalanced_species=jnp.array([0, 3]),
ReversibleMichaelisMenten(water_stoichiometry=0.0),
]
structure = KineticModelStructure(
stoichiometry=stoichiometry,
species=species,
balanced_species=balanced_species,
rate_equations=rate_equations,
)
```

Next we provide some kinetic parameter values:
Next we define what a set of kinetic parameters looks like for our problem, and provide a set of parameters matching this definition:

```python
from enzax.parameters import AllostericMichaelisMentenParameters

parameters = AllostericMichaelisMentenParameters(
log_kcat=jnp.array([-0.1, 0.0, 0.1]),
log_enzyme=jnp.log(jnp.array([0.3, 0.2, 0.1])),
dgf=jnp.array([-3, -1.0]),
log_km=jnp.array([0.1, -0.2, 0.5, 0.0, -1.0, 0.5]),
log_ki=jnp.array([1.0]),
log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])),
class ParameterDefinition(eqx.Module):
log_substrate_km: dict[int, Array]
log_product_km: dict[int, Array]
log_kcat: dict[int, Scalar]
log_enzyme: dict[int, Array]
log_ki: dict[int, Array]
dgf: Array
temperature: Scalar
log_conc_unbalanced: Array
log_dc_inhibitor: dict[int, Array]
log_dc_activator: dict[int, Array]
log_tc: dict[int, Array]

parameters = ParameterDefinition(
log_substrate_km={
"r1": jnp.array([0.1]),
"r2": jnp.array([0.5]),
"r3": jnp.array([-1.0]),
},
log_product_km={
"r1": jnp.array([-0.2]),
"r2": jnp.array([0.0]),
"r3": jnp.array([0.5]),
},
log_kcat={"r1": jnp.array(-0.1), "r2": jnp.array(0.0), "r3": jnp.array(0.1)},
dgf=jnp.array([-3.0, -1.0]),
log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])},
temperature=jnp.array(310.0),
log_transfer_constant=jnp.array([-0.2, 0.3]),
log_dissociation_constant=jnp.array([-0.1, 0.2]),
log_drain=jnp.array([]),
)
```
Now we can use enzax's rate laws to specify how each reaction behaves:

```python
from enzax.rate_equations import (
AllostericReversibleMichaelisMenten,
ReversibleMichaelisMenten,
)

r0 = AllostericReversibleMichaelisMenten(
kcat_ix=0,
enzyme_ix=0,
km_ix=jnp.array([0, 1], dtype=jnp.int16),
ki_ix=jnp.array([], dtype=jnp.int16),
reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16),
reactant_to_dgf=jnp.array([0, 0], dtype=jnp.int16),
ix_ki_species=jnp.array([], dtype=jnp.int16),
substrate_km_positions=jnp.array([0], dtype=jnp.int16),
substrate_reactant_positions=jnp.array([0], dtype=jnp.int16),
ix_substrate=jnp.array([0], dtype=jnp.int16),
ix_product=jnp.array([1], dtype=jnp.int16),
ix_reactants=jnp.array([0, 1], dtype=jnp.int16),
product_reactant_positions=jnp.array([1], dtype=jnp.int16),
product_km_positions=jnp.array([1], dtype=jnp.int16),
water_stoichiometry=jnp.array(0.0),
tc_ix=0,
ix_dc_inhibition=jnp.array([], dtype=jnp.int16),
ix_dc_activation=jnp.array([0], dtype=jnp.int16),
species_activation=jnp.array([2], dtype=jnp.int16),
species_inhibition=jnp.array([], dtype=jnp.int16),
subunits=1,
)
r1 = AllostericReversibleMichaelisMenten(
kcat_ix=1,
enzyme_ix=1,
km_ix=jnp.array([2, 3], dtype=jnp.int16),
ki_ix=jnp.array([0]),
reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16),
reactant_to_dgf=jnp.array([0, 1], dtype=jnp.int16),
ix_ki_species=jnp.array([1]),
substrate_km_positions=jnp.array([0], dtype=jnp.int16),
substrate_reactant_positions=jnp.array([0], dtype=jnp.int16),
ix_substrate=jnp.array([1], dtype=jnp.int16),
ix_product=jnp.array([2], dtype=jnp.int16),
ix_reactants=jnp.array([1, 2], dtype=jnp.int16),
product_reactant_positions=jnp.array([1], dtype=jnp.int16),
product_km_positions=jnp.array([1], dtype=jnp.int16),
water_stoichiometry=jnp.array(0.0),
tc_ix=1,
ix_dc_inhibition=jnp.array([1], dtype=jnp.int16),
ix_dc_activation=jnp.array([], dtype=jnp.int16),
species_activation=jnp.array([], dtype=jnp.int16),
species_inhibition=jnp.array([1], dtype=jnp.int16),
subunits=1,
)
r2 = ReversibleMichaelisMenten(
kcat_ix=2,
enzyme_ix=2,
km_ix=jnp.array([4, 5], dtype=jnp.int16),
ki_ix=jnp.array([], dtype=jnp.int16),
ix_substrate=jnp.array([2], dtype=jnp.int16),
ix_product=jnp.array([3], dtype=jnp.int16),
ix_reactants=jnp.array([2, 3], dtype=jnp.int16),
reactant_to_dgf=jnp.array([1, 1], dtype=jnp.int16),
reactant_stoichiometry=jnp.array([-1, 1], dtype=jnp.int16),
ix_ki_species=jnp.array([], dtype=jnp.int16),
substrate_km_positions=jnp.array([0], dtype=jnp.int16),
substrate_reactant_positions=jnp.array([0], dtype=jnp.int16),
product_reactant_positions=jnp.array([1], dtype=jnp.int16),
product_km_positions=jnp.array([1], dtype=jnp.int16),
water_stoichiometry=jnp.array(0.0),
log_enzyme={
"r1": jnp.log(jnp.array(0.3)),
"r2": jnp.log(jnp.array(0.2)),
"r3": jnp.log(jnp.array(0.1)),
},
log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])),
log_tc={"r1": jnp.array(-0.2), "r2": jnp.array(0.3)},
log_dc_activator={"r1": jnp.array([-0.1]), "r2": jnp.array([])},
log_dc_inhibitor={"r1": jnp.array([]), "r2": jnp.array([0.2])},
)
```
Note that the parameters use `jnp` whereas the structure uses `np`. This is because we want JAX to trace the parameters, whereas the structure should be static. Read more about this [here](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations).

Now we can declare our model:

```python
model = RateEquationModel(structure, parameters, [r0, r1, r2])
model = RateEquationModel(structure, parameters)
```

To test out the model, we can see if it returns some fluxes and state variable rates when provided a set of balanced species concentrations:
Expand Down Expand Up @@ -181,9 +155,7 @@ model = methionine.model

def get_steady_state_from_params(parameters: PyTree):
"""Get the steady state with a one-argument non-pure function."""
_model = RateEquationModel(
parameters, model.structure, model.rate_equations
)
_model = RateEquationModel(parameters, model.structure)
return get_kinetic_model_steady_state(_model, guess)

jacobian = jax.jacrev(get_steady_state_from_params)(model.parameters)
Expand Down
15 changes: 9 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies = [
"diffrax>=0.6.0",
"jaxtyping>=0.2.31",
"arviz>=0.19.0",
"jax>=0.4.35",
"equinox>=0.11.9",
"python-libsbml>=5.20.4",
"sympy2jax>=0.0.5",
"sbmlmath>=0.2.0",
Expand All @@ -33,12 +35,6 @@ build-backend = "hatchling.build"
[tool.pdm]
distribution = true

[tool.pdm.dev-dependencies]
dev = [
"pytest>=8.3.2",
"pytest-cov>=5.0.0",
"pre-commit>=3.8.0",
]
[tool.ruff]
line-length = 80

Expand All @@ -48,3 +44,10 @@ extend-select = ["E501"] # line length is checked

[tool.ruff.lint.isort]
known-first-party = ["enzax"]

[dependency-groups]
dev = [
"pytest>=8.3.3",
"pytest-cov>=5.0.0",
"pre-commit>=3.8.0",
]
Loading
Loading