Skip to content

Commit

Permalink
Review and update operations
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Sep 8, 2023
1 parent 4df236f commit e42cc42
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 54 deletions.
127 changes: 76 additions & 51 deletions mira/modeling/askenet/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,42 +56,50 @@ def replace_transition_id(tm, old_id, new_id):


@amr_to_mira
def replace_observable_id(tm, old_id, new_id, display_name):
def replace_observable_id(tm, old_id, new_id, name=None):
"""Replace the ID of an observable."""
for obs, observable in copy.deepcopy(tm.observables).items():
if obs == old_id:
observable.name = new_id
observable.display_name = display_name
observable.display_name = name if name else observable.display_name
tm.observables[new_id] = observable
tm.observables.pop(old_id)
return tm


@amr_to_mira
def remove_observable_or_parameter(tm, replaced_id, replacement_value=None):
def remove_observable(tm, removed_id):
for obs, observable in copy.deepcopy(tm.observables).items():
if obs == removed_id:
tm.observables.pop(obs)
return tm


@amr_to_mira
def remove_parameter(tm, removed_id, replacement_value=None):
if replacement_value:
tm.substitute_parameter(replaced_id, replacement_value)
tm.substitute_parameter(removed_id, replacement_value)
else:
for obs, observable in copy.deepcopy(tm.observables).items():
if obs == replaced_id:
tm.observables.pop(obs)
tm.eliminate_parameter(removed_id)
return tm


@amr_to_mira
def add_observable(tm, new_id, new_display_name, new_rate_law):
if new_id in tm.observables:
print('This observable id is already present')
return tm
rate_law_sympy = mathml_to_expression(new_rate_law)
new_observable = Observable(name=new_id, display_name=new_display_name, expression=rate_law_sympy)
def add_observable(tm, new_id, new_name, new_expression):
# Note that if an observable already exists with the given
# key, it will be replaced
rate_law_sympy = mathml_to_expression(new_expression)
new_observable = Observable(name=new_id, display_name=new_name,
expression=rate_law_sympy)
tm.observables[new_id] = new_observable
return tm


@amr_to_mira
def replace_parameter_id(tm, old_id, new_id):
"""Replace the ID of a parameter."""
if old_id not in tm.parameters:
raise ValueError(f"Parameter with ID {old_id} not found in model.")
for template in tm.templates:
if template.rate_law:
template.rate_law = SympyExprStr(
Expand All @@ -103,31 +111,31 @@ def replace_parameter_id(tm, old_id, new_id):
sympy.Symbol(new_id)))
for key, param in copy.deepcopy(tm.parameters).items():
if param.name == old_id:
try:
popped_param = tm.parameters.pop(param.name)
popped_param.name = new_id
tm.parameters[new_id] = popped_param
except KeyError:
print('Old id: {}, is not present in the parameter dictionary of the template model'.format(old_id))
popped_param = tm.parameters.pop(param.name)
popped_param.name = new_id
tm.parameters[new_id] = popped_param
return tm


# Resolve issue where only parameters are added only when they are present in rate laws.
@amr_to_mira
def add_parameter(tm, parameter_id: str, expression_xml: str, value: float, distribution_type: str,
min_value: float, max_value: float):
distribution = Distribution(type=distribution_type,
parameters={
'maximum': max_value,
'minimum': min_value
})
sympy_expression = mathml_to_expression(expression_xml)
def add_parameter(tm, parameter_id: str,
value: float = None,
distribution=None,
units_mathml: str = None):
distribution = Distribution(**distribution) if distribution else None
if units_mathml:
units = {
'expression': mathml_to_expression(units_mathml),
'expression_mathml': units_mathml
}
else:
units = None
data = {
'name': parameter_id,
'value': value,
'distribution': distribution,
'units': {'expression': sympy_expression,
'expression_mathml': expression_xml}
'units': units
}

new_param = Parameter(**data)
Expand Down Expand Up @@ -177,27 +185,34 @@ def remove_transition(tm, transition_id):


@amr_to_mira
# option 1 take in optional parameters dict if rate law contains parameters that aren't already present
# option 2, reverse engineer rate law and find parameters and states within the rate law and add to model
def add_transition(tm, new_transition_id, rate_law_mathml, src_id=None, tgt_id=None):
rate_law_sympy = SympyExprStr(mathml_to_expression(rate_law_mathml))
def add_transition(tm, new_transition_id, src_id=None, tgt_id=None, rate_law_mathml=None):
# TODO: handle parameters added in the rate law as follows
# option 1 take in optional parameters dict if rate law contains parameters
# that aren't already present
# option 2, reverse engineer rate law and find parameters and states within
# the rate law and add to model
if src_id is None and tgt_id is None:
print("You must pass in at least one of source and target id")
elif src_id is None and tgt_id:
template = NaturalProduction(name=new_transition_id, outcome=tgt_id, rate_law=rate_law_sympy)
tm.templates.append(template)
ValueError("You must pass in at least one of source and target id")
rate_law_sympy = SympyExprStr(mathml_to_expression(rate_law_mathml)) \
if rate_law_mathml else None
if src_id is None and tgt_id:
template = NaturalProduction(name=new_transition_id, outcome=tgt_id,
rate_law=rate_law_sympy)
elif src_id and tgt_id is None:
template = NaturalDegradation(name=new_transition_id, subject=src_id, rate_law=rate_law_sympy)
tm.templates.append(template)
template = NaturalDegradation(name=new_transition_id, subject=src_id,
rate_law=rate_law_sympy)
else:
template = NaturalConversion(name=new_transition_id, subject=src_id, outcome=tgt_id, rate_law=rate_law_sympy)
tm.templates.append(template)
template = NaturalConversion(name=new_transition_id, subject=src_id,
outcome=tgt_id, rate_law=rate_law_sympy)
tm.templates.append(template)
return tm


@amr_to_mira
# rate law is of type Sympy Expression
def replace_rate_law_sympy(tm, transition_id, new_rate_law):
def replace_rate_law_sympy(tm, transition_id, new_rate_law: sympy.Expr):
# NOTE: this assumes that a sympy expression object is given
# though it might make sense to take a string instead
for template in tm.templates:
if template.name == transition_id:
template.rate_law = SympyExprStr(new_rate_law)
Expand All @@ -212,19 +227,29 @@ def replace_rate_law_mathml(tm, transition_id, new_rate_law):
# currently initials don't support expressions so only implement the following 2 methods for observables
# if we are seeking to replace an expression in an initial, return current template model
@amr_to_mira
def replace_expression_sympy(tm, object_id, new_expression_sympy, initial_flag):
if initial_flag:
return tm
else:
for obs, observable in tm.observables.items():
if obs == object_id:
observable.expression = SympyExprStr(new_expression_sympy)
def replace_observable_expression_sympy(tm, obs_id,
new_expression_sympy: sympy.Expr):
for obs, observable in tm.observables.items():
if obs == obs_id:
observable.expression = SympyExprStr(new_expression_sympy)
return tm


def replace_intial_expression_sympy(tm, initial_id,
new_expression_sympy: sympy.Expr):
# TODO: once initial expressions are supported, implement this
return tm


def replace_expression_mathml(tm, object_id, new_expression_mathml, initial_flag):
def replace_observable_exression_mathml(tm, obj_id, new_expression_mathml):
new_expression_sympy = mathml_to_expression(new_expression_mathml)
return replace_expression_sympy(tm, object_id, new_expression_sympy, initial_flag)
return replace_observable_expression_sympy(tm, obj_id,
new_expression_sympy)


def replace_intial_expression_mathml(tm, initial_id, new_expression_mathml):
# TODO: once initial expressions are supported, implement this
return tm


@amr_to_mira
Expand Down
7 changes: 4 additions & 3 deletions tests/test_modeling/test_askenet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_replace_observable_id(self):
new_id = 'testinf'
new_display_name = 'test-infection'
amr = _d(self.sir_amr)
new_amr = replace_observable_id(amr, old_id, new_id, new_display_name)
new_amr = replace_observable_id(amr, old_id, new_id,
name=new_display_name)

old_semantics_observables = amr['semantics']['ode']['observables']
new_semantics_observables = new_amr['semantics']['ode']['observables']
Expand All @@ -155,13 +156,13 @@ def test_remove_observable_or_parameter(self):
old_amr_param = _d(self.sir_amr)

replaced_observable_id = 'noninf'
new_amr_obs = remove_observable_or_parameter(old_amr_obs, replaced_observable_id)
new_amr_obs = remove_observable(old_amr_obs, replaced_observable_id)
for new_observable in new_amr_obs['semantics']['ode']['observables']:
self.assertNotEqual(new_observable['id'], replaced_observable_id)

replaced_param_id = 'beta'
replacement_value = 5
new_amr_param = remove_observable_or_parameter(old_amr_param, replaced_param_id, replacement_value)
new_amr_param = remove_parameter(old_amr_param, replaced_param_id, replacement_value)
for new_param in new_amr_param['semantics']['ode']['parameters']:
self.assertNotEqual(new_param['id'], replaced_param_id)
for old_rate, new_rate in zip(old_amr_param['semantics']['ode']['rates'],
Expand Down

0 comments on commit e42cc42

Please sign in to comment.