Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add include_t_stop flag to Synchrotool class, Issue 493 #637

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
718637a
add include_t_stop flag and implement False behavior
Moritz-Alexander-Kern Jul 16, 2024
9606903
add tests
Moritz-Alexander-Kern Jul 16, 2024
52fae9c
fix pep8
Moritz-Alexander-Kern Jul 16, 2024
f2e84b1
add parameter to docstring
Moritz-Alexander-Kern Jul 17, 2024
0ec2354
Merge branch 'master' into fix/index_error_synchrotool_493
Moritz-Alexander-Kern Jul 17, 2024
59f283a
implement changes
Moritz-Alexander-Kern Jul 25, 2024
805ad79
update docstring
Moritz-Alexander-Kern Jul 25, 2024
025e22b
add flag for tests
Moritz-Alexander-Kern Jul 25, 2024
05f0414
fix docstring
Moritz-Alexander-Kern Jul 25, 2024
a2d2a93
change t_stop back
Moritz-Alexander-Kern Jul 26, 2024
a02a0d8
add new flag ignore_shared_time to BinnedSpiketrain
Moritz-Alexander-Kern Aug 9, 2024
b0537b7
make use of new flag in Complexity class no_spread
Moritz-Alexander-Kern Aug 9, 2024
7f83ba7
remove unnecessary import
Moritz-Alexander-Kern Aug 9, 2024
a62f302
readd type check of for annotations
Moritz-Alexander-Kern Aug 9, 2024
88db22e
no need to pass parameter in other tests
Moritz-Alexander-Kern Aug 9, 2024
70b9e9e
pass t_start to time_histogram
Moritz-Alexander-Kern Aug 28, 2024
118d3fb
update docstring for ignore_shared_time
Moritz-Alexander-Kern Aug 28, 2024
f1e0202
add tests for ignore_shared_time_interval
Moritz-Alexander-Kern Aug 28, 2024
7d821f7
fix doctest
Moritz-Alexander-Kern Aug 28, 2024
3779ff5
fix pep8
Moritz-Alexander-Kern Aug 28, 2024
a65c14f
typo
Moritz-Alexander-Kern Sep 18, 2024
54e7e8c
edit typo
Moritz-Alexander-Kern Sep 18, 2024
6284e79
edit gitignore
Moritz-Alexander-Kern Sep 18, 2024
b2e32ab
add test checking correct binning
Moritz-Alexander-Kern Sep 30, 2024
1bc8752
Merge branch 'master' into fix/index_error_synchrotool_493
Moritz-Alexander-Kern Nov 14, 2024
589b276
Merge branch 'master' into fix/index_error_synchrotool_493
Moritz-Alexander-Kern Jan 8, 2025
935db4b
Merge branch 'master' into fix/index_error_synchrotool_493
Moritz-Alexander-Kern Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#########################################
# Editor temporary/working/backup files #
# Editor temporary/working/backup files
.#*
[#]*#
*~
Expand All @@ -17,6 +17,7 @@ nosetests.xml
*.tmp*
.idea/
venv/
.venv
env/
.pytest_cache/
**/*/__pycache__
Expand Down
36 changes: 26 additions & 10 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ class BinnedSpikeTrain(object):
The sparse matrix format. By default, CSR format is used to perform
slicing and computations efficiently.
Default: 'csr'
ignore_shared_time : bool, optional
If `True`, the method allows `t_start` and `t_stop` to extend beyond
the shared time interval across all spike trains. This means that the
binning process can include spikes that occur outside the common
time range.
If `False` (default), the method enforces that `t_start` and `t_stop`
must fall within the shared time interval of all spike trains. If
either `t_start` or `t_stop` lies outside this range, a `ValueError`
is raised, ensuring that only the time period where all spike trains
overlap is considered for binning.
Use this parameter when you want to include spikes outside the common
time interval, understanding that it may result in bins that do not
have contributions from all spike trains.

Raises
------
Expand Down Expand Up @@ -335,7 +348,8 @@ class BinnedSpikeTrain(object):
"""

def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None,
t_stop=None, tolerance=1e-8, sparse_format="csr"):
t_stop=None, tolerance=1e-8, sparse_format="csr",
ignore_shared_time=False):
if sparse_format not in ("csr", "csc"):
raise ValueError(f"Invalid 'sparse_format': {sparse_format}. "
"Available: 'csr' and 'csc'")
Expand All @@ -352,9 +366,10 @@ def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None,
self.n_bins = n_bins
self._bin_size = bin_size
self.units = None # will be set later
self.ignore_shared_time = ignore_shared_time
# Check all parameter, set also missing values
self._resolve_input_parameters(spiketrains)
# Now create the sparse matrix
# Now create the sparse matrix.
self.sparse_matrix = self._create_sparse_matrix(
spiketrains, sparse_format=sparse_format)

Expand Down Expand Up @@ -531,14 +546,15 @@ def check_consistency():
tolerance = self.tolerance
if tolerance is None:
tolerance = 0
if self._t_start < start_shared - tolerance \
or self._t_stop > stop_shared + tolerance:
raise ValueError("'t_start' ({t_start}) or 't_stop' ({t_stop}) is "
"outside of the shared [{start_shared}, "
"{stop_shared}] interval".format(
t_start=self.t_start, t_stop=self.t_stop,
start_shared=start_shared,
stop_shared=stop_shared))
if not self.ignore_shared_time:
if self._t_start < start_shared - tolerance \
or self._t_stop > stop_shared + tolerance:
raise ValueError("'t_start' ({t_start}) or 't_stop' ({t_stop}) is "
"outside of the shared [{start_shared}, "
"{stop_shared}] interval".format(
t_start=self.t_start, t_stop=self.t_stop,
start_shared=start_shared,
stop_shared=stop_shared))

if self.n_bins is None:
# bin_size is provided
Expand Down
42 changes: 33 additions & 9 deletions elephant/spike_train_synchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,15 @@ class Synchrotool(Complexity):
This class inherits from :class:`elephant.statistics.Complexity`, see its
documentation for more details and input parameters description.

Parameters
----------
include_t_stop : bool, optional
If True, the end of the spike train (`t_stop`) is included in the
analysis, ensuring that any spikes close to `t_stop` are properly
annotated.
Default: True.


See also
--------
elephant.statistics.Complexity
Expand All @@ -266,16 +275,20 @@ def __init__(self, spiketrains,
bin_size=None,
binary=True,
spread=0,
tolerance=1e-8):
tolerance=1e-8,
include_t_stop=True):

self.annotated = False

super(Synchrotool, self).__init__(spiketrains=spiketrains,
bin_size=bin_size,
sampling_rate=sampling_rate,
binary=binary,
spread=spread,
tolerance=tolerance)
super(Synchrotool, self).__init__(
spiketrains=spiketrains,
bin_size=bin_size,
sampling_rate=sampling_rate,
binary=binary,
spread=spread,
tolerance=tolerance,
t_stop=spiketrains[0].t_stop + (1 / sampling_rate) if include_t_stop else None,
)

def delete_synchrofacts(self, threshold, in_place=False, mode='delete'):
"""
Expand Down Expand Up @@ -391,6 +404,11 @@ def annotate_synchrofacts(self):
"""
Annotate the complexity of each spike in the
``self.epoch.array_annotations`` *in-place*.

Raises
-----
ValueError
If spikes fall too close to `t_stop` and can not be associated with a bin.
"""
epoch_complexities = self.epoch.array_annotations['complexity']
right_edges = (
Expand All @@ -399,15 +417,21 @@ def annotate_synchrofacts(self):
self.epoch.times.units).magnitude.flatten()
)

for idx, st in enumerate(self.input_spiketrains):
for st in self.input_spiketrains:

# all indices of spikes that are within the half-open intervals
# defined by the boundaries
# note that every second entry in boundaries is an upper boundary
spike_to_epoch_idx = np.searchsorted(
right_edges,
st.times.rescale(self.epoch.times.units).magnitude.flatten())
complexity_per_spike = epoch_complexities[spike_to_epoch_idx]
try:
complexity_per_spike = epoch_complexities[spike_to_epoch_idx]
except IndexError:
raise ValueError(
"Some spikes in the input Spike Train may be too close or right at t_stop, they can not be binned "
"and therefore are not annotated. "
"Consider setting include_t_stop=True in the Synchrotool class to address this.")

st.array_annotate(complexity=complexity_per_spike)

Expand Down
19 changes: 13 additions & 6 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,12 +1162,14 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
if binary:
binned_spiketrain = BinnedSpikeTrain(spiketrains,
t_start=t_start,
t_stop=t_stop, bin_size=bin_size
t_stop=t_stop, bin_size=bin_size,
ignore_shared_time=True
).binarize(copy=False)
else:
binned_spiketrain = BinnedSpikeTrain(spiketrains,
t_start=t_start,
t_stop=t_stop, bin_size=bin_size
t_stop=t_stop, bin_size=bin_size,
ignore_shared_time=True
)

bin_hist: Union[int, ndarray] = binned_spiketrain.get_num_of_spikes(axis=0)
Expand Down Expand Up @@ -1423,7 +1425,10 @@ def __init__(self, spiketrains,
bin_size=None,
binary=True,
spread=0,
tolerance=1e-8):
tolerance=1e-8,
t_start=None,
t_stop=None,
):

check_neo_consistency(spiketrains, object_type=neo.SpikeTrain)

Expand All @@ -1434,8 +1439,8 @@ def __init__(self, spiketrains,
raise ValueError('Spread must be >=0')

self.input_spiketrains = spiketrains
self.t_start = spiketrains[0].t_start
self.t_stop = spiketrains[0].t_stop
self.t_start = spiketrains[0].t_start if t_start is None else t_start
self.t_stop = spiketrains[0].t_stop if t_stop is None else t_stop
self.sampling_rate = sampling_rate
self.bin_size = bin_size
self.binary = binary
Expand Down Expand Up @@ -1482,7 +1487,9 @@ def _histogram_no_spread(self):
# clip the spike trains before summing
time_hist = time_histogram(self.input_spiketrains,
self.bin_size,
binary=self.binary)
binary=self.binary,
t_start=self.t_start,
t_stop=self.t_stop)

time_hist_magnitude = time_hist.magnitude.flatten()

Expand Down
58 changes: 58 additions & 0 deletions elephant/test/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ def setUp(self):
self.bin_size = 1 * pq.s
self.tolerance = 1e-8

# Create some sample spike trains with different start and stop times
self.spiketrains = (
neo.SpikeTrain([0.1, 0.5, 1.0, 1.5, 2.0] * pq.s, t_start=0.0 * pq.s, t_stop=2.5 * pq.s),
neo.SpikeTrain([0.2, 0.6, 1.1, 1.6, 2.0] * pq.s, t_start=0.1 * pq.s, t_stop=2.0 * pq.s),
neo.SpikeTrain([0.3, 0.7, 1.2, 1.7, 2.1] * pq.s, t_start=0.2 * pq.s, t_stop=2.1 * pq.s),
)

def test_binarize(self):
spiketrains = [self.spiketrain_a, self.spiketrain_b,
self.spiketrain_a, self.spiketrain_b]
Expand Down Expand Up @@ -723,6 +730,57 @@ def test_binned_spiketrain_rounding(self):
assert_array_equal(bst.to_array().nonzero()[1],
np.arange(120000))

def test_binned_spiketrain_ignore_shared_time_false_raises_error(self):
"""
Test that a ValueError is raised when ignore_shared_time is False and
t_start or t_stop is outside the shared interval.
"""
t_start = 0.0 * pq.s # Outside shared interval (shared start is 0.2 s)
t_stop = 2.5 * pq.s # Outside shared interval (shared stop is 2.0 s)

with self.assertRaises(ValueError):
cv.BinnedSpikeTrain(spiketrains=self.spiketrains, bin_size=self.bin_size,
t_start=t_start, t_stop=t_stop, ignore_shared_time=False)

def test_binned_spiketrain_ignore_shared_time_true_allows_outside_interval(self):
"""
Test that no error is raised when ignore_shared_time is True, even if
t_start or t_stop is outside the shared interval.
"""
t_start = 0.0 * pq.s # Outside shared interval (shared start is 0.2 s)
t_stop = 2.5 * pq.s # Outside shared interval (shared stop is 2.0 s)

try:
_ = cv.BinnedSpikeTrain(spiketrains=self.spiketrains, bin_size=self.bin_size,
t_start=t_start, t_stop=t_stop, ignore_shared_time=True)
# If we reach this point, the test should pass.
self.assertTrue(True)
except ValueError:
self.fail("BinnedSpikeTrain raised ValueError unexpectedly when ignore_shared_time=True")

def test_ignore_shared_time_correct_binning(self):
# Create spike trains with different time ranges
st1 = neo.SpikeTrain([0.5, 1.5, 2.5, 3.5] * pq.s, t_start=0.0 * pq.s, t_stop=4.0 * pq.s)
st2 = neo.SpikeTrain([1.0, 2.0, 3.0, 4.0] * pq.s, t_start=1.0 * pq.s, t_stop=5.0 * pq.s)
st3 = neo.SpikeTrain([1.5, 2.5, 3.5, 5.5] * pq.s, t_start=1.5 * pq.s, t_stop=5.5 * pq.s)

spiketrains = [st1, st2, st3]
bin_size = 1 * pq.s

# Test with ignore_shared_time=True
bst_ignore = cv.BinnedSpikeTrain(spiketrains, bin_size=bin_size,
t_start=0 * pq.s, t_stop=6 * pq.s,
ignore_shared_time=True)
self.assertEqual(bst_ignore.t_start, 0 * pq.s)
self.assertEqual(bst_ignore.t_stop, 6 * pq.s)
self.assertEqual(bst_ignore.n_bins, 6)
expected_array_ignore = np.array([
[1, 1, 1, 1, 0, 0],
[0, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 0, 1]
])
assert_array_equal(bst_ignore.to_array(), expected_array_ignore)


class DiscretiseSpiketrainsTestCase(unittest.TestCase):
def setUp(self):
Expand Down
23 changes: 23 additions & 0 deletions elephant/test/test_spike_train_synchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,29 @@ def test_wrong_input_errors(self):
synchrofact_obj.delete_synchrofacts,
-1)

def test_regression_PR_612_index_out_of_bounds_raise_warning(self):
"""
https://github.com/NeuralEnsemble/elephant/pull/612
"""
sampling_rate = 1/pq.ms
st = neo.SpikeTrain(np.arange(0, 11)*pq.ms, t_start=0*pq.ms, t_stop=10*pq.ms)

synchrotool_instance = Synchrotool([st, st], sampling_rate, spread=0, include_t_stop=False)

with self.assertRaises(ValueError):
synchrotool_instance.annotate_synchrofacts()

def test_regression_PR_612_index_out_of_bounds(self):
"""
https://github.com/NeuralEnsemble/elephant/pull/612
"""
sampling_rate = 1/pq.ms
st = neo.SpikeTrain(np.arange(0, 11)*pq.ms, t_start=0*pq.ms, t_stop=10*pq.ms)

synchrotool_instance = Synchrotool([st, st], sampling_rate, spread=0, include_t_stop=True)
synchrotool_instance.annotate_synchrofacts()
self.assertEqual(len(st.array_annotations['complexity']), len(st)) # all spikes annotated


if __name__ == '__main__':
unittest.main()
Loading