diff --git a/docs/source/modeling.rst b/docs/source/modeling.rst index f98e39085..2b909ab1e 100644 --- a/docs/source/modeling.rst +++ b/docs/source/modeling.rst @@ -4,12 +4,18 @@ Modeling :members: :show-inheritance: -ASKEM AMR Petri net generation(:py:mod:`mira.modeling.askenet.petrinet`) ------------------------------------------------------------------------- +ASKEM AMR Petri net generation (:py:mod:`mira.modeling.askenet.petrinet`) +------------------------------------------------------------------------- .. automodule:: mira.modeling.askenet.petrinet :members: :show-inheritance: +ASKEM AMR operations (:py:mod:`mira.modeling.askenet.ops`) +---------------------------------------------------------- +.. automodule:: mira.modeling.askenet.ops + :members: + :show-inheritance: + ASKEM AMR Regulatory net generation (:py:mod:`mira.modeling.askenet.regnet`) ---------------------------------------------------------------------------- .. automodule:: mira.modeling.askenet.regnet diff --git a/mira/modeling/__init__.py b/mira/modeling/__init__.py index 6ffc28f99..6a479a5ab 100644 --- a/mira/modeling/__init__.py +++ b/mira/modeling/__init__.py @@ -92,10 +92,14 @@ def assemble_variable( if key in self.variables: return self.variables[key] - if initials and concept.name in initials: - initial_value = initials[concept.name].value - else: - initial_value = None + # We don't assume that the initial dict key is the same as the + # name of the given concept the initial applies to, so we check + # concept name match instead of key match. + initial_value = None + if initials: + for k, v in initials.items(): + if v.concept.name == concept.name: + initial_value = v.value data = { 'name': concept.name, @@ -141,8 +145,8 @@ def make_model(self): value = self.template_model.parameters[key].value distribution = self.template_model.parameters[key].distribution self.get_create_parameter( - ModelParameter(key, value, distribution, - placeholder=False)) + ModelParameter(key, value, distribution, + placeholder=False)) for template in self.template_model.templates: if isinstance(template, StaticConcept): @@ -245,4 +249,3 @@ def num_controllers(template): return len(template.controllers) else: return 0 - diff --git a/mira/modeling/askenet/ops.py b/mira/modeling/askenet/ops.py new file mode 100644 index 000000000..6ffb50898 --- /dev/null +++ b/mira/modeling/askenet/ops.py @@ -0,0 +1,272 @@ +import copy +import sympy +from mira.metamodel import SympyExprStr +import mira.metamodel.ops as tmops +from mira.sources.askenet.petrinet import template_model_from_askenet_json +from .petrinet import template_model_to_petrinet_json +from mira.metamodel.io import mathml_to_expression +from mira.metamodel.template_model import Parameter, Distribution, Observable +from mira.metamodel.templates import NaturalConversion, NaturalProduction, NaturalDegradation + + +def amr_to_mira(func): + def wrapper(amr, *args, **kwargs): + tm = template_model_from_askenet_json(amr) + result = func(tm, *args, **kwargs) + amr = template_model_to_petrinet_json(result) + return amr + + return wrapper + + +# Edit ID / label / name of State, Transition, Observable, Parameter, Initial +@amr_to_mira +def replace_state_id(tm, old_id, new_id): + """Replace the ID of a state.""" + concepts_name_map = tm.get_concepts_name_map() + if old_id not in concepts_name_map: + raise ValueError(f"State with ID {old_id} not found in model.") + for template in tm.templates: + for concept in template.get_concepts(): + if concept.name == old_id: + concept.name = new_id + template.rate_law = SympyExprStr( + template.rate_law.args[0].subs(sympy.Symbol(old_id), + sympy.Symbol(new_id))) + for observable in tm.observables.values(): + observable.expression = SympyExprStr( + observable.expression.args[0].subs(sympy.Symbol(old_id), + sympy.Symbol(new_id))) + for key, initial in copy.deepcopy(tm.initials).items(): + if initial.concept.name == old_id: + tm.initials[key].concept.name = new_id + # If the key is same as the old ID, we replace that too + if key == old_id: + tm.initials[new_id] = tm.initials.pop(old_id) + return tm + + +@amr_to_mira +def replace_transition_id(tm, old_id, new_id): + """Replace the ID of a transition.""" + for template in tm.templates: + if template.name == old_id: + template.name = new_id + return tm + + +@amr_to_mira +def replace_observable_id(tm, old_id, new_id, name=None): + """Replace the ID of an observable.""" + for obs, observable in copy.deepcopy(tm.observables).items(): + if obs == old_id: + observable.name = new_id + observable.display_name = name if name else observable.display_name + tm.observables[new_id] = observable + tm.observables.pop(old_id) + return tm + + +@amr_to_mira +def remove_observable(tm, removed_id): + for obs, observable in copy.deepcopy(tm.observables).items(): + if obs == removed_id: + tm.observables.pop(obs) + return tm + + +@amr_to_mira +def remove_parameter(tm, removed_id, replacement_value=None): + if replacement_value: + tm.substitute_parameter(removed_id, replacement_value) + else: + tm.eliminate_parameter(removed_id) + return tm + + +@amr_to_mira +def add_observable(tm, new_id, new_name, new_expression): + # Note that if an observable already exists with the given + # key, it will be replaced + rate_law_sympy = mathml_to_expression(new_expression) + new_observable = Observable(name=new_id, display_name=new_name, + expression=rate_law_sympy) + tm.observables[new_id] = new_observable + return tm + + +@amr_to_mira +def replace_parameter_id(tm, old_id, new_id): + """Replace the ID of a parameter.""" + if old_id not in tm.parameters: + raise ValueError(f"Parameter with ID {old_id} not found in model.") + for template in tm.templates: + if template.rate_law: + template.rate_law = SympyExprStr( + template.rate_law.args[0].subs(sympy.Symbol(old_id), + sympy.Symbol(new_id))) + for observable in tm.observables.values(): + observable.expression = SympyExprStr( + observable.expression.args[0].subs(sympy.Symbol(old_id), + sympy.Symbol(new_id))) + for key, param in copy.deepcopy(tm.parameters).items(): + if param.name == old_id: + popped_param = tm.parameters.pop(param.name) + popped_param.name = new_id + tm.parameters[new_id] = popped_param + return tm + + +# Resolve issue where only parameters are added only when they are present in rate laws. +@amr_to_mira +def add_parameter(tm, parameter_id: str, + value: float = None, + distribution=None, + units_mathml: str = None): + distribution = Distribution(**distribution) if distribution else None + if units_mathml: + units = { + 'expression': mathml_to_expression(units_mathml), + 'expression_mathml': units_mathml + } + else: + units = None + data = { + 'name': parameter_id, + 'value': value, + 'distribution': distribution, + 'units': units + } + + new_param = Parameter(**data) + tm.parameters[parameter_id] = new_param + + return tm + + +@amr_to_mira +def replace_initial_id(tm, old_id, new_id): + """Replace the ID of an initial.""" + tm.initials = { + (new_id if k == old_id else k): v for k, v in tm.initials.items() + } + return tm + + +# Remove state +@amr_to_mira +def remove_state(tm, state_id): + new_templates = [] + for template in tm.templates: + to_remove = False + for concept in template.get_concepts(): + if concept.name == state_id: + to_remove = True + if not to_remove: + new_templates.append(template) + tm.templates = new_templates + + for obs, observable in tm.observables.items(): + observable.expression = SympyExprStr( + observable.expression.args[0].subs(sympy.Symbol(state_id), 0)) + return tm + + +@amr_to_mira +def add_state(tm, state_id, grounding: None, units: None): + pass + + +# Remove transition +@amr_to_mira +def remove_transition(tm, transition_id): + tm.templates = [t for t in tm.templates if t.name != transition_id] + return tm + + +@amr_to_mira +def add_transition(tm, new_transition_id, src_id=None, tgt_id=None, rate_law_mathml=None): + # TODO: handle parameters added in the rate law as follows + # option 1 take in optional parameters dict if rate law contains parameters + # that aren't already present + # option 2, reverse engineer rate law and find parameters and states within + # the rate law and add to model + if src_id is None and tgt_id is None: + ValueError("You must pass in at least one of source and target id") + rate_law_sympy = SympyExprStr(mathml_to_expression(rate_law_mathml)) \ + if rate_law_mathml else None + if src_id is None and tgt_id: + template = NaturalProduction(name=new_transition_id, outcome=tgt_id, + rate_law=rate_law_sympy) + elif src_id and tgt_id is None: + template = NaturalDegradation(name=new_transition_id, subject=src_id, + rate_law=rate_law_sympy) + else: + template = NaturalConversion(name=new_transition_id, subject=src_id, + outcome=tgt_id, rate_law=rate_law_sympy) + tm.templates.append(template) + return tm + + +@amr_to_mira +# rate law is of type Sympy Expression +def replace_rate_law_sympy(tm, transition_id, new_rate_law: sympy.Expr): + # NOTE: this assumes that a sympy expression object is given + # though it might make sense to take a string instead + for template in tm.templates: + if template.name == transition_id: + template.rate_law = SympyExprStr(new_rate_law) + return tm + + +def replace_rate_law_mathml(tm, transition_id, new_rate_law): + new_rate_law_sympy = mathml_to_expression(new_rate_law) + return replace_rate_law_sympy(tm, transition_id, new_rate_law_sympy) + + +# currently initials don't support expressions so only implement the following 2 methods for observables +# if we are seeking to replace an expression in an initial, return current template model +@amr_to_mira +def replace_observable_expression_sympy(tm, obs_id, + new_expression_sympy: sympy.Expr): + for obs, observable in tm.observables.items(): + if obs == obs_id: + observable.expression = SympyExprStr(new_expression_sympy) + return tm + + +def replace_intial_expression_sympy(tm, initial_id, + new_expression_sympy: sympy.Expr): + # TODO: once initial expressions are supported, implement this + return tm + + +def replace_observable_expression_mathml(tm, obj_id, new_expression_mathml): + new_expression_sympy = mathml_to_expression(new_expression_mathml) + return replace_observable_expression_sympy(tm, obj_id, + new_expression_sympy) + + +def replace_intial_expression_mathml(tm, initial_id, new_expression_mathml): + # TODO: once initial expressions are supported, implement this + return tm + + +@amr_to_mira +def stratify(*args, **kwargs): + return tmops.stratify(*args, **kwargs) + + +@amr_to_mira +def simplify_rate_laws(*args, **kwargs): + return tmops.simplify_rate_laws(*args, **kwargs) + + +@amr_to_mira +def aggregate_parameters(*args, **kwargs): + return tmops.aggregate_parameters(*args, **kwargs) + + +@amr_to_mira +def counts_to_dimensionless(*args, **kwargs): + return tmops.counts_to_dimensionless(*args, **kwargs) diff --git a/mira/modeling/askenet/petrinet.py b/mira/modeling/askenet/petrinet.py index 784932058..5ce3327a4 100644 --- a/mira/modeling/askenet/petrinet.py +++ b/mira/modeling/askenet/petrinet.py @@ -2,7 +2,8 @@ at https://github.com/DARPA-ASKEM/Model-Representations/tree/main/petrinet. """ -__all__ = ["AskeNetPetriNetModel", "ModelSpecification"] +__all__ = ["AskeNetPetriNetModel", "ModelSpecification", + "template_model_to_petrinet_json"] import json @@ -12,7 +13,8 @@ from pydantic import BaseModel, Field -from mira.metamodel import expression_to_mathml, safe_parse_expr +from mira.metamodel import expression_to_mathml, safe_parse_expr, \ + TemplateModel from .. import Model from .utils import add_metadata_annotations @@ -103,9 +105,12 @@ def __init__(self, model: Model): self.initials.append(initial_data) for key, observable in model.observables.items(): + display_name = observable.observable.display_name \ + if observable.observable.display_name \ + else observable.observable.name obs_data = { 'id': observable.observable.name, - 'name': observable.observable.name, + 'name': display_name, 'expression': str(observable.observable.expression), 'expression_mathml': expression_to_mathml( observable.observable.expression.args[0]), @@ -273,6 +278,21 @@ def to_json_file(self, fname, name=None, description=None, json.dump(js, fh, indent=indent, **kwargs) +def template_model_to_petrinet_json(tm: TemplateModel): + """Convert a template model to a PetriNet JSON dict. + + Parameters + ---------- + tm : + The template model to convert. + + Returns + ------- + A JSON dict representing the PetriNet model. + """ + return AskeNetPetriNetModel(Model(tm)).to_json() + + class Initial(BaseModel): target: str expression: str diff --git a/mira/sources/askenet/petrinet.py b/mira/sources/askenet/petrinet.py index 2699e75d2..5ca309f97 100644 --- a/mira/sources/askenet/petrinet.py +++ b/mira/sources/askenet/petrinet.py @@ -145,7 +145,8 @@ def template_model_from_askenet_json(model_json) -> TemplateModel: continue observable = Observable(name=observable['id'], - expression=observable_expr) + expression=observable_expr, + display_name=observable.get('name')) observables[observable.name] = observable # We get the time variable from the semantics diff --git a/notebooks/applications/Enzyme_substrate_kinetics.ipynb b/notebooks/applications/Enzyme_substrate_kinetics.ipynb index 2d7ed8a25..cd837e90f 100644 --- a/notebooks/applications/Enzyme_substrate_kinetics.ipynb +++ b/notebooks/applications/Enzyme_substrate_kinetics.ipynb @@ -3,7 +3,6 @@ { "cell_type": "code", "execution_count": null, - "id": "475993a6-38f9-4627-9cf6-808d351a9faf", "metadata": {}, "outputs": [], "source": [ @@ -23,7 +22,6 @@ { "cell_type": "code", "execution_count": 3, - "id": "015c410c-0f31-448e-ba17-bda9729e1fc5", "metadata": {}, "outputs": [ { @@ -47,7 +45,6 @@ { "cell_type": "code", "execution_count": 148, - "id": "d45757d7-1fb9-4d0f-ae90-d109e3c7a893", "metadata": {}, "outputs": [], "source": [ @@ -57,7 +54,6 @@ { "cell_type": "code", "execution_count": 151, - "id": "4df86721-da78-4426-8313-ff9e1dc2115c", "metadata": {}, "outputs": [], "source": [ @@ -67,7 +63,6 @@ { "cell_type": "code", "execution_count": 44, - "id": "ce3fd838-2a3c-4f65-a4b6-c9fc6064a803", "metadata": {}, "outputs": [], "source": [ @@ -77,7 +72,6 @@ { "cell_type": "code", "execution_count": 153, - "id": "b2bdf38d-afee-48a9-bcbe-315235142807", "metadata": {}, "outputs": [], "source": [ @@ -89,7 +83,6 @@ { "cell_type": "code", "execution_count": 154, - "id": "a6b5b9fd-8ac7-42cb-bd3c-c09cc2e96795", "metadata": {}, "outputs": [], "source": [ @@ -103,7 +96,6 @@ { "cell_type": "code", "execution_count": 157, - "id": "b10bfcb2-ef77-42b3-9249-3ded6309596a", "metadata": {}, "outputs": [], "source": [ @@ -113,7 +105,6 @@ { "cell_type": "code", "execution_count": 158, - "id": "db584dc2-0700-499e-ac07-3e3c1bb29300", "metadata": {}, "outputs": [ { @@ -137,7 +128,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -151,7 +142,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.3" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 09aa4efdb..a9107d5e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,5 +15,6 @@ reverse_relative = true [tool.pytest.ini_options] markers = [ - "slow: marks tests as slow for pytest (deselect with '-m \"not slow\"')" + "slow: marks tests as slow for pytest (deselect with '-m \"not slow\"')", + "sbmlmath: marks tests that import sbmlmath pytest (deselect with '-m \"not sbmlmath\"')" ] diff --git a/tests/test_modeling/test_askenet_ops.py b/tests/test_modeling/test_askenet_ops.py new file mode 100644 index 000000000..a844c6e00 --- /dev/null +++ b/tests/test_modeling/test_askenet_ops.py @@ -0,0 +1,490 @@ +import unittest +import requests +import pytest +from copy import deepcopy as _d +from mira.modeling.askenet.ops import * +from sympy import * +from mira.metamodel.io import mathml_to_expression +from mira.metamodel.templates import Concept + + +class TestAskenetOperations(unittest.TestCase): + """A test case for operations on template models.""" + + @classmethod + def setUpClass(cls): + cls.sir_amr = requests.get( + 'https://raw.githubusercontent.com/DARPA-ASKEM/' + 'Model-Representations/main/petrinet/examples/sir.json').json() + + '''These unit tests are conducted by zipping through lists of each key in a amr file + (e.g. parameters, observables, etc). Zipping in this manner assumes that the order (assuming insertion order) + is preserved before and after mira operation for an amr key + ''' + + def test_replace_state_id(self): + old_id = 'S' + new_id = 'X' + amr = _d(self.sir_amr) + new_amr = replace_state_id(amr, old_id, new_id) + + old_model = amr['model'] + new_model = new_amr['model'] + + old_model_states = old_model['states'] + new_model_states = new_model['states'] + + old_model_transitions = old_model['transitions'] + new_model_transitions = new_model['transitions'] + + self.assertEqual(len(old_model_states), len(new_model_states)) + for old_state, new_state in zip(old_model_states, new_model_states): + + # output states missing description field + if old_state['id'] == old_id: + self.assertEqual(new_state['id'], new_id) + self.assertEqual(old_state['name'], new_state['name']) + self.assertEqual(old_state['grounding']['identifiers'], new_state['grounding']['identifiers']) + self.assertEqual(old_state['units'], new_state['units']) + + self.assertEqual(len(old_model_transitions), len(new_model_transitions)) + + # output transitions are missing a description and transition['properties']['name'] field + # is abbreviated in output amr + for old_transition, new_transition in zip(old_model_transitions, new_model_transitions): + if old_id in old_transition['input'] or old_id in old_transition['output']: + self.assertIn(new_id, new_transition['input']) + self.assertNotIn(old_id, new_transition['output']) + self.assertEqual(len(old_transition['input']), len(new_transition['input'])) + self.assertEqual(len(old_transition['output']), len(new_transition['output'])) + self.assertEqual(old_transition['id'], new_transition['id']) + + old_semantics_ode = amr['semantics']['ode'] + new_semantics_ode = new_amr['semantics']['ode'] + + old_semantics_ode_rates = old_semantics_ode['rates'] + new_semantics_ode_rates = new_semantics_ode['rates'] + + # this test doesn't account for if the expression semantic is preserved (e.g. same type of operations) + # would pass test if we call replace_state_id(I,J) and old expression is "I*X" and new expression is "J+X" + for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates): + if old_id in old_rate['expression'] or old_id in old_rate['expression_mathml']: + self.assertIn(new_id, new_rate['expression']) + self.assertNotIn(old_id, new_rate['expression']) + + self.assertIn(new_id, new_rate['expression_mathml']) + self.assertNotIn(old_id, new_rate['expression_mathml']) + + self.assertEqual(old_rate['target'], new_rate['target']) + + # initials have float values substituted in for state ids in their expression and expression_mathml field + old_semantics_ode_initials = old_semantics_ode['initials'] + new_semantics_ode_initials = new_semantics_ode['initials'] + + self.assertEqual(len(old_semantics_ode_initials), len(new_semantics_ode_initials)) + for old_initials, new_initials in zip(old_semantics_ode_initials, new_semantics_ode_initials): + if old_id == old_initials['target']: + self.assertEqual(new_initials['target'], new_id) + + old_semantics_ode_parameters = old_semantics_ode['parameters'] + new_semantics_ode_parameters = new_semantics_ode['parameters'] + # This is due to initial expressions vs values + assert len(old_semantics_ode_parameters) == 5 + assert len(new_semantics_ode_parameters) == 2 + + # zip method iterates over length of the smaller iterable len(new_semantics_ode_parameters) = 2 + # as opposed to len(old_semantics_ode_parameters) = 5 , non-state parameters are listed first in input amr + for old_params, new_params in zip(old_semantics_ode_parameters, new_semantics_ode_parameters): + # test to see if old_id/new_id in name/id field and not for id/name equality because these fields + # may contain subscripts or timestamps appended to the old_id/new_id + if old_id in old_params['id'] and old_id in old_params['name']: + self.assertIn(new_id, new_params['id']) + self.assertIn(new_id, new_params['name']) + + old_semantics_ode_observables = old_semantics_ode['observables'] + new_semantics_ode_observables = new_semantics_ode['observables'] + self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables)) + + for old_observable, new_observable in zip(old_semantics_ode_observables, new_semantics_ode_observables): + if old_id in old_observable['states'] and old_id in old_observable['expression'] and \ + old_id in old_observable['expression_mathml']: + self.assertIn(new_id, new_observable['expression']) + self.assertNotIn(old_id, new_observable['expression']) + + self.assertIn(new_id, new_observable['expression_mathml']) + self.assertNotIn(old_id, new_observable['expression_mathml']) + + self.assertEqual(old_observable['id'], new_observable['id']) + + def test_replace_transition_id(self): + + old_id = 'inf' + new_id = 'new_inf' + amr = _d(self.sir_amr) + new_amr = replace_transition_id(amr, old_id, new_id) + + old_model_transitions = amr['model']['transitions'] + new_model_transitions = new_amr['model']['transitions'] + + self.assertEqual(len(old_model_transitions), len(new_model_transitions)) + + for old_transitions, new_transition in zip(old_model_transitions, new_model_transitions): + if old_transitions['id'] == old_id: + self.assertEqual(new_transition['id'], new_id) + + def test_replace_observable_id(self): + old_id = 'noninf' + new_id = 'testinf' + new_display_name = 'test-infection' + amr = _d(self.sir_amr) + new_amr = replace_observable_id(amr, old_id, new_id, + name=new_display_name) + + old_semantics_observables = amr['semantics']['ode']['observables'] + new_semantics_observables = new_amr['semantics']['ode']['observables'] + + self.assertEqual(len(old_semantics_observables), len(new_semantics_observables)) + + for old_observable, new_observable in zip(old_semantics_observables, new_semantics_observables): + if old_observable['id'] == old_id: + self.assertEqual(new_observable['id'], new_id) + self.assertEqual(new_observable['name'], new_display_name) + + def test_remove_observable_or_parameter(self): + + old_amr_obs = _d(self.sir_amr) + old_amr_param = _d(self.sir_amr) + + replaced_observable_id = 'noninf' + new_amr_obs = remove_observable(old_amr_obs, replaced_observable_id) + for new_observable in new_amr_obs['semantics']['ode']['observables']: + self.assertNotEqual(new_observable['id'], replaced_observable_id) + + replaced_param_id = 'beta' + replacement_value = 5 + new_amr_param = remove_parameter(old_amr_param, replaced_param_id, replacement_value) + for new_param in new_amr_param['semantics']['ode']['parameters']: + self.assertNotEqual(new_param['id'], replaced_param_id) + for old_rate, new_rate in zip(old_amr_param['semantics']['ode']['rates'], + new_amr_param['semantics']['ode']['rates']): + if replaced_param_id in old_rate['expression'] and replaced_param_id in old_rate['expression_mathml']: + self.assertNotIn(replaced_param_id, new_rate['expression']) + self.assertIn(str(replacement_value), new_rate['expression']) + + self.assertNotIn(replaced_param_id, new_rate['expression_mathml']) + self.assertIn(str(replacement_value), new_rate['expression_mathml']) + + self.assertEqual(old_rate['target'], new_rate['target']) + + # currently don't support expressions for initials + for old_obs, new_obs in zip(old_amr_param['semantics']['ode']['observables'], + new_amr_param['semantics']['ode']['observables']): + if replaced_param_id in old_obs['expression'] and replaced_param_id in old_obs['expression_mathml']: + self.assertNotIn(replaced_param_id, new_obs['expression']) + self.assertIn(str(replacement_value), new_obs['expression']) + + self.assertNotIn(replaced_param_id, new_obs['expression_mathml']) + self.assertIn(str(replacement_value), new_obs['expression_mathml']) + + self.assertEqual(old_obs['id'], new_obs['id']) + + @pytest.mark.sbmlmath + def test_add_observable(self): + amr = _d(self.sir_amr) + new_id = 'testinf' + new_display_name = 'DISPLAY_TEST' + xml_expression = "Edelta" + new_amr = add_observable(amr, new_id, new_display_name, xml_expression) + + # Create a dict out of a list of observable dict entries to easier test for addition of new observables + new_observable_dict = {} + for observable in new_amr['semantics']['ode']['observables']: + name = observable.pop('id') + new_observable_dict[name] = observable + + self.assertIn(new_id, new_observable_dict) + self.assertEqual(new_display_name, new_observable_dict[new_id]['name']) + self.assertEqual(xml_expression, new_observable_dict[new_id]['expression_mathml']) + self.assertEqual(sstr(mathml_to_expression(xml_expression)), new_observable_dict[new_id]['expression']) + + @pytest.mark.sbmlmath + def test_replace_parameter_id(self): + old_id = 'beta' + new_id = 'TEST' + amr = _d(self.sir_amr) + new_amr = replace_parameter_id(amr, old_id, new_id) + + old_semantics_ode_rates = amr['semantics']['ode']['rates'] + new_semantics_ode_rates = new_amr['semantics']['ode']['rates'] + + old_semantics_ode_observables = amr['semantics']['ode']['observables'] + new_semantics_ode_observables = new_amr['semantics']['ode']['observables'] + + old_semantics_ode_parameters = amr['semantics']['ode']['parameters'] + new_semantics_ode_parameters = new_amr['semantics']['ode']['parameters'] + + new_model_states = new_amr['model']['states'] + + self.assertEqual(len(old_semantics_ode_rates), len(new_semantics_ode_rates)) + self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables)) + + self.assertEqual(len(old_semantics_ode_parameters) - len(new_model_states), len(new_semantics_ode_parameters)) + + for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates): + if old_id in old_rate['expression'] and old_id in old_rate['expression_mathml']: + self.assertIn(new_id, new_rate['expression']) + self.assertNotIn(old_id, new_rate['expression']) + + self.assertIn(new_id, new_rate['expression_mathml']) + self.assertNotIn(old_id, new_rate['expression_mathml']) + + # don't test states field for a parameter as it is assumed that replace_parameter_id will only be used with + # parameters such as gamma or beta (i.e. non-states) + for old_observable, new_observable in zip(old_semantics_ode_observables, new_semantics_ode_observables): + if old_id in old_observable['expression'] and old_id in new_observable['expression_mathml']: + self.assertIn(new_id, new_observable['expression']) + self.assertNotIn(old_id, new_observable['expression']) + + self.assertIn(new_id, new_observable['expression_mathml']) + self.assertNotIn(old_id, new_observable['expression_mathml']) + + for old_parameter, new_parameter in zip(old_semantics_ode_parameters, new_semantics_ode_parameters): + if old_parameter['id'] == old_id: + self.assertEqual(new_parameter['id'], new_id) + self.assertEqual(old_parameter['value'], new_parameter['value']) + self.assertEqual(old_parameter['distribution'], new_parameter['distribution']) + self.assertEqual(sstr(old_parameter['units']['expression']), new_parameter['units']['expression']) + self.assertEqual(mathml_to_expression(old_parameter['units']['expression_mathml']), + mathml_to_expression(new_parameter['units']['expression_mathml'])) + + def test_remove_state(self): + removed_state_id = 'S' + amr = _d(self.sir_amr) + + new_amr = remove_state(amr, removed_state_id) + + new_model = new_amr['model'] + new_model_states = new_model['states'] + new_model_transitions = new_model['transitions'] + + new_semantics_ode = new_amr['semantics']['ode'] + new_semantics_ode_rates = new_semantics_ode['rates'] + new_semantics_ode_initials = new_semantics_ode['initials'] + new_semantics_ode_parameters = new_semantics_ode['parameters'] + new_semantics_ode_observables = new_semantics_ode['observables'] + + for new_state in new_model_states: + self.assertNotEquals(removed_state_id, new_state['id']) + + for new_transition in new_model_transitions: + self.assertNotIn(removed_state_id, new_transition['input']) + self.assertNotIn(removed_state_id, new_transition['output']) + + # output rates that originally contained targeted state are removed + for new_rate in new_semantics_ode_rates: + self.assertNotIn(removed_state_id, new_rate['expression']) + self.assertNotIn(removed_state_id, new_rate['expression_mathml']) + + # initials are bugged, all states removed rather than just targeted removed state in output amr + for new_initial in new_semantics_ode_initials: + self.assertNotEquals(removed_state_id, new_initial['target']) + + # parameters that are associated in an expression with a removed state are not present in output amr + # (e.g.) if there exists an expression: "S*I*beta" and we remove S, then beta is no longer present in output + # list of parameters + for new_parameter in new_semantics_ode_parameters: + self.assertNotIn(removed_state_id, new_parameter['id']) + + # output observable expressions that originally contained targeted state still exist with targeted state removed + # (e.g. 'S+R' -> 'R') if 'S' is the removed state + for new_observable in new_semantics_ode_observables: + self.assertNotIn(removed_state_id, new_observable['expression']) + self.assertNotIn(removed_state_id, new_observable['expression_mathml']) + + def test_remove_transition(self): + removed_transition = 'inf' + amr = _d(self.sir_amr) + + new_amr = remove_transition(amr, removed_transition) + new_model_transition = new_amr['model']['transitions'] + + for new_transition in new_model_transition: + self.assertNotEquals(removed_transition, new_transition['id']) + + @pytest.mark.sbmlmath + def test_add_transition(self): + test_subject = Concept(name="test_subject", identifiers={"ido": "0000511"}) + test_outcome = Concept(name="test_outcome", identifiers={"ido": "0000592"}) + expression_xml = "Edelta" + new_transition_id = 'test' + old_natural_conversion_amr = _d(self.sir_amr) + old_natural_production_amr = _d(self.sir_amr) + old_natural_degradation_amr = _d(self.sir_amr) + + # NaturalConversion + new_natural_conversion_amr = add_transition(old_natural_conversion_amr, new_transition_id, + rate_law_mathml=expression_xml, + src_id=test_subject, + tgt_id=test_outcome) + natural_conversion_transition_dict = {} + natural_conversion_rates_dict = {} + natural_conversion_state_dict = {} + + for transition in new_natural_conversion_amr['model']['transitions']: + name = transition.pop('id') + natural_conversion_transition_dict[name] = transition + + for rate in new_natural_conversion_amr['semantics']['ode']['rates']: + name = rate.pop('target') + natural_conversion_rates_dict[name] = rate + + for state in new_natural_conversion_amr['model']['states']: + name = state.pop('id') + natural_conversion_state_dict[name] = state + + self.assertIn(new_transition_id, natural_conversion_transition_dict) + self.assertIn(new_transition_id, natural_conversion_rates_dict) + self.assertEqual(expression_xml, natural_conversion_rates_dict[new_transition_id]['expression_mathml']) + self.assertEqual(sstr(mathml_to_expression(expression_xml)), + natural_conversion_rates_dict[new_transition_id]['expression']) + self.assertIn(test_subject.name, natural_conversion_state_dict) + self.assertIn(test_outcome.name, natural_conversion_state_dict) + + # NaturalProduction + new_natural_production_amr = add_transition(old_natural_production_amr, new_transition_id, + rate_law_mathml=expression_xml, + tgt_id=test_outcome) + natural_production_transition_dict = {} + natural_production_rates_dict = {} + natural_production_state_dict = {} + + for transition in new_natural_production_amr['model']['transitions']: + name = transition.pop('id') + natural_production_transition_dict[name] = transition + + for rate in new_natural_production_amr['semantics']['ode']['rates']: + name = rate.pop('target') + natural_production_rates_dict[name] = rate + + for state in new_natural_production_amr['model']['states']: + name = state.pop('id') + natural_production_state_dict[name] = state + + self.assertIn(new_transition_id, natural_production_transition_dict) + self.assertIn(new_transition_id, natural_production_rates_dict) + self.assertEqual(expression_xml, natural_production_rates_dict[new_transition_id]['expression_mathml']) + self.assertEqual(sstr(mathml_to_expression(expression_xml)), + natural_production_rates_dict[new_transition_id]['expression']) + self.assertIn(test_outcome.name, natural_production_state_dict) + + # NaturalDegradation + new_natural_degradation_amr = add_transition(old_natural_degradation_amr, new_transition_id, + rate_law_mathml=expression_xml, + src_id=test_subject) + natural_degradation_transition_dict = {} + natural_degradation_rates_dict = {} + natural_degradation_states_dict = {} + + for transition in new_natural_degradation_amr['model']['transitions']: + name = transition.pop('id') + natural_degradation_transition_dict[name] = transition + + for rate in new_natural_degradation_amr['semantics']['ode']['rates']: + name = rate.pop('target') + natural_degradation_rates_dict[name] = rate + + for state in new_natural_degradation_amr['model']['states']: + name = state.pop('id') + natural_degradation_states_dict[name] = state + + self.assertIn(new_transition_id, natural_degradation_transition_dict) + self.assertIn(new_transition_id, natural_degradation_rates_dict) + self.assertEqual(expression_xml, natural_degradation_rates_dict[new_transition_id]['expression_mathml']) + self.assertEqual(sstr(mathml_to_expression(expression_xml)), + natural_degradation_rates_dict[new_transition_id]['expression']) + self.assertIn(test_subject.name, natural_degradation_states_dict) + + @pytest.mark.sbmlmath + def test_replace_rate_law_sympy(self): + transition_id = 'inf' + target_expression_xml_str = 'X8' + target_expression_sympy = mathml_to_expression(target_expression_xml_str) + + amr = _d(self.sir_amr) + new_amr = replace_rate_law_sympy(amr, transition_id, target_expression_sympy) + new_semantics_ode_rates = new_amr['semantics']['ode']['rates'] + + for new_rate in new_semantics_ode_rates: + if new_rate['target'] == transition_id: + self.assertEqual(sstr(target_expression_sympy), new_rate['expression']) + self.assertEqual(target_expression_xml_str, new_rate['expression_mathml']) + + @pytest.mark.sbmlmath + def test_replace_rate_law_mathml(self): + amr = _d(self.sir_amr) + transition_id = 'inf' + target_expression_xml_str = "Edelta" + target_expression_sympy = mathml_to_expression(target_expression_xml_str) + + new_amr = replace_rate_law_mathml(amr, transition_id, target_expression_xml_str) + + new_semantics_ode_rates = new_amr['semantics']['ode']['rates'] + + for new_rate in new_semantics_ode_rates: + if new_rate['target'] == transition_id: + self.assertEqual(sstr(target_expression_sympy), new_rate['expression']) + self.assertEqual(target_expression_xml_str, new_rate['expression_mathml']) + + @pytest.mark.sbmlmath + # Following 2 unit tests only test for replacing expressions in observables, not initials + def test_replace_observable_expression_sympy(self): + object_id = 'noninf' + amr = _d(self.sir_amr) + target_expression_xml_str = "Ebeta" + target_expression_sympy = mathml_to_expression(target_expression_xml_str) + new_amr = replace_observable_expression_sympy(amr, object_id, target_expression_sympy) + + for new_obs in new_amr['semantics']['ode']['observables']: + if new_obs['id'] == object_id: + self.assertEqual(sstr(target_expression_sympy), new_obs['expression']) + self.assertEqual(target_expression_xml_str, new_obs['expression_mathml']) + + @pytest.mark.sbmlmath + def test_replace_observable_expression_mathml(self): + object_id = 'noninf' + amr = _d(self.sir_amr) + target_expression_xml_str = "Ebeta" + target_expression_sympy = mathml_to_expression(target_expression_xml_str) + new_amr = replace_observable_expression_mathml(amr, object_id, target_expression_xml_str) + + for new_obs in new_amr['semantics']['ode']['observables']: + if new_obs['id'] == object_id: + self.assertEqual(sstr(target_expression_sympy), new_obs['expression']) + self.assertEqual(target_expression_xml_str, new_obs['expression_mathml']) + + def test_stratify(self): + amr = _d(self.sir_amr) + new_amr = stratify(amr, key='city', strata=['boston', 'nyc']) + + self.assertIsInstance(amr, dict) + self.assertIsInstance(new_amr, dict) + + def test_simplify_rate_laws(self): + amr = _d(self.sir_amr) + new_amr = simplify_rate_laws(amr) + + self.assertIsInstance(amr, dict) + self.assertIsInstance(new_amr, dict) + + def test_aggregate_parameters(self): + amr = _d(self.sir_amr) + new_amr = aggregate_parameters(amr) + + self.assertIsInstance(amr, dict) + self.assertIsInstance(new_amr, dict) + + def test_counts_to_dimensionless(self): + amr = _d(self.sir_amr) + new_amr = counts_to_dimensionless(amr, 'ml', .8) + self.assertIsInstance(amr, dict) + self.assertIsInstance(new_amr, dict) diff --git a/tests/test_ops.py b/tests/test_ops.py index 23f58d290..d0d00c90a 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,6 +5,8 @@ from copy import deepcopy as _d import sympy +import requests +import itertools from mira.metamodel import * from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless @@ -231,10 +233,8 @@ def test_stratify_directed_simple(self): def assert_unique_controllers(self, tm: TemplateModel): """Assert that controllers are unique.""" for template in tm.templates: - if not isinstance( - template, - (GroupedControlledConversion, GroupedControlledProduction) - ): + if not isinstance(template, (GroupedControlledConversion, + GroupedControlledProduction)): continue counter = Counter( controller.get_key() @@ -298,7 +298,7 @@ def _make_template(rate_law): assert all(t.type == 'ControlledConversion' for t in templates) # This one can be simplified too - rate_law = (1 - _s('alpha')) * _s('S') * (_s('A') + _s('beta')*_s('B')) + rate_law = (1 - _s('alpha')) * _s('S') * (_s('A') + _s('beta') * _s('B')) template = _make_template(rate_law) templates = simplify_rate_law(template, {'alpha': Parameter(name='alpha', @@ -321,12 +321,12 @@ def test_counts_to_dimensionless(): for template in tm.templates: for concept in template.get_concepts(): concept.units = Unit(expression=sympy.Symbol('person')) - tm.initials['susceptible_population'].value = 1e5-1 + tm.initials['susceptible_population'].value = 1e5 - 1 tm.initials['infected_population'].value = 1 tm.initials['immune_population'].value = 0 tm.parameters['beta'].units = \ - Unit(expression=1/(sympy.Symbol('person')*sympy.Symbol('day'))) + Unit(expression=1 / (sympy.Symbol('person') * sympy.Symbol('day'))) old_beta = tm.parameters['beta'].value for initial in tm.initials.values(): @@ -337,11 +337,11 @@ def test_counts_to_dimensionless(): for concept in template.get_concepts(): assert concept.units.expression.args[0].equals(1), concept.units - assert tm.parameters['beta'].units.expression.args[0].equals(1/sympy.Symbol('day')) - assert tm.parameters['beta'].value == old_beta*1e5 + assert tm.parameters['beta'].units.expression.args[0].equals(1 / sympy.Symbol('day')) + assert tm.parameters['beta'].value == old_beta * 1e5 - assert tm.initials['susceptible_population'].value == (1e5-1)/1e5 - assert tm.initials['infected_population'].value == 1/1e5 + assert tm.initials['susceptible_population'].value == (1e5 - 1) / 1e5 + assert tm.initials['infected_population'].value == 1 / 1e5 assert tm.initials['immune_population'].value == 0 for initial in tm.initials.values(): @@ -354,7 +354,7 @@ def test_stratify_observable(): expr = sympy.Add(*[sympy.Symbol(s) for s in symbols]) tm.observables = {'half_population': Observable( name='half_population', - expression=SympyExprStr(expr/2)) + expression=SympyExprStr(expr / 2)) } tm = stratify(tm, key='age', diff --git a/tests/test_templatemodel_delta.py b/tests/test_templatemodel_delta.py index 61fb6f238..abad2979e 100644 --- a/tests/test_templatemodel_delta.py +++ b/tests/test_templatemodel_delta.py @@ -84,10 +84,10 @@ def test_equal_no_context(self): edge_count, f"len(edges)={len(tmd.comparison_graph.edges)}", ) - self.assert_( + self.assertTrue( all("is_refinement" != d["label"] for _, _, d in tmd.comparison_graph.edges(data=True)) ) - self.assert_( + self.assertTrue( all( d["label"] in ["is_equal"] + concept_edge_labels for _, _, d in tmd.comparison_graph.edges(data=True) @@ -111,13 +111,13 @@ def test_equal_context(self): edge_count, f"len(edges)={len(tmd_context.comparison_graph.edges)}", ) - self.assert_( + self.assertTrue( all( "is_refinement" != d["label"] for _, _, d in tmd_context.comparison_graph.edges(data=True) ) ) - self.assert_( + self.assertTrue( all( d["label"] in ["is_equal"] + concept_edge_labels for _, _, d in tmd_context.comparison_graph.edges(data=True) @@ -148,24 +148,24 @@ def test_refinement(self): tmd_vs_boston = TemplateModelDelta(self.sir, self.sir_boston, is_ontological_child_web) tmd_vs_nyc = TemplateModelDelta(self.sir, self.sir_nyc, is_ontological_child_web) - self.assert_( + self.assertTrue( all( d["label"] in ["refinement_of"] + concept_edge_labels for _, _, d in tmd_vs_boston.comparison_graph.edges(data=True) ) ) - self.assert_( + self.assertTrue( all( "is_equal" != d["label"] for _, _, d in tmd_vs_boston.comparison_graph.edges(data=True) ) ) - self.assert_( + self.assertTrue( all( d["label"] in ["refinement_of"] + concept_edge_labels for _, _, d in tmd_vs_nyc.comparison_graph.edges(data=True) ) ) - self.assert_( + self.assertTrue( all("is_equal" != d["label"] for _, _, d in tmd_vs_nyc.comparison_graph.edges(data=True)) ) diff --git a/tox.ini b/tox.ini index 0eb59af35..36779833f 100644 --- a/tox.ini +++ b/tox.ini @@ -10,15 +10,16 @@ envlist = py [testenv] -passenv = PYTHONPATH MIRA_REST_URL +passenv = PYTHONPATH, MIRA_REST_URL extras = tests web +deps = + anyio<4 commands = coverage run -p -m pytest --durations=20 {posargs:tests} -m "not slow and not sbmlmath" ; coverage combine ; coverage xml - [testenv:docs] extras = docs