Skip to content

Commit

Permalink
WIP better ui
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Nov 27, 2024
1 parent b927ced commit c8ddebd
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 260 deletions.
69 changes: 60 additions & 9 deletions src/enzax/kinetic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,59 @@ class KineticModelStructure(eqx.Module):
"""Structural information about a kinetic model."""

S: Float[Array, " s r"]
balanced_species: Int[Array, " n_balanced"]
unbalanced_species: Int[Array, " n_unbalanced"]
species: list[str]
reactions: list[str]
balanced_species: list[str]
species_to_dgf_ix: Int[Array, " s"]
balanced_species_ix: Int[Array, " b"]
unbalanced_species_ix: Int[Array, " u"]

def __init__(
self,
S,
species,
reactions,
balanced_species,
species_to_dgf_ix,
):
self.S = S
self.species = species
self.reactions = reactions
self.balanced_species = balanced_species
self.species_to_dgf_ix = species_to_dgf_ix
self.balanced_species_ix = jnp.array(
[i for i, s in enumerate(species) if s in balanced_species],
dtype=jnp.int16,
)
self.unbalanced_species_ix = jnp.array(
[i for i, s in enumerate(species) if s not in balanced_species],
dtype=jnp.int16,
)


class RateEquationKineticModelStructure(KineticModelStructure):
rate_equations: list[RateEquation]

def __init__(
self,
S,
species,
reactions,
balanced_species,
species_to_dgf_ix,
rate_equations,
):
super().__init__(
S, species, reactions, balanced_species, species_to_dgf_ix
)
self.rate_equations = rate_equations


class KineticModel(eqx.Module, ABC):
"""Abstract base class for kinetic models."""

parameters: PyTree
structure: KineticModelStructure
structure: KineticModelStructure = eqx.field(static=True)

@abstractmethod
def flux(
Expand All @@ -44,7 +88,7 @@ def dcdt(
""" # Noqa: E501
sv = self.structure.S @ self.flux(conc)
return sv[self.structure.balanced_species]
return sv[self.structure.balanced_species_ix]


class RateEquationModel(KineticModel):
Expand All @@ -64,10 +108,17 @@ def flux(
""" # Noqa: E501
conc = jnp.zeros(self.structure.S.shape[0])
conc = conc.at[self.structure.balanced_species].set(conc_balanced)
conc = conc.at[self.structure.unbalanced_species].set(
conc = conc.at[self.structure.balanced_species_ix].set(conc_balanced)
conc = conc.at[self.structure.unbalanced_species_ix].set(
jnp.exp(self.parameters.log_conc_unbalanced)
)
t = [f(conc, self.parameters) for f in self.rate_equations]
out = jnp.array(t)
return out
flux_list = []
for i, rate_equation in enumerate(self.structure.rate_equations):
ipt = rate_equation.get_input(
self.parameters,
i,
self.structure.S,
self.structure.species_to_dgf_ix,
)
flux_list.append(rate_equation(conc, ipt))
return jnp.array(flux_list)
20 changes: 16 additions & 4 deletions src/enzax/rate_equation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
"""Module containing rate equations for enzyme-catalysed reactions."""

from abc import ABC, abstractmethod
from equinox import Module

import numpy as np
from equinox import Module
from jaxtyping import Array, Float, PyTree, Scalar

from numpy.typing import NDArray

ConcArray = Float[Array, " n"]


class RateEquation(Module, ABC):
"""Abstract definition of a rate equation.
A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of parameters, returning a scalar value representing a single flux.
A rate equation is an equinox [Module](https://docs.kidger.site/equinox/api/module/module/) with a `__call__` method that takes in a 1 dimensional array of concentrations and an arbitrary PyTree of other inputs, returning a scalar value representing a single flux.
""" # Noqa: E501

@abstractmethod
def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar: ...
def get_input(
self,
parameters: PyTree,
rxn_ix: int,
S: NDArray[np.float64],
species_to_dgf_ix: NDArray[np.int16],
) -> PyTree: ...

@abstractmethod
def __call__(
self, conc: ConcArray, rate_equation_input: PyTree
) -> Scalar: ...
15 changes: 7 additions & 8 deletions src/enzax/rate_equations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from enzax.rate_equations.michaelis_menten import (
ReversibleMichaelisMenten,
IrreversibleMichaelisMenten,
MichaelisMenten,
)
from enzax.rate_equations.generalised_mwc import (
AllostericReversibleMichaelisMenten,
AllostericIrreversibleMichaelisMenten,
)

# from enzax.rate_equations.generalised_mwc import (
# AllostericReversibleMichaelisMenten,
# AllostericIrreversibleMichaelisMenten,
# )
from enzax.rate_equations.drain import Drain

AVAILABLE_RATE_EQUATIONS = [
ReversibleMichaelisMenten,
IrreversibleMichaelisMenten,
MichaelisMenten,
AllostericReversibleMichaelisMenten,
AllostericIrreversibleMichaelisMenten,
# AllostericReversibleMichaelisMenten,
# AllostericIrreversibleMichaelisMenten,
Drain,
]
187 changes: 93 additions & 94 deletions src/enzax/rate_equations/generalised_mwc.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,93 @@
from jax import numpy as jnp
from jaxtyping import Array, Float, Int, PyTree, Scalar

from enzax.rate_equation import ConcArray
from enzax.rate_equations.michaelis_menten import (
IrreversibleMichaelisMenten,
MichaelisMenten,
ReversibleMichaelisMenten,
)


def generalised_mwc_effect(
conc_inhibitor: Float[Array, " n_inhibition"],
dc_inhibitor: Float[Array, " n_inhibition"],
conc_activator: Float[Array, " n_activation"],
dc_activator: Float[Array, " n_activation"],
free_enzyme_ratio: Scalar,
tc: Scalar,
subunits: int,
) -> Scalar:
"""Get the allosteric effect on a rate.
The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2.
""" # noqa: E501
qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor)
qdenom = 1 + jnp.sum(conc_activator / dc_activator)
out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits)
return out


class GeneralisedMWC(MichaelisMenten):
"""Mixin class for allosteric rate laws, assuming generalised MWC kinetics.
See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2.
Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation.
""" # noqa: E501

subunits: int
tc_ix: int
ix_dc_activation: Int[Array, " n_activation"]
ix_dc_inhibition: Int[Array, " n_inhibition"]
species_activation: Int[Array, " n_activation"]
species_inhibition: Int[Array, " n_inhibition"]

def get_tc(self, parameters: PyTree) -> Scalar:
return jnp.exp(parameters.log_transfer_constant[self.tc_ix])

def get_dc_activation(self, parameters: PyTree) -> Scalar:
return jnp.exp(
parameters.log_dissociation_constant[self.ix_dc_activation],
)

def get_dc_inhibition(self, parameters: PyTree) -> Scalar:
return jnp.exp(
parameters.log_dissociation_constant[self.ix_dc_inhibition],
)

def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar:
return generalised_mwc_effect(
conc_inhibitor=conc[self.species_inhibition],
conc_activator=conc[self.species_activation],
free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters),
tc=self.get_tc(parameters),
dc_inhibitor=self.get_dc_inhibition(parameters),
dc_activator=self.get_dc_activation(parameters),
subunits=self.subunits,
)


class AllostericIrreversibleMichaelisMenten(
GeneralisedMWC, IrreversibleMichaelisMenten
):
"""A reaction with irreversible Michaelis Menten kinetics and allostery."""

def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar:
"""The flux of an irreversible allosteric Michaelis Menten reaction."""
allosteric_effect = self.allosteric_effect(conc, parameters)
non_allosteric_rate = super().__call__(conc, parameters)
return non_allosteric_rate * allosteric_effect


class AllostericReversibleMichaelisMenten(
GeneralisedMWC,
ReversibleMichaelisMenten,
):
"""A reaction with reversible Michaelis Menten kinetics and allostery."""

def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar:
"""The flux of an allosteric reversible Michaelis Menten reaction."""
allosteric_effect = self.allosteric_effect(conc, parameters)
non_allosteric_rate = super().__call__(conc, parameters)
return non_allosteric_rate * allosteric_effect
# from jax import numpy as jnp
# from jaxtyping import Array, Float, Int, PyTree, Scalar

# from enzax.rate_equation import ConcArray
# from enzax.rate_equations.michaelis_menten import (
# IrreversibleMichaelisMenten,
# ReversibleMichaelisMenten,
# )


# def generalised_mwc_effect(
# conc_inhibitor: Float[Array, " n_inhibition"],
# dc_inhibitor: Float[Array, " n_inhibition"],
# conc_activator: Float[Array, " n_activation"],
# dc_activator: Float[Array, " n_activation"],
# free_enzyme_ratio: Scalar,
# tc: Scalar,
# subunits: int,
# ) -> Scalar:
# """Get the allosteric effect on a rate.

# The equation is generalised Monod Wyman Changeux model as presented in Popova and Sel'kov 1975: https://doi.org/10.1016/0014-5793(75)80034-2.

# """ # noqa: E501
# qnum = 1 + jnp.sum(conc_inhibitor / dc_inhibitor)
# qdenom = 1 + jnp.sum(conc_activator / dc_activator)
# out = 1.0 / (1 + tc * (free_enzyme_ratio * qnum / qdenom) ** subunits)
# return out


# class GeneralisedMWC:
# """Mixin class for allosteric rate laws, assuming generalised MWC kinetics.

# See Popova and Sel'kov 1975 for the rate law: https://doi.org/10.1016/0014-5793(75)80034-2.

# Note that it is assumed there is a free_enzyme_ratio method available - that is why this is a subclass of MichaelisMenten rather than RateEquation.
# """ # noqa: E501

# subunits: int
# tc_ix: int
# ix_dc_activation: Int[Array, " n_activation"]
# ix_dc_inhibition: Int[Array, " n_inhibition"]
# species_activation: Int[Array, " n_activation"]
# species_inhibition: Int[Array, " n_inhibition"]

# def get_tc(self, parameters: PyTree) -> Scalar:
# return jnp.exp(parameters.log_transfer_constant[self.tc_ix])

# def get_dc_activation(self, parameters: PyTree) -> Scalar:
# return jnp.exp(
# parameters.log_dissociation_constant[self.ix_dc_activation],
# )

# def get_dc_inhibition(self, parameters: PyTree) -> Scalar:
# return jnp.exp(
# parameters.log_dissociation_constant[self.ix_dc_inhibition],
# )

# def allosteric_effect(self, conc: ConcArray, parameters: PyTree) -> Scalar:
# return generalised_mwc_effect(
# conc_inhibitor=conc[self.species_inhibition],
# conc_activator=conc[self.species_activation],
# free_enzyme_ratio=self.free_enzyme_ratio(conc, parameters),
# tc=self.get_tc(parameters),
# dc_inhibitor=self.get_dc_inhibition(parameters),
# dc_activator=self.get_dc_activation(parameters),
# subunits=self.subunits,
# )


# class AllostericIrreversibleMichaelisMenten(
# GeneralisedMWC, IrreversibleMichaelisMenten
# ):
# """A reaction with irreversible Michaelis Menten kinetics and allostery."""

# def __call__(self, conc: Float[Array, " n"], parameters: PyTree) -> Scalar:
# """The flux of an irreversible allosteric Michaelis Menten reaction."""
# allosteric_effect = self.allosteric_effect(conc, parameters)
# non_allosteric_rate = super().__call__(conc, parameters)
# return non_allosteric_rate * allosteric_effect


# class AllostericReversibleMichaelisMenten(
# GeneralisedMWC,
# ReversibleMichaelisMenten,
# ):
# """A reaction with reversible Michaelis Menten kinetics and allostery."""

# def __call__(self, conc: ConcArray, parameters: PyTree) -> Scalar:
# """The flux of an allosteric reversible Michaelis Menten reaction."""
# allosteric_effect = self.allosteric_effect(conc, parameters)
# non_allosteric_rate = super().__call__(conc, parameters)
# return non_allosteric_rate * allosteric_effect
Loading

0 comments on commit c8ddebd

Please sign in to comment.