Skip to content

Commit

Permalink
add tests for trial object
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz-Alexander-Kern committed Oct 9, 2024
1 parent 7b24455 commit 2d4f92d
Showing 1 changed file with 115 additions and 105 deletions.
220 changes: 115 additions & 105 deletions elephant/test/test_unitary_event_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import quantities as pq
from numpy.testing import assert_array_equal

from elephant.trials import TrialsFromLists
import elephant.unitary_event_analysis as ue
from elephant.datasets import download, ELEPHANT_TMP_DIR
from numpy.testing import assert_array_almost_equal
Expand Down Expand Up @@ -324,52 +325,56 @@ def test_jointJ_window_analysis(self):
sts2 = self.sts2_neo

# joinJ_window_analysis requires the following:
# A list of spike trains(neo.SpikeTrain objects) in different trials:
data = list(zip(*[sts1,sts2]))

win_size = 100 * pq.ms
bin_size = 5 * pq.ms
win_step = 20 * pq.ms
pattern_hash = [3]
UE_dic = ue.jointJ_window_analysis(spiketrains=data,
pattern_hash=pattern_hash,
bin_size=bin_size,
win_size=win_size,
win_step=win_step)
expected_Js = np.array(
[0.57953708, 0.47348757, 0.1729669,
0.01883295, -0.21934742, -0.80608759])
expected_n_emp = np.array(
[9., 9., 7., 7., 6., 6.])
expected_n_exp = np.array(
[6.5, 6.85, 6.05, 6.6, 6.45, 8.7])
expected_rate = np.array(
[[0.02166667, 0.01861111],
[0.02277778, 0.01777778],
[0.02111111, 0.01777778],
[0.02277778, 0.01888889],
[0.02305556, 0.01722222],
[0.02388889, 0.02055556]]) * pq.kHz
expected_indecis_tril26 = [4., 4.]
expected_indecis_tril4 = [1.]
assert_array_almost_equal(UE_dic['Js'].squeeze(), expected_Js)
assert_array_almost_equal(UE_dic['n_emp'].squeeze(), expected_n_emp)
assert_array_almost_equal(UE_dic['n_exp'].squeeze(), expected_n_exp)
assert_array_almost_equal(UE_dic['rate_avg'].squeeze(), expected_rate)
assert_array_almost_equal(UE_dic['indices']['trial26'],
expected_indecis_tril26)
assert_array_almost_equal(UE_dic['indices']['trial4'],
expected_indecis_tril4)

# check the input parameters
input_params = UE_dic['input_parameters']
self.assertEqual(input_params['pattern_hash'], pattern_hash)
self.assertEqual(input_params['bin_size'], bin_size)
self.assertEqual(input_params['win_size'], win_size)
self.assertEqual(input_params['win_step'], win_step)
self.assertEqual(input_params['method'], 'analytic_TrialByTrial')
self.assertEqual(input_params['t_start'], 0 * pq.s)
self.assertEqual(input_params['t_stop'], 200 * pq.ms)
# A list of spike trains(neo.SpikeTrain objects) in different trials, or trials.Trial object
test_cases = (
list(zip(*[sts1, sts2])), # list
TrialsFromLists(list(zip(*[sts1, sts2]))), # Trial object
)
for data in test_cases:
with self.subTest(data=data):
win_size = 100 * pq.ms
bin_size = 5 * pq.ms
win_step = 20 * pq.ms
pattern_hash = [3]
UE_dic = ue.jointJ_window_analysis(spiketrains=data,
pattern_hash=pattern_hash,
bin_size=bin_size,
win_size=win_size,
win_step=win_step)
expected_Js = np.array(
[0.57953708, 0.47348757, 0.1729669,
0.01883295, -0.21934742, -0.80608759])
expected_n_emp = np.array(
[9., 9., 7., 7., 6., 6.])
expected_n_exp = np.array(
[6.5, 6.85, 6.05, 6.6, 6.45, 8.7])
expected_rate = np.array(
[[0.02166667, 0.01861111],
[0.02277778, 0.01777778],
[0.02111111, 0.01777778],
[0.02277778, 0.01888889],
[0.02305556, 0.01722222],
[0.02388889, 0.02055556]]) * pq.kHz
expected_indecis_tril26 = [4., 4.]
expected_indecis_tril4 = [1.]
assert_array_almost_equal(UE_dic['Js'].squeeze(), expected_Js)
assert_array_almost_equal(UE_dic['n_emp'].squeeze(), expected_n_emp)
assert_array_almost_equal(UE_dic['n_exp'].squeeze(), expected_n_exp)
assert_array_almost_equal(UE_dic['rate_avg'].squeeze(), expected_rate)
assert_array_almost_equal(UE_dic['indices']['trial26'],
expected_indecis_tril26)
assert_array_almost_equal(UE_dic['indices']['trial4'],
expected_indecis_tril4)

# check the input parameters
input_params = UE_dic['input_parameters']
self.assertEqual(input_params['pattern_hash'], pattern_hash)
self.assertEqual(input_params['bin_size'], bin_size)
self.assertEqual(input_params['win_size'], win_size)
self.assertEqual(input_params['win_step'], win_step)
self.assertEqual(input_params['method'], 'analytic_TrialByTrial')
self.assertEqual(input_params['t_start'], 0 * pq.s)
self.assertEqual(input_params['t_stop'], 200 * pq.ms)

@staticmethod
def load_gdf2Neo(fname, trigger, t_pre, t_post):
Expand Down Expand Up @@ -501,69 +506,74 @@ def test_multiple_neurons(self):
np.random.seed(12)

# Create a list of lists containing 3 Trials with 5 spiketrains
spiketrains = \
spiketrains_poisson = \
[StationaryPoissonProcess(
rate=50 * pq.Hz, t_stop=1 * pq.s).generate_n_spiketrains(5)
for _ in range(3)]

spiketrains = list(zip(*spiketrains))
UE_dic = ue.jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms,
win_size=300 * pq.ms,
win_step=100 * pq.ms)

js_expected = [[0.3978179],
[0.08131966],
[-1.4239882],
[-0.9377029],
[-0.3374434],
[-0.2043383],
[-1.001536],
[-np.inf]]
indices_expected = \
{'trial3': [12, 27, 31, 34, 27, 31, 34, 136, 136, 136],
'trial4': [4, 60, 60, 60, 117, 117, 117]}
n_emp_expected = [[5.],
[4.],
[1.],
[2.],
[2.],
[2.],
[1.],
[0.]]
n_exp_expected = [[3.5591667],
[3.4536111],
[3.3158333],
[3.8466666],
[2.370278],
[2.0811112],
[2.4011111],
[3.0533333]]
rate_expected = [[[0.042, 0.03933334, 0.048]],
[[0.04533333, 0.038, 0.05]],
[[0.046, 0.04, 0.04666667]],
[[0.05066667, 0.042, 0.046]],
[[0.04466667, 0.03666667, 0.04066667]],
[[0.04066667, 0.03533333, 0.04333333]],
[[0.03933334, 0.038, 0.038]],
[[0.04066667, 0.04866667, 0.03666667]]] * (1. / pq.ms)
input_parameters_expected = {'pattern_hash': [7],
'bin_size': 5 * pq.ms,
'win_size': 300 * pq.ms,
'win_step': 100 * pq.ms,
'method': 'analytic_TrialByTrial',
't_start': 0 * pq.s,
't_stop': 1 * pq.s, 'n_surrogates': 100}

assert_array_almost_equal(UE_dic['Js'], js_expected)
assert_array_almost_equal(UE_dic['n_emp'], n_emp_expected)
assert_array_almost_equal(UE_dic['n_exp'], n_exp_expected)
assert_array_almost_equal(UE_dic['rate_avg'], rate_expected)
self.assertEqual(sorted(UE_dic['indices'].keys()),
sorted(indices_expected.keys()))
for trial_key in indices_expected.keys():
assert_array_equal(indices_expected[trial_key],
UE_dic['indices'][trial_key])
self.assertEqual(UE_dic['input_parameters'], input_parameters_expected)
test_cases = (
list(zip(*spiketrains_poisson)), # list
TrialsFromLists(list(zip(*spiketrains_poisson))), # Trial object
)
for spiketrains in test_cases:
with self.subTest(data=spiketrains):
UE_dic = ue.jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms,
win_size=300 * pq.ms,
win_step=100 * pq.ms)

js_expected = [[0.3978179],
[0.08131966],
[-1.4239882],
[-0.9377029],
[-0.3374434],
[-0.2043383],
[-1.001536],
[-np.inf]]
indices_expected = \
{'trial3': [12, 27, 31, 34, 27, 31, 34, 136, 136, 136],
'trial4': [4, 60, 60, 60, 117, 117, 117]}
n_emp_expected = [[5.],
[4.],
[1.],
[2.],
[2.],
[2.],
[1.],
[0.]]
n_exp_expected = [[3.5591667],
[3.4536111],
[3.3158333],
[3.8466666],
[2.370278],
[2.0811112],
[2.4011111],
[3.0533333]]
rate_expected = [[[0.042, 0.03933334, 0.048]],
[[0.04533333, 0.038, 0.05]],
[[0.046, 0.04, 0.04666667]],
[[0.05066667, 0.042, 0.046]],
[[0.04466667, 0.03666667, 0.04066667]],
[[0.04066667, 0.03533333, 0.04333333]],
[[0.03933334, 0.038, 0.038]],
[[0.04066667, 0.04866667, 0.03666667]]] * (1. / pq.ms)
input_parameters_expected = {'pattern_hash': [7],
'bin_size': 5 * pq.ms,
'win_size': 300 * pq.ms,
'win_step': 100 * pq.ms,
'method': 'analytic_TrialByTrial',
't_start': 0 * pq.s,
't_stop': 1 * pq.s, 'n_surrogates': 100}

assert_array_almost_equal(UE_dic['Js'], js_expected)
assert_array_almost_equal(UE_dic['n_emp'], n_emp_expected)
assert_array_almost_equal(UE_dic['n_exp'], n_exp_expected)
assert_array_almost_equal(UE_dic['rate_avg'], rate_expected)
self.assertEqual(sorted(UE_dic['indices'].keys()),
sorted(indices_expected.keys()))
for trial_key in indices_expected.keys():
assert_array_equal(indices_expected[trial_key],
UE_dic['indices'][trial_key])
self.assertEqual(UE_dic['input_parameters'], input_parameters_expected)


if __name__ == '__main__':
Expand Down

0 comments on commit 2d4f92d

Please sign in to comment.