Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] pairwise phase consistency #392

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions doc/reference/phase_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,11 @@ Phase Analysis
==============

.. automodule:: elephant.phase_analysis

References
----------

.. bibliography:: ../bib/elephant.bib
:labelprefix: ph
:keyprefix: phase-
:style: unsrt
95 changes: 92 additions & 3 deletions elephant/phase_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
phase_locking_value
mean_phase_vector
phase_difference
pairwise_phase_consistency
weighted_phase_lag_index

References
Expand All @@ -31,6 +32,7 @@

__all__ = [
"spike_triggered_phase",
"pairwise_phase_consistency",
"phase_locking_value",
"mean_phase_vector",
"phase_difference",
Expand Down Expand Up @@ -161,8 +163,8 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):

# Find index into signal for each spike
ind_at_spike = (
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
hilbert_transform[phase_i].sampling_period). \
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
hilbert_transform[phase_i].sampling_period). \
simplified.magnitude.astype(int)

# Append new list to the results for this spiketrain
Expand All @@ -173,7 +175,7 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
# Step through all spikes
for spike_i, ind_at_spike_j in enumerate(ind_at_spike):

if interpolate and ind_at_spike_j+1 < len(times):
if interpolate and ind_at_spike_j + 1 < len(times):
# Get relative spike occurrence between the two closest signal
# sample points
# if z->0 spike is more to the left sample
Expand All @@ -182,12 +184,14 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
hilbert_transform[phase_i].sampling_period

# Save hilbert_transform (interpolate on circle)

p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]
).item()
p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1]
).item()
interpolation = (1 - z) * np.exp(complex(0, p1)) \
+ z * np.exp(complex(0, p2))

p12 = np.angle([interpolation])
result_phases[spiketrain_i].append(p12)

Expand Down Expand Up @@ -217,6 +221,91 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
return result_phases, result_amps, result_times


def pairwise_phase_consistency(phases, method='ppc0'):
r"""
The Pairwise Phase Consistency (PPC0) :cite:`phase-Vinck2010_51` is an
improved measure of phase consistency/phase locking value, accounting for
bias due to low trial counts.

PPC0 is computed according to Eq. 14 and 15 of the cited paper.

An improved version of the PPC (PPC1) :cite:`phase-Vinck2012_33` computes
angular difference ony between pairs of spikes within trials.

PPC1 is not implemented yet


.. math::
\text{PPC} = \frac{2}{N(N-1)} \sum_{j=1}^{N-1} \sum_{k=j+1}^N
f(\theta_j, \theta_k)

wherein the function :math:`f` computes the dot product between two unit
vectors and is defined by

.. math::
f(\phi, \omega) = \cos(\phi) \cos(\omega) + \sin(\phi) \sin(\omega)

Parameters
----------
phases : np.ndarray or list of np.ndarray
Spike-triggered phases (output from :func:`spike_triggered_phase`).
If phases is a list of arrays, each array is considered a trial

method : str
'ppc0' - compute PPC between all pairs of spikes

Returns
-------
result_ppc : list of float
Pairwise Phase Consistency

"""
if isinstance(phases, np.ndarray):
phases = [phases]
if not isinstance(phases, (list, tuple)):
raise TypeError("Input must be a list of 1D numpy arrays with phases")

for phase_array in phases:
if not isinstance(phase_array, np.ndarray):
raise TypeError("Each entry of the input list must be an 1D "
"numpy array with phases")
if phase_array.ndim != 1:
raise ValueError("Phase arrays must be 1D (use .flatten())")

if method not in ['ppc0']:
raise ValueError('For method choose out of: ["ppc0"]')

phase_array = np.hstack(phases)
TRuikes marked this conversation as resolved.
Show resolved Hide resolved
n_trials = phase_array.shape[0] # 'spikes' are 'trials' as in paper

# Compute the distance between each pair of phases using dot product
# Optimize computation time using array multiplications instead of for
# loops
p_cos_2d = np.broadcast_to(np.cos(phase_array), (n_trials, n_trials))
p_sin_2d = np.broadcast_to(np.sin(phase_array), (n_trials, n_trials))

# By doing the element-wise multiplication of this matrix with its
# transpose, we get the distance between phases for all possible pairs
# of elements in 'phase'
dot_prod = np.multiply(p_cos_2d, p_cos_2d.T, dtype=np.float32) + \
np.multiply(p_sin_2d, p_sin_2d.T, dtype=np.float32)

# Now average over all elements in temp_results (the diagonal are 1
# and should not be included)
np.fill_diagonal(dot_prod, 0)

if method == 'ppc0':
# Note: each pair i,j is computed twice in dot_prod. do not
# multiply by 2. n_trial * n_trials - n_trials = nr of filled elements
# in dot_prod
ppc = np.sum(dot_prod) / (n_trials * n_trials - n_trials)
return ppc

elif method == 'ppc1':
# TODO: remove all indices from the same trial
return


def phase_locking_value(phases_i, phases_j):
r"""
Calculates the phase locking value (PLV) :cite:`phase-Lachaux99_194`.
Expand Down
126 changes: 126 additions & 0 deletions elephant/test/test_phase_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,132 @@ def test_regression_269(self):
self.assertEqual(len(phases_noint[0]), 2)


class PairwisePhaseConsistencyTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls): # Note: using setUp makes the class call this
TRuikes marked this conversation as resolved.
Show resolved Hide resolved
# function per test, while this way the function is called only
# 1 time per TestCase, slightly more efficient (0.5s tough)

# Same setup as SpikeTriggerePhaseTestCase
tlen0 = 100 * pq.s
f0 = 20. * pq.Hz
fs0 = 1 * pq.ms
t0 = np.arange(
0, tlen0.rescale(pq.s).magnitude,
fs0.rescale(pq.s).magnitude) * pq.s
cls.anasig0 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)

# Spiketrain with perfect locking
cls.st_perfect = SpikeTrain(
np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)

# Spiketrain with inperfect locking
cls.st_inperfect = SpikeTrain(
[100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)

# Generate 2 'bursting' spiketrains, both locking on sinus period,
# but with different strengths
n_spikes = 3 # n spikes per burst
burst_interval = (1 / f0.magnitude) * pq.s
burst_start_times = np.arange(
0,
tlen0.rescale('ms').magnitude,
burst_interval.rescale('ms').magnitude
)

# Spiketrain with strong locking
burst_freq_strong = 200. * pq.Hz # strongly locking unit
burst_spike_interval = (1 / burst_freq_strong.magnitude) * pq.s
st_in_burst = np.arange(
0,
burst_spike_interval.rescale('ms').magnitude * n_spikes,
burst_spike_interval.rescale('ms').magnitude
)
st = [st_in_burst + t_offset for t_offset in burst_start_times]
st = np.hstack(st) * pq.ms
cls.st_bursting_strong = SpikeTrain(st,
t_start=0 * pq.ms,
t_stop=tlen0
)

# Spiketrain with weak locking
burst_freq_weak = 100. * pq.Hz # weak locking unit
burst_spike_interval = (1 / burst_freq_weak.magnitude) * pq.s
st_in_burst = np.arange(
0,
burst_spike_interval.rescale('ms').magnitude * n_spikes,
burst_spike_interval.rescale('ms').magnitude
)
st = [st_in_burst + t_offset for t_offset in burst_start_times]
st = np.hstack(st) * pq.ms
cls.st_bursting_weak = SpikeTrain(st,
t_start=0 * pq.ms,
t_stop=tlen0
)

def test_perfect_locking(self):
phases, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_perfect,
interpolate=True
)
# Pass input as single array
ppc0 = elephant.phase_analysis.pairwise_phase_consistency(
phases[0], method='ppc0'
)
self.assertEqual(ppc0, 1)
self.assertIsInstance(ppc0, float)

# Pass input as list of arrays
n_phases = int(phases[0].shape[0] / 2)
phases_cut = [phases[0][i * 2:i * 2 + 2] for i in range(n_phases)]
ppc0 = elephant.phase_analysis.pairwise_phase_consistency(
phases_cut, method='ppc0'
)
self.assertEqual(ppc0, 1)
self.assertIsInstance(ppc0, float)

def test_inperfect_locking(self):
phases, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_inperfect,
interpolate=True
)
# Pass input as single array
ppc0 = elephant.phase_analysis.pairwise_phase_consistency(
phases[0], method='ppc0'
)
self.assertLess(ppc0, 1)
self.assertIsInstance(ppc0, float)

def test_strong_vs_weak_locking(self):
phases_weak, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_bursting_weak,
interpolate=True
)
# Pass input as single array
ppc0_weak = elephant.phase_analysis.pairwise_phase_consistency(
phases_weak[0], method='ppc0'
)
phases_strong, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_bursting_strong,
interpolate=True
)
# Pass input as single array
ppc0_strong = elephant.phase_analysis.pairwise_phase_consistency(
phases_strong[0], method='ppc0'
)

self.assertLess(ppc0_weak, ppc0_strong)


class MeanVectorTestCase(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-15
Expand Down