Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use all available cores in feature extraction #176

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
68 changes: 34 additions & 34 deletions bluepyefe/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,30 @@
along with this library; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""
from collections import defaultdict
import logging
from multiprocessing import Pool
import numpy
import matplotlib.pyplot as plt
import pathlib

from bluepyefe.ecode import eCodes
from bluepyefe.reader import *
from bluepyefe.plotting import _save_fig
from matplotlib.backends.backend_pdf import PdfPages

from bluepyefe.recording import Recording

logger = logging.getLogger(__name__)


class Cell(object):
def extract_efeatures_helper(recording, efeatures, efeature_names, efel_settings):
"""Helper function to compute efeatures for a single recording."""
recording.compute_efeatures(
efeatures, efeature_names, efel_settings)
return recording


class Cell:

"""Contains the metadata related to a cell as well as the
electrophysiological recordings once they are read"""
Expand All @@ -46,9 +56,14 @@ def __init__(self, name):

self.name = name

self.recordings = []
self.recordings: dict[str, list[Recording]] = defaultdict(list)
self.rheobase = None

@property
def recordings_as_list(self):
"""Return all the recordings as a list."""
return [rec for recordings_list in self.recordings.values() for rec in recordings_list]

def reader(self, config_data, recording_reader=None):
"""Define the reader method used to read the ephys data for the
present recording and returns the data contained in the file.
Expand Down Expand Up @@ -90,9 +105,8 @@ def reader(self, config_data, recording_reader=None):
)

def get_protocol_names(self):
"""List of all the protocols available for the present cell."""

return list(set([rec.protocol_name for rec in self.recordings]))
"""List of all the protocol names available for the present cell."""
return list(self.recordings.keys())

def get_recordings_by_protocol_name(self, protocol_name):
"""List of all the recordings available for the present cell for a
Expand All @@ -102,27 +116,7 @@ def get_recordings_by_protocol_name(self, protocol_name):
protocol_name (str): name of the protocol for which to get
the recordings.
"""

return [
rec
for rec in self.recordings
if rec.protocol_name == protocol_name
]

def get_recordings_id_by_protocol_name(self, protocol_name):
"""List of the indexes of the recordings available for the present
cell for a given protocol.

Args:
protocol_name (str): name of the protocol for which to get
the recordings.
"""

return [
i
for i, trace in enumerate(self.recordings)
if trace.protocol_name == protocol_name
]
return self.recordings.get(protocol_name)

def read_recordings(
self,
Expand Down Expand Up @@ -163,7 +157,7 @@ def read_recordings(
protocol_name,
efel_settings
)
self.recordings.append(rec)
self.recordings[protocol_name].append(rec)
break
else:
raise KeyError(
Expand Down Expand Up @@ -192,19 +186,25 @@ def extract_efeatures(
is to be extracted several time on different sections
of the same recording.
"""
recordings_of_protocol: list[Recording] = self.recordings.get(protocol_name)

# Run in parallel via multiprocessing
with Pool(maxtasksperchild=1) as pool:
tasks = [
(recording, efeatures, efeature_names, efel_settings)
for recording in recordings_of_protocol
]
results = pool.starmap(extract_efeatures_helper, tasks)

for i in self.get_recordings_id_by_protocol_name(protocol_name):
self.recordings[i].compute_efeatures(
efeatures, efeature_names, efel_settings)
self.recordings[protocol_name] = results

def compute_relative_amp(self):
"""Compute the relative current amplitude for all the recordings as a
percentage of the rheobase."""

if self.rheobase not in (0.0, None, False, numpy.nan):

for i in range(len(self.recordings)):
self.recordings[i].compute_relative_amp(self.rheobase)
for recording in self.recordings_as_list:
recording.compute_relative_amp(self.rheobase)

else:

Expand Down
9 changes: 3 additions & 6 deletions bluepyefe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
along with this library; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""
import os
import pickle
import functools
import logging
Expand Down Expand Up @@ -448,10 +447,8 @@ def _build_current_dict(cells, default_std_value):
threshold = {}

for cell in cells:

holding[cell.name] = numpy.nanmean(
[t.hypamp for t in cell.recordings]
)
holding_currents = [rec.hypamp for rec in cell.recordings_as_list]
holding[cell.name] = numpy.nanmean(holding_currents)

if cell.rheobase is not None:
threshold[cell.name] = cell.rheobase
Expand Down Expand Up @@ -762,7 +759,7 @@ def _extract_auto_targets(

recordings = []
for c in cells:
recordings += c.recordings
recordings += c.recordings_as_list

for i in range(len(auto_targets)):
auto_targets[i].select_ecode_and_amplitude(recordings)
Expand Down
13 changes: 8 additions & 5 deletions bluepyefe/rheobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@


def _get_list_spiking_amplitude(cell, protocols_rheobase):
"""Return the list of sorted list of amplitude that triggered at least
one spike"""
"""Return the list of sorted amplitudes that triggered at least
one spike, along with their corresponding spike counts."""

amps = []
spike_counts = []

for i, rec in enumerate(cell.recordings):
for rec in cell.recordings_as_list:
if rec.protocol_name in protocols_rheobase:
if rec.spikecount is not None:

Expand All @@ -42,13 +42,16 @@ def _get_list_spiking_amplitude(cell, protocols_rheobase):
logger.warning(
f"A recording of cell {cell.name} protocol "
f"{rec.protocol_name} shows spikes at a "
"suspiciously low current in a trace from file"
f" {rec.files}. Check that the ton and toff are"
"suspiciously low current in a trace from file "
f"{rec.files}. Check that the ton and toff are "
"correct or for the presence of unwanted spikes."
)

# Sort amplitudes and their corresponding spike counts
if amps:
amps, spike_counts = zip(*sorted(zip(amps, spike_counts)))
else:
amps, spike_counts = (), ()

return amps, spike_counts

Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""conftest.py contains fixtures that are automatically imported by pytest."""
import pytest
import matplotlib


@pytest.fixture(autouse=True, scope='session')
def set_matplotlib_backend():
matplotlib.use('Agg') # to avoid opening windows during testing
5 changes: 2 additions & 3 deletions tests/ecode/test_apthresh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""bluepyefe.ecode.APThreshold tests"""

import unittest
import pytest
import glob
import json

Expand Down Expand Up @@ -96,13 +95,13 @@ def run_test_with_absolute_amplitude(self, absolute_amplitude):
bluepyefe.extract.compute_rheobase(cells, protocols_rheobase=["IDthresh"])

self.assertEqual(len(cells), 1)
self.assertEqual(len(cells[0].recordings), 21)
self.assertEqual(len(cells[0].recordings_as_list), 21)
self.assertLess(abs(cells[0].rheobase - 0.1103), 0.01)

# amplitude test for one recording
# sort the recordings because they can be in any order,
# and we want to select the same one each time we test
apthresh_recs = [rec for rec in cells[0].recordings if rec.protocol_name == "APThreshold"]
apthresh_recs = cells[0].recordings["APThreshold"]
rec1 = sorted(apthresh_recs, key=lambda x: x.amp)[1]
self.assertLess(abs(rec1.amp - 0.1740), 0.01)
self.assertLess(abs(rec1.amp_rel - 157.7), 0.1)
Expand Down
4 changes: 2 additions & 2 deletions tests/ecode/test_sahp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def run_test_with_absolute_amplitude(self, absolute_amplitude):
bluepyefe.extract.compute_rheobase(cells, protocols_rheobase=["IDthresh"])

self.assertEqual(len(cells), 1)
self.assertEqual(len(cells[0].recordings), 24)
self.assertEqual(len(cells[0].recordings_as_list), 24)
self.assertLess(abs(cells[0].rheobase - 0.1103), 0.01)

# amplitude test for one recording
# sort the recordings because they can be in any order,
# and we want to select the same one each time we test
sahp_recs = [rec for rec in cells[0].recordings if rec.protocol_name == "sAHP"]
sahp_recs = cells[0].recordings["sAHP"]
rec1 = sorted(sahp_recs, key=lambda x: x.amp2)[1]
self.assertLess(abs(rec1.amp - 0.0953), 0.01)
self.assertLess(abs(rec1.amp2 - 0.3153), 0.01)
Expand Down
18 changes: 11 additions & 7 deletions tests/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import unittest

import bluepyefe.cell
import bluepyefe.recording
from bluepyefe.cell import Cell, extract_efeatures_helper
from bluepyefe.rheobase import compute_rheobase_absolute


class CellTest(unittest.TestCase):
def setUp(self):

self.cell = bluepyefe.cell.Cell(name="MouseNeuron")
self.cell = Cell(name="MouseNeuron")
self.protocol_name = "IDRest"

file_metadata = {
"i_file": "./tests/exp_data/B95_Ch0_IDRest_107.ibw",
Expand All @@ -25,18 +25,22 @@ def setUp(self):
self.cell.read_recordings(protocol_data=[file_metadata], protocol_name="IDRest")

self.cell.extract_efeatures(
protocol_name="IDRest", efeatures=["Spikecount", "AP1_amp"]
protocol_name=self.protocol_name, efeatures=["Spikecount", "AP1_amp"]
)

def test_efeature_extraction(self):
recording = self.cell.recordings[0]
recording = self.cell.recordings[self.protocol_name][0]
self.assertEqual(2, len(recording.efeatures))
self.assertEqual(recording.efeatures["Spikecount"], 9.0)
self.assertLess(abs(recording.efeatures["AP1_amp"] - 66.4), 2.0)

def test_extract_efeatures_helper(self):
recording = self.cell.recordings[self.protocol_name][0]
extract_efeatures_helper(recording, ["Spikecount", "AP1_amp"], None, None)

def test_amp_threshold(self):
recording = self.cell.recordings[0]
compute_rheobase_absolute(self.cell, ["IDRest"])
recording = self.cell.recordings[self.protocol_name][0]
compute_rheobase_absolute(self.cell, [self.protocol_name])
self.cell.compute_relative_amp()
self.assertEqual(recording.amp, self.cell.rheobase)
self.assertEqual(recording.amp_rel, 100.0)
Expand Down
21 changes: 4 additions & 17 deletions tests/test_efel_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ def setUp(self):

def test_efel_threshold(self):

self.cell.recordings[0].efeatures = {}
self.cell.recordings["IDRest"][0].efeatures = {}

self.cell.extract_efeatures(
protocol_name="IDRest",
efeatures=["Spikecount", "AP1_amp"],
efel_settings={'Threshold': 40.}
)

recording = self.cell.recordings[0]
recording = self.cell.recordings["IDRest"][0]
self.assertEqual(recording.efeatures["Spikecount"], 0.)
self.assertLess(abs(recording.efeatures["AP1_amp"] - 66.68), 0.01)
Copy link
Contributor Author

@anilbey anilbey Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code was not getting executed for a long time. The other test was shadowing it


def test_efel_strictstim(self):

self.cell.recordings[0].efeatures = {}
self.cell.recordings["IDRest"][0].efeatures = {}

self.cell.extract_efeatures(
protocol_name="IDRest",
Expand All @@ -54,20 +54,7 @@ def test_efel_strictstim(self):
}
)

self.assertEqual(self.cell.recordings[0].efeatures["Spikecount"], 0.)

def test_efel_threshold(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate of another test with the same name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the other test has more checks


self.cell.recordings[0].efeatures = {}

self.cell.extract_efeatures(
protocol_name="IDRest",
efeatures=["Spikecount"],
efel_settings={'Threshold': 40.}
)

recording = self.cell.recordings[0]
self.assertEqual(recording.efeatures["Spikecount"], 0.)
self.assertEqual(self.cell.recordings["IDRest"][0].efeatures["Spikecount"], 0.)


if __name__ == "__main__":
Expand Down
15 changes: 5 additions & 10 deletions tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import bluepyefe.extract
import bluepyefe.tools
from tests.utils import download_sahp_datafiles


def get_config(absolute_amplitude=False):
Expand Down Expand Up @@ -90,8 +89,8 @@ def test_extract(self):
bluepyefe.extract.compute_rheobase(cells, protocols_rheobase=["IDRest"])

self.assertEqual(len(cells), 2)
self.assertEqual(len(cells[0].recordings), 5)
self.assertEqual(len(cells[1].recordings), 5)
self.assertEqual(len(cells[0].recordings_as_list), 5)
self.assertEqual(len(cells[1].recordings_as_list), 5)

self.assertLess(abs(cells[0].rheobase - 0.119), 0.01)
self.assertLess(abs(cells[1].rheobase - 0.0923), 0.01)
Expand Down Expand Up @@ -158,7 +157,7 @@ def test_extract_auto(self):

recordings = []
for c in cells:
recordings += c.recordings
recordings += c.recordings_as_list

for i in range(len(auto_targets)):
auto_targets[i].select_ecode_and_amplitude(recordings)
Expand All @@ -181,8 +180,8 @@ def test_extract_absolute(self):
)

self.assertEqual(len(cells), 2)
self.assertEqual(len(cells[0].recordings), 5)
self.assertEqual(len(cells[1].recordings), 5)
self.assertEqual(len(cells[0].recordings_as_list), 5)
self.assertEqual(len(cells[1].recordings_as_list), 5)

self.assertEqual(cells[0].rheobase, None)
self.assertEqual(cells[1].rheobase, None)
Expand All @@ -198,10 +197,6 @@ def test_extract_absolute(self):
cells=cells, protocols=protocols, output_directory="MouseCells"
)

for cell in cells:
for r in cell.recordings:
print(r.amp, r.efeatures)

for protocol in protocols:
if protocol.name == "IDRest" and protocol.amplitude == 0.25:
for target in protocol.feature_targets:
Expand Down
1 change: 0 additions & 1 deletion tests/test_lccr_csv_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""bluepyefe.nwbreader tests"""
import unittest
import h5py
from pathlib import Path
from bluepyefe.reader import csv_lccr_reader

Expand Down
Loading
Loading