diff --git a/docs/source/modeling.rst b/docs/source/modeling.rst
index f98e39085..2b909ab1e 100644
--- a/docs/source/modeling.rst
+++ b/docs/source/modeling.rst
@@ -4,12 +4,18 @@ Modeling
:members:
:show-inheritance:
-ASKEM AMR Petri net generation(:py:mod:`mira.modeling.askenet.petrinet`)
-------------------------------------------------------------------------
+ASKEM AMR Petri net generation (:py:mod:`mira.modeling.askenet.petrinet`)
+-------------------------------------------------------------------------
.. automodule:: mira.modeling.askenet.petrinet
:members:
:show-inheritance:
+ASKEM AMR operations (:py:mod:`mira.modeling.askenet.ops`)
+----------------------------------------------------------
+.. automodule:: mira.modeling.askenet.ops
+ :members:
+ :show-inheritance:
+
ASKEM AMR Regulatory net generation (:py:mod:`mira.modeling.askenet.regnet`)
----------------------------------------------------------------------------
.. automodule:: mira.modeling.askenet.regnet
diff --git a/mira/modeling/__init__.py b/mira/modeling/__init__.py
index 6ffc28f99..6a479a5ab 100644
--- a/mira/modeling/__init__.py
+++ b/mira/modeling/__init__.py
@@ -92,10 +92,14 @@ def assemble_variable(
if key in self.variables:
return self.variables[key]
- if initials and concept.name in initials:
- initial_value = initials[concept.name].value
- else:
- initial_value = None
+ # We don't assume that the initial dict key is the same as the
+ # name of the given concept the initial applies to, so we check
+ # concept name match instead of key match.
+ initial_value = None
+ if initials:
+ for k, v in initials.items():
+ if v.concept.name == concept.name:
+ initial_value = v.value
data = {
'name': concept.name,
@@ -141,8 +145,8 @@ def make_model(self):
value = self.template_model.parameters[key].value
distribution = self.template_model.parameters[key].distribution
self.get_create_parameter(
- ModelParameter(key, value, distribution,
- placeholder=False))
+ ModelParameter(key, value, distribution,
+ placeholder=False))
for template in self.template_model.templates:
if isinstance(template, StaticConcept):
@@ -245,4 +249,3 @@ def num_controllers(template):
return len(template.controllers)
else:
return 0
-
diff --git a/mira/modeling/askenet/ops.py b/mira/modeling/askenet/ops.py
new file mode 100644
index 000000000..6ffb50898
--- /dev/null
+++ b/mira/modeling/askenet/ops.py
@@ -0,0 +1,272 @@
+import copy
+import sympy
+from mira.metamodel import SympyExprStr
+import mira.metamodel.ops as tmops
+from mira.sources.askenet.petrinet import template_model_from_askenet_json
+from .petrinet import template_model_to_petrinet_json
+from mira.metamodel.io import mathml_to_expression
+from mira.metamodel.template_model import Parameter, Distribution, Observable
+from mira.metamodel.templates import NaturalConversion, NaturalProduction, NaturalDegradation
+
+
+def amr_to_mira(func):
+ def wrapper(amr, *args, **kwargs):
+ tm = template_model_from_askenet_json(amr)
+ result = func(tm, *args, **kwargs)
+ amr = template_model_to_petrinet_json(result)
+ return amr
+
+ return wrapper
+
+
+# Edit ID / label / name of State, Transition, Observable, Parameter, Initial
+@amr_to_mira
+def replace_state_id(tm, old_id, new_id):
+ """Replace the ID of a state."""
+ concepts_name_map = tm.get_concepts_name_map()
+ if old_id not in concepts_name_map:
+ raise ValueError(f"State with ID {old_id} not found in model.")
+ for template in tm.templates:
+ for concept in template.get_concepts():
+ if concept.name == old_id:
+ concept.name = new_id
+ template.rate_law = SympyExprStr(
+ template.rate_law.args[0].subs(sympy.Symbol(old_id),
+ sympy.Symbol(new_id)))
+ for observable in tm.observables.values():
+ observable.expression = SympyExprStr(
+ observable.expression.args[0].subs(sympy.Symbol(old_id),
+ sympy.Symbol(new_id)))
+ for key, initial in copy.deepcopy(tm.initials).items():
+ if initial.concept.name == old_id:
+ tm.initials[key].concept.name = new_id
+ # If the key is same as the old ID, we replace that too
+ if key == old_id:
+ tm.initials[new_id] = tm.initials.pop(old_id)
+ return tm
+
+
+@amr_to_mira
+def replace_transition_id(tm, old_id, new_id):
+ """Replace the ID of a transition."""
+ for template in tm.templates:
+ if template.name == old_id:
+ template.name = new_id
+ return tm
+
+
+@amr_to_mira
+def replace_observable_id(tm, old_id, new_id, name=None):
+ """Replace the ID of an observable."""
+ for obs, observable in copy.deepcopy(tm.observables).items():
+ if obs == old_id:
+ observable.name = new_id
+ observable.display_name = name if name else observable.display_name
+ tm.observables[new_id] = observable
+ tm.observables.pop(old_id)
+ return tm
+
+
+@amr_to_mira
+def remove_observable(tm, removed_id):
+ for obs, observable in copy.deepcopy(tm.observables).items():
+ if obs == removed_id:
+ tm.observables.pop(obs)
+ return tm
+
+
+@amr_to_mira
+def remove_parameter(tm, removed_id, replacement_value=None):
+ if replacement_value:
+ tm.substitute_parameter(removed_id, replacement_value)
+ else:
+ tm.eliminate_parameter(removed_id)
+ return tm
+
+
+@amr_to_mira
+def add_observable(tm, new_id, new_name, new_expression):
+ # Note that if an observable already exists with the given
+ # key, it will be replaced
+ rate_law_sympy = mathml_to_expression(new_expression)
+ new_observable = Observable(name=new_id, display_name=new_name,
+ expression=rate_law_sympy)
+ tm.observables[new_id] = new_observable
+ return tm
+
+
+@amr_to_mira
+def replace_parameter_id(tm, old_id, new_id):
+ """Replace the ID of a parameter."""
+ if old_id not in tm.parameters:
+ raise ValueError(f"Parameter with ID {old_id} not found in model.")
+ for template in tm.templates:
+ if template.rate_law:
+ template.rate_law = SympyExprStr(
+ template.rate_law.args[0].subs(sympy.Symbol(old_id),
+ sympy.Symbol(new_id)))
+ for observable in tm.observables.values():
+ observable.expression = SympyExprStr(
+ observable.expression.args[0].subs(sympy.Symbol(old_id),
+ sympy.Symbol(new_id)))
+ for key, param in copy.deepcopy(tm.parameters).items():
+ if param.name == old_id:
+ popped_param = tm.parameters.pop(param.name)
+ popped_param.name = new_id
+ tm.parameters[new_id] = popped_param
+ return tm
+
+
+# Resolve issue where only parameters are added only when they are present in rate laws.
+@amr_to_mira
+def add_parameter(tm, parameter_id: str,
+ value: float = None,
+ distribution=None,
+ units_mathml: str = None):
+ distribution = Distribution(**distribution) if distribution else None
+ if units_mathml:
+ units = {
+ 'expression': mathml_to_expression(units_mathml),
+ 'expression_mathml': units_mathml
+ }
+ else:
+ units = None
+ data = {
+ 'name': parameter_id,
+ 'value': value,
+ 'distribution': distribution,
+ 'units': units
+ }
+
+ new_param = Parameter(**data)
+ tm.parameters[parameter_id] = new_param
+
+ return tm
+
+
+@amr_to_mira
+def replace_initial_id(tm, old_id, new_id):
+ """Replace the ID of an initial."""
+ tm.initials = {
+ (new_id if k == old_id else k): v for k, v in tm.initials.items()
+ }
+ return tm
+
+
+# Remove state
+@amr_to_mira
+def remove_state(tm, state_id):
+ new_templates = []
+ for template in tm.templates:
+ to_remove = False
+ for concept in template.get_concepts():
+ if concept.name == state_id:
+ to_remove = True
+ if not to_remove:
+ new_templates.append(template)
+ tm.templates = new_templates
+
+ for obs, observable in tm.observables.items():
+ observable.expression = SympyExprStr(
+ observable.expression.args[0].subs(sympy.Symbol(state_id), 0))
+ return tm
+
+
+@amr_to_mira
+def add_state(tm, state_id, grounding: None, units: None):
+ pass
+
+
+# Remove transition
+@amr_to_mira
+def remove_transition(tm, transition_id):
+ tm.templates = [t for t in tm.templates if t.name != transition_id]
+ return tm
+
+
+@amr_to_mira
+def add_transition(tm, new_transition_id, src_id=None, tgt_id=None, rate_law_mathml=None):
+ # TODO: handle parameters added in the rate law as follows
+ # option 1 take in optional parameters dict if rate law contains parameters
+ # that aren't already present
+ # option 2, reverse engineer rate law and find parameters and states within
+ # the rate law and add to model
+ if src_id is None and tgt_id is None:
+ ValueError("You must pass in at least one of source and target id")
+ rate_law_sympy = SympyExprStr(mathml_to_expression(rate_law_mathml)) \
+ if rate_law_mathml else None
+ if src_id is None and tgt_id:
+ template = NaturalProduction(name=new_transition_id, outcome=tgt_id,
+ rate_law=rate_law_sympy)
+ elif src_id and tgt_id is None:
+ template = NaturalDegradation(name=new_transition_id, subject=src_id,
+ rate_law=rate_law_sympy)
+ else:
+ template = NaturalConversion(name=new_transition_id, subject=src_id,
+ outcome=tgt_id, rate_law=rate_law_sympy)
+ tm.templates.append(template)
+ return tm
+
+
+@amr_to_mira
+# rate law is of type Sympy Expression
+def replace_rate_law_sympy(tm, transition_id, new_rate_law: sympy.Expr):
+ # NOTE: this assumes that a sympy expression object is given
+ # though it might make sense to take a string instead
+ for template in tm.templates:
+ if template.name == transition_id:
+ template.rate_law = SympyExprStr(new_rate_law)
+ return tm
+
+
+def replace_rate_law_mathml(tm, transition_id, new_rate_law):
+ new_rate_law_sympy = mathml_to_expression(new_rate_law)
+ return replace_rate_law_sympy(tm, transition_id, new_rate_law_sympy)
+
+
+# currently initials don't support expressions so only implement the following 2 methods for observables
+# if we are seeking to replace an expression in an initial, return current template model
+@amr_to_mira
+def replace_observable_expression_sympy(tm, obs_id,
+ new_expression_sympy: sympy.Expr):
+ for obs, observable in tm.observables.items():
+ if obs == obs_id:
+ observable.expression = SympyExprStr(new_expression_sympy)
+ return tm
+
+
+def replace_intial_expression_sympy(tm, initial_id,
+ new_expression_sympy: sympy.Expr):
+ # TODO: once initial expressions are supported, implement this
+ return tm
+
+
+def replace_observable_expression_mathml(tm, obj_id, new_expression_mathml):
+ new_expression_sympy = mathml_to_expression(new_expression_mathml)
+ return replace_observable_expression_sympy(tm, obj_id,
+ new_expression_sympy)
+
+
+def replace_intial_expression_mathml(tm, initial_id, new_expression_mathml):
+ # TODO: once initial expressions are supported, implement this
+ return tm
+
+
+@amr_to_mira
+def stratify(*args, **kwargs):
+ return tmops.stratify(*args, **kwargs)
+
+
+@amr_to_mira
+def simplify_rate_laws(*args, **kwargs):
+ return tmops.simplify_rate_laws(*args, **kwargs)
+
+
+@amr_to_mira
+def aggregate_parameters(*args, **kwargs):
+ return tmops.aggregate_parameters(*args, **kwargs)
+
+
+@amr_to_mira
+def counts_to_dimensionless(*args, **kwargs):
+ return tmops.counts_to_dimensionless(*args, **kwargs)
diff --git a/mira/modeling/askenet/petrinet.py b/mira/modeling/askenet/petrinet.py
index 784932058..5ce3327a4 100644
--- a/mira/modeling/askenet/petrinet.py
+++ b/mira/modeling/askenet/petrinet.py
@@ -2,7 +2,8 @@
at https://github.com/DARPA-ASKEM/Model-Representations/tree/main/petrinet.
"""
-__all__ = ["AskeNetPetriNetModel", "ModelSpecification"]
+__all__ = ["AskeNetPetriNetModel", "ModelSpecification",
+ "template_model_to_petrinet_json"]
import json
@@ -12,7 +13,8 @@
from pydantic import BaseModel, Field
-from mira.metamodel import expression_to_mathml, safe_parse_expr
+from mira.metamodel import expression_to_mathml, safe_parse_expr, \
+ TemplateModel
from .. import Model
from .utils import add_metadata_annotations
@@ -103,9 +105,12 @@ def __init__(self, model: Model):
self.initials.append(initial_data)
for key, observable in model.observables.items():
+ display_name = observable.observable.display_name \
+ if observable.observable.display_name \
+ else observable.observable.name
obs_data = {
'id': observable.observable.name,
- 'name': observable.observable.name,
+ 'name': display_name,
'expression': str(observable.observable.expression),
'expression_mathml': expression_to_mathml(
observable.observable.expression.args[0]),
@@ -273,6 +278,21 @@ def to_json_file(self, fname, name=None, description=None,
json.dump(js, fh, indent=indent, **kwargs)
+def template_model_to_petrinet_json(tm: TemplateModel):
+ """Convert a template model to a PetriNet JSON dict.
+
+ Parameters
+ ----------
+ tm :
+ The template model to convert.
+
+ Returns
+ -------
+ A JSON dict representing the PetriNet model.
+ """
+ return AskeNetPetriNetModel(Model(tm)).to_json()
+
+
class Initial(BaseModel):
target: str
expression: str
diff --git a/mira/sources/askenet/petrinet.py b/mira/sources/askenet/petrinet.py
index 2699e75d2..5ca309f97 100644
--- a/mira/sources/askenet/petrinet.py
+++ b/mira/sources/askenet/petrinet.py
@@ -145,7 +145,8 @@ def template_model_from_askenet_json(model_json) -> TemplateModel:
continue
observable = Observable(name=observable['id'],
- expression=observable_expr)
+ expression=observable_expr,
+ display_name=observable.get('name'))
observables[observable.name] = observable
# We get the time variable from the semantics
diff --git a/notebooks/applications/Enzyme_substrate_kinetics.ipynb b/notebooks/applications/Enzyme_substrate_kinetics.ipynb
index 2d7ed8a25..cd837e90f 100644
--- a/notebooks/applications/Enzyme_substrate_kinetics.ipynb
+++ b/notebooks/applications/Enzyme_substrate_kinetics.ipynb
@@ -3,7 +3,6 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "475993a6-38f9-4627-9cf6-808d351a9faf",
"metadata": {},
"outputs": [],
"source": [
@@ -23,7 +22,6 @@
{
"cell_type": "code",
"execution_count": 3,
- "id": "015c410c-0f31-448e-ba17-bda9729e1fc5",
"metadata": {},
"outputs": [
{
@@ -47,7 +45,6 @@
{
"cell_type": "code",
"execution_count": 148,
- "id": "d45757d7-1fb9-4d0f-ae90-d109e3c7a893",
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +54,6 @@
{
"cell_type": "code",
"execution_count": 151,
- "id": "4df86721-da78-4426-8313-ff9e1dc2115c",
"metadata": {},
"outputs": [],
"source": [
@@ -67,7 +63,6 @@
{
"cell_type": "code",
"execution_count": 44,
- "id": "ce3fd838-2a3c-4f65-a4b6-c9fc6064a803",
"metadata": {},
"outputs": [],
"source": [
@@ -77,7 +72,6 @@
{
"cell_type": "code",
"execution_count": 153,
- "id": "b2bdf38d-afee-48a9-bcbe-315235142807",
"metadata": {},
"outputs": [],
"source": [
@@ -89,7 +83,6 @@
{
"cell_type": "code",
"execution_count": 154,
- "id": "a6b5b9fd-8ac7-42cb-bd3c-c09cc2e96795",
"metadata": {},
"outputs": [],
"source": [
@@ -103,7 +96,6 @@
{
"cell_type": "code",
"execution_count": 157,
- "id": "b10bfcb2-ef77-42b3-9249-3ded6309596a",
"metadata": {},
"outputs": [],
"source": [
@@ -113,7 +105,6 @@
{
"cell_type": "code",
"execution_count": 158,
- "id": "db584dc2-0700-499e-ac07-3e3c1bb29300",
"metadata": {},
"outputs": [
{
@@ -137,7 +128,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -151,7 +142,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.4"
+ "version": "3.7.3"
}
},
"nbformat": 4,
diff --git a/pyproject.toml b/pyproject.toml
index 09aa4efdb..a9107d5e1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,5 +15,6 @@ reverse_relative = true
[tool.pytest.ini_options]
markers = [
- "slow: marks tests as slow for pytest (deselect with '-m \"not slow\"')"
+ "slow: marks tests as slow for pytest (deselect with '-m \"not slow\"')",
+ "sbmlmath: marks tests that import sbmlmath pytest (deselect with '-m \"not sbmlmath\"')"
]
diff --git a/tests/test_modeling/test_askenet_ops.py b/tests/test_modeling/test_askenet_ops.py
new file mode 100644
index 000000000..a844c6e00
--- /dev/null
+++ b/tests/test_modeling/test_askenet_ops.py
@@ -0,0 +1,490 @@
+import unittest
+import requests
+import pytest
+from copy import deepcopy as _d
+from mira.modeling.askenet.ops import *
+from sympy import *
+from mira.metamodel.io import mathml_to_expression
+from mira.metamodel.templates import Concept
+
+
+class TestAskenetOperations(unittest.TestCase):
+ """A test case for operations on template models."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sir_amr = requests.get(
+ 'https://raw.githubusercontent.com/DARPA-ASKEM/'
+ 'Model-Representations/main/petrinet/examples/sir.json').json()
+
+ '''These unit tests are conducted by zipping through lists of each key in a amr file
+ (e.g. parameters, observables, etc). Zipping in this manner assumes that the order (assuming insertion order)
+ is preserved before and after mira operation for an amr key
+ '''
+
+ def test_replace_state_id(self):
+ old_id = 'S'
+ new_id = 'X'
+ amr = _d(self.sir_amr)
+ new_amr = replace_state_id(amr, old_id, new_id)
+
+ old_model = amr['model']
+ new_model = new_amr['model']
+
+ old_model_states = old_model['states']
+ new_model_states = new_model['states']
+
+ old_model_transitions = old_model['transitions']
+ new_model_transitions = new_model['transitions']
+
+ self.assertEqual(len(old_model_states), len(new_model_states))
+ for old_state, new_state in zip(old_model_states, new_model_states):
+
+ # output states missing description field
+ if old_state['id'] == old_id:
+ self.assertEqual(new_state['id'], new_id)
+ self.assertEqual(old_state['name'], new_state['name'])
+ self.assertEqual(old_state['grounding']['identifiers'], new_state['grounding']['identifiers'])
+ self.assertEqual(old_state['units'], new_state['units'])
+
+ self.assertEqual(len(old_model_transitions), len(new_model_transitions))
+
+ # output transitions are missing a description and transition['properties']['name'] field
+ # is abbreviated in output amr
+ for old_transition, new_transition in zip(old_model_transitions, new_model_transitions):
+ if old_id in old_transition['input'] or old_id in old_transition['output']:
+ self.assertIn(new_id, new_transition['input'])
+ self.assertNotIn(old_id, new_transition['output'])
+ self.assertEqual(len(old_transition['input']), len(new_transition['input']))
+ self.assertEqual(len(old_transition['output']), len(new_transition['output']))
+ self.assertEqual(old_transition['id'], new_transition['id'])
+
+ old_semantics_ode = amr['semantics']['ode']
+ new_semantics_ode = new_amr['semantics']['ode']
+
+ old_semantics_ode_rates = old_semantics_ode['rates']
+ new_semantics_ode_rates = new_semantics_ode['rates']
+
+ # this test doesn't account for if the expression semantic is preserved (e.g. same type of operations)
+ # would pass test if we call replace_state_id(I,J) and old expression is "I*X" and new expression is "J+X"
+ for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates):
+ if old_id in old_rate['expression'] or old_id in old_rate['expression_mathml']:
+ self.assertIn(new_id, new_rate['expression'])
+ self.assertNotIn(old_id, new_rate['expression'])
+
+ self.assertIn(new_id, new_rate['expression_mathml'])
+ self.assertNotIn(old_id, new_rate['expression_mathml'])
+
+ self.assertEqual(old_rate['target'], new_rate['target'])
+
+ # initials have float values substituted in for state ids in their expression and expression_mathml field
+ old_semantics_ode_initials = old_semantics_ode['initials']
+ new_semantics_ode_initials = new_semantics_ode['initials']
+
+ self.assertEqual(len(old_semantics_ode_initials), len(new_semantics_ode_initials))
+ for old_initials, new_initials in zip(old_semantics_ode_initials, new_semantics_ode_initials):
+ if old_id == old_initials['target']:
+ self.assertEqual(new_initials['target'], new_id)
+
+ old_semantics_ode_parameters = old_semantics_ode['parameters']
+ new_semantics_ode_parameters = new_semantics_ode['parameters']
+ # This is due to initial expressions vs values
+ assert len(old_semantics_ode_parameters) == 5
+ assert len(new_semantics_ode_parameters) == 2
+
+ # zip method iterates over length of the smaller iterable len(new_semantics_ode_parameters) = 2
+ # as opposed to len(old_semantics_ode_parameters) = 5 , non-state parameters are listed first in input amr
+ for old_params, new_params in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
+ # test to see if old_id/new_id in name/id field and not for id/name equality because these fields
+ # may contain subscripts or timestamps appended to the old_id/new_id
+ if old_id in old_params['id'] and old_id in old_params['name']:
+ self.assertIn(new_id, new_params['id'])
+ self.assertIn(new_id, new_params['name'])
+
+ old_semantics_ode_observables = old_semantics_ode['observables']
+ new_semantics_ode_observables = new_semantics_ode['observables']
+ self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables))
+
+ for old_observable, new_observable in zip(old_semantics_ode_observables, new_semantics_ode_observables):
+ if old_id in old_observable['states'] and old_id in old_observable['expression'] and \
+ old_id in old_observable['expression_mathml']:
+ self.assertIn(new_id, new_observable['expression'])
+ self.assertNotIn(old_id, new_observable['expression'])
+
+ self.assertIn(new_id, new_observable['expression_mathml'])
+ self.assertNotIn(old_id, new_observable['expression_mathml'])
+
+ self.assertEqual(old_observable['id'], new_observable['id'])
+
+ def test_replace_transition_id(self):
+
+ old_id = 'inf'
+ new_id = 'new_inf'
+ amr = _d(self.sir_amr)
+ new_amr = replace_transition_id(amr, old_id, new_id)
+
+ old_model_transitions = amr['model']['transitions']
+ new_model_transitions = new_amr['model']['transitions']
+
+ self.assertEqual(len(old_model_transitions), len(new_model_transitions))
+
+ for old_transitions, new_transition in zip(old_model_transitions, new_model_transitions):
+ if old_transitions['id'] == old_id:
+ self.assertEqual(new_transition['id'], new_id)
+
+ def test_replace_observable_id(self):
+ old_id = 'noninf'
+ new_id = 'testinf'
+ new_display_name = 'test-infection'
+ amr = _d(self.sir_amr)
+ new_amr = replace_observable_id(amr, old_id, new_id,
+ name=new_display_name)
+
+ old_semantics_observables = amr['semantics']['ode']['observables']
+ new_semantics_observables = new_amr['semantics']['ode']['observables']
+
+ self.assertEqual(len(old_semantics_observables), len(new_semantics_observables))
+
+ for old_observable, new_observable in zip(old_semantics_observables, new_semantics_observables):
+ if old_observable['id'] == old_id:
+ self.assertEqual(new_observable['id'], new_id)
+ self.assertEqual(new_observable['name'], new_display_name)
+
+ def test_remove_observable_or_parameter(self):
+
+ old_amr_obs = _d(self.sir_amr)
+ old_amr_param = _d(self.sir_amr)
+
+ replaced_observable_id = 'noninf'
+ new_amr_obs = remove_observable(old_amr_obs, replaced_observable_id)
+ for new_observable in new_amr_obs['semantics']['ode']['observables']:
+ self.assertNotEqual(new_observable['id'], replaced_observable_id)
+
+ replaced_param_id = 'beta'
+ replacement_value = 5
+ new_amr_param = remove_parameter(old_amr_param, replaced_param_id, replacement_value)
+ for new_param in new_amr_param['semantics']['ode']['parameters']:
+ self.assertNotEqual(new_param['id'], replaced_param_id)
+ for old_rate, new_rate in zip(old_amr_param['semantics']['ode']['rates'],
+ new_amr_param['semantics']['ode']['rates']):
+ if replaced_param_id in old_rate['expression'] and replaced_param_id in old_rate['expression_mathml']:
+ self.assertNotIn(replaced_param_id, new_rate['expression'])
+ self.assertIn(str(replacement_value), new_rate['expression'])
+
+ self.assertNotIn(replaced_param_id, new_rate['expression_mathml'])
+ self.assertIn(str(replacement_value), new_rate['expression_mathml'])
+
+ self.assertEqual(old_rate['target'], new_rate['target'])
+
+ # currently don't support expressions for initials
+ for old_obs, new_obs in zip(old_amr_param['semantics']['ode']['observables'],
+ new_amr_param['semantics']['ode']['observables']):
+ if replaced_param_id in old_obs['expression'] and replaced_param_id in old_obs['expression_mathml']:
+ self.assertNotIn(replaced_param_id, new_obs['expression'])
+ self.assertIn(str(replacement_value), new_obs['expression'])
+
+ self.assertNotIn(replaced_param_id, new_obs['expression_mathml'])
+ self.assertIn(str(replacement_value), new_obs['expression_mathml'])
+
+ self.assertEqual(old_obs['id'], new_obs['id'])
+
+ @pytest.mark.sbmlmath
+ def test_add_observable(self):
+ amr = _d(self.sir_amr)
+ new_id = 'testinf'
+ new_display_name = 'DISPLAY_TEST'
+ xml_expression = "Edelta"
+ new_amr = add_observable(amr, new_id, new_display_name, xml_expression)
+
+ # Create a dict out of a list of observable dict entries to easier test for addition of new observables
+ new_observable_dict = {}
+ for observable in new_amr['semantics']['ode']['observables']:
+ name = observable.pop('id')
+ new_observable_dict[name] = observable
+
+ self.assertIn(new_id, new_observable_dict)
+ self.assertEqual(new_display_name, new_observable_dict[new_id]['name'])
+ self.assertEqual(xml_expression, new_observable_dict[new_id]['expression_mathml'])
+ self.assertEqual(sstr(mathml_to_expression(xml_expression)), new_observable_dict[new_id]['expression'])
+
+ @pytest.mark.sbmlmath
+ def test_replace_parameter_id(self):
+ old_id = 'beta'
+ new_id = 'TEST'
+ amr = _d(self.sir_amr)
+ new_amr = replace_parameter_id(amr, old_id, new_id)
+
+ old_semantics_ode_rates = amr['semantics']['ode']['rates']
+ new_semantics_ode_rates = new_amr['semantics']['ode']['rates']
+
+ old_semantics_ode_observables = amr['semantics']['ode']['observables']
+ new_semantics_ode_observables = new_amr['semantics']['ode']['observables']
+
+ old_semantics_ode_parameters = amr['semantics']['ode']['parameters']
+ new_semantics_ode_parameters = new_amr['semantics']['ode']['parameters']
+
+ new_model_states = new_amr['model']['states']
+
+ self.assertEqual(len(old_semantics_ode_rates), len(new_semantics_ode_rates))
+ self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables))
+
+ self.assertEqual(len(old_semantics_ode_parameters) - len(new_model_states), len(new_semantics_ode_parameters))
+
+ for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates):
+ if old_id in old_rate['expression'] and old_id in old_rate['expression_mathml']:
+ self.assertIn(new_id, new_rate['expression'])
+ self.assertNotIn(old_id, new_rate['expression'])
+
+ self.assertIn(new_id, new_rate['expression_mathml'])
+ self.assertNotIn(old_id, new_rate['expression_mathml'])
+
+ # don't test states field for a parameter as it is assumed that replace_parameter_id will only be used with
+ # parameters such as gamma or beta (i.e. non-states)
+ for old_observable, new_observable in zip(old_semantics_ode_observables, new_semantics_ode_observables):
+ if old_id in old_observable['expression'] and old_id in new_observable['expression_mathml']:
+ self.assertIn(new_id, new_observable['expression'])
+ self.assertNotIn(old_id, new_observable['expression'])
+
+ self.assertIn(new_id, new_observable['expression_mathml'])
+ self.assertNotIn(old_id, new_observable['expression_mathml'])
+
+ for old_parameter, new_parameter in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
+ if old_parameter['id'] == old_id:
+ self.assertEqual(new_parameter['id'], new_id)
+ self.assertEqual(old_parameter['value'], new_parameter['value'])
+ self.assertEqual(old_parameter['distribution'], new_parameter['distribution'])
+ self.assertEqual(sstr(old_parameter['units']['expression']), new_parameter['units']['expression'])
+ self.assertEqual(mathml_to_expression(old_parameter['units']['expression_mathml']),
+ mathml_to_expression(new_parameter['units']['expression_mathml']))
+
+ def test_remove_state(self):
+ removed_state_id = 'S'
+ amr = _d(self.sir_amr)
+
+ new_amr = remove_state(amr, removed_state_id)
+
+ new_model = new_amr['model']
+ new_model_states = new_model['states']
+ new_model_transitions = new_model['transitions']
+
+ new_semantics_ode = new_amr['semantics']['ode']
+ new_semantics_ode_rates = new_semantics_ode['rates']
+ new_semantics_ode_initials = new_semantics_ode['initials']
+ new_semantics_ode_parameters = new_semantics_ode['parameters']
+ new_semantics_ode_observables = new_semantics_ode['observables']
+
+ for new_state in new_model_states:
+ self.assertNotEquals(removed_state_id, new_state['id'])
+
+ for new_transition in new_model_transitions:
+ self.assertNotIn(removed_state_id, new_transition['input'])
+ self.assertNotIn(removed_state_id, new_transition['output'])
+
+ # output rates that originally contained targeted state are removed
+ for new_rate in new_semantics_ode_rates:
+ self.assertNotIn(removed_state_id, new_rate['expression'])
+ self.assertNotIn(removed_state_id, new_rate['expression_mathml'])
+
+ # initials are bugged, all states removed rather than just targeted removed state in output amr
+ for new_initial in new_semantics_ode_initials:
+ self.assertNotEquals(removed_state_id, new_initial['target'])
+
+ # parameters that are associated in an expression with a removed state are not present in output amr
+ # (e.g.) if there exists an expression: "S*I*beta" and we remove S, then beta is no longer present in output
+ # list of parameters
+ for new_parameter in new_semantics_ode_parameters:
+ self.assertNotIn(removed_state_id, new_parameter['id'])
+
+ # output observable expressions that originally contained targeted state still exist with targeted state removed
+ # (e.g. 'S+R' -> 'R') if 'S' is the removed state
+ for new_observable in new_semantics_ode_observables:
+ self.assertNotIn(removed_state_id, new_observable['expression'])
+ self.assertNotIn(removed_state_id, new_observable['expression_mathml'])
+
+ def test_remove_transition(self):
+ removed_transition = 'inf'
+ amr = _d(self.sir_amr)
+
+ new_amr = remove_transition(amr, removed_transition)
+ new_model_transition = new_amr['model']['transitions']
+
+ for new_transition in new_model_transition:
+ self.assertNotEquals(removed_transition, new_transition['id'])
+
+ @pytest.mark.sbmlmath
+ def test_add_transition(self):
+ test_subject = Concept(name="test_subject", identifiers={"ido": "0000511"})
+ test_outcome = Concept(name="test_outcome", identifiers={"ido": "0000592"})
+ expression_xml = "Edelta"
+ new_transition_id = 'test'
+ old_natural_conversion_amr = _d(self.sir_amr)
+ old_natural_production_amr = _d(self.sir_amr)
+ old_natural_degradation_amr = _d(self.sir_amr)
+
+ # NaturalConversion
+ new_natural_conversion_amr = add_transition(old_natural_conversion_amr, new_transition_id,
+ rate_law_mathml=expression_xml,
+ src_id=test_subject,
+ tgt_id=test_outcome)
+ natural_conversion_transition_dict = {}
+ natural_conversion_rates_dict = {}
+ natural_conversion_state_dict = {}
+
+ for transition in new_natural_conversion_amr['model']['transitions']:
+ name = transition.pop('id')
+ natural_conversion_transition_dict[name] = transition
+
+ for rate in new_natural_conversion_amr['semantics']['ode']['rates']:
+ name = rate.pop('target')
+ natural_conversion_rates_dict[name] = rate
+
+ for state in new_natural_conversion_amr['model']['states']:
+ name = state.pop('id')
+ natural_conversion_state_dict[name] = state
+
+ self.assertIn(new_transition_id, natural_conversion_transition_dict)
+ self.assertIn(new_transition_id, natural_conversion_rates_dict)
+ self.assertEqual(expression_xml, natural_conversion_rates_dict[new_transition_id]['expression_mathml'])
+ self.assertEqual(sstr(mathml_to_expression(expression_xml)),
+ natural_conversion_rates_dict[new_transition_id]['expression'])
+ self.assertIn(test_subject.name, natural_conversion_state_dict)
+ self.assertIn(test_outcome.name, natural_conversion_state_dict)
+
+ # NaturalProduction
+ new_natural_production_amr = add_transition(old_natural_production_amr, new_transition_id,
+ rate_law_mathml=expression_xml,
+ tgt_id=test_outcome)
+ natural_production_transition_dict = {}
+ natural_production_rates_dict = {}
+ natural_production_state_dict = {}
+
+ for transition in new_natural_production_amr['model']['transitions']:
+ name = transition.pop('id')
+ natural_production_transition_dict[name] = transition
+
+ for rate in new_natural_production_amr['semantics']['ode']['rates']:
+ name = rate.pop('target')
+ natural_production_rates_dict[name] = rate
+
+ for state in new_natural_production_amr['model']['states']:
+ name = state.pop('id')
+ natural_production_state_dict[name] = state
+
+ self.assertIn(new_transition_id, natural_production_transition_dict)
+ self.assertIn(new_transition_id, natural_production_rates_dict)
+ self.assertEqual(expression_xml, natural_production_rates_dict[new_transition_id]['expression_mathml'])
+ self.assertEqual(sstr(mathml_to_expression(expression_xml)),
+ natural_production_rates_dict[new_transition_id]['expression'])
+ self.assertIn(test_outcome.name, natural_production_state_dict)
+
+ # NaturalDegradation
+ new_natural_degradation_amr = add_transition(old_natural_degradation_amr, new_transition_id,
+ rate_law_mathml=expression_xml,
+ src_id=test_subject)
+ natural_degradation_transition_dict = {}
+ natural_degradation_rates_dict = {}
+ natural_degradation_states_dict = {}
+
+ for transition in new_natural_degradation_amr['model']['transitions']:
+ name = transition.pop('id')
+ natural_degradation_transition_dict[name] = transition
+
+ for rate in new_natural_degradation_amr['semantics']['ode']['rates']:
+ name = rate.pop('target')
+ natural_degradation_rates_dict[name] = rate
+
+ for state in new_natural_degradation_amr['model']['states']:
+ name = state.pop('id')
+ natural_degradation_states_dict[name] = state
+
+ self.assertIn(new_transition_id, natural_degradation_transition_dict)
+ self.assertIn(new_transition_id, natural_degradation_rates_dict)
+ self.assertEqual(expression_xml, natural_degradation_rates_dict[new_transition_id]['expression_mathml'])
+ self.assertEqual(sstr(mathml_to_expression(expression_xml)),
+ natural_degradation_rates_dict[new_transition_id]['expression'])
+ self.assertIn(test_subject.name, natural_degradation_states_dict)
+
+ @pytest.mark.sbmlmath
+ def test_replace_rate_law_sympy(self):
+ transition_id = 'inf'
+ target_expression_xml_str = 'X8'
+ target_expression_sympy = mathml_to_expression(target_expression_xml_str)
+
+ amr = _d(self.sir_amr)
+ new_amr = replace_rate_law_sympy(amr, transition_id, target_expression_sympy)
+ new_semantics_ode_rates = new_amr['semantics']['ode']['rates']
+
+ for new_rate in new_semantics_ode_rates:
+ if new_rate['target'] == transition_id:
+ self.assertEqual(sstr(target_expression_sympy), new_rate['expression'])
+ self.assertEqual(target_expression_xml_str, new_rate['expression_mathml'])
+
+ @pytest.mark.sbmlmath
+ def test_replace_rate_law_mathml(self):
+ amr = _d(self.sir_amr)
+ transition_id = 'inf'
+ target_expression_xml_str = "Edelta"
+ target_expression_sympy = mathml_to_expression(target_expression_xml_str)
+
+ new_amr = replace_rate_law_mathml(amr, transition_id, target_expression_xml_str)
+
+ new_semantics_ode_rates = new_amr['semantics']['ode']['rates']
+
+ for new_rate in new_semantics_ode_rates:
+ if new_rate['target'] == transition_id:
+ self.assertEqual(sstr(target_expression_sympy), new_rate['expression'])
+ self.assertEqual(target_expression_xml_str, new_rate['expression_mathml'])
+
+ @pytest.mark.sbmlmath
+ # Following 2 unit tests only test for replacing expressions in observables, not initials
+ def test_replace_observable_expression_sympy(self):
+ object_id = 'noninf'
+ amr = _d(self.sir_amr)
+ target_expression_xml_str = "Ebeta"
+ target_expression_sympy = mathml_to_expression(target_expression_xml_str)
+ new_amr = replace_observable_expression_sympy(amr, object_id, target_expression_sympy)
+
+ for new_obs in new_amr['semantics']['ode']['observables']:
+ if new_obs['id'] == object_id:
+ self.assertEqual(sstr(target_expression_sympy), new_obs['expression'])
+ self.assertEqual(target_expression_xml_str, new_obs['expression_mathml'])
+
+ @pytest.mark.sbmlmath
+ def test_replace_observable_expression_mathml(self):
+ object_id = 'noninf'
+ amr = _d(self.sir_amr)
+ target_expression_xml_str = "Ebeta"
+ target_expression_sympy = mathml_to_expression(target_expression_xml_str)
+ new_amr = replace_observable_expression_mathml(amr, object_id, target_expression_xml_str)
+
+ for new_obs in new_amr['semantics']['ode']['observables']:
+ if new_obs['id'] == object_id:
+ self.assertEqual(sstr(target_expression_sympy), new_obs['expression'])
+ self.assertEqual(target_expression_xml_str, new_obs['expression_mathml'])
+
+ def test_stratify(self):
+ amr = _d(self.sir_amr)
+ new_amr = stratify(amr, key='city', strata=['boston', 'nyc'])
+
+ self.assertIsInstance(amr, dict)
+ self.assertIsInstance(new_amr, dict)
+
+ def test_simplify_rate_laws(self):
+ amr = _d(self.sir_amr)
+ new_amr = simplify_rate_laws(amr)
+
+ self.assertIsInstance(amr, dict)
+ self.assertIsInstance(new_amr, dict)
+
+ def test_aggregate_parameters(self):
+ amr = _d(self.sir_amr)
+ new_amr = aggregate_parameters(amr)
+
+ self.assertIsInstance(amr, dict)
+ self.assertIsInstance(new_amr, dict)
+
+ def test_counts_to_dimensionless(self):
+ amr = _d(self.sir_amr)
+ new_amr = counts_to_dimensionless(amr, 'ml', .8)
+ self.assertIsInstance(amr, dict)
+ self.assertIsInstance(new_amr, dict)
diff --git a/tests/test_ops.py b/tests/test_ops.py
index 23f58d290..d0d00c90a 100644
--- a/tests/test_ops.py
+++ b/tests/test_ops.py
@@ -5,6 +5,8 @@
from copy import deepcopy as _d
import sympy
+import requests
+import itertools
from mira.metamodel import *
from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless
@@ -231,10 +233,8 @@ def test_stratify_directed_simple(self):
def assert_unique_controllers(self, tm: TemplateModel):
"""Assert that controllers are unique."""
for template in tm.templates:
- if not isinstance(
- template,
- (GroupedControlledConversion, GroupedControlledProduction)
- ):
+ if not isinstance(template, (GroupedControlledConversion,
+ GroupedControlledProduction)):
continue
counter = Counter(
controller.get_key()
@@ -298,7 +298,7 @@ def _make_template(rate_law):
assert all(t.type == 'ControlledConversion' for t in templates)
# This one can be simplified too
- rate_law = (1 - _s('alpha')) * _s('S') * (_s('A') + _s('beta')*_s('B'))
+ rate_law = (1 - _s('alpha')) * _s('S') * (_s('A') + _s('beta') * _s('B'))
template = _make_template(rate_law)
templates = simplify_rate_law(template,
{'alpha': Parameter(name='alpha',
@@ -321,12 +321,12 @@ def test_counts_to_dimensionless():
for template in tm.templates:
for concept in template.get_concepts():
concept.units = Unit(expression=sympy.Symbol('person'))
- tm.initials['susceptible_population'].value = 1e5-1
+ tm.initials['susceptible_population'].value = 1e5 - 1
tm.initials['infected_population'].value = 1
tm.initials['immune_population'].value = 0
tm.parameters['beta'].units = \
- Unit(expression=1/(sympy.Symbol('person')*sympy.Symbol('day')))
+ Unit(expression=1 / (sympy.Symbol('person') * sympy.Symbol('day')))
old_beta = tm.parameters['beta'].value
for initial in tm.initials.values():
@@ -337,11 +337,11 @@ def test_counts_to_dimensionless():
for concept in template.get_concepts():
assert concept.units.expression.args[0].equals(1), concept.units
- assert tm.parameters['beta'].units.expression.args[0].equals(1/sympy.Symbol('day'))
- assert tm.parameters['beta'].value == old_beta*1e5
+ assert tm.parameters['beta'].units.expression.args[0].equals(1 / sympy.Symbol('day'))
+ assert tm.parameters['beta'].value == old_beta * 1e5
- assert tm.initials['susceptible_population'].value == (1e5-1)/1e5
- assert tm.initials['infected_population'].value == 1/1e5
+ assert tm.initials['susceptible_population'].value == (1e5 - 1) / 1e5
+ assert tm.initials['infected_population'].value == 1 / 1e5
assert tm.initials['immune_population'].value == 0
for initial in tm.initials.values():
@@ -354,7 +354,7 @@ def test_stratify_observable():
expr = sympy.Add(*[sympy.Symbol(s) for s in symbols])
tm.observables = {'half_population': Observable(
name='half_population',
- expression=SympyExprStr(expr/2))
+ expression=SympyExprStr(expr / 2))
}
tm = stratify(tm,
key='age',
diff --git a/tests/test_templatemodel_delta.py b/tests/test_templatemodel_delta.py
index 61fb6f238..abad2979e 100644
--- a/tests/test_templatemodel_delta.py
+++ b/tests/test_templatemodel_delta.py
@@ -84,10 +84,10 @@ def test_equal_no_context(self):
edge_count,
f"len(edges)={len(tmd.comparison_graph.edges)}",
)
- self.assert_(
+ self.assertTrue(
all("is_refinement" != d["label"] for _, _, d in tmd.comparison_graph.edges(data=True))
)
- self.assert_(
+ self.assertTrue(
all(
d["label"] in ["is_equal"] + concept_edge_labels
for _, _, d in tmd.comparison_graph.edges(data=True)
@@ -111,13 +111,13 @@ def test_equal_context(self):
edge_count,
f"len(edges)={len(tmd_context.comparison_graph.edges)}",
)
- self.assert_(
+ self.assertTrue(
all(
"is_refinement" != d["label"]
for _, _, d in tmd_context.comparison_graph.edges(data=True)
)
)
- self.assert_(
+ self.assertTrue(
all(
d["label"] in ["is_equal"] + concept_edge_labels for _, _,
d in tmd_context.comparison_graph.edges(data=True)
@@ -148,24 +148,24 @@ def test_refinement(self):
tmd_vs_boston = TemplateModelDelta(self.sir, self.sir_boston, is_ontological_child_web)
tmd_vs_nyc = TemplateModelDelta(self.sir, self.sir_nyc, is_ontological_child_web)
- self.assert_(
+ self.assertTrue(
all(
d["label"] in ["refinement_of"] + concept_edge_labels
for _, _, d in tmd_vs_boston.comparison_graph.edges(data=True)
)
)
- self.assert_(
+ self.assertTrue(
all(
"is_equal" != d["label"]
for _, _, d in tmd_vs_boston.comparison_graph.edges(data=True)
)
)
- self.assert_(
+ self.assertTrue(
all(
d["label"] in ["refinement_of"] + concept_edge_labels
for _, _, d in tmd_vs_nyc.comparison_graph.edges(data=True)
)
)
- self.assert_(
+ self.assertTrue(
all("is_equal" != d["label"] for _, _, d in tmd_vs_nyc.comparison_graph.edges(data=True))
)
diff --git a/tox.ini b/tox.ini
index 0eb59af35..36779833f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -10,15 +10,16 @@ envlist =
py
[testenv]
-passenv = PYTHONPATH MIRA_REST_URL
+passenv = PYTHONPATH, MIRA_REST_URL
extras =
tests
web
+deps =
+ anyio<4
commands =
coverage run -p -m pytest --durations=20 {posargs:tests} -m "not slow and not sbmlmath"
; coverage combine
; coverage xml
-
[testenv:docs]
extras =
docs