Skip to content

Commit

Permalink
Added support for adding parameters not apart of rates laws to output…
Browse files Browse the repository at this point in the history
… amr and updated unit tests
  • Loading branch information
nanglo123 authored and bgyori committed Sep 15, 2023
1 parent dd2f4d6 commit 53d4410
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 39 deletions.
13 changes: 11 additions & 2 deletions mira/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class Transition:
def __init__(
self, key, consumed, produced, control, rate, template_type, template: Template,
self, key, consumed, produced, control, rate, template_type, template: Template,
):
self.key = key
self.consumed = consumed
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self, template_model):
self.make_model()

def assemble_variable(
self, concept: Concept, initials: Optional[Mapping[str, Initial]] = None,
self, concept: Concept, initials: Optional[Mapping[str, Initial]] = None,
):
"""Assemble a variable from a concept and optional
dictionary of initial values.
Expand Down Expand Up @@ -204,6 +204,15 @@ def make_model(self):
template=template,
))

for key, parameter in self.template_model.parameters.items():
if key not in self.parameters:
value = self.template_model.parameters[key].value
distribution = self.template_model.parameters[key].distribution
self.get_create_parameter(
ModelParameter(key, value, distribution,
placeholder=False)
)

def get_create_parameter(self, parameter: ModelParameter) -> ModelParameter:
if parameter.key not in self.parameters:
self.parameters[parameter.key] = parameter
Expand Down
2 changes: 1 addition & 1 deletion mira/modeling/askenet/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def replace_parameter_id(tm, old_id, new_id):
return tm


# Resolve issue where only parameters are added only when they are present in rate laws.
@amr_to_mira
def add_parameter(tm, parameter_id: str,
name: str = None,
Expand Down Expand Up @@ -255,6 +254,7 @@ def add_transition(tm, new_transition_id, src_id=None, tgt_id=None,
else:
template = NaturalConversion(name=new_transition_id, subject=src_id,
outcome=tgt_id, rate_law=rate_law_sympy)

if params_dict:
# add parameters to template model
for free_symbol_sympy in template.rate_law.free_symbols:
Expand Down
65 changes: 29 additions & 36 deletions tests/test_modeling/test_askenet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from copy import deepcopy as _d
from mira.modeling.askenet.ops import *
from sympy import *
from mira.metamodel.io import mathml_to_expression, expression_to_mathml
from mira.metamodel.io import mathml_to_expression
from mira.metamodel.templates import Concept

try:
Expand Down Expand Up @@ -94,21 +94,6 @@ def test_replace_state_id(self):
if old_id == old_initials['target']:
self.assertEqual(new_initials['target'], new_id)

old_semantics_ode_parameters = old_semantics_ode['parameters']
new_semantics_ode_parameters = new_semantics_ode['parameters']
# This is due to initial expressions vs values
assert len(old_semantics_ode_parameters) == 5
assert len(new_semantics_ode_parameters) == 2

# zip method iterates over length of the smaller iterable len(new_semantics_ode_parameters) = 2
# as opposed to len(old_semantics_ode_parameters) = 5 , non-state parameters are listed first in input amr
for old_params, new_params in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
# test to see if old_id/new_id in name/id field and not for id/name equality because these fields
# may contain subscripts or timestamps appended to the old_id/new_id
if old_id in old_params['id'] and old_id in old_params['name']:
self.assertIn(new_id, new_params['id'])
self.assertIn(new_id, new_params['name'])

old_semantics_ode_observables = old_semantics_ode['observables']
new_semantics_ode_observables = new_semantics_ode['observables']
self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables))
Expand Down Expand Up @@ -231,13 +216,9 @@ def test_replace_parameter_id(self):
old_semantics_ode_parameters = amr['semantics']['ode']['parameters']
new_semantics_ode_parameters = new_amr['semantics']['ode']['parameters']

new_model_states = new_amr['model']['states']

self.assertEqual(len(old_semantics_ode_rates), len(new_semantics_ode_rates))
self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables))

self.assertEqual(len(old_semantics_ode_parameters) - len(new_model_states), len(new_semantics_ode_parameters))

for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates):
if old_id in old_rate['expression'] and old_id in old_rate['expression_mathml']:
self.assertIn(new_id, new_rate['expression'])
Expand All @@ -256,14 +237,25 @@ def test_replace_parameter_id(self):
self.assertIn(new_id, new_observable['expression_mathml'])
self.assertNotIn(old_id, new_observable['expression_mathml'])

old_param_dict = {}
new_param_dict = {}

for old_parameter, new_parameter in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
if old_parameter['id'] == old_id:
self.assertEqual(new_parameter['id'], new_id)
self.assertEqual(old_parameter['value'], new_parameter['value'])
self.assertEqual(old_parameter['distribution'], new_parameter['distribution'])
self.assertEqual(sstr(old_parameter['units']['expression']), new_parameter['units']['expression'])
self.assertEqual(mathml_to_expression(old_parameter['units']['expression_mathml']),
mathml_to_expression(new_parameter['units']['expression_mathml']))
old_name = old_parameter.pop('id')
new_name = new_parameter.pop('id')

old_param_dict[old_name] = old_parameter
new_param_dict[new_name] = new_parameter

self.assertIn(new_id, new_param_dict)
self.assertNotIn(old_id, new_param_dict)

self.assertEqual(old_param_dict[old_id]['value'], new_param_dict[new_id]['value'])
self.assertEqual(old_param_dict[old_id]['distribution'], new_param_dict[new_id]['distribution'])
self.assertEqual(sstr(old_param_dict[old_id]['units']['expression']),
new_param_dict[new_id]['units']['expression'])
self.assertEqual(mathml_to_expression(old_param_dict[old_id]['units']['expression_mathml']),
mathml_to_expression(new_param_dict[new_id]['units']['expression_mathml']))

@SBMLMATH_REQUIRED
def test_add_parameter(self):
Expand All @@ -273,9 +265,17 @@ def test_add_parameter(self):
value = 0.35
xml_str = "<apply><times/><ci>E</ci><ci>delta</ci></apply>"
distribution = {'type': 'test_distribution',
'parameters': {'delta': 5}}
'parameters': {'test_dist': 5}}
new_amr = add_parameter(amr, parameter_id=parameter_id, name=name, value=value, distribution=distribution,
units_mathml=xml_str)
param_dict = {}
for param in new_amr['semantics']['ode']['parameters']:
name = param.pop('id')
param_dict[name] = param

self.assertIn(parameter_id, param_dict)
self.assertEqual(param_dict[parameter_id]['value'], value)
self.assertEqual(param_dict[parameter_id]['distribution'], distribution)

def test_remove_state(self):
removed_state_id = 'S'
Expand All @@ -290,7 +290,6 @@ def test_remove_state(self):
new_semantics_ode = new_amr['semantics']['ode']
new_semantics_ode_rates = new_semantics_ode['rates']
new_semantics_ode_initials = new_semantics_ode['initials']
new_semantics_ode_parameters = new_semantics_ode['parameters']
new_semantics_ode_observables = new_semantics_ode['observables']

for new_state in new_model_states:
Expand All @@ -309,12 +308,6 @@ def test_remove_state(self):
for new_initial in new_semantics_ode_initials:
self.assertNotEqual(removed_state_id, new_initial['target'])

# parameters that are associated in an expression with a removed state are not present in output amr
# (e.g.) if there exists an expression: "S*I*beta" and we remove S, then beta is no longer present in output
# list of parameters
for new_parameter in new_semantics_ode_parameters:
self.assertNotIn(removed_state_id, new_parameter['id'])

# output observable expressions that originally contained targeted state still exist with targeted state removed
# (e.g. 'S+R' -> 'R') if 'S' is the removed state
for new_observable in new_semantics_ode_observables:
Expand Down Expand Up @@ -393,7 +386,7 @@ def test_add_transition(self):
'units': expression_xml
}

# Test to see if parameters with attributes initialized are correctly added to parameters list of output amr
# Test to see if parameters with no attributes initialized are correctly added to parameters list of output amr
test_params_dict['E'] = {}

old_natural_conversion_amr = _d(self.sir_amr)
Expand Down

0 comments on commit 53d4410

Please sign in to comment.