From 6ce6558478ac97ac594262785e716e457b5aba38 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 11 Jun 2024 09:05:56 +0200 Subject: [PATCH] [Fix] Tests for Neo 0.13.1, add same object to list (#634) * fix tests for statistics and trials modules * fix tests in neo_tools --- elephant/test/test_neo_tools.py | 82 ++++++++++--------------- elephant/test/test_trials.py | 2 +- requirements/requirements-docs.txt | 2 +- requirements/requirements-tutorials.txt | 2 +- 4 files changed, 35 insertions(+), 53 deletions(-) diff --git a/elephant/test/test_neo_tools.py b/elephant/test/test_neo_tools.py index dbb5d5aad..3c8597e6e 100644 --- a/elephant/test/test_neo_tools.py +++ b/elephant/test/test_neo_tools.py @@ -1163,40 +1163,52 @@ def test__get_all_spiketrains__spiketrain(self): # assert_same_sub_schema(targ, res0) def test__get_all_spiketrains__segment(self): + # Generate a simple segment object containing one spike train, + # supporting objects of type Segment and SpikeTrain. obj = generate_one_simple_segment( - supported_objects=[neo.core.Segment, neo.core.SpikeTrain]) - targ = copy.deepcopy(obj) - obj.spiketrains.append(obj.spiketrains[0]) - + nb_spiketrain=1, + supported_objects=[neo.core.Segment, neo.core.SpikeTrain] + ) + # Append a deep copy of the first spike train in the segment's + # spike train list to itself. + obj.spiketrains.append(copy.deepcopy(obj.spiketrains[0])) + # Call the function get_all_spiketrains with the segment object res0 = nt.get_all_spiketrains(obj) - - targ = targ.spiketrains - - self.assertTrue(len(res0) > 0) - - self.assertEqual(len(targ), len(res0)) - - assert_same_sub_schema(targ, res0) + # Assert that the length of the result res0 is equal to 2. + # This checks if the function correctly returns two spike trains, + # including the original and its copy. + self.assertTrue(len(res0) == 2) def test__get_all_spiketrains__block(self): + # Generate a simple block with 3 segments obj = generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) + supported_objects=[neo.core.Block, + neo.core.Segment, + neo.core.SpikeTrain] + ) + + # Deep copy the generated block for comparison targ = copy.deepcopy(obj) - iobj1 = obj.segments[0] - obj.segments.append(iobj1) + # Manipulate the block by appending a spiketrain from one segment to + # another iobj2 = obj.segments[0].spiketrains[1] obj.segments[1].spiketrains.append(iobj2) + + # Get all spiketrains from the modified block res0 = nt.get_all_spiketrains(obj) + # Convert the target deep copy to a SpikeTrainList targ = SpikeTrainList(targ.list_children_by_class('SpikeTrain')) - self.assertTrue(len(res0) > 0) - - self.assertEqual(len(targ), len(res0)) - + # Perform assertions to validate the results + self.assertTrue( + len(res0) > 0, + "The result of get_all_spiketrains should not be empty.") + self.assertEqual( + len(targ), len(res0), + "The lengths of the SpikeTrainList and result should be equal.") assert_same_sub_schema(targ, res0) def test__get_all_spiketrains__list(self): @@ -1207,8 +1219,6 @@ def test__get_all_spiketrains__list(self): neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) for _ in range(3)] targ = copy.deepcopy(obj) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) obj.append(obj[-1]) @@ -1232,8 +1242,6 @@ def test__get_all_spiketrains__tuple(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) obj.append(obj[-1]) @@ -1256,8 +1264,6 @@ def test__get_all_spiketrains__iter(self): neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) for _ in range(3)] targ = copy.deepcopy(obj) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) obj.append(obj[-1]) @@ -1281,8 +1287,6 @@ def test__get_all_spiketrains__dict(self): neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) for _ in range(3)] targ = copy.deepcopy(obj) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) obj.append(obj[-1]) @@ -1333,8 +1337,6 @@ def test__get_all_events__block(self): neo.core.Block, neo.core.Segment, neo.core.Event]) targ = copy.deepcopy(obj) - iobj1 = obj.segments[0] - obj.segments.append(iobj1) iobj2 = obj.segments[0].events[1] obj.segments[1].events.append(iobj2) res0 = nt.get_all_events(obj) @@ -1356,8 +1358,6 @@ def test__get_all_events__list(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].events[1] obj[2].segments[1].events.append(iobj2) obj.append(obj[-1]) @@ -1381,8 +1381,6 @@ def test__get_all_events__tuple(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].events[1] obj[2].segments[1].events.append(iobj2) obj.append(obj[0]) @@ -1406,8 +1404,6 @@ def test__get_all_events__iter(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].events[1] obj[2].segments[1].events.append(iobj2) obj.append(obj[0]) @@ -1431,8 +1427,6 @@ def test__get_all_events__dict(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].events[1] obj[2].segments[1].events.append(iobj2) obj.append(obj[0]) @@ -1482,10 +1476,6 @@ def test__get_all_epochs__block(self): neo.core.Block, neo.core.Segment, neo.core.Epoch]) targ = copy.deepcopy(obj) - iobj1 = obj.segments[0] - obj.segments.append(iobj1) - iobj2 = obj.segments[0].epochs[1] - obj.segments[1].epochs.append(iobj2) res0 = nt.get_all_epochs(obj) targ = targ.list_children_by_class('Epoch') @@ -1505,8 +1495,6 @@ def test__get_all_epochs__list(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].epochs[1] obj[2].segments[1].epochs.append(iobj2) obj.append(obj[-1]) @@ -1530,8 +1518,6 @@ def test__get_all_epochs__tuple(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].epochs[1] obj[2].segments[1].epochs.append(iobj2) obj.append(obj[0]) @@ -1555,8 +1541,6 @@ def test__get_all_epochs__iter(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].epochs[1] obj[2].segments[1].epochs.append(iobj2) obj.append(obj[0]) @@ -1580,8 +1564,6 @@ def test__get_all_epochs__dict(self): for _ in range(3)] targ = copy.deepcopy(obj) obj.append(obj[-1]) - iobj1 = obj[2].segments[0] - obj[2].segments.append(iobj1) iobj2 = obj[1].segments[2].epochs[1] obj[2].segments[1].epochs.append(iobj2) obj.append(obj[0]) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index fb0ceab10..b472e7a8e 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -29,7 +29,7 @@ def _create_trials_block(n_trials: int = 0, n_spiketrains=n_spiketrains) analogsignals = [AnalogSignal(signal=[.01, 3.3, 9.3], units='uV', sampling_rate=1 * pq.Hz) - ] * n_analogsignals + for _ in range(n_analogsignals)] for spiketrain in spiketrains: segment.spiketrains.append(spiketrain) for analogsignal in analogsignals: diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt index 72c42b6c3..e05df77a8 100644 --- a/requirements/requirements-docs.txt +++ b/requirements/requirements-docs.txt @@ -5,5 +5,5 @@ sphinx>=3.3.0 nbsphinx>=0.8.0 sphinxcontrib-bibtex>1.0.0 sphinx-tabs>=1.3.0 -matplotlib>=3.3.2 +matplotlib>=3.3.2, <3.9.0 # conda install -c conda-forge pandoc diff --git a/requirements/requirements-tutorials.txt b/requirements/requirements-tutorials.txt index 3ee70c3cd..5c142ab15 100644 --- a/requirements/requirements-tutorials.txt +++ b/requirements/requirements-tutorials.txt @@ -1,4 +1,4 @@ # Packages required to execute jupyter notebook tutorials -matplotlib>=3.3.2 +matplotlib>=3.3.2, <3.9.0 h5py>=3.1.0 nixio>=1.5.0 \ No newline at end of file