diff --git a/bluepyemodel/emodel_pipeline/plotting_utils.py b/bluepyemodel/emodel_pipeline/plotting_utils.py index dc473c41..e213cda1 100644 --- a/bluepyemodel/emodel_pipeline/plotting_utils.py +++ b/bluepyemodel/emodel_pipeline/plotting_utils.py @@ -32,6 +32,9 @@ from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus +from bluepyemodel.tools.utils import get_curr_name +from bluepyemodel.tools.utils import get_loc_name +from bluepyemodel.tools.utils import get_protocol_name logger = logging.getLogger("__main__") @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name): simulated_amp = [] for val in values: if prot_name.lower() in val.lower(): - # val is e.g. IV_40.soma.maximum_voltage_from_voltagebase - n = val.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - protocol_name = n[0] + protocol_name = get_protocol_name(val) amp_temp = float(protocol_name.split("_")[-1]) if "mean_frequency" in val: simulated_freq.append(values[val]) @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys): ordered_keys = {} for name in keys: - n = name.split(".") - # case where protocol has '.' in its name, e.g. IV_-100.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - prot_name = n[0] + prot_name = get_protocol_name(name) # prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055 if not any(to_skip_ in prot_name for to_skip_ in to_skip): - if len(n) != 3: - raise ValueError(f"Expected 3 elements in {n}") - loc_name = n[1] - curr_name = n[2] + loc_name = get_loc_name(name) + curr_name = get_curr_name(name) if prot_name not in ordered_keys: ordered_keys[prot_name] = {} diff --git a/bluepyemodel/tools/utils.py b/bluepyemodel/tools/utils.py index bef41fc7..73610b18 100644 --- a/bluepyemodel/tools/utils.py +++ b/bluepyemodel/tools/utils.py @@ -256,10 +256,119 @@ def get_amplitude_from_feature_key(feat_key): Args: feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase """ - n = feat_key.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - protocol_name = n[0] + protocol_name = get_protocol_name(feat_key) return float(protocol_name.split("_")[-1]) + + +def parse_feature_name_parts(feature_name): + """ + Splits the feature name into its respective parts, + handling cases where the protocol name contains a dot. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It splits the input into a list of parts, + combining the first two parts if the protocol name contains a dot and is followed by + a numeric component. + + Args: + feature_name (str): The full feature name string or response key to be parsed. + + Returns: + list: A list of strings representing the correctly parsed parts of the feature name. + + Examples: + >>> parse_feature_name_parts("IV_40.0.soma.v.voltage_base") + ['IV_40.0', 'soma', 'v', 'voltage_base'] + + >>> parse_feature_name_parts("IV_40.0.soma.v") + ['IV_40.0', 'soma', 'v'] + """ + parts = feature_name.split(".") + if len(parts) > 1 and parts[1].isdigit(): + return [".".join(parts[:2])] + parts[2:] + return parts + + +def get_protocol_name(feature_name): + """ + Extracts the protocol name from the feature name or response key. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It returns the first part of the input, which is + the protocol name, correctly handling cases where the protocol contains a dot. + + Args: + feature_name (str): The full feature name string or response key. + + Returns: + str: The protocol name part of the feature name. + + Examples: + >>> get_protocol_name("IV_40.0.soma.v.voltage_base") + 'IV_40.0' + + >>> get_protocol_name("IV_40.0.soma.v") + 'IV_40.0' + """ + return parse_feature_name_parts(feature_name)[0] + + +def get_loc_name(feature_name): + """ + Extracts the location name from the feature name or response key. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It returns the second part of the input, which is + the location name, correctly handling cases where the protocol contains a dot. + + Args: + feature_name (str): The full feature name string or response key. + + Returns: + str: The location name part of the feature name. + + Raises: + IndexError: If the location name cannot be determined from the input. + + Examples: + >>> get_loc_name("IV_40.0.soma.v.voltage_base") + 'soma' + + >>> get_loc_name("IV_40.0.soma.v") + 'soma' + """ + parts = parse_feature_name_parts(feature_name) + if len(parts) < 2: + raise IndexError("Location name not found in the feature name.") + return parts[1] + + +def get_curr_name(feature_name): + """ + Extracts the current name from the feature name or response key. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It returns the third part of the input, which is + the current name, correctly handling cases where the protocol contains a dot. + + Args: + feature_name (str): The full feature name string or response key. + + Returns: + str: The current name part of the feature name. + + Raises: + IndexError: If the current name cannot be determined from the input. + + Examples: + >>> get_curr_name("IV_40.0.soma.v.voltage_base") + 'v' + + >>> get_curr_name("IV_40.0.soma.v") + 'v' + """ + parts = parse_feature_name_parts(feature_name) + if len(parts) < 3: + raise IndexError("Current name not found in the feature name.") + return parts[2] diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py index 28cf8286..0d217c45 100644 --- a/bluepyemodel/validation/validation.py +++ b/bluepyemodel/validation/validation.py @@ -25,6 +25,7 @@ from bluepyemodel.evaluation.evaluation import compute_responses from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point from bluepyemodel.tools.utils import are_same_protocol +from bluepyemodel.tools.utils import get_protocol_name from bluepyemodel.validation import validation_functions logger = logging.getLogger(__name__) @@ -76,7 +77,7 @@ def compute_scores(model, validation_protocols): scores = model.evaluator.fitness_calculator.calculate_scores(model.responses) for feature_name in scores: - protocol_name = feature_name.split(".")[0] + protocol_name = get_protocol_name(feature_name) if any(are_same_protocol(p, protocol_name) for p in validation_protocols): model.scores_validation[feature_name] = scores[feature_name] else: diff --git a/tests/unit_tests/test_tools.py b/tests/unit_tests/test_tools.py index 85e1cb84..7c14dab9 100644 --- a/tests/unit_tests/test_tools.py +++ b/tests/unit_tests/test_tools.py @@ -20,6 +20,9 @@ from bluepyemodel.tools.utils import are_same_protocol from bluepyemodel.tools.utils import format_protocol_name_to_list from bluepyemodel.tools.utils import select_rec_for_thumbnail +from bluepyemodel.tools.utils import get_protocol_name +from bluepyemodel.tools.utils import get_loc_name +from bluepyemodel.tools.utils import get_curr_name from tests.utils import DATA @@ -136,3 +139,70 @@ def test_select_rec_for_thumbnail(): assert ( select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v" ) + + +def test_get_protocol_name(): + # feature keys + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_protocol_name(feature_name) == "IV_40.0" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_protocol_name(feature_name) == "IV_40" + + feature_name = "ProtocolA.1.soma.v.some_feature" + assert get_protocol_name(feature_name) == "ProtocolA.1" + + # response keys + feature_name = "IV_40.0.soma.v" + assert get_protocol_name(feature_name) == "IV_40.0" + + feature_name = "IV_40.soma.v" + assert get_protocol_name(feature_name) == "IV_40" + + feature_name = "ProtocolA.1.soma.v" + assert get_protocol_name(feature_name) == "ProtocolA.1" + + +def test_get_loc_name(): + # feature keys + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.0" + with pytest.raises(IndexError, match="Location name not found in the feature name."): + get_loc_name(feature_name) + + # response keys + feature_name = "IV_40.0.soma.v" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.soma.v" + assert get_loc_name(feature_name) == "soma" + + feature_name = "ProtocolA.1.soma.v" + assert get_loc_name(feature_name) == "soma" + +def test_get_curr_name(): + # feature keys + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.0.soma" + with pytest.raises(IndexError, match="Current name not found in the feature name."): + get_curr_name(feature_name) + + # response keys + feature_name = "IV_40.0.soma.v" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.soma.v" + assert get_curr_name(feature_name) == "v" + + feature_name = "ProtocolA.1.soma.v" + assert get_curr_name(feature_name) == "v" \ No newline at end of file