Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model definition ui #29

Merged
merged 28 commits into from
Dec 13, 2024
Merged

Improve model definition ui #29

merged 28 commits into from
Dec 13, 2024

Conversation

teddygroves
Copy link
Contributor

@teddygroves teddygroves commented Nov 27, 2024

This change makes it easier to define a kinetic model. Now you can create one from scratch like this (believe it or not, this is a big improvement!):

import equinox as eqx

from jax import numpy as jnp
import numpy as np

from enzax.kinetic_model import (
    KineticModelStructure,
    RateEquationModel,
)
from enzax.rate_equations import (
    AllostericReversibleMichaelisMenten,
    ReversibleMichaelisMenten,
)

S = np.array([[-1, 0, 0], [1, -1, 0], [0, 1, -1], [0, 0, 1]])
reactions = ["r1", "r2", "r3"]
species = ["m1e", "m1c", "m2c", "m2e"]
balanced_species = ["m1c", "m2c"]
rate_equations = [
    AllostericReversibleMichaelisMenten(
        ix_allosteric_activators=np.array([2]), subunits=1
    ),
    AllostericReversibleMichaelisMenten(
        ix_allosteric_inhibitors=np.array([1]), ix_ki_species=np.array([1])
    ),
    ReversibleMichaelisMenten(water_stoichiometry=0.0),
]
structure = KineticModelStructure(
    S=S,
    species=species,
    balanced_species=balanced_species,
    rate_equations=rate_equations,
)

class ParameterDefinition(eqx.Module):
    log_substrate_km: dict[int, Array]
    log_product_km: dict[int, Array]
    log_kcat: dict[int, Scalar]
    log_enzyme: dict[int, Array]
    log_ki: dict[int, Array]
    dgf: Array
    temperature: Scalar
    log_conc_unbalanced: Array
    log_dc_inhibitor: dict[int, Array]
    log_dc_activator: dict[int, Array]
    log_tc: dict[int, Array]

parameters = ParameterDefinition(
    log_substrate_km={
        0: jnp.array([0.1]),
        1: jnp.array([0.5]),
        2: jnp.array([-1.0]),
    },
    log_product_km={
        0: jnp.array([-0.2]),
        1: jnp.array([0.0]),
        2: jnp.array([0.5]),
    },
    log_kcat={0: jnp.array(-0.1), 1: jnp.array(0.0), 2: jnp.array(0.1)},
    dgf=jnp.array([-3.0, -1.0]),
    log_ki={0: jnp.array([]), 1: jnp.array([1.0]), 2: jnp.array([])},
    temperature=jnp.array(310.0),
    log_enzyme={
        0: jnp.log(jnp.array(0.3)),
        1: jnp.log(jnp.array(0.2)),
        2: jnp.log(jnp.array(0.1)),
    },
    log_conc_unbalanced=jnp.log(jnp.array([0.5, 0.1])),
    log_tc={0: jnp.array(-0.2), 1: jnp.array(0.3)},
    log_dc_activator={0: jnp.array([-0.1]), 1: jnp.array([])},
    log_dc_inhibitor={0: jnp.array([]), 1: jnp.array([0.2])},
)

model = RateEquationModel(parameters, structure)

There is still a lot of room to make things easier. In particular it should be possible to create the parameter definition automatically from the structure. There should also be helpful validation of the parameters.

Checklist:

  • tests pass
  • README.md up to date
  • docs up to date
  • link to any relevant issues
  • steady state demo works
  • mcmc demo works

@teddygroves
Copy link
Contributor Author

Now the example model tests pass!

@teddygroves
Copy link
Contributor Author

Now all the tests pass!

@teddygroves teddygroves changed the title WIP better ui Improve model definition ui Nov 29, 2024
src/enzax/kinetic_model.py Outdated Show resolved Hide resolved
src/enzax/kinetic_model.py Show resolved Hide resolved
@teddygroves
Copy link
Contributor Author

Oops, thanks @AlberteSloth!

"r3": jnp.array(0.1),
},
dgf=jnp.array([-3.0, -1.0]),
log_ki={"r1": jnp.array([]), "r2": jnp.array([1.0]), "r3": jnp.array([])},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still need to specify the index of the balanced metabolites for the ui?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently it just has to agree with the order of balanced_species

Copy link
Contributor

@NicholasCowie NicholasCowie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@teddygroves teddygroves merged commit 1669e64 into main Dec 13, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants