Skip to content

Commit

Permalink
Fix compute_scores to handle protocol names with '.' (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkilic authored Aug 28, 2024
1 parent 8799c74 commit 81d90c2
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 21 deletions.
22 changes: 7 additions & 15 deletions bluepyemodel/emodel_pipeline/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol
from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom
from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus
from bluepyemodel.tools.utils import get_curr_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_protocol_name

logger = logging.getLogger("__main__")

Expand Down Expand Up @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name):
simulated_amp = []
for val in values:
if prot_name.lower() in val.lower():
# val is e.g. IV_40.soma.maximum_voltage_from_voltagebase
n = val.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(val)
amp_temp = float(protocol_name.split("_")[-1])
if "mean_frequency" in val:
simulated_freq.append(values[val])
Expand Down Expand Up @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys):

ordered_keys = {}
for name in keys:
n = name.split(".")
# case where protocol has '.' in its name, e.g. IV_-100.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
prot_name = n[0]
prot_name = get_protocol_name(name)
# prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055
if not any(to_skip_ in prot_name for to_skip_ in to_skip):
if len(n) != 3:
raise ValueError(f"Expected 3 elements in {n}")
loc_name = n[1]
curr_name = n[2]
loc_name = get_loc_name(name)
curr_name = get_curr_name(name)

if prot_name not in ordered_keys:
ordered_keys[prot_name] = {}
Expand Down
119 changes: 114 additions & 5 deletions bluepyemodel/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,119 @@ def get_amplitude_from_feature_key(feat_key):
Args:
feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase
"""
n = feat_key.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(feat_key)

return float(protocol_name.split("_")[-1])


def parse_feature_name_parts(feature_name):
"""
Splits the feature name into its respective parts,
handling cases where the protocol name contains a dot.
This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It splits the input into a list of parts,
combining the first two parts if the protocol name contains a dot and is followed by
a numeric component.
Args:
feature_name (str): The full feature name string or response key to be parsed.
Returns:
list: A list of strings representing the correctly parsed parts of the feature name.
Examples:
>>> parse_feature_name_parts("IV_40.0.soma.v.voltage_base")
['IV_40.0', 'soma', 'v', 'voltage_base']
>>> parse_feature_name_parts("IV_40.0.soma.v")
['IV_40.0', 'soma', 'v']
"""
parts = feature_name.split(".")
if len(parts) > 1 and parts[1].isdigit():
return [".".join(parts[:2])] + parts[2:]
return parts


def get_protocol_name(feature_name):
"""
Extracts the protocol name from the feature name or response key.
This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It returns the first part of the input, which is
the protocol name, correctly handling cases where the protocol contains a dot.
Args:
feature_name (str): The full feature name string or response key.
Returns:
str: The protocol name part of the feature name.
Examples:
>>> get_protocol_name("IV_40.0.soma.v.voltage_base")
'IV_40.0'
>>> get_protocol_name("IV_40.0.soma.v")
'IV_40.0'
"""
return parse_feature_name_parts(feature_name)[0]


def get_loc_name(feature_name):
"""
Extracts the location name from the feature name or response key.
This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It returns the second part of the input, which is
the location name, correctly handling cases where the protocol contains a dot.
Args:
feature_name (str): The full feature name string or response key.
Returns:
str: The location name part of the feature name.
Raises:
IndexError: If the location name cannot be determined from the input.
Examples:
>>> get_loc_name("IV_40.0.soma.v.voltage_base")
'soma'
>>> get_loc_name("IV_40.0.soma.v")
'soma'
"""
parts = parse_feature_name_parts(feature_name)
if len(parts) < 2:
raise IndexError("Location name not found in the feature name.")
return parts[1]


def get_curr_name(feature_name):
"""
Extracts the current name from the feature name or response key.
This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It returns the third part of the input, which is
the current name, correctly handling cases where the protocol contains a dot.
Args:
feature_name (str): The full feature name string or response key.
Returns:
str: The current name part of the feature name.
Raises:
IndexError: If the current name cannot be determined from the input.
Examples:
>>> get_curr_name("IV_40.0.soma.v.voltage_base")
'v'
>>> get_curr_name("IV_40.0.soma.v")
'v'
"""
parts = parse_feature_name_parts(feature_name)
if len(parts) < 3:
raise IndexError("Current name not found in the feature name.")
return parts[2]
3 changes: 2 additions & 1 deletion bluepyemodel/validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from bluepyemodel.evaluation.evaluation import compute_responses
from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.validation import validation_functions

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,7 +77,7 @@ def compute_scores(model, validation_protocols):

scores = model.evaluator.fitness_calculator.calculate_scores(model.responses)
for feature_name in scores:
protocol_name = feature_name.split(".")[0]
protocol_name = get_protocol_name(feature_name)
if any(are_same_protocol(p, protocol_name) for p in validation_protocols):
model.scores_validation[feature_name] = scores[feature_name]
else:
Expand Down
70 changes: 70 additions & 0 deletions tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import format_protocol_name_to_list
from bluepyemodel.tools.utils import select_rec_for_thumbnail
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_curr_name
from tests.utils import DATA


Expand Down Expand Up @@ -136,3 +139,70 @@ def test_select_rec_for_thumbnail():
assert (
select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v"
)


def test_get_protocol_name():
# feature keys
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40.0"

feature_name = "IV_40.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40"

feature_name = "ProtocolA.1.soma.v.some_feature"
assert get_protocol_name(feature_name) == "ProtocolA.1"

# response keys
feature_name = "IV_40.0.soma.v"
assert get_protocol_name(feature_name) == "IV_40.0"

feature_name = "IV_40.soma.v"
assert get_protocol_name(feature_name) == "IV_40"

feature_name = "ProtocolA.1.soma.v"
assert get_protocol_name(feature_name) == "ProtocolA.1"


def test_get_loc_name():
# feature keys
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.0"
with pytest.raises(IndexError, match="Location name not found in the feature name."):
get_loc_name(feature_name)

# response keys
feature_name = "IV_40.0.soma.v"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.soma.v"
assert get_loc_name(feature_name) == "soma"

feature_name = "ProtocolA.1.soma.v"
assert get_loc_name(feature_name) == "soma"

def test_get_curr_name():
# feature keys
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.0.soma"
with pytest.raises(IndexError, match="Current name not found in the feature name."):
get_curr_name(feature_name)

# response keys
feature_name = "IV_40.0.soma.v"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.soma.v"
assert get_curr_name(feature_name) == "v"

feature_name = "ProtocolA.1.soma.v"
assert get_curr_name(feature_name) == "v"

0 comments on commit 81d90c2

Please sign in to comment.