Skip to content

Commit

Permalink
Merge branch 'master' into enh/add_caching_action
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz-Alexander-Kern authored Jun 11, 2024
2 parents 2102057 + bdd98ee commit 2fb7fb0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 51 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/ruff-formatting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Ruff formatting

on:
workflow_dispatch:
schedule:
- cron: "0 12 * * 0" # Weekly at noon UTC on Sundays


jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Check with Ruff
id: ruff-check
uses: chartboost/ruff-action@v1
with:
src: './elephant'
args: 'format --check'
continue-on-error: true

- name: Fix with Ruff
uses: chartboost/ruff-action@v1
if : ${{ steps.ruff-check.outcome == 'failure' }}
with:
src: './elephant'
args: 'format --verbose'

- name: Create PR
uses: peter-evans/create-pull-request@v5
if : ${{ steps.ruff-check.outcome == 'failure' }}
with:
commit-message: ruff formatting
title: Ruff formatting
body: Reformatting code with ruff
branch: ruff-formatting
82 changes: 32 additions & 50 deletions elephant/test/test_neo_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,40 +1163,52 @@ def test__get_all_spiketrains__spiketrain(self):
# assert_same_sub_schema(targ, res0)

def test__get_all_spiketrains__segment(self):
# Generate a simple segment object containing one spike train,
# supporting objects of type Segment and SpikeTrain.
obj = generate_one_simple_segment(
supported_objects=[neo.core.Segment, neo.core.SpikeTrain])
targ = copy.deepcopy(obj)
obj.spiketrains.append(obj.spiketrains[0])

nb_spiketrain=1,
supported_objects=[neo.core.Segment, neo.core.SpikeTrain]
)
# Append a deep copy of the first spike train in the segment's
# spike train list to itself.
obj.spiketrains.append(copy.deepcopy(obj.spiketrains[0]))
# Call the function get_all_spiketrains with the segment object
res0 = nt.get_all_spiketrains(obj)

targ = targ.spiketrains

self.assertTrue(len(res0) > 0)

self.assertEqual(len(targ), len(res0))

assert_same_sub_schema(targ, res0)
# Assert that the length of the result res0 is equal to 2.
# This checks if the function correctly returns two spike trains,
# including the original and its copy.
self.assertTrue(len(res0) == 2)

def test__get_all_spiketrains__block(self):
# Generate a simple block with 3 segments
obj = generate_one_simple_block(
nb_segment=3,
supported_objects=[
neo.core.Block, neo.core.Segment, neo.core.SpikeTrain])
supported_objects=[neo.core.Block,
neo.core.Segment,
neo.core.SpikeTrain]
)

# Deep copy the generated block for comparison
targ = copy.deepcopy(obj)

iobj1 = obj.segments[0]
obj.segments.append(iobj1)
# Manipulate the block by appending a spiketrain from one segment to
# another
iobj2 = obj.segments[0].spiketrains[1]
obj.segments[1].spiketrains.append(iobj2)

# Get all spiketrains from the modified block
res0 = nt.get_all_spiketrains(obj)

# Convert the target deep copy to a SpikeTrainList
targ = SpikeTrainList(targ.list_children_by_class('SpikeTrain'))

self.assertTrue(len(res0) > 0)

self.assertEqual(len(targ), len(res0))

# Perform assertions to validate the results
self.assertTrue(
len(res0) > 0,
"The result of get_all_spiketrains should not be empty.")
self.assertEqual(
len(targ), len(res0),
"The lengths of the SpikeTrainList and result should be equal.")
assert_same_sub_schema(targ, res0)

def test__get_all_spiketrains__list(self):
Expand All @@ -1207,8 +1219,6 @@ def test__get_all_spiketrains__list(self):
neo.core.Block, neo.core.Segment, neo.core.SpikeTrain])
for _ in range(3)]
targ = copy.deepcopy(obj)
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].spiketrains[1]
obj[2].segments[1].spiketrains.append(iobj2)
obj.append(obj[-1])
Expand All @@ -1232,8 +1242,6 @@ def test__get_all_spiketrains__tuple(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].spiketrains[1]
obj[2].segments[1].spiketrains.append(iobj2)
obj.append(obj[-1])
Expand All @@ -1256,8 +1264,6 @@ def test__get_all_spiketrains__iter(self):
neo.core.Block, neo.core.Segment, neo.core.SpikeTrain])
for _ in range(3)]
targ = copy.deepcopy(obj)
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].spiketrains[1]
obj[2].segments[1].spiketrains.append(iobj2)
obj.append(obj[-1])
Expand All @@ -1281,8 +1287,6 @@ def test__get_all_spiketrains__dict(self):
neo.core.Block, neo.core.Segment, neo.core.SpikeTrain])
for _ in range(3)]
targ = copy.deepcopy(obj)
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].spiketrains[1]
obj[2].segments[1].spiketrains.append(iobj2)
obj.append(obj[-1])
Expand Down Expand Up @@ -1333,8 +1337,6 @@ def test__get_all_events__block(self):
neo.core.Block, neo.core.Segment, neo.core.Event])
targ = copy.deepcopy(obj)

iobj1 = obj.segments[0]
obj.segments.append(iobj1)
iobj2 = obj.segments[0].events[1]
obj.segments[1].events.append(iobj2)
res0 = nt.get_all_events(obj)
Expand All @@ -1356,8 +1358,6 @@ def test__get_all_events__list(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].events[1]
obj[2].segments[1].events.append(iobj2)
obj.append(obj[-1])
Expand All @@ -1381,8 +1381,6 @@ def test__get_all_events__tuple(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].events[1]
obj[2].segments[1].events.append(iobj2)
obj.append(obj[0])
Expand All @@ -1406,8 +1404,6 @@ def test__get_all_events__iter(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].events[1]
obj[2].segments[1].events.append(iobj2)
obj.append(obj[0])
Expand All @@ -1431,8 +1427,6 @@ def test__get_all_events__dict(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].events[1]
obj[2].segments[1].events.append(iobj2)
obj.append(obj[0])
Expand Down Expand Up @@ -1482,10 +1476,6 @@ def test__get_all_epochs__block(self):
neo.core.Block, neo.core.Segment, neo.core.Epoch])
targ = copy.deepcopy(obj)

iobj1 = obj.segments[0]
obj.segments.append(iobj1)
iobj2 = obj.segments[0].epochs[1]
obj.segments[1].epochs.append(iobj2)
res0 = nt.get_all_epochs(obj)

targ = targ.list_children_by_class('Epoch')
Expand All @@ -1505,8 +1495,6 @@ def test__get_all_epochs__list(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].epochs[1]
obj[2].segments[1].epochs.append(iobj2)
obj.append(obj[-1])
Expand All @@ -1530,8 +1518,6 @@ def test__get_all_epochs__tuple(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].epochs[1]
obj[2].segments[1].epochs.append(iobj2)
obj.append(obj[0])
Expand All @@ -1555,8 +1541,6 @@ def test__get_all_epochs__iter(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].epochs[1]
obj[2].segments[1].epochs.append(iobj2)
obj.append(obj[0])
Expand All @@ -1580,8 +1564,6 @@ def test__get_all_epochs__dict(self):
for _ in range(3)]
targ = copy.deepcopy(obj)
obj.append(obj[-1])
iobj1 = obj[2].segments[0]
obj[2].segments.append(iobj1)
iobj2 = obj[1].segments[2].epochs[1]
obj[2].segments[1].epochs.append(iobj2)
obj.append(obj[0])
Expand Down
2 changes: 1 addition & 1 deletion elephant/test/test_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _create_trials_block(n_trials: int = 0,
n_spiketrains=n_spiketrains)
analogsignals = [AnalogSignal(signal=[.01, 3.3, 9.3], units='uV',
sampling_rate=1 * pq.Hz)
] * n_analogsignals
for _ in range(n_analogsignals)]
for spiketrain in spiketrains:
segment.spiketrains.append(spiketrain)
for analogsignal in analogsignals:
Expand Down

0 comments on commit 2fb7fb0

Please sign in to comment.