diff --git a/elephant/__init__.py b/elephant/__init__.py index 1ceea64fb..17f7ea1d1 100644 --- a/elephant/__init__.py +++ b/elephant/__init__.py @@ -6,28 +6,30 @@ :license: Modified BSD, see LICENSE.txt for details. """ -from . import (cell_assembly_detection, - change_point_detection, - conversion, - cubic, - current_source_density, - gpfa, - kernels, - neo_tools, - phase_analysis, - signal_processing, - spade, - spectral, - spike_train_correlation, - spike_train_dissimilarity, - spike_train_generation, - spike_train_surrogates, - spike_train_synchrony, - sta, - trials, - unitary_event_analysis, - waveform_features, - statistics) +from . import ( + cell_assembly_detection, + change_point_detection, + conversion, + cubic, + current_source_density, + gpfa, + kernels, + neo_tools, + phase_analysis, + signal_processing, + spade, + spectral, + spike_train_correlation, + spike_train_dissimilarity, + spike_train_generation, + spike_train_surrogates, + spike_train_synchrony, + sta, + trials, + unitary_event_analysis, + waveform_features, + statistics, +) # not included modules on purpose: # parallel: avoid warns when elephant is imported @@ -43,8 +45,9 @@ def _get_version(): import os + elephant_dir = os.path.dirname(__file__) - with open(os.path.join(elephant_dir, 'VERSION')) as version_file: + with open(os.path.join(elephant_dir, "VERSION")) as version_file: version = version_file.read().strip() return version diff --git a/elephant/asset/asset.py b/elephant/asset/asset.py index 4f464d7ca..8bd39f957 100644 --- a/elephant/asset/asset.py +++ b/elephant/asset/asset.py @@ -119,6 +119,7 @@ :func:`viziphant.asset.plot_synchronous_events`. """ + from __future__ import division, print_function, unicode_literals import math @@ -164,15 +165,18 @@ "synchronous_events_contains_all", "synchronous_events_overlap", "get_neurons_in_sse", - "get_sse_start_and_end_time_bins" + "get_sse_start_and_end_time_bins", ] # Create logger and set configuration logger = logging.getLogger(__file__) log_handler = logging.StreamHandler() log_handler.setFormatter( - logging.Formatter(f"[%(asctime)s] {__name__[__name__.rfind('.')+1::]} -" - " %(levelname)s: %(message)s")) + logging.Formatter( + f"[%(asctime)s] {__name__[__name__.rfind('.')+1::]} -" + " %(levelname)s: %(message)s" + ) +) logger.addHandler(log_handler) logger.propagate = False @@ -206,12 +210,11 @@ def _signals_same_attribute(signals, attr_name): If `signals` have different `attr_name` attribute values. """ if len(signals) == 0: - raise ValueError('Empty signals list') + raise ValueError("Empty signals list") attribute = getattr(signals[0], attr_name) for sig in signals[1:]: if getattr(sig, attr_name) != attribute: - raise ValueError( - "Signals have different '{}' values".format(attr_name)) + raise ValueError("Signals have different '{}' values".format(attr_name)) return attribute @@ -307,17 +310,22 @@ def _transactions(spiketrains, bin_size, t_start, t_stop, ids=None): try: ids, trains = zip(*spiketrains) except TypeError: - raise TypeError('spiketrains must be either a list of ' + - 'SpikeTrains or a list of (id, SpikeTrain) pairs') + raise TypeError( + "spiketrains must be either a list of " + + "SpikeTrains or a list of (id, SpikeTrain) pairs" + ) # Bin the spike trains and take for each of them the ids of filled bins binned = conv.BinnedSpikeTrain( - trains, bin_size=bin_size, t_start=t_start, t_stop=t_stop) + trains, bin_size=bin_size, t_start=t_start, t_stop=t_stop + ) filled_bins = binned.spike_indices # Compute and return the transaction list - return [[train_id for train_id, b in zip(ids, filled_bins) - if bin_id in b] for bin_id in range(binned.n_bins)] + return [ + [train_id for train_id, b in zip(ids, filled_bins) if bin_id in b] + for bin_id in range(binned.n_bins) + ] def _analog_signal_step_interp(signal, times): @@ -347,8 +355,8 @@ def _analog_signal_step_interp(signal, times): # Compute the ids of the signal times to the left of each time in times time_ids = np.floor( - ((times - signal.t_start) / dt).rescale( - pq.dimensionless).magnitude).astype('i') + ((times - signal.t_start) / dt).rescale(pq.dimensionless).magnitude + ).astype("i") return (signal.magnitude[time_ids] * signal.units).rescale(signal.units) @@ -358,8 +366,9 @@ def _analog_signal_step_interp(signal, times): # ============================================================================= -def _stretched_metric_2d(x, y, stretch, ref_angle, working_memory=None, - mapped_array_file=None, verbose=None): +def _stretched_metric_2d( + x, y, stretch, ref_angle, working_memory=None, mapped_array_file=None, verbose=None +): r""" Given a list of points on the real plane, identified by their abscissa `x` and ordinate `y`, compute a stretched transformation of the Euclidean @@ -432,12 +441,15 @@ def _stretched_metric_2d(x, y, stretch, ref_angle, working_memory=None, """ if verbose is not None: - warnings.warn("The 'verbose' parameter is deprecated and will be " - "removed in the future. Its functionality is still " - "available by using the logging module from Python. " - "We recommend transitioning to the logging module " - "for improved control and flexibility in handling " - "verbosity levels.", DeprecationWarning) + warnings.warn( + "The 'verbose' parameter is deprecated and will be " + "removed in the future. Its functionality is still " + "available by using the logging module from Python. " + "We recommend transitioning to the logging module " + "for improved control and flexibility in handling " + "verbosity levels.", + DeprecationWarning, + ) alpha = np.deg2rad(ref_angle) # reference angle in radians @@ -467,7 +479,6 @@ def calculate_stretch_mat(theta_mat, D_mat): return _stretch_mat if working_memory is None: - logger.info("Finding distances without chunking") # Compute the matrix D[i, j] of Euclidean distances among points @@ -493,8 +504,7 @@ def calculate_stretch_mat(theta_mat, D_mat): # processed per iteration. Mininum is 1. Working memory size is in MB. # Size is computed for a float32 matrix. The function # `pairwise_distances_chunked` returns half of the possible size. - estimated_chunk = max( - ((working_memory * 1024 * 1024) // (len(y) * 4)) // 2, 1) + estimated_chunk = max(((working_memory * 1024 * 1024) // (len(y) * 4)) // 2, 1) # The number of rows in a chunk cannot be larger than the maximum estimated_chunk = min(len(x), estimated_chunk) @@ -508,9 +518,11 @@ def calculate_stretch_mat(theta_mat, D_mat): if last_chunk > 0: it_todo += 1 - logger.info(f"Estimated chunk size: {estimated_chunk}; " - f"Dimension: ({len(x)}, {len(y)}), " - f"Number of chunked iterations: {it_todo}") + logger.info( + f"Estimated chunk size: {estimated_chunk}; " + f"Dimension: ({len(x)}, {len(y)}), " + f"Number of chunked iterations: {it_todo}" + ) # x and y sizes are the same if mapped_array_file is None: @@ -519,11 +531,13 @@ def calculate_stretch_mat(theta_mat, D_mat): try: stretch_mat = np.empty((len(x), len(y)), dtype=np.float32) except MemoryError: - required_size = (len(x) * len(y) * 4) / (1024 ** 3) - raise MemoryError("Can't allocate array in memory. Specify " - "a temporary disk file to map the array " - "to the disk. Operations will be slower. " - f"The required size is {required_size} GiB") + required_size = (len(x) * len(y) * 4) / (1024**3) + raise MemoryError( + "Can't allocate array in memory. Specify " + "a temporary disk file to map the array " + "to the disk. Operations will be slower. " + f"The required size is {required_size} GiB" + ) else: # Using an array mapped to disk. Store in the file passed as @@ -531,33 +545,36 @@ def calculate_stretch_mat(theta_mat, D_mat): logger.info(f"Creating disk array at '{mapped_array_file.name}'.") - stretch_mat = np.memmap(mapped_array_file, mode='w+', - shape=(len(x), len(y)), - dtype=np.float32) + stretch_mat = np.memmap( + mapped_array_file, mode="w+", shape=(len(x), len(y)), dtype=np.float32 + ) # Buffer to store the computations per iteration, to avoid # writing to the file after every single operation chunk_mat = np.empty((estimated_chunk, len(y)), dtype=np.float32) for D_chunk in tqdm( - pairwise_distances_chunked(points, - working_memory=working_memory), - desc='Pairwise distances chunked', - total=it_todo): - + pairwise_distances_chunked(points, working_memory=working_memory), + desc="Pairwise distances chunked", + total=it_todo, + ): chunk_size = D_chunk.shape[0] - assert (chunk_size == estimated_chunk or - chunk_size == last_chunk) # Safety check + assert ( + chunk_size == estimated_chunk or chunk_size == last_chunk + ) # Safety check - dX = x_array[:, start: start + chunk_size].T - x_array - dY = y_array[:, start: start + chunk_size].T - y_array + dX = x_array[:, start : start + chunk_size].T - x_array + dY = y_array[:, start : start + chunk_size].T - y_array # If not using an array mapped to the disk, the output of # the theta computations are written directly to the # stretch_mat. Otherwise, write to the buffer - out = stretch_mat[start: start + chunk_size, :] \ - if mapped_array_file is None else chunk_mat[:chunk_size, :] + out = ( + stretch_mat[start : start + chunk_size, :] + if mapped_array_file is None + else chunk_mat[:chunk_size, :] + ) theta_chunk = np.arctan2(dY, dX, out=out) @@ -566,7 +583,7 @@ def calculate_stretch_mat(theta_mat, D_mat): # If mapping to file, transfer from the buffer to stretch_mat if mapped_array_file is not None: - stretch_mat[start: start + chunk_size, :] = theta_chunk + stretch_mat[start : start + chunk_size, :] = theta_chunk start += chunk_size @@ -579,25 +596,36 @@ def _interpolate_signals(signals, sampling_times, verbose=None): Interpolate signals at given sampling times. """ if verbose is not None: - warnings.warn("The 'verbose' parameter is deprecated and will be " - "removed in the future. Its functionality is still " - "available by using the logging module from Python. " - "We recommend transitioning to the logging module " - "for improved control and flexibility in handling " - "verbosity levels.", DeprecationWarning) + warnings.warn( + "The 'verbose' parameter is deprecated and will be " + "removed in the future. Its functionality is still " + "available by using the logging module from Python. " + "We recommend transitioning to the logging module " + "for improved control and flexibility in handling " + "verbosity levels.", + DeprecationWarning, + ) # Reshape all signals to one-dimensional array object (e.g. AnalogSignal) for i, signal in enumerate(signals): if signal.ndim == 2: signals[i] = signal.flatten() elif signal.ndim > 2: - raise ValueError('elements in fir_rates must have 2 dimensions') + raise ValueError("elements in fir_rates must have 2 dimensions") logger.info("Create time slices of the rates...") # Interpolate in the time bins - interpolated_signal = np.vstack([_analog_signal_step_interp( - signal, sampling_times).rescale('Hz').magnitude - for signal in signals]) * pq.Hz + interpolated_signal = ( + np.vstack( + [ + _analog_signal_step_interp(signal, sampling_times) + .rescale("Hz") + .magnitude + for signal in signals + ] + ) + * pq.Hz + ) return interpolated_signal @@ -639,8 +667,8 @@ def _choose_backend(self): # If CUDA is detected, always use CUDA. # If OpenCL is detected, don't use it by default to avoid the system # becoming unresponsive until the program terminates. - use_cuda = int(os.getenv("ELEPHANT_USE_CUDA", '1')) - use_opencl = int(os.getenv("ELEPHANT_USE_OPENCL", '1')) + use_cuda = int(os.getenv("ELEPHANT_USE_CUDA", "1")) + use_opencl = int(os.getenv("ELEPHANT_USE_OPENCL", "1")) cuda_detected = get_cuda_capability_major() != 0 opencl_detected = get_opencl_capability() @@ -655,9 +683,11 @@ def _split_axis(self, chunk_size, axis_size, min_chunk_size=None): if self.max_chunk_size is not None: chunk_size = min(chunk_size, self.max_chunk_size) if min_chunk_size is not None and chunk_size < min_chunk_size: - raise ValueError(f"[GPU not enough memory] Impossible to split " - f"the array into chunks of size at least " - f"{min_chunk_size} to fit into GPU memory") + raise ValueError( + f"[GPU not enough memory] Impossible to split " + f"the array into chunks of size at least " + f"{min_chunk_size} to fit into GPU memory" + ) n_chunks = math.ceil(axis_size / chunk_size) chunk_size = math.ceil(axis_size / n_chunks) # align in size if min_chunk_size is not None: @@ -675,19 +705,30 @@ def _split_axis(self, chunk_size, axis_size, min_chunk_size=None): class _JSFUniformOrderStat3D(_GPUBackend): - def __init__(self, n, d, precision='float', verbose=None, - cuda_threads=64, cuda_cwr_loops=32, tolerance=1e-5, - max_chunk_size=None): + def __init__( + self, + n, + d, + precision="float", + verbose=None, + cuda_threads=64, + cuda_cwr_loops=32, + tolerance=1e-5, + max_chunk_size=None, + ): super().__init__(max_chunk_size=max_chunk_size) if d > n: raise ValueError(f"d ({d}) must be less or equal n ({n})") if verbose is not None: - warnings.warn("The 'verbose' parameter is deprecated and will be " - "removed in the future. Its functionality is still " - "available by using the logging module from Python. " - "We recommend transitioning to the logging module " - "for improved control and flexibility in handling " - "verbosity levels.", DeprecationWarning) + warnings.warn( + "The 'verbose' parameter is deprecated and will be " + "removed in the future. Its functionality is still " + "available by using the logging module from Python. " + "We recommend transitioning to the logging module " + "for improved control and flexibility in handling " + "verbosity levels.", + DeprecationWarning, + ) self.n = n self.d = d self.precision = precision @@ -750,14 +791,16 @@ def _combinations_with_replacement(self): sequence_sorted[-1] = last_element yield tuple(sequence_sorted) increment_id = self.d - 2 - while increment_id > 0 and sequence_sorted[increment_id - 1] == \ - sequence_sorted[increment_id]: + while ( + increment_id > 0 + and sequence_sorted[increment_id - 1] == sequence_sorted[increment_id] + ): increment_id -= 1 - sequence_sorted[increment_id + 1:] = input_order[increment_id + 1:] + sequence_sorted[increment_id + 1 :] = input_order[increment_id + 1 :] sequence_sorted[increment_id] += 1 def cpu(self, log_du): - log_1 = np.log(1.) + log_1 = np.log(1.0) # Compute the log of the integral's coefficient logK = np.sum(np.log(np.arange(1, self.n + 1))) # Add to the 3D matrix u a bottom layer equal to 0 and a @@ -778,13 +821,16 @@ def cpu(self, log_du): # initialise probabilities to 0 P_total = np.zeros( log_du.shape[0], - dtype=np.float32 if self.precision == 'float' else np.float64 + dtype=np.float32 if self.precision == "float" else np.float64, ) for iter_id, matrix_entries in enumerate( - tqdm(self._combinations_with_replacement(), - total=self.num_iterations, - desc="Joint survival function")): + tqdm( + self._combinations_with_replacement(), + total=self.num_iterations, + desc="Joint survival function", + ) + ): # if we are running with MPI if mpi_accelerated and iter_id % size != rank: continue @@ -819,12 +865,10 @@ def cpu(self, log_du): totals = np.zeros_like(P_total) # exchange all the results - mpi_float_type = MPI.FLOAT \ - if self.precision == 'float' else MPI.DOUBLE + mpi_float_type = MPI.FLOAT if self.precision == "float" else MPI.DOUBLE comm.Allreduce( - [P_total, mpi_float_type], - [totals, mpi_float_type], - op=MPI.SUM) + [P_total, mpi_float_type], [totals, mpi_float_type], op=MPI.SUM + ) # We need to return the collected totals instead of the local # P_total @@ -834,12 +878,16 @@ def cpu(self, log_du): def _compile_template(self, template_name, **kwargs): from jinja2 import Template + cu_template_path = Path(__file__).parent / template_name cu_template = Template(cu_template_path.read_text()) asset_cu = cu_template.render( precision=self.precision, CWR_LOOPS=self.cuda_cwr_loops, - N=self.n, D=self.d, **kwargs) + N=self.n, + D=self.d, + **kwargs, + ) return asset_cu def pyopencl(self, log_du, device_id=0): @@ -859,39 +907,37 @@ def pyopencl(self, log_du, device_id=0): # A queue bounded to the device queue = cl.CommandQueue(context) - max_l_block = device.local_mem_size // ( - self.dtype.itemsize * (self.d + 2)) - n_threads = min(self.cuda_threads, max_l_block, - device.max_work_group_size) + max_l_block = device.local_mem_size // (self.dtype.itemsize * (self.d + 2)) + n_threads = min(self.cuda_threads, max_l_block, device.max_work_group_size) if n_threads > 32: # It's more efficient to make the number of threads # a multiple of the warp size (32). n_threads -= n_threads % 32 - iteration_table_str = ", ".join(f"{val}LU" for val in - self.map_iterations.flatten()) + iteration_table_str = ", ".join( + f"{val}LU" for val in self.map_iterations.flatten() + ) iteration_table_str = "{%s}" % iteration_table_str log_factorial = np.r_[0, np.cumsum(np.log(range(1, self.n + 1)))] logK = log_factorial[-1] log_factorial_str = ", ".join(f"{val:.10f}" for val in log_factorial) log_factorial_str = "{%s}" % log_factorial_str - atomic_int = 'int' if self.precision == 'float' else 'long' + atomic_int = "int" if self.precision == "float" else "long" # GPU_MAX_HEAP_SIZE OpenCL flag is set to 2 Gb (1 << 31) by default - mem_avail = min(device.max_mem_alloc_size, device.global_mem_size, - 1 << 31) + mem_avail = min(device.max_mem_alloc_size, device.global_mem_size, 1 << 31) # 4 * (D + 1) * size + 8 * size == mem_avail chunk_size = mem_avail // (4 * log_du.shape[1] + self.dtype.itemsize) - chunk_size, split_idx = self._split_axis(chunk_size=chunk_size, - axis_size=u_length) + chunk_size, split_idx = self._split_axis( + chunk_size=chunk_size, axis_size=u_length + ) P_total = np.empty(u_length, dtype=self.dtype) P_total_gpu = cl_array.Array(queue, shape=chunk_size, dtype=self.dtype) for i_start, i_end in split_idx: - log_du_gpu = cl_array.to_device(queue, log_du[i_start: i_end], - async_=True) + log_du_gpu = cl_array.to_device(queue, log_du[i_start:i_end], async_=True) P_total_gpu.fill(0, queue=queue) chunk_size = i_end - i_start l_block = min(n_threads, chunk_size) @@ -904,9 +950,11 @@ def pyopencl(self, log_du, device_id=0): # grid_size must be at least l_num_blocks grid_size = l_num_blocks - logger.info(f"[Joint prob. matrix] it_todo={it_todo}, " - f"grid_size={grid_size}, L_BLOCK={l_block}, " - f"N_THREADS={n_threads}") + logger.info( + f"[Joint prob. matrix] it_todo={it_todo}, " + f"grid_size={grid_size}, L_BLOCK={l_block}, " + f"N_THREADS={n_threads}" + ) # OpenCL defines unsigned long as uint64, therefore we're adding # the LU suffix, not LLU, which would indicate unsupported uint128 @@ -921,7 +969,7 @@ def pyopencl(self, log_du, device_id=0): iteration_table=iteration_table_str, log_factorial=log_factorial_str, ATOMIC_UINT=f"unsigned {atomic_int}", - ASSET_ENABLE_DOUBLE_SUPPORT=int(self.precision == "double") + ASSET_ENABLE_DOUBLE_SUPPORT=int(self.precision == "double"), ) program = cl.Program(context, asset_cl).build() @@ -930,10 +978,16 @@ def pyopencl(self, log_du, device_id=0): cl.enqueue_barrier(queue) kernel = program.jsf_uniform_orderstat_3d_kernel - kernel(queue, (grid_size,), (n_threads,), - P_total_gpu.data, log_du_gpu.data, g_times_l=True) + kernel( + queue, + (grid_size,), + (n_threads,), + P_total_gpu.data, + log_du_gpu.data, + g_times_l=True, + ) - P_total_gpu[:chunk_size].get(ary=P_total[i_start: i_end]) + P_total_gpu[:chunk_size].get(ary=P_total[i_start:i_end]) return P_total @@ -946,8 +1000,7 @@ def pycuda(self, log_du): import pycuda.driver as drv from pycuda.compiler import SourceModule except ImportError as err: - raise ImportError( - "Install pycuda with 'pip install pycuda'") from err + raise ImportError("Install pycuda with 'pip install pycuda'") from err self._check_input(log_du) @@ -957,9 +1010,9 @@ def pycuda(self, log_du): device = pycuda.autoinit.device max_l_block = device.MAX_SHARED_MEMORY_PER_BLOCK // ( - self.dtype.itemsize * (self.d + 2)) - n_threads = min(self.cuda_threads, max_l_block, - device.MAX_THREADS_PER_BLOCK) + self.dtype.itemsize * (self.d + 2) + ) + n_threads = min(self.cuda_threads, max_l_block, device.MAX_THREADS_PER_BLOCK) if n_threads > device.WARP_SIZE: # It's more efficient to make the number of threads # a multiple of the warp size (32). @@ -972,15 +1025,16 @@ def pycuda(self, log_du): free, total = drv.mem_get_info() # 4 * (D + 1) * size + 8 * size == mem_avail chunk_size = free // (4 * log_du.shape[1] + self.dtype.itemsize) - chunk_size, split_idx = self._split_axis(chunk_size=chunk_size, - axis_size=u_length) + chunk_size, split_idx = self._split_axis( + chunk_size=chunk_size, axis_size=u_length + ) P_total = np.empty(u_length, dtype=self.dtype) P_total_gpu = gpuarray.GPUArray(chunk_size, dtype=self.dtype) log_du_gpu = drv.mem_alloc(4 * chunk_size * log_du.shape[1]) for i_start, i_end in split_idx: - drv.memcpy_htod_async(dest=log_du_gpu, src=log_du[i_start: i_end]) + drv.memcpy_htod_async(dest=log_du_gpu, src=log_du[i_start:i_end]) P_total_gpu.fill(0) chunk_size = i_end - i_start l_block = min(n_threads, chunk_size) @@ -994,9 +1048,11 @@ def pycuda(self, log_du): # grid_size must be at least l_num_blocks grid_size = l_num_blocks - logger.info(f"[Joint prob. matrix] it_todo={it_todo}, " - f"grid_size={grid_size}, L_BLOCK={l_block}, " - f"N_THREADS={n_threads}") + logger.info( + f"[Joint prob. matrix] it_todo={it_todo}, " + f"grid_size={grid_size}, L_BLOCK={l_block}, " + f"N_THREADS={n_threads}" + ) asset_cu = self._compile_template( template_name="joint_pmat.cu", @@ -1019,10 +1075,14 @@ def pycuda(self, log_du): drv.Context.synchronize() kernel = module.get_function("jsf_uniform_orderstat_3d_kernel") - kernel(P_total_gpu.gpudata, log_du_gpu, grid=(grid_size, 1), - block=(n_threads, 1, 1)) + kernel( + P_total_gpu.gpudata, + log_du_gpu, + grid=(grid_size, 1), + block=(n_threads, 1, 1), + ) - P_total_gpu[:chunk_size].get(ary=P_total[i_start: i_end]) + P_total_gpu[:chunk_size].get(ary=P_total[i_start:i_end]) return P_total @@ -1041,40 +1101,41 @@ def _cuda(self, log_du): template_name="joint_pmat_old.cu", L=f"{log_du.shape[0]}LLU", N_THREADS=self.cuda_threads, - ITERATIONS_TODO=f"{self.num_iterations}LLU" + ITERATIONS_TODO=f"{self.num_iterations}LLU", ) with tempfile.TemporaryDirectory() as asset_tmp_folder: - asset_cu_path = os.path.join(asset_tmp_folder, 'asset.cu') - asset_bin_path = os.path.join(asset_tmp_folder, 'asset.o') - with open(asset_cu_path, 'w') as f: + asset_cu_path = os.path.join(asset_tmp_folder, "asset.cu") + asset_bin_path = os.path.join(asset_tmp_folder, "asset.o") + with open(asset_cu_path, "w") as f: f.write(asset_cu) # -O3 optimization flag is for the host code only; # by default, GPU device code is optimized with -O3. # -w to ignore warnings. - compile_cmd = ['nvcc', '-w', '-O3', '-o', asset_bin_path, - asset_cu_path] - if self.precision == 'double' and get_cuda_capability_major() >= 6: + compile_cmd = ["nvcc", "-w", "-O3", "-o", asset_bin_path, asset_cu_path] + if self.precision == "double" and get_cuda_capability_major() >= 6: # atomicAdd(double) requires compute capability 6.x - compile_cmd.extend(['-arch', 'sm_60']) + compile_cmd.extend(["-arch", "sm_60"]) compile_status = subprocess.run( - compile_cmd, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + compile_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) logger.info(compile_status.stdout.decode()) logger.info(compile_status.stderr.decode()) compile_status.check_returncode() log_du_path = os.path.join(asset_tmp_folder, "log_du.dat") P_total_path = os.path.join(asset_tmp_folder, "P_total.dat") - with open(log_du_path, 'wb') as f: + with open(log_du_path, "wb") as f: log_du.tofile(f) run_status = subprocess.run( [asset_bin_path, log_du_path, P_total_path], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) logger.info(run_status.stdout.decode()) logger.info(run_status.stderr.decode()) run_status.check_returncode() - with open(P_total_path, 'rb') as f: + with open(P_total_path, "rb") as f: P_total = np.fromfile(f, dtype=self.dtype) return P_total @@ -1082,21 +1143,28 @@ def _cuda(self, log_du): def _check_input(self, log_du): it_todo = self.num_iterations if it_todo > np.iinfo(np.uint64).max: - raise ValueError(f"it_todo ({it_todo}) is larger than MAX_UINT64." - " Only Python backend is supported.") + raise ValueError( + f"it_todo ({it_todo}) is larger than MAX_UINT64." + " Only Python backend is supported." + ) # Don't convert log_du to float32 transparently for the user to avoid # situations when the user accidentally passes an array with float64. # Doing so wastes memory for nothing. if log_du.dtype != np.float32: raise ValueError("'log_du' must be a float32 array") if log_du.shape[1] != self.d + 1: - raise ValueError(f"log_du.shape[1] ({log_du.shape[1]}) must be " - f"equal to D+1 ({self.d + 1})") + raise ValueError( + f"log_du.shape[1] ({log_du.shape[1]}) must be " + f"equal to D+1 ({self.d + 1})" + ) def compute(self, u): if u.shape[1] != self.d: - raise ValueError("Invalid input data shape axis 1: expected {}, " - "got {}".format(self.d, u.shape[1])) + raise ValueError( + "Invalid input data shape axis 1: expected {}, " "got {}".format( + self.d, u.shape[1] + ) + ) # A faster and memory efficient implementation of # du = np.diff(u, prepend=0, append=1, axis=1).astype(np.float32) du = np.empty((u.shape[0], u.shape[1] + 1), dtype=np.float32) @@ -1110,7 +1178,7 @@ def compute(self, u): # the remaining infinities correctly evaluate to # exp(ln(0)) = exp(-inf) = 0 with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) + warnings.simplefilter("ignore", RuntimeWarning) log_du = np.log(du, out=du) jsf_backend = self._choose_backend() @@ -1122,15 +1190,17 @@ def compute(self, u): outside_vals = P_total[~inside] if len(outside_vals) > 0: # A watchdog for unexpected results. - warnings.warn(f"{len(outside_vals)}/{P_total.shape[0]} values of " - "the computed joint prob. matrix lie outside of the " - f"valid [0, 1] interval:\n{outside_vals}\nIf you're " - "using PyOpenCL backend, make sure you've disabled " - "GPU Hangcheck as described here https://www.intel." - "com/content/www/us/en/docs/oneapi/installation-" - "guide-linux/2023-1/gpu-disable-hangcheck.html \n" - "Clipping the output array to 0 and 1.") - P_total = np.clip(P_total, a_min=0., a_max=1., out=P_total) + warnings.warn( + f"{len(outside_vals)}/{P_total.shape[0]} values of " + "the computed joint prob. matrix lie outside of the " + f"valid [0, 1] interval:\n{outside_vals}\nIf you're " + "using PyOpenCL backend, make sure you've disabled " + "GPU Hangcheck as described here https://www.intel." + "com/content/www/us/en/docs/oneapi/installation-" + "guide-linux/2023-1/gpu-disable-hangcheck.html \n" + "Clipping the output array to 0 and 1." + ) + P_total = np.clip(P_total, a_min=0.0, a_max=1.0, out=P_total) return P_total @@ -1145,36 +1215,41 @@ class _PMatNeighbors(_GPUBackend): The number of largest neighbors to collect for each entry in `mat`. """ - def __init__(self, filter_shape, n_largest, max_chunk_size=None, - verbose=None): + def __init__(self, filter_shape, n_largest, max_chunk_size=None, verbose=None): super().__init__(max_chunk_size=max_chunk_size) self.n_largest = n_largest self.max_chunk_size = max_chunk_size if verbose is not None: - warnings.warn("The 'verbose' parameter is deprecated and will be " - "removed in the future. Its functionality is still " - "available by using the logging module from Python. " - "We recommend transitioning to the logging module " - "for improved control and flexibility in handling " - "verbosity levels.", DeprecationWarning) + warnings.warn( + "The 'verbose' parameter is deprecated and will be " + "removed in the future. Its functionality is still " + "available by using the logging module from Python. " + "We recommend transitioning to the logging module " + "for improved control and flexibility in handling " + "verbosity levels.", + DeprecationWarning, + ) self.verbose = verbose filter_size, filter_width = filter_shape if filter_width >= filter_size: - raise ValueError('filter_shape width must be lower than length') + raise ValueError("filter_shape width must be lower than length") if not ((filter_width % 2) and (filter_size % 2)): warnings.warn( - 'The kernel is not centered on the datapoint in whose' - 'calculation it is used. Consider using odd values' - 'for both entries of filter_shape.') + "The kernel is not centered on the datapoint in whose" + "calculation it is used. Consider using odd values" + "for both entries of filter_shape." + ) # Construct the kernel filt = np.ones((filter_size, filter_size), dtype=bool) filt = np.triu(filt, -filter_width) filt = np.tril(filt, filter_width) if n_largest > len(filt.nonzero()[0]): - raise ValueError(f"Too small filter shape {filter_shape} to " - f"select {n_largest} largest elements.") + raise ValueError( + f"Too small filter shape {filter_shape} to " + f"select {n_largest} largest elements." + ) self.filter_kernel = filt @@ -1182,11 +1257,14 @@ def _check_input(self, mat): symmetric = np.all(np.diagonal(mat) == 0.5) # Check consistent arguments filter_size = self.filter_kernel.shape[0] - if (symmetric and mat.shape[0] < 2 * filter_size - 1) \ - or (not symmetric and min(mat.shape) < filter_size): - raise ValueError(f"'filter_shape' {self.filter_kernel.shape} is " - f"too large for the input matrix of shape " - f"{mat.shape}") + if (symmetric and mat.shape[0] < 2 * filter_size - 1) or ( + not symmetric and min(mat.shape) < filter_size + ): + raise ValueError( + f"'filter_shape' {self.filter_kernel.shape} is " + f"too large for the input matrix of shape " + f"{mat.shape}" + ) if mat.dtype != np.float32: raise ValueError("The input matrix dtype must be float32.") @@ -1209,37 +1287,38 @@ def pyopencl(self, mat): filt_rows = "{%s}" % ", ".join(f"{row}U" for row in filt_rows) filt_cols = "{%s}" % ", ".join(f"{col}U" for col in filt_cols) - lmat_padded = np.zeros((mat.shape[0], mat.shape[1], self.n_largest), - dtype=np.float32) + lmat_padded = np.zeros( + (mat.shape[0], mat.shape[1], self.n_largest), dtype=np.float32 + ) if symmetric: mat = mat[filt_size:] - lmat = lmat_padded[filt_size + filt_size // 2: -filt_size // 2 + 1] + lmat = lmat_padded[filt_size + filt_size // 2 : -filt_size // 2 + 1] else: - lmat = lmat_padded[filt_size // 2: -filt_size // 2 + 1] + lmat = lmat_padded[filt_size // 2 : -filt_size // 2 + 1] # GPU_MAX_HEAP_SIZE OpenCL flag is set to 2 Gb (1 << 31) by default - mem_avail = min(device.max_mem_alloc_size, device.global_mem_size, - 1 << 31) + mem_avail = min(device.max_mem_alloc_size, device.global_mem_size, 1 << 31) # 4 * size * n_cols * n_largest + 4 * (size + filt_size) * n_cols chunk_size = (mem_avail // 4 - filt_size * lmat.shape[1]) // ( - lmat.shape[1] * (self.n_largest + 1)) - chunk_size, split_idx = self._split_axis(chunk_size=chunk_size, - axis_size=lmat.shape[0], - min_chunk_size=filt_size) + lmat.shape[1] * (self.n_largest + 1) + ) + chunk_size, split_idx = self._split_axis( + chunk_size=chunk_size, axis_size=lmat.shape[0], min_chunk_size=filt_size + ) pmat_cl_path = Path(__file__).parent / "pmat_neighbors.cl" pmat_cl_template = Template(pmat_cl_path.read_text()) lmat_gpu = cl_array.Array( - queue, shape=(chunk_size, lmat.shape[1], self.n_largest), - dtype=np.float32 + queue, shape=(chunk_size, lmat.shape[1], self.n_largest), dtype=np.float32 ) - for i_start, i_end in tqdm(split_idx, total=len(split_idx), - desc="Largest neighbors OpenCL"): - mat_gpu = cl_array.to_device(queue, - mat[i_start: i_end + filt_size], - async_=True) + for i_start, i_end in tqdm( + split_idx, total=len(split_idx), desc="Largest neighbors OpenCL" + ): + mat_gpu = cl_array.to_device( + queue, mat[i_start : i_end + filt_size], async_=True + ) lmat_gpu.fill(0, queue=queue) chunk_size = i_end - i_start it_todo = chunk_size * (lmat.shape[1] - filt_size + 1) @@ -1252,7 +1331,7 @@ def pyopencl(self, mat): NONZERO_SIZE=self.filter_kernel.sum(), SYMMETRIC=int(symmetric), filt_rows=filt_rows, - filt_cols=filt_cols + filt_cols=filt_cols, ) program = cl.Program(context, pmat_neighbors_cl).build() @@ -1267,12 +1346,13 @@ def pyopencl(self, mat): # work items exactly matches the desired number of iterations. kernel(queue, (it_todo,), None, lmat_gpu.data, mat_gpu.data) - lmat_gpu[:chunk_size].get(ary=lmat[i_start: i_end]) + lmat_gpu[:chunk_size].get(ary=lmat[i_start:i_end]) return lmat_padded def pycuda(self, mat): from jinja2 import Template + try: # PyCuda should not be in requirements-extra because CPU limited # users won't be able to install Elephant. @@ -1281,8 +1361,7 @@ def pycuda(self, mat): import pycuda.driver as drv from pycuda.compiler import SourceModule except ImportError as err: - raise ImportError( - "Install pycuda with 'pip install pycuda'") from err + raise ImportError("Install pycuda with 'pip install pycuda'") from err # if the matrix is symmetric the diagonal was set to 0.5 # when computing the probability matrix @@ -1295,35 +1374,38 @@ def pycuda(self, mat): filt_size = self.filter_kernel.shape[0] filt_rows, filt_cols = self.filter_kernel.nonzero() - lmat_padded = np.zeros((mat.shape[0], mat.shape[1], self.n_largest), - dtype=np.float32) + lmat_padded = np.zeros( + (mat.shape[0], mat.shape[1], self.n_largest), dtype=np.float32 + ) if symmetric: mat = mat[filt_size:] - lmat = lmat_padded[filt_size + filt_size // 2: -filt_size // 2 + 1] + lmat = lmat_padded[filt_size + filt_size // 2 : -filt_size // 2 + 1] else: - lmat = lmat_padded[filt_size // 2: -filt_size // 2 + 1] + lmat = lmat_padded[filt_size // 2 : -filt_size // 2 + 1] free, total = drv.mem_get_info() # 4 * size * n_cols * n_largest + 4 * (size + filt_size) * n_cols chunk_size = (free // 4 - filt_size * lmat.shape[1]) // ( - lmat.shape[1] * (self.n_largest + 1)) - chunk_size, split_idx = self._split_axis(chunk_size=chunk_size, - axis_size=lmat.shape[0], - min_chunk_size=filt_size) + lmat.shape[1] * (self.n_largest + 1) + ) + chunk_size, split_idx = self._split_axis( + chunk_size=chunk_size, axis_size=lmat.shape[0], min_chunk_size=filt_size + ) pmat_cu_path = Path(__file__).parent / "pmat_neighbors.cu" pmat_cu_template = Template(pmat_cu_path.read_text()) lmat_gpu = gpuarray.GPUArray( - (chunk_size, lmat.shape[1], self.n_largest), dtype=np.float32) + (chunk_size, lmat.shape[1], self.n_largest), dtype=np.float32 + ) mat_gpu = drv.mem_alloc(4 * (chunk_size + filt_size) * mat.shape[1]) - for i_start, i_end in tqdm(split_idx, total=len(split_idx), - desc="Largest neighbors CUDA"): - drv.memcpy_htod_async(dest=mat_gpu, - src=mat[i_start: i_end + filt_size]) + for i_start, i_end in tqdm( + split_idx, total=len(split_idx), desc="Largest neighbors CUDA" + ): + drv.memcpy_htod_async(dest=mat_gpu, src=mat[i_start : i_end + filt_size]) lmat_gpu.fill(0) chunk_size = i_end - i_start it_todo = chunk_size * (lmat.shape[1] - filt_size + 1) @@ -1350,15 +1432,18 @@ def pycuda(self, mat): grid_size = math.ceil(it_todo / n_threads) if grid_size > device.MAX_GRID_DIM_X: - raise ValueError("Cannot launch a CUDA kernel with " - f"{grid_size} num. of blocks. Adjust the " - "'max_chunk_size' parameter.") + raise ValueError( + "Cannot launch a CUDA kernel with " + f"{grid_size} num. of blocks. Adjust the " + "'max_chunk_size' parameter." + ) kernel = module.get_function("pmat_neighbors") - kernel(lmat_gpu.gpudata, mat_gpu, grid=(grid_size, 1), - block=(n_threads, 1, 1)) + kernel( + lmat_gpu.gpudata, mat_gpu, grid=(grid_size, 1), block=(n_threads, 1, 1) + ) - lmat_gpu[:chunk_size].get(ary=lmat[i_start: i_end]) + lmat_gpu[:chunk_size].get(ary=lmat[i_start:i_end]) return lmat_padded @@ -1401,8 +1486,7 @@ def cpu(self, mat): filter_size = self.filter_kernel.shape[0] # Initialize the matrix of d-largest values as a matrix of zeroes - lmat = np.zeros((mat.shape[0], mat.shape[1], self.n_largest), - dtype=np.float32) + lmat = np.zeros((mat.shape[0], mat.shape[1], self.n_largest), dtype=np.float32) N_bin_y = mat.shape[0] N_bin_x = mat.shape[1] @@ -1415,22 +1499,22 @@ def cpu(self, mat): bin_range_x = range(N_bin_x - filter_size + 1) # compute matrix of largest values - for y in tqdm(bin_range_y, total=len(bin_range_y), - desc="Largest neighbors CPU"): + for y in tqdm( + bin_range_y, total=len(bin_range_y), desc="Largest neighbors CPU" + ): if symmetric: # x range depends on y position bin_range_x = range(y - filter_size + 1) for x in bin_range_x: - patch = mat[y: y + filter_size, x: x + filter_size] + patch = mat[y : y + filter_size, x : x + filter_size] mskd = patch[self.filter_kernel] - largest_vals = np.sort(mskd)[-self.n_largest:] - lmat[y + (filter_size // 2), x + (filter_size // 2), :] = \ - largest_vals + largest_vals = np.sort(mskd)[-self.n_largest :] + lmat[y + (filter_size // 2), x + (filter_size // 2), :] = largest_vals return lmat -def synchronous_events_intersection(sse1, sse2, intersection='linkwise'): +def synchronous_events_intersection(sse1, sse2, intersection="linkwise"): """ Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each consisting of a pool of positions `(iK, jK)` of matrix entries and @@ -1480,23 +1564,24 @@ def synchronous_events_intersection(sse1, sse2, intersection='linkwise'): if pixel1 not in sse2.keys(): del sse_new[pixel1] - if intersection == 'linkwise': + if intersection == "linkwise": for pixel1, link1 in sse_new.items(): sse_new[pixel1] = link1.intersection(sse2[pixel1]) if len(sse_new[pixel1]) == 0: del sse_new[pixel1] - elif intersection == 'pixelwise': + elif intersection == "pixelwise": # no action required pass else: raise ValueError( - "intersection (=%s) can only be" % intersection + - " 'pixelwise' or 'linkwise'") + "intersection (=%s) can only be" % intersection + + " 'pixelwise' or 'linkwise'" + ) return sse_new -def synchronous_events_difference(sse1, sse2, difference='linkwise'): +def synchronous_events_difference(sse1, sse2, difference="linkwise"): """ Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each consisting of a pool of pixel positions and associated synchronous events @@ -1546,16 +1631,17 @@ def synchronous_events_difference(sse1, sse2, difference='linkwise'): sse_new = sse1.copy() for pixel1 in sse1.keys(): if pixel1 in sse2.keys(): - if difference == 'pixelwise': + if difference == "pixelwise": del sse_new[pixel1] - elif difference == 'linkwise': + elif difference == "linkwise": sse_new[pixel1] = sse_new[pixel1].difference(sse2[pixel1]) if len(sse_new[pixel1]) == 0: del sse_new[pixel1] else: raise ValueError( - "difference (=%s) can only be" % difference + - " 'pixelwise' or 'linkwise'") + "difference (=%s) can only be" % difference + + " 'pixelwise' or 'linkwise'" + ) return sse_new @@ -1919,24 +2005,32 @@ def get_sse_start_and_end_time_bins(sse): def _signals_t_start_stop(signals, t_start=None, t_stop=None): if t_start is None: - t_start = _signals_same_attribute(signals, 't_start') + t_start = _signals_same_attribute(signals, "t_start") if t_stop is None: - t_stop = _signals_same_attribute(signals, 't_stop') + t_stop = _signals_same_attribute(signals, "t_stop") return t_start, t_stop -def _intersection_matrix(spiketrains, spiketrains_y, bin_size, t_start_x, - t_start_y, t_stop_x, t_stop_y, normalization=None): +def _intersection_matrix( + spiketrains, + spiketrains_y, + bin_size, + t_start_x, + t_start_y, + t_stop_x, + t_stop_y, + normalization=None, +): if spiketrains_y is None: spiketrains_y = spiketrains # Compute the binned spike train matrices, along both time axes spiketrains_binned = conv.BinnedSpikeTrain( - spiketrains, bin_size=bin_size, - t_start=t_start_x, t_stop=t_stop_x) + spiketrains, bin_size=bin_size, t_start=t_start_x, t_stop=t_stop_x + ) spiketrains_binned_y = conv.BinnedSpikeTrain( - spiketrains_y, bin_size=bin_size, - t_start=t_start_y, t_stop=t_stop_y) + spiketrains_y, bin_size=bin_size, t_start=t_start_y, t_stop=t_stop_y + ) # Compute imat by matrix multiplication bsts_x = spiketrains_binned.sparse_matrix @@ -1953,29 +2047,31 @@ def _intersection_matrix(spiketrains, spiketrains_y, bin_size, t_start_x, # Normalize the row col_sum = bsts_x[:, ii].sum() if normalization is None or col_sum == 0: - norm_coef = 1. - elif normalization == 'intersection': - norm_coef = np.minimum( - spikes_per_bin_x[ii], spikes_per_bin_y) - elif normalization == 'mean': + norm_coef = 1.0 + elif normalization == "intersection": + norm_coef = np.minimum(spikes_per_bin_x[ii], spikes_per_bin_y) + elif normalization == "mean": # geometric mean - norm_coef = np.sqrt( - spikes_per_bin_x[ii] * spikes_per_bin_y) - elif normalization == 'union': - norm_coef = np.array([(bsts_x[:, ii] - + bsts_y[:, jj]).count_nonzero() - for jj in range(bsts_y.shape[1])]) + norm_coef = np.sqrt(spikes_per_bin_x[ii] * spikes_per_bin_y) + elif normalization == "union": + norm_coef = np.array( + [ + (bsts_x[:, ii] + bsts_y[:, jj]).count_nonzero() + for jj in range(bsts_y.shape[1]) + ] + ) else: - raise ValueError( - "Invalid parameter 'norm': {}".format(normalization)) + raise ValueError("Invalid parameter 'norm': {}".format(normalization)) # If normalization required, for each j such that bsts_y[j] is # identically 0 the code above sets imat[:, j] to identically nan. # Substitute 0s instead. - imat[ii, :] = np.divide(imat[ii, :], norm_coef, - out=np.zeros(imat.shape[1], - dtype=np.float32), - where=norm_coef != 0) + imat[ii, :] = np.divide( + imat[ii, :], + norm_coef, + out=np.zeros(imat.shape[1], dtype=np.float32), + where=norm_coef != 0, + ) # Return the intersection matrix and the edges of the bins used for the # x and y axes, respectively. @@ -2029,17 +2125,28 @@ class ASSET(object): """ - def __init__(self, spiketrains_i, spiketrains_j=None, bin_size=3 * pq.ms, - t_start_i=None, t_start_j=None, t_stop_i=None, t_stop_j=None, - bin_tolerance='default', verbose=None): - + def __init__( + self, + spiketrains_i, + spiketrains_j=None, + bin_size=3 * pq.ms, + t_start_i=None, + t_start_j=None, + t_stop_i=None, + t_stop_j=None, + bin_tolerance="default", + verbose=None, + ): if verbose is not None: - warnings.warn("The 'verbose' parameter is deprecated and will be " - "removed in the future. Its functionality is still " - "available by using the logging module from Python. " - "We recommend transitioning to the logging module " - "for improved control and flexibility in handling " - "verbosity levels.", DeprecationWarning) + warnings.warn( + "The 'verbose' parameter is deprecated and will be " + "removed in the future. Its functionality is still " + "available by using the logging module from Python. " + "We recommend transitioning to the logging module " + "for improved control and flexibility in handling " + "verbosity levels.", + DeprecationWarning, + ) self.spiketrains_i = spiketrains_i if spiketrains_j is None: @@ -2047,29 +2154,29 @@ def __init__(self, spiketrains_i, spiketrains_j=None, bin_size=3 * pq.ms, self.spiketrains_j = spiketrains_j self.bin_size = bin_size self.t_start_i, self.t_stop_i = _signals_t_start_stop( - spiketrains_i, - t_start=t_start_i, - t_stop=t_stop_i) + spiketrains_i, t_start=t_start_i, t_stop=t_stop_i + ) self.t_start_j, self.t_stop_j = _signals_t_start_stop( - spiketrains_j, - t_start=t_start_j, - t_stop=t_stop_j) + spiketrains_j, t_start=t_start_j, t_stop=t_stop_j + ) self.verbose = verbose and rank == 0 - msg = 'The time intervals for x and y need to be either identical ' \ - 'or fully disjoint, but they are:\n' \ - 'x: ({}, {}) and y: ({}, {}).'.format(self.t_start_i, - self.t_stop_i, - self.t_start_j, - self.t_stop_j) + msg = ( + "The time intervals for x and y need to be either identical " + "or fully disjoint, but they are:\n" + "x: ({}, {}) and y: ({}, {}).".format( + self.t_start_i, self.t_stop_i, self.t_start_j, self.t_stop_j + ) + ) # the starts have to be perfectly aligned for the binning to work # the stops can differ without impacting the binning if self.t_start_i == self.t_start_j: if not _quantities_almost_equal(self.t_stop_i, self.t_stop_j): raise ValueError(msg) - elif (self.t_start_i < self.t_start_j < self.t_stop_i) \ - or (self.t_start_i < self.t_stop_j < self.t_stop_i): + elif (self.t_start_i < self.t_start_j < self.t_stop_i) or ( + self.t_start_i < self.t_stop_j < self.t_stop_i + ): raise ValueError(msg) # Define the tolerance parameter for binning. @@ -2077,16 +2184,25 @@ def __init__(self, spiketrains_i, spiketrains_j=None, bin_size=3 * pq.ms, # called without passing the parameter, and it will take what is # defined by the behavior of that class. Otherwise, set to the value # specified by `bin_tolerance` - tolerance_param = {'tolerance': bin_tolerance} if \ - bin_tolerance != 'default' else {} + tolerance_param = ( + {"tolerance": bin_tolerance} if bin_tolerance != "default" else {} + ) # Compute the binned spike train matrices, along both time axes self.spiketrains_binned_i = conv.BinnedSpikeTrain( - self.spiketrains_i, bin_size=self.bin_size, - t_start=self.t_start_i, t_stop=self.t_stop_i, **tolerance_param) + self.spiketrains_i, + bin_size=self.bin_size, + t_start=self.t_start_i, + t_stop=self.t_stop_i, + **tolerance_param, + ) self.spiketrains_binned_j = conv.BinnedSpikeTrain( - self.spiketrains_j, bin_size=self.bin_size, - t_start=self.t_start_j, t_stop=self.t_stop_j, **tolerance_param) + self.spiketrains_j, + bin_size=self.bin_size, + t_start=self.t_start_j, + t_stop=self.t_stop_j, + **tolerance_param, + ) @property def x_edges(self): @@ -2161,16 +2277,25 @@ def intersection_matrix(self, normalization=None): time was discretized in. """ - imat = _intersection_matrix(self.spiketrains_i, self.spiketrains_j, - self.bin_size, - self.t_start_i, self.t_start_j, - self.t_stop_i, self.t_stop_j, - normalization=normalization) + imat = _intersection_matrix( + self.spiketrains_i, + self.spiketrains_j, + self.bin_size, + self.t_start_i, + self.t_start_j, + self.t_stop_i, + self.t_stop_j, + normalization=normalization, + ) return imat - def probability_matrix_montecarlo(self, n_surrogates, imat=None, - surrogate_method='dither_spikes', - surrogate_dt=None): + def probability_matrix_montecarlo( + self, + n_surrogates, + imat=None, + surrogate_method="dither_spikes", + surrogate_dt=None, + ): """ Given a list of parallel spike trains, estimate the cumulative probability of each entry in their intersection matrix by a Monte Carlo @@ -2259,45 +2384,64 @@ def probability_matrix_montecarlo(self, n_surrogates, imat=None, for surr_id in trange(n_surrogates, desc="pmat_bootstrap"): if mpi_accelerated and surr_id % size != rank: continue - surrogates = [spike_train_surrogates.surrogates( - st, n_surrogates=1, - method=surrogate_method, - dt=surrogate_dt, - decimals=None, - edges=True)[0] - for st in self.spiketrains_i] + surrogates = [ + spike_train_surrogates.surrogates( + st, + n_surrogates=1, + method=surrogate_method, + dt=surrogate_dt, + decimals=None, + edges=True, + )[0] + for st in self.spiketrains_i + ] if symmetric: surrogates_y = surrogates else: - surrogates_y = [spike_train_surrogates.surrogates( - st, n_surrogates=1, method=surrogate_method, - dt=surrogate_dt, decimals=None, edges=True)[0] - for st in self.spiketrains_j] - - imat_surr = _intersection_matrix(surrogates, surrogates_y, - self.bin_size, - self.t_start_i, self.t_start_j, - self.t_stop_i, self.t_stop_j) + surrogates_y = [ + spike_train_surrogates.surrogates( + st, + n_surrogates=1, + method=surrogate_method, + dt=surrogate_dt, + decimals=None, + edges=True, + )[0] + for st in self.spiketrains_j + ] + + imat_surr = _intersection_matrix( + surrogates, + surrogates_y, + self.bin_size, + self.t_start_i, + self.t_start_j, + self.t_stop_i, + self.t_stop_j, + ) - pmat += (imat_surr <= (imat - 1)) + pmat += imat_surr <= (imat - 1) del imat_surr if mpi_accelerated: pmat = comm.allreduce(pmat, op=MPI.SUM) - pmat = pmat * 1. / n_surrogates + pmat = pmat * 1.0 / n_surrogates if symmetric: np.fill_diagonal(pmat, 0.5) return pmat - def probability_matrix_analytical(self, imat=None, - firing_rates_x='estimate', - firing_rates_y='estimate', - kernel_width=100 * pq.ms): + def probability_matrix_analytical( + self, + imat=None, + firing_rates_x="estimate", + firing_rates_y="estimate", + kernel_width=100 * pq.ms, + ): r""" Given a list of spike trains, approximates the cumulative probability of each entry in their intersection matrix. @@ -2363,51 +2507,49 @@ def probability_matrix_analytical(self, imat=None, # Check that the nr. neurons is identical between the two axes if bsts_x_matrix.shape[0] != bsts_y_matrix.shape[0]: - raise ValueError( - 'Different number of neurons along the x and y axis!') + raise ValueError("Different number of neurons along the x and y axis!") # Define the firing rate profiles - if firing_rates_x == 'estimate': + if firing_rates_x == "estimate": # If rates are to be estimated, create the rate profiles as # Quantity objects obtained by boxcar-kernel convolution - fir_rate_x = self._rate_of_binned_spiketrain(bsts_x_matrix, - kernel_width) + fir_rate_x = self._rate_of_binned_spiketrain(bsts_x_matrix, kernel_width) elif isinstance(firing_rates_x, list): # If rates provided as lists of AnalogSignals, create time slices # for both axes, interpolate in the time bins of interest and # convert to Quantity fir_rate_x = _interpolate_signals( - firing_rates_x, self.spiketrains_binned_i.bin_edges[:-1]) + firing_rates_x, self.spiketrains_binned_i.bin_edges[:-1] + ) else: - raise ValueError( - 'fir_rates_x must be a list or the string "estimate"') + raise ValueError('fir_rates_x must be a list or the string "estimate"') if symmetric: fir_rate_y = fir_rate_x - elif firing_rates_y == 'estimate': - fir_rate_y = self._rate_of_binned_spiketrain(bsts_y_matrix, - kernel_width) + elif firing_rates_y == "estimate": + fir_rate_y = self._rate_of_binned_spiketrain(bsts_y_matrix, kernel_width) elif isinstance(firing_rates_y, list): # If rates provided as lists of AnalogSignals, create time slices # for both axes, interpolate in the time bins of interest and # convert to Quantity fir_rate_y = _interpolate_signals( - firing_rates_y, self.spiketrains_binned_j.bin_edges[:-1]) + firing_rates_y, self.spiketrains_binned_j.bin_edges[:-1] + ) else: - raise ValueError( - 'fir_rates_y must be a list or the string "estimate"') + raise ValueError('fir_rates_y must be a list or the string "estimate"') # For each neuron, compute the prob. that that neuron spikes in any bin - logger.info("Compute the probability that each neuron fires in " - "each pair of bins...") + logger.info( + "Compute the probability that each neuron fires in " "each pair of bins..." + ) rate_bins_x = (fir_rate_x * self.bin_size).simplified.magnitude - spike_probs_x = 1. - np.exp(-rate_bins_x) + spike_probs_x = 1.0 - np.exp(-rate_bins_x) if symmetric: spike_probs_y = spike_probs_x else: rate_bins_y = (fir_rate_y * self.bin_size).simplified.magnitude - spike_probs_y = 1. - np.exp(-rate_bins_y) + spike_probs_y = 1.0 - np.exp(-rate_bins_y) # Compute the matrix Mu[i, j] of parameters for the Poisson # distributions which describe, at each (i, j), the approximated @@ -2415,8 +2557,7 @@ def probability_matrix_analytical(self, imat=None, # matrices p_ijk computed for each neuron k: # p_ijk is the probability that neuron k spikes in both bins i and j. # The sum of outer products is equivalent to a dot product. - logger.info("Compute the probability matrix by Le Cam's " - "approximation...") + logger.info("Compute the probability matrix by Le Cam's " "approximation...") Mu = spike_probs_x.T.dot(spike_probs_y) # A straightforward implementation is: # pmat_shape = spike_probs_x.shape[1], spike_probs_y.shape[1] @@ -2430,16 +2571,22 @@ def probability_matrix_analytical(self, imat=None, if symmetric: # Substitute 0.5 to the elements along the main diagonal - logger.info("Substitute 0.5 to elements along the main " - "diagonal...") + logger.info("Substitute 0.5 to elements along the main " "diagonal...") np.fill_diagonal(pmat, 0.5) return pmat - def joint_probability_matrix(self, pmat, filter_shape, n_largest, - min_p_value=1e-5, precision='float', - cuda_threads=64, cuda_cwr_loops=32, - tolerance=1e-5): + def joint_probability_matrix( + self, + pmat, + filter_shape, + n_largest, + min_p_value=1e-5, + precision="float", + cuda_threads=64, + cuda_cwr_loops=32, + tolerance=1e-5, + ): """ Map a probability matrix `pmat` to a joint probability matrix `jmat`, where `jmat[i, j]` is the joint p-value of the largest neighbors of @@ -2542,37 +2689,38 @@ def joint_probability_matrix(self, pmat, filter_shape, n_largest, # Find for each P_ij in the probability matrix its neighbors and # maximize them by the maximum value 1-p_value_min pmat = np.asarray(pmat, dtype=np.float32) - pmat_neighb_obj = _PMatNeighbors(filter_shape=filter_shape, - n_largest=n_largest) + pmat_neighb_obj = _PMatNeighbors(filter_shape=filter_shape, n_largest=n_largest) pmat_neighb = pmat_neighb_obj.compute(pmat) logger.info("Finding unique set of values...") - pmat_neighb = np.minimum(pmat_neighb, 1. - min_p_value, - out=pmat_neighb) + pmat_neighb = np.minimum(pmat_neighb, 1.0 - min_p_value, out=pmat_neighb) # in order to avoid doing the same calculation multiple times: # find all unique sets of values in pmat_neighb # and store the corresponding indices # flatten the second and third dimension in order to use np.unique pmat_neighb = pmat_neighb.reshape(pmat.size, n_largest) - pmat_neighb, pmat_neighb_indices = np.unique(pmat_neighb, axis=0, - return_inverse=True) + pmat_neighb, pmat_neighb_indices = np.unique( + pmat_neighb, axis=0, return_inverse=True + ) # Compute the joint p-value matrix jpvmat - n = l * (1 + 2 * w) - w * ( - w + 1) # number of entries covered by kernel - jsf = _JSFUniformOrderStat3D(n=n, d=pmat_neighb.shape[1], - precision=precision, - cuda_threads=cuda_threads, - cuda_cwr_loops=cuda_cwr_loops, - tolerance=tolerance) + n = l * (1 + 2 * w) - w * (w + 1) # number of entries covered by kernel + jsf = _JSFUniformOrderStat3D( + n=n, + d=pmat_neighb.shape[1], + precision=precision, + cuda_threads=cuda_threads, + cuda_cwr_loops=cuda_cwr_loops, + tolerance=tolerance, + ) jpvmat = jsf.compute(u=pmat_neighb) # restore the original shape using the stored indices jpvmat = jpvmat[pmat_neighb_indices].reshape(pmat.shape) - return 1. - jpvmat + return 1.0 - jpvmat @staticmethod def mask_matrices(matrices, thresholds): @@ -2617,11 +2765,10 @@ def mask_matrices(matrices, thresholds): if isinstance(thresholds, float): thresholds = np.full(shape=len(matrices), fill_value=thresholds) if len(matrices) != len(thresholds): - raise ValueError( - '`matrices` and `thresholds` must have same length') + raise ValueError("`matrices` and `thresholds` must have same length") mask = np.ones_like(matrices[0], dtype=bool) - for (mat, thresh) in zip(matrices, thresholds): + for mat, thresh in zip(matrices, thresholds): mask &= mat > thresh # Replace nans, coming from False * np.inf, with zeros @@ -2630,9 +2777,16 @@ def mask_matrices(matrices, thresholds): return mask @staticmethod - def cluster_matrix_entries(mask_matrix, max_distance, min_neighbors, - stretch, working_memory=None, array_file=None, - keep_file=False, verbose=None): + def cluster_matrix_entries( + mask_matrix, + max_distance, + min_neighbors, + stretch, + working_memory=None, + array_file=None, + keep_file=False, + verbose=None, + ): r""" Given a matrix `mask_matrix`, replaces its positive elements with integers representing different cluster IDs. Each cluster comprises @@ -2728,12 +2882,15 @@ def cluster_matrix_entries(mask_matrix, max_distance, min_neighbors, """ if verbose is not None: - warnings.warn("The 'verbose' parameter is deprecated and will be " - "removed in the future. Its functionality is still " - "available by using the logging module from Python. " - "We recommend transitioning to the logging module " - "for improved control and flexibility in handling " - "verbosity levels.", DeprecationWarning) + warnings.warn( + "The 'verbose' parameter is deprecated and will be " + "removed in the future. Its functionality is still " + "available by using the logging module from Python. " + "We recommend transitioning to the logging module " + "for improved control and flexibility in handling " + "verbosity levels.", + DeprecationWarning, + ) # Don't do anything if mat is identically zero if np.all(mask_matrix == 0): return mask_matrix @@ -2744,34 +2901,38 @@ def cluster_matrix_entries(mask_matrix, max_distance, min_neighbors, # Allocate temporary file if requested mapped_array_file = None if array_file: - file_path = Path(array_file) if isinstance(array_file, str) \ - else array_file + file_path = Path(array_file) if isinstance(array_file, str) else array_file file_dir = file_path.parent file_name = file_path.stem mapped_array_file = tempfile.NamedTemporaryFile( - prefix=file_name, dir=file_dir, - delete=not keep_file) + prefix=file_name, dir=file_dir, delete=not keep_file + ) # Compute the matrix D[i, j] of euclidean distances between pixels i # and j try: D = _stretched_metric_2d( - xpos_sgnf, ypos_sgnf, stretch=stretch, ref_angle=45, + xpos_sgnf, + ypos_sgnf, + stretch=stretch, + ref_angle=45, working_memory=working_memory, - mapped_array_file=mapped_array_file) + mapped_array_file=mapped_array_file, + ) except MemoryError as err: - raise MemoryError("Set 'working_memory=100' or another value to " - "chunk the data. If this does not solve, use the" - " 'array_file' parameter to pass a location for " - "a temporary file to map the array to the disk." - ) from err + raise MemoryError( + "Set 'working_memory=100' or another value to " + "chunk the data. If this does not solve, use the" + " 'array_file' parameter to pass a location for " + "a temporary file to map the array to the disk." + ) from err logger.info("Running DBSCAN") # Cluster positions of significant pixels via dbscan core_samples, config = dbscan( - D, eps=max_distance, min_samples=min_neighbors, - metric='precomputed') + D, eps=max_distance, min_samples=min_neighbors, metric="precomputed" + ) logger.info("Building cluster matrix") @@ -2780,8 +2941,9 @@ def cluster_matrix_entries(mask_matrix, max_distance, min_neighbors, # * 0 if it is not significant, # * -1 if it is significant but does not belong to any cluster cluster_mat = np.zeros_like(mask_matrix, dtype=np.int32) - cluster_mat[xpos_sgnf, ypos_sgnf] = \ - config * (config == -1) + (config + 1) * (config >= 0) + cluster_mat[xpos_sgnf, ypos_sgnf] = config * (config == -1) + (config + 1) * ( + config >= 0 + ) return cluster_mat @@ -2829,9 +2991,12 @@ def extract_synchronous_events(self, cmat, ids=None): # Compute the transactions associated to the two binnings logger.info("Finding transactions") tracts_x = _transactions( - self.spiketrains_i, bin_size=self.bin_size, t_start=self.t_start_i, + self.spiketrains_i, + bin_size=self.bin_size, + t_start=self.t_start_i, t_stop=self.t_stop_i, - ids=ids) + ids=ids, + ) if self.spiketrains_j is self.spiketrains_i or self.is_symmetric(): diag_id = 0 @@ -2839,25 +3004,26 @@ def extract_synchronous_events(self, cmat, ids=None): else: diag_id = None tracts_y = _transactions( - self.spiketrains_j, bin_size=self.bin_size, - t_start=self.t_start_j, t_stop=self.t_stop_j, ids=ids) + self.spiketrains_j, + bin_size=self.bin_size, + t_start=self.t_start_j, + t_stop=self.t_stop_j, + ids=ids, + ) # Reconstruct each worm, link by link sse_dict = {} - for k in tqdm(range(1, nr_worms + 1), - total=nr_worms, - desc="Extracting SSEs"): # for each worm + for k in tqdm( + range(1, nr_worms + 1), total=nr_worms, desc="Extracting SSEs" + ): # for each worm # worm k is a list of links (each link will be 1 sublist) worm_k = {} - pos_worm_k = np.array( - np.where(cmat == k)).T # position of all links + pos_worm_k = np.array(np.where(cmat == k)).T # position of all links # if no link lies on the reference diagonal if all([y - x != diag_id for (x, y) in pos_worm_k]): for bin_x, bin_y in pos_worm_k: # for each link - # reconstruct the link - link_l = set(tracts_x[bin_x]).intersection( - tracts_y[bin_y]) + link_l = set(tracts_x[bin_x]).intersection(tracts_y[bin_y]) # and assign it to its pixel worm_k[(bin_x, bin_y)] = link_l @@ -2876,9 +3042,10 @@ def _rate_of_binned_spiketrain(self, binned_spiketrains, kernel_width): # Create the boxcar kernel and convolve it with the binned spike trains k = int((kernel_width / self.bin_size).simplified.item()) - kernel = np.full(k, fill_value=1. / k) - rate = np.vstack([np.convolve(bst, kernel, mode='same') - for bst in binned_spiketrains]) + kernel = np.full(k, fill_value=1.0 / k) + rate = np.vstack( + [np.convolve(bst, kernel, mode="same") for bst in binned_spiketrains] + ) # The convolution results in an array decreasing at the borders due # to absence of spikes beyond the borders. Replace the first and last @@ -2889,6 +3056,6 @@ def _rate_of_binned_spiketrain(self, binned_spiketrains, kernel_width): rate[i, -k2:] = rate[i, -k2 - 1] # Multiply the firing rates by the proper unit - rate = rate * (1. / self.bin_size).rescale('Hz') + rate = rate * (1.0 / self.bin_size).rescale("Hz") return rate diff --git a/elephant/causality/granger.py b/elephant/causality/granger.py index 673161373..f0de5f9d1 100644 --- a/elephant/causality/granger.py +++ b/elephant/causality/granger.py @@ -95,16 +95,20 @@ "Causality", "pairwise_granger", "conditional_granger", - "pairwise_spectral_granger" + "pairwise_spectral_granger", ) # the return type of pairwise_granger(), pairwise_spectral_granger() function -Causality = namedtuple('Causality', - ['directional_causality_x_y', - 'directional_causality_y_x', - 'instantaneous_causality', - 'total_interdependence']) +Causality = namedtuple( + "Causality", + [ + "directional_causality_x_y", + "directional_causality_y_x", + "instantaneous_causality", + "total_interdependence", + ], +) def _bic(cov, order, dimension, length): @@ -128,8 +132,7 @@ def _bic(cov, order, dimension, length): Bayesian Information Criterion """ sign, log_det_cov = np.linalg.slogdet(cov) - criterion = 2 * log_det_cov \ - + 2*(dimension**2)*order*np.log(length)/length + criterion = 2 * log_det_cov + 2 * (dimension**2) * order * np.log(length) / length return criterion @@ -155,8 +158,7 @@ def _aic(cov, order, dimension, length): Akaike Information Criterion """ sign, log_det_cov = np.linalg.slogdet(cov) - criterion = 2 * log_det_cov \ - + 2*(dimension**2)*order/length + criterion = 2 * log_det_cov + 2 * (dimension**2) * order / length return criterion @@ -199,13 +201,14 @@ def _lag_covariances(signals, dimension, max_lag): # centralize time series signals_mean = (signals - np.mean(signals, keepdims=True)).T - lag_covariances = np.zeros((max_lag+1, dimension, dimension)) + lag_covariances = np.zeros((max_lag + 1, dimension, dimension)) # determine lagged covariance for different time lags - for lag in range(0, max_lag+1): - lag_covariances[lag] = \ - np.mean(np.einsum('ij,ik -> ijk', signals_mean[:length-lag], - signals_mean[lag:]), axis=0) + for lag in range(0, max_lag + 1): + lag_covariances[lag] = np.mean( + np.einsum("ij,ik -> ijk", signals_mean[: length - lag], signals_mean[lag:]), + axis=0, + ) return lag_covariances @@ -248,20 +251,19 @@ def _yule_walker_matrix(data, dimension, order): lag_covariances = _lag_covariances(data, dimension, order) - yule_walker_matrix = np.zeros((dimension*order, dimension*order)) + yule_walker_matrix = np.zeros((dimension * order, dimension * order)) for block_row in range(order): for block_column in range(block_row, order): - yule_walker_matrix[block_row*dimension: (block_row+1)*dimension, - block_column*dimension: - (block_column+1)*dimension] = \ - lag_covariances[block_column-block_row].T - - yule_walker_matrix[block_column*dimension: - (block_column+1)*dimension, - block_row*dimension: - (block_row+1)*dimension] = \ - lag_covariances[block_column-block_row] + yule_walker_matrix[ + block_row * dimension : (block_row + 1) * dimension, + block_column * dimension : (block_column + 1) * dimension, + ] = lag_covariances[block_column - block_row].T + + yule_walker_matrix[ + block_column * dimension : (block_column + 1) * dimension, + block_row * dimension : (block_row + 1) * dimension, + ] = lag_covariances[block_column - block_row] return yule_walker_matrix, lag_covariances @@ -302,31 +304,30 @@ def _vector_arm(signals, dimension, order): """ - yule_walker_matrix, lag_covariances = \ - _yule_walker_matrix(signals, dimension, order) + yule_walker_matrix, lag_covariances = _yule_walker_matrix(signals, dimension, order) - positive_lag_covariances = np.reshape(lag_covariances[1:], - (dimension*order, dimension)) + positive_lag_covariances = np.reshape( + lag_covariances[1:], (dimension * order, dimension) + ) - lstsq_coeffs = np.linalg.lstsq(yule_walker_matrix, - positive_lag_covariances, - rcond=None)[0] + lstsq_coeffs = np.linalg.lstsq( + yule_walker_matrix, positive_lag_covariances, rcond=None + )[0] coeffs = [] for index in range(order): - coeffs.append(lstsq_coeffs[index*dimension:(index+1)*dimension, ].T) + coeffs.append(lstsq_coeffs[index * dimension : (index + 1) * dimension,].T) coeffs = np.stack(coeffs) cov_matrix = np.copy(lag_covariances[0]) for i in range(order): - cov_matrix -= np.matmul(coeffs[i], lag_covariances[i+1]) + cov_matrix -= np.matmul(coeffs[i], lag_covariances[i + 1]) return coeffs, cov_matrix -def _optimal_vector_arm(signals, dimension, max_order, - information_criterion='aic'): +def _optimal_vector_arm(signals, dimension, max_order, information_criterion="aic"): """ Determine optimal auto regressive model by choosing optimal order via Information Criterion @@ -365,13 +366,15 @@ def _optimal_vector_arm(signals, dimension, max_order, for order in range(1, max_order + 1): coeffs, cov_matrix = _vector_arm(signals, dimension, order) - if information_criterion == 'aic': + if information_criterion == "aic": temp_ic = _aic(cov_matrix, order, dimension, length) - elif information_criterion == 'bic': + elif information_criterion == "bic": temp_ic = _bic(cov_matrix, order, dimension, length) else: - raise ValueError("The specified information criterion is not" - "available. Please use 'aic' or 'bic'.") + raise ValueError( + "The specified information criterion is not" + "available. Please use 'aic' or 'bic'." + ) if temp_ic < optimal_ic: optimal_ic = temp_ic @@ -405,7 +408,7 @@ def _bracket_operator(spectrum, num_freqs, num_signals): causal_part = np.fft.ifft(spectrum, axis=0) # Throw away acausal part - causal_part[(num_freqs + 1) // 2:] = 0 + causal_part[(num_freqs + 1) // 2 :] = 0 # Treat coefficient belonging to 0 causal_part[0] /= 2 @@ -490,31 +493,29 @@ def _spectral_factorization(cross_spectrum, num_iterations, term_crit=1e-12): # Initialization identity = np.identity(num_signals) - factorization = np.zeros(np.shape(spectral_density_function), - dtype='complex128') + factorization = np.zeros(np.shape(spectral_density_function), dtype="complex128") # Estimate initial conditions try: initial_cond = np.linalg.cholesky(cross_spectrum[0].real) except np.linalg.LinAlgError: - raise ValueError('Could not calculate Cholesky decomposition of real' - + ' part of zero frequency estimate of cross-spectrum' - + '. This might suggest a problem with the input') + raise ValueError( + "Could not calculate Cholesky decomposition of real" + + " part of zero frequency estimate of cross-spectrum" + + ". This might suggest a problem with the input" + ) factorization += initial_cond converged = False # Iteration for calculating spectral factorization for i in range(num_iterations): - factorization_old = np.copy(factorization) # Implementation of Eq. 3.1 from "The Factorization of Matricial # Spectral Densities", Wilson 1972, SiAM J Appl Math - X = np.linalg.solve(factorization, - spectral_density_function) - Y = np.linalg.solve(factorization, - _dagger(X)) + X = np.linalg.solve(factorization, spectral_density_function) + Y = np.linalg.solve(factorization, _dagger(X)) Y += identity Y = _bracket_operator(Y, num_freqs, num_signals) @@ -523,27 +524,27 @@ def _spectral_factorization(cross_spectrum, num_iterations, term_crit=1e-12): diff = factorization - factorization_old error = np.max(np.abs(diff)) if error < term_crit: - print(f'Spectral factorization converged after {i} steps') - converged=True + print(f"Spectral factorization converged after {i} steps") + converged = True break if not converged: - raise Exception("Spectral factorization did not converge after " - + f"{num_iterations} steps. Try to increase " - + "'num_iterations', or lower the allowed error " - + " in the termination criterion, currently " - + f"{term_crit}") + raise Exception( + "Spectral factorization did not converge after " + + f"{num_iterations} steps. Try to increase " + + "'num_iterations', or lower the allowed error " + + " in the termination criterion, currently " + + f"{term_crit}" + ) - cov_matrix = np.matmul(factorization[0], - _dagger(factorization[0])) + cov_matrix = np.matmul(factorization[0], _dagger(factorization[0])) - transfer_function = np.matmul(factorization, - np.linalg.inv(factorization[0])) + transfer_function = np.matmul(factorization, np.linalg.inv(factorization[0])) return cov_matrix, transfer_function -def pairwise_granger(signals, max_order, information_criterion='aic'): +def pairwise_granger(signals, max_order, information_criterion="aic"): r""" Determine Granger Causality of two time series @@ -658,12 +659,15 @@ def pairwise_granger(signals, max_order, information_criterion='aic'): # signal_x and signal_y are (1, N) arrays signal_x, signal_y = np.expand_dims(signals, axis=1) - coeffs_x, var_x, p_1 = _optimal_vector_arm(signal_x, 1, max_order, - information_criterion) - coeffs_y, var_y, p_2 = _optimal_vector_arm(signal_y, 1, max_order, - information_criterion) - coeffs_xy, cov_xy, p_3 = _optimal_vector_arm(signals, 2, max_order, - information_criterion) + coeffs_x, var_x, p_1 = _optimal_vector_arm( + signal_x, 1, max_order, information_criterion + ) + coeffs_y, var_y, p_2 = _optimal_vector_arm( + signal_y, 1, max_order, information_criterion + ) + coeffs_xy, cov_xy, p_3 = _optimal_vector_arm( + signals, 2, max_order, information_criterion + ) sign, log_det_cov = np.linalg.slogdet(cov_xy) tolerance = 1e-7 @@ -671,18 +675,20 @@ def pairwise_granger(signals, max_order, information_criterion='aic'): if sign <= 0: raise ValueError( "Determinant of covariance matrix must be always positive: " - "In this case its sign is {}".format(sign)) + "In this case its sign is {}".format(sign) + ) if log_det_cov <= tolerance: - warnings.warn("The value of the log determinant is at or below the " - "tolerance level. Proceeding with computation.", - UserWarning) + warnings.warn( + "The value of the log determinant is at or below the " + "tolerance level. Proceeding with computation.", + UserWarning, + ) directional_causality_y_x = np.log(var_x[0]) - np.log(cov_xy[0, 0]) directional_causality_x_y = np.log(var_y[0]) - np.log(cov_xy[1, 1]) - instantaneous_causality = \ - np.log(cov_xy[0, 0]) + np.log(cov_xy[1, 1]) - log_det_cov + instantaneous_causality = np.log(cov_xy[0, 0]) + np.log(cov_xy[1, 1]) - log_det_cov instantaneous_causality = np.asarray(instantaneous_causality) total_interdependence = np.log(var_x[0]) + np.log(var_y[0]) - log_det_cov @@ -691,26 +697,27 @@ def pairwise_granger(signals, max_order, information_criterion='aic'): # Note that standard error scales as 1/sqrt(sample_size) # Calculate significant figures according to standard error length = np.size(signal_x) - asymptotic_std_error = 1/np.sqrt(length) - est_sig_figures = int((-1)*np.around(np.log10(asymptotic_std_error))) - - directional_causality_x_y_round = np.around(directional_causality_x_y, - est_sig_figures) - directional_causality_y_x_round = np.around(directional_causality_y_x, - est_sig_figures) - instantaneous_causality_round = np.around(instantaneous_causality, - est_sig_figures) - total_interdependence_round = np.around(total_interdependence, - est_sig_figures) + asymptotic_std_error = 1 / np.sqrt(length) + est_sig_figures = int((-1) * np.around(np.log10(asymptotic_std_error))) + + directional_causality_x_y_round = np.around( + directional_causality_x_y, est_sig_figures + ) + directional_causality_y_x_round = np.around( + directional_causality_y_x, est_sig_figures + ) + instantaneous_causality_round = np.around(instantaneous_causality, est_sig_figures) + total_interdependence_round = np.around(total_interdependence, est_sig_figures) return Causality( directional_causality_x_y=directional_causality_x_y_round.item(), directional_causality_y_x=directional_causality_y_x_round.item(), instantaneous_causality=instantaneous_causality_round.item(), - total_interdependence=total_interdependence_round.item()) + total_interdependence=total_interdependence_round.item(), + ) -def conditional_granger(signals, max_order, information_criterion='aic'): +def conditional_granger(signals, max_order, information_criterion="aic"): r""" Determine conditional Granger Causality of the second time series on the first time series, given the third time series. In other words, for time @@ -763,11 +770,17 @@ def conditional_granger(signals, max_order, information_criterion='aic'): signals_xz = np.vstack([signal_x, signal_z]) coeffs_xz, cov_xz, p_1 = _optimal_vector_arm( - signals_xz, dimension=2, max_order=max_order, - information_criterion=information_criterion) + signals_xz, + dimension=2, + max_order=max_order, + information_criterion=information_criterion, + ) coeffs_xyz, cov_xyz, p_2 = _optimal_vector_arm( - signals, dimension=3, max_order=max_order, - information_criterion=information_criterion) + signals, + dimension=3, + max_order=max_order, + information_criterion=information_criterion, + ) conditional_causality_xy_z = np.log(cov_xz[0, 0]) - np.log(cov_xyz[0, 0]) @@ -775,20 +788,30 @@ def conditional_granger(signals, max_order, information_criterion='aic'): # Note that standard error scales as 1/sqrt(sample_size) # Calculate significant figures according to standard error length = np.size(signal_x) - asymptotic_std_error = 1/np.sqrt(length) - est_sig_figures = int((-1)*np.around(np.log10(asymptotic_std_error))) + asymptotic_std_error = 1 / np.sqrt(length) + est_sig_figures = int((-1) * np.around(np.log10(asymptotic_std_error))) - conditional_causality_xy_z_round = np.around(conditional_causality_xy_z, - est_sig_figures) + conditional_causality_xy_z_round = np.around( + conditional_causality_xy_z, est_sig_figures + ) return conditional_causality_xy_z_round -def pairwise_spectral_granger(signal_i, signal_j, fs=1, nw=4, num_tapers=None, - peak_resolution=None, n_segments=1, - len_segment=None, frequency_resolution=None, - overlap=0.5, num_iterations=300, - term_crit=1e-12): +def pairwise_spectral_granger( + signal_i, + signal_j, + fs=1, + nw=4, + num_tapers=None, + peak_resolution=None, + n_segments=1, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + num_iterations=300, + term_crit=1e-12, +): r"""Determine spectral Granger Causality of two signals. The spectral Granger Causality is obtained through the following steps: @@ -886,18 +909,26 @@ def pairwise_spectral_granger(signal_i, signal_j, fs=1, nw=4, num_tapers=None, of X, i.e. `signal_i` and Y, i.e. `signal_j`. If the total interdependence is positive, X and Y are not independent. """ - if isinstance(signal_i, neo.core.AnalogSignal) and \ - isinstance(signal_j, neo.core.AnalogSignal): + if isinstance(signal_i, neo.core.AnalogSignal) and isinstance( + signal_j, neo.core.AnalogSignal + ): signals = signal_i.merge(signal_j) elif isinstance(signal_i, np.ndarray) and isinstance(signal_j, np.ndarray): signals = np.vstack([signal_i, signal_j]) # Calculate cross spectrum for signals freqs, S = segmented_multitaper_cross_spectrum( - signals=signals, n_segments=n_segments, len_segment=len_segment, - frequency_resolution=frequency_resolution, overlap=overlap, fs=fs, - nw=nw, num_tapers=num_tapers, peak_resolution=peak_resolution, - return_onesided=False) + signals=signals, + n_segments=n_segments, + len_segment=len_segment, + frequency_resolution=frequency_resolution, + overlap=overlap, + fs=fs, + nw=nw, + num_tapers=num_tapers, + peak_resolution=peak_resolution, + return_onesided=False, + ) # Remove units attached by the multitaper_cross_spectrum if isinstance(S, pq.Quantity): @@ -915,7 +946,7 @@ def pairwise_spectral_granger(signal_i, signal_j, fs=1, nw=4, num_tapers=None, C, H = _spectral_factorization(S, num_iterations=num_iterations) # Take positive frequencies - mask = (freqs >= 0) + mask = freqs >= 0 freqs = freqs[mask] S = S[mask] @@ -923,32 +954,32 @@ def pairwise_spectral_granger(signal_i, signal_j, fs=1, nw=4, num_tapers=None, # Calculate spectral Granger Causality. # Formulae follow Wen et al., 2013, Phil Trans R Soc - H_tilde_xx = H[:, 0, 0] + C[0, 1]/C[0, 0]*H[:, 0, 1] - H_tilde_yy = H[:, 1, 1] + C[0, 1]/C[1, 1]*H[:, 1, 0] + H_tilde_xx = H[:, 0, 0] + C[0, 1] / C[0, 0] * H[:, 0, 1] + H_tilde_yy = H[:, 1, 1] + C[0, 1] / C[1, 1] * H[:, 1, 0] - directional_causality_y_x = np.log(S[:, 0, 0].real / - (H_tilde_xx - * C[0, 0] - * H_tilde_xx.conj()).real) + directional_causality_y_x = np.log( + S[:, 0, 0].real / (H_tilde_xx * C[0, 0] * H_tilde_xx.conj()).real + ) - directional_causality_x_y = np.log(S[:, 1, 1].real / - (H_tilde_yy - * C[1, 1] - * H_tilde_yy.conj()).real) + directional_causality_x_y = np.log( + S[:, 1, 1].real / (H_tilde_yy * C[1, 1] * H_tilde_yy.conj()).real + ) instantaneous_causality = np.log( (H_tilde_xx * C[0, 0] * H_tilde_xx.conj()).real - * (H_tilde_yy * C[1, 1] * H_tilde_yy.conj()).real) + * (H_tilde_yy * C[1, 1] * H_tilde_yy.conj()).real + ) instantaneous_causality -= np.linalg.slogdet(S)[1] - total_interdependence = (directional_causality_x_y - + directional_causality_y_x - + instantaneous_causality) + total_interdependence = ( + directional_causality_x_y + directional_causality_y_x + instantaneous_causality + ) spectral_causality = Causality( directional_causality_x_y=directional_causality_x_y, directional_causality_y_x=directional_causality_y_x, instantaneous_causality=instantaneous_causality, - total_interdependence=total_interdependence) + total_interdependence=total_interdependence, + ) return freqs, spectral_causality diff --git a/elephant/cell_assembly_detection.py b/elephant/cell_assembly_detection.py index bab3ef47b..678e8fa4f 100644 --- a/elephant/cell_assembly_detection.py +++ b/elephant/cell_assembly_detection.py @@ -77,17 +77,22 @@ import elephant.conversion as conv -__all__ = [ - "cell_assembly_detection" -] - - -def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, - alpha=0.05, min_occurrences=1, size_chunks=100, - max_spikes=np.inf, significance_pruning=True, - subgroup_pruning=True, - same_configuration_pruning=False, - verbose=False): +__all__ = ["cell_assembly_detection"] + + +def cell_assembly_detection( + binned_spiketrain, + max_lag, + reference_lag=2, + alpha=0.05, + min_occurrences=1, + size_chunks=100, + max_spikes=np.inf, + significance_pruning=True, + subgroup_pruning=True, + same_configuration_pruning=False, + verbose=False, +): """ Perform the CAD analysis :cite:`cad-Russo2017_e19428` for the binned (discretized) spike trains given in the input. The method looks for @@ -196,12 +201,14 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, initial_time = time.time() # check parameter input and raise errors if necessary - _raise_errors(binned_spiketrain=binned_spiketrain, - max_lag=max_lag, - alpha=alpha, - min_occurrences=min_occurrences, - size_chunks=size_chunks, - max_spikes=max_spikes) + _raise_errors( + binned_spiketrain=binned_spiketrain, + max_lag=max_lag, + alpha=alpha, + min_occurrences=min_occurrences, + size_chunks=size_chunks, + max_spikes=max_spikes, + ) bin_size = binned_spiketrain.bin_size t_start = binned_spiketrain.t_start @@ -214,21 +221,26 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, # initialize empty assembly - assembly_in = [{'neurons': None, - 'lags': None, - 'pvalue': None, - 'times': None, - 'signature': None} for _ in range(n_neurons)] + assembly_in = [ + { + "neurons": None, + "lags": None, + "pvalue": None, + "times": None, + "signature": None, + } + for _ in range(n_neurons) + ] # initializing the dictionaries if verbose: - print('Initializing the dictionaries...') + print("Initializing the dictionaries...") for w1 in range(n_neurons): - assembly_in[w1]['neurons'] = [w1] - assembly_in[w1]['lags'] = [] - assembly_in[w1]['pvalue'] = [] - assembly_in[w1]['times'] = binned_spiketrain[w1] - assembly_in[w1]['signature'] = [[1, sum(binned_spiketrain[w1])]] + assembly_in[w1]["neurons"] = [w1] + assembly_in[w1]["lags"] = [] + assembly_in[w1]["pvalue"] = [] + assembly_in[w1]["times"] = binned_spiketrain[w1] + assembly_in[w1]["signature"] = [[1, sum(binned_spiketrain[w1])]] # first order = test over pairs @@ -243,13 +255,13 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, alph = alpha alpha = alph * 2 / float(number_test_performed) if verbose: - print('actual significance_level', alpha) + print("actual significance_level", alpha) # sign_pairs_matrix is the matrix with entry as 1 for the significant pairs sign_pairs_matrix = np.zeros((n_neurons, n_neurons), dtype=int) assembly = [] if verbose: - print('Testing on pairs...') + print("Testing on pairs...") # nns: count of the existing assemblies nns = 0 @@ -275,7 +287,8 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, size_chunks=size_chunks, reference_lag=reference_lag, existing_patterns=existing_patterns, - same_configuration_pruning=same_configuration_pruning) + same_configuration_pruning=same_configuration_pruning, + ) if same_configuration_pruning: assem_tp = call_tp[0] else: @@ -283,8 +296,10 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, # if the assembly given in output is significant and the number # of occurrences is higher than the minimum requested number - if assem_tp['pvalue'][-1] < alpha and \ - assem_tp['signature'][-1][1] > min_occurrences: + if ( + assem_tp["pvalue"][-1] < alpha + and assem_tp["signature"][-1][1] > min_occurrences + ): # save the assembly in the output assembly.append(assem_tp) sign_pairs_matrix[w1][w2] = 1 @@ -309,7 +324,7 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, # the algorithm will return assemblies composed by # maximum max_spikes elements if verbose: - print('\nTesting on higher order assemblies...\n') + print("\nTesting on higher order assemblies...\n") # keep the count of the current size of the assembly current_size_agglomeration = 2 @@ -326,8 +341,7 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, w1 = 0 while w1 < n_as: - - w1_elements = assembly[w1]['neurons'] + w1_elements = assembly[w1]["neurons"] # Add only neurons that have significant first order # co-occurrences with members of the assembly @@ -341,13 +355,11 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, # list with the elements to test # that are not already in the assembly - w2_to_test = [item for item in w2_to_test_p - if item not in w1_elements] + w2_to_test = [item for item in w2_to_test_p if item not in w1_elements] pop_flag = 0 # check that there are candidate neurons for agglomeration if w2_to_test: - # bonferroni correction only for the tests actually performed alpha = alph / float(len(w2_to_test) * n_as * (2 * max_lag + 1)) @@ -367,7 +379,8 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, size_chunks=size_chunks, reference_lag=reference_lag, existing_patterns=existing_patterns, - same_configuration_pruning=same_configuration_pruning) + same_configuration_pruning=same_configuration_pruning, + ) if same_configuration_pruning: assem_tp = call_tp[0] else: @@ -376,26 +389,32 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, # if it is significant and # the number of occurrences is sufficient and # the length of the assembly is less than the input limit - if assem_tp['pvalue'][-1] < alpha and \ - assem_tp['signature'][-1][1] > min_occurrences and \ - assem_tp['signature'][-1][0] <= max_spikes: + if ( + assem_tp["pvalue"][-1] < alpha + and assem_tp["signature"][-1][1] > min_occurrences + and assem_tp["signature"][-1][0] <= max_spikes + ): # the assembly is saved in the output list of # assemblies assembly.append(assem_tp) assembly_flag = 1 - if len(assem_tp['neurons']) > current_size_agglomeration: + if len(assem_tp["neurons"]) > current_size_agglomeration: # up to the next agglomeration level current_size_agglomeration += 1 # Pruning step 1 # between two assemblies with the same unit set # arranged into different # configurations, choose the most significant one - if significance_pruning is True and \ - current_size_agglomeration > 3: - assembly, n_filtered_assemblies = \ + if ( + significance_pruning is True + and current_size_agglomeration > 3 + ): + assembly, n_filtered_assemblies = ( _significance_pruning_step( - pre_pruning_assembly=assembly) + pre_pruning_assembly=assembly + ) + ) if same_configuration_pruning: item_candidate = call_tp[1] existing_patterns.append(item_candidate) @@ -426,23 +445,21 @@ def cell_assembly_detection(binned_spiketrain, max_lag, reference_lag=2, # Reformat of the activation times for pattern in assembly: - times = np.where(pattern['times'] > 0)[0] * bin_size + t_start - pattern['times'] = times - pattern['lags'] = pattern['lags'] * bin_size - pattern['signature'] = np.array(pattern['signature'], dtype=int) + times = np.where(pattern["times"] > 0)[0] * bin_size + t_start + pattern["times"] = times + pattern["lags"] = pattern["lags"] * bin_size + pattern["signature"] = np.array(pattern["signature"], dtype=int) # Give as output only the maximal groups if verbose: - print('\nGiving outputs of the method...\n') - print('final_assembly') + print("\nGiving outputs of the method...\n") + print("final_assembly") for item in assembly: - print(item['neurons'], - item['lags'], - item['signature']) + print(item["neurons"], item["lags"], item["signature"]) # Time needed for the computation if verbose: - print('\ntime', time.time() - initial_time) + print("\ntime", time.time() - initial_time) return assembly @@ -470,7 +487,9 @@ def _chunking(binned_pair, size_chunks, max_lag, best_lag): number of chunks """ - length = len(binned_pair[0], ) + length = len( + binned_pair[0], + ) # number of chunks n_chunks = math.ceil((length - max_lag) / size_chunks) @@ -485,34 +504,33 @@ def _chunking(binned_pair, size_chunks, max_lag, best_lag): # cut the time series according to best_lag - binned_pair_cut = np.array([np.zeros(length - max_lag, dtype=int), - np.zeros(length - max_lag, dtype=int)]) + binned_pair_cut = np.array( + [np.zeros(length - max_lag, dtype=int), np.zeros(length - max_lag, dtype=int)] + ) # choose which entries to consider according to the best lag chosen if best_lag == 0: - binned_pair_cut[0] = binned_pair[0][0:length - max_lag] - binned_pair_cut[1] = binned_pair[1][0:length - max_lag] + binned_pair_cut[0] = binned_pair[0][0 : length - max_lag] + binned_pair_cut[1] = binned_pair[1][0 : length - max_lag] elif best_lag > 0: - binned_pair_cut[0] = binned_pair[0][0:length - max_lag] - binned_pair_cut[1] = binned_pair[1][ - best_lag:length - max_lag + best_lag] + binned_pair_cut[0] = binned_pair[0][0 : length - max_lag] + binned_pair_cut[1] = binned_pair[1][best_lag : length - max_lag + best_lag] else: - binned_pair_cut[0] = binned_pair[0][ - -best_lag:length - max_lag - best_lag] - binned_pair_cut[1] = binned_pair[1][0:length - max_lag] + binned_pair_cut[0] = binned_pair[0][-best_lag : length - max_lag - best_lag] + binned_pair_cut[1] = binned_pair[1][0 : length - max_lag] # put the cut data into the chunked object for iii in range(n_chunks - 1): chunked[iii][0] = binned_pair_cut[0][ - size_chunks * iii:size_chunks * (iii + 1)] + size_chunks * iii : size_chunks * (iii + 1) + ] chunked[iii][1] = binned_pair_cut[1][ - size_chunks * iii:size_chunks * (iii + 1)] + size_chunks * iii : size_chunks * (iii + 1) + ] # last chunk can be of slightly different size - chunked[n_chunks - 1][0] = binned_pair_cut[0][ - size_chunks * (n_chunks - 1):length] - chunked[n_chunks - 1][1] = binned_pair_cut[1][ - size_chunks * (n_chunks - 1):length] + chunked[n_chunks - 1][0] = binned_pair_cut[0][size_chunks * (n_chunks - 1) : length] + chunked[n_chunks - 1][1] = binned_pair_cut[1][size_chunks * (n_chunks - 1) : length] return chunked, n_chunks @@ -539,16 +557,25 @@ def _assert_same_pattern(item_candidate, existing_patterns, max_lag): """ # unique representation of pattern in term of lags, maxlag and neurons # participating - item_candidate = sorted(item_candidate[0] * 2 * max_lag + - item_candidate[1] + max_lag) + item_candidate = sorted( + item_candidate[0] * 2 * max_lag + item_candidate[1] + max_lag + ) if item_candidate in existing_patterns: return True else: return False -def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, - existing_patterns, same_configuration_pruning): +def _test_pair( + ensemble, + spiketrain2, + n2, + max_lag, + size_chunks, + reference_lag, + existing_patterns, + same_configuration_pruning, +): """ Tests if two spike trains have repetitive patterns occurring more frequently than chance. @@ -603,7 +630,7 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, """ # list with the binned spike trains of the two neurons - binned_pair = [ensemble['times'], spiketrain2] + binned_pair = [ensemble["times"], spiketrain2] # For large bin_sizes, the binned spike counts may potentially fluctuate # around a high mean level and never fall below some minimum count @@ -612,8 +639,9 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, # to the coincidence count although they are completely # uninformative, so we subtract the minima. - binned_pair = np.array([binned_pair[0] - min(binned_pair[0]), - binned_pair[1] - min(binned_pair[1])]) + binned_pair = np.array( + [binned_pair[0] - min(binned_pair[0]), binned_pair[1] - min(binned_pair[1])] + ) ntp = len(binned_pair[0]) # trial length @@ -628,8 +656,9 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, for i in range(maxrate): par_processes[i] = np.array(binned_pair > i, dtype=int) - par_proc_expectation[i] = (np.sum(par_processes[i][0]) * np.sum( - par_processes[i][1])) / float(ntp) + par_proc_expectation[i] = ( + np.sum(par_processes[i][0]) * np.sum(par_processes[i][1]) + ) / float(ntp) # Decide which is the lag with most coincidences (l_ : best lag) # we are calculating the joint spike count of units A and B at lag l. @@ -643,21 +672,23 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, bwd_coinc_count = np.array([0 for _ in range(max_lag + 1)]) for lag in range(max_lag + 1): - time_fwd_cc = np.array([binned_pair[0][ - 0:len(binned_pair[0]) - max_lag], - binned_pair[1][ - lag:len(binned_pair[1]) - max_lag + lag]]) - - time_bwd_cc = np.array([binned_pair[0][ - lag:len(binned_pair[0]) - max_lag + lag], - binned_pair[1][ - 0:len(binned_pair[1]) - max_lag]]) + time_fwd_cc = np.array( + [ + binned_pair[0][0 : len(binned_pair[0]) - max_lag], + binned_pair[1][lag : len(binned_pair[1]) - max_lag + lag], + ] + ) + + time_bwd_cc = np.array( + [ + binned_pair[0][lag : len(binned_pair[0]) - max_lag + lag], + binned_pair[1][0 : len(binned_pair[1]) - max_lag], + ] + ) # taking the minimum, place by place for the coincidences - fwd_coinc_count[lag] = np.sum(np.minimum(time_fwd_cc[0], - time_fwd_cc[1])) - bwd_coinc_count[lag] = np.sum(np.minimum(time_bwd_cc[0], - time_bwd_cc[1])) + fwd_coinc_count[lag] = np.sum(np.minimum(time_fwd_cc[0], time_fwd_cc[1])) + bwd_coinc_count[lag] = np.sum(np.minimum(time_bwd_cc[0], time_bwd_cc[1])) # choice of the best lag, taking into account the reference lag if reference_lag <= 0: @@ -670,12 +701,12 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, fwd_flag = 2 global_maximum_index = np.argmax(bwd_coinc_count) best_lag = (fwd_flag == 1) * global_maximum_index - ( - fwd_flag == 2) * global_maximum_index - max_coinc_count = max(np.amax(fwd_coinc_count), - np.amax(bwd_coinc_count)) + fwd_flag == 2 + ) * global_maximum_index + max_coinc_count = max(np.amax(fwd_coinc_count), np.amax(bwd_coinc_count)) else: # reverse the ctAB_ object and not take into account the first entry - bwd_coinc_count_rev = bwd_coinc_count[1:len(bwd_coinc_count)][::-1] + bwd_coinc_count_rev = bwd_coinc_count[1 : len(bwd_coinc_count)][::-1] hab_l = np.append(bwd_coinc_count_rev, fwd_coinc_count) lags = range(-max_lag, max_lag + 1) max_coinc_count = np.amax(hab_l) @@ -692,13 +723,13 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, # is already in the list of the significant patterns # if it is, don't do the testing # if it is not, continue - previous_neu = ensemble['neurons'] + previous_neu = ensemble["neurons"] pattern_candidate = copy.copy(previous_neu) pattern_candidate.append(n2) pattern_candidate = np.array(pattern_candidate) # add both the new lag and zero - previous_lags = ensemble['lags'] + previous_lags = ensemble["lags"] lags_candidate = copy.copy(previous_lags) lags_candidate.append(best_lag) @@ -708,23 +739,27 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, item_candidate = [[pattern_candidate], [lags_candidate]] if same_configuration_pruning: - if _assert_same_pattern(item_candidate=item_candidate, - existing_patterns=existing_patterns, - max_lag=max_lag): - en_neurons = copy.copy(ensemble['neurons']) + if _assert_same_pattern( + item_candidate=item_candidate, + existing_patterns=existing_patterns, + max_lag=max_lag, + ): + en_neurons = copy.copy(ensemble["neurons"]) en_neurons.append(n2) - en_lags = copy.copy(ensemble['lags']) + en_lags = copy.copy(ensemble["lags"]) en_lags.append(np.inf) - en_pvalue = copy.copy(ensemble['pvalue']) + en_pvalue = copy.copy(ensemble["pvalue"]) en_pvalue.append(1) - en_n_occ = copy.copy(ensemble['signature']) + en_n_occ = copy.copy(ensemble["signature"]) en_n_occ.append([0, 0]) item_candidate = [] - assembly = {'neurons': en_neurons, - 'lags': en_lags, - 'pvalue': en_pvalue, - 'times': [], - 'signature': en_n_occ} + assembly = { + "neurons": en_neurons, + "lags": en_lags, + "pvalue": en_pvalue, + "times": [], + "signature": en_n_occ, + } return assembly, item_candidate else: # I go on with the testing @@ -732,22 +767,27 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, pair_expectation = np.sum(par_proc_expectation) # case of no coincidences or limit for the F asimptotical # distribution (too few coincidences) - if max_coinc_count == 0 or pair_expectation <= 5 or \ - pair_expectation >= (min(np.sum(binned_pair[0]), - np.sum(binned_pair[1])) - 5): - en_neurons = copy.copy(ensemble['neurons']) + if ( + max_coinc_count == 0 + or pair_expectation <= 5 + or pair_expectation + >= (min(np.sum(binned_pair[0]), np.sum(binned_pair[1])) - 5) + ): + en_neurons = copy.copy(ensemble["neurons"]) en_neurons.append(n2) - en_lags = copy.copy(ensemble['lags']) + en_lags = copy.copy(ensemble["lags"]) en_lags.append(np.inf) - en_pvalue = copy.copy(ensemble['pvalue']) + en_pvalue = copy.copy(ensemble["pvalue"]) en_pvalue.append(1) - en_n_occ = copy.copy(ensemble['signature']) + en_n_occ = copy.copy(ensemble["signature"]) en_n_occ.append([0, 0]) - assembly = {'neurons': en_neurons, - 'lags': en_lags, - 'pvalue': en_pvalue, - 'times': [], - 'signature': en_n_occ} + assembly = { + "neurons": en_neurons, + "lags": en_lags, + "pvalue": en_pvalue, + "times": [], + "signature": en_n_occ, + } if same_configuration_pruning: item_candidate = [] return assembly, item_candidate @@ -762,12 +802,13 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, for i in range(maxrate): # for all parallel processes par_processes_a = par_processes[i][0] par_processes_b = par_processes[i][1] - activation_series = \ - np.add(activation_series, - np.multiply(par_processes_a, - par_processes_b)) - coinc_count_matrix = np.array([[0, fwd_coinc_count[0]], - [bwd_coinc_count[2], 0]]) + activation_series = np.add( + activation_series, + np.multiply(par_processes_a, par_processes_b), + ) + coinc_count_matrix = np.array( + [[0, fwd_coinc_count[0]], [bwd_coinc_count[2], 0]] + ) # matrix with #AB and #BA # here we specifically choose # 'l* = -2' for the synchrony case @@ -777,66 +818,79 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, par_processes_b = par_processes[i][1] # multiplication between the two binned time series # shifted by best_lag - activation_series[0:length - best_lag] = \ - np.add(activation_series[0:length - best_lag], - np.multiply(par_processes_a[ - 0:length - best_lag], - par_processes_b[ - best_lag:length])) - coinc_count_matrix = \ - np.array([[0, fwd_coinc_count[global_maximum_index]], - [bwd_coinc_count[global_maximum_index], 0]]) + activation_series[0 : length - best_lag] = np.add( + activation_series[0 : length - best_lag], + np.multiply( + par_processes_a[0 : length - best_lag], + par_processes_b[best_lag:length], + ), + ) + coinc_count_matrix = np.array( + [ + [0, fwd_coinc_count[global_maximum_index]], + [bwd_coinc_count[global_maximum_index], 0], + ] + ) else: for i in range(maxrate): par_processes_a = par_processes[i][0] par_processes_b = par_processes[i][1] - activation_series[-best_lag:length] = \ - np.add(activation_series[-best_lag:length], - np.multiply(par_processes_a[ - -best_lag:length], - par_processes_b[ - 0:length + best_lag])) - coinc_count_matrix = \ - np.array([[0, fwd_coinc_count[global_maximum_index]], - [bwd_coinc_count[global_maximum_index], 0]]) + activation_series[-best_lag:length] = np.add( + activation_series[-best_lag:length], + np.multiply( + par_processes_a[-best_lag:length], + par_processes_b[0 : length + best_lag], + ), + ) + coinc_count_matrix = np.array( + [ + [0, fwd_coinc_count[global_maximum_index]], + [bwd_coinc_count[global_maximum_index], 0], + ] + ) else: if best_lag == 0: for i in range(maxrate): par_processes_a = par_processes[i][0] par_processes_b = par_processes[i][1] - activation_series = \ - np.add(activation_series, - np.multiply(par_processes_a, - par_processes_b)) + activation_series = np.add( + activation_series, + np.multiply(par_processes_a, par_processes_b), + ) elif best_lag > 0: for i in range(maxrate): par_processes_a = par_processes[i][0] par_processes_b = par_processes[i][1] - activation_series[0:length - best_lag] = \ - np.add(activation_series[0:length - best_lag], - np.multiply(par_processes_a[ - 0:length - best_lag], - par_processes_b[ - best_lag:length])) + activation_series[0 : length - best_lag] = np.add( + activation_series[0 : length - best_lag], + np.multiply( + par_processes_a[0 : length - best_lag], + par_processes_b[best_lag:length], + ), + ) else: for i in range(maxrate): par_processes_a = par_processes[i][0] par_processes_b = par_processes[i][1] - activation_series[-best_lag:length] = \ - np.add(activation_series[-best_lag:length], - np.multiply(par_processes_a[ - -best_lag:length], - par_processes_b[ - 0:length + best_lag])) - coinc_count_matrix = np.array([[0, max_coinc_count], - [coinc_count_ref, 0]]) + activation_series[-best_lag:length] = np.add( + activation_series[-best_lag:length], + np.multiply( + par_processes_a[-best_lag:length], + par_processes_b[0 : length + best_lag], + ), + ) + coinc_count_matrix = np.array( + [[0, max_coinc_count], [coinc_count_ref, 0]] + ) # chunking - chunked, nch = _chunking(binned_pair=binned_pair, - size_chunks=size_chunks, - max_lag=max_lag, - best_lag=best_lag) + chunked, nch = _chunking( + binned_pair=binned_pair, + size_chunks=size_chunks, + max_lag=max_lag, + best_lag=best_lag, + ) marginal_counts = np.zeros((nch, maxrate, 2), dtype=int) @@ -854,25 +908,26 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, for iii in range(nch): binned_pair_chunked = np.array(chunked[iii]) - maxrate_t[iii] = max(max(binned_pair_chunked[0]), - max(binned_pair_chunked[1])) + maxrate_t[iii] = max( + max(binned_pair_chunked[0]), max(binned_pair_chunked[1]) + ) ch_nn[iii] = len(chunked[iii][0]) - par_processes_chunked = [None for _ in range( - int(maxrate_t[iii]))] + par_processes_chunked = [None for _ in range(int(maxrate_t[iii]))] for i in range(int(maxrate_t[iii])): par_processes_chunked[i] = np.zeros( - (2, len(binned_pair_chunked[0])), dtype=int) - par_processes_chunked[i] = np.array(binned_pair_chunked > i, - dtype=int) + (2, len(binned_pair_chunked[0])), dtype=int + ) + par_processes_chunked[i] = np.array(binned_pair_chunked > i, dtype=int) for i in range(int(maxrate_t[iii])): par_processes_a = par_processes_chunked[i][0] par_processes_b = par_processes_chunked[i][1] marginal_counts[iii][i][0] = int(np.sum(par_processes_a)) marginal_counts[iii][i][1] = int(np.sum(par_processes_b)) - count_sum = count_sum + min(marginal_counts[iii][i][0], - marginal_counts[iii][i][1]) + count_sum = count_sum + min( + marginal_counts[iii][i][0], marginal_counts[iii][i][1] + ) # marginal_counts[iii][i] has in its entries # '[ #_a^{\alpha,c} , #_b^{\alpha,c}]' @@ -893,20 +948,21 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, # evaluation of AB + variance and covariance - cov_abab[iii] = [[0 for _ in range(maxrate_t[iii])] - for _ in range(maxrate_t[iii])] + cov_abab[iii] = [ + [0 for _ in range(maxrate_t[iii])] for _ in range(maxrate_t[iii]) + ] # for every rate up to the maxrate in that chunk for i in range(maxrate_t[iii]): - par_marg_counts_i = \ - np.outer(marginal_counts[iii][i], np.ones(2)) + par_marg_counts_i = np.outer(marginal_counts[iii][i], np.ones(2)) - cov_abab[iii][i][i] = \ + cov_abab[iii][i][i] = np.multiply( + np.multiply(par_marg_counts_i, par_marg_counts_i.T) + / float(ch_size), np.multiply( - np.multiply(par_marg_counts_i, par_marg_counts_i.T) - / float(ch_size), - np.multiply(ch_size - par_marg_counts_i, - ch_size - par_marg_counts_i.T) - / float(ch_size * (ch_size - 1))) + ch_size - par_marg_counts_i, ch_size - par_marg_counts_i.T + ) + / float(ch_size * (ch_size - 1)), + ) # calculation of the variance var_t[iii] = var_t[iii] + cov_abab[iii][i][i] @@ -914,50 +970,55 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, # cross covariances terms if maxrate_t[iii] > 1: for j in range(i + 1, maxrate_t[iii]): - par_marg_counts_j = \ - np.outer(marginal_counts[iii][j], np.ones(2)) - cov_abab[iii][i][j] = \ - 2 * np.multiply( - np.multiply(par_marg_counts_j, - par_marg_counts_j.T) - / float(ch_size), - np.multiply(ch_size - par_marg_counts_i, - ch_size - par_marg_counts_i.T) - / float(ch_size * (ch_size - 1))) + par_marg_counts_j = np.outer( + marginal_counts[iii][j], np.ones(2) + ) + cov_abab[iii][i][j] = 2 * np.multiply( + np.multiply(par_marg_counts_j, par_marg_counts_j.T) + / float(ch_size), + np.multiply( + ch_size - par_marg_counts_i, + ch_size - par_marg_counts_i.T, + ) + / float(ch_size * (ch_size - 1)), + ) # update of the variance var_t[iii] = var_t[iii] + cov_abab[iii][i][j] # evaluation of coinc_count_matrix = #AB - #BA - cov_abba[iii] = [[0 for _ in range(maxrate_t[iii])] - for _ in range(maxrate_t[iii])] + cov_abba[iii] = [ + [0 for _ in range(maxrate_t[iii])] for _ in range(maxrate_t[iii]) + ] for i in range(maxrate_t[iii]): - par_marg_counts_i = \ - np.outer(marginal_counts[iii][i], np.ones(2)) - cov_abba[iii][i][i] = \ + par_marg_counts_i = np.outer(marginal_counts[iii][i], np.ones(2)) + cov_abba[iii][i][i] = np.multiply( + np.multiply(par_marg_counts_i, par_marg_counts_i.T) + / float(ch_size), np.multiply( - np.multiply(par_marg_counts_i, par_marg_counts_i.T) - / float(ch_size), - np.multiply(ch_size - par_marg_counts_i, - ch_size - par_marg_counts_i.T) - / float(ch_size * (ch_size - 1) ** 2)) + ch_size - par_marg_counts_i, ch_size - par_marg_counts_i.T + ) + / float(ch_size * (ch_size - 1) ** 2), + ) cov_x[iii] = cov_x[iii] + cov_abba[iii][i][i] if maxrate_t[iii] > 1: for j in range((i + 1), maxrate_t[iii]): - par_marg_counts_j = \ - np.outer(marginal_counts[iii][j], np.ones(2)) - - cov_abba[iii][i][j] = \ - 2 * np.multiply( - np.multiply(par_marg_counts_j, - par_marg_counts_j.T) - / float(ch_size), - np.multiply(ch_size - par_marg_counts_i, - ch_size - par_marg_counts_i.T) - / float(ch_size * (ch_size - 1) ** 2)) + par_marg_counts_j = np.outer( + marginal_counts[iii][j], np.ones(2) + ) + + cov_abba[iii][i][j] = 2 * np.multiply( + np.multiply(par_marg_counts_j, par_marg_counts_j.T) + / float(ch_size), + np.multiply( + ch_size - par_marg_counts_i, + ch_size - par_marg_counts_i.T, + ) + / float(ch_size * (ch_size - 1) ** 2), + ) cov_x[iii] = cov_x[iii] + cov_abba[iii][i][j] @@ -973,23 +1034,25 @@ def _test_pair(ensemble, spiketrain2, n2, max_lag, size_chunks, reference_lag, # p-value obtained through approximation to a Fischer F distribution # (here we employ the survival function) else: - fstat = coinc_count_matrix ** 2 / var_tot + fstat = coinc_count_matrix**2 / var_tot pr_f = f.sf(fstat[0][1], 1, n) # Creation of the dictionary with the results - en_neurons = copy.copy(ensemble['neurons']) + en_neurons = copy.copy(ensemble["neurons"]) en_neurons.append(n2) - en_lags = copy.copy(ensemble['lags']) + en_lags = copy.copy(ensemble["lags"]) en_lags.append(best_lag) - en_pvalue = copy.copy(ensemble['pvalue']) + en_pvalue = copy.copy(ensemble["pvalue"]) en_pvalue.append(pr_f) - en_n_occ = copy.copy(ensemble['signature']) + en_n_occ = copy.copy(ensemble["signature"]) en_n_occ.append([len(en_neurons), sum(activation_series)]) - assembly = {'neurons': en_neurons, - 'lags': en_lags, - 'pvalue': en_pvalue, - 'times': activation_series, - 'signature': en_n_occ} + assembly = { + "neurons": en_neurons, + "lags": en_lags, + "pvalue": en_pvalue, + "times": activation_series, + "signature": en_n_occ, + } if same_configuration_pruning: return assembly, item_candidate else: @@ -1024,14 +1087,13 @@ def _significance_pruning_step(pre_pruning_assembly): assembly = [] for i in range(nns): - elem = sorted(pre_pruning_assembly[i]['neurons']) + elem = sorted(pre_pruning_assembly[i]["neurons"]) # in the list, so that membership can be checked if elem in selection: # find the element that was already in the list pre = selection.index(elem) - if pre_pruning_assembly[i]['pvalue'][-1] <= \ - assembly[pre]['pvalue'][-1]: + if pre_pruning_assembly[i]["pvalue"][-1] <= assembly[pre]["pvalue"][-1]: # if the new element has a p-value that is smaller # than the one had previously selection[pre] = elem @@ -1075,10 +1137,10 @@ def _subgroup_pruning_step(pre_pruning_assembly): for i in range(nns): # check only in the range of the already selected assemblies if selection[i]: - a = pre_pruning_assembly_r[i]['neurons'] + a = pre_pruning_assembly_r[i]["neurons"] for j in range(i + 1, nns): if selection[j]: - b = pre_pruning_assembly_r[j]['neurons'] + b = pre_pruning_assembly_r[j]["neurons"] # check if a is included in b or vice versa if set(a).issuperset(set(b)): selection[j] = False @@ -1101,8 +1163,9 @@ def _subgroup_pruning_step(pre_pruning_assembly): return assembly -def _raise_errors(binned_spiketrain, max_lag, alpha, min_occurrences, - size_chunks, max_spikes): +def _raise_errors( + binned_spiketrain, max_lag, alpha, min_occurrences, size_chunks, max_spikes +): """ Returns errors if the parameters given in input are not correct. @@ -1144,29 +1207,31 @@ def _raise_errors(binned_spiketrain, max_lag, alpha, min_occurrences, """ if not isinstance(binned_spiketrain, conv.BinnedSpikeTrain): - raise TypeError( - 'data must be in BinnedSpikeTrain format') + raise TypeError("data must be in BinnedSpikeTrain format") if max_lag < 2: - raise ValueError('max_lag value cant be less than 2') + raise ValueError("max_lag value cant be less than 2") if alpha < 0 or alpha > 1: - raise ValueError('significance level has to be in interval [0,1]') + raise ValueError("significance level has to be in interval [0,1]") if min_occurrences < 1: - raise ValueError('minimal number of occurrences for an assembly ' - 'must be at least 1') + raise ValueError( + "minimal number of occurrences for an assembly " "must be at least 1" + ) if size_chunks < 2: - raise ValueError('length of the chunks cannot be 1 or less') + raise ValueError("length of the chunks cannot be 1 or less") if max_spikes < 2: - raise ValueError('maximal assembly order must be less than 2') + raise ValueError("maximal assembly order must be less than 2") if binned_spiketrain.shape[1] - max_lag < 100: - raise ValueError('The time series is too short, consider ' - 'taking a longer portion of spike train ' - 'or diminish the bin size to be tested') + raise ValueError( + "The time series is too short, consider " + "taking a longer portion of spike train " + "or diminish the bin size to be tested" + ) # alias for the function diff --git a/elephant/change_point_detection.py b/elephant/change_point_detection.py index d42b30a66..5c8dcb382 100644 --- a/elephant/change_point_detection.py +++ b/elephant/change_point_detection.py @@ -45,15 +45,19 @@ import numpy as np import quantities as pq -__all__ = [ - "multiple_filter_test", - "empirical_parameters" -] - - -def multiple_filter_test(window_sizes, spiketrain, t_final, alpha, - n_surrogates=1000, test_quantile=None, - test_param=None, time_step=None): +__all__ = ["multiple_filter_test", "empirical_parameters"] + + +def multiple_filter_test( + window_sizes, + spiketrain, + t_final, + alpha, + n_surrogates=1000, + test_quantile=None, + test_param=None, + time_step=None, +): """ Detects change points. @@ -107,15 +111,17 @@ def multiple_filter_test(window_sizes, spiketrain, t_final, alpha, """ if test_quantile is None and test_param is None: - test_quantile, test_param = empirical_parameters(window_sizes, t_final, - alpha, n_surrogates, - time_step) + test_quantile, test_param = empirical_parameters( + window_sizes, t_final, alpha, n_surrogates, time_step + ) elif test_quantile is None: - test_quantile = empirical_parameters(window_sizes, t_final, alpha, - n_surrogates, time_step)[0] + test_quantile = empirical_parameters( + window_sizes, t_final, alpha, n_surrogates, time_step + )[0] elif test_param is None: - test_param = empirical_parameters(window_sizes, t_final, alpha, - n_surrogates, time_step)[1] + test_param = empirical_parameters( + window_sizes, t_final, alpha, n_surrogates, time_step + )[1] # List of lists of detected change points (CPs), to be returned cps = [] @@ -124,8 +130,7 @@ def multiple_filter_test(window_sizes, spiketrain, t_final, alpha, # automatic setting of time_step dt_temp = h / 20 if time_step is None else time_step # filter_process for window of size h - t, differences = _filter_process(dt_temp, h, spiketrain, t_final, - test_param) + t, differences = _filter_process(dt_temp, h, spiketrain, t_final, test_param) time_index = np.arange(len(differences)) # Point detected with window h cps_window = [] @@ -193,8 +198,7 @@ def _brownian_motion(t_in, t_fin, x_in, time_step): except ValueError: raise ValueError("dt must be a time quantity") - x = np.random.normal(0, np.sqrt(dt_sec), - size=int((t_fin_sec - t_in_sec) / dt_sec)) + x = np.random.normal(0, np.sqrt(dt_sec), size=int((t_fin_sec - t_in_sec) / dt_sec)) s = np.cumsum(x) return s + x_in @@ -237,11 +241,11 @@ def _limit_processes(window_sizes, t_final, time_step): for h in window_sizes_sec: # BM on [h,T-h], shifted in time t-->t+h - brownian_right = w[int(2 * h / dt_sec):] + brownian_right = w[int(2 * h / dt_sec) :] # BM on [h,T-h], shifted in time t-->t-h - brownian_left = w[:int(-2 * h / dt_sec)] + brownian_left = w[: int(-2 * h / dt_sec)] # BM on [h,T-h] - brownian_center = w[int(h / dt_sec):int(-h / dt_sec)] + brownian_center = w[int(h / dt_sec) : int(-h / dt_sec)] modul = np.abs(brownian_right + brownian_left - 2 * brownian_center) limit_process_h = modul / (np.sqrt(2 * h)) @@ -250,8 +254,9 @@ def _limit_processes(window_sizes, t_final, time_step): return limit_processes -def empirical_parameters(window_sizes, t_final, alpha, n_surrogates=1000, - time_step=None): +def empirical_parameters( + window_sizes, t_final, alpha, n_surrogates=1000, time_step=None +): r""" This function generates the threshold and the null parameters. The filter processes (`h`) have been proved to converge (for `t_final`, @@ -341,9 +346,8 @@ def empirical_parameters(window_sizes, t_final, alpha, n_surrogates=1000, raise ValueError("window size too large") if time_step is not None: for h in window_sizes: - if int(h.rescale('us')) % int(time_step.rescale('us')) != 0: - raise ValueError( - "Every window size h must be a multiple of time_step") + if int(h.rescale("us")) % int(time_step.rescale("us")) != 0: + raise ValueError("Every window size h must be a multiple of time_step") # Generate a matrix M*: n X m where n = n_surrogates is the number of # simulated limit processes and m is the number of chosen window sizes. @@ -418,7 +422,8 @@ def _filter(t_center, window, spiketrain): spk_sec = spiketrain.rescale(u).magnitude except AttributeError: raise ValueError( - "spiketrain must be a list (array) of times or a neo spiketrain") + "spiketrain must be a list (array) of times or a neo spiketrain" + ) # cut spike-train on the right train_right = spk_sec[(t_sec < spk_sec) & (spk_sec < t_sec + h_sec)] @@ -516,7 +521,6 @@ def _filter_process(time_step, h, spk, t_final, test_param): filter_trajectrory = np.asanyarray(filter_trajectrory) # ordered normalization to give each process the same impact on the max - filter_process = ( - np.abs(filter_trajectrory) - emp_mean_h) / np.sqrt(emp_var_h) + filter_process = (np.abs(filter_trajectrory) - emp_mean_h) / np.sqrt(emp_var_h) return time_domain, filter_process diff --git a/elephant/conversion.py b/elephant/conversion.py index f3686d643..d7d382dea 100644 --- a/elephant/conversion.py +++ b/elephant/conversion.py @@ -84,17 +84,19 @@ import quantities as pq import scipy.sparse as sps -from elephant.utils import is_binary, is_time_quantity, \ - check_neo_consistency, round_binning_errors +from elephant.utils import ( + is_binary, + is_time_quantity, + check_neo_consistency, + round_binning_errors, +) -__all__ = [ - "binarize", - "BinnedSpikeTrain" -] +__all__ = ["binarize", "BinnedSpikeTrain"] -def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None, - return_times=False): +def binarize( + spiketrain, sampling_rate=None, t_start=None, t_stop=None, return_times=False +): """ Return an array indicating if spikes occurred at individual time points. @@ -173,46 +175,52 @@ def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None, """ # get the values from spiketrain if they are not specified. if sampling_rate is None: - sampling_rate = getattr(spiketrain, 'sampling_rate', None) + sampling_rate = getattr(spiketrain, "sampling_rate", None) if sampling_rate is None: - raise ValueError('sampling_rate must either be explicitly defined ' - 'or must be an attribute of spiketrain') + raise ValueError( + "sampling_rate must either be explicitly defined " + "or must be an attribute of spiketrain" + ) if t_start is None: - t_start = getattr(spiketrain, 't_start', 0) + t_start = getattr(spiketrain, "t_start", 0) if t_stop is None: - t_stop = getattr(spiketrain, 't_stop', np.max(spiketrain)) + t_stop = getattr(spiketrain, "t_stop", np.max(spiketrain)) # we don't actually want the sampling rate, we want the sampling period - sampling_period = 1. / sampling_rate + sampling_period = 1.0 / sampling_rate # figure out what units, if any, we are dealing with - if hasattr(spiketrain, 'units'): + if hasattr(spiketrain, "units"): units = spiketrain.units spiketrain = spiketrain.magnitude else: units = None # convert everything to the same units, then get the magnitude - if hasattr(sampling_period, 'units'): + if hasattr(sampling_period, "units"): if units is None: - raise TypeError('sampling_period cannot be a Quantity if ' - 'spiketrain is not a quantity') + raise TypeError( + "sampling_period cannot be a Quantity if " + "spiketrain is not a quantity" + ) sampling_period = sampling_period.rescale(units).magnitude - if hasattr(t_start, 'units'): + if hasattr(t_start, "units"): if units is None: - raise TypeError('t_start cannot be a Quantity if ' - 'spiketrain is not a quantity') + raise TypeError( + "t_start cannot be a Quantity if " "spiketrain is not a quantity" + ) t_start = t_start.rescale(units).magnitude - if hasattr(t_stop, 'units'): + if hasattr(t_stop, "units"): if units is None: - raise TypeError('t_stop cannot be a Quantity if ' - 'spiketrain is not a quantity') + raise TypeError( + "t_stop cannot be a Quantity if " "spiketrain is not a quantity" + ) t_stop = t_stop.rescale(units).magnitude # figure out the bin edges - edges = np.arange(t_start - sampling_period / 2, - t_stop + sampling_period * 3 / 2, - sampling_period) + edges = np.arange( + t_start - sampling_period / 2, t_stop + sampling_period * 3 / 2, sampling_period + ) # we don't want to count any spikes before t_start or after t_stop if edges[-2] > t_stop: edges = edges[:-1] @@ -222,16 +230,16 @@ def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None, edges[-1] = t_stop # this is where we actually get the binarized spike train - res = np.histogram(spiketrain, edges)[0].astype('bool') + res = np.histogram(spiketrain, edges)[0].astype("bool") # figure out what to output if not return_times: return res if units is None: - return res, np.arange(t_start, t_stop + sampling_period, - sampling_period) - return res, pq.Quantity(np.arange(t_start, t_stop + sampling_period, - sampling_period), units=units) + return res, np.arange(t_start, t_stop + sampling_period, sampling_period) + return res, pq.Quantity( + np.arange(t_start, t_stop + sampling_period, sampling_period), units=units + ) ########################################################################### @@ -334,11 +342,21 @@ class BinnedSpikeTrain(object): """ - def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None, - t_stop=None, tolerance=1e-8, sparse_format="csr"): + def __init__( + self, + spiketrains, + bin_size=None, + n_bins=None, + t_start=None, + t_stop=None, + tolerance=1e-8, + sparse_format="csr", + ): if sparse_format not in ("csr", "csc"): - raise ValueError(f"Invalid 'sparse_format': {sparse_format}. " - "Available: 'csr' and 'csc'") + raise ValueError( + f"Invalid 'sparse_format': {sparse_format}. " + "Available: 'csr' and 'csc'" + ) # Converting spiketrains to a list, if spiketrains is one # SpikeTrain object @@ -356,7 +374,8 @@ def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None, self._resolve_input_parameters(spiketrains) # Now create the sparse matrix self.sparse_matrix = self._create_sparse_matrix( - spiketrains, sparse_format=sparse_format) + spiketrains, sparse_format=sparse_format + ) @property def shape(self): @@ -387,10 +406,12 @@ def t_stop(self): return pq.Quantity(self._t_stop, units=self.units, copy=False) def __repr__(self): - return f"{type(self).__name__}(t_start={str(self.t_start)}, " \ - f"t_stop={str(self.t_stop)}, bin_size={str(self.bin_size)}; " \ - f"shape={self.shape}, " \ - f"format={self.sparse_matrix.__class__.__name__})" + return ( + f"{type(self).__name__}(t_start={str(self.t_start)}, " + f"t_stop={str(self.t_stop)}, bin_size={str(self.bin_size)}; " + f"shape={self.shape}, " + f"format={self.sparse_matrix.__class__.__name__})" + ) def rescale(self, units): """ @@ -422,21 +443,27 @@ def rescale(self, units): def __resolve_binned(self, spiketrains): spiketrains = np.asarray(spiketrains) - if spiketrains.ndim != 2 or spiketrains.dtype == np.dtype('O'): - raise ValueError("If the input is not a spiketrain(s), it " - "must be an MxN numpy array, each cell of " - "which represents the number of (binned) " - "spikes that fall in an interval - not " - "raw spike times.") + if spiketrains.ndim != 2 or spiketrains.dtype == np.dtype("O"): + raise ValueError( + "If the input is not a spiketrain(s), it " + "must be an MxN numpy array, each cell of " + "which represents the number of (binned) " + "spikes that fall in an interval - not " + "raw spike times." + ) if self.n_bins is not None: - raise ValueError("When the input is a binned matrix, 'n_bins' " - "must be set to None - it's extracted from the " - "input shape.") + raise ValueError( + "When the input is a binned matrix, 'n_bins' " + "must be set to None - it's extracted from the " + "input shape." + ) self.n_bins = spiketrains.shape[1] if self._bin_size is None: if self._t_start is None or self._t_stop is None: - raise ValueError("To determine the bin size, both 't_start' " - "and 't_stop' must be set") + raise ValueError( + "To determine the bin size, both 't_start' " + "and 't_stop' must be set" + ) self._bin_size = (self._t_stop - self._t_start) / self.n_bins if self._t_start is None and self._t_stop is None: raise ValueError("Either 't_start' or 't_stop' must be set") @@ -458,6 +485,7 @@ def _resolve_input_parameters(self, spiketrains): spiketrains : neo.SpikeTrain or list or np.ndarray of neo.SpikeTrain """ + def get_n_bins(): n_bins = (self._t_stop - self._t_start) / self._bin_size if isinstance(n_bins, pq.Quantity): @@ -471,15 +499,22 @@ def check_n_bins_consistency(): "Inconsistent arguments: t_start ({t_start}), " "t_stop ({t_stop}), bin_size ({bin_size}), and " "n_bins ({n_bins})".format( - t_start=self.t_start, t_stop=self.t_stop, - bin_size=self.bin_size, n_bins=self.n_bins)) + t_start=self.t_start, + t_stop=self.t_stop, + bin_size=self.bin_size, + n_bins=self.n_bins, + ) + ) def check_consistency(): if self.t_start >= self.t_stop: raise ValueError("t_start must be smaller than t_stop") if not isinstance(self.n_bins, int) or self.n_bins <= 0: - raise TypeError("The number of bins ({}) must be a positive " - "integer".format(self.n_bins)) + raise TypeError( + "The number of bins ({}) must be a positive " "integer".format( + self.n_bins + ) + ) if not _check_neo_spiketrain(spiketrains): # a binned numpy matrix @@ -496,20 +531,24 @@ def check_consistency(): raise ValueError("Either 'bin_size' or 'n_bins' must be given") try: - check_neo_consistency(spiketrains, - object_type=neo.SpikeTrain, - t_start=self._t_start, - t_stop=self._t_stop, - tolerance=self.tolerance) + check_neo_consistency( + spiketrains, + object_type=neo.SpikeTrain, + t_start=self._t_start, + t_stop=self._t_stop, + tolerance=self.tolerance, + ) except ValueError as er: # different t_start/t_stop - raise ValueError(er, "If you want to bin over the shared " - "[t_start, t_stop] interval, provide " - "shared t_start and t_stop explicitly, " - "which can be obtained like so: " - "t_start, t_stop = elephant.utils." - "get_common_start_stop_times(spiketrains)" - ) + raise ValueError( + er, + "If you want to bin over the shared " + "[t_start, t_stop] interval, provide " + "shared t_start and t_stop explicitly, " + "which can be obtained like so: " + "t_start, t_stop = elephant.utils." + "get_common_start_stop_times(spiketrains)", + ) if self._t_start is None: self._t_start = spiketrains[0].t_start @@ -523,22 +562,26 @@ def check_consistency(): self._t_start = self._t_start.rescale(self.units).item() self._t_stop = self._t_stop.rescale(self.units).item() - start_shared = max(st.t_start.rescale(self.units).item() - for st in spiketrains) - stop_shared = min(st.t_stop.rescale(self.units).item() - for st in spiketrains) + start_shared = max(st.t_start.rescale(self.units).item() for st in spiketrains) + stop_shared = min(st.t_stop.rescale(self.units).item() for st in spiketrains) tolerance = self.tolerance if tolerance is None: tolerance = 0 - if self._t_start < start_shared - tolerance \ - or self._t_stop > stop_shared + tolerance: - raise ValueError("'t_start' ({t_start}) or 't_stop' ({t_stop}) is " - "outside of the shared [{start_shared}, " - "{stop_shared}] interval".format( - t_start=self.t_start, t_stop=self.t_stop, - start_shared=start_shared, - stop_shared=stop_shared)) + if ( + self._t_start < start_shared - tolerance + or self._t_stop > stop_shared + tolerance + ): + raise ValueError( + "'t_start' ({t_start}) or 't_stop' ({t_stop}) is " + "outside of the shared [{start_shared}, " + "{stop_shared}] interval".format( + t_start=self.t_start, + t_stop=self.t_stop, + start_shared=start_shared, + stop_shared=stop_shared, + ) + ) if self.n_bins is None: # bin_size is provided @@ -573,9 +616,12 @@ def bin_edges(self): All edges in interval [:attr:`t_start`, :attr:`t_stop`] with :attr:`n_bins` bins are returned as a quantity array. """ - bin_edges = np.linspace(self._t_start, self._t_start + self.n_bins * - self._bin_size, - num=self.n_bins + 1, endpoint=True) + bin_edges = np.linspace( + self._t_start, + self._t_start + self.n_bins * self._bin_size, + num=self.n_bins + 1, + endpoint=True, + ) return pq.Quantity(bin_edges, units=self.units, copy=False) @property @@ -594,9 +640,9 @@ def bin_centers(self): """ start = self._t_start + self._bin_size / 2 stop = start + (self.n_bins - 1) * self._bin_size - bin_centers = np.linspace(start=start, - stop=stop, - num=self.n_bins, endpoint=True) + bin_centers = np.linspace( + start=start, stop=stop, num=self.n_bins, endpoint=True + ) bin_centers = pq.Quantity(bin_centers, units=self.units, copy=False) return bin_centers @@ -634,12 +680,17 @@ def __eq__(self, other): return False sp1 = self.sparse_matrix sp2 = other.sparse_matrix - if sp1.__class__ is not sp2.__class__ or sp1.shape != sp2.shape \ - or sp1.data.shape != sp2.data.shape: + if ( + sp1.__class__ is not sp2.__class__ + or sp1.shape != sp2.shape + or sp1.data.shape != sp2.data.shape + ): return False - return (sp1.data == sp2.data).all() and \ - (sp1.indptr == sp2.indptr).all() and \ - (sp1.indices == sp2.indices).all() + return ( + (sp1.data == sp2.data).all() + and (sp1.indptr == sp2.indptr).all() + and (sp1.indices == sp2.indices).all() + ) def copy(self): """ @@ -651,20 +702,24 @@ def copy(self): BinnedSpikeTrainView A copied view of itself. """ - return BinnedSpikeTrainView(t_start=self._t_start, - t_stop=self._t_stop, - bin_size=self._bin_size, - units=self.units, - sparse_matrix=self.sparse_matrix.copy(), - tolerance=self.tolerance) + return BinnedSpikeTrainView( + t_start=self._t_start, + t_stop=self._t_stop, + bin_size=self._bin_size, + units=self.units, + sparse_matrix=self.sparse_matrix.copy(), + tolerance=self.tolerance, + ) def __iter_sparse_matrix(self): spmat = self.sparse_matrix if isinstance(spmat, sps.csc_matrix): - warnings.warn("The sparse matrix format is CSC. For better " - "performance, specify the CSR format while " - "constructing a " - "BinnedSpikeTrain(sparse_format='csr')") + warnings.warn( + "The sparse matrix format is CSC. For better " + "performance, specify the CSR format while " + "constructing a " + "BinnedSpikeTrain(sparse_format='csr')" + ) spmat = spmat.tocsr() # taken from csr_matrix.__iter__() i0 = 0 @@ -703,18 +758,22 @@ def __getitem__(self, item): elif isinstance(col, slice): start, stop, stride = col.indices(self.n_bins) else: - raise TypeError(f"The second slice argument ({col}), which " - "corresponds to bin indices, must be either int " - "or slice.") + raise TypeError( + f"The second slice argument ({col}), which " + "corresponds to bin indices, must be either int " + "or slice." + ) t_start = self._t_start + start * self._bin_size t_stop = self._t_start + stop * self._bin_size bin_size = stride * self._bin_size - bst = BinnedSpikeTrainView(t_start=t_start, - t_stop=t_stop, - bin_size=bin_size, - units=self.units, - sparse_matrix=spmat, - tolerance=self.tolerance) + bst = BinnedSpikeTrainView( + t_start=t_start, + t_stop=t_stop, + bin_size=bin_size, + units=self.units, + sparse_matrix=spmat, + tolerance=self.tolerance, + ) return bst def __setitem__(self, key, value): @@ -767,25 +826,25 @@ def time_slice(self, t_start=None, t_stop=None, copy=False): else: t_stop = t_stop.rescale(self.units).item() stop_index = (t_stop - self._t_start) / self._bin_size - stop_index = round_binning_errors(stop_index, - tolerance=self.tolerance) + stop_index = round_binning_errors(stop_index, tolerance=self.tolerance) stop_index = min(stop_index, self.n_bins) stop_index = max(stop_index, start_index) - spmat = self.sparse_matrix[:, start_index: stop_index] + spmat = self.sparse_matrix[:, start_index:stop_index] if copy: spmat = spmat.copy() t_start = self._t_start + start_index * self._bin_size t_stop = self._t_start + stop_index * self._bin_size - bst = BinnedSpikeTrainView(t_start=t_start, - t_stop=t_stop, - bin_size=self._bin_size, - units=self.units, - sparse_matrix=spmat, - tolerance=self.tolerance) + bst = BinnedSpikeTrainView( + t_start=t_start, + t_stop=t_stop, + bin_size=self._bin_size, + units=self.units, + sparse_matrix=spmat, + tolerance=self.tolerance, + ) return bst - def to_spike_trains(self, spikes="random", as_array=False, - annotate_bins=False): + def to_spike_trains(self, spikes="random", as_array=False, annotate_bins=False): """ Generate spike trains from the binned spike train object. This function is inverse to binning such that @@ -835,13 +894,16 @@ def to_spike_trains(self, spikes="random", as_array=False, bin_indices = np.repeat(indices, spike_count) t_starts = self._t_start + bin_indices * self._bin_size if spikes == "random": - spiketrain = np.random.uniform(low=0, high=self._bin_size, - size=spike_count.sum()) + spiketrain = np.random.uniform( + low=0, high=self._bin_size, size=spike_count.sum() + ) spiketrain += t_starts spiketrain.sort() elif spikes == "left": - spiketrain = [np.arange(shift, count + shift) / (count + shift) - for count in spike_count] + spiketrain = [ + np.arange(shift, count + shift) / (count + shift) + for count in spike_count + ] spiketrain = np.hstack(spiketrain) * self._bin_size spiketrain += t_starts else: @@ -852,12 +914,16 @@ def to_spike_trains(self, spikes="random", as_array=False, array_ants = None if annotate_bins: array_ants = dict(bins=bin_indices) - spiketrain = neo.SpikeTrain(spiketrain, t_start=self._t_start, - t_stop=self._t_stop, - units=self.units, copy=False, - description=description, - array_annotations=array_ants, - bin_size=self.bin_size) + spiketrain = neo.SpikeTrain( + spiketrain, + t_start=self._t_start, + t_stop=self._t_stop, + units=self.units, + copy=False, + description=description, + array_annotations=array_ants, + bin_size=self.bin_size, + ) spiketrains.append(spiketrain) return spiketrains @@ -1036,14 +1102,16 @@ def binarize(self, copy=True): if copy: data = np.ones(len(spmat.data), dtype=spmat.data.dtype) spmat = spmat.__class__( - (data, spmat.indices, spmat.indptr), - shape=spmat.shape, copy=False) - bst = BinnedSpikeTrainView(t_start=self._t_start, - t_stop=self._t_stop, - bin_size=self._bin_size, - units=self.units, - sparse_matrix=spmat, - tolerance=self.tolerance) + (data, spmat.indices, spmat.indptr), shape=spmat.shape, copy=False + ) + bst = BinnedSpikeTrainView( + t_start=self._t_start, + t_stop=self._t_stop, + bin_size=self._bin_size, + units=self.units, + sparse_matrix=spmat, + tolerance=self.tolerance, + ) else: spmat.data[:] = 1 bst = self @@ -1081,7 +1149,7 @@ def _create_sparse_matrix(self, spiketrains, sparse_format): # The data type for numeric values data_dtype = np.int32 - if sparse_format == 'csr': + if sparse_format == "csr": sparse_format = sps.csr_matrix else: # csc @@ -1109,8 +1177,10 @@ def _create_sparse_matrix(self, spiketrains, sparse_format): scale_units = 1 / self._bin_size for idx, st in enumerate(spiketrains): times = st.magnitude - times = times[(times >= self._t_start) & ( - times <= self._t_stop)] - self._t_start + times = ( + times[(times >= self._t_start) & (times <= self._t_stop)] + - self._t_start + ) bins = times * scale_units # shift spikes that are very close @@ -1127,8 +1197,11 @@ def _create_sparse_matrix(self, spiketrains, sparse_format): row_ids.append(np.repeat(idx, repeats=len(f)).astype(numtype)) if n_discarded > 0: - warnings.warn("Binning discarded {} last spike(s) of the " - "input spiketrain".format(n_discarded)) + warnings.warn( + "Binning discarded {} last spike(s) of the " "input spiketrain".format( + n_discarded + ) + ) # Stacking preserves the data type. In any case, while creating # the sparse matrix, a copy is performed even if we set 'copy' to False @@ -1138,9 +1211,9 @@ def _create_sparse_matrix(self, spiketrains, sparse_format): column_ids = np.hstack(column_ids) row_ids = np.hstack(row_ids) - sparse_matrix = sparse_format((counts, (row_ids, column_ids)), - shape=shape, dtype=data_dtype, - copy=False) + sparse_matrix = sparse_format( + (counts, (row_ids, column_ids)), shape=shape, dtype=data_dtype, copy=False + ) return sparse_matrix @@ -1172,8 +1245,7 @@ class BinnedSpikeTrainView(BinnedSpikeTrain): This class is an experimental feature. """ - def __init__(self, t_start, t_stop, bin_size, units, sparse_matrix, - tolerance=1e-8): + def __init__(self, t_start, t_stop, bin_size, units, sparse_matrix, tolerance=1e-8): self._t_start = t_start self._t_stop = t_stop self._bin_size = bin_size @@ -1259,31 +1331,35 @@ def discretise_spiketimes(spiketrains, sampling_rate): if not isinstance(st, (np.ndarray, neo.SpikeTrain)): raise TypeError( "spiketrains must be a SpikeTrain, a numpy ndarray, or a " - "list of one of those, not %s." % type(spiketrains)) + "list of one of those, not %s." % type(spiketrains) + ) else: raise TypeError( "spiketrains must be a SpikeTrain or a list of SpikeTrain objects," - " not %s." % type(spiketrains)) + " not %s." % type(spiketrains) + ) if not isinstance(sampling_rate, pq.Quantity): raise TypeError( - "The 'sampling_rate' must be pq.Quantity.\n" - "Found: %s." % type(sampling_rate)) + "The 'sampling_rate' must be pq.Quantity.\n" + "Found: %s." % type(sampling_rate) + ) units = spiketrains[0].times.units - mag_sampling_rate = sampling_rate.rescale(1/units).magnitude.flatten() + mag_sampling_rate = sampling_rate.rescale(1 / units).magnitude.flatten() new_spiketrains = [] for spiketrain in spiketrains: mag_t_start = spiketrain.t_start.rescale(units).magnitude.flatten() mag_times = spiketrain.times.magnitude.flatten() - discrete_times = (mag_times // (1 / mag_sampling_rate) - / mag_sampling_rate) + discrete_times = mag_times // (1 / mag_sampling_rate) / mag_sampling_rate mask = discrete_times < mag_t_start if np.any(mask): - warnings.warn(f'{mask.sum()} spike(s) would be before t_start ' - 'and are set to t_start instead.') + warnings.warn( + f"{mask.sum()} spike(s) would be before t_start " + "and are set to t_start instead." + ) discrete_times[mask] = mag_t_start discrete_times *= units diff --git a/elephant/cubic.py b/elephant/cubic.py index ac117406e..daf9682e4 100644 --- a/elephant/cubic.py +++ b/elephant/cubic.py @@ -52,9 +52,7 @@ import scipy.special import scipy.stats -__all__ = [ - "cubic" -] +__all__ = ["cubic"] # Based on matlab code by Benjamin Staude @@ -107,12 +105,14 @@ def cubic(histogram, max_iterations=100, alpha=0.05): iteration, `max_iterations`. """ if alpha < 0 or alpha > 1: - raise ValueError(f'the significance level alpha ({alpha}) has to be ' - f'in [0, 1] range') + raise ValueError( + f"the significance level alpha ({alpha}) has to be " f"in [0, 1] range" + ) if not isinstance(max_iterations, int) or max_iterations < 0: - raise ValueError(f"'max_iterations' ({max_iterations}) has to be a " - "positive integer") + raise ValueError( + f"'max_iterations' ({max_iterations}) has to be a " "positive integer" + ) # dict of all possible rate functions try: @@ -125,7 +125,7 @@ def cubic(histogram, max_iterations=100, alpha=0.05): kappa = _kstat(histogram) xi_hat = 1 xi = 1 - pval = 0. + pval = 0.0 p = [] test_aborted = False @@ -133,8 +133,10 @@ def cubic(histogram, max_iterations=100, alpha=0.05): while pval < alpha: xi_hat = xi if xi > max_iterations: - warnings.warn(f'Test aborted after ximax={max_iterations} ' - f'iterations with p-value={pval}') + warnings.warn( + f"Test aborted after ximax={max_iterations} " + f"iterations with p-value={pval}" + ) test_aborted = True break @@ -170,11 +172,13 @@ def _H03xi(kappa, xi, L): # Check the order condition of the cumulants necessary to perform CuBIC if kappa[1] < kappa[0]: - raise ValueError(f"The null hypothesis H_0 cannot be tested: the " - f"population count histogram variance ({kappa[1]}) " - f"is less than the mean ({kappa[0]}). This can " - f"happen when the spike train population is not " - f"large enough or the bin size is small.") + raise ValueError( + f"The null hypothesis H_0 cannot be tested: the " + f"population count histogram variance ({kappa[1]}) " + f"is less than the mean ({kappa[0]}). This can " + f"happen when the spike train population is not " + f"large enough or the bin size is small." + ) else: # computation of the maximized cumulants kstar = [_kappamstar(kappa[:2], i, xi) for i in range(2, 7)] @@ -182,8 +186,10 @@ def _H03xi(kappa, xi, L): # variance of third cumulant (from Stuart & Ord) sigmak3star = math.sqrt( - kstar[4] / L + 9 * (kstar[2] * kstar[0] + kstar[1] ** 2) / - (L - 1) + 6 * L * kstar[0] ** 3 / ((L - 1) * (L - 2))) + kstar[4] / L + + 9 * (kstar[2] * kstar[0] + kstar[1] ** 2) / (L - 1) + + 6 * L * kstar[0] ** 3 / ((L - 1) * (L - 2)) + ) # computation of the p-value (the third cumulant is supposed to # be gaussian distributed) p = 1 - scipy.stats.norm(k3star, sigmak3star).cdf(kappa[2]) @@ -212,9 +218,9 @@ def _kappamstar(kappa, m, xi): if xi == 1: kappa_out = kappa[1] else: - kappa_out = \ - (kappa[1] * (xi ** (m - 1) - 1) - - kappa[0] * (xi ** (m - 1) - xi)) / (xi - 1) + kappa_out = ( + kappa[1] * (xi ** (m - 1) - 1) - kappa[0] * (xi ** (m - 1) - xi) + ) / (xi - 1) return kappa_out @@ -236,6 +242,6 @@ def _kstat(data): The first three unbiased cumulants of the population count """ if len(data) == 0: - raise ValueError('The input data must be a non-empty array') + raise ValueError("The input data must be a non-empty array") moments = [scipy.stats.kstat(data, n=n) for n in [1, 2, 3]] return moments diff --git a/elephant/current_source_density.py b/elephant/current_source_density.py index 00a25f0b4..6715a3bca 100644 --- a/elephant/current_source_density.py +++ b/elephant/current_source_density.py @@ -43,25 +43,23 @@ import elephant.current_source_density_src.utility_functions as utils from elephant.current_source_density_src import KCSD, icsd -__all__ = [ - "estimate_csd", - "generate_lfp" -] +__all__ = ["estimate_csd", "generate_lfp"] utils.patch_quantities() -available_1d = ['StandardCSD', 'DeltaiCSD', 'StepiCSD', 'SplineiCSD', 'KCSD1D'] -available_2d = ['KCSD2D', 'MoIKCSD'] -available_3d = ['KCSD3D'] +available_1d = ["StandardCSD", "DeltaiCSD", "StepiCSD", "SplineiCSD", "KCSD1D"] +available_2d = ["KCSD2D", "MoIKCSD"] +available_3d = ["KCSD3D"] -kernel_methods = ['KCSD1D', 'KCSD2D', 'KCSD3D', 'MoIKCSD'] -icsd_methods = ['DeltaiCSD', 'StepiCSD', 'SplineiCSD'] +kernel_methods = ["KCSD1D", "KCSD2D", "KCSD3D", "MoIKCSD"] +icsd_methods = ["DeltaiCSD", "StepiCSD", "SplineiCSD"] -py_iCSD_toolbox = ['StandardCSD'] + icsd_methods +py_iCSD_toolbox = ["StandardCSD"] + icsd_methods -def estimate_csd(lfp, coordinates='coordinates', method=None, - process_estimate=True, **kwargs): +def estimate_csd( + lfp, coordinates="coordinates", method=None, process_estimate=True, **kwargs +): """ Function call to compute the current source density (CSD) from extracellular potential recordings (local field potentials - LFP) using @@ -115,7 +113,7 @@ def estimate_csd(lfp, coordinates='coordinates', method=None, Invalid cv_param argument passed """ if not isinstance(lfp, neo.AnalogSignal): - raise TypeError('Parameter `lfp` must be a neo.AnalogSignal object') + raise TypeError("Parameter `lfp` must be a neo.AnalogSignal object") if isinstance(coordinates, str): coordinates = lfp.annotations[coordinates] @@ -125,91 +123,104 @@ def estimate_csd(lfp, coordinates='coordinates', method=None, try: scaled_coords.append(coord.rescale(pq.mm)) except AttributeError: - raise AttributeError('No units given for electrode spatial \ - coordinates') + raise AttributeError( + "No units given for electrode spatial \ + coordinates" + ) coordinates = scaled_coords if method is None: - raise ValueError('Must specify a method of CSD implementation') + raise ValueError("Must specify a method of CSD implementation") if len(coordinates) != lfp.shape[1]: - raise ValueError('Number of signals and coords is not same') + raise ValueError("Number of signals and coords is not same") for ii in coordinates: # CHECK for Dimensionality of electrodes if len(ii) > 3: - raise ValueError('Invalid number of coordinate positions') + raise ValueError("Invalid number of coordinate positions") dim = len(coordinates[0]) # TODO : Generic co-ordinates! if dim == 1 and (method not in available_1d): - raise ValueError('Invalid method, Available options are:', - available_1d) + raise ValueError("Invalid method, Available options are:", available_1d) if dim == 2 and (method not in available_2d): - raise ValueError('Invalid method, Available options are:', - available_2d) + raise ValueError("Invalid method, Available options are:", available_2d) if dim == 3 and (method not in available_3d): - raise ValueError('Invalid method, Available options are:', - available_3d) + raise ValueError("Invalid method, Available options are:", available_3d) if method in kernel_methods: input_array = np.zeros((len(lfp), lfp[0].magnitude.shape[0])) for ii, jj in enumerate(lfp): input_array[ii, :] = jj.rescale(pq.mV).magnitude kernel_method = getattr(KCSD, method) # fetch the class 'KCSD1D' - lambdas = kwargs.pop('lambdas', None) - Rs = kwargs.pop('Rs', None) + lambdas = kwargs.pop("lambdas", None) + Rs = kwargs.pop("Rs", None) k = kernel_method(np.array(coordinates), input_array.T, **kwargs) if process_estimate: k.cross_validate(lambdas, Rs) estm_csd = k.values() estm_csd = np.rollaxis(estm_csd, -1, 0) - output = neo.AnalogSignal(estm_csd * pq.uA / pq.mm**3, - t_start=lfp.t_start, - sampling_rate=lfp.sampling_rate) + output = neo.AnalogSignal( + estm_csd * pq.uA / pq.mm**3, + t_start=lfp.t_start, + sampling_rate=lfp.sampling_rate, + ) if dim == 1: output.annotate(x_coords=k.estm_x) elif dim == 2: output.annotate(x_coords=k.estm_x, y_coords=k.estm_y) elif dim == 3: - output.annotate(x_coords=k.estm_x, y_coords=k.estm_y, - z_coords=k.estm_z) + output.annotate(x_coords=k.estm_x, y_coords=k.estm_y, z_coords=k.estm_z) elif method in py_iCSD_toolbox: - coordinates = np.array(coordinates) * coordinates[0].units if method in icsd_methods: try: - coordinates = coordinates.rescale(kwargs['diam'].units) + coordinates = coordinates.rescale(kwargs["diam"].units) except KeyError: # Then why specify as a default in icsd? # All iCSD methods explicitly assume a source # diameter in contrast to the stdCSD that # implicitly assume infinite source radius - raise ValueError(f"Parameter diam must be specified for iCSD " - f"methods: {', '.join(icsd_methods)}") + raise ValueError( + f"Parameter diam must be specified for iCSD " + f"methods: {', '.join(icsd_methods)}" + ) - if 'f_type' in kwargs: - if (kwargs['f_type'] != 'identity') and \ - (kwargs['f_order'] is None): - raise ValueError(f"The order of {kwargs['f_type']} filter must" - f" be specified") + if "f_type" in kwargs: + if (kwargs["f_type"] != "identity") and (kwargs["f_order"] is None): + raise ValueError( + f"The order of {kwargs['f_type']} filter must" f" be specified" + ) csd_method = getattr(icsd, method) # fetch class from icsd.py file - csd_estimator = csd_method(lfp=lfp.T.magnitude * lfp.units, - coord_electrode=coordinates.flatten(), - **kwargs) + csd_estimator = csd_method( + lfp=lfp.T.magnitude * lfp.units, + coord_electrode=coordinates.flatten(), + **kwargs, + ) csd_pqarr = csd_estimator.get_csd() if process_estimate: csd_pqarr_filtered = csd_estimator.filter_csd(csd_pqarr) - output = neo.AnalogSignal(csd_pqarr_filtered.T, - t_start=lfp.t_start, - sampling_rate=lfp.sampling_rate) + output = neo.AnalogSignal( + csd_pqarr_filtered.T, + t_start=lfp.t_start, + sampling_rate=lfp.sampling_rate, + ) else: - output = neo.AnalogSignal(csd_pqarr.T, t_start=lfp.t_start, - sampling_rate=lfp.sampling_rate) + output = neo.AnalogSignal( + csd_pqarr.T, t_start=lfp.t_start, sampling_rate=lfp.sampling_rate + ) output.annotate(x_coords=coordinates) return output -def generate_lfp(csd_profile, x_positions, y_positions=None, z_positions=None, - x_limits=[0., 1.], y_limits=[0., 1.], z_limits=[0., 1.], - resolution=50): +def generate_lfp( + csd_profile, + x_positions, + y_positions=None, + z_positions=None, + x_limits=[0.0, 1.0], + y_limits=[0.0, 1.0], + z_limits=[0.0, 1.0], + resolution=50, +): """ Forward modelling for getting the potentials for testing Current Source Density (CSD). @@ -279,7 +290,7 @@ def generate_lfp(csd_profile, x_positions, y_positions=None, z_positions=None, """ def integrate_1D(x0, csd_x, csd, h): - m = np.sqrt((csd_x - x0) ** 2 + h ** 2) - abs(csd_x - x0) + m = np.sqrt((csd_x - x0) ** 2 + h**2) - abs(csd_x - x0) y = csd * m I = simpson(y, x=csd_x) return I @@ -314,7 +325,7 @@ def integrate_3D(x, y, z, csd, xlin, ylin, zlin, X, Y, Z): x = np.linspace(x_limits[0], x_limits[1], resolution) sigma = 1.0 - h = 50. + h = 50.0 if dim == 1: # Handle one dimensional case, # see https://github.com/NeuralEnsemble/elephant/issues/546 @@ -323,42 +334,47 @@ def integrate_3D(x, y, z, csd, xlin, ylin, zlin, X, Y, Z): chrg_x = x csd = csd_profile(chrg_x) pots = integrate_1D(x_positions, chrg_x, csd, h) - pots /= 2. * sigma # eq.: 26 from Potworowski et al + pots /= 2.0 * sigma # eq.: 26 from Potworowski et al ele_pos = x_positions elif dim == 2: y = np.linspace(y_limits[0], y_limits[1], resolution) chrg_x = np.expand_dims(x, axis=1) chrg_y = np.expand_dims(y, axis=0) csd = csd_profile(chrg_x, chrg_y) - pots = integrate_2D(x_positions, y_positions, - x, y, - csd, h, - chrg_x, chrg_y) + pots = integrate_2D(x_positions, y_positions, x, y, csd, h, chrg_x, chrg_y) pots /= 2 * np.pi * sigma ele_pos = np.vstack((x_positions, y_positions)).T elif dim == 3: y = np.linspace(y_limits[0], y_limits[1], resolution) z = np.linspace(z_limits[0], z_limits[1], resolution) chrg_x, chrg_y, chrg_z = np.mgrid[ - x_limits[0]: x_limits[1]: complex(0, resolution), - y_limits[0]: y_limits[1]: complex(0, resolution), - z_limits[0]: z_limits[1]: complex(0, resolution) + x_limits[0] : x_limits[1] : complex(0, resolution), + y_limits[0] : y_limits[1] : complex(0, resolution), + z_limits[0] : z_limits[1] : complex(0, resolution), ] csd = csd_profile(chrg_x, chrg_y, chrg_z) pots = np.zeros(len(x_positions)) for ii in range(len(x_positions)): - pots[ii] = integrate_3D(x_positions[ii], y_positions[ii], - z_positions[ii], - csd, - x, y, z, - chrg_x, chrg_y, chrg_z) + pots[ii] = integrate_3D( + x_positions[ii], + y_positions[ii], + z_positions[ii], + csd, + x, + y, + z, + chrg_x, + chrg_y, + chrg_z, + ) pots /= 4 * np.pi * sigma ele_pos = np.vstack((x_positions, y_positions, z_positions)).T ele_pos = ele_pos * pq.mm - asig = neo.AnalogSignal(np.expand_dims(pots, axis=0), - sampling_rate=pq.kHz, units='mV') + asig = neo.AnalogSignal( + np.expand_dims(pots, axis=0), sampling_rate=pq.kHz, units="mV" + ) asig.annotate(coordinates=ele_pos) return asig diff --git a/elephant/current_source_density_src/KCSD.py b/elephant/current_source_density_src/KCSD.py index 727d41b44..cf1624085 100644 --- a/elephant/current_source_density_src/KCSD.py +++ b/elephant/current_source_density_src/KCSD.py @@ -8,6 +8,7 @@ Nencki Institute of Exprimental Biology, Warsaw. KCSD1D[1][2], KCSD2D[1], KCSD3D[1], MoIKCSD[1] """ + from __future__ import division import numpy as np @@ -23,6 +24,7 @@ class CSD(object): """CSD - The base class for KCSD methods.""" + def __init__(self, ele_pos, pots): self.validate(ele_pos, pots) self.ele_pos = ele_pos @@ -43,11 +45,13 @@ def validate(self, ele_pos, pots): potentials measured by electrodes """ if ele_pos.shape[0] != pots.shape[0]: - raise Exception("Number of measured potentials is not equal " - "to electrode number!") - if ele_pos.shape[0] < 1+ele_pos.shape[1]: #Dim+1 - raise Exception("Number of electrodes must be at least :", - 1+ele_pos.shape[1]) + raise Exception( + "Number of measured potentials is not equal " "to electrode number!" + ) + if ele_pos.shape[0] < 1 + ele_pos.shape[1]: # Dim+1 + raise Exception( + "Number of electrodes must be at least :", 1 + ele_pos.shape[1] + ) if utils.contains_duplicated_electrodes(ele_pos): raise Exception("Error! Duplicated electrode!") @@ -67,6 +71,7 @@ def sanity(self, true_csd, pos_csd): RMSE = np.sqrt(np.mean(np.square(true_csd - csd))) return RMSE + class KCSD(CSD): """KCSD - The base class for all the KCSD variants. This estimates the Current Source Density, for a given configuration of @@ -74,6 +79,7 @@ class KCSD(CSD): The method implented here is based on the original paper by Jan Potworowski et.al. 2012. """ + def __init__(self, ele_pos, pots, **kwargs): super(KCSD, self).__init__(ele_pos, pots) self.parameters(**kwargs) @@ -89,28 +95,28 @@ def parameters(self, **kwargs): **kwargs Same as those passed to initialize the Class """ - self.src_type = kwargs.pop('src_type', 'gauss') - self.sigma = kwargs.pop('sigma', 1.0) - self.h = kwargs.pop('h', 1.0) - self.n_src_init = kwargs.pop('n_src_init', 1000) - self.lambd = kwargs.pop('lambd', 0.0) - self.R_init = kwargs.pop('R_init', 0.23) - self.ext_x = kwargs.pop('ext_x', 0.0) - self.xmin = kwargs.pop('xmin', np.min(self.ele_pos[:, 0])) - self.xmax = kwargs.pop('xmax', np.max(self.ele_pos[:, 0])) - self.gdx = kwargs.pop('gdx', 0.01*(self.xmax - self.xmin)) + self.src_type = kwargs.pop("src_type", "gauss") + self.sigma = kwargs.pop("sigma", 1.0) + self.h = kwargs.pop("h", 1.0) + self.n_src_init = kwargs.pop("n_src_init", 1000) + self.lambd = kwargs.pop("lambd", 0.0) + self.R_init = kwargs.pop("R_init", 0.23) + self.ext_x = kwargs.pop("ext_x", 0.0) + self.xmin = kwargs.pop("xmin", np.min(self.ele_pos[:, 0])) + self.xmax = kwargs.pop("xmax", np.max(self.ele_pos[:, 0])) + self.gdx = kwargs.pop("gdx", 0.01 * (self.xmax - self.xmin)) if self.dim >= 2: - self.ext_y = kwargs.pop('ext_y', 0.0) - self.ymin = kwargs.pop('ymin', np.min(self.ele_pos[:, 1])) - self.ymax = kwargs.pop('ymax', np.max(self.ele_pos[:, 1])) - self.gdy = kwargs.pop('gdy', 0.01*(self.ymax - self.ymin)) + self.ext_y = kwargs.pop("ext_y", 0.0) + self.ymin = kwargs.pop("ymin", np.min(self.ele_pos[:, 1])) + self.ymax = kwargs.pop("ymax", np.max(self.ele_pos[:, 1])) + self.gdy = kwargs.pop("gdy", 0.01 * (self.ymax - self.ymin)) if self.dim == 3: - self.ext_z = kwargs.pop('ext_z', 0.0) - self.zmin = kwargs.pop('zmin', np.min(self.ele_pos[:, 2])) - self.zmax = kwargs.pop('zmax', np.max(self.ele_pos[:, 2])) - self.gdz = kwargs.pop('gdz', 0.01*(self.zmax - self.zmin)) + self.ext_z = kwargs.pop("ext_z", 0.0) + self.zmin = kwargs.pop("zmin", np.min(self.ele_pos[:, 2])) + self.zmax = kwargs.pop("zmax", np.max(self.ele_pos[:, 2])) + self.gdz = kwargs.pop("gdz", 0.01 * (self.zmax - self.zmin)) if kwargs: - raise TypeError('Invalid keyword arguments:', kwargs.keys()) + raise TypeError("Invalid keyword arguments:", kwargs.keys()) def method(self): """Actual sequence of methods called for KCSD @@ -120,10 +126,10 @@ def method(self): ---------- None """ - self.create_lookup() #Look up table - self.update_b_pot() #update kernel - self.update_b_src() #update crskernel - self.update_b_interp_pot() #update pot interp + self.create_lookup() # Look up table + self.update_b_pot() # update kernel + self.update_b_src() # update crskernel + self.update_b_interp_pot() # update pot interp def create_lookup(self, dist_table_density=20): """Creates a table for easy potential estimation from CSD. @@ -136,16 +142,14 @@ def create_lookup(self, dist_table_density=20): number of distance values at which potentials are computed. Default 100 """ - xs = np.logspace(0., np.log10(self.dist_max+1.), dist_table_density) - xs = xs - 1.0 #starting from 0 + xs = np.logspace(0.0, np.log10(self.dist_max + 1.0), dist_table_density) + xs = xs - 1.0 # starting from 0 dist_table = np.zeros(len(xs)) for i, pos in enumerate(xs): - dist_table[i] = self.forward_model(pos, - self.R, - self.h, - self.sigma, - self.basis) - self.interpolate_pot_at = interpolate.interp1d(xs, dist_table, kind='cubic') + dist_table[i] = self.forward_model( + pos, self.R, self.h, self.sigma, self.basis + ) + self.interpolate_pot_at = interpolate.interp1d(xs, dist_table, kind="cubic") def update_b_pot(self): """Updates the b_pot - array is (#_basis_sources, #_electrodes) @@ -159,7 +163,7 @@ def update_b_pot(self): None """ self.b_pot = self.interpolate_pot_at(self.src_ele_dists) - self.k_pot = np.dot(self.b_pot.T, self.b_pot) #K(x,x') Eq9,Jan2012 + self.k_pot = np.dot(self.b_pot.T, self.b_pot) # K(x,x') Eq9,Jan2012 self.k_pot /= self.n_src def update_b_src(self): @@ -173,7 +177,7 @@ def update_b_src(self): None """ self.b_src = self.basis(self.src_estm_dists, self.R).T - self.k_interp_cross = np.dot(self.b_src, self.b_pot) #K_t(x,y) Eq17 + self.k_interp_cross = np.dot(self.b_src, self.b_pot) # K_t(x,y) Eq17 self.k_interp_cross /= self.n_src def update_b_interp_pot(self): @@ -189,7 +193,7 @@ def update_b_interp_pot(self): self.k_interp_pot = np.dot(self.b_interp_pot, self.b_pot) self.k_interp_pot /= self.n_src - def values(self, estimate='CSD'): + def values(self, estimate="CSD"): """Computes the values of the quantity of interest Parameters ---------- @@ -201,19 +205,20 @@ def values(self, estimate='CSD'): estimation : np.array estimated quantity of shape (ngx, ngy, ngz, nt) """ - if estimate == 'CSD': #Maybe used for estimating the potentials also. + if estimate == "CSD": # Maybe used for estimating the potentials also. estimation_table = self.k_interp_cross - elif estimate == 'POT': + elif estimate == "POT": estimation_table = self.k_interp_pot else: - print('Invalid quantity to be measured, pass either CSD or POT') - k_inv = np.linalg.inv(self.k_pot + self.lambd * - np.identity(self.k_pot.shape[0])) + print("Invalid quantity to be measured, pass either CSD or POT") + k_inv = np.linalg.inv( + self.k_pot + self.lambd * np.identity(self.k_pot.shape[0]) + ) estimation = np.zeros((self.n_estm, self.n_time)) for t in range(self.n_time): beta = np.dot(k_inv, self.pots[:, t]) for i in range(self.n_ele): - estimation[:, t] += estimation_table[:, i] *beta[i] # C*(x) Eq 18 + estimation[:, t] += estimation_table[:, i] * beta[i] # C*(x) Eq 18 return self.process_estimate(estimation) def process_estimate(self, estimation): @@ -242,8 +247,9 @@ def update_R(self, R): R : float """ self.R = R - self.dist_max = max(np.max(self.src_ele_dists), - np.max(self.src_estm_dists)) + self.R + self.dist_max = ( + max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + ) self.method() def update_lambda(self, lambd): @@ -270,34 +276,33 @@ def cross_validate(self, lambdas=None, Rs=None): R : post cross validation Lambda : post cross validation """ - if lambdas is None: #when None - print('No lambda given, using defaults') - lambdas = np.logspace(-2,-25,25,base=10.) #Default multiple lambda + if lambdas is None: # when None + print("No lambda given, using defaults") + lambdas = np.logspace(-2, -25, 25, base=10.0) # Default multiple lambda lambdas = np.hstack((lambdas, np.array((0.0)))) - elif lambdas.size == 1: #resize when one entry + elif lambdas.size == 1: # resize when one entry lambdas = lambdas.flatten() - if Rs is None: #when None - Rs = np.array((self.R)).flatten() #Default over one R value + if Rs is None: # when None + Rs = np.array((self.R)).flatten() # Default over one R value errs = np.zeros((Rs.size, lambdas.size)) index_generator = [] for ii in range(self.n_ele): idx_test = [ii] idx_train = list(range(self.n_ele)) - idx_train.remove(ii) #Leave one out + idx_train.remove(ii) # Leave one out index_generator.append((idx_train, idx_test)) - for R_idx,R in enumerate(Rs): #Iterate over R + for R_idx, R in enumerate(Rs): # Iterate over R self.update_R(R) - print('Cross validating R (all lambda) :', R) - for lambd_idx,lambd in enumerate(lambdas): #Iterate over lambdas - errs[R_idx, lambd_idx] = self.compute_cverror(lambd, - index_generator) - err_idx = np.where(errs==np.min(errs)) #Index of the least error - cv_R = Rs[err_idx[0]][0] #First occurance of the least error's + print("Cross validating R (all lambda) :", R) + for lambd_idx, lambd in enumerate(lambdas): # Iterate over lambdas + errs[R_idx, lambd_idx] = self.compute_cverror(lambd, index_generator) + err_idx = np.where(errs == np.min(errs)) # Index of the least error + cv_R = Rs[err_idx[0]][0] # First occurance of the least error's cv_lambda = lambdas[err_idx[1]][0] - self.cv_error = np.min(errs) #otherwise is None - self.update_R(cv_R) #Update solver + self.cv_error = np.min(errs) # otherwise is None + self.update_R(cv_R) # Update solver self.update_lambda(cv_lambda) - print('R, lambda :', cv_R, cv_lambda) + print("R, lambda :", cv_R, cv_lambda) return cv_R, cv_lambda def compute_cverror(self, lambd, index_generator): @@ -328,9 +333,11 @@ def compute_cverror(self, lambd, index_generator): err += np.linalg.norm(V_est - V_test) except np.linalg.LinAlgError: raise np.linalg.LinAlgError( - 'Encountered Singular Matrix Error: try changing ele_pos slightly') + "Encountered Singular Matrix Error: try changing ele_pos slightly" + ) return err + class KCSD1D(KCSD): """KCSD1D - The 1D variant for the Kernel Current Source Density method. This estimates the Current Source Density, for a given configuration of @@ -338,6 +345,7 @@ class KCSD1D(KCSD): electrodes (laminar probes). The method implented here is based on the original paper by Jan Potworowski et.al. 2012. """ + def __init__(self, ele_pos, pots, **kwargs): """Initialize KCSD1D Class. Parameters @@ -394,8 +402,8 @@ def estimate_at(self): ---------- None """ - nx = (self.xmax - self.xmin)/self.gdx - self.estm_x = np.mgrid[self.xmin:self.xmax:complex(0,nx)] + nx = (self.xmax - self.xmin) / self.gdx + self.estm_x = np.mgrid[self.xmin : self.xmax : complex(0, nx)] self.n_estm = self.estm_x.size self.ngx = self.estm_x.shape[0] @@ -417,12 +425,12 @@ def place_basis(self): try: self.basis = basis.basis_1D[source_type] except KeyError: - raise KeyError('Invalid source_type for basis! available are:', - basis.basis_1D.keys()) - (self.src_x, self.R) = utils.distribute_srcs_1D(self.estm_x, - self.n_src_init, - self.ext_x, - self.R_init ) + raise KeyError( + "Invalid source_type for basis! available are:", basis.basis_1D.keys() + ) + (self.src_x, self.R) = utils.distribute_srcs_1D( + self.estm_x, self.n_src_init, self.ext_x, self.R_init + ) self.n_src = self.src_x.size self.nsx = self.src_x.shape @@ -436,9 +444,11 @@ def create_src_dist_tables(self): src_loc = src_loc.reshape((len(src_loc), 1)) est_loc = np.array((self.estm_x.ravel())) est_loc = est_loc.reshape((len(est_loc), 1)) - self.src_ele_dists = distance.cdist(src_loc, self.ele_pos, 'euclidean') - self.src_estm_dists = distance.cdist(src_loc, est_loc, 'euclidean') - self.dist_max = max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + self.src_ele_dists = distance.cdist(src_loc, self.ele_pos, "euclidean") + self.src_estm_dists = distance.cdist(src_loc, est_loc, "euclidean") + self.dist_max = ( + max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + ) def forward_model(self, x, R, h, sigma, src_type): """FWD model functions @@ -456,10 +466,8 @@ def forward_model(self, x, R, h, sigma, src_type): pot : float value of potential at specified distance from the source """ - pot, err = integrate.quad(self.int_pot_1D, - -R, R, - args=(x, R, h, src_type)) - pot *= 1./(2.0*sigma) + pot, err = integrate.quad(self.int_pot_1D, -R, R, args=(x, R, h, src_type)) + pot *= 1.0 / (2.0 * sigma) return pot def int_pot_1D(self, xp, x, R, h, basis_func): @@ -487,10 +495,11 @@ def int_pot_1D(self, xp, x, R, h, basis_func): ------- pot : float """ - m = np.sqrt((x-xp)**2 + h**2) - abs(x-xp) - m *= basis_func(abs(xp), R) #xp is the distance + m = np.sqrt((x - xp) ** 2 + h**2) - abs(x - xp) + m *= basis_func(abs(xp), R) # xp is the distance return m + class KCSD2D(KCSD): """KCSD2D - The 2D variant for the Kernel Current Source Density method. This estimates the Current Source Density, for a given configuration of @@ -498,6 +507,7 @@ class KCSD2D(KCSD): electrodes. The method implented here is based on the original paper by Jan Potworowski et.al. 2012. """ + def __init__(self, ele_pos, pots, **kwargs): """Initialize KCSD2D Class. Parameters @@ -557,10 +567,12 @@ def estimate_at(self): ---------- None """ - nx = (self.xmax - self.xmin)/self.gdx - ny = (self.ymax - self.ymin)/self.gdy - self.estm_x, self.estm_y = np.mgrid[self.xmin:self.xmax:complex(0,nx), - self.ymin:self.ymax:complex(0,ny)] + nx = (self.xmax - self.xmin) / self.gdx + ny = (self.ymax - self.ymin) / self.gdy + self.estm_x, self.estm_y = np.mgrid[ + self.xmin : self.xmax : complex(0, nx), + self.ymin : self.ymax : complex(0, ny), + ] self.n_estm = self.estm_x.size self.ngx, self.ngy = self.estm_x.shape @@ -582,14 +594,17 @@ def place_basis(self): try: self.basis = basis.basis_2D[source_type] except KeyError: - raise KeyError('Invalid source_type for basis! available are:', - basis.basis_2D.keys()) - (self.src_x, self.src_y, self.R) = utils.distribute_srcs_2D(self.estm_x, - self.estm_y, - self.n_src_init, - self.ext_x, - self.ext_y, - self.R_init ) + raise KeyError( + "Invalid source_type for basis! available are:", basis.basis_2D.keys() + ) + (self.src_x, self.src_y, self.R) = utils.distribute_srcs_2D( + self.estm_x, + self.estm_y, + self.n_src_init, + self.ext_x, + self.ext_y, + self.R_init, + ) self.n_src = self.src_x.size self.nsx, self.nsy = self.src_x.shape @@ -601,9 +616,11 @@ def create_src_dist_tables(self): """ src_loc = np.array((self.src_x.ravel(), self.src_y.ravel())) est_loc = np.array((self.estm_x.ravel(), self.estm_y.ravel())) - self.src_ele_dists = distance.cdist(src_loc.T, self.ele_pos, 'euclidean') - self.src_estm_dists = distance.cdist(src_loc.T, est_loc.T, 'euclidean') - self.dist_max = max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + self.src_ele_dists = distance.cdist(src_loc.T, self.ele_pos, "euclidean") + self.src_estm_dists = distance.cdist(src_loc.T, est_loc.T, "euclidean") + self.dist_max = ( + max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + ) def forward_model(self, x, R, h, sigma, src_type): """FWD model functions @@ -621,12 +638,10 @@ def forward_model(self, x, R, h, sigma, src_type): pot : float value of potential at specified distance from the source """ - pot, err = integrate.dblquad(self.int_pot_2D, - -R, R, - lambda x: -R, - lambda x: R, - args=(x, R, h, src_type)) - pot *= 1./(2.0*np.pi*sigma) #Potential basis functions bi_x_y + pot, err = integrate.dblquad( + self.int_pot_2D, -R, R, lambda x: -R, lambda x: R, args=(x, R, h, src_type) + ) + pot *= 1.0 / (2.0 * np.pi * sigma) # Potential basis functions bi_x_y return pot def int_pot_2D(self, xp, yp, x, R, h, basis_func): @@ -651,13 +666,14 @@ def int_pot_2D(self, xp, yp, x, R, h, basis_func): ------- pot : float """ - y = ((x-xp)**2 + yp**2)**(0.5) + y = ((x - xp) ** 2 + yp**2) ** (0.5) if y < 0.00001: y = 0.00001 dist = np.sqrt(xp**2 + yp**2) - pot = np.arcsinh(h/y)*basis_func(dist, R) + pot = np.arcsinh(h / y) * basis_func(dist, R) return pot + class MoIKCSD(KCSD2D): """MoIKCSD - CSD while including the forward modeling effects of saline. @@ -667,6 +683,7 @@ class MoIKCSD(KCSD2D): The method implented here is based on kCSD method by Jan Potworowski et.al. 2012, which was extended in Ness, Chintaluri 2015 for MEA. """ + def __init__(self, ele_pos, pots, **kwargs): """Initialize MoIKCSD Class. Parameters @@ -714,11 +731,11 @@ def __init__(self, ele_pos, pots, **kwargs): Number of interations in method of images. Default is 20 """ - self.MoI_iters = kwargs.pop('MoI_iters', 20) - self.sigma_S = kwargs.pop('sigma_S', 5.0) - self.sigma = kwargs.pop('sigma', 1.0) + self.MoI_iters = kwargs.pop("MoI_iters", 20) + self.sigma_S = kwargs.pop("sigma_S", 5.0) + self.sigma = kwargs.pop("sigma", 1.0) W_TS = (self.sigma - self.sigma_S) / (self.sigma + self.sigma_S) - self.iters = np.arange(self.MoI_iters) + 1 #Eq 6, Ness (2015) + self.iters = np.arange(self.MoI_iters) + 1 # Eq 6, Ness (2015) self.iter_factor = W_TS**self.iters super(MoIKCSD, self).__init__(ele_pos, pots, **kwargs) @@ -738,11 +755,15 @@ def forward_model(self, x, R, h, sigma, src_type): pot : float value of potential at specified distance from the source """ - pot, err = integrate.dblquad(self.int_pot_2D_moi, -R, R, - lambda x: -R, - lambda x: R, - args=(x, R, h, src_type)) - pot *= 1./(2.0*np.pi*sigma) + pot, err = integrate.dblquad( + self.int_pot_2D_moi, + -R, + R, + lambda x: -R, + lambda x: R, + args=(x, R, h, src_type), + ) + pot *= 1.0 / (2.0 * np.pi * sigma) return pot def int_pot_2D_moi(self, xp, yp, x, R, h, basis_func): @@ -768,15 +789,18 @@ def int_pot_2D_moi(self, xp, yp, x, R, h, basis_func): ------- pot : float """ - L = ((x-xp)**2 + yp**2)**(0.5) + L = ((x - xp) ** 2 + yp**2) ** (0.5) if L < 0.00001: L = 0.00001 - correction = np.arcsinh((h-(2*h*self.iters))/L) + np.arcsinh((h+(2*h*self.iters))/L) - pot = np.arcsinh(h/L) + np.sum(self.iter_factor*correction) + correction = np.arcsinh((h - (2 * h * self.iters)) / L) + np.arcsinh( + (h + (2 * h * self.iters)) / L + ) + pot = np.arcsinh(h / L) + np.sum(self.iter_factor * correction) dist = np.sqrt(xp**2 + yp**2) - pot *= basis_func(dist, R) #Eq 20, Ness et.al. + pot *= basis_func(dist, R) # Eq 20, Ness et.al. return pot + class KCSD3D(KCSD): """KCSD3D - The 3D variant for the Kernel Current Source Density method. This estimates the Current Source Density, for a given configuration of @@ -784,6 +808,7 @@ class KCSD3D(KCSD): electrodes. The method implented here is based on the original paper by Jan Potworowski et.al. 2012. """ + def __init__(self, ele_pos, pots, **kwargs): """Initialize KCSD3D Class. Parameters @@ -846,12 +871,14 @@ def estimate_at(self): ---------- None """ - nx = (self.xmax - self.xmin)/self.gdx - ny = (self.ymax - self.ymin)/self.gdy - nz = (self.zmax - self.zmin)/self.gdz - self.estm_x, self.estm_y, self.estm_z = np.mgrid[self.xmin:self.xmax:complex(0,nx), - self.ymin:self.ymax:complex(0,ny), - self.zmin:self.zmax:complex(0,nz)] + nx = (self.xmax - self.xmin) / self.gdx + ny = (self.ymax - self.ymin) / self.gdy + nz = (self.zmax - self.zmin) / self.gdz + self.estm_x, self.estm_y, self.estm_z = np.mgrid[ + self.xmin : self.xmax : complex(0, nx), + self.ymin : self.ymax : complex(0, ny), + self.zmin : self.zmax : complex(0, nz), + ] self.n_estm = self.estm_x.size self.ngx, self.ngy, self.ngz = self.estm_x.shape @@ -873,16 +900,19 @@ def place_basis(self): try: self.basis = basis.basis_3D[source_type] except KeyError: - raise KeyError('Invalid source_type for basis! available are:', - basis.basis_3D.keys()) - (self.src_x, self.src_y, self.src_z, self.R) = utils.distribute_srcs_3D(self.estm_x, - self.estm_y, - self.estm_z, - self.n_src_init, - self.ext_x, - self.ext_y, - self.ext_z, - self.R_init) + raise KeyError( + "Invalid source_type for basis! available are:", basis.basis_3D.keys() + ) + (self.src_x, self.src_y, self.src_z, self.R) = utils.distribute_srcs_3D( + self.estm_x, + self.estm_y, + self.estm_z, + self.n_src_init, + self.ext_x, + self.ext_y, + self.ext_z, + self.R_init, + ) self.n_src = self.src_x.size self.nsx, self.nsy, self.nsz = self.src_x.shape @@ -893,15 +923,15 @@ def create_src_dist_tables(self): ---------- None """ - src_loc = np.array((self.src_x.ravel(), - self.src_y.ravel(), - self.src_z.ravel())) - est_loc = np.array((self.estm_x.ravel(), - self.estm_y.ravel(), - self.estm_z.ravel())) - self.src_ele_dists = distance.cdist(src_loc.T, self.ele_pos, 'euclidean') - self.src_estm_dists = distance.cdist(src_loc.T, est_loc.T, 'euclidean') - self.dist_max = max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + src_loc = np.array((self.src_x.ravel(), self.src_y.ravel(), self.src_z.ravel())) + est_loc = np.array( + (self.estm_x.ravel(), self.estm_y.ravel(), self.estm_z.ravel()) + ) + self.src_ele_dists = distance.cdist(src_loc.T, self.ele_pos, "euclidean") + self.src_estm_dists = distance.cdist(src_loc.T, est_loc.T, "euclidean") + self.dist_max = ( + max(np.max(self.src_ele_dists), np.max(self.src_estm_dists)) + self.R + ) def forward_model(self, x, R, h, sigma, src_type): """FWD model functions @@ -921,45 +951,58 @@ def forward_model(self, x, R, h, sigma, src_type): value of potential at specified distance from the source """ if src_type.__name__ == "gauss_3D": - if x == 0: x=0.0001 - pot = special.erf(x/(np.sqrt(2)*R/3.0)) / x + if x == 0: + x = 0.0001 + pot = special.erf(x / (np.sqrt(2) * R / 3.0)) / x elif src_type.__name__ == "gauss_lim_3D": - if x == 0: x=0.0001 - d = R/3. + if x == 0: + x = 0.0001 + d = R / 3.0 if x < R: - e = np.exp(-(x/ (np.sqrt(2)*d))**2) - erf = special.erf(x / (np.sqrt(2)*d)) - pot = 4* np.pi * ( (d**2)*(e - np.exp(-4.5)) + - (1/x)*((np.sqrt(np.pi/2)*(d**3)*erf) - x*(d**2)*e)) + e = np.exp(-((x / (np.sqrt(2) * d)) ** 2)) + erf = special.erf(x / (np.sqrt(2) * d)) + pot = ( + 4 + * np.pi + * ( + (d**2) * (e - np.exp(-4.5)) + + (1 / x) + * ((np.sqrt(np.pi / 2) * (d**3) * erf) - x * (d**2) * e) + ) + ) else: - pot = 15.28828*(d)**3 / x - pot /= (np.sqrt(2*np.pi)*d)**3 + pot = 15.28828 * (d) ** 3 / x + pot /= (np.sqrt(2 * np.pi) * d) ** 3 elif src_type.__name__ == "step_3D": - Q = 4.*np.pi*(R**3)/3. + Q = 4.0 * np.pi * (R**3) / 3.0 if x < R: - pot = (Q * (3 - (x/R)**2)) / (2.*R) + pot = (Q * (3 - (x / R) ** 2)) / (2.0 * R) else: pot = Q / x - pot *= 3/(4*np.pi*R**3) + pot *= 3 / (4 * np.pi * R**3) else: if skmonaco_available: - pot, err = mcmiser(self.int_pot_3D_mc, - npoints=1e5, - xl=[-R, -R, -R], - xu=[R, R, R], - seed=42, - nprocs=num_cores, - args=(x, R, h, src_type)) + pot, err = mcmiser( + self.int_pot_3D_mc, + npoints=1e5, + xl=[-R, -R, -R], + xu=[R, R, R], + seed=42, + nprocs=num_cores, + args=(x, R, h, src_type), + ) else: - pot, err = integrate.tplquad(self.int_pot_3D, - -R, - R, - lambda x: -R, - lambda x: R, - lambda x, y: -R, - lambda x, y: R, - args=(x, R, h, src_type)) - pot *= 1./(4.0*np.pi*sigma) + pot, err = integrate.tplquad( + self.int_pot_3D, + -R, + R, + lambda x: -R, + lambda x: R, + lambda x, y: -R, + lambda x, y: R, + args=(x, R, h, src_type), + ) + pot *= 1.0 / (4.0 * np.pi * sigma) return pot def int_pot_3D(self, xp, yp, zp, x, R, h, basis_func): @@ -984,11 +1027,11 @@ def int_pot_3D(self, xp, yp, zp, x, R, h, basis_func): ------- pot : float """ - y = ((x-xp)**2 + yp**2 + zp**2)**0.5 + y = ((x - xp) ** 2 + yp**2 + zp**2) ** 0.5 if y < 0.00001: y = 0.00001 dist = np.sqrt(xp**2 + yp**2 + zp**2) - pot = 1.0/y + pot = 1.0 / y pot *= basis_func(dist, R) return pot @@ -1019,42 +1062,63 @@ def int_pot_3D_mc(self, xyz, x, R, h, basis_func): xp, yp, zp = xyz return self.int_pot_3D(xp, yp, zp, x, R, h, basis_func) -if __name__ == '__main__': - print('Checking 1D') - ele_pos = np.array(([-0.1],[0], [0.5], [1.], [1.4], [2.], [2.3])) + +if __name__ == "__main__": + print("Checking 1D") + ele_pos = np.array(([-0.1], [0], [0.5], [1.0], [1.4], [2.0], [2.3])) pots = np.array([[-1], [-1], [-1], [0], [0], [1], [-1.5]]) - k = KCSD1D(ele_pos, pots, - gdx=0.01, n_src_init=300, - ext_x=0.0, src_type='gauss') + k = KCSD1D(ele_pos, pots, gdx=0.01, n_src_init=300, ext_x=0.0, src_type="gauss") k.cross_validate() print(k.values()) - print('Checking 2D') - ele_pos = np.array([[-0.2, -0.2],[0, 0], [0, 1], [1, 0], [1,1], [0.5, 0.5], - [1.2, 1.2]]) + print("Checking 2D") + ele_pos = np.array( + [[-0.2, -0.2], [0, 0], [0, 1], [1, 0], [1, 1], [0.5, 0.5], [1.2, 1.2]] + ) pots = np.array([[-1], [-1], [-1], [0], [0], [1], [-1.5]]) - k = KCSD2D(ele_pos, pots, - gdx=0.05, gdy=0.05, - xmin=-2.0, xmax=2.0, - ymin=-2.0, ymax=2.0, - src_type='gauss') + k = KCSD2D( + ele_pos, + pots, + gdx=0.05, + gdy=0.05, + xmin=-2.0, + xmax=2.0, + ymin=-2.0, + ymax=2.0, + src_type="gauss", + ) k.cross_validate() print(k.values()) - print('Checking MoIKCSD') - k = MoIKCSD(ele_pos, pots, - gdx=0.05, gdy=0.05, - xmin=-2.0, xmax=2.0, - ymin=-2.0, ymax= 2.0) + print("Checking MoIKCSD") + k = MoIKCSD( + ele_pos, pots, gdx=0.05, gdy=0.05, xmin=-2.0, xmax=2.0, ymin=-2.0, ymax=2.0 + ) k.cross_validate() - print('Checking KCSD3D') - ele_pos = np.array([(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), - (0, 1, 1), (1, 1, 0), (1, 0, 1), (1, 1, 1), - (0.5, 0.5, 0.5)]) + print("Checking KCSD3D") + ele_pos = np.array( + [ + (0, 0, 0), + (0, 0, 1), + (0, 1, 0), + (1, 0, 0), + (0, 1, 1), + (1, 1, 0), + (1, 0, 1), + (1, 1, 1), + (0.5, 0.5, 0.5), + ] + ) pots = np.array([[-0.5], [0], [-0.5], [0], [0], [0.2], [0], [0], [1]]) - k = KCSD3D(ele_pos, pots, - gdx=0.02, gdy=0.02, gdz=0.02, - n_src_init=1000, src_type='gauss_lim') + k = KCSD3D( + ele_pos, + pots, + gdx=0.02, + gdy=0.02, + gdz=0.02, + n_src_init=1000, + src_type="gauss_lim", + ) k.cross_validate() print(k.values()) diff --git a/elephant/current_source_density_src/basis_functions.py b/elephant/current_source_density_src/basis_functions.py index 7c5892122..b136db68b 100644 --- a/elephant/current_source_density_src/basis_functions.py +++ b/elephant/current_source_density_src/basis_functions.py @@ -11,10 +11,12 @@ Laboratory of Neuroinformatics, Nencki Institute of Experimental Biology, Warsaw. """ + from __future__ import division import numpy as np + def gauss(d, stdev, dim): """Gaussian function Parameters @@ -30,9 +32,10 @@ def gauss(d, stdev, dim): Z : floats or np.arrays function evaluated """ - Z = np.exp(-(d**2) / (2* stdev**2) ) / (np.sqrt(2*np.pi)*stdev)**dim + Z = np.exp(-(d**2) / (2 * stdev**2)) / (np.sqrt(2 * np.pi) * stdev) ** dim return Z + def step_1D(d, R): """Returns normalized 1D step function. Parameters @@ -45,10 +48,11 @@ def step_1D(d, R): ------- s : Value of the function (d <= R) / R """ - s = (d <= R) - s = s / R #normalize with width + s = d <= R + s = s / R # normalize with width return s + def gauss_1D(d, three_stdev): """Returns normalized gaussian 2D scale function Parameters @@ -61,10 +65,11 @@ def gauss_1D(d, three_stdev): ------- Z : (three_std/3)*(1/2*pi)*(exp(-0.5)*stddev**(-2) *(d**2)) """ - stdev = three_stdev/3.0 + stdev = three_stdev / 3.0 Z = gauss(d, stdev, 1) return Z + def gauss_lim_1D(d, three_stdev): """Returns gausian 2D function cut off after 3 standard deviations. Parameters @@ -79,9 +84,10 @@ def gauss_lim_1D(d, three_stdev): cut off = three_stdev """ Z = gauss_1D(d, three_stdev) - Z *= (d < three_stdev) + Z *= d < three_stdev return Z + def step_2D(d, R): """Returns normalized 2D step function. Parameters @@ -95,9 +101,10 @@ def step_2D(d, R): ------- s : step function """ - s = (d <= R) / (np.pi*(R**2)) + s = (d <= R) / (np.pi * (R**2)) return s + def gauss_2D(d, three_stdev): """Returns normalized gaussian 2D scale function Parameters @@ -111,10 +118,11 @@ def gauss_2D(d, three_stdev): Z : function Normalized gaussian 2D function """ - stdev = three_stdev/3.0 + stdev = three_stdev / 3.0 Z = gauss(d, stdev, 2) return Z + def gauss_lim_2D(d, three_stdev): """Returns gausian 2D function cut off after 3 standard deviations. Parameters @@ -128,9 +136,10 @@ def gauss_lim_2D(d, three_stdev): Z : function Normalized gaussian 2D function cut off after three_stdev """ - Z = (d <= three_stdev)*gauss_2D(d, three_stdev) + Z = (d <= three_stdev) * gauss_2D(d, three_stdev) return Z + def gauss_3D(d, three_stdev): """Returns normalized gaussian 3D scale function Parameters @@ -144,10 +153,11 @@ def gauss_3D(d, three_stdev): Z : funtion Normalized gaussian 3D function """ - stdev = three_stdev/3.0 + stdev = three_stdev / 3.0 Z = gauss(d, stdev, 3) return Z + def gauss_lim_3D(d, three_stdev): """Returns normalized gaussian 3D scale function cut off after 3stdev Parameters @@ -165,6 +175,7 @@ def gauss_lim_3D(d, three_stdev): Z = Z * (d < (three_stdev)) return Z + def step_3D(d, R): """Returns normalized 3D step function. Parameters @@ -178,9 +189,10 @@ def step_3D(d, R): s : step function in 3D """ - s = 3/(4*np.pi*R**3)*(d <= R) + s = 3 / (4 * np.pi * R**3) * (d <= R) return s + basis_1D = { "step": step_1D, "gauss": gauss_1D, diff --git a/elephant/current_source_density_src/icsd.py b/elephant/current_source_density_src/icsd.py index bb3376ab3..93d053844 100644 --- a/elephant/current_source_density_src/icsd.py +++ b/elephant/current_source_density_src/icsd.py @@ -41,7 +41,8 @@ class CSD(object): """Base iCSD class""" - def __init__(self, lfp, f_type='gaussian', f_order=(3, 1)): + + def __init__(self, lfp, f_type="gaussian", f_order=(3, 1)): """Initialize parent class iCSD Parameters @@ -53,13 +54,15 @@ def __init__(self, lfp, f_type='gaussian', f_order=(3, 1)): f_order : list settings for spatial filter, arg passed to filter design function """ - self.name = 'CSD estimate parent class' + self.name = "CSD estimate parent class" self.lfp = lfp self.f_matrix = np.eye(lfp.shape[0]) * pq.m**3 / pq.S self.f_type = f_type self.f_order = f_order - def get_csd(self, ): + def get_csd( + self, + ): """ Perform the CSD estimate from the LFP and forward matrix F, i.e as CSD=F**-1*LFP @@ -76,7 +79,7 @@ def get_csd(self, ): return csd * (self.f_matrix.units**-1 * self.lfp.units).simplified - def filter_csd(self, csd, filterfunction='convolve'): + def filter_csd(self, csd, filterfunction="convolve"): """ Spatial filtering of the CSD estimate, using an N-point filter @@ -88,59 +91,65 @@ def filter_csd(self, csd, filterfunction='convolve'): 'filtfilt' or 'convolve'. Apply spatial filter using scipy.signal.filtfilt or scipy.signal.convolve. """ - if self.f_type == 'gaussian': + if self.f_type == "gaussian": try: - assert(len(self.f_order) == 2) + assert len(self.f_order) == 2 except AssertionError as ae: - raise ae('filter order f_order must be a tuple of length 2') + raise ae("filter order f_order must be a tuple of length 2") else: try: - assert (self.f_order > 0 and isinstance(self.f_order, int)) + assert self.f_order > 0 and isinstance(self.f_order, int) except AssertionError as ae: - raise ae('Filter order must be int > 0!') + raise ae("Filter order must be int > 0!") try: - assert (filterfunction in ['filtfilt', 'convolve']) + assert filterfunction in ["filtfilt", "convolve"] except AssertionError as ae: - raise ae("{} not equal to 'filtfilt' or \ - 'convolve'".format(filterfunction)) + raise ae( + "{} not equal to 'filtfilt' or \ + 'convolve'".format(filterfunction) + ) - if self.f_type == 'boxcar': + if self.f_type == "boxcar": num = ss.windows.boxcar(self.f_order) denom = np.array([num.sum()]) - elif self.f_type == 'hamming': + elif self.f_type == "hamming": num = ss.windows.hamming(self.f_order) denom = np.array([num.sum()]) - elif self.f_type == 'triangular': + elif self.f_type == "triangular": num = ss.windows.triang(self.f_order) denom = np.array([num.sum()]) - elif self.f_type == 'gaussian': + elif self.f_type == "gaussian": num = ss.windows.gaussian(self.f_order[0], self.f_order[1]) denom = np.array([num.sum()]) - elif self.f_type == 'identity': - num = np.array([1.]) - denom = np.array([1.]) + elif self.f_type == "identity": + num = np.array([1.0]) + denom = np.array([1.0]) else: - print('%s Wrong filter type!' % self.f_type) + print("%s Wrong filter type!" % self.f_type) raise - num_string = '[ ' + num_string = "[ " for i in num: - num_string = num_string + '%.3f ' % i - num_string = num_string + ']' - denom_string = '[ ' + num_string = num_string + "%.3f " % i + num_string = num_string + "]" + denom_string = "[ " for i in denom: - denom_string = denom_string + '%.3f ' % i - denom_string = denom_string + ']' + denom_string = denom_string + "%.3f " % i + denom_string = denom_string + "]" - print(('discrete filter coefficients: \nb = {}, \ - \na = {}'.format(num_string, denom_string))) + print( + ( + "discrete filter coefficients: \nb = {}, \ + \na = {}".format(num_string, denom_string) + ) + ) - if filterfunction == 'filtfilt': + if filterfunction == "filtfilt": return ss.filtfilt(num, denom, csd, axis=0) * csd.units - elif filterfunction == 'convolve': + elif filterfunction == "convolve": csdf = csd / csd.units for i in range(csdf.shape[1]): - csdf[:, i] = ss.convolve(csdf[:, i], num / denom.sum(), 'same') + csdf[:, i] = ss.convolve(csdf[:, i], num / denom.sum(), "same") return csdf * csd.units @@ -179,9 +188,9 @@ def __init__(self, lfp, coord_electrode, **kwargs): diff_diff_coord = np.diff(np.diff(coord_electrode)).magnitude zeros_ddc = np.zeros_like(diff_diff_coord) try: - assert(np.all(np.isclose(diff_diff_coord, zeros_ddc, atol=1e-12))) + assert np.all(np.isclose(diff_diff_coord, zeros_ddc, atol=1e-12)) except AssertionError as ae: - print('coord_electrode not monotonously varying') + print("coord_electrode not monotonously varying") raise ae if self.vaknin_el: @@ -191,13 +200,13 @@ def __init__(self, lfp, coord_electrode, **kwargs): else: self.lfp = np.empty((lfp.shape[0] + 2, lfp.shape[1])) self.lfp[0,] = lfp[0,] - self.lfp[1:-1, ] = lfp + self.lfp[1:-1,] = lfp self.lfp[-1,] = lfp[-1,] self.lfp = self.lfp * lfp.units else: self.lfp = lfp - self.name = 'Standard CSD method' + self.name = "Standard CSD method" self.coord_electrode = coord_electrode self.f_inv_matrix = self.get_f_inv_matrix() @@ -209,12 +218,12 @@ def parameters(self, **kwargs): **kwargs Same as those passed to initialize the Class """ - self.sigma = kwargs.pop('sigma', 0.3 * pq.S / pq.m) - self.vaknin_el = kwargs.pop('vaknin_el', True) - self.f_type = kwargs.pop('f_type', 'gaussian') - self.f_order = kwargs.pop('f_order', (3, 1)) + self.sigma = kwargs.pop("sigma", 0.3 * pq.S / pq.m) + self.vaknin_el = kwargs.pop("vaknin_el", True) + self.f_type = kwargs.pop("f_type", "gaussian") + self.f_order = kwargs.pop("f_order", (3, 1)) if kwargs: - raise TypeError('Invalid keyword arguments:', kwargs.keys()) + raise TypeError("Invalid keyword arguments:", kwargs.keys()) def get_f_inv_matrix(self): """Calculate the inverse F-matrix for the standard CSD method""" @@ -223,7 +232,7 @@ def get_f_inv_matrix(self): # Inner matrix elements is just the discrete laplacian coefficients for j in range(1, f_inv.shape[0] - 1): - f_inv[j, j - 1: j + 2] = np.array([1., -2., 1.]) + f_inv[j, j - 1 : j + 2] = np.array([1.0, -2.0, 1.0]) return f_inv * -self.sigma / h_val def get_csd(self): @@ -235,7 +244,7 @@ def get_csd(self): csd : np.ndarray * quantity.Quantity Array with the csd estimate """ - csd = np.dot(self.f_inv_matrix, self.lfp)[1:-1, ] + csd = np.dot(self.f_inv_matrix, self.lfp)[1:-1,] # `np.dot()` does not return correct units, so the units of `csd` must # be assigned manually csd_units = (self.f_inv_matrix.units * self.lfp.units).simplified @@ -248,6 +257,7 @@ class DeltaiCSD(CSD): """ delta-iCSD method """ + def __init__(self, lfp, coord_electrode, **kwargs): """ Initialize the delta-iCSD method class object @@ -280,37 +290,41 @@ def __init__(self, lfp, coord_electrode, **kwargs): CSD.__init__(self, lfp, self.f_type, self.f_order) try: # Should the class not take care of this?! - assert(self.diam.units == coord_electrode.units) + assert self.diam.units == coord_electrode.units except AssertionError as ae: - print('units of coord_electrode ({}) and diam ({}) differ' - .format(coord_electrode.units, self.diam.units)) + print( + "units of coord_electrode ({}) and diam ({}) differ".format( + coord_electrode.units, self.diam.units + ) + ) raise ae try: - assert(np.all(np.diff(coord_electrode) > 0)) + assert np.all(np.diff(coord_electrode) > 0) except AssertionError as ae: - print('values of coord_electrode not continously increasing') + print("values of coord_electrode not continously increasing") raise ae try: - assert(self.diam.size == 1 or self.diam.size == coord_electrode.size) + assert self.diam.size == 1 or self.diam.size == coord_electrode.size if self.diam.size == coord_electrode.size: - assert(np.all(self.diam > 0 * self.diam.units)) + assert np.all(self.diam > 0 * self.diam.units) else: - assert(self.diam > 0 * self.diam.units) + assert self.diam > 0 * self.diam.units except AssertionError as ae: - print('diam must be positive scalar or of same shape \ - as coord_electrode') + print( + "diam must be positive scalar or of same shape \ + as coord_electrode" + ) raise ae if self.diam.size == 1: self.diam = np.ones(coord_electrode.size) * self.diam - self.name = 'delta-iCSD method' + self.name = "delta-iCSD method" self.coord_electrode = coord_electrode # initialize F- and iCSD-matrices - self.f_matrix = np.empty((self.coord_electrode.size, - self.coord_electrode.size)) + self.f_matrix = np.empty((self.coord_electrode.size, self.coord_electrode.size)) self.f_matrix = self.get_f_matrix() def parameters(self, **kwargs): @@ -320,38 +334,44 @@ def parameters(self, **kwargs): **kwargs Same as those passed to initialize the Class """ - self.diam = kwargs.pop('diam', 500E-6 * pq.m) - self.sigma = kwargs.pop('sigma', 0.3 * pq.S / pq.m) - self.sigma_top = kwargs.pop('sigma_top', 0.3 * pq.S / pq.m) - self.f_type = kwargs.pop('f_type', 'gaussian') - self.f_order = kwargs.pop('f_order', (3, 1)) + self.diam = kwargs.pop("diam", 500e-6 * pq.m) + self.sigma = kwargs.pop("sigma", 0.3 * pq.S / pq.m) + self.sigma_top = kwargs.pop("sigma_top", 0.3 * pq.S / pq.m) + self.f_type = kwargs.pop("f_type", "gaussian") + self.f_order = kwargs.pop("f_order", (3, 1)) if kwargs: - raise TypeError('Invalid keyword arguments:', kwargs.keys()) + raise TypeError("Invalid keyword arguments:", kwargs.keys()) def get_f_matrix(self): """Calculate the F-matrix""" - f_matrix = np.empty((self.coord_electrode.size, - self.coord_electrode.size)) * self.coord_electrode.units + f_matrix = ( + np.empty((self.coord_electrode.size, self.coord_electrode.size)) + * self.coord_electrode.units + ) for j in range(self.coord_electrode.size): for i in range(self.coord_electrode.size): - f_matrix[j, i] = ((np.sqrt((self.coord_electrode[j] - - self.coord_electrode[i])**2 + - (self.diam[j] / 2)**2) - abs(self.coord_electrode[j] - - self.coord_electrode[i])) + - (self.sigma - self.sigma_top) / (self.sigma + - self.sigma_top) * - (np.sqrt((self.coord_electrode[j] + - self.coord_electrode[i])**2 + (self.diam[j] / 2)**2)- - abs(self.coord_electrode[j] + self.coord_electrode[i]))) - - f_matrix /= (2 * self.sigma) + f_matrix[j, i] = ( + np.sqrt( + (self.coord_electrode[j] - self.coord_electrode[i]) ** 2 + + (self.diam[j] / 2) ** 2 + ) + - abs(self.coord_electrode[j] - self.coord_electrode[i]) + ) + (self.sigma - self.sigma_top) / (self.sigma + self.sigma_top) * ( + np.sqrt( + (self.coord_electrode[j] + self.coord_electrode[i]) ** 2 + + (self.diam[j] / 2) ** 2 + ) + - abs(self.coord_electrode[j] + self.coord_electrode[i]) + ) + + f_matrix /= 2 * self.sigma return f_matrix class StepiCSD(CSD): """step-iCSD method""" - def __init__(self, lfp, coord_electrode, **kwargs): + def __init__(self, lfp, coord_electrode, **kwargs): """ Initializing step-iCSD method class object @@ -389,40 +409,45 @@ def __init__(self, lfp, coord_electrode, **kwargs): CSD.__init__(self, lfp, self.f_type, self.f_order) try: # Should the class not take care of this? - assert(self.diam.units == coord_electrode.units) + assert self.diam.units == coord_electrode.units except AssertionError as ae: - print('units of coord_electrode ({}) and diam ({}) differ' - .format(coord_electrode.units, self.diam.units)) + print( + "units of coord_electrode ({}) and diam ({}) differ".format( + coord_electrode.units, self.diam.units + ) + ) raise ae try: - assert(np.all(np.diff(coord_electrode) > 0)) + assert np.all(np.diff(coord_electrode) > 0) except AssertionError as ae: - print('values of coord_electrode not continously increasing') + print("values of coord_electrode not continously increasing") raise ae try: - assert(self.diam.size == 1 or self.diam.size == coord_electrode.size) + assert self.diam.size == 1 or self.diam.size == coord_electrode.size if self.diam.size == coord_electrode.size: - assert(np.all(self.diam > 0 * self.diam.units)) + assert np.all(self.diam > 0 * self.diam.units) else: - assert(self.diam > 0 * self.diam.units) + assert self.diam > 0 * self.diam.units except AssertionError as ae: - print('diam must be positive scalar or of same shape \ - as coord_electrode') + print( + "diam must be positive scalar or of same shape \ + as coord_electrode" + ) raise ae if self.diam.size == 1: self.diam = np.ones(coord_electrode.size) * self.diam try: - assert(self.h.size == 1 or self.h.size == coord_electrode.size) + assert self.h.size == 1 or self.h.size == coord_electrode.size if self.h.size == coord_electrode.size: - assert(np.all(self.h > 0 * self.h.units)) + assert np.all(self.h > 0 * self.h.units) except AssertionError as ae: - print('h must be scalar or of same shape as coord_electrode') + print("h must be scalar or of same shape as coord_electrode") raise ae if self.h.size == 1: self.h = np.ones(coord_electrode.size) * self.h - self.name = 'step-iCSD method' + self.name = "step-iCSD method" self.coord_electrode = coord_electrode # compute forward-solution matrix @@ -436,15 +461,15 @@ def parameters(self, **kwargs): Same as those passed to initialize the Class """ - self.diam = kwargs.pop('diam', 500E-6 * pq.m) - self.h = kwargs.pop('h', np.ones(23) * 100E-6 * pq.m) - self.sigma = kwargs.pop('sigma', 0.3 * pq.S / pq.m) - self.sigma_top = kwargs.pop('sigma_top', 0.3 * pq.S / pq.m) - self.tol = kwargs.pop('tol', 1e-6) - self.f_type = kwargs.pop('f_type', 'gaussian') - self.f_order = kwargs.pop('f_order', (3, 1)) + self.diam = kwargs.pop("diam", 500e-6 * pq.m) + self.h = kwargs.pop("h", np.ones(23) * 100e-6 * pq.m) + self.sigma = kwargs.pop("sigma", 0.3 * pq.S / pq.m) + self.sigma_top = kwargs.pop("sigma_top", 0.3 * pq.S / pq.m) + self.tol = kwargs.pop("tol", 1e-6) + self.f_type = kwargs.pop("f_type", "gaussian") + self.f_order = kwargs.pop("f_order", (3, 1)) if kwargs: - raise TypeError('Invalid keyword arguments:', kwargs.keys()) + raise TypeError("Invalid keyword arguments:", kwargs.keys()) def get_f_matrix(self): """Calculate F-matrix for step iCSD method""" @@ -458,16 +483,28 @@ def get_f_matrix(self): upper_int = self.coord_electrode[i] + self.h[j] / 2 # components of f_matrix object - f_cyl0 = si.quad(self._f_cylinder, - a=lower_int, b=upper_int, - args=(float(self.coord_electrode[j]), - float(self.diam[j]), - float(self.sigma)), - epsabs=self.tol)[0] - f_cyl1 = si.quad(self._f_cylinder, a=lower_int, b=upper_int, - args=(-float(self.coord_electrode[j]), - float(self.diam[j]), float(self.sigma)), - epsabs=self.tol)[0] + f_cyl0 = si.quad( + self._f_cylinder, + a=lower_int, + b=upper_int, + args=( + float(self.coord_electrode[j]), + float(self.diam[j]), + float(self.sigma), + ), + epsabs=self.tol, + )[0] + f_cyl1 = si.quad( + self._f_cylinder, + a=lower_int, + b=upper_int, + args=( + -float(self.coord_electrode[j]), + float(self.diam[j]), + float(self.sigma), + ), + epsabs=self.tol, + )[0] # method of images coefficient mom = (self.sigma - self.sigma_top) / (self.sigma + self.sigma_top) @@ -479,15 +516,18 @@ def get_f_matrix(self): def _f_cylinder(self, zeta, z_val, diam, sigma): """function used by class method""" - f_cyl = 1. / (2. * sigma) * \ - (np.sqrt((diam / 2)**2 + ((z_val - zeta))**2) - abs(z_val - zeta)) + f_cyl = ( + 1.0 + / (2.0 * sigma) + * (np.sqrt((diam / 2) ** 2 + (z_val - zeta) ** 2) - abs(z_val - zeta)) + ) return f_cyl class SplineiCSD(CSD): """spline iCSD method""" - def __init__(self, lfp, coord_electrode, **kwargs): + def __init__(self, lfp, coord_electrode, **kwargs): """ Initializing spline-iCSD method class object @@ -525,28 +565,31 @@ def __init__(self, lfp, coord_electrode, **kwargs): CSD.__init__(self, lfp, self.f_type, self.f_order) try: # Should the class not take care of this?! - assert(self.diam.units == coord_electrode.units) + assert self.diam.units == coord_electrode.units except AssertionError as ae: - print('units of coord_electrode ({}) and diam ({}) differ' - .format(coord_electrode.units, self.diam.units)) + print( + "units of coord_electrode ({}) and diam ({}) differ".format( + coord_electrode.units, self.diam.units + ) + ) raise try: - assert(np.all(np.diff(coord_electrode) > 0)) + assert np.all(np.diff(coord_electrode) > 0) except AssertionError as ae: - print('values of coord_electrode not continously increasing') + print("values of coord_electrode not continously increasing") raise ae try: - assert(self.diam.size == 1 or self.diam.size == coord_electrode.size) + assert self.diam.size == 1 or self.diam.size == coord_electrode.size if self.diam.size == coord_electrode.size: - assert(np.all(self.diam > 0 * self.diam.units)) + assert np.all(self.diam > 0 * self.diam.units) except AssertionError as ae: - print('diam must be scalar or of same shape as coord_electrode') + print("diam must be scalar or of same shape as coord_electrode") raise ae if self.diam.size == 1: self.diam = np.ones(coord_electrode.size) * self.diam - self.name = 'spline-iCSD method' + self.name = "spline-iCSD method" self.coord_electrode = coord_electrode # compute stuff @@ -559,15 +602,15 @@ def parameters(self, **kwargs): **kwargs Same as those passed to initialize the Class """ - self.diam = kwargs.pop('diam', 500E-6 * pq.m) - self.sigma = kwargs.pop('sigma', 0.3 * pq.S / pq.m) - self.sigma_top = kwargs.pop('sigma_top', 0.3 * pq.S / pq.m) - self.tol = kwargs.pop('tol', 1e-6) - self.num_steps = kwargs.pop('num_steps', 200) - self.f_type = kwargs.pop('f_type', 'gaussian') - self.f_order = kwargs.pop('f_order', (3, 1)) + self.diam = kwargs.pop("diam", 500e-6 * pq.m) + self.sigma = kwargs.pop("sigma", 0.3 * pq.S / pq.m) + self.sigma_top = kwargs.pop("sigma_top", 0.3 * pq.S / pq.m) + self.tol = kwargs.pop("tol", 1e-6) + self.num_steps = kwargs.pop("num_steps", 200) + self.f_type = kwargs.pop("f_type", "gaussian") + self.f_order = kwargs.pop("f_order", (3, 1)) if kwargs: - raise TypeError('Invalid keyword arguments:', kwargs.keys()) + raise TypeError("Invalid keyword arguments:", kwargs.keys()) def get_f_matrix(self): """Calculate the F-matrix for cubic spline iCSD method""" @@ -585,59 +628,111 @@ def get_f_matrix(self): # Calc. elements for j in range(el_len): for i in range(el_len): - f_mat0[j, i] = si.quad(self._f_mat0, a=z_js[i], b=z_js[i + 1], - args=(z_js[j + 1], - float(self.sigma), - float(self.diam[j])), - epsabs=self.tol)[0] - f_mat1[j, i] = si.quad(self._f_mat1, a=z_js[i], b=z_js[i + 1], - args=(z_js[j + 1], z_js[i], - float(self.sigma), - float(self.diam[j])), - epsabs=self.tol)[0] - f_mat2[j, i] = si.quad(self._f_mat2, a=z_js[i], b=z_js[i + 1], - args=(z_js[j + 1], z_js[i], - float(self.sigma), - float(self.diam[j])), - epsabs=self.tol)[0] - f_mat3[j, i] = si.quad(self._f_mat3, a=z_js[i], b=z_js[i + 1], - args=(z_js[j + 1], z_js[i], - float(self.sigma), - float(self.diam[j])), - epsabs=self.tol)[0] + f_mat0[j, i] = si.quad( + self._f_mat0, + a=z_js[i], + b=z_js[i + 1], + args=(z_js[j + 1], float(self.sigma), float(self.diam[j])), + epsabs=self.tol, + )[0] + f_mat1[j, i] = si.quad( + self._f_mat1, + a=z_js[i], + b=z_js[i + 1], + args=(z_js[j + 1], z_js[i], float(self.sigma), float(self.diam[j])), + epsabs=self.tol, + )[0] + f_mat2[j, i] = si.quad( + self._f_mat2, + a=z_js[i], + b=z_js[i + 1], + args=(z_js[j + 1], z_js[i], float(self.sigma), float(self.diam[j])), + epsabs=self.tol, + )[0] + f_mat3[j, i] = si.quad( + self._f_mat3, + a=z_js[i], + b=z_js[i + 1], + args=(z_js[j + 1], z_js[i], float(self.sigma), float(self.diam[j])), + epsabs=self.tol, + )[0] # image technique if conductivity not constant: if self.sigma != self.sigma_top: - f_mat0[j, i] = f_mat0[j, i] + (self.sigma-self.sigma_top) / \ - (self.sigma + self.sigma_top) * \ - si.quad(self._f_mat0, a=z_js[i], b=z_js[i+1], \ - args=(-z_js[j+1], - float(self.sigma), float(self.diam[j])), \ - epsabs=self.tol)[0] - f_mat1[j, i] = f_mat1[j, i] + (self.sigma-self.sigma_top) / \ - (self.sigma + self.sigma_top) * \ - si.quad(self._f_mat1, a=z_js[i], b=z_js[i+1], \ - args=(-z_js[j+1], z_js[i], float(self.sigma), - float(self.diam[j])), epsabs=self.tol)[0] - f_mat2[j, i] = f_mat2[j, i] + (self.sigma-self.sigma_top) / \ - (self.sigma + self.sigma_top) * \ - si.quad(self._f_mat2, a=z_js[i], b=z_js[i+1], \ - args=(-z_js[j+1], z_js[i], float(self.sigma), - float(self.diam[j])), epsabs=self.tol)[0] - f_mat3[j, i] = f_mat3[j, i] + (self.sigma-self.sigma_top) / \ - (self.sigma + self.sigma_top) * \ - si.quad(self._f_mat3, a=z_js[i], b=z_js[i+1], \ - args=(-z_js[j+1], z_js[i], float(self.sigma), - float(self.diam[j])), epsabs=self.tol)[0] + f_mat0[j, i] = ( + f_mat0[j, i] + + (self.sigma - self.sigma_top) + / (self.sigma + self.sigma_top) + * si.quad( + self._f_mat0, + a=z_js[i], + b=z_js[i + 1], + args=(-z_js[j + 1], float(self.sigma), float(self.diam[j])), + epsabs=self.tol, + )[0] + ) + f_mat1[j, i] = ( + f_mat1[j, i] + + (self.sigma - self.sigma_top) + / (self.sigma + self.sigma_top) + * si.quad( + self._f_mat1, + a=z_js[i], + b=z_js[i + 1], + args=( + -z_js[j + 1], + z_js[i], + float(self.sigma), + float(self.diam[j]), + ), + epsabs=self.tol, + )[0] + ) + f_mat2[j, i] = ( + f_mat2[j, i] + + (self.sigma - self.sigma_top) + / (self.sigma + self.sigma_top) + * si.quad( + self._f_mat2, + a=z_js[i], + b=z_js[i + 1], + args=( + -z_js[j + 1], + z_js[i], + float(self.sigma), + float(self.diam[j]), + ), + epsabs=self.tol, + )[0] + ) + f_mat3[j, i] = ( + f_mat3[j, i] + + (self.sigma - self.sigma_top) + / (self.sigma + self.sigma_top) + * si.quad( + self._f_mat3, + a=z_js[i], + b=z_js[i + 1], + args=( + -z_js[j + 1], + z_js[i], + float(self.sigma), + float(self.diam[j]), + ), + epsabs=self.tol, + )[0] + ) e_mat0, e_mat1, e_mat2, e_mat3 = self._calc_e_matrices() # Calculate the F-matrix f_matrix = np.eye(el_len + 2) - f_matrix[1:-1, :] = np.dot(f_mat0, e_mat0) + \ - np.dot(f_mat1, e_mat1) + \ - np.dot(f_mat2, e_mat2) + \ - np.dot(f_mat3, e_mat3) + f_matrix[1:-1, :] = ( + np.dot(f_mat0, e_mat0) + + np.dot(f_mat1, e_mat1) + + np.dot(f_mat2, e_mat2) + + np.dot(f_mat3, e_mat3) + ) return f_matrix * self.coord_electrode.units**2 / self.sigma.units @@ -661,9 +756,13 @@ def get_csd(self): cs_lfp = np.r_[[0], np.asarray(self.lfp), [0]].reshape(1, -1).T csd = np.zeros(self.num_steps) else: - cs_lfp = np.vstack((np.zeros(self.lfp.shape[1]), - np.asarray(self.lfp), - np.zeros(self.lfp.shape[1]))) + cs_lfp = np.vstack( + ( + np.zeros(self.lfp.shape[1]), + np.asarray(self.lfp), + np.zeros(self.lfp.shape[1]), + ) + ) csd = np.zeros((self.num_steps, self.lfp.shape[1])) cs_lfp *= self.lfp.units @@ -680,7 +779,7 @@ def get_csd(self): h = np.diff(self.coord_electrode).min() z_js = np.zeros(el_len + 2) z_js[0] = self.coord_electrode[0] - h - z_js[1: -1] = self.coord_electrode + z_js[1:-1] = self.coord_electrode z_js[-1] = self.coord_electrode[-1] + h # create high res spatial grid @@ -691,10 +790,12 @@ def get_csd(self): for j in range(self.num_steps): if out_zs[j] >= z_js[i + 1]: i += 1 - csd[j] = (a_mat0[i, :] + a_mat1[i, :] * - (out_zs[j] - z_js[i]) + - a_mat2[i, :] * (out_zs[j] - z_js[i])**2 + - a_mat3[i, :] * (out_zs[j] - z_js[i])**3).item() + csd[j] = ( + a_mat0[i, :] + + a_mat1[i, :] * (out_zs[j] - z_js[i]) + + a_mat2[i, :] * (out_zs[j] - z_js[i]) ** 2 + + a_mat3[i, :] * (out_zs[j] - z_js[i]) ** 3 + ).item() csd_unit = (self.f_matrix.units**-1 * self.lfp.units).simplified @@ -702,8 +803,11 @@ def get_csd(self): def _f_mat0(self, zeta, z_val, sigma, diam): """0'th order potential function""" - return 1. / (2. * sigma) * \ - (np.sqrt((diam / 2)**2 + ((z_val - zeta))**2) - abs(z_val - zeta)) + return ( + 1.0 + / (2.0 * sigma) + * (np.sqrt((diam / 2) ** 2 + (z_val - zeta) ** 2) - abs(z_val - zeta)) + ) def _f_mat1(self, zeta, z_val, zi_val, sigma, diam): """1'th order potential function""" @@ -711,11 +815,11 @@ def _f_mat1(self, zeta, z_val, zi_val, sigma, diam): def _f_mat2(self, zeta, z_val, zi_val, sigma, diam): """2'nd order potential function""" - return (zeta - zi_val)**2 * self._f_mat0(zeta, z_val, sigma, diam) + return (zeta - zi_val) ** 2 * self._f_mat0(zeta, z_val, sigma, diam) def _f_mat3(self, zeta, z_val, zi_val, sigma, diam): """3'rd order potential function""" - return (zeta - zi_val)**3 * self._f_mat0(zeta, z_val, sigma, diam) + return (zeta - zi_val) ** 3 * self._f_mat0(zeta, z_val, sigma, diam) def _calc_k_matrix(self): """Calculate the K-matrix used by to calculate E-matrices""" @@ -740,14 +844,21 @@ def _calc_k_matrix(self): tj0[-1, -1] = 0 # Defining K-matrix used to calculate e_mat1-3 - return np.dot(np.linalg.inv(np.dot(c_jm1, tjm1) + - 2 * np.dot(c_jm1, tj0) + - 2 * c_jall + - np.dot(c_j0, tjp1)), - 3 * (np.dot(np.dot(c_jm1, c_jm1), tj0) - - np.dot(np.dot(c_jm1, c_jm1), tjm1) + - np.dot(np.dot(c_j0, c_j0), tjp1) - - np.dot(np.dot(c_j0, c_j0), tj0))) + return np.dot( + np.linalg.inv( + np.dot(c_jm1, tjm1) + + 2 * np.dot(c_jm1, tj0) + + 2 * c_jall + + np.dot(c_j0, tjp1) + ), + 3 + * ( + np.dot(np.dot(c_jm1, c_jm1), tj0) + - np.dot(np.dot(c_jm1, c_jm1), tjm1) + + np.dot(np.dot(c_j0, c_j0), tjp1) + - np.dot(np.dot(c_j0, c_j0), tj0) + ), + ) def _calc_e_matrices(self): """Calculate the E-matrices used by cubic spline iCSD method""" @@ -762,127 +873,143 @@ def _calc_e_matrices(self): k_matrix = self._calc_k_matrix() # Define matrixes for C to A transformation: - tja = np.eye(el_len + 2)[:-1, ] - tjp1a = np.eye(el_len + 2, k=1)[:-1, ] + tja = np.eye(el_len + 2)[:-1,] + tjp1a = np.eye(el_len + 2, k=1)[:-1,] # Define spline coefficients e_mat0 = tja e_mat1 = np.dot(tja, k_matrix) - e_mat2 = 3 * np.dot(c_mat3**2, (tjp1a - tja)) - \ - np.dot(np.dot(c_mat3, (tjp1a + 2 * tja)), k_matrix) - e_mat3 = 2 * np.dot(c_mat3**3, (tja - tjp1a)) + \ - np.dot(np.dot(c_mat3**2, (tjp1a + tja)), k_matrix) + e_mat2 = 3 * np.dot(c_mat3**2, (tjp1a - tja)) - np.dot( + np.dot(c_mat3, (tjp1a + 2 * tja)), k_matrix + ) + e_mat3 = 2 * np.dot(c_mat3**3, (tja - tjp1a)) + np.dot( + np.dot(c_mat3**2, (tjp1a + tja)), k_matrix + ) return e_mat0, e_mat1, e_mat2, e_mat3 -if __name__ == '__main__': +if __name__ == "__main__": from scipy.io import loadmat import matplotlib.pyplot as plt - - #loading test data - test_data = loadmat('test_data.mat') - - #prepare lfp data for use, by changing the units to SI and append quantities, - #along with electrode geometry, conductivities and assumed source geometry - lfp_data = test_data['pot1'] * 1E-6 * pq.V # [uV] -> [V] - z_data = np.linspace(100E-6, 2300E-6, 23) * pq.m # [m] - diam = 500E-6 * pq.m # [m] - h = 100E-6 * pq.m # [m] - sigma = 0.3 * pq.S / pq.m # [S/m] or [1/(ohm*m)] - sigma_top = 0.3 * pq.S / pq.m # [S/m] or [1/(ohm*m)] - + # loading test data + test_data = loadmat("test_data.mat") + + # prepare lfp data for use, by changing the units to SI and append quantities, + # along with electrode geometry, conductivities and assumed source geometry + lfp_data = test_data["pot1"] * 1e-6 * pq.V # [uV] -> [V] + z_data = np.linspace(100e-6, 2300e-6, 23) * pq.m # [m] + diam = 500e-6 * pq.m # [m] + h = 100e-6 * pq.m # [m] + sigma = 0.3 * pq.S / pq.m # [S/m] or [1/(ohm*m)] + sigma_top = 0.3 * pq.S / pq.m # [S/m] or [1/(ohm*m)] + # Input dictionaries for each method delta_input = { - 'lfp' : lfp_data, - 'coord_electrode' : z_data, - 'diam' : diam, # source diameter - 'sigma' : sigma, # extracellular conductivity - 'sigma_top' : sigma, # conductivity on top of cortex - 'f_type' : 'gaussian', # gaussian filter - 'f_order' : (3, 1), # 3-point filter, sigma = 1. + "lfp": lfp_data, + "coord_electrode": z_data, + "diam": diam, # source diameter + "sigma": sigma, # extracellular conductivity + "sigma_top": sigma, # conductivity on top of cortex + "f_type": "gaussian", # gaussian filter + "f_order": (3, 1), # 3-point filter, sigma = 1. } step_input = { - 'lfp' : lfp_data, - 'coord_electrode' : z_data, - 'diam' : diam, - 'h' : h, # source thickness - 'sigma' : sigma, - 'sigma_top' : sigma, - 'tol' : 1E-12, # Tolerance in numerical integration - 'f_type' : 'gaussian', - 'f_order' : (3, 1), + "lfp": lfp_data, + "coord_electrode": z_data, + "diam": diam, + "h": h, # source thickness + "sigma": sigma, + "sigma_top": sigma, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } spline_input = { - 'lfp' : lfp_data, - 'coord_electrode' : z_data, - 'diam' : diam, - 'sigma' : sigma, - 'sigma_top' : sigma, - 'num_steps' : 201, # Spatial CSD upsampling to N steps - 'tol' : 1E-12, - 'f_type' : 'gaussian', - 'f_order' : (20, 5), + "lfp": lfp_data, + "coord_electrode": z_data, + "diam": diam, + "sigma": sigma, + "sigma_top": sigma, + "num_steps": 201, # Spatial CSD upsampling to N steps + "tol": 1e-12, + "f_type": "gaussian", + "f_order": (20, 5), } std_input = { - 'lfp' : lfp_data, - 'coord_electrode' : z_data, - 'sigma' : sigma, - 'f_type' : 'gaussian', - 'f_order' : (3, 1), + "lfp": lfp_data, + "coord_electrode": z_data, + "sigma": sigma, + "f_type": "gaussian", + "f_order": (3, 1), } - - - #Create the different CSD-method class instances. We use the class methods - #get_csd() and filter_csd() below to get the raw and spatially filtered - #versions of the current-source density estimates. + + # Create the different CSD-method class instances. We use the class methods + # get_csd() and filter_csd() below to get the raw and spatially filtered + # versions of the current-source density estimates. csd_dict = dict( - delta_icsd = DeltaiCSD(**delta_input), - step_icsd = StepiCSD(**step_input), - spline_icsd = SplineiCSD(**spline_input), - std_csd = StandardCSD(**std_input), + delta_icsd=DeltaiCSD(**delta_input), + step_icsd=StepiCSD(**step_input), + spline_icsd=SplineiCSD(**spline_input), + std_csd=StandardCSD(**std_input), ) - - #plot + + # plot for method, csd_obj in list(csd_dict.items()): - fig, axes = plt.subplots(3,1, figsize=(8,8)) - - #plot LFP signal + fig, axes = plt.subplots(3, 1, figsize=(8, 8)) + + # plot LFP signal ax = axes[0] - im = ax.imshow(np.array(lfp_data), origin='upper', vmin=-abs(lfp_data).max(), \ - vmax=abs(lfp_data).max(), cmap='jet_r', interpolation='nearest') - ax.axis(ax.axis('tight')) + im = ax.imshow( + np.array(lfp_data), + origin="upper", + vmin=-abs(lfp_data).max(), + vmax=abs(lfp_data).max(), + cmap="jet_r", + interpolation="nearest", + ) + ax.axis(ax.axis("tight")) cb = plt.colorbar(im, ax=ax) - cb.set_label('LFP (%s)' % lfp_data.dimensionality.string) + cb.set_label("LFP (%s)" % lfp_data.dimensionality.string) ax.set_xticklabels([]) - ax.set_title('LFP') - ax.set_ylabel('ch #') - - #plot raw csd estimate + ax.set_title("LFP") + ax.set_ylabel("ch #") + + # plot raw csd estimate csd = csd_obj.get_csd() ax = axes[1] - im = ax.imshow(np.array(csd), origin='upper', vmin=-abs(csd).max(), \ - vmax=abs(csd).max(), cmap='jet_r', interpolation='nearest') - ax.axis(ax.axis('tight')) + im = ax.imshow( + np.array(csd), + origin="upper", + vmin=-abs(csd).max(), + vmax=abs(csd).max(), + cmap="jet_r", + interpolation="nearest", + ) + ax.axis(ax.axis("tight")) ax.set_title(csd_obj.name) cb = plt.colorbar(im, ax=ax) - cb.set_label('CSD (%s)' % csd.dimensionality.string) + cb.set_label("CSD (%s)" % csd.dimensionality.string) ax.set_xticklabels([]) - ax.set_ylabel('ch #') - - #plot spatially filtered csd estimate + ax.set_ylabel("ch #") + + # plot spatially filtered csd estimate ax = axes[2] csd = csd_obj.filter_csd(csd) - im = ax.imshow(np.array(csd), origin='upper', vmin=-abs(csd).max(), \ - vmax=abs(csd).max(), cmap='jet_r', interpolation='nearest') - ax.axis(ax.axis('tight')) - ax.set_title(csd_obj.name + ', filtered') + im = ax.imshow( + np.array(csd), + origin="upper", + vmin=-abs(csd).max(), + vmax=abs(csd).max(), + cmap="jet_r", + interpolation="nearest", + ) + ax.axis(ax.axis("tight")) + ax.set_title(csd_obj.name + ", filtered") cb = plt.colorbar(im, ax=ax) - cb.set_label('CSD (%s)' % csd.dimensionality.string) - ax.set_ylabel('ch #') - ax.set_xlabel('timestep') - - - plt.show() + cb.set_label("CSD (%s)" % csd.dimensionality.string) + ax.set_ylabel("ch #") + ax.set_xlabel("timestep") + plt.show() diff --git a/elephant/current_source_density_src/utility_functions.py b/elephant/current_source_density_src/utility_functions.py index e11b1470c..5f58dfb9e 100644 --- a/elephant/current_source_density_src/utility_functions.py +++ b/elephant/current_source_density_src/utility_functions.py @@ -10,6 +10,7 @@ Laboratory of Neuroinformatics, Nencki Institute of Experimental Biology, Warsaw. """ + from __future__ import division import numpy as np @@ -20,18 +21,21 @@ def patch_quantities(): """patch quantities with the SI unit Siemens if it does not exist""" for symbol, prefix, definition, u_symbol in zip( - ['siemens', 'S', 'mS', 'uS', 'nS', 'pS'], - ['', '', 'milli', 'micro', 'nano', 'pico'], - [pq.A / pq.V, pq.A / pq.V, 'S', 'mS', 'uS', 'nS'], - [None, None, None, None, u'µS', None]): + ["siemens", "S", "mS", "uS", "nS", "pS"], + ["", "", "milli", "micro", "nano", "pico"], + [pq.A / pq.V, pq.A / pq.V, "S", "mS", "uS", "nS"], + [None, None, None, None, "µS", None], + ): if type(definition) is str: definition = lastdefinition / 1000 if not hasattr(pq, symbol): - setattr(pq, symbol, pq.UnitQuantity( - prefix + 'siemens', - definition, - symbol=symbol, - u_symbol=u_symbol)) + setattr( + pq, + symbol, + pq.UnitQuantity( + prefix + "siemens", definition, symbol=symbol, u_symbol=u_symbol + ), + ) lastdefinition = definition return @@ -70,8 +74,7 @@ def distribute_srcs_1D(X, n_src, ext_x, R_init): R : float effective radius of the basis element """ - X_src = np.mgrid[(np.min(X) - ext_x):(np.max(X) + ext_x): - complex(0, n_src)] + X_src = np.mgrid[(np.min(X) - ext_x) : (np.max(X) + ext_x) : complex(0, n_src)] R = R_init return X_src, R @@ -105,10 +108,10 @@ def distribute_srcs_2D(X, Y, n_src, ext_x, ext_y, R_init): [nx, ny, Lx_nn, Ly_nn, ds] = get_src_params_2D(Lx_n, Ly_n, n_src) ext_x_n = (Lx_nn - Lx) / 2 ext_y_n = (Ly_nn - Ly) / 2 - X_src, Y_src = np.mgrid[(np.min(X) - ext_x_n):(np.max(X) + ext_x_n): - complex(0, nx), - (np.min(Y) - ext_y_n):(np.max(Y) + ext_y_n): - complex(0, ny)] + X_src, Y_src = np.mgrid[ + (np.min(X) - ext_x_n) : (np.max(X) + ext_x_n) : complex(0, nx), + (np.min(Y) - ext_y_n) : (np.max(Y) + ext_y_n) : complex(0, ny), + ] # d = round(R_init / ds) R = R_init # R = d * ds return X_src, Y_src, R @@ -178,19 +181,15 @@ def distribute_srcs_3D(X, Y, Z, n_src, ext_x, ext_y, ext_z, R_init): Lx_n = Lx + 2 * ext_x Ly_n = Ly + 2 * ext_y Lz_n = Lz + 2 * ext_z - (nx, ny, nz, Lx_nn, Ly_nn, Lz_nn, ds) = get_src_params_3D(Lx_n, - Ly_n, - Lz_n, - n_src) + (nx, ny, nz, Lx_nn, Ly_nn, Lz_nn, ds) = get_src_params_3D(Lx_n, Ly_n, Lz_n, n_src) ext_x_n = (Lx_nn - Lx) / 2 ext_y_n = (Ly_nn - Ly) / 2 ext_z_n = (Lz_nn - Lz) / 2 - X_src, Y_src, Z_src = np.mgrid[(np.min(X) - ext_x_n):(np.max(X) + ext_x_n): - complex(0, nx), - (np.min(Y) - ext_y_n):(np.max(Y) + ext_y_n): - complex(0, ny), - (np.min(Z) - ext_z_n):(np.max(Z) + ext_z_n): - complex(0, nz)] + X_src, Y_src, Z_src = np.mgrid[ + (np.min(X) - ext_x_n) : (np.max(X) + ext_x_n) : complex(0, nx), + (np.min(Y) - ext_y_n) : (np.max(Y) + ext_y_n) : complex(0, ny), + (np.min(Z) - ext_z_n) : (np.max(Z) + ext_z_n) : complex(0, nz), + ] # d = np.round(R_init / ds) R = R_init return (X_src, Y_src, Z_src, R) @@ -218,7 +217,7 @@ def get_src_params_3D(Lx, Ly, Lz, n_src): """ V = Lx * Ly * Lz V_unit = V / n_src - L_unit = V_unit**(1. / 3.) + L_unit = V_unit ** (1.0 / 3.0) nx = np.ceil(Lx / L_unit) ny = np.ceil(Ly / L_unit) nz = np.ceil(Lz / L_unit) @@ -229,40 +228,44 @@ def get_src_params_3D(Lx, Ly, Lz, n_src): return (nx, ny, nz, Lx_n, Ly_n, Lz_n, ds) -def generate_electrodes(dim, xlims=[0.1, 0.9], ylims=[0.1, 0.9], - zlims=[0.1, 0.9], res=5): +def generate_electrodes( + dim, xlims=[0.1, 0.9], ylims=[0.1, 0.9], zlims=[0.1, 0.9], res=5 +): """Generates electrodes, helpful for FWD funtion. - Parameters - ---------- - dim : int - Dimensionality of the electrodes, 1,2 or 3 - xlims : [start, end] - Spatial limits of the electrodes - ylims : [start, end] - Spatial limits of the electrodes - zlims : [start, end] - Spatial limits of the electrodes - res : int - How many electrodes in each dimension - Returns - ------- - ele_x, ele_y, ele_z : flattened np.array of the electrode pos + Parameters + ---------- + dim : int + Dimensionality of the electrodes, 1,2 or 3 + xlims : [start, end] + Spatial limits of the electrodes + ylims : [start, end] + Spatial limits of the electrodes + zlims : [start, end] + Spatial limits of the electrodes + res : int + How many electrodes in each dimension + Returns + ------- + ele_x, ele_y, ele_z : flattened np.array of the electrode pos """ if dim == 1: - ele_x = np.mgrid[xlims[0]: xlims[1]: complex(0, res)] + ele_x = np.mgrid[xlims[0] : xlims[1] : complex(0, res)] ele_x = ele_x.flatten() return ele_x elif dim == 2: - ele_x, ele_y = np.mgrid[xlims[0]: xlims[1]: complex(0, res), - ylims[0]: ylims[1]: complex(0, res)] + ele_x, ele_y = np.mgrid[ + xlims[0] : xlims[1] : complex(0, res), ylims[0] : ylims[1] : complex(0, res) + ] ele_x = ele_x.flatten() ele_y = ele_y.flatten() return ele_x, ele_y elif dim == 3: - ele_x, ele_y, ele_z = np.mgrid[xlims[0]: xlims[1]: complex(0, res), - ylims[0]: ylims[1]: complex(0, res), - zlims[0]: zlims[1]: complex(0, res)] + ele_x, ele_y, ele_z = np.mgrid[ + xlims[0] : xlims[1] : complex(0, res), + ylims[0] : ylims[1] : complex(0, res), + zlims[0] : zlims[1] : complex(0, res), + ] ele_x = ele_x.flatten() ele_y = ele_y.flatten() ele_z = ele_z.flatten() @@ -271,94 +274,120 @@ def generate_electrodes(dim, xlims=[0.1, 0.9], ylims=[0.1, 0.9], def gauss_1d_dipole(x): """1D Gaussian dipole source is placed between 0 and 1 - to be used to test the CSD + to be used to test the CSD - Parameters - ---------- - x : np.array - Spatial pts. at which the true csd is evaluated + Parameters + ---------- + x : np.array + Spatial pts. at which the true csd is evaluated - Returns - ------- - f : np.array - The value of the csd at the requested points + Returns + ------- + f : np.array + The value of the csd at the requested points """ - src = 0.5*exp(-((x-0.7)**2)/(2.*0.3))*(2*np.pi*0.3)**-0.5 - snk = -0.5*exp(-((x-0.3)**2)/(2.*0.3))*(2*np.pi*0.3)**-0.5 - f = src+snk + src = 0.5 * exp(-((x - 0.7) ** 2) / (2.0 * 0.3)) * (2 * np.pi * 0.3) ** -0.5 + snk = -0.5 * exp(-((x - 0.3) ** 2) / (2.0 * 0.3)) * (2 * np.pi * 0.3) ** -0.5 + f = src + snk return f + def large_source_2D(x, y): """2D Gaussian large source profile - to use to test csd - Parameters - ---------- - x : np.array - Spatial x pts. at which the true csd is evaluated - y : np.array - Spatial y pts. at which the true csd is evaluated - Returns - ------- - f : np.array - The value of the csd at the requested points + Parameters + ---------- + x : np.array + Spatial x pts. at which the true csd is evaluated + y : np.array + Spatial y pts. at which the true csd is evaluated + Returns + ------- + f : np.array + The value of the csd at the requested points """ zz = [0.4, -0.3, -0.1, 0.6] zs = [0.2, 0.3, 0.4, 0.2] - f1 = 0.5965*exp( (-1*(x-0.1350)**2 - (y-0.8628)**2) /0.4464)* exp(-(-zz[0])**2 / zs[0]) /exp(-(zz[0])**2/zs[0]) - f2 = -0.9269*exp( (-2*(x-0.1848)**2 - (y-0.0897)**2) /0.2046)* exp(-(-zz[1])**2 / zs[1]) /exp(-(zz[1])**2/zs[1]); - f3 = 0.5910*exp( (-3*(x-1.3189)**2 - (y-0.3522)**2) /0.2129)* exp(-(-zz[2])**2 / zs[2]) /exp(-(zz[2])**2/zs[2]); - f4 = -0.1963*exp( (-4*(x-1.3386)**2 - (y-0.5297)**2) /0.2507)* exp(-(-zz[3])**2 / zs[3]) /exp(-(zz[3])**2/zs[3]); - f = f1+f2+f3+f4 + f1 = ( + 0.5965 + * exp((-1 * (x - 0.1350) ** 2 - (y - 0.8628) ** 2) / 0.4464) + * exp(-((-zz[0]) ** 2) / zs[0]) + / exp(-((zz[0]) ** 2) / zs[0]) + ) + f2 = ( + -0.9269 + * exp((-2 * (x - 0.1848) ** 2 - (y - 0.0897) ** 2) / 0.2046) + * exp(-((-zz[1]) ** 2) / zs[1]) + / exp(-((zz[1]) ** 2) / zs[1]) + ) + f3 = ( + 0.5910 + * exp((-3 * (x - 1.3189) ** 2 - (y - 0.3522) ** 2) / 0.2129) + * exp(-((-zz[2]) ** 2) / zs[2]) + / exp(-((zz[2]) ** 2) / zs[2]) + ) + f4 = ( + -0.1963 + * exp((-4 * (x - 1.3386) ** 2 - (y - 0.5297) ** 2) / 0.2507) + * exp(-((-zz[3]) ** 2) / zs[3]) + / exp(-((zz[3]) ** 2) / zs[3]) + ) + f = f1 + f2 + f3 + f4 return f + def small_source_2D(x, y): """2D Gaussian small source profile - to be used to test csd - Parameters - ---------- - x : np.array - Spatial x pts. at which the true csd is evaluated - y : np.array - Spatial y pts. at which the true csd is evaluated - Returns - ------- - f : np.array - The value of the csd at the requested points + Parameters + ---------- + x : np.array + Spatial x pts. at which the true csd is evaluated + y : np.array + Spatial y pts. at which the true csd is evaluated + Returns + ------- + f : np.array + The value of the csd at the requested points """ - def gauss2d(x,y,p): + + def gauss2d(x, y, p): rcen_x = p[0] * np.cos(p[5]) - p[1] * np.sin(p[5]) rcen_y = p[0] * np.sin(p[5]) + p[1] * np.cos(p[5]) xp = x * np.cos(p[5]) - y * np.sin(p[5]) yp = x * np.sin(p[5]) + y * np.cos(p[5]) - g = p[4]*exp(-(((rcen_x-xp)/p[2])**2+ - ((rcen_y-yp)/p[3])**2)/2.) + g = p[4] * exp( + -(((rcen_x - xp) / p[2]) ** 2 + ((rcen_y - yp) / p[3]) ** 2) / 2.0 + ) return g - f1 = gauss2d(x,y,[0.3,0.7,0.038,0.058,0.5,0.]) - f2 = gauss2d(x,y,[0.3,0.6,0.038,0.058,-0.5,0.]) - f3 = gauss2d(x,y,[0.45,0.7,0.038,0.058,0.5,0.]) - f4 = gauss2d(x,y,[0.45,0.6,0.038,0.058,-0.5,0.]) - f = f1+f2+f3+f4 + + f1 = gauss2d(x, y, [0.3, 0.7, 0.038, 0.058, 0.5, 0.0]) + f2 = gauss2d(x, y, [0.3, 0.6, 0.038, 0.058, -0.5, 0.0]) + f3 = gauss2d(x, y, [0.45, 0.7, 0.038, 0.058, 0.5, 0.0]) + f4 = gauss2d(x, y, [0.45, 0.6, 0.038, 0.058, -0.5, 0.0]) + f = f1 + f2 + f3 + f4 return f + def gauss_3d_dipole(x, y, z): """3D Gaussian dipole profile - to be used to test csd. - Parameters - ---------- - x : np.array - Spatial x pts. at which the true csd is evaluated - y : np.array - Spatial y pts. at which the true csd is evaluated - z : np.array - Spatial z pts. at which the true csd is evaluated - Returns - ------- - f : np.array - The value of the csd at the requested points + Parameters + ---------- + x : np.array + Spatial x pts. at which the true csd is evaluated + y : np.array + Spatial y pts. at which the true csd is evaluated + z : np.array + Spatial z pts. at which the true csd is evaluated + Returns + ------- + f : np.array + The value of the csd at the requested points """ x0, y0, z0 = 0.3, 0.7, 0.3 x1, y1, z1 = 0.6, 0.5, 0.7 sig_2 = 0.023 - A = (2*np.pi*sig_2)**-1 - f1 = A*exp( (-(x-x0)**2 -(y-y0)**2 -(z-z0)**2) / (2*sig_2) ) - f2 = -1*A*exp( (-(x-x1)**2 -(y-y1)**2 -(z-z1)**2) / (2*sig_2) ) - f = f1+f2 + A = (2 * np.pi * sig_2) ** -1 + f1 = A * exp((-((x - x0) ** 2) - (y - y0) ** 2 - (z - z0) ** 2) / (2 * sig_2)) + f2 = -1 * A * exp((-((x - x1) ** 2) - (y - y1) ** 2 - (z - z1) ** 2) / (2 * sig_2)) + f = f1 + f2 return f diff --git a/elephant/datasets.py b/elephant/datasets.py index 58d52d31f..d5a4a5411 100644 --- a/elephant/datasets.py +++ b/elephant/datasets.py @@ -38,8 +38,8 @@ def update_to(self, b=1, bsize=1, tsize=None): def calculate_md5(filepath, chunk_size=1024 * 1024): md5 = hashlib.md5() - with open(filepath, 'rb') as f: - for chunk in iter(lambda: f.read(chunk_size), b''): + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): md5.update(chunk) return md5.hexdigest() @@ -52,7 +52,7 @@ def check_integrity(filepath, md5): def download(url, filepath=None, checksum=None, verbose=True): if filepath is None: - filename = url.split('/')[-1] + filename = url.split("/")[-1] filepath = ELEPHANT_TMP_DIR / filename filepath = Path(filepath) if check_integrity(filepath, md5=checksum): @@ -60,8 +60,14 @@ def download(url, filepath=None, checksum=None, verbose=True): folder = filepath.absolute().parent folder.mkdir(exist_ok=True) desc = f"Downloading {url} to '{filepath}'" - with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, - desc=desc, disable=not verbose) as t: + with TqdmUpTo( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=desc, + disable=not verbose, + ) as t: try: urlretrieve(url, filename=filepath, reporthook=t.update_to) except URLError: @@ -71,68 +77,67 @@ def download(url, filepath=None, checksum=None, verbose=True): return filepath -def download_datasets(repo_path, filepath=None, checksum=None, - verbose=True): +def download_datasets(repo_path, filepath=None, checksum=None, verbose=True): r""" - This function can be used to download files from elephant-data using - only the path relative to the root of the elephant-data repository. - The default URL used, points to elephants corresponding release of - elephant-data. - Different versions of the elephant package may require different - versions of elephant-data. - e.g. the follwoing URLs: - - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/0.0.1 - points to release v0.0.1. - - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/master - always points to the latest state of elephant-data. - - https://datasets.python-elephant.org/ - points to the root of elephant data - - To change this URL, use the environment variable `ELEPHANT_DATA_URL`. - When using data, which is not yet contained in the master branch or a - release of elephant data, e.g. during development, this variable can - be used to change the default URL. - For example to use data on branch `multitaper`, change the - `ELEPHANT_DATA_URL` to - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper. - For a complete example, see Examples section. - - Parameters - ---------- - repo_path : str - String denoting the path relative to elephant-data repository root - filepath : str, optional - Path to temporary folder where the downloaded files will be stored - checksum : str, otpional - Checksum to verify dara integrity after download - verbose : bool, optional - Whether to disable the entire progressbar wrapper []. - If set to None, disable on non-TTY. - Default: True - - Returns - ------- - filepath : pathlib.Path - Path to downloaded files. - - - Notes - ----- - The default URL always points to elephant-data. Please - do not change its value. For development purposes use the environment - variable 'ELEPHANT_DATA_URL'. - - Examples - -------- - The following example downloads a file from elephant-data branch - 'multitaper', by setting the environment variable to the branch URL: - - >>> import os - >>> from elephant.datasets import download_datasets - >>> os.environ["ELEPHANT_DATA_URL"] = "https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper" # noqa - >>> download_datasets("unittest/spectral/multitaper_psd/data/time_series.npy") # doctest: +SKIP - PosixPath('/tmp/elephant/time_series.npy') - """ + This function can be used to download files from elephant-data using + only the path relative to the root of the elephant-data repository. + The default URL used, points to elephants corresponding release of + elephant-data. + Different versions of the elephant package may require different + versions of elephant-data. + e.g. the follwoing URLs: + - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/0.0.1 + points to release v0.0.1. + - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/master + always points to the latest state of elephant-data. + - https://datasets.python-elephant.org/ + points to the root of elephant data + + To change this URL, use the environment variable `ELEPHANT_DATA_URL`. + When using data, which is not yet contained in the master branch or a + release of elephant data, e.g. during development, this variable can + be used to change the default URL. + For example to use data on branch `multitaper`, change the + `ELEPHANT_DATA_URL` to + https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper. + For a complete example, see Examples section. + + Parameters + ---------- + repo_path : str + String denoting the path relative to elephant-data repository root + filepath : str, optional + Path to temporary folder where the downloaded files will be stored + checksum : str, otpional + Checksum to verify dara integrity after download + verbose : bool, optional + Whether to disable the entire progressbar wrapper []. + If set to None, disable on non-TTY. + Default: True + + Returns + ------- + filepath : pathlib.Path + Path to downloaded files. + + + Notes + ----- + The default URL always points to elephant-data. Please + do not change its value. For development purposes use the environment + variable 'ELEPHANT_DATA_URL'. + + Examples + -------- + The following example downloads a file from elephant-data branch + 'multitaper', by setting the environment variable to the branch URL: + + >>> import os + >>> from elephant.datasets import download_datasets + >>> os.environ["ELEPHANT_DATA_URL"] = "https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper" # noqa + >>> download_datasets("unittest/spectral/multitaper_psd/data/time_series.npy") # doctest: +SKIP + PosixPath('/tmp/elephant/time_series.npy') + """ # this url redirects to the current location of elephant-data url_to_root = "https://datasets.python-elephant.org/" @@ -141,33 +146,34 @@ def download_datasets(repo_path, filepath=None, checksum=None, # (version elephant is equal to version elephant-data) default_url = url_to_root + f"raw/v{_get_version()}" - if 'ELEPHANT_DATA_URL' not in environ: # user did not set URL + if "ELEPHANT_DATA_URL" not in environ: # user did not set URL # is 'version-URL' available? (not for elephant development version) try: - urlopen(default_url+'/README.md') + urlopen(default_url + "/README.md") except HTTPError as error: # if corresponding elephant-data version is not found, # use latest commit of elephant-data default_url = url_to_root + f"raw/master" - warnings.warn(f"No corresponding version of elephant-data found.\n" - f"Elephant version: {_get_version()}. " - f"Data URL:{error.url}, error: {error}.\n" - f"Using elephant-data latest instead (This is " - f"expected for elephant development versions).") + warnings.warn( + f"No corresponding version of elephant-data found.\n" + f"Elephant version: {_get_version()}. " + f"Data URL:{error.url}, error: {error}.\n" + f"Using elephant-data latest instead (This is " + f"expected for elephant development versions)." + ) except URLError as error: # if verification of SSL certificate fails, do not verify cert try: # try again without certificate verification ctx = ssl._create_unverified_context() ctx.check_hostname = True - urlopen(default_url + '/README.md') + urlopen(default_url + "/README.md") except HTTPError: # e.g. 404 default_url = url_to_root + f"raw/master" - warnings.warn(f"Data URL:{default_url}, error: {error}." - f"{error.reason}") + warnings.warn(f"Data URL:{default_url}, error: {error}." f"{error.reason}") url = f"{getenv('ELEPHANT_DATA_URL', default_url)}/{repo_path}" diff --git a/elephant/functional_connectivity.py b/elephant/functional_connectivity.py index 9e7ea0b32..d6f444ead 100644 --- a/elephant/functional_connectivity.py +++ b/elephant/functional_connectivity.py @@ -23,9 +23,8 @@ :license: Modified BSD, see LICENSE.txt for details. """ -from elephant.functional_connectivity_src.total_spiking_probability_edges \ - import ( - total_spiking_probability_edges, - ) +from elephant.functional_connectivity_src.total_spiking_probability_edges import ( + total_spiking_probability_edges, +) __all__ = ["total_spiking_probability_edges"] diff --git a/elephant/functional_connectivity_src/total_spiking_probability_edges.py b/elephant/functional_connectivity_src/total_spiking_probability_edges.py index 054fd0e13..d441fad91 100644 --- a/elephant/functional_connectivity_src/total_spiking_probability_edges.py +++ b/elephant/functional_connectivity_src/total_spiking_probability_edges.py @@ -157,8 +157,7 @@ def total_spiking_probability_edges( if normalize: for delay_time in delay_times: NCC_d[:, :, delay_time] /= np.sum( - NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0], - dtype=bool)] + NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0], dtype=bool)] ) # Apply edge and running total filter @@ -168,30 +167,26 @@ def total_spiking_probability_edges( NCC_window = NCC_d[ :, :, - max_padding - - filter.needed_padding: max_delay + max_padding - filter.needed_padding : max_delay + max_padding + filter.needed_padding, ] # Compute two convolutions with edge- and running total filter x1 = oaconvolve( - NCC_window, np.expand_dims(filter.edge_filter, (0, 1)), - mode="valid", axes=2 + NCC_window, np.expand_dims(filter.edge_filter, (0, 1)), mode="valid", axes=2 ) x2 = oaconvolve( - x1, np.expand_dims(filter.running_total_filter, (0, 1)), - mode="full", axes=2 + x1, np.expand_dims(filter.running_total_filter, (0, 1)), mode="full", axes=2 ) tspe_matrix += x2 # Take maxima of absolute of delays to get estimation for connectivity - connectivity_matrix_index = np.argmax(np.abs(tspe_matrix), - axis=2, keepdims=True) - connectivity_matrix = np.take_along_axis(tspe_matrix, - connectivity_matrix_index, axis=2 - ).squeeze(axis=2) + connectivity_matrix_index = np.argmax(np.abs(tspe_matrix), axis=2, keepdims=True) + connectivity_matrix = np.take_along_axis( + tspe_matrix, connectivity_matrix_index, axis=2 + ).squeeze(axis=2) delay_matrix = connectivity_matrix_index.squeeze() return connectivity_matrix, delay_matrix @@ -242,8 +237,7 @@ def normalized_cross_correlation( # Uses theoretical zero-padding for shifted values, # but since $0 \cdot x = 0$ values can simply be omitted if delay_time == 0: - CC = spike_trains_array[:, :] @ spike_trains_array[:, : - ].transpose() + CC = spike_trains_array[:, :] @ spike_trains_array[:, :].transpose() elif delay_time > 0: CC = ( @@ -305,8 +299,7 @@ def generate_edge_filter( conditions = [ (i > 0) & (i <= surrounding_window_size), (i > (surrounding_window_size + crossover_window_size)) - & (i <= surrounding_window_size + observed_window_size + - crossover_window_size), + & (i <= surrounding_window_size + observed_window_size + crossover_window_size), ( i > surrounding_window_size diff --git a/elephant/gpfa/__init__.py b/elephant/gpfa/__init__.py index 141688932..0a4f9630f 100644 --- a/elephant/gpfa/__init__.py +++ b/elephant/gpfa/__init__.py @@ -4,6 +4,4 @@ # please run command `pip install -r requirements-extras.txt` pass -__all__ = [ - "GPFA" -] +__all__ = ["GPFA"] diff --git a/elephant/gpfa/gpfa.py b/elephant/gpfa/gpfa.py index 79d490e0d..d4dfb936f 100644 --- a/elephant/gpfa/gpfa.py +++ b/elephant/gpfa/gpfa.py @@ -356,9 +356,7 @@ def fit( print("Number of training trials: {}".format(len(seqs_train))) print("Latent space dimensionality: {}".format(self.x_dim)) print( - "Observation dimensionality: {}".format( - self.has_spikes_bool.sum() - ) + "Observation dimensionality: {}".format(self.has_spikes_bool.sum()) ) # The following does the heavy lifting. diff --git a/elephant/gpfa/gpfa_core.py b/elephant/gpfa/gpfa_core.py index dc3112954..b0a8e64d6 100644 --- a/elephant/gpfa/gpfa_core.py +++ b/elephant/gpfa/gpfa_core.py @@ -21,9 +21,18 @@ from . import gpfa_util -def fit(seqs_train, x_dim=3, bin_width=20.0, min_var_frac=0.01, em_tol=1.0E-8, - em_max_iters=500, tau_init=100.0, eps_init=1.0E-3, freq_ll=5, - verbose=False): +def fit( + seqs_train, + x_dim=3, + bin_width=20.0, + min_var_frac=0.01, + em_tol=1.0e-8, + em_max_iters=500, + tau_init=100.0, + eps_init=1.0e-3, + freq_ll=5, + verbose=False, +): """ Fit the GPFA model with the given training data. @@ -95,57 +104,74 @@ def fit(seqs_train, x_dim=3, bin_width=20.0, min_var_frac=0.01, em_tol=1.0E-8, # For compute efficiency, train on equal-length segments of trials seqs_train_cut = gpfa_util.cut_trials(seqs_train) if len(seqs_train_cut) == 0: - warnings.warn('No segments extracted for training. Defaulting to ' - 'segLength=Inf.') + warnings.warn( + "No segments extracted for training. Defaulting to " "segLength=Inf." + ) seqs_train_cut = gpfa_util.cut_trials(seqs_train, seg_length=np.inf) # ================================== # Initialize state model parameters # ================================== params_init = dict() - params_init['covType'] = 'rbf' + params_init["covType"] = "rbf" # GP timescale # Assume binWidth is the time step size. - params_init['gamma'] = (bin_width / tau_init) ** 2 * np.ones(x_dim) + params_init["gamma"] = (bin_width / tau_init) ** 2 * np.ones(x_dim) # GP noise variance - params_init['eps'] = eps_init * np.ones(x_dim) + params_init["eps"] = eps_init * np.ones(x_dim) # ======================================== # Initialize observation model parameters # ======================================== - print('Initializing parameters using factor analysis...') - - y_all = np.hstack(seqs_train_cut['y']) - fa = FactorAnalysis(n_components=x_dim, copy=True, - noise_variance_init=np.diag(np.cov(y_all, bias=True))) + print("Initializing parameters using factor analysis...") + + y_all = np.hstack(seqs_train_cut["y"]) + fa = FactorAnalysis( + n_components=x_dim, + copy=True, + noise_variance_init=np.diag(np.cov(y_all, bias=True)), + ) fa.fit(y_all.T) - params_init['d'] = y_all.mean(axis=1) - params_init['C'] = fa.components_.T - params_init['R'] = np.diag(fa.noise_variance_) + params_init["d"] = y_all.mean(axis=1) + params_init["C"] = fa.components_.T + params_init["R"] = np.diag(fa.noise_variance_) # Define parameter constraints - params_init['notes'] = { - 'learnKernelParams': True, - 'learnGPNoise': False, - 'RforceDiagonal': True, + params_init["notes"] = { + "learnKernelParams": True, + "learnGPNoise": False, + "RforceDiagonal": True, } # ===================== # Fit model parameters # ===================== - print('\nFitting GPFA model...') + print("\nFitting GPFA model...") params_est, seqs_train_cut, ll_cut, iter_time = em( - params_init, seqs_train_cut, min_var_frac=min_var_frac, - max_iters=em_max_iters, tol=em_tol, freq_ll=freq_ll, verbose=verbose) + params_init, + seqs_train_cut, + min_var_frac=min_var_frac, + max_iters=em_max_iters, + tol=em_tol, + freq_ll=freq_ll, + verbose=verbose, + ) - fit_info = {'iteration_time': iter_time, 'log_likelihoods': ll_cut} + fit_info = {"iteration_time": iter_time, "log_likelihoods": ll_cut} return params_est, fit_info -def em(params_init, seqs_train, max_iters=500, tol=1.0E-8, min_var_frac=0.01, - freq_ll=5, verbose=False): +def em( + params_init, + seqs_train, + max_iters=500, + tol=1.0e-8, + min_var_frac=0.01, + freq_ll=5, + verbose=False, +): """ Fits GPFA model parameters using expectation-maximization (EM) algorithm. @@ -217,17 +243,16 @@ def em(params_init, seqs_train, max_iters=500, tol=1.0E-8, min_var_frac=0.01, lisf of computation times (in seconds) for each EM iteration """ params = params_init - t = seqs_train['T'] - y_dim, x_dim = params['C'].shape + t = seqs_train["T"] + y_dim, x_dim = params["C"].shape lls = [] ll_old = ll_base = ll = 0.0 iter_time = [] - var_floor = min_var_frac * np.diag(np.cov(np.hstack(seqs_train['y']))) + var_floor = min_var_frac * np.diag(np.cov(np.hstack(seqs_train["y"]))) seqs_latent = None # Loop once for each iteration of EM algorithm - for iter_id in trange(1, max_iters + 1, desc='EM iteration', - disable=not verbose): + for iter_id in trange(1, max_iters + 1, desc="EM iteration", disable=not verbose): if verbose: print() tic = time.time() @@ -236,58 +261,61 @@ def em(params_init, seqs_train, max_iters=500, tol=1.0E-8, min_var_frac=0.01, # ==== E STEP ===== if not np.isnan(ll): ll_old = ll - seqs_latent, ll = exact_inference_with_ll(seqs_train, params, - get_ll=get_ll) + seqs_latent, ll = exact_inference_with_ll(seqs_train, params, get_ll=get_ll) lls.append(ll) # ==== M STEP ==== sum_p_auto = np.zeros((x_dim, x_dim)) for seq_latent in seqs_latent: - sum_p_auto += seq_latent['Vsm'].sum(axis=2) \ - + seq_latent['latent_variable'].dot( - seq_latent['latent_variable'].T) - y = np.hstack(seqs_train['y']) - latent_variable = np.hstack(seqs_latent['latent_variable']) + sum_p_auto += seq_latent["Vsm"].sum(axis=2) + seq_latent[ + "latent_variable" + ].dot(seq_latent["latent_variable"].T) + y = np.hstack(seqs_train["y"]) + latent_variable = np.hstack(seqs_latent["latent_variable"]) sum_yxtrans = y.dot(latent_variable.T) sum_xall = latent_variable.sum(axis=1)[:, np.newaxis] sum_yall = y.sum(axis=1)[:, np.newaxis] # term is (xDim+1) x (xDim+1) - term = np.vstack([np.hstack([sum_p_auto, sum_xall]), - np.hstack([sum_xall.T, t.sum().reshape((1, 1))])]) + term = np.vstack( + [ + np.hstack([sum_p_auto, sum_xall]), + np.hstack([sum_xall.T, t.sum().reshape((1, 1))]), + ] + ) # yDim x (xDim+1) cd = gpfa_util.rdiv(np.hstack([sum_yxtrans, sum_yall]), term) - params['C'] = cd[:, :x_dim] - params['d'] = cd[:, -1] + params["C"] = cd[:, :x_dim] + params["d"] = cd[:, -1] # yCent must be based on the new d # yCent = bsxfun(@minus, [seq.y], currentParams.d); # R = (yCent * yCent' - (yCent * [seq.latent_variable]') * \ # currentParams.C') / sum(T); - c = params['C'] - d = params['d'][:, np.newaxis] - if params['notes']['RforceDiagonal']: + c = params["C"] + d = params["d"][:, np.newaxis] + if params["notes"]["RforceDiagonal"]: sum_yytrans = (y * y).sum(axis=1)[:, np.newaxis] yd = sum_yall * d term = ((sum_yxtrans - d.dot(sum_xall.T)) * c).sum(axis=1) term = term[:, np.newaxis] - r = d ** 2 + (sum_yytrans - 2 * yd - term) / t.sum() + r = d**2 + (sum_yytrans - 2 * yd - term) / t.sum() # Set minimum private variance r = np.maximum(var_floor, r) - params['R'] = np.diag(r[:, 0]) + params["R"] = np.diag(r[:, 0]) else: sum_yytrans = y.dot(y.T) yd = sum_yall.dot(d.T) term = (sum_yxtrans - d.dot(sum_xall.T)).dot(c.T) r = d.dot(d.T) + (sum_yytrans - yd - yd.T - term) / t.sum() - params['R'] = (r + r.T) / 2 # ensure symmetry + params["R"] = (r + r.T) / 2 # ensure symmetry - if params['notes']['learnKernelParams']: + if params["notes"]["learnKernelParams"]: res = learn_gp_params(seqs_latent, params, verbose=verbose) - params['gamma'] = res['gamma'] + params["gamma"] = res["gamma"] t_end = time.time() - tic iter_time.append(t_end) @@ -296,18 +324,21 @@ def em(params_init, seqs_train, max_iters=500, tol=1.0E-8, min_var_frac=0.01, if iter_id <= 2: ll_base = ll elif verbose and ll < ll_old: - print('\nError: Data likelihood has decreased ', - 'from {0} to {1}'.format(ll_old, ll)) + print( + "\nError: Data likelihood has decreased ", + "from {0} to {1}".format(ll_old, ll), + ) elif (ll - ll_base) < (1 + tol) * (ll_old - ll_base): break if len(lls) < max_iters: - print('Fitting has converged after {0} EM iterations.)'.format( - len(lls))) + print("Fitting has converged after {0} EM iterations.)".format(len(lls))) - if np.any(np.diag(params['R']) == var_floor): - warnings.warn('Private variance floor used for one or more observed ' - 'dimensions in GPFA.') + if np.any(np.diag(params["R"]) == var_floor): + warnings.warn( + "Private variance floor used for one or more observed " + "dimensions in GPFA." + ) return params, seqs_latent, lls, iter_time @@ -356,31 +387,30 @@ def exact_inference_with_ll(seqs, params, get_ll=True): ll : float data log likelihood, np.nan is returned when `get_ll` is set False """ - y_dim, x_dim = params['C'].shape + y_dim, x_dim = params["C"].shape # copy the contents of the input data structure to output structure dtype_out = [(x, seqs[x].dtype) for x in seqs.dtype.names] - dtype_out.extend([('latent_variable', object), ('Vsm', object), - ('VsmGP', object)]) + dtype_out.extend([("latent_variable", object), ("Vsm", object), ("VsmGP", object)]) seqs_latent = np.empty(len(seqs), dtype=dtype_out) for dtype_name in seqs.dtype.names: seqs_latent[dtype_name] = seqs[dtype_name] # Precomputations - if params['notes']['RforceDiagonal']: - rinv = np.diag(1.0 / np.diag(params['R'])) - logdet_r = (np.log(np.diag(params['R']))).sum() + if params["notes"]["RforceDiagonal"]: + rinv = np.diag(1.0 / np.diag(params["R"])) + logdet_r = (np.log(np.diag(params["R"]))).sum() else: - rinv = linalg.inv(params['R']) + rinv = linalg.inv(params["R"]) rinv = (rinv + rinv.T) / 2 # ensure symmetry - logdet_r = gpfa_util.logdet(params['R']) + logdet_r = gpfa_util.logdet(params["R"]) - c_rinv = params['C'].T.dot(rinv) - c_rinv_c = c_rinv.dot(params['C']) + c_rinv = params["C"].T.dot(rinv) + c_rinv_c = c_rinv.dot(params["C"]) - t_all = seqs_latent['T'] + t_all = seqs_latent["T"] t_uniq = np.unique(t_all) - ll = 0. + ll = 0.0 # Overview: # - Outer loop on each element of Tu. @@ -400,7 +430,7 @@ def exact_inference_with_ll(seqs, params, get_ll=True): vsm = np.full((x_dim, x_dim, t), np.nan) idx = np.arange(0, x_dim * t + 1, x_dim) for i in range(t): - vsm[:, :, i] = minv[idx[i]:idx[i + 1], idx[i]:idx[i + 1]] + vsm[:, :, i] = minv[idx[i] : idx[i + 1], idx[i] : idx[i + 1]] # T x T posterior covariance for each GP vsm_gp = np.full((t, t, x_dim), np.nan) @@ -410,9 +440,9 @@ def exact_inference_with_ll(seqs, params, get_ll=True): # Process all trials with length T n_list = np.where(t_all == t)[0] # dif is yDim x sum(T) - dif = np.hstack(seqs_latent[n_list]['y']) - params['d'][:, np.newaxis] + dif = np.hstack(seqs_latent[n_list]["y"]) - params["d"][:, np.newaxis] # term1Mat is (xDim*T) x length(nList) - term1_mat = c_rinv.dot(dif).reshape((x_dim * t, -1), order='F') + term1_mat = c_rinv.dot(dif).reshape((x_dim * t, -1), order="F") # Compute blkProd = CRinvC_big * invM efficiently # blkProd is block persymmetric, so just compute top half @@ -420,27 +450,35 @@ def exact_inference_with_ll(seqs, params, get_ll=True): blk_prod = np.zeros((x_dim * t_half, x_dim * t)) idx = range(0, x_dim * t_half + 1, x_dim) for i in range(t_half): - blk_prod[idx[i]:idx[i + 1], :] = c_rinv_c.dot( - minv[idx[i]:idx[i + 1], :]) - blk_prod = k_big[:x_dim * t_half, :].dot( - gpfa_util.fill_persymm(np.eye(x_dim * t_half, x_dim * t) - - blk_prod, x_dim, t)) + blk_prod[idx[i] : idx[i + 1], :] = c_rinv_c.dot( + minv[idx[i] : idx[i + 1], :] + ) + blk_prod = k_big[: x_dim * t_half, :].dot( + gpfa_util.fill_persymm( + np.eye(x_dim * t_half, x_dim * t) - blk_prod, x_dim, t + ) + ) # latent_variableMat is (xDim*T) x length(nList) - latent_variable_mat = gpfa_util.fill_persymm( - blk_prod, x_dim, t).dot(term1_mat) + latent_variable_mat = gpfa_util.fill_persymm(blk_prod, x_dim, t).dot(term1_mat) for i, n in enumerate(n_list): - seqs_latent[n]['latent_variable'] = \ - latent_variable_mat[:, i].reshape((x_dim, t), order='F') - seqs_latent[n]['Vsm'] = vsm - seqs_latent[n]['VsmGP'] = vsm_gp + seqs_latent[n]["latent_variable"] = latent_variable_mat[:, i].reshape( + (x_dim, t), order="F" + ) + seqs_latent[n]["Vsm"] = vsm + seqs_latent[n]["VsmGP"] = vsm_gp if get_ll: # Compute data likelihood - val = -t * logdet_r - logdet_k_big - logdet_m \ - - y_dim * t * np.log(2 * np.pi) - ll = ll + len(n_list) * val - (rinv.dot(dif) * dif).sum() \ + val = ( + -t * logdet_r - logdet_k_big - logdet_m - y_dim * t * np.log(2 * np.pi) + ) + ll = ( + ll + + len(n_list) * val + - (rinv.dot(dif) * dif).sum() + (term1_mat.T.dot(minv) * term1_mat.T).sum() + ) if get_ll: ll /= 2 @@ -475,11 +513,11 @@ def learn_gp_params(seqs_latent, params, verbose=False): If `params['notes']['learnGPNoise']` set to True. """ - if params['covType'] != 'rbf': + if params["covType"] != "rbf": raise ValueError("Only 'rbf' GP covariance type is supported.") - if params['notes']['learnGPNoise']: + if params["notes"]["learnGPNoise"]: raise ValueError("learnGPNoise is not supported.") - param_name = 'gamma' + param_name = "gamma" param_init = params[param_name] param_opt = {param_name: np.empty_like(param_init)} @@ -489,15 +527,19 @@ def learn_gp_params(seqs_latent, params, verbose=False): # Loop once for each state dimension (each GP) for i in range(x_dim): - const = {'eps': params['eps'][i]} + const = {"eps": params["eps"][i]} initp = np.log(param_init[i]) - res_opt = optimize.minimize(gpfa_util.grad_betgam, initp, - args=(precomp[i], const), - method='L-BFGS-B', jac=True) - param_opt['gamma'][i] = np.exp(res_opt.x.item()) + res_opt = optimize.minimize( + gpfa_util.grad_betgam, + initp, + args=(precomp[i], const), + method="L-BFGS-B", + jac=True, + ) + param_opt["gamma"][i] = np.exp(res_opt.x.item()) if verbose: - print('\n Converged p; xDim:{}, p:{}'.format(i, res_opt.x)) + print("\n Converged p; xDim:{}, p:{}".format(i, res_opt.x)) return param_opt @@ -554,12 +596,13 @@ def orthonormalize(params_est, seqs): Training data structure that contains the new field `latent_variable_orth`, the orthonormalized neural trajectories. """ - C = params_est['C'] - X = np.hstack(seqs['latent_variable']) + C = params_est["C"] + X = np.hstack(seqs["latent_variable"]) latent_variable_orth, Corth, _ = gpfa_util.orthonormalize(X, C) seqs = gpfa_util.segment_by_trial( - seqs, latent_variable_orth, 'latent_variable_orth') + seqs, latent_variable_orth, "latent_variable_orth" + ) - params_est['Corth'] = Corth + params_est["Corth"] = Corth return Corth, seqs diff --git a/elephant/gpfa/gpfa_util.py b/elephant/gpfa/gpfa_util.py index bc0e7ad8a..c5a809fbc 100644 --- a/elephant/gpfa/gpfa_util.py +++ b/elephant/gpfa/gpfa_util.py @@ -64,13 +64,12 @@ def get_seqs(data, bin_size, use_sqrt=True): binned = np.sqrt(binned_spiketrain.to_array()) else: binned = binned_spiketrain.to_array() - seqs.append( - (binned_spiketrain.n_bins, binned)) - seqs = np.array(seqs, dtype=[('T', int), ('y', 'O')]) + seqs.append((binned_spiketrain.n_bins, binned)) + seqs = np.array(seqs, dtype=[("T", int), ("y", "O")]) # Remove trials that are shorter than one bin width if len(seqs) > 0: - trials_to_keep = seqs['T'] > 0 + trials_to_keep = seqs["T"] > 0 seqs = seqs[trials_to_keep] return seqs @@ -120,24 +119,28 @@ def cut_trials(seq_in, seg_length=20): seqOut = seq_in return seqOut - dtype_seqOut = [('segId', int), ('T', int), - ('y', object)] + dtype_seqOut = [("segId", int), ("T", int), ("y", object)] seqOut_buff = [] for n, seqIn_n in enumerate(seq_in): - T = seqIn_n['T'] + T = seqIn_n["T"] # Skip trials that are shorter than segLength if T < seg_length: warnings.warn( - 'trial corresponding to index {} shorter than one segLength...' - 'skipping'.format(n)) + "trial corresponding to index {} shorter than one segLength..." + "skipping".format(n) + ) continue numSeg = int(np.ceil(float(T) / seg_length)) # Randomize the sizes of overlaps if numSeg == 1: - cumOL = np.array([0, ]) + cumOL = np.array( + [ + 0, + ] + ) else: totalOL = (seg_length * numSeg) - T probs = np.ones(numSeg - 1, float) / (numSeg - 1) @@ -145,11 +148,11 @@ def cut_trials(seq_in, seg_length=20): cumOL = np.hstack([0, np.cumsum(randOL)]) seg = np.empty(numSeg, dtype_seqOut) - seg['T'] = seg_length + seg["T"] = seg_length for s, seg_s in enumerate(seg): tStart = seg_length * s - cumOL[s] - seg_s['y'] = seqIn_n['y'][:, tStart:tStart + seg_length] + seg_s["y"] = seqIn_n["y"][:, tStart : tStart + seg_length] seqOut_buff.append(seg) @@ -211,21 +214,22 @@ def make_k_big(params, n_timesteps): If `params['covType'] != 'rbf'`. """ - if params['covType'] != 'rbf': + if params["covType"] != "rbf": raise ValueError("Only 'rbf' GP covariance type is supported.") - xDim = params['C'].shape[1] + xDim = params["C"].shape[1] K_big = np.zeros((xDim * n_timesteps, xDim * n_timesteps)) K_big_inv = np.zeros((xDim * n_timesteps, xDim * n_timesteps)) - Tdif = np.tile(np.arange(0, n_timesteps), (n_timesteps, 1)).T \ - - np.tile(np.arange(0, n_timesteps), (n_timesteps, 1)) + Tdif = np.tile(np.arange(0, n_timesteps), (n_timesteps, 1)).T - np.tile( + np.arange(0, n_timesteps), (n_timesteps, 1) + ) logdet_K_big = 0 for i in range(xDim): - K = (1 - params['eps'][i]) * np.exp(-params['gamma'][i] / 2 * - Tdif ** 2) \ - + params['eps'][i] * np.eye(n_timesteps) + K = (1 - params["eps"][i]) * np.exp(-params["gamma"][i] / 2 * Tdif**2) + params[ + "eps" + ][i] * np.eye(n_timesteps) K_big[i::xDim, i::xDim] = K # the original MATLAB program uses here a special algorithm, provided # in C and MEX, for inversion of Toeplitz matrix: @@ -300,25 +304,25 @@ def inv_persymm(M, blk_size): def fill_persymm(p_in, blk_size, n_blocks, blk_size_vert=None): """ - Fills in the bottom half of a block persymmetric matrix, given the - top half. - - Parameters - ---------- - p_in : (xDim*Thalf, xDim*T) np.ndarray - Top half of block persymmetric matrix, where Thalf = ceil(T/2) - blk_size : int - Edge length of one block - n_blocks : int - Number of blocks making up a row of Pin - blk_size_vert : int, optional - Vertical block edge length if blocks are not square. - `blk_size` is assumed to be the horizontal block edge length. - - Returns - ------- - Pout : (xDim*T, xDim*T) np.ndarray - Full block persymmetric matrix + Fills in the bottom half of a block persymmetric matrix, given the + top half. + + Parameters + ---------- + p_in : (xDim*Thalf, xDim*T) np.ndarray + Top half of block persymmetric matrix, where Thalf = ceil(T/2) + blk_size : int + Edge length of one block + n_blocks : int + Number of blocks making up a row of Pin + blk_size_vert : int, optional + Vertical block edge length if blocks are not square. + `blk_size` is assumed to be the horizontal block edge length. + + Returns + ------- + Pout : (xDim*T, xDim*T) np.ndarray + Full block persymmetric matrix """ if blk_size_vert is None: blk_size_vert = blk_size @@ -329,14 +333,16 @@ def fill_persymm(p_in, blk_size, n_blocks, blk_size_vert=None): THalf = int(np.ceil(n_blocks / 2.0)) Pout = np.empty((blk_size_vert * n_blocks, blk_size * n_blocks)) - Pout[:blk_size_vert * THalf, :] = p_in + Pout[: blk_size_vert * THalf, :] = p_in for i in range(Thalf): for j in range(n_blocks): - Pout[Nv - (i + 1) * blk_size_vert:Nv - i * blk_size_vert, - Nh - (j + 1) * blk_size:Nh - j * blk_size] \ - = p_in[i * blk_size_vert:(i + 1) * - blk_size_vert, - j * blk_size:(j + 1) * blk_size] + Pout[ + Nv - (i + 1) * blk_size_vert : Nv - i * blk_size_vert, + Nh - (j + 1) * blk_size : Nh - j * blk_size, + ] = p_in[ + i * blk_size_vert : (i + 1) * blk_size_vert, + j * blk_size : (j + 1) * blk_size, + ] return Pout @@ -375,35 +381,42 @@ def make_precomp(seqs, xDim): Finally, see the notes in the GPFA README. """ - Tall = seqs['T'] + Tall = seqs["T"] Tmax = (Tall).max() - Tdif = np.tile(np.arange(0, Tmax), (Tmax, 1)).T \ - - np.tile(np.arange(0, Tmax), (Tmax, 1)) + Tdif = np.tile(np.arange(0, Tmax), (Tmax, 1)).T - np.tile( + np.arange(0, Tmax), (Tmax, 1) + ) # assign some helpful precomp items # this is computationally cheap, so we keep a few loops in MATLAB # for ease of readability. - precomp = np.empty(xDim, dtype=[( - 'absDif', object), ('difSq', object), ('Tall', object), - ('Tu', object)]) + precomp = np.empty( + xDim, + dtype=[("absDif", object), ("difSq", object), ("Tall", object), ("Tu", object)], + ) for i in range(xDim): - precomp[i]['absDif'] = np.abs(Tdif) - precomp[i]['difSq'] = Tdif ** 2 - precomp[i]['Tall'] = Tall + precomp[i]["absDif"] = np.abs(Tdif) + precomp[i]["difSq"] = Tdif**2 + precomp[i]["Tall"] = Tall # find unique numbers of trial lengths trial_lengths_num_unique = np.unique(Tall) # Loop once for each state dimension (each GP) for i in range(xDim): - precomp_Tu = np.empty(len(trial_lengths_num_unique), dtype=[( - 'nList', object), ('T', int), ('numTrials', int), - ('PautoSUM', object)]) + precomp_Tu = np.empty( + len(trial_lengths_num_unique), + dtype=[ + ("nList", object), + ("T", int), + ("numTrials", int), + ("PautoSUM", object), + ], + ) for j, trial_len_num in enumerate(trial_lengths_num_unique): - precomp_Tu[j]['nList'] = np.where(Tall == trial_len_num)[0] - precomp_Tu[j]['T'] = trial_len_num - precomp_Tu[j]['numTrials'] = len(precomp_Tu[j]['nList']) - precomp_Tu[j]['PautoSUM'] = np.zeros((trial_len_num, - trial_len_num)) - precomp[i]['Tu'] = precomp_Tu + precomp_Tu[j]["nList"] = np.where(Tall == trial_len_num)[0] + precomp_Tu[j]["T"] = trial_len_num + precomp_Tu[j]["numTrials"] = len(precomp_Tu[j]["nList"]) + precomp_Tu[j]["PautoSUM"] = np.zeros((trial_len_num, trial_len_num)) + precomp[i]["Tu"] = precomp_Tu # at this point the basic precomp is built. The previous steps # should be computationally cheap. We now try to embed the @@ -419,10 +432,10 @@ def make_precomp(seqs, xDim): # Loop once for each trial length (each of Tu) for j in range(len(trial_lengths_num_unique)): # Loop once for each trial (each of nList) - for n in precomp[i]['Tu'][j]['nList']: - precomp[i]['Tu'][j]['PautoSUM'] += seqs[n]['VsmGP'][:, :, i] \ - + np.outer(seqs[n]['latent_variable'][i, :], - seqs[n]['latent_variable'][i, :]) + for n in precomp[i]["Tu"][j]["nList"]: + precomp[i]["Tu"][j]["PautoSUM"] += seqs[n]["VsmGP"][:, :, i] + np.outer( + seqs[n]["latent_variable"][i, :], seqs[n]["latent_variable"][i, :] + ) return precomp @@ -448,18 +461,18 @@ def grad_betgam(p, pre_comp, const): df : float gradient at p """ - Tall = pre_comp['Tall'] + Tall = pre_comp["Tall"] Tmax = Tall.max() # temp is Tmax x Tmax - temp = (1 - const['eps']) * np.exp(-np.exp(p) / 2 * pre_comp['difSq']) - Kmax = temp + const['eps'] * np.eye(Tmax) - dKdgamma_max = -0.5 * temp * pre_comp['difSq'] + temp = (1 - const["eps"]) * np.exp(-np.exp(p) / 2 * pre_comp["difSq"]) + Kmax = temp + const["eps"] * np.eye(Tmax) + dKdgamma_max = -0.5 * temp * pre_comp["difSq"] dEdgamma = 0 f = 0 - for j in range(len(pre_comp['Tu'])): - T = pre_comp['Tu'][j]['T'] + for j in range(len(pre_comp["Tu"])): + T = pre_comp["Tu"][j]["T"] Thalf = int(np.ceil(T / 2.0)) Kinv = np.linalg.inv(Kmax[:T, :T]) @@ -471,20 +484,22 @@ def grad_betgam(p, pre_comp, const): dg_KinvM = np.diag(KinvM) tr_KinvM = 2 * dg_KinvM.sum() - np.fmod(T, 2) * dg_KinvM[-1] - mkr = int(np.ceil(0.5 * T ** 2)) - numTrials = pre_comp['Tu'][j]['numTrials'] - PautoSUM = pre_comp['Tu'][j]['PautoSUM'] - - pauto_kinv_dot = PautoSUM.ravel('F')[:mkr].dot( - KinvMKinv.ravel('F')[:mkr]) - pauto_kinv_dot_rest = PautoSUM.ravel('F')[-1:mkr - 1:- 1].dot( - KinvMKinv.ravel('F')[:(T ** 2 - mkr)]) - dEdgamma = dEdgamma - 0.5 * numTrials * tr_KinvM \ - + 0.5 * pauto_kinv_dot \ + mkr = int(np.ceil(0.5 * T**2)) + numTrials = pre_comp["Tu"][j]["numTrials"] + PautoSUM = pre_comp["Tu"][j]["PautoSUM"] + + pauto_kinv_dot = PautoSUM.ravel("F")[:mkr].dot(KinvMKinv.ravel("F")[:mkr]) + pauto_kinv_dot_rest = PautoSUM.ravel("F")[-1 : mkr - 1 : -1].dot( + KinvMKinv.ravel("F")[: (T**2 - mkr)] + ) + dEdgamma = ( + dEdgamma + - 0.5 * numTrials * tr_KinvM + + 0.5 * pauto_kinv_dot + 0.5 * pauto_kinv_dot_rest + ) - f = f - 0.5 * numTrials * logdet_K \ - - 0.5 * (PautoSUM * Kinv).sum() + f = f - 0.5 * numTrials * logdet_K - 0.5 * (PautoSUM * Kinv).sum() f = -f # exp(p) is needed because we're computing gradients with @@ -556,8 +571,8 @@ def segment_by_trial(seqs, x, fn): If `seqs['T']) != x.shape[1]`. """ - if np.sum(seqs['T']) != x.shape[1]: - raise ValueError('size of X incorrect.') + if np.sum(seqs["T"]) != x.shape[1]: + raise ValueError("size of X incorrect.") dtype_new = [(i, seqs[i].dtype) for i in seqs.dtype.names] dtype_new.append((fn, object)) @@ -566,8 +581,8 @@ def segment_by_trial(seqs, x, fn): seqs_new[dtype_name] = seqs[dtype_name] ctr = 0 - for n, T in enumerate(seqs['T']): - seqs_new[n][fn] = x[:, ctr:ctr + T] + for n, T in enumerate(seqs["T"]): + seqs_new[n][fn] = x[:, ctr : ctr + T] ctr += T return seqs_new diff --git a/elephant/kernels.py b/elephant/kernels.py index 8c305a58b..c6e538eca 100644 --- a/elephant/kernels.py +++ b/elephant/kernels.py @@ -82,8 +82,13 @@ import scipy.stats __all__ = [ - 'RectangularKernel', 'TriangularKernel', 'EpanechnikovLikeKernel', - 'GaussianKernel', 'LaplacianKernel', 'ExponentialKernel', 'AlphaKernel' + "RectangularKernel", + "TriangularKernel", + "EpanechnikovLikeKernel", + "GaussianKernel", + "LaplacianKernel", + "ExponentialKernel", + "AlphaKernel", ] @@ -162,7 +167,8 @@ def __init__(self, sigma, invert=False): def __repr__(self): return "{cls}(sigma={sigma}, invert={invert})".format( - cls=self.__class__.__name__, sigma=self.sigma, invert=self.invert) + cls=self.__class__.__name__, sigma=self.sigma, invert=self.invert + ) def __call__(self, times): """ @@ -207,7 +213,8 @@ def _evaluate(self, times): """ raise NotImplementedError( "The Kernel class should not be used directly, " - "instead the subclasses for the single kernels.") + "instead the subclasses for the single kernels." + ) def boundary_enclosing_area_fraction(self, fraction): """ @@ -240,7 +247,8 @@ def boundary_enclosing_area_fraction(self, fraction): """ raise NotImplementedError( "The Kernel class should not be used directly, " - "instead the subclasses for the single kernels.") + "instead the subclasses for the single kernels." + ) def _check_fraction(self, fraction): """ @@ -265,24 +273,28 @@ def _check_fraction(self, fraction): raise TypeError("`fraction` must be float or integer") if isinstance(self, (TriangularKernel, RectangularKernel)): valid = 0 <= fraction <= 1 - bracket = ']' + bracket = "]" else: valid = 0 <= fraction < 1 - bracket = ')' + bracket = ")" if not valid: - raise ValueError("`fraction` must be in the interval " - "[0, 1{}".format(bracket)) + raise ValueError( + "`fraction` must be in the interval " "[0, 1{}".format(bracket) + ) def _check_time_input(self, t): if not isinstance(t, pq.Quantity): - raise TypeError("The argument 't' of the kernel callable must be " - "of type Quantity") + raise TypeError( + "The argument 't' of the kernel callable must be " "of type Quantity" + ) if t.dimensionality.simplified != self.sigma.dimensionality.simplified: - raise TypeError("The dimensionality of sigma and the input array " - "to the callable kernel object must be the same. " - "Otherwise a normalization to 1 of the kernel " - "cannot be performed.") + raise TypeError( + "The dimensionality of sigma and the input array " + "to the callable kernel object must be the same. " + "Otherwise a normalization to 1 of the kernel " + "cannot be performed." + ) def cdf(self, time): r""" @@ -369,10 +381,11 @@ def median_index(self, times): return 0 is_sorted = (np.diff(times.magnitude) >= 0).all() if not is_sorted: - raise ValueError("The input time array must be sorted (in " - "ascending order).") + raise ValueError( + "The input time array must be sorted (in " "ascending order)." + ) cdf_mean = 0.5 * (self.cdf(times[0]) + self.cdf(times[-1])) - if cdf_mean == 0.: + if cdf_mean == 0.0: # any index of the kernel non-support is valid; choose median return len(times) // 2 icdf = self.icdf(fraction=cdf_mean) @@ -525,16 +538,14 @@ def min_cutoff(self): def _evaluate(self, times): tau = math.sqrt(6) * self.sigma.rescale(times.units).magnitude - kernel = scipy.stats.triang.pdf(times.magnitude, c=0.5, loc=-tau, - scale=2 * tau) + kernel = scipy.stats.triang.pdf(times.magnitude, c=0.5, loc=-tau, scale=2 * tau) kernel = pq.Quantity(kernel, units=1 / times.units) return kernel def cdf(self, time): self._check_time_input(time) tau = math.sqrt(6) * self.sigma.rescale(time.units).magnitude - cdf = scipy.stats.triang.cdf(time.magnitude, c=0.5, loc=-tau, - scale=2 * tau) + cdf = scipy.stats.triang.cdf(time.magnitude, c=0.5, loc=-tau, scale=2 * tau) return cdf def icdf(self, fraction): @@ -598,7 +609,7 @@ def min_cutoff(self): def _evaluate(self, times): tau = math.sqrt(5) * self.sigma.rescale(times.units).magnitude t_div_tau = np.clip(times.magnitude / tau, a_min=-1, a_max=1) - kernel = 3. / (4. * tau) * np.maximum(0., 1 - t_div_tau ** 2) + kernel = 3.0 / (4.0 * tau) * np.maximum(0.0, 1 - t_div_tau**2) kernel = pq.Quantity(kernel, units=1 / times.units) return kernel @@ -606,13 +617,13 @@ def cdf(self, time): self._check_time_input(time) tau = math.sqrt(5) * self.sigma.rescale(time.units).magnitude t_div_tau = np.clip(time.magnitude / tau, a_min=-1, a_max=1) - cdf = 3. / 4 * (t_div_tau - t_div_tau ** 3 / 3.) + 0.5 + cdf = 3.0 / 4 * (t_div_tau - t_div_tau**3 / 3.0) + 0.5 return cdf def icdf(self, fraction): self._check_fraction(fraction) # CDF(t) = -1/4 t^3 + 3/4 t + 1/2 - coefs = [-1. / 4, 0, 3. / 4, 0.5 - fraction] + coefs = [-1.0 / 4, 0, 3.0 / 4, 0.5 - fraction] roots = np.roots(coefs) icdf = next(root for root in roots if -1 <= root <= 1) tau = math.sqrt(5) * self.sigma @@ -667,16 +678,18 @@ def boundary_enclosing_area_fraction(self, fraction): self._check_fraction(fraction) # Python's complex-operator cannot handle quantities, hence the # following construction on quantities is necessary: - Delta_0 = complex(1.0 / (5.0 * self.sigma.magnitude ** 2), 0) / \ - self.sigma.units ** 2 - Delta_1 = complex(2.0 * np.sqrt(5.0) * fraction / - (25.0 * self.sigma.magnitude ** 3), 0) / \ - self.sigma.units ** 3 - C = ((Delta_1 + (Delta_1 ** 2.0 - 4.0 * Delta_0 ** 3.0) ** ( - 1.0 / 2.0)) / - 2.0) ** (1.0 / 3.0) + Delta_0 = ( + complex(1.0 / (5.0 * self.sigma.magnitude**2), 0) / self.sigma.units**2 + ) + Delta_1 = ( + complex(2.0 * np.sqrt(5.0) * fraction / (25.0 * self.sigma.magnitude**3), 0) + / self.sigma.units**3 + ) + C = ((Delta_1 + (Delta_1**2.0 - 4.0 * Delta_0**3.0) ** (1.0 / 2.0)) / 2.0) ** ( + 1.0 / 3.0 + ) u_3 = complex(-1.0 / 2.0, -np.sqrt(3.0) / 2.0) - b = -5.0 * self.sigma ** 2 * (u_3 * C + Delta_0 / (u_3 * C)) + b = -5.0 * self.sigma**2 * (u_3 * C + Delta_0 / (u_3 * C)) return b.real @@ -732,8 +745,7 @@ def cdf(self, time): def icdf(self, fraction): self._check_fraction(fraction) - icdf = scipy.stats.norm.ppf(fraction, loc=0, - scale=self.sigma.magnitude) + icdf = scipy.stats.norm.ppf(fraction, loc=0, scale=self.sigma.magnitude) return icdf * self.sigma.units def boundary_enclosing_area_fraction(self, fraction): @@ -861,7 +873,7 @@ def cdf(self, time): time = np.minimum(time, 0) return np.exp(time / tau) time = np.maximum(time, 0) - return 1. - np.exp(-time / tau) + return 1.0 - np.exp(-time / tau) def icdf(self, fraction): self._check_fraction(fraction) @@ -921,7 +933,7 @@ def _evaluate(self, times): times = times.magnitude if self.invert: times = -times - kernel = (times >= 0) * 1 / tau ** 2 * times * np.exp(-times / tau) + kernel = (times >= 0) * 1 / tau**2 * times * np.exp(-times / tau) kernel = pq.Quantity(kernel, units=1 / t_units) return kernel diff --git a/elephant/neo_tools.py b/elephant/neo_tools.py index 53ab745e6..4ec13cb5f 100644 --- a/elephant/neo_tools.py +++ b/elephant/neo_tools.py @@ -26,12 +26,13 @@ "extract_neo_attributes", "get_all_spiketrains", "get_all_events", - "get_all_epochs" + "get_all_epochs", ] -def extract_neo_attributes(neo_object, parents=True, child_first=True, - skip_array=False, skip_none=False): +def extract_neo_attributes( + neo_object, parents=True, child_first=True, skip_array=False, skip_none=False +): """ Given a Neo object, return a dictionary of attributes and annotations. @@ -67,18 +68,16 @@ def extract_neo_attributes(neo_object, parents=True, child_first=True, if not skip_array and hasattr(neo_object, "array_annotations"): # Exclude labels and durations, and any other fields that should not # be a part of array_annotation. - required_keys = set(neo_object.array_annotations).difference( - dir(neo_object)) + required_keys = set(neo_object.array_annotations).difference(dir(neo_object)) for a in required_keys: if "array_annotations" not in attrs: attrs["array_annotations"] = {} - attrs["array_annotations"][a] = \ - neo_object.array_annotations[a].copy() + attrs["array_annotations"][a] = neo_object.array_annotations[a].copy() for attr in neo_object._necessary_attrs + neo_object._recommended_attrs: if skip_array and len(attr) >= 3 and attr[2]: continue attr = attr[0] - if attr == getattr(neo_object, '_quantity_attr', None): + if attr == getattr(neo_object, "_quantity_attr", None): continue attrs[attr] = getattr(neo_object, attr, None) @@ -90,13 +89,16 @@ def extract_neo_attributes(neo_object, parents=True, child_first=True, if not parents: return attrs - for parent in getattr(neo_object, 'parents', []): + for parent in getattr(neo_object, "parents", []): if parent is None: continue - newattr = extract_neo_attributes(parent, parents=True, - child_first=child_first, - skip_array=skip_array, - skip_none=skip_none) + newattr = extract_neo_attributes( + parent, + parents=True, + child_first=child_first, + skip_array=skip_array, + skip_none=skip_none, + ) if child_first: newattr.update(attrs) attrs = newattr @@ -107,8 +109,10 @@ def extract_neo_attributes(neo_object, parents=True, child_first=True, def extract_neo_attrs(*args, **kwargs): - warnings.warn("'extract_neo_attrs' function is deprecated; " - "use 'extract_neo_attributes'", DeprecationWarning) + warnings.warn( + "'extract_neo_attrs' function is deprecated; " "use 'extract_neo_attributes'", + DeprecationWarning, + ) return extract_neo_attributes(*args, **kwargs) @@ -143,19 +147,18 @@ def _get_all_objs(container, class_name): """ if container.__class__.__name__ == class_name: return [container] - classholder = class_name.lower() + 's' + classholder = class_name.lower() + "s" if hasattr(container, classholder): vals = getattr(container, classholder) - elif hasattr(container, 'list_children_by_class'): + elif hasattr(container, "list_children_by_class"): vals = container.list_children_by_class(class_name) - elif hasattr(container, 'values') and not hasattr(container, 'ndim'): + elif hasattr(container, "values") and not hasattr(container, "ndim"): vals = container.values() - elif hasattr(container, '__iter__') and not hasattr(container, 'ndim'): + elif hasattr(container, "__iter__") and not hasattr(container, "ndim"): vals = container else: - raise ValueError('Cannot handle object of type %s' % type(container)) - res = list(chain.from_iterable(_get_all_objs(obj, class_name) - for obj in vals)) + raise ValueError("Cannot handle object of type %s" % type(container)) + res = list(chain.from_iterable(_get_all_objs(obj, class_name) for obj in vals)) return unique_objs(res) @@ -183,7 +186,7 @@ def get_all_spiketrains(container): in `container`. """ - return SpikeTrainList(_get_all_objs(container, 'SpikeTrain')) + return SpikeTrainList(_get_all_objs(container, "SpikeTrain")) def get_all_events(container): @@ -208,7 +211,7 @@ def get_all_events(container): A list of the unique `neo.Event` objects in `container`. """ - return _get_all_objs(container, 'Event') + return _get_all_objs(container, "Event") def get_all_epochs(container): @@ -233,4 +236,4 @@ def get_all_epochs(container): A list of the unique `neo.Epoch` objects in `container`. """ - return _get_all_objs(container, 'Epoch') + return _get_all_objs(container, "Epoch") diff --git a/elephant/parallel/__init__.py b/elephant/parallel/__init__.py index 370148c48..b9ec4397a 100644 --- a/elephant/parallel/__init__.py +++ b/elephant/parallel/__init__.py @@ -44,12 +44,9 @@ from .mpi import MPIPoolExecutor, MPICommExecutor except ImportError: # mpi4py is missing - warnings.warn("mpi4py package is missing. Please run 'pip install mpi4py' " - "in a terminal to activate MPI features.") - -__all__ = [ - "ProcessPoolExecutor", - "SingleProcess", - "MPIPoolExecutor", - "MPICommExecutor" -] + warnings.warn( + "mpi4py package is missing. Please run 'pip install mpi4py' " + "in a terminal to activate MPI features." + ) + +__all__ = ["ProcessPoolExecutor", "SingleProcess", "MPIPoolExecutor", "MPICommExecutor"] diff --git a/elephant/parallel/mpi.py b/elephant/parallel/mpi.py index c8ff646c0..3976f0cf7 100644 --- a/elephant/parallel/mpi.py +++ b/elephant/parallel/mpi.py @@ -60,6 +60,7 @@ class MPICommExecutor(MPIPoolExecutor): For more information of how to launch MPI processes in Python refer to https://mpi4py.readthedocs.io/en/stable/mpi4py.futures.html#command-line """ + def __init__(self, comm=None, root=0): super(MPICommExecutor, self).__init__(max_workers=None) if comm is None: diff --git a/elephant/parallel/parallel.py b/elephant/parallel/parallel.py index 84acb91c3..a91e1f960 100644 --- a/elephant/parallel/parallel.py +++ b/elephant/parallel/parallel.py @@ -8,8 +8,9 @@ class SingleProcess(object): """ def __repr__(self): - return "{name}({extra})".format(name=self.__class__.__name__, - extra=self._extra_repr()) + return "{name}({extra})".format( + name=self.__class__.__name__, extra=self._extra_repr() + ) def _extra_repr(self): return "" @@ -64,6 +65,7 @@ class ProcessPoolExecutor(SingleProcess): worker processes will be created as the machine has processors. Default: None """ + def __init__(self, max_workers=None): self.max_workers = max_workers diff --git a/elephant/phase_analysis.py b/elephant/phase_analysis.py index e74e9dee4..3dfac9a0a 100644 --- a/elephant/phase_analysis.py +++ b/elephant/phase_analysis.py @@ -34,7 +34,7 @@ "phase_locking_value", "mean_phase_vector", "phase_difference", - "weighted_phase_lag_index" + "weighted_phase_lag_index", ] @@ -127,11 +127,11 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): num_spiketrains = len(spiketrains) num_phase = len(hilbert_transform) - if num_spiketrains != 1 and num_phase != 1 and \ - num_spiketrains != num_phase: + if num_spiketrains != 1 and num_phase != 1 and num_spiketrains != num_phase: raise ValueError( "Number of spike trains and number of phase signals" - "must match, or either of the two must be a single signal.") + "must match, or either of the two must be a single signal." + ) # For each trial, select the first input start = [elem.t_start for elem in hilbert_transform] @@ -153,17 +153,18 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): # Take only spikes which lie directly within the signal segment - # ignore spikes sitting on the last sample - sttimeind = np.where(np.logical_and( - spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0] + sttimeind = np.where( + np.logical_and(spiketrain >= start[phase_i], spiketrain < stop[phase_i]) + )[0] # Extract times for speed reasons times = hilbert_transform[phase_i].times # Find index into signal for each spike ind_at_spike = ( - (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) / - hilbert_transform[phase_i].sampling_period). \ - simplified.magnitude.astype(int) + (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) + / hilbert_transform[phase_i].sampling_period + ).simplified.magnitude.astype(int) # Append new list to the results for this spiketrain result_phases.append([]) @@ -172,37 +173,37 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): # Step through all spikes for spike_i, ind_at_spike_j in enumerate(ind_at_spike): - - if interpolate and ind_at_spike_j+1 < len(times): + if interpolate and ind_at_spike_j + 1 < len(times): # Get relative spike occurrence between the two closest signal # sample points # if z->0 spike is more to the left sample # if z->1 more to the right sample - z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\ - hilbert_transform[phase_i].sampling_period + z = ( + spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j] + ) / hilbert_transform[phase_i].sampling_period # Save hilbert_transform (interpolate on circle) - p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j] - ).item() - p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1] - ).item() - interpolation = (1 - z) * np.exp(complex(0, p1)) \ - + z * np.exp(complex(0, p2)) + p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]).item() + p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1]).item() + interpolation = (1 - z) * np.exp(complex(0, p1)) + z * np.exp( + complex(0, p2) + ) p12 = np.angle([interpolation]) result_phases[spiketrain_i].append(p12) # Save amplitude result_amps[spiketrain_i].append( - (1 - z) * np.abs( - hilbert_transform[phase_i][ind_at_spike_j]) + - z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1])) + (1 - z) * np.abs(hilbert_transform[phase_i][ind_at_spike_j]) + + z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1]) + ) else: p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]) result_phases[spiketrain_i].append(p1) # Save amplitude result_amps[spiketrain_i].append( - np.abs(hilbert_transform[phase_i][ind_at_spike_j])) + np.abs(hilbert_transform[phase_i][ind_at_spike_j]) + ) # Save time result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]]) @@ -257,8 +258,9 @@ def phase_locking_value(phases_i, phases_j): """ if np.shape(phases_i) != np.shape(phases_j): - raise ValueError("trial number and trial length of signal x and y " - "must be equal") + raise ValueError( + "trial number and trial length of signal x and y " "must be equal" + ) # trial by trial and time-resolved # version 0.2: signal x and y have multiple trials @@ -333,8 +335,9 @@ def phase_difference(alpha, beta): return phase_diff -def weighted_phase_lag_index(signal_i, signal_j, sampling_frequency=None, - absolute_value=True): +def weighted_phase_lag_index( + signal_i, signal_j, sampling_frequency=None, absolute_value=True +): r""" Calculates the Weigthed Phase-Lag Index (WPLI) :cite:`phase-Vinck11_1548`. @@ -389,18 +392,19 @@ def weighted_phase_lag_index(signal_i, signal_j, sampling_frequency=None, spectra of a particular frequency of the signals i and j. """ - if isinstance(signal_i, neo.AnalogSignal) and \ - isinstance(signal_j, neo.AnalogSignal): # neo.AnalogSignal input - if signal_i.sampling_rate.rescale("Hz") != \ - signal_j.sampling_rate.rescale("Hz"): + if isinstance(signal_i, neo.AnalogSignal) and isinstance( + signal_j, neo.AnalogSignal + ): # neo.AnalogSignal input + if signal_i.sampling_rate.rescale("Hz") != signal_j.sampling_rate.rescale("Hz"): raise ValueError("sampling rate of signal i and j must be equal") sampling_frequency = signal_i.sampling_rate signal_i = signal_i.magnitude signal_j = signal_j.magnitude else: # np.array() or Quantity input if sampling_frequency is None: - raise ValueError("sampling frequency must be given for np.array or" - "Quantity input") + raise ValueError( + "sampling frequency must be given for np.array or" "Quantity input" + ) if np.shape(signal_i) != np.shape(signal_j): if len(signal_i) != len(signal_j): diff --git a/elephant/signal_processing.py b/elephant/signal_processing.py index 9a901b41a..9ffecae68 100644 --- a/elephant/signal_processing.py +++ b/elephant/signal_processing.py @@ -34,7 +34,7 @@ "wavelet_transform", "hilbert", "rauc", - "derivative" + "derivative", ] @@ -166,8 +166,10 @@ def zscore(signal, inplace=True): # Perform inplace operation only if array is of dtype float. # Otherwise, raise an error. if inplace and not np.issubdtype(sig.dtype, np.floating): - raise ValueError(f"Cannot perform inplace operation as the " - f"signal dtype is not float. Source: {sig.name}") + raise ValueError( + f"Cannot perform inplace operation as the " + f"signal dtype is not float. Source: {sig.name}" + ) sig_normalized = sig.magnitude.astype(mean.dtype, copy=not inplace) sig_normalized -= mean @@ -181,8 +183,9 @@ def zscore(signal, inplace=True): sig_dimless = sig else: # Create new object - sig_dimless = sig.duplicate_with_new_data(sig_normalized, - units=pq.dimensionless) + sig_dimless = sig.duplicate_with_new_data( + sig_normalized, units=pq.dimensionless + ) # todo use flag once is fixed # https://github.com/NeuralEnsemble/python-neo/issues/752 sig_dimless.array_annotate(**sig.array_annotations) @@ -195,8 +198,9 @@ def zscore(signal, inplace=True): return signal_ztransformed -def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, - n_lags=None, scaleopt='unbiased'): +def cross_correlation_function( + signal, channel_pairs, hilbert_envelope=False, n_lags=None, scaleopt="unbiased" +): r""" Computes an estimator of the cross-correlation function :cite:`signal-Stoica2005`. @@ -329,22 +333,27 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, # Check input if not isinstance(signal, neo.AnalogSignal): - raise ValueError('Input signal must be of type neo.AnalogSignal') + raise ValueError("Input signal must be of type neo.AnalogSignal") if pairs.shape[1] != 2: - raise ValueError("'channel_pairs' is not a list of channel pair " - "indices. Cannot define pairs for cross-correlation.") + raise ValueError( + "'channel_pairs' is not a list of channel pair " + "indices. Cannot define pairs for cross-correlation." + ) if not isinstance(hilbert_envelope, bool): raise ValueError("'hilbert_envelope' must be a boolean value") if n_lags is not None: if not isinstance(n_lags, int) or n_lags <= 0: - raise ValueError('n_lags must be a non-negative integer') + raise ValueError("n_lags must be a non-negative integer") # z-score analog signal and store channel time series in different arrays # Cross-correlation will be calculated between xsig and ysig z_transformed = signal.magnitude - signal.magnitude.mean(axis=0) - z_transformed = np.divide(z_transformed, signal.magnitude.std(axis=0), - out=z_transformed, - where=z_transformed != 0) + z_transformed = np.divide( + z_transformed, + signal.magnitude.std(axis=0), + out=z_transformed, + where=z_transformed != 0, + ) # transpose (nch, xy, nt) -> (xy, nt, nch) xsig, ysig = np.transpose(z_transformed.T[pairs], (1, 2, 0)) @@ -355,16 +364,16 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, # Calculate cross-correlation by taking Fourier transform of signal, # multiply in Fourier space, and transform back. Correct for bias due # to zero-padding - xcorr = scipy.signal.fftconvolve(xsig, ysig[::-1], mode='same', axes=0) - if scaleopt == 'biased': + xcorr = scipy.signal.fftconvolve(xsig, ysig[::-1], mode="same", axes=0) + if scaleopt == "biased": xcorr /= nt - elif scaleopt == 'unbiased': + elif scaleopt == "unbiased": normalizer = np.expand_dims(nt - np.abs(tau), axis=1) xcorr /= normalizer - elif scaleopt in ('normalized', 'coeff'): - normalizer = np.sqrt((xsig ** 2).sum(axis=0) * (ysig ** 2).sum(axis=0)) + elif scaleopt in ("normalized", "coeff"): + normalizer = np.sqrt((xsig**2).sum(axis=0) * (ysig**2).sum(axis=0)) xcorr /= normalizer - elif scaleopt != 'none': + elif scaleopt != "none": raise ValueError("Invalid scaleopt mode: '{}'".format(scaleopt)) # Calculate envelope of cross-correlation function with Hilbert transform. @@ -375,20 +384,29 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, # Cut off lags outside the desired range if n_lags is not None: tau0 = np.argwhere(tau == 0).item() - xcorr = xcorr[tau0 - n_lags: tau0 + n_lags + 1, :] + xcorr = xcorr[tau0 - n_lags : tau0 + n_lags + 1, :] # Return neo.AnalogSignal - cross_corr = neo.AnalogSignal(xcorr, - units='', - t_start=tau[0] * signal.sampling_period, - t_stop=tau[-1] * signal.sampling_period, - sampling_rate=signal.sampling_rate, - dtype=float) + cross_corr = neo.AnalogSignal( + xcorr, + units="", + t_start=tau[0] * signal.sampling_period, + t_stop=tau[-1] * signal.sampling_period, + sampling_rate=signal.sampling_rate, + dtype=float, + ) return cross_corr -def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, - filter_function='filtfilt', sampling_frequency=1.0, axis=-1): +def butter( + signal, + highpass_frequency=None, + lowpass_frequency=None, + order=4, + filter_function="filtfilt", + sampling_frequency=1.0, + axis=-1, +): """ Butterworth filtering function for `neo.AnalogSignal`. @@ -488,45 +506,44 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, >>> freq[0], psd[0, 0] # doctest: +SKIP (array(0.) * Hz, array(7.21464674e-08) * mV**2/Hz) """ - available_filters = 'lfilter', 'filtfilt', 'sosfiltfilt' + available_filters = "lfilter", "filtfilt", "sosfiltfilt" if filter_function not in available_filters: - raise ValueError("Invalid `filter_function`: {filter_function}. " - "Available filters: {available_filters}".format( - filter_function=filter_function, - available_filters=available_filters)) + raise ValueError( + "Invalid `filter_function`: {filter_function}. " + "Available filters: {available_filters}".format( + filter_function=filter_function, available_filters=available_filters + ) + ) # design filter - if hasattr(signal, 'sampling_rate'): + if hasattr(signal, "sampling_rate"): sampling_frequency = signal.sampling_rate.rescale(pq.Hz).magnitude if isinstance(highpass_frequency, pq.quantity.Quantity): highpass_frequency = highpass_frequency.rescale(pq.Hz).magnitude if isinstance(lowpass_frequency, pq.quantity.Quantity): lowpass_frequency = lowpass_frequency.rescale(pq.Hz).magnitude - Fn = sampling_frequency / 2. + Fn = sampling_frequency / 2.0 # filter type is determined according to the values of cut-off # frequencies if lowpass_frequency and highpass_frequency: if highpass_frequency < lowpass_frequency: Wn = (highpass_frequency / Fn, lowpass_frequency / Fn) - btype = 'bandpass' + btype = "bandpass" else: Wn = (lowpass_frequency / Fn, highpass_frequency / Fn) - btype = 'bandstop' + btype = "bandstop" elif lowpass_frequency: Wn = lowpass_frequency / Fn - btype = 'lowpass' + btype = "lowpass" elif highpass_frequency: Wn = highpass_frequency / Fn - btype = 'highpass' + btype = "highpass" else: - raise ValueError( - "Either highpass_frequency or lowpass_frequency must be given" - ) - if filter_function == 'sosfiltfilt': - output = 'sos' + raise ValueError("Either highpass_frequency or lowpass_frequency must be given") + if filter_function == "sosfiltfilt": + output = "sos" else: - output = 'ba' - designed_filter = scipy.signal.butter(order, Wn, btype=btype, - output=output) + output = "ba" + designed_filter = scipy.signal.butter(order, Wn, btype=btype, output=output) # When the input is AnalogSignal, the axis for time index (i.e. the # first axis) needs to be rolled to the last @@ -535,15 +552,14 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, data = np.rollaxis(data, 0, len(data.shape)) # apply filter - if filter_function == 'lfilter': + if filter_function == "lfilter": b, a = designed_filter filtered_data = scipy.signal.lfilter(b=b, a=a, x=data, axis=axis) - elif filter_function == 'filtfilt': + elif filter_function == "filtfilt": b, a = designed_filter filtered_data = scipy.signal.filtfilt(b=b, a=a, x=data, axis=axis) else: - filtered_data = scipy.signal.sosfiltfilt(sos=designed_filter, - x=data, axis=axis) + filtered_data = scipy.signal.sosfiltfilt(sos=designed_filter, x=data, axis=axis) if isinstance(signal, neo.AnalogSignal): filtered_data = np.rollaxis(filtered_data, -1, 0) @@ -558,8 +574,9 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, return filtered_data -def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0, - zero_padding=True): +def wavelet_transform( + signal, frequency, n_cycles=6.0, sampling_frequency=1.0, zero_padding=True +): r""" Compute the wavelet transform of a given signal with Morlet mother wavelet. The parametrization of the wavelet is based on @@ -655,14 +672,20 @@ def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0, [-2.95996766-0.9872236j ]]) """ + def _morlet_wavelet_ft(freq, n_cycles, fs, n): # Generate the Fourier transform of Morlet wavelet as defined # in Le van Quyen et al. J Neurosci Meth 111:83-98 (2001). - sigma = n_cycles / (6. * freq) + sigma = n_cycles / (6.0 * freq) freqs = np.fft.fftfreq(n, 1.0 / fs) - heaviside = np.array(freqs > 0., dtype=float) - ft_real = np.sqrt(2 * np.pi * freq) * sigma * np.exp( - -2 * (np.pi * sigma * (freqs - freq)) ** 2) * heaviside * fs + heaviside = np.array(freqs > 0.0, dtype=float) + ft_real = ( + np.sqrt(2 * np.pi * freq) + * sigma + * np.exp(-2 * (np.pi * sigma * (freqs - freq)) ** 2) + * heaviside + * fs + ) ft_imag = np.zeros_like(ft_real) return ft_real + 1.0j * ft_imag @@ -674,24 +697,29 @@ def _morlet_wavelet_ft(freq, n_cycles, fs, n): # When the input is AnalogSignal, use its attribute to specify the # sampling frequency - if hasattr(signal, 'sampling_rate'): + if hasattr(signal, "sampling_rate"): sampling_frequency = signal.sampling_rate if isinstance(sampling_frequency, pq.quantity.Quantity): - sampling_frequency = sampling_frequency.rescale('Hz').magnitude + sampling_frequency = sampling_frequency.rescale("Hz").magnitude if isinstance(frequency, (list, tuple, np.ndarray)): freqs = np.asarray(frequency) else: - freqs = np.array([frequency, ]) + freqs = np.array( + [ + frequency, + ] + ) if isinstance(freqs[0], pq.quantity.Quantity): - freqs = [f.rescale('Hz').magnitude for f in freqs] + freqs = [f.rescale("Hz").magnitude for f in freqs] # check whether the given central frequencies are less than the # Nyquist frequency of the signal if np.any(freqs >= sampling_frequency / 2): - raise ValueError("'frequency' elements must be less than the half of " - "the 'sampling_frequency' ({}) Hz" - .format(sampling_frequency)) + raise ValueError( + "'frequency' elements must be less than the half of " + "the 'sampling_frequency' ({}) Hz".format(sampling_frequency) + ) # check if n_cycles is positive if n_cycles <= 0: @@ -729,7 +757,7 @@ def _morlet_wavelet_ft(freq, n_cycles, fs, n): return signal_wt -def hilbert(signal, padding='nextpow'): +def hilbert(signal, padding="nextpow"): """ Apply a Hilbert transform to a `neo.AnalogSignal` object in order to obtain its (complex) analytic signal. @@ -806,7 +834,7 @@ def hilbert(signal, padding='nextpow'): if isinstance(padding, int): # User defined padding n = padding - elif padding == 'nextpow': + elif padding == "nextpow": # To speed up calculation of the Hilbert transform, make sure we change # the signal to be of a length that is a power of two. Failure to do so # results in computations of certain signal lengths to not finish (or @@ -823,14 +851,15 @@ def hilbert(signal, padding='nextpow'): # For this reason, nextpow is the default setting for now. n = 2 ** (int(np.log2(n_org - 1)) + 1) - elif padding == 'none' or padding is None: + elif padding == "none" or padding is None: # No padding n = n_org else: raise ValueError("Invalid padding '{}'.".format(padding)) output = signal.duplicate_with_new_data( - scipy.signal.hilbert(signal.magnitude, N=n, axis=0)[:n_org]) + scipy.signal.hilbert(signal.magnitude, N=n, axis=0)[:n_org] + ) # todo use flag once is fixed # https://github.com/NeuralEnsemble/python-neo/issues/752 output.array_annotate(**signal.array_annotations) @@ -918,22 +947,24 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): """ if not isinstance(signal, neo.AnalogSignal): - raise ValueError('Input signal is not a neo.AnalogSignal!') + raise ValueError("Input signal is not a neo.AnalogSignal!") if baseline is None: pass - elif baseline == 'mean': + elif baseline == "mean": # subtract mean from each channel signal = signal - signal.mean(axis=0) - elif baseline == 'median': + elif baseline == "median": # subtract median from each channel signal = signal - np.median(signal.as_quantity(), axis=0) elif isinstance(baseline, pq.Quantity): # subtract arbitrary baseline signal = signal - baseline else: - raise ValueError("baseline must be either None, 'mean', 'median', or " - "a Quantity. Got {}".format(baseline)) + raise ValueError( + "baseline must be either None, 'mean', 'median', or " + "a Quantity. Got {}".format(baseline) + ) # slice the signal after subtracting baseline signal = signal.time_slice(t_start, t_stop) @@ -943,12 +974,14 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): if isinstance(bin_duration, pq.Quantity): samples_per_bin = int( np.round( - bin_duration.rescale('s') / - signal.sampling_period.rescale('s'))) + bin_duration.rescale("s") / signal.sampling_period.rescale("s") + ) + ) n_bins = int(np.ceil(signal.shape[0] / samples_per_bin)) else: - raise ValueError("bin_duration must be a Quantity. Got {}".format( - bin_duration)) + raise ValueError( + "bin_duration must be a Quantity. Got {}".format(bin_duration) + ) else: # all samples in one bin samples_per_bin = signal.shape[0] @@ -972,8 +1005,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): # return an AnalogSignal with times corresponding to center of each bin t_start = signal.t_start.rescale(bin_duration.units) + bin_duration / 2 - rauc_sig = neo.AnalogSignal(rauc, t_start=t_start, - sampling_period=bin_duration) + rauc_sig = neo.AnalogSignal(rauc, t_start=t_start, sampling_period=bin_duration) return rauc_sig @@ -1017,11 +1049,12 @@ def derivative(signal): """ if not isinstance(signal, neo.AnalogSignal): - raise TypeError('Input signal is not a neo.AnalogSignal!') + raise TypeError("Input signal is not a neo.AnalogSignal!") derivative_sig = neo.AnalogSignal( np.diff(signal.as_quantity(), axis=0) / signal.sampling_period, t_start=signal.t_start + signal.sampling_period / 2, - sampling_period=signal.sampling_period) + sampling_period=signal.sampling_period, + ) return derivative_sig diff --git a/elephant/spade.py b/elephant/spade.py index 251fc6ecd..182797eef 100644 --- a/elephant/spade.py +++ b/elephant/spade.py @@ -102,6 +102,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: BSD, see LICENSE.txt for details. """ + from __future__ import division, print_function, unicode_literals import operator @@ -131,10 +132,10 @@ "test_signature_significance", "approximate_stability", "pattern_set_reduction", - "concept_output_to_patterns" + "concept_output_to_patterns", ] -warnings.simplefilter('once', UserWarning) +warnings.simplefilter("once", UserWarning) try: from mpi4py import MPI # for parallelized routines @@ -151,12 +152,27 @@ HAVE_FIM = False -@deprecated_alias(binsize='bin_size') -def spade(spiketrains, bin_size, winlen, min_spikes=2, min_occ=2, - max_spikes=None, max_occ=None, min_neu=1, approx_stab_pars=None, - n_surr=0, dither=15 * pq.ms, spectrum='#', - alpha=None, stat_corr='fdr_bh', surr_method='dither_spikes', - psr_param=None, output_format='patterns', **surr_kwargs): +@deprecated_alias(binsize="bin_size") +def spade( + spiketrains, + bin_size, + winlen, + min_spikes=2, + min_occ=2, + max_spikes=None, + max_occ=None, + min_neu=1, + approx_stab_pars=None, + n_surr=0, + dither=15 * pq.ms, + spectrum="#", + alpha=None, + stat_corr="fdr_bh", + surr_method="dither_spikes", + psr_param=None, + output_format="patterns", + **surr_kwargs, +): r""" Perform the SPADE :cite:`spade-Torre2013_132`, :cite:`spade-Quaglio2017_41`, :cite:`spade-Stella2019_104022` analysis for @@ -342,40 +358,60 @@ def spade(spiketrains, bin_size, winlen, min_spikes=2, min_occ=2, rank = 0 compute_stability = _check_input( - spiketrains=spiketrains, bin_size=bin_size, winlen=winlen, - min_spikes=min_spikes, min_occ=min_occ, - max_spikes=max_spikes, max_occ=max_occ, min_neu=min_neu, + spiketrains=spiketrains, + bin_size=bin_size, + winlen=winlen, + min_spikes=min_spikes, + min_occ=min_occ, + max_spikes=max_spikes, + max_occ=max_occ, + min_neu=min_neu, approx_stab_pars=approx_stab_pars, - n_surr=n_surr, dither=dither, spectrum=spectrum, - alpha=alpha, stat_corr=stat_corr, surr_method=surr_method, - psr_param=psr_param, output_format=output_format) + n_surr=n_surr, + dither=dither, + spectrum=spectrum, + alpha=alpha, + stat_corr=stat_corr, + surr_method=surr_method, + psr_param=psr_param, + output_format=output_format, + ) time_mining = time.time() if rank == 0 or compute_stability: # Mine the spiketrains for extraction of concepts concepts, rel_matrix = concepts_mining( - spiketrains, bin_size, winlen, min_spikes=min_spikes, - min_occ=min_occ, max_spikes=max_spikes, max_occ=max_occ, - min_neu=min_neu, report='a') + spiketrains, + bin_size, + winlen, + min_spikes=min_spikes, + min_occ=min_occ, + max_spikes=max_spikes, + max_occ=max_occ, + min_neu=min_neu, + report="a", + ) time_mining = time.time() - time_mining print(f"Time for data mining: {time_mining}") # Decide if compute the approximated stability if compute_stability: - if 'stability_thresh' in approx_stab_pars.keys(): - stability_thresh = approx_stab_pars.pop('stability_thresh') + if "stability_thresh" in approx_stab_pars.keys(): + stability_thresh = approx_stab_pars.pop("stability_thresh") else: stability_thresh = None # Computing the approximated stability of all the concepts time_stability = time.time() - concepts = approximate_stability( - concepts, rel_matrix, **approx_stab_pars) + concepts = approximate_stability(concepts, rel_matrix, **approx_stab_pars) time_stability = time.time() - time_stability print(f"Time for stability computation: {time_stability}") # Filtering the concepts using stability thresholds if stability_thresh is not None: - concepts = [concept for concept in concepts - if _stability_filter(concept, stability_thresh)] + concepts = [ + concept + for concept in concepts + if _stability_filter(concept, stability_thresh) + ] output = {} pv_spec = None # initialize pv_spec to None @@ -384,18 +420,28 @@ def spade(spiketrains, bin_size, winlen, min_spikes=2, min_occ=2, # Compute pvalue spectrum time_pvalue_spectrum = time.time() pv_spec = pvalue_spectrum( - spiketrains, bin_size, winlen, dither=dither, n_surr=n_surr, - min_spikes=min_spikes, min_occ=min_occ, max_spikes=max_spikes, - max_occ=max_occ, min_neu=min_neu, spectrum=spectrum, - surr_method=surr_method, **surr_kwargs) + spiketrains, + bin_size, + winlen, + dither=dither, + n_surr=n_surr, + min_spikes=min_spikes, + min_occ=min_occ, + max_spikes=max_spikes, + max_occ=max_occ, + min_neu=min_neu, + spectrum=spectrum, + surr_method=surr_method, + **surr_kwargs, + ) time_pvalue_spectrum = time.time() - time_pvalue_spectrum print(f"Time for pvalue spectrum computation: {time_pvalue_spectrum}") # Storing pvalue spectrum - output['pvalue_spectrum'] = pv_spec + output["pvalue_spectrum"] = pv_spec # rank!=0 returning None if rank != 0: - warnings.warn('Returning None because executed on a process != 0') + warnings.warn("Returning None because executed on a process != 0") return None # Initialize non-significant signatures as empty list: @@ -406,41 +452,68 @@ def spade(spiketrains, bin_size, winlen, min_spikes=2, min_occ=2, # Computing non-significant entries of the spectrum applying # the statistical correction ns_signatures = test_signature_significance( - pv_spec, concepts, alpha, winlen, corr=stat_corr, - report='non_significant', spectrum=spectrum) + pv_spec, + concepts, + alpha, + winlen, + corr=stat_corr, + report="non_significant", + spectrum=spectrum, + ) # Storing non-significant entries of the pvalue spectrum - output['non_sgnf_sgnt'] = ns_signatures + output["non_sgnf_sgnt"] = ns_signatures # Filter concepts with pvalue spectrum (psf) if len(ns_signatures) > 0: - concepts = [concept for concept in concepts - if _pattern_spectrum_filter(concept, ns_signatures, - spectrum, winlen)] + concepts = [ + concept + for concept in concepts + if _pattern_spectrum_filter(concept, ns_signatures, spectrum, winlen) + ] # Decide whether to filter concepts using psr if psr_param is not None: # Filter using conditional tests (psr) concepts = pattern_set_reduction( - concepts, ns_signatures, winlen=winlen, spectrum=spectrum, - h_subset_filtering=psr_param[0], k_superset_filtering=psr_param[1], - l_covered_spikes=psr_param[2], min_spikes=min_spikes, - min_occ=min_occ) + concepts, + ns_signatures, + winlen=winlen, + spectrum=spectrum, + h_subset_filtering=psr_param[0], + k_superset_filtering=psr_param[1], + l_covered_spikes=psr_param[2], + min_spikes=min_spikes, + min_occ=min_occ, + ) # Storing patterns for output format concepts - if output_format == 'concepts': - output['patterns'] = concepts + if output_format == "concepts": + output["patterns"] = concepts else: # output_format == 'patterns': # Transforming concepts to dictionary containing pattern's infos - output['patterns'] = concept_output_to_patterns( - concepts, winlen, bin_size, pv_spec, spectrum, - spiketrains[0].t_start) + output["patterns"] = concept_output_to_patterns( + concepts, winlen, bin_size, pv_spec, spectrum, spiketrains[0].t_start + ) return output def _check_input( - spiketrains, bin_size, winlen, min_spikes=2, min_occ=2, - max_spikes=None, max_occ=None, min_neu=1, approx_stab_pars=None, - n_surr=0, dither=15 * pq.ms, spectrum='#', - alpha=None, stat_corr='fdr_bh', surr_method='dither_spikes', - psr_param=None, output_format='patterns'): + spiketrains, + bin_size, + winlen, + min_spikes=2, + min_occ=2, + max_spikes=None, + max_occ=None, + min_neu=1, + approx_stab_pars=None, + n_surr=0, + dither=15 * pq.ms, + spectrum="#", + alpha=None, + stat_corr="fdr_bh", + surr_method="dither_spikes", + psr_param=None, + output_format="patterns", +): """ Checks all input given to SPADE Parameters @@ -455,108 +528,125 @@ def _check_input( # Check spiketrains if not all([isinstance(elem, neo.SpikeTrain) for elem in spiketrains]): - raise TypeError( - 'spiketrains must be a list of SpikeTrains') + raise TypeError("spiketrains must be a list of SpikeTrains") # Check that all spiketrains have same t_start and same t_stop - if not all([spiketrain.t_start == spiketrains[0].t_start - for spiketrain in spiketrains]) or \ - not all([spiketrain.t_stop == spiketrains[0].t_stop - for spiketrain in spiketrains]): - raise ValueError( - 'All spiketrains must have the same t_start and t_stop') + if not all( + [spiketrain.t_start == spiketrains[0].t_start for spiketrain in spiketrains] + ) or not all( + [spiketrain.t_stop == spiketrains[0].t_stop for spiketrain in spiketrains] + ): + raise ValueError("All spiketrains must have the same t_start and t_stop") # Check bin_size if not isinstance(bin_size, pq.Quantity): - raise TypeError('bin_size must be a pq.Quantity') + raise TypeError("bin_size must be a pq.Quantity") # Check winlen if not isinstance(winlen, int): - raise TypeError('winlen must be an integer') + raise TypeError("winlen must be an integer") # Check min_spikes if not isinstance(min_spikes, int): - raise TypeError('min_spikes must be an integer') + raise TypeError("min_spikes must be an integer") # Check min_occ if not isinstance(min_occ, int): - raise TypeError('min_occ must be an integer') + raise TypeError("min_occ must be an integer") # Check max_spikes if not (isinstance(max_spikes, int) or max_spikes is None): - raise TypeError('max_spikes must be an integer or None') + raise TypeError("max_spikes must be an integer or None") # Check max_occ if not (isinstance(max_occ, int) or max_occ is None): - raise TypeError('max_occ must be an integer or None') + raise TypeError("max_occ must be an integer or None") # Check min_neu if not isinstance(min_neu, int): - raise TypeError('min_neu must be an integer') + raise TypeError("min_neu must be an integer") # Check approx_stab_pars compute_stability = False if isinstance(approx_stab_pars, dict): - if 'n_subsets' in approx_stab_pars.keys() or\ - ('epsilon' in approx_stab_pars.keys() and - 'delta' in approx_stab_pars.keys()): + if "n_subsets" in approx_stab_pars.keys() or ( + "epsilon" in approx_stab_pars.keys() and "delta" in approx_stab_pars.keys() + ): compute_stability = True else: raise ValueError( - 'for approximate stability computation you need to ' - 'pass n_subsets or epsilon and delta.') + "for approximate stability computation you need to " + "pass n_subsets or epsilon and delta." + ) # Check n_surr if not isinstance(n_surr, int): - raise TypeError('n_surr must be an integer') + raise TypeError("n_surr must be an integer") # Check dither if not isinstance(dither, pq.Quantity): - raise TypeError('dither must be a pq.Quantity') + raise TypeError("dither must be a pq.Quantity") # Check spectrum - if spectrum not in ('#', '3d#'): + if spectrum not in ("#", "3d#"): raise ValueError("spectrum must be '#' or '3d#'") # Check alpha if isinstance(alpha, (int, float)): # Check redundant use of alpha - if 0. < alpha < 1. and n_surr == 0: - warnings.warn('0.=1') + raise ValueError("min_neu must be an integer >=1") # By default, set the maximum pattern size to the number of spiketrains if max_z is None: max_z = max(max(map(len, transactions)), min_z + 1) @@ -891,9 +995,9 @@ def _fpgrowth(transactions, min_c=2, min_z=2, max_z=None, # Initializing outputs concepts = [] - if report == '#': + if report == "#": spec_matrix = np.zeros((max_z + 1, max_c + 1)) - if report == '3d#': + if report == "3d#": spec_matrix = np.zeros((max_z + 1, max_c + 1, winlen)) spectrum = [] # check whether all transactions are identical @@ -912,12 +1016,13 @@ def _fpgrowth(transactions, min_c=2, min_z=2, max_z=None, supp=-min_c, zmin=min_z, zmax=max_z, - report='a', - algo='s', + report="a", + algo="s", winlen=winlen, min_neu=min_neu, threads=0, - verbose=4) + verbose=4, + ) break else: fpgrowth_output = [(tuple(transactions[0]), len(transactions))] @@ -926,47 +1031,46 @@ def _fpgrowth(transactions, min_c=2, min_z=2, max_z=None, # if _fpgrowth_filter(concept, winlen, max_c, min_neu)] # filter out subsets of patterns that are found as a side effect # of using the moving window strategy - fpgrowth_output = _filter_for_moving_window_subsets( - fpgrowth_output, winlen) - for (intent, supp) in fpgrowth_output: - if report == 'a': + fpgrowth_output = _filter_for_moving_window_subsets(fpgrowth_output, winlen) + for intent, supp in fpgrowth_output: + if report == "a": if rel_matrix is not None: # Computing the extent of the concept (patterns # occurrences), checking in rel_matrix in which windows # the intent occurred - extent = tuple( - np.nonzero( - np.all(rel_matrix[:, intent], axis=1) - )[0]) + extent = tuple(np.nonzero(np.all(rel_matrix[:, intent], axis=1))[0]) concepts.append((intent, extent)) # Computing 2d spectrum - elif report == '#': + elif report == "#": spec_matrix[len(intent) - 1, supp - 1] += 1 # Computing 3d spectrum - elif report == '3d#': - spec_matrix[len(intent) - 1, supp - 1, max( - np.array(intent) % winlen)] += 1 + elif report == "3d#": + spec_matrix[len(intent) - 1, supp - 1, max(np.array(intent) % winlen)] += 1 del fpgrowth_output - if report == 'a': + if report == "a": return concepts - if report == '#': - for (size, occurrences) in np.transpose(np.where(spec_matrix != 0)): + if report == "#": + for size, occurrences in np.transpose(np.where(spec_matrix != 0)): spectrum.append( - (size + 1, occurrences + 1, - int(spec_matrix[size, occurrences]))) - elif report == '3d#': - for (size, occurrences, duration) in\ - np.transpose(np.where(spec_matrix != 0)): + (size + 1, occurrences + 1, int(spec_matrix[size, occurrences])) + ) + elif report == "3d#": + for size, occurrences, duration in np.transpose(np.where(spec_matrix != 0)): spectrum.append( - (size + 1, occurrences + 1, duration, - int(spec_matrix[size, occurrences, duration]))) + ( + size + 1, + occurrences + 1, + duration, + int(spec_matrix[size, occurrences, duration]), + ) + ) del spec_matrix if len(spectrum) > 0: spectrum = np.array(spectrum) - elif report == '#': + elif report == "#": spectrum = np.zeros(shape=(0, 3)) - elif report == '3d#': + elif report == "3d#": spectrum = np.zeros(shape=(0, 4)) return spectrum @@ -1012,7 +1116,7 @@ def _filter_for_moving_window_subsets(concepts, winlen): if winlen == 1: return concepts - if hasattr(concepts[0], 'intent'): + if hasattr(concepts[0], "intent"): # fca format # sort the concepts by (decreasing) support concepts.sort(key=lambda c: -len(c.extent)) @@ -1020,9 +1124,10 @@ def _filter_for_moving_window_subsets(concepts, winlen): support = np.array([len(c.extent) for c in concepts]) # convert transactions relative to last pattern spike - converted_transactions = [_rereference_to_last_spike(concept.intent, - winlen=winlen) - for concept in concepts] + converted_transactions = [ + _rereference_to_last_spike(concept.intent, winlen=winlen) + for concept in concepts + ] else: # fim.fpgrowth format # sort the concepts by (decreasing) support @@ -1031,9 +1136,10 @@ def _filter_for_moving_window_subsets(concepts, winlen): support = np.array([concept[1] for concept in concepts]) # convert transactions relative to last pattern spike - converted_transactions = [_rereference_to_last_spike(concept[0], - winlen=winlen) - for concept in concepts] + converted_transactions = [ + _rereference_to_last_spike(concept[0], winlen=winlen) + for concept in concepts + ] output = [] @@ -1049,16 +1155,17 @@ def _filter_for_moving_window_subsets(concepts, winlen): for i in support_indices: intersection = reduce( operator.and_, - (reverse_map[window_bin] - for window_bin in converted_transactions[i])) + (reverse_map[window_bin] for window_bin in converted_transactions[i]), + ) if len(intersection) == 1: output.append(concepts[i]) return output -def _fast_fca(context, min_c=2, min_z=2, max_z=None, - max_c=None, report='a', winlen=1, min_neu=1): +def _fast_fca( + context, min_c=2, min_z=2, max_z=None, max_c=None, report="a", winlen=1, min_neu=1 +): """ Find concepts of the context with the fast-fca algorithm. @@ -1123,25 +1230,27 @@ def _fast_fca(context, min_c=2, min_z=2, max_z=None, concepts = [] # Check parameters if min_neu < 1: - raise ValueError('min_neu must be an integer >=1') + raise ValueError("min_neu must be an integer >=1") # By default, set maximum number of attributes if max_z is None: max_z = len(context) # By default, set maximum number of data to number of bins if max_c is None: max_c = len(context) - if report == '#': + if report == "#": spec_matrix = np.zeros((max_z, max_c)) - elif report == '3d#': + elif report == "3d#": spec_matrix = np.zeros((max_z, max_c, winlen)) spectrum = [] # Mining the spiketrains with fast fca algorithm fca_out = fast_fca.FormalConcepts(context) fca_out.computeLattice() fca_concepts = fca_out.concepts - fca_concepts = [concept for concept in fca_concepts - if _fca_filter(concept, winlen, min_c, min_z, max_c, max_z, - min_neu)] + fca_concepts = [ + concept + for concept in fca_concepts + if _fca_filter(concept, winlen, min_c, min_z, max_c, max_z, min_neu) + ] fca_concepts = _filter_for_moving_window_subsets(fca_concepts, winlen) # Applying min/max conditions for fca_concept in fca_concepts: @@ -1149,33 +1258,39 @@ def _fast_fca(context, min_c=2, min_z=2, max_z=None, extent = tuple(fca_concept.extent) concepts.append((intent, extent)) # computing spectrum - if report == '#': + if report == "#": spec_matrix[len(intent) - 1, len(extent) - 1] += 1 - elif report == '3d#': - spec_matrix[len(intent) - 1, len(extent) - 1, max( - np.array(intent) % winlen)] += 1 - if report == 'a': + elif report == "3d#": + spec_matrix[ + len(intent) - 1, len(extent) - 1, max(np.array(intent) % winlen) + ] += 1 + if report == "a": return concepts del concepts # returning spectrum - if report == '#': - for (size, occurrence) in np.transpose(np.where(spec_matrix != 0)): + if report == "#": + for size, occurrence in np.transpose(np.where(spec_matrix != 0)): spectrum.append( - (size + 1, occurrence + 1, int(spec_matrix[size, occurrence]))) + (size + 1, occurrence + 1, int(spec_matrix[size, occurrence])) + ) - if report == '3d#': - for (size, occurrence, duration) in\ - np.transpose(np.where(spec_matrix != 0)): + if report == "3d#": + for size, occurrence, duration in np.transpose(np.where(spec_matrix != 0)): spectrum.append( - (size + 1, occurrence + 1, duration, - int(spec_matrix[size, occurrence, duration]))) + ( + size + 1, + occurrence + 1, + duration, + int(spec_matrix[size, occurrence, duration]), + ) + ) del spec_matrix if len(spectrum) > 0: spectrum = np.array(spectrum) - elif report == '#': + elif report == "#": spectrum = np.zeros(shape=(0, 3)) - elif report == '3d#': + elif report == "3d#": spectrum = np.zeros(shape=(0, 4)) return spectrum @@ -1187,19 +1302,31 @@ def _fca_filter(concept, winlen, min_c, min_z, max_c, max_z, min_neu): """ intent = tuple(concept.intent) extent = tuple(concept.extent) - keep_concepts = \ - min_z <= len(intent) <= max_z and \ - min_c <= len(extent) <= max_c and \ - len(np.unique(np.array(intent) // winlen)) >= min_neu and \ - min(np.array(intent) % winlen) == 0 + keep_concepts = ( + min_z <= len(intent) <= max_z + and min_c <= len(extent) <= max_c + and len(np.unique(np.array(intent) // winlen)) >= min_neu + and min(np.array(intent) % winlen) == 0 + ) return keep_concepts -@deprecated_alias(binsize='bin_size') +@deprecated_alias(binsize="bin_size") def pvalue_spectrum( - spiketrains, bin_size, winlen, dither, n_surr, min_spikes=2, min_occ=2, - max_spikes=None, max_occ=None, min_neu=1, spectrum='#', - surr_method='dither_spikes', **surr_kwargs): + spiketrains, + bin_size, + winlen, + dither, + n_surr, + min_spikes=2, + min_occ=2, + max_spikes=None, + max_occ=None, + min_neu=1, + spectrum="#", + surr_method="dither_spikes", + **surr_kwargs, +): """ Compute the p-value spectrum of pattern signatures extracted from surrogates of parallel spike trains, under the null hypothesis of @@ -1295,11 +1422,10 @@ def pvalue_spectrum( size = 1 # Check on number of surrogates if n_surr <= 0: - raise ValueError('n_surr has to be >0') + raise ValueError("n_surr has to be >0") if surr_method not in surr.SURR_METHODS: - raise ValueError( - f'specified surr_method (={surr_method}) not valid') - if spectrum not in ('#', '3d#'): + raise ValueError(f"specified surr_method (={surr_method}) not valid") + if spectrum not in ("#", "3d#"): raise ValueError(f"Invalid spectrum: '{spectrum}'") len_partition = n_surr // size # length of each MPI task @@ -1312,31 +1438,44 @@ def pvalue_spectrum( # number of bins per window. max_spikes = len(spiketrains) * winlen - if spectrum == '#': - max_occs = np.empty(shape=(len_partition + add_remainder, - max_spikes - min_spikes + 1), - dtype=np.uint16) + if spectrum == "#": + max_occs = np.empty( + shape=(len_partition + add_remainder, max_spikes - min_spikes + 1), + dtype=np.uint16, + ) else: # spectrum == '3d#': - max_occs = np.empty(shape=(len_partition + add_remainder, - max_spikes - min_spikes + 1, winlen), - dtype=np.uint16) + max_occs = np.empty( + shape=(len_partition + add_remainder, max_spikes - min_spikes + 1, winlen), + dtype=np.uint16, + ) for surr_id, binned_surrogates in _generate_binned_surrogates( - spiketrains, bin_size=bin_size, dither=dither, - surr_method=surr_method, n_surrogates=len_partition+add_remainder, - **surr_kwargs): - + spiketrains, + bin_size=bin_size, + dither=dither, + surr_method=surr_method, + n_surrogates=len_partition + add_remainder, + **surr_kwargs, + ): # Find all pattern signatures in the current surrogate data set surr_concepts = concepts_mining( - binned_surrogates, bin_size, winlen, min_spikes=min_spikes, - max_spikes=max_spikes, min_occ=min_occ, max_occ=max_occ, - min_neu=min_neu, report=spectrum)[0] + binned_surrogates, + bin_size, + winlen, + min_spikes=min_spikes, + max_spikes=max_spikes, + min_occ=min_occ, + max_occ=max_occ, + min_neu=min_neu, + report=spectrum, + )[0] # The last entry of the signature is the number of times the # signature appeared. This entry is not needed here. surr_concepts = surr_concepts[:, :-1] max_occs[surr_id] = _get_max_occ( - surr_concepts, min_spikes, max_spikes, winlen, spectrum) + surr_concepts, min_spikes, max_spikes, winlen, spectrum + ) # Collecting results on the first PCU if size != 1: @@ -1350,69 +1489,89 @@ def pvalue_spectrum( max_occs = np.vstack(max_occs) # Compute the p-value spectrum, and return it - return _get_pvalue_spec(max_occs, min_spikes, max_spikes, min_occ, - n_surr, winlen, spectrum) + return _get_pvalue_spec( + max_occs, min_spikes, max_spikes, min_occ, n_surr, winlen, spectrum + ) def _generate_binned_surrogates( - spiketrains, bin_size, dither, surr_method, n_surrogates, - **surr_kwargs): - if surr_method == 'bin_shuffling': + spiketrains, bin_size, dither, surr_method, n_surrogates, **surr_kwargs +): + if surr_method == "bin_shuffling": binned_spiketrains = [ - conv.BinnedSpikeTrain( - spiketrain, bin_size=bin_size, tolerance=None) - for spiketrain in spiketrains] - max_displacement = int(dither.rescale(pq.ms).magnitude / - bin_size.rescale(pq.ms).magnitude) - elif surr_method in ('joint_isi_dithering', 'isi_dithering'): - isi_dithering = surr_method == 'isi_dithering' - joint_isi_instances = \ - [surr.JointISI(spiketrain, dither=dither, - isi_dithering=isi_dithering, **surr_kwargs) - for spiketrain in spiketrains] + conv.BinnedSpikeTrain(spiketrain, bin_size=bin_size, tolerance=None) + for spiketrain in spiketrains + ] + max_displacement = int( + dither.rescale(pq.ms).magnitude / bin_size.rescale(pq.ms).magnitude + ) + elif surr_method in ("joint_isi_dithering", "isi_dithering"): + isi_dithering = surr_method == "isi_dithering" + joint_isi_instances = [ + surr.JointISI( + spiketrain, dither=dither, isi_dithering=isi_dithering, **surr_kwargs + ) + for spiketrain in spiketrains + ] for surr_id in range(n_surrogates): - if surr_method == 'bin_shuffling': - binned_surrogates = \ - [surr.bin_shuffling(binned_spiketrain, - max_displacement=max_displacement, - **surr_kwargs)[0] - for binned_spiketrain in binned_spiketrains] + if surr_method == "bin_shuffling": + binned_surrogates = [ + surr.bin_shuffling( + binned_spiketrain, max_displacement=max_displacement, **surr_kwargs + )[0] + for binned_spiketrain in binned_spiketrains + ] binned_surrogates = np.array( - [binned_surrogate.to_bool_array()[0] - for binned_surrogate in binned_surrogates]) + [ + binned_surrogate.to_bool_array()[0] + for binned_surrogate in binned_surrogates + ] + ) binned_surrogates = conv.BinnedSpikeTrain( binned_surrogates, bin_size=bin_size, t_start=spiketrains[0].t_start, t_stop=spiketrains[0].t_stop, - tolerance=None) - elif surr_method in ('joint_isi_dithering', 'isi_dithering'): - surrs = [instance.dithering()[0] - for instance in joint_isi_instances] - elif surr_method == 'dither_spikes_with_refractory_period': + tolerance=None, + ) + elif surr_method in ("joint_isi_dithering", "isi_dithering"): + surrs = [instance.dithering()[0] for instance in joint_isi_instances] + elif surr_method == "dither_spikes_with_refractory_period": # The initial refractory period is set to the bin size in order to # prevent that spikes fall into the same bin, if the spike trains # are sparse (min(ISI)>bin size). - surrs = \ - [surr.dither_spikes( - spiketrain, dither=dither, n_surrogates=1, - refractory_period=bin_size, **surr_kwargs)[0] - for spiketrain in spiketrains] + surrs = [ + surr.dither_spikes( + spiketrain, + dither=dither, + n_surrogates=1, + refractory_period=bin_size, + **surr_kwargs, + )[0] + for spiketrain in spiketrains + ] else: - surrs = \ - [surr.surrogates( - spiketrain, n_surrogates=1, method=surr_method, - dt=dither, **surr_kwargs)[0] - for spiketrain in spiketrains] - - if surr_method != 'bin_shuffling': + surrs = [ + surr.surrogates( + spiketrain, + n_surrogates=1, + method=surr_method, + dt=dither, + **surr_kwargs, + )[0] + for spiketrain in spiketrains + ] + + if surr_method != "bin_shuffling": binned_surrogates = conv.BinnedSpikeTrain( - surrs, bin_size=bin_size, tolerance=None) + surrs, bin_size=bin_size, tolerance=None + ) yield surr_id, binned_surrogates -def _get_pvalue_spec(max_occs, min_spikes, max_spikes, min_occ, n_surr, winlen, - spectrum): +def _get_pvalue_spec( + max_occs, min_spikes, max_spikes, min_occ, n_surr, winlen, spectrum +): """ This function converts the list of maximal occurrences into the corresponding p-value spectrum. @@ -1437,11 +1596,11 @@ def _get_pvalue_spec(max_occs, min_spikes, max_spikes, min_occ, n_surr, winlen, each entry has the form: [pattern_size, pattern_occ, pattern_dur, p_value] """ - if spectrum not in ('#', '3d#'): + if spectrum not in ("#", "3d#"): raise ValueError(f"Invalid spectrum: '{spectrum}'") pv_spec = [] - if spectrum == '#': + if spectrum == "#": max_occs = np.expand_dims(max_occs, axis=2) winlen = 1 for size_id, pt_size in enumerate(range(min_spikes, max_spikes + 1)): @@ -1449,12 +1608,12 @@ def _get_pvalue_spec(max_occs, min_spikes, max_spikes, min_occ, n_surr, winlen, max_occs_size_dur = max_occs[:, size_id, dur] counts, occs = np.histogram( max_occs_size_dur, - bins=np.arange(min_occ, - np.max(max_occs_size_dur) + 2)) + bins=np.arange(min_occ, np.max(max_occs_size_dur) + 2), + ) occs = occs[:-1].astype(np.uint16) pvalues = np.cumsum(counts[::-1])[::-1] / n_surr for occ_id, occ in enumerate(occs): - if spectrum == '#': + if spectrum == "#": pv_spec.append([pt_size, occ, pvalues[occ_id]]) else: # spectrum == '3d#': pv_spec.append([pt_size, occ, dur, pvalues[occ_id]]) @@ -1483,14 +1642,13 @@ def _get_max_occ(surr_concepts, min_spikes, max_spikes, winlen, spectrum): The first axis corresponds to the pattern size the second to the duration. """ - if spectrum == '#': + if spectrum == "#": winlen = 1 max_occ = np.zeros(shape=(max_spikes - min_spikes + 1, winlen)) for size_id, pt_size in enumerate(range(min_spikes, max_spikes + 1)): - concepts_for_size = surr_concepts[ - surr_concepts[:, 0] == pt_size][:, 1:] + concepts_for_size = surr_concepts[surr_concepts[:, 0] == pt_size][:, 1:] for dur in range(winlen): - if spectrum == '#': + if spectrum == "#": occs = concepts_for_size[:, 0] else: # spectrum == '3d#': occs = concepts_for_size[concepts_for_size[:, 1] == dur][:, 0] @@ -1498,8 +1656,8 @@ def _get_max_occ(surr_concepts, min_spikes, max_spikes, winlen, spectrum): for pt_size in range(max_spikes - 1, min_spikes - 1, -1): size_id = pt_size - min_spikes - max_occ[size_id] = np.max(max_occ[size_id:size_id + 2], axis=0) - if spectrum == '#': + max_occ[size_id] = np.max(max_occ[size_id : size_id + 2], axis=0) + if spectrum == "#": max_occ = np.squeeze(max_occ, axis=1) return max_occ @@ -1508,9 +1666,7 @@ def _get_max_occ(surr_concepts, min_spikes, max_spikes, winlen, spectrum): def _stability_filter(concept, stability_thresh): """Criteria by which to filter concepts from the lattice""" # stabilities larger than stability_thresh - keep_concept = \ - concept[2] > stability_thresh[0]\ - or concept[3] > stability_thresh[1] + keep_concept = concept[2] > stability_thresh[0] or concept[3] > stability_thresh[1] return keep_concept @@ -1535,38 +1691,37 @@ def _mask_pvalue_spectrum(pv_spec, concepts, spectrum, winlen): An array of boolean values, indicating if a signature of p-value spectrum is also in the mined concepts of the original data. """ - if spectrum == '#': - signatures = {(len(concept[0]), len(concept[1])) - for concept in concepts} + if spectrum == "#": + signatures = {(len(concept[0]), len(concept[1])) for concept in concepts} else: # spectrum == '3d#': # third entry of signatures is the duration, fixed as the maximum lag - signatures = {(len(concept[0]), len(concept[1]), - max(np.array(concept[0]) % winlen)) - for concept in concepts} + signatures = { + (len(concept[0]), len(concept[1]), max(np.array(concept[0]) % winlen)) + for concept in concepts + } mask = np.zeros(len(pv_spec), dtype=bool) for index, pv_entry in enumerate(pv_spec): - if tuple(pv_entry[:-1]) in signatures \ - and not np.isclose(pv_entry[-1], [1]): + if tuple(pv_entry[:-1]) in signatures and not np.isclose(pv_entry[-1], [1]): # select the highest number of occurrences for size and duration mask[index] = True - if mask[index-1]: - if spectrum == '#': + if mask[index - 1]: + if spectrum == "#": size = pv_spec[index][0] - prev_size = pv_spec[index-1][0] + prev_size = pv_spec[index - 1][0] if prev_size == size: - mask[index-1] = False + mask[index - 1] = False else: size, duration = pv_spec[index][[0, 2]] - prev_size, prev_duration = pv_spec[index-1][[0, 2]] + prev_size, prev_duration = pv_spec[index - 1][[0, 2]] if prev_size == size and duration == prev_duration: - mask[index-1] = False + mask[index - 1] = False return mask -def test_signature_significance(pv_spec, concepts, alpha, winlen, - corr='fdr_bh', report='spectrum', - spectrum='#'): +def test_signature_significance( + pv_spec, concepts, alpha, winlen, corr="fdr_bh", report="spectrum", spectrum="#" +): """ Compute the significance spectrum of a pattern spectrum. @@ -1652,16 +1807,30 @@ def test_signature_significance(pv_spec, concepts, alpha, winlen, if alpha == 1: return [] - if spectrum not in ('#', '3d#'): - raise ValueError("spectrum must be either '#' or '3d#', " + - f"got {spectrum} instead") - if report not in ('spectrum', 'significant', 'non_significant'): - raise ValueError("report must be either 'spectrum'," + - " 'significant' or 'non_significant'," + - f"got {report} instead") - if corr not in ('bonferroni', 'sidak', 'holm-sidak', 'holm', - 'simes-hochberg', 'hommel', 'fdr_bh', 'fdr_by', - 'fdr_tsbh', 'fdr_tsbky', '', 'no'): + if spectrum not in ("#", "3d#"): + raise ValueError( + "spectrum must be either '#' or '3d#', " + f"got {spectrum} instead" + ) + if report not in ("spectrum", "significant", "non_significant"): + raise ValueError( + "report must be either 'spectrum'," + + " 'significant' or 'non_significant'," + + f"got {report} instead" + ) + if corr not in ( + "bonferroni", + "sidak", + "holm-sidak", + "holm", + "simes-hochberg", + "hommel", + "fdr_bh", + "fdr_by", + "fdr_tsbh", + "fdr_tsbky", + "", + "no", + ): raise ValueError("Parameter corr not recognized") pv_spec = np.array(pv_spec) @@ -1674,9 +1843,8 @@ def test_signature_significance(pv_spec, concepts, alpha, winlen, tests = [False] * len(pvalues) if len(pvalues_totest) > 0: - # Compute significance for only the non-trivial tests - if corr in ['', 'no']: # ...without statistical correction + if corr in ["", "no"]: # ...without statistical correction tests_selected = pvalues_totest <= alpha else: try: @@ -1684,10 +1852,12 @@ def test_signature_significance(pv_spec, concepts, alpha, winlen, except ModuleNotFoundError: raise ModuleNotFoundError( "Please run 'pip install statsmodels' if you " - "want to use multiple testing correction") + "want to use multiple testing correction" + ) - tests_selected = sm.multipletests(pvalues_totest, alpha=alpha, - method=corr)[0] + tests_selected = sm.multipletests(pvalues_totest, alpha=alpha, method=corr)[ + 0 + ] # assign each corrected pvalue to its corresponding entry # this breaks @@ -1695,31 +1865,39 @@ def test_signature_significance(pv_spec, concepts, alpha, winlen, tests[index] = value # Return the specified results: - if spectrum == '#': - if report == 'spectrum': - sig_spectrum = [(size, occ, test) - for (size, occ, pv), test in zip(pv_spec, tests)] - elif report == 'significant': - sig_spectrum = [(size, occ) for ((size, occ, pv), test) - in zip(pv_spec, tests) if test] + if spectrum == "#": + if report == "spectrum": + sig_spectrum = [ + (size, occ, test) for (size, occ, pv), test in zip(pv_spec, tests) + ] + elif report == "significant": + sig_spectrum = [ + (size, occ) for ((size, occ, pv), test) in zip(pv_spec, tests) if test + ] else: # report == 'non_significant' - sig_spectrum = [(size, occ) - for ((size, occ, pv), test) in zip(pv_spec, tests) - if not test] + sig_spectrum = [ + (size, occ) + for ((size, occ, pv), test) in zip(pv_spec, tests) + if not test + ] else: # spectrum == '3d#' - if report == 'spectrum': - sig_spectrum =\ - [(size, occ, l, test) - for (size, occ, l, pv), test in zip(pv_spec, tests)] - elif report == 'significant': - sig_spectrum = [(size, occ, l) for ((size, occ, l, pv), test) - in zip(pv_spec, tests) if test] + if report == "spectrum": + sig_spectrum = [ + (size, occ, l, test) for (size, occ, l, pv), test in zip(pv_spec, tests) + ] + elif report == "significant": + sig_spectrum = [ + (size, occ, l) + for ((size, occ, l, pv), test) in zip(pv_spec, tests) + if test + ] else: # report == 'non_significant' - sig_spectrum =\ - [(size, occ, l) - for ((size, occ, l, pv), test) in zip(pv_spec, tests) - if not test] + sig_spectrum = [ + (size, occ, l) + for ((size, occ, l, pv), test) in zip(pv_spec, tests) + if not test + ] return sig_spectrum @@ -1727,18 +1905,16 @@ def _pattern_spectrum_filter(concept, ns_signatures, spectrum, winlen): """ Filter for significant concepts """ - if spectrum == '#': + if spectrum == "#": keep_concept = (len(concept[0]), len(concept[1])) not in ns_signatures - else: # spectrum == '3d#': + else: # spectrum == '3d#': # duration is fixed as the maximum lag duration = max(np.array(concept[0]) % winlen) - keep_concept = (len(concept[0]), len(concept[1]), - duration) not in ns_signatures + keep_concept = (len(concept[0]), len(concept[1]), duration) not in ns_signatures return keep_concept -def approximate_stability(concepts, rel_matrix, n_subsets=0, - delta=0., epsilon=0.): +def approximate_stability(concepts, rel_matrix, n_subsets=0, delta=0.0, epsilon=0.0): r""" Approximate the stability of concepts. Uses the algorithm described in Babin, Kuznetsov (2012): Approximating Concept Stability @@ -1809,11 +1985,16 @@ def approximate_stability(concepts, rel_matrix, n_subsets=0, rank = 0 size = 1 if not (isinstance(n_subsets, int) and n_subsets >= 0): - raise ValueError('n_subsets must be an integer >=0') - if n_subsets == 0 and not (isinstance(delta, float) and delta > 0. and - isinstance(epsilon, float) and epsilon > 0.): - raise ValueError('delta and epsilon must be floats > 0., ' - 'given that n_subsets = 0') + raise ValueError("n_subsets must be an integer >=0") + if n_subsets == 0 and not ( + isinstance(delta, float) + and delta > 0.0 + and isinstance(epsilon, float) + and epsilon > 0.0 + ): + raise ValueError( + "delta and epsilon must be floats > 0., " "given that n_subsets = 0" + ) if len(concepts) == 0: return [] @@ -1821,25 +2002,29 @@ def approximate_stability(concepts, rel_matrix, n_subsets=0, rank_idx = [0] * (size + 1) + [len(concepts)] else: rank_idx = list( - range(0, len(concepts) - len(concepts) % size + 1, - len(concepts) // size)) + [len(concepts)] + range(0, len(concepts) - len(concepts) % size + 1, len(concepts) // size) + ) + [len(concepts)] # Calculate optimal n if n_subsets == 0: - n_subsets = int(round(np.log(2. / delta) / (2 * epsilon ** 2) + 1)) + n_subsets = int(round(np.log(2.0 / delta) / (2 * epsilon**2) + 1)) if rank == 0: - concepts_on_partition = concepts[rank_idx[rank]:rank_idx[rank + 1]] + \ - concepts[rank_idx[-2]:rank_idx[-1]] + concepts_on_partition = ( + concepts[rank_idx[rank] : rank_idx[rank + 1]] + + concepts[rank_idx[-2] : rank_idx[-1]] + ) else: - concepts_on_partition = concepts[rank_idx[rank]:rank_idx[rank + 1]] + concepts_on_partition = concepts[rank_idx[rank] : rank_idx[rank + 1]] output = [] for concept in concepts_on_partition: intent, extent = np.array(concept[0]), np.array(concept[1]) stab_int = _calculate_single_stability_parameter( - intent, extent, n_subsets, rel_matrix, look_at='intent') + intent, extent, n_subsets, rel_matrix, look_at="intent" + ) stab_ext = _calculate_single_stability_parameter( - intent, extent, n_subsets, rel_matrix, look_at='extent') + intent, extent, n_subsets, rel_matrix, look_at="extent" + ) output.append((intent, extent, stab_int, stab_ext)) if size != 1: @@ -1851,9 +2036,9 @@ def approximate_stability(concepts, rel_matrix, n_subsets=0, return output -def _calculate_single_stability_parameter(intent, extent, - n_subsets, rel_matrix, - look_at='intent'): +def _calculate_single_stability_parameter( + intent, extent, n_subsets, rel_matrix, look_at="intent" +): """ Calculates the stability parameter for extent or intent. @@ -1878,7 +2063,7 @@ def _calculate_single_stability_parameter(intent, extent, stability : float Stability parameter for given extent, intent depending on which to look """ - if look_at == 'intent': + if look_at == "intent": element_1, element_2 = intent, extent else: # look_at == 'extent': element_1, element_2 = extent, intent @@ -1886,24 +2071,27 @@ def _calculate_single_stability_parameter(intent, extent, if n_subsets > 2 ** len(element_1): subsets = chain.from_iterable( combinations(element_1, subset_index) - for subset_index in range(len(element_1) + 1)) + for subset_index in range(len(element_1) + 1) + ) else: subsets = _select_random_subsets(element_1, n_subsets) stability = 0 excluded_subsets = [] for subset in subsets: - if any([set(subset).issubset(excluded_subset) - for excluded_subset in excluded_subsets]): + if any( + [ + set(subset).issubset(excluded_subset) + for excluded_subset in excluded_subsets + ] + ): continue # computation of the ' operator for the subset - if look_at == 'intent': - subset_prime = \ - np.where(np.all(rel_matrix[:, subset], axis=1) == 1)[0] + if look_at == "intent": + subset_prime = np.where(np.all(rel_matrix[:, subset], axis=1) == 1)[0] else: # look_at == 'extent': - subset_prime = \ - np.where(np.all(rel_matrix[subset, :], axis=0) == 1)[0] + subset_prime = np.where(np.all(rel_matrix[subset, :], axis=0) == 1)[0] # Condition holds if the closure of the subset of element_1 given in # element_2 is equal to element_2 given in input @@ -1936,8 +2124,9 @@ def _select_random_subsets(element_1, n_subsets): while len(subsets) < n_subsets: num_indices = np.random.binomial(n=len(element_1), p=1 / 2) - random_indices = sorted(np.random.choice( - len(element_1), size=num_indices, replace=False)) + random_indices = sorted( + np.random.choice(len(element_1), size=num_indices, replace=False) + ) random_tuple = tuple(random_indices) if random_tuple not in subsets_indices[num_indices]: @@ -1947,9 +2136,17 @@ def _select_random_subsets(element_1, n_subsets): return subsets -def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, - h_subset_filtering=0, k_superset_filtering=0, - l_covered_spikes=0, min_spikes=2, min_occ=2): +def pattern_set_reduction( + concepts, + ns_signatures, + winlen, + spectrum, + h_subset_filtering=0, + k_superset_filtering=0, + l_covered_spikes=0, + min_spikes=2, + min_occ=2, +): r""" Takes a list concepts and performs pattern set reduction (PSR). @@ -2063,19 +2260,18 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, # Collecting all the possible distances between the windows # of the two concepts - time_diff_all = np.array( - [w2 - w1 for w2 in extent2 for w1 in extent1]) + time_diff_all = np.array([w2 - w1 for w2 in extent2 for w1 in extent1]) # sort time differences by ascending absolute value time_diff_sorting = np.argsort(np.abs(time_diff_all)) sorted_time_diff, sorted_time_diff_occ = np.unique( - time_diff_all[time_diff_sorting], - return_counts=True) + time_diff_all[time_diff_sorting], return_counts=True + ) # only consider time differences that are smaller than winlen # and that correspond to intersections that occur at least min_occ # times time_diff_mask = np.logical_and( - np.abs(sorted_time_diff) < winlen, - sorted_time_diff_occ >= min_occ) + np.abs(sorted_time_diff) < winlen, sorted_time_diff_occ >= min_occ + ) # Rescaling the spike times to realign to real time for time_diff in sorted_time_diff[time_diff_mask]: intent1_new = [t_old - time_diff for t_old in intent1] @@ -2099,7 +2295,8 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, k_superset_filtering=k_superset_filtering, l_covered_spikes=l_covered_spikes, min_spikes=min_spikes, - min_occ=min_occ) + min_occ=min_occ, + ) elif intent2.issuperset(intent1_new): reject2, reject1 = _perform_combined_filtering( @@ -2115,7 +2312,8 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, k_superset_filtering=k_superset_filtering, l_covered_spikes=l_covered_spikes, min_spikes=min_spikes, - min_occ=min_occ) + min_occ=min_occ, + ) else: # none of the intents is a superset of the other one @@ -2131,7 +2329,8 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, spectrum=spectrum, ns_signatures=ns_signatures, k_superset_filtering=k_superset_filtering, - min_spikes=min_spikes) + min_spikes=min_spikes, + ) reject2 = _superset_filter( occ_superset=occ2, size_superset=size2, @@ -2140,7 +2339,8 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, spectrum=spectrum, ns_signatures=ns_signatures, k_superset_filtering=k_superset_filtering, - min_spikes=min_spikes) + min_spikes=min_spikes, + ) # Reject accordingly: if reject1 and reject2: reject1, reject2 = _covered_spikes_criterion( @@ -2148,7 +2348,8 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, size_superset=size1, occ_subset=occ2, size_subset=size2, - l_covered_spikes=l_covered_spikes) + l_covered_spikes=l_covered_spikes, + ) selected[id1] &= not reject1 selected[id2] &= not reject2 @@ -2161,19 +2362,21 @@ def pattern_set_reduction(concepts, ns_signatures, winlen, spectrum, return [p for i, p in enumerate(concepts) if selected[i]] -def _perform_combined_filtering(occ_superset, - size_superset, - dur_superset, - occ_subset, - size_subset, - dur_subset, - spectrum, - ns_signatures, - h_subset_filtering, - k_superset_filtering, - l_covered_spikes, - min_spikes, - min_occ): +def _perform_combined_filtering( + occ_superset, + size_superset, + dur_superset, + occ_subset, + size_subset, + dur_subset, + spectrum, + ns_signatures, + h_subset_filtering, + k_superset_filtering, + l_covered_spikes, + min_spikes, + min_occ, +): """ perform combined filtering (see pattern_set_reduction) @@ -2186,7 +2389,8 @@ def _perform_combined_filtering(occ_superset, spectrum=spectrum, ns_signatures=ns_signatures, h_subset_filtering=h_subset_filtering, - min_occ=min_occ) + min_occ=min_occ, + ) reject_superset = _superset_filter( occ_superset=occ_superset, size_superset=size_superset, @@ -2195,7 +2399,8 @@ def _perform_combined_filtering(occ_superset, spectrum=spectrum, ns_signatures=ns_signatures, k_superset_filtering=k_superset_filtering, - min_spikes=min_spikes) + min_spikes=min_spikes, + ) # Reject the superset and/or the subset accordingly: if reject_superset and reject_subset: reject_superset, reject_subset = _covered_spikes_criterion( @@ -2203,12 +2408,21 @@ def _perform_combined_filtering(occ_superset, size_superset=size_superset, occ_subset=occ_subset, size_subset=size_subset, - l_covered_spikes=l_covered_spikes) + l_covered_spikes=l_covered_spikes, + ) return reject_superset, reject_subset -def _subset_filter(occ_superset, occ_subset, size_subset, dur_subset, spectrum, - ns_signatures=None, h_subset_filtering=0, min_occ=2): +def _subset_filter( + occ_superset, + occ_subset, + size_subset, + dur_subset, + spectrum, + ns_signatures=None, + h_subset_filtering=0, + min_occ=2, +): """ perform subset filtering (see pattern_set_reduction) @@ -2216,7 +2430,7 @@ def _subset_filter(occ_superset, occ_subset, size_subset, dur_subset, spectrum, if ns_signatures is None: ns_signatures = [] occ_diff = occ_subset - occ_superset + h_subset_filtering - if spectrum == '#': + if spectrum == "#": signature_to_test = (size_subset, occ_diff) else: # spectrum == '3d#': signature_to_test = (size_subset, occ_diff, dur_subset) @@ -2224,9 +2438,16 @@ def _subset_filter(occ_superset, occ_subset, size_subset, dur_subset, spectrum, return reject_subset -def _superset_filter(occ_superset, size_superset, dur_superset, size_subset, - spectrum, ns_signatures=None, k_superset_filtering=0, - min_spikes=2): +def _superset_filter( + occ_superset, + size_superset, + dur_superset, + size_subset, + spectrum, + ns_signatures=None, + k_superset_filtering=0, + min_spikes=2, +): """ perform superset filtering (see pattern_set_reduction) @@ -2234,20 +2455,17 @@ def _superset_filter(occ_superset, size_superset, dur_superset, size_subset, if ns_signatures is None: ns_signatures = [] size_diff = size_superset - size_subset + k_superset_filtering - if spectrum == '#': + if spectrum == "#": signature_to_test = (size_diff, occ_superset) else: # spectrum == '3d#': signature_to_test = (size_diff, occ_superset, dur_superset) - reject_superset = \ - size_diff < min_spikes or signature_to_test in ns_signatures + reject_superset = size_diff < min_spikes or signature_to_test in ns_signatures return reject_superset -def _covered_spikes_criterion(occ_superset, - size_superset, - occ_subset, - size_subset, - l_covered_spikes): +def _covered_spikes_criterion( + occ_superset, size_superset, occ_subset, size_subset, l_covered_spikes +): """ evaluate covered spikes criterion (see pattern_set_reduction) @@ -2263,8 +2481,9 @@ def _covered_spikes_criterion(occ_superset, return reject_superset, reject_subset -def concept_output_to_patterns(concepts, winlen, bin_size, pv_spec=None, - spectrum='#', t_start=0 * pq.ms): +def concept_output_to_patterns( + concepts, winlen, bin_size, pv_spec=None, spectrum="#", t_start=0 * pq.ms +): """ Construction of dictionaries containing all the information about a pattern starting from a list of concepts and its associated pvalue_spectrum. @@ -2327,9 +2546,9 @@ def concept_output_to_patterns(concepts, winlen, bin_size, pv_spec=None, pvalue_dict = defaultdict(float) # Creating a dictionary for the pvalue spectrum for entry in pv_spec: - if spectrum == '3d#': + if spectrum == "3d#": pvalue_dict[(entry[0], entry[1], entry[2])] = entry[-1] - if spectrum == '#': + if spectrum == "#": pvalue_dict[(entry[0], entry[1])] = entry[-1] # Initializing list containing all the patterns t_start = t_start.rescale(bin_size.units) @@ -2341,21 +2560,21 @@ def concept_output_to_patterns(concepts, winlen, bin_size, pv_spec=None, # is represented as spiketrain_id * winlen + bin_id # - The ids of the windows in which the pattern occurred in discretized # time (clipping) - output_dict = {'itemset': itemset, 'windows_ids': window_ids} + output_dict = {"itemset": itemset, "windows_ids": window_ids} # Bins relative to the sliding window in which the spikes of patt fall itemset = np.array(itemset) bin_ids_unsort = itemset % winlen order_bin_ids = np.argsort(bin_ids_unsort) bin_ids = bin_ids_unsort[order_bin_ids] # id of the neurons forming the pattern - output_dict['neurons'] = list(itemset[order_bin_ids] // winlen) + output_dict["neurons"] = list(itemset[order_bin_ids] // winlen) # Lags (in bin_sizes units) of the pattern - output_dict['lags'] = bin_ids[1:] * bin_size + output_dict["lags"] = bin_ids[1:] * bin_size # Times (in bin_size units) in which the pattern occurs - output_dict['times'] = sorted(window_ids) * bin_size + t_start + output_dict["times"] = sorted(window_ids) * bin_size + t_start # pattern dictionary appended to the output - if spectrum == '#': + if spectrum == "#": # Signature (size, n occ) of the pattern signature = (len(itemset), len(window_ids)) else: # spectrum == '3d#': @@ -2363,14 +2582,14 @@ def concept_output_to_patterns(concepts, winlen, bin_size, pv_spec=None, # duration is position of the last bin signature = (len(itemset), len(window_ids), bin_ids[-1]) - output_dict['signature'] = signature + output_dict["signature"] = signature # If None is given in input to the pval spectrum the pvalue # is set to -1 (pvalue spectrum not available) if pv_spec is None: - output_dict['pvalue'] = -1 + output_dict["pvalue"] = -1 else: # p-value assigned to the pattern from the pvalue spectrum - output_dict['pvalue'] = pvalue_dict[signature] + output_dict["pvalue"] = pvalue_dict[signature] output.append(output_dict) return output diff --git a/elephant/spade_src/fast_fca.py b/elephant/spade_src/fast_fca.py index 689380a05..bd420130c 100644 --- a/elephant/spade_src/fast_fca.py +++ b/elephant/spade_src/fast_fca.py @@ -43,9 +43,8 @@ class FormalConcept(object): and lectically ordered lists of upper and lower neighbours. """ - def __init__(self, extent=frozenset(), intent=frozenset(), - intentIndexes=[]): - """ intent/extent are a frozensets because they need to be hashable.""" + def __init__(self, extent=frozenset(), intent=frozenset(), intentIndexes=[]): + """intent/extent are a frozensets because they need to be hashable.""" self.cnum = 0 self.extent = extent self.intent = intent @@ -86,16 +85,14 @@ def __lt__(self, other): return 1 def __repr__(self): - """ print the concept.""" + """print the concept.""" strrep = "concept no:" + str(self.cnum) + "\n" strrep += "extent:" + repr(self.extent) + "\n" strrep += "intent:" + repr(self.intent) + "\n" strrep += "introduced objects:" + repr(self.introducedObjects) + "\n" - strrep += "introduced attributes:" + repr( - self.introducedAttributes) + "\n" + strrep += "introduced attributes:" + repr(self.introducedAttributes) + "\n" if hasattr(self, "stability"): - strrep += "stability: {0:1.4f}".format( - self.stability) + "\n" + strrep += "stability: {0:1.4f}".format(self.stability) + "\n" strrep += "upper neighbours: " for un in self.upperNeighbours: strrep += str(un.cnum) + ", " @@ -144,7 +141,7 @@ def __init__(self, relation, objects=None, attributes=None): def objectsPrime(self, objectSet): """return a frozenset of all attributes which are shared by members - of objectSet. """ + of objectSet.""" if len(objectSet) == 0: return frozenset(self.attributes) oiter = iter(objectSet) @@ -155,7 +152,7 @@ def objectsPrime(self, objectSet): def attributesPrime(self, attributeSet): """return a set of all objects which have all attributes in - attribute set. """ + attribute set.""" if len(attributeSet) == 0: return frozenset(self.objects) aiter = iter(attributeSet) @@ -170,19 +167,19 @@ def updateIntent(self, intent, object): def indexList(self, attributeSet): """return ordered list of attribute indexes. For lectic ordering of - concepts. """ + concepts.""" ilist = sorted(map(self.attributes.index, attributeSet)) return ilist class FormalConcepts(object): - """ Computes set of concepts from a binary relation by an algorithm + """Computes set of concepts from a binary relation by an algorithm similar to C. Lindig's Fast Concept Analysis (2002). """ def __init__(self, relation, objects=None, attributes=None, verbose=False): - """ 'relation' has to be an iterable container of tuples. If objects - or attributes are not supplied, determine from relation. """ + """'relation' has to be an iterable container of tuples. If objects + or attributes are not supplied, determine from relation.""" self.context = FormalContext(relation, objects, attributes) self.concepts = [] # a lectically ordered list of concepts" self.intentToConceptDict = dict() @@ -198,7 +195,8 @@ def computeUpperNeighbours(self, concept): # might therefore be used to create upper neighbours via ((G u g)'',(G # u g)') upperNeighbourGeneratingObjects = self.context.objects.difference( - concept.extent) + concept.extent + ) # dictionary of intent => set of generating objects upperNeighbourCandidates = defaultdict(set) for g in upperNeighbourGeneratingObjects: @@ -212,8 +210,9 @@ def computeUpperNeighbours(self, concept): # Store every concept in self.conceptDict, because it will # eventually be used and the closure is expensive to compute extent = self.context.attributesPrime(intent) - curConcept = FormalConcept(extent, intent, - self.context.indexList(intent)) + curConcept = FormalConcept( + extent, intent, self.context.indexList(intent) + ) self.intentToConceptDict[intent] = curConcept # remember which g generated what concept @@ -225,13 +224,14 @@ def computeUpperNeighbours(self, concept): # and only if (G u g)'' \ G = set of all g which generated C. for intent, generatingObjects in upperNeighbourCandidates.items(): extraObjects = self.intentToConceptDict[intent].extent.difference( - concept.extent) + concept.extent + ) if extraObjects == generatingObjects: neighbours.append(self.intentToConceptDict[intent]) return neighbours def numberConceptsAndComputeIntroduced(self): - """ Numbers concepts and computes introduced objects and attributes""" + """Numbers concepts and computes introduced objects and attributes""" for curConNum, curConcept in enumerate(self.concepts): curConcept.cnum = curConNum curConcept.introducedObjects = set(curConcept.extent) @@ -242,25 +242,24 @@ def numberConceptsAndComputeIntroduced(self): curConcept.introducedAttributes.difference_update(un.intent) def computeLattice(self): - """ Computes concepts and lattice. + """Computes concepts and lattice. self.concepts contains lectically ordered list of concepts after completion.""" intent = self.context.objectsPrime(set()) extent = self.context.attributesPrime(intent) - curConcept = FormalConcept(extent, intent, - self.context.indexList(intent)) + curConcept = FormalConcept(extent, intent, self.context.indexList(intent)) self.concepts = [curConcept] self.intentToConceptDict[curConcept.intent] = curConcept curConceptIndex = 0 - progress_bar = tqdm.tqdm(disable=not self.verbose, - desc="computeLattice") + progress_bar = tqdm.tqdm(disable=not self.verbose, desc="computeLattice") while curConceptIndex >= 0: upperNeighbours = self.computeUpperNeighbours(curConcept) for upperNeighbour in upperNeighbours: - upperNeighbourIndex = bisect.bisect(self.concepts, - upperNeighbour) - if upperNeighbourIndex == 0 or self.concepts[ - upperNeighbourIndex - 1] != upperNeighbour: + upperNeighbourIndex = bisect.bisect(self.concepts, upperNeighbour) + if ( + upperNeighbourIndex == 0 + or self.concepts[upperNeighbourIndex - 1] != upperNeighbour + ): self.concepts.insert(upperNeighbourIndex, upperNeighbour) curConceptIndex += 1 diff --git a/elephant/spectral.py b/elephant/spectral.py index ca46033ad..c4ccd1e70 100644 --- a/elephant/spectral.py +++ b/elephant/spectral.py @@ -31,14 +31,24 @@ "multitaper_psd", "multitaper_cross_spectrum", "segmented_multitaper_cross_spectrum", - "multitaper_coherence" + "multitaper_coherence", ] -def welch_psd(signal, n_segments=8, len_segment=None, - frequency_resolution=None, overlap=0.5, fs=1.0, window='hann', - nfft=None, detrend='constant', return_onesided=True, - scaling='density', axis=-1): +def welch_psd( + signal, + n_segments=8, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + fs=1.0, + window="hann", + nfft=None, + detrend="constant", + return_onesided=True, + scaling="density", + axis=-1, +): """ Estimates power spectrum density (PSD) of a given `neo.AnalogSignal` using Welch's method. @@ -203,30 +213,37 @@ def welch_psd(signal, n_segments=8, len_segment=None, """ # 'hanning' window was removed with release of scipy 1.9.0, it was # deprecated since 1.1.0. - if window == 'hanning': - warnings.warn("'hanning' is deprecated and was removed from scipy " - "with release 1.9.0. Please use 'hann' instead", - DeprecationWarning) - window = 'hann' + if window == "hanning": + warnings.warn( + "'hanning' is deprecated and was removed from scipy " + "with release 1.9.0. Please use 'hann' instead", + DeprecationWarning, + ) + window = "hann" # initialize a parameter dict (to be given to scipy.signal.welch()) with # the parameters directly passed on to scipy.signal.welch() - params = {'window': window, 'nfft': nfft, - 'detrend': detrend, 'return_onesided': return_onesided, - 'scaling': scaling, 'axis': axis} + params = { + "window": window, + "nfft": nfft, + "detrend": detrend, + "return_onesided": return_onesided, + "scaling": scaling, + "axis": axis, + } # add the input data to params. When the input is AnalogSignal, the # data is added after rolling the axis for time index to the last data = np.asarray(signal) if isinstance(signal, neo.AnalogSignal): data = np.rollaxis(data, 0, len(data.shape)) - params['x'] = data + params["x"] = data # if the data is given as AnalogSignal, use its attribute to specify # the sampling frequency - if hasattr(signal, 'sampling_rate'): - params['fs'] = signal.sampling_rate.rescale('Hz').magnitude + if hasattr(signal, "sampling_rate"): + params["fs"] = signal.sampling_rate.rescale("Hz").magnitude else: - params['fs'] = fs + params["fs"] = fs if overlap < 0: raise ValueError("overlap must be greater than or equal to 0") @@ -239,13 +256,14 @@ def welch_psd(signal, n_segments=8, len_segment=None, if frequency_resolution <= 0: raise ValueError("frequency_resolution must be positive") if isinstance(frequency_resolution, pq.quantity.Quantity): - dF = frequency_resolution.rescale('Hz').magnitude + dF = frequency_resolution.rescale("Hz").magnitude else: dF = frequency_resolution - nperseg = int(params['fs'] / dF) + nperseg = int(params["fs"] / dF) if nperseg > data.shape[axis]: - raise ValueError("frequency_resolution is too high for the given " - "data size") + raise ValueError( + "frequency_resolution is too high for the given " "data size" + ) elif len_segment is not None: if len_segment <= 0: raise ValueError("len_seg must be a positive number") @@ -263,16 +281,15 @@ def welch_psd(signal, n_segments=8, len_segment=None, # data.shape[-1] # -------------------- =============================== ^^^^^^^^^^^ # summed segment lengths total overlap data length - nperseg = int(data.shape[axis] / (n_segments - overlap * ( - n_segments - 1))) - params['nperseg'] = nperseg - params['noverlap'] = int(nperseg * overlap) + nperseg = int(data.shape[axis] / (n_segments - overlap * (n_segments - 1))) + params["nperseg"] = nperseg + params["noverlap"] = int(nperseg * overlap) freqs, psd = scipy.signal.welch(**params) # attach proper units to return values if isinstance(signal, pq.quantity.Quantity): - if 'scaling' in params and params['scaling'] == 'spectrum': + if "scaling" in params and params["scaling"] == "spectrum": psd = psd * signal.units * signal.units else: psd = psd * signal.units * signal.units / pq.Hz @@ -281,8 +298,9 @@ def welch_psd(signal, n_segments=8, len_segment=None, return freqs, psd -def multitaper_psd(signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, - attach_units=True): +def multitaper_psd( + signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, attach_units=True +): """ Estimates power spectrum density (PSD) of a given 'neo.AnalogSignal' using the Multitaper method. @@ -357,7 +375,7 @@ def multitaper_psd(signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, fs = signal.sampling_rate if isinstance(fs, pq.Quantity): - fs = fs.rescale('Hz').magnitude + fs = fs.rescale("Hz").magnitude # Add a dim if data has only one dimension if data.ndim == 1: @@ -367,17 +385,17 @@ def multitaper_psd(signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, # If peak resolution is pq.Quantity, get magnitude if isinstance(peak_resolution, pq.quantity.Quantity): - peak_resolution = peak_resolution.rescale('Hz').magnitude + peak_resolution = peak_resolution.rescale("Hz").magnitude # Determine time-halfbandwidth product from given parameters if peak_resolution is not None: if peak_resolution <= 0: raise ValueError("peak_resolution must be positive") nw = length_signal / fs * peak_resolution / 2 - num_tapers = int(np.floor(2*nw) - 1) + num_tapers = int(np.floor(2 * nw) - 1) if num_tapers is None: - num_tapers = int(np.floor(2*nw) - 1) + num_tapers = int(np.floor(2 * nw) - 1) else: if not isinstance(num_tapers, int): raise TypeError("num_tapers must be integer") @@ -385,13 +403,12 @@ def multitaper_psd(signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, raise ValueError("num_tapers must be positive") # Generate frequencies of PSD estimate - freqs = np.fft.rfftfreq(length_signal, d=1/fs) + freqs = np.fft.rfftfreq(length_signal, d=1 / fs) # Generate Slepian sequences - slepian_fcts = scipy.signal.windows.dpss(M=length_signal, - NW=nw, - Kmax=num_tapers, - sym=False) + slepian_fcts = scipy.signal.windows.dpss( + M=length_signal, NW=nw, Kmax=num_tapers, sym=False + ) # Calculate approximately independent spectrum estimates # Use broadcasting to match dim for point-wise multiplication @@ -399,7 +416,7 @@ def multitaper_psd(signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, tapered_signal = data[:, np.newaxis, :] * slepian_fcts # Determine Fourier transform of tapered signal - spectrum_estimates = np.abs(np.fft.rfft(tapered_signal, axis=-1))**2 + spectrum_estimates = np.abs(np.fft.rfft(tapered_signal, axis=-1)) ** 2 spectrum_estimates[..., 1:] *= 2 # Average Fourier transform windowed signal @@ -413,9 +430,17 @@ def multitaper_psd(signal, fs=1, nw=4, num_tapers=None, peak_resolution=None, return freqs, psd -def segmented_multitaper_psd(signal, n_segments=1, len_segment=None, - frequency_resolution=None, overlap=0.5, fs=1, - nw=4, num_tapers=None, peak_resolution=None): +def segmented_multitaper_psd( + signal, + n_segments=1, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + fs=1, + nw=4, + num_tapers=None, + peak_resolution=None, +): """ Estimates power spectrum density (PSD) of a given 'neo.AnalogSignal' using Multitaper method @@ -537,16 +562,22 @@ def segmented_multitaper_psd(signal, n_segments=1, len_segment=None, length_signal = np.shape(data)[1] psd_params_dict = { - 'nw': nw, - 'num_tapers': num_tapers, - 'peak_resolution': peak_resolution, - 'attach_units': False + "nw": nw, + "num_tapers": num_tapers, + "peak_resolution": peak_resolution, + "attach_units": False, } freqs, psd = _segmented_apply_func( - data=data, fs=fs, n_segments=n_segments, len_segment=len_segment, - overlap=overlap, frequency_resolution=frequency_resolution, - func=multitaper_psd, func_params_dict=psd_params_dict) + data=data, + fs=fs, + n_segments=n_segments, + len_segment=len_segment, + overlap=overlap, + frequency_resolution=frequency_resolution, + func=multitaper_psd, + func_params_dict=psd_params_dict, + ) # Attach proper units to return values if isinstance(signal, pq.quantity.Quantity): @@ -556,9 +587,15 @@ def segmented_multitaper_psd(signal, n_segments=1, len_segment=None, return freqs, psd -def multitaper_cross_spectrum(signals, fs=1.0, nw=4.0, num_tapers=None, - peak_resolution=None, return_onesided=True, - attach_units=True): +def multitaper_cross_spectrum( + signals, + fs=1.0, + nw=4.0, + num_tapers=None, + peak_resolution=None, + return_onesided=True, + attach_units=True, +): """ Estimates the cross spectrum of a given `neo.AnalogSignal` using the Multitaper method. @@ -655,24 +692,24 @@ def multitaper_cross_spectrum(signals, fs=1.0, nw=4.0, num_tapers=None, # If the data is given as AnalogSignal, use its attribute to specify the # sampling frequency - if hasattr(signals, 'sampling_rate'): - fs = signals.sampling_rate.rescale('Hz') + if hasattr(signals, "sampling_rate"): + fs = signals.sampling_rate.rescale("Hz") # If fs and peak resolution is pq.Quantity, get magnitude if isinstance(fs, pq.quantity.Quantity): - fs = fs.rescale('Hz').magnitude + fs = fs.rescale("Hz").magnitude # Determine time-halfbandwidth product from given parameters if peak_resolution is not None: if isinstance(peak_resolution, pq.quantity.Quantity): - peak_resolution = peak_resolution.rescale('Hz').magnitude + peak_resolution = peak_resolution.rescale("Hz").magnitude if peak_resolution <= 0: raise ValueError("peak_resolution must be positive") nw = length_signal / fs * peak_resolution / 2 - num_tapers = int(np.floor(2*nw) - 1) + num_tapers = int(np.floor(2 * nw) - 1) if num_tapers is None: - num_tapers = int(np.floor(2*nw) - 1) + num_tapers = int(np.floor(2 * nw) - 1) else: if not isinstance(num_tapers, int): raise TypeError("num_tapers must be integer") @@ -680,13 +717,12 @@ def multitaper_cross_spectrum(signals, fs=1.0, nw=4.0, num_tapers=None, raise ValueError("num_tapers must be positive") # Get slepian functions (sym='False' used for spectral analysis) - slepian_fcts = scipy.signal.windows.dpss(M=length_signal, - NW=nw, - Kmax=num_tapers, - sym=False) + slepian_fcts = scipy.signal.windows.dpss( + M=length_signal, NW=nw, Kmax=num_tapers, sym=False + ) # Use broadcasting to match dime for point-wise multiplication - tapered_signals = (data[:, np.newaxis] * slepian_fcts) + tapered_signals = data[:, np.newaxis] * slepian_fcts if return_onesided: # Determine frequencies for real Fourier transform @@ -700,8 +736,9 @@ def multitaper_cross_spectrum(signals, fs=1.0, nw=4.0, num_tapers=None, # Determine full Fourier transform of tapered signal spectrum_estimates = np.fft.fft(tapered_signals, axis=-1) - temp = (spectrum_estimates[np.newaxis, :, :, :] - * np.conjugate(spectrum_estimates[:, np.newaxis, :, :])) + temp = spectrum_estimates[np.newaxis, :, :, :] * np.conjugate( + spectrum_estimates[:, np.newaxis, :, :] + ) # Average Fourier transform windowed signal cross_spec = np.mean(temp, axis=-2, dtype=np.complex64) / fs @@ -713,9 +750,16 @@ def multitaper_cross_spectrum(signals, fs=1.0, nw=4.0, num_tapers=None, return freqs, cross_spec -def _segmented_apply_func(data, func, fs=1.0, n_segments=1, len_segment=None, - frequency_resolution=None, overlap=0.5, - func_params_dict=None): +def _segmented_apply_func( + data, + func, + fs=1.0, + n_segments=1, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + func_params_dict=None, +): """ Estimate a spectral measure of a signal by applying it to segments of a signal and then averaging over the segments. @@ -812,20 +856,21 @@ def _segmented_apply_func(data, func, fs=1.0, n_segments=1, len_segment=None, # If fs and peak resolution is pq.Quantity, get magnitude if isinstance(fs, pq.quantity.Quantity): - fs = fs.rescale('Hz').magnitude + fs = fs.rescale("Hz").magnitude # Determine length per segment - n_per_seg if frequency_resolution is not None: if frequency_resolution <= 0: raise ValueError("frequency_resolution must be positive") if isinstance(frequency_resolution, pq.quantity.Quantity): - dF = frequency_resolution.rescale('Hz').magnitude + dF = frequency_resolution.rescale("Hz").magnitude else: dF = frequency_resolution n_per_seg = int(fs / dF) if n_per_seg > data.shape[axis]: - raise ValueError("frequency_resolution is too high for the given " - "data size") + raise ValueError( + "frequency_resolution is too high for the given " "data size" + ) elif len_segment is not None: if len_segment <= 0: raise ValueError("len_seg must be a positive number") @@ -843,42 +888,38 @@ def _segmented_apply_func(data, func, fs=1.0, n_segments=1, len_segment=None, # data.shape[-1] # -------------------- =============================== ^^^^^^^^^^^ # summed segment lengths total overlap data length - n_per_seg = int(data.shape[axis] / - (n_segments - overlap * (n_segments - 1))) + n_per_seg = int(data.shape[axis] / (n_segments - overlap * (n_segments - 1))) n_overlap = int(n_per_seg * overlap) n_overlap_step = n_per_seg - n_overlap n_segments = int((length_signal - n_overlap) / (n_per_seg - n_overlap)) # Generate frequencies for spectral measure estimate - if func_params_dict.get('return_onesided'): - freqs = np.fft.rfftfreq(n_per_seg, d=1/fs) - elif func_params_dict.get('return_onesided') is None: + if func_params_dict.get("return_onesided"): + freqs = np.fft.rfftfreq(n_per_seg, d=1 / fs) + elif func_params_dict.get("return_onesided") is None: # multitaper_psd uses rfft (i.e. no return_onesided parameter) - freqs = np.fft.rfftfreq(n_per_seg, d=1/fs) + freqs = np.fft.rfftfreq(n_per_seg, d=1 / fs) else: - freqs = np.fft.fftfreq(n_per_seg, d=1/fs) + freqs = np.fft.fftfreq(n_per_seg, d=1 / fs) # Zero-pad signal to fit segment length remainder = length_signal % n_overlap_step - data = np.pad(data, [(0, 0), (0, remainder)], - mode='constant', constant_values=0) + data = np.pad(data, [(0, 0), (0, remainder)], mode="constant", constant_values=0) # Generate array for storing cross spectra estimates of segments - seg_estimates = np.zeros((n_segments, - data.shape[0], - data.shape[0], - len(freqs)), - dtype=np.complex64) + seg_estimates = np.zeros( + (n_segments, data.shape[0], data.shape[0], len(freqs)), dtype=np.complex64 + ) n_overlap_step = n_per_seg - n_overlap for i in range(n_segments): - _, estimate = func( - data[:, i * n_overlap_step:i * n_overlap_step + n_per_seg], - **func_params_dict) + data[:, i * n_overlap_step : i * n_overlap_step + n_per_seg], + **func_params_dict, + ) # Workaround for mismatched dimensions if estimate.ndim != seg_estimates.ndim - 1: # Multitaper PSD @@ -891,12 +932,18 @@ def _segmented_apply_func(data, func, fs=1.0, n_segments=1, len_segment=None, return freqs, avg_estimate -def segmented_multitaper_cross_spectrum(signals, n_segments=1, - len_segment=None, - frequency_resolution=None, overlap=0.5, - fs=1.0, nw=4.0, num_tapers=None, - peak_resolution=None, - return_onesided=True): +def segmented_multitaper_cross_spectrum( + signals, + n_segments=1, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + fs=1.0, + nw=4.0, + num_tapers=None, + peak_resolution=None, + return_onesided=True, +): """ Estimates the cross spectrum of a given `neo.AnalogSignal` using the Multitaper method on segments of the data. @@ -1008,18 +1055,24 @@ def segmented_multitaper_cross_spectrum(signals, n_segments=1, # Initialize argument dictionary for multitaper_cross_spectrum called on # segments cross_spec_params_dict = { - 'nw': nw, - 'num_tapers': num_tapers, - 'peak_resolution': peak_resolution, - 'return_onesided': return_onesided, - 'attach_units': False} + "nw": nw, + "num_tapers": num_tapers, + "peak_resolution": peak_resolution, + "return_onesided": return_onesided, + "attach_units": False, + } # Apply segmentation freqs, cross_spec = _segmented_apply_func( - data=data, fs=fs, n_segments=n_segments, len_segment=len_segment, - overlap=overlap, frequency_resolution=frequency_resolution, + data=data, + fs=fs, + n_segments=n_segments, + len_segment=len_segment, + overlap=overlap, + frequency_resolution=frequency_resolution, func=multitaper_cross_spectrum, - func_params_dict=cross_spec_params_dict) + func_params_dict=cross_spec_params_dict, + ) # Attach proper units to return values if isinstance(signals, pq.quantity.Quantity): @@ -1029,9 +1082,18 @@ def segmented_multitaper_cross_spectrum(signals, n_segments=1, return freqs, cross_spec -def multitaper_coherence(signal_i, signal_j, n_segments=1, len_segment=None, - frequency_resolution=None, overlap=0.5, fs=1, - nw=4, num_tapers=None, peak_resolution=None): +def multitaper_coherence( + signal_i, + signal_j, + n_segments=1, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + fs=1, + nw=4, + num_tapers=None, + peak_resolution=None, +): """ Estimates the magnitude-squared coherence and phase-lag of two given `neo.AnalogSignal` using the Multitaper method. @@ -1101,16 +1163,24 @@ def multitaper_coherence(signal_i, signal_j, n_segments=1, len_segment=None, Phase lags associated with the magnitude-square coherence estimate """ - if isinstance(signal_i, neo.core.AnalogSignal) and \ - isinstance(signal_j, neo.core.AnalogSignal): + if isinstance(signal_i, neo.core.AnalogSignal) and isinstance( + signal_j, neo.core.AnalogSignal + ): signals = signal_i.merge(signal_j) elif isinstance(signal_i, np.ndarray) and isinstance(signal_j, np.ndarray): signals = np.vstack([signal_i, signal_j]) freqs, Pxy = segmented_multitaper_cross_spectrum( - signals=signals, n_segments=n_segments, len_segment=len_segment, - frequency_resolution=frequency_resolution, overlap=overlap, fs=fs, - nw=nw, num_tapers=num_tapers, peak_resolution=peak_resolution) + signals=signals, + n_segments=n_segments, + len_segment=len_segment, + frequency_resolution=frequency_resolution, + overlap=overlap, + fs=fs, + nw=nw, + num_tapers=num_tapers, + peak_resolution=peak_resolution, + ) # Calculate magnitude-squared coherence. coherence = np.abs(Pxy[0, 1]) ** 2 / (Pxy[0, 0].real * Pxy[1, 1].real) @@ -1120,10 +1190,20 @@ def multitaper_coherence(signal_i, signal_j, n_segments=1, len_segment=None, return freqs, coherence, phase_lag -def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, - frequency_resolution=None, overlap=0.5, fs=1.0, - window='hann', nfft=None, detrend='constant', - scaling='density', axis=-1): +def welch_coherence( + signal_i, + signal_j, + n_segments=8, + len_segment=None, + frequency_resolution=None, + overlap=0.5, + fs=1.0, + window="hann", + nfft=None, + detrend="constant", + scaling="density", + axis=-1, +): r""" Estimates coherence between a given pair of analog signals. @@ -1277,15 +1357,22 @@ def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, # TODO: code duplication with welch_psd() # 'hanning' window was removed with release of scipy 1.9.0, it was # deprecated since 1.1.0. - if window == 'hanning': - warnings.warn("'hanning' is deprecated and was removed from scipy " - "with release 1.9.0. Please use 'hann' instead", - DeprecationWarning) - window = 'hann' + if window == "hanning": + warnings.warn( + "'hanning' is deprecated and was removed from scipy " + "with release 1.9.0. Please use 'hann' instead", + DeprecationWarning, + ) + window = "hann" # initialize a parameter dict for scipy.signal.csd() - params = {'window': window, 'nfft': nfft, - 'detrend': detrend, 'scaling': scaling, 'axis': axis} + params = { + "window": window, + "nfft": nfft, + "detrend": detrend, + "scaling": scaling, + "axis": axis, + } # When the input is AnalogSignal, the axis for time index is rolled to # the last @@ -1297,10 +1384,10 @@ def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, # if the data is given as AnalogSignal, use its attribute to specify # the sampling frequency - if hasattr(signal_i, 'sampling_rate'): - params['fs'] = signal_i.sampling_rate.rescale('Hz').magnitude + if hasattr(signal_i, "sampling_rate"): + params["fs"] = signal_i.sampling_rate.rescale("Hz").magnitude else: - params['fs'] = fs + params["fs"] = fs if overlap < 0: raise ValueError("overlap must be greater than or equal to 0") @@ -1311,13 +1398,14 @@ def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, # parameters if frequency_resolution is not None: if isinstance(frequency_resolution, pq.quantity.Quantity): - dF = frequency_resolution.rescale('Hz').magnitude + dF = frequency_resolution.rescale("Hz").magnitude else: dF = frequency_resolution - nperseg = int(params['fs'] / dF) + nperseg = int(params["fs"] / dF) if nperseg > xdata.shape[axis]: - raise ValueError("frequency_resolution is too high for the given" - "data size") + raise ValueError( + "frequency_resolution is too high for the given" "data size" + ) elif len_segment is not None: if len_segment <= 0: raise ValueError("len_seg must be a positive number") @@ -1335,10 +1423,9 @@ def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, # data.shape[-1] # ------------------- =============================== ^^^^^^^^^^^ # summed segment lengths total overlap data length - nperseg = int(xdata.shape[axis] / (n_segments - overlap * ( - n_segments - 1))) - params['nperseg'] = nperseg - params['noverlap'] = int(nperseg * overlap) + nperseg = int(xdata.shape[axis] / (n_segments - overlap * (n_segments - 1))) + params["nperseg"] = nperseg + params["noverlap"] = int(nperseg * overlap) freqs, Pxx = scipy.signal.welch(xdata, **params) _, Pyy = scipy.signal.welch(ydata, **params) @@ -1362,6 +1449,7 @@ def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, def welch_cohere(*args, **kwargs): - warnings.warn("'welch_cohere' is deprecated; use 'welch_coherence'", - DeprecationWarning) + warnings.warn( + "'welch_cohere' is deprecated; use 'welch_coherence'", DeprecationWarning + ) return welch_coherence(*args, **kwargs) diff --git a/elephant/spike_train_correlation.py b/elephant/spike_train_correlation.py index 65ccf6800..8d96017b6 100644 --- a/elephant/spike_train_correlation.py +++ b/elephant/spike_train_correlation.py @@ -14,6 +14,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + from __future__ import division, print_function, unicode_literals import warnings @@ -31,7 +32,7 @@ "correlation_coefficient", "cross_correlation_histogram", "spike_time_tiling_coefficient", - "spike_train_timescale" + "spike_train_timescale", ] # The highest sparsity of the `BinnedSpikeTrain` matrix for which @@ -89,16 +90,20 @@ def get_valid_lags(binned_spiketrain_i, binned_spiketrain_j): if binned_spiketrain_i.n_bins < binned_spiketrain_j.n_bins: # ex. 1) lags range: [-2, 5] ms # ex. 2) lags range: [1, 2] ms - left_edge = (binned_spiketrain_j._t_start - - binned_spiketrain_i._t_start) / bin_size - right_edge = (binned_spiketrain_j._t_stop - - binned_spiketrain_i._t_stop) / bin_size + left_edge = ( + binned_spiketrain_j._t_start - binned_spiketrain_i._t_start + ) / bin_size + right_edge = ( + binned_spiketrain_j._t_stop - binned_spiketrain_i._t_stop + ) / bin_size else: # ex. 3) lags range: [-1, 3] ms - left_edge = (binned_spiketrain_j._t_stop - - binned_spiketrain_i._t_stop) / bin_size - right_edge = (binned_spiketrain_j._t_start - - binned_spiketrain_i._t_start) / bin_size + left_edge = ( + binned_spiketrain_j._t_stop - binned_spiketrain_i._t_stop + ) / bin_size + right_edge = ( + binned_spiketrain_j._t_start - binned_spiketrain_i._t_start + ) / bin_size right_edge = int(right_edge) left_edge = int(left_edge) lags = np.arange(left_edge, right_edge + 1, dtype=np.int32) @@ -126,8 +131,7 @@ def correlate_memory(self, cch_mode): # 'valid' mode requires bins correction due to the shift in t_starts # 'full' and 'pad' modes don't need this correction if cch_mode == "valid": - if self.binned_spiketrain_i.n_bins > \ - self.binned_spiketrain_j.n_bins: + if self.binned_spiketrain_i.n_bins > self.binned_spiketrain_j.n_bins: st2_bin_idx_unique += right_edge else: st2_bin_idx_unique += left_edge @@ -144,14 +148,12 @@ def correlate_memory(self, cch_mode): # Compute the CCH at lags in left_edge,...,right_edge only for idx, i in enumerate(st1_bin_idx_unique): il = np.searchsorted(st2_bin_idx_unique, left_edge + i) - ir = np.searchsorted(st2_bin_idx_unique, - right_edge + i, side='right') + ir = np.searchsorted(st2_bin_idx_unique, right_edge + i, side="right") timediff = st2_bin_idx_unique[il:ir] - i - assert ((timediff >= left_edge) & ( - timediff <= right_edge)).all(), \ - 'Not all the entries of cch lie in the window' - cross_corr[timediff - left_edge] += ( - st1_spmat[idx] * st2_spmat[il:ir]) + assert ( + (timediff >= left_edge) & (timediff <= right_edge) + ).all(), "Not all the entries of cch lie in the window" + cross_corr[timediff - left_edge] += st1_spmat[idx] * st2_spmat[il:ir] st2_bin_idx_unique = st2_bin_idx_unique[il:] st2_spmat = st2_spmat[il:] return cross_corr @@ -175,14 +177,13 @@ def correlate_speed(self, cch_mode): st1_arr = self.binned_spiketrain_i.to_array()[0] st2_arr = self.binned_spiketrain_j.to_array()[0] left_edge, right_edge = self.window - if cch_mode == 'pad': + if cch_mode == "pad": # Zero padding to stay between left_edge and right_edge pad_width = min(max(-left_edge, 0), max(right_edge, 0)) - st2_arr = np.pad(st2_arr, pad_width=pad_width, mode='constant') - cch_mode = 'valid' + st2_arr = np.pad(st2_arr, pad_width=pad_width, mode="constant") + cch_mode = "valid" # Cross correlate the spike trains - cross_corr = scipy.signal.fftconvolve(st2_arr, st1_arr[::-1], - mode=cch_mode) + cross_corr = scipy.signal.fftconvolve(st2_arr, st1_arr[::-1], mode=cch_mode) # convolution of integers is integers cross_corr = np.round(cross_corr) return cross_corr @@ -200,11 +201,13 @@ def border_correction(self, cross_corr): np.ndarray Cross-correlation array with the border correction applied. """ - min_num_bins = min(self.binned_spiketrain_i.n_bins, - self.binned_spiketrain_j.n_bins) + min_num_bins = min( + self.binned_spiketrain_i.n_bins, self.binned_spiketrain_j.n_bins + ) left_edge, right_edge = self.window - valid_lags = _CrossCorrHist.get_valid_lags(self.binned_spiketrain_i, - self.binned_spiketrain_j) + valid_lags = _CrossCorrHist.get_valid_lags( + self.binned_spiketrain_i, self.binned_spiketrain_j + ) lags_to_compute = np.arange(left_edge, right_edge + 1) outer_subtraction = np.subtract.outer(lags_to_compute, valid_lags) min_distance_from_window = np.abs(outer_subtraction).min(axis=1) @@ -232,8 +235,9 @@ def cross_correlation_coefficient(self, cross_corr): np.ndarray Normalized cross-correlation array in range `[-1, 1]`. """ - max_num_bins = max(self.binned_spiketrain_i.n_bins, - self.binned_spiketrain_j.n_bins) + max_num_bins = max( + self.binned_spiketrain_i.n_bins, self.binned_spiketrain_j.n_bins + ) n_spikes1 = self.binned_spiketrain_i.get_num_of_spikes() n_spikes2 = self.binned_spiketrain_j.get_num_of_spikes() data1 = self.binned_spiketrain_i.sparse_matrix.data @@ -241,8 +245,9 @@ def cross_correlation_coefficient(self, cross_corr): ii = data1.dot(data1) jj = data2.dot(data2) cov_mean = n_spikes1 * n_spikes2 / max_num_bins - std_xy = np.sqrt((ii - n_spikes1 ** 2. / max_num_bins) * ( - jj - n_spikes2 ** 2. / max_num_bins)) + std_xy = np.sqrt( + (ii - n_spikes1**2.0 / max_num_bins) * (jj - n_spikes2**2.0 / max_num_bins) + ) cross_corr_normalized = (cross_corr - cov_mean) / std_xy return cross_corr_normalized @@ -268,12 +273,12 @@ def kernel_smoothing(self, cross_corr_array, kernel): # Define the kern for smoothing as an ndarray if len(kernel) > kern_len_max: raise ValueError( - 'The length of the kernel {} cannot be larger than the ' - 'length {} of the resulting CCH.'.format(len(kernel), - kern_len_max)) + "The length of the kernel {} cannot be larger than the " + "length {} of the resulting CCH.".format(len(kernel), kern_len_max) + ) kernel = np.divide(kernel, kernel.sum()) # Smooth the cross-correlation histogram with the kern - return np.convolve(cross_corr_array, kernel, mode='same') + return np.convolve(cross_corr_array, kernel, mode="same") def covariance(binned_spiketrain, binary=False, fast=True): @@ -372,8 +377,7 @@ def covariance(binned_spiketrain, binary=False, fast=True): array = binned_spiketrain.to_array() return np.cov(array) - return _covariance_sparse( - binned_spiketrain, corrcoef_norm=False) + return _covariance_sparse(binned_spiketrain, corrcoef_norm=False) def correlation_coefficient(binned_spiketrain, binary=False, fast=True): @@ -480,13 +484,13 @@ def correlation_coefficient(binned_spiketrain, binary=False, fast=True): array = binned_spiketrain.to_array() return np.corrcoef(array) - return _covariance_sparse( - binned_spiketrain, corrcoef_norm=True) + return _covariance_sparse(binned_spiketrain, corrcoef_norm=True) def corrcoef(*args, **kwargs): - warnings.warn("'corrcoef' is deprecated; use 'correlation_coefficient'", - DeprecationWarning) + warnings.warn( + "'corrcoef' is deprecated; use 'correlation_coefficient'", DeprecationWarning + ) return correlation_coefficient(*args, **kwargs) @@ -534,25 +538,30 @@ def _covariance_sparse(binned_spiketrain, corrcoef_norm): # Check for empty spike trains n_spikes_per_row = spmat.sum(axis=1) if n_spikes_per_row.min() == 0: - warnings.warn( - 'Detected empty spike trains (rows) in the binned_spiketrain.') + warnings.warn("Detected empty spike trains (rows) in the binned_spiketrain.") res = spmat.dot(spmat.T) - n_spikes_per_row * n_spikes_per_row.T / n_bins res = np.asarray(res) if corrcoef_norm: stdx = np.sqrt(res.diagonal()) stdx = np.expand_dims(stdx, axis=0) - res /= (stdx.T * stdx) + res /= stdx.T * stdx else: - res /= (n_bins - 1) + res /= n_bins - 1 res = np.squeeze(res) return res def cross_correlation_histogram( - binned_spiketrain_i, binned_spiketrain_j, window='full', - border_correction=False, binary=False, kernel=None, method='speed', - cross_correlation_coefficient=False): + binned_spiketrain_i, + binned_spiketrain_j, + window="full", + border_correction=False, + binary=False, + kernel=None, + method="speed", + cross_correlation_coefficient=False, +): """ Computes the cross-correlation histogram (CCH) between two binned spike trains `binned_spiketrain_i` and `binned_spiketrain_j`. @@ -700,30 +709,28 @@ def cross_correlation_histogram( # Check that the spike trains are binned with the same temporal # resolution - if binned_spiketrain_i.shape[0] != 1 or \ - binned_spiketrain_j.shape[0] != 1: + if binned_spiketrain_i.shape[0] != 1 or binned_spiketrain_j.shape[0] != 1: raise ValueError("Spike trains must be one dimensional") # rescale to the common units # this does not change the data - only its representation binned_spiketrain_j.rescale(binned_spiketrain_i.units) - if not np.isclose(binned_spiketrain_i._bin_size, - binned_spiketrain_j._bin_size): + if not np.isclose(binned_spiketrain_i._bin_size, binned_spiketrain_j._bin_size): raise ValueError("Bin sizes must be equal") bin_size = binned_spiketrain_i._bin_size left_edge_min = -binned_spiketrain_i.n_bins + 1 right_edge_max = binned_spiketrain_j.n_bins - 1 - t_lags_shift = (binned_spiketrain_j._t_start - - binned_spiketrain_i._t_start) / bin_size + t_lags_shift = ( + binned_spiketrain_j._t_start - binned_spiketrain_i._t_start + ) / bin_size if not np.isclose(t_lags_shift, round(t_lags_shift)): # For example, if bin_size=1 ms, binned_spiketrain_i.t_start=0 ms, and # binned_spiketrain_j.t_start=0.5 ms then there is a global shift in # the binning of the spike trains. - raise ValueError( - "Binned spiketrains time shift is not multiple of bin_size") + raise ValueError("Binned spiketrains time shift is not multiple of bin_size") t_lags_shift = int(round(t_lags_shift)) # In the examples below we fix st2 and "move" st1. @@ -740,35 +747,38 @@ def cross_correlation_histogram( # zero-lag is at 4 ms # Find left and right edges of unaligned (time-dropped) time signals - if len(window) == 2 and np.issubdtype(type(window[0]), np.integer) \ - and np.issubdtype(type(window[1]), np.integer): + if ( + len(window) == 2 + and np.issubdtype(type(window[0]), np.integer) + and np.issubdtype(type(window[1]), np.integer) + ): # ex. 1) lags range: [w[0] - 2, w[1] - 2] ms # ex. 2) lags range: [w[0] + 1, w[1] + 1] ms # ex. 3) lags range: [w[0] + 3, w[0] + 3] ms if window[0] >= window[1]: raise ValueError( "Window's left edge ({left}) must be lower than the right " - "edge ({right})".format(left=window[0], right=window[1])) + "edge ({right})".format(left=window[0], right=window[1]) + ) left_edge, right_edge = np.subtract(window, t_lags_shift) if left_edge < left_edge_min or right_edge > right_edge_max: - raise ValueError( - "The window exceeds the length of the spike trains") + raise ValueError("The window exceeds the length of the spike trains") lags = np.arange(window[0], window[1] + 1, dtype=np.int32) - cch_mode = 'pad' - elif window == 'full': + cch_mode = "pad" + elif window == "full": # cch computed for all the possible entries # ex. 1) lags range: [-6, 9] ms # ex. 2) lags range: [-4, 7] ms # ex. 3) lags range: [-2, 4] ms left_edge = left_edge_min right_edge = right_edge_max - lags = np.arange(left_edge + t_lags_shift, - right_edge + 1 + t_lags_shift, dtype=np.int32) + lags = np.arange( + left_edge + t_lags_shift, right_edge + 1 + t_lags_shift, dtype=np.int32 + ) cch_mode = window - elif window == 'valid': - lags = _CrossCorrHist.get_valid_lags(binned_spiketrain_i, - binned_spiketrain_j) - left_edge, right_edge = lags[(0, -1), ] + elif window == "valid": + lags = _CrossCorrHist.get_valid_lags(binned_spiketrain_i, binned_spiketrain_j) + left_edge, right_edge = lags[(0, -1),] cch_mode = window else: raise ValueError("Invalid window parameter") @@ -777,18 +787,20 @@ def cross_correlation_histogram( binned_spiketrain_i = binned_spiketrain_i.binarize() binned_spiketrain_j = binned_spiketrain_j.binarize() - cch_builder = _CrossCorrHist(binned_spiketrain_i, binned_spiketrain_j, - window=(left_edge, right_edge)) - if method == 'memory': + cch_builder = _CrossCorrHist( + binned_spiketrain_i, binned_spiketrain_j, window=(left_edge, right_edge) + ) + if method == "memory": cross_corr = cch_builder.correlate_memory(cch_mode=cch_mode) else: cross_corr = cch_builder.correlate_speed(cch_mode=cch_mode) if border_correction: - if window == 'valid': + if window == "valid": warnings.warn( "Border correction does not have any effect in " - "'valid' window mode since there are no border effects!") + "'valid' window mode since there are no border effects!" + ) else: cross_corr = cch_builder.border_correction(cross_corr) if kernel is not None: @@ -796,21 +808,28 @@ def cross_correlation_histogram( if cross_correlation_coefficient: cross_corr = cch_builder.cross_correlation_coefficient(cross_corr) - normalization = 'normalized' if cross_correlation_coefficient else 'counts' - annotations = dict(window=window, border_correction=border_correction, - binary=binary, kernel=kernel is not None, - normalization=normalization) + normalization = "normalized" if cross_correlation_coefficient else "counts" + annotations = dict( + window=window, + border_correction=border_correction, + binary=binary, + kernel=kernel is not None, + normalization=normalization, + ) annotations = dict(cch_parameters=annotations) # Transform the array count into an AnalogSignal - t_start = pq.Quantity((lags[0] - 0.5) * bin_size, - units=binned_spiketrain_i.units, copy=False) + t_start = pq.Quantity( + (lags[0] - 0.5) * bin_size, units=binned_spiketrain_i.units, copy=False + ) cch_result = neo.AnalogSignal( signal=np.expand_dims(cross_corr, axis=1), units=pq.dimensionless, t_start=t_start, - sampling_period=binned_spiketrain_i.bin_size, copy=False, - **annotations) + sampling_period=binned_spiketrain_i.bin_size, + copy=False, + **annotations, + ) return cch_result, lags @@ -818,9 +837,11 @@ def cross_correlation_histogram( cch = cross_correlation_histogram -def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain, - spiketrain_j: neo.core.SpikeTrain, - dt: pq.Quantity = 0.005 * pq.s) -> float: +def spike_time_tiling_coefficient( + spiketrain_i: neo.core.SpikeTrain, + spiketrain_j: neo.core.SpikeTrain, + dt: pq.Quantity = 0.005 * pq.s, +) -> float: """ Calculates the Spike Time Tiling Coefficient (STTC) as described in :cite:`correlation-Cutts2014_14288` following their implementation in C. @@ -896,9 +917,11 @@ def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain, if dt.units != spiketrain_i.units: dt = dt.rescale(spiketrain_i.units) - def run_p(spiketrain_j: neo.core.SpikeTrain, - spiketrain_i: neo.core.SpikeTrain, - dt: pq.Quantity = dt) -> float: + def run_p( + spiketrain_j: neo.core.SpikeTrain, + spiketrain_i: neo.core.SpikeTrain, + dt: pq.Quantity = dt, + ) -> float: """ Returns number of spikes in spiketrain_j which lie within +- dt of any spike from spiketrain_i, divided by the total number of spikes in @@ -909,14 +932,15 @@ def run_p(spiketrain_j: neo.core.SpikeTrain, tiled_spikes_j = np.isclose( spiketrain_j.times.magnitude[:, np.newaxis], spiketrain_i.times.magnitude, - atol=dt.item()) + atol=dt.item(), + ) # Determine which spikes in spiketrain_j satisfy the time window # condition. tiled_spike_indices = np.any(tiled_spikes_j, axis=1) # Extract the spike times in spiketrain_j that satisfy the condition. tiled_spikes_j = spiketrain_j[tiled_spike_indices] # Calculate the ratio of matching spikes in j to the total spikes in j. - return len(tiled_spikes_j)/len(spiketrain_j) + return len(tiled_spikes_j) / len(spiketrain_j) def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float: """ @@ -950,7 +974,7 @@ def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float: covered_time_overlap += sorted_spikes[0] - t_start else: covered_time_non_overlap += dt - if t_stop - sorted_spikes[- 1] < dt: + if t_stop - sorted_spikes[-1] < dt: covered_time_overlap += t_stop - sorted_spikes[-1] else: covered_time_non_overlap += dt @@ -978,14 +1002,13 @@ def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float: # is within dt of a spike in the other train, # so we set the respective (partial) index to 1. if PA * TB == 1 and PB * TA == 1: - index = 1. + index = 1.0 elif PA * TB == 1: index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA) elif PB * TA == 1: index = 0.5 + 0.5 * (PA - TB) / (1 - PA * TB) else: - index = 0.5 * (PA - TB) / (1 - PA * TB) + \ - 0.5 * (PB - TA) / (1 - PB * TA) + index = 0.5 * (PA - TB) / (1 - PA * TB) + 0.5 * (PB - TA) / (1 - PB * TA) return index @@ -1047,8 +1070,9 @@ def spike_train_timescale(binned_spiketrain, max_tau): """ if binned_spiketrain.get_num_of_spikes() < 2: - warnings.warn("Spike train contains less than 2 spikes! " - "np.nan will be returned.") + warnings.warn( + "Spike train contains less than 2 spikes! " "np.nan will be returned." + ) return np.nan bin_size = binned_spiketrain._bin_size @@ -1064,8 +1088,10 @@ def spike_train_timescale(binned_spiketrain, max_tau): cch_window = [-max_tau_bins, max_tau_bins] corrfct, bin_ids = cross_correlation_histogram( - binned_spiketrain, binned_spiketrain, window=cch_window, - cross_correlation_coefficient=True + binned_spiketrain, + binned_spiketrain, + window=cch_window, + cross_correlation_coefficient=True, ) # Take only t > 0 values, in particular neglecting the delta peak. start_id = corrfct.time_index((bin_size / 2) * binned_spiketrain.units) diff --git a/elephant/spike_train_dissimilarity.py b/elephant/spike_train_dissimilarity.py index 3234f8916..6f2c6f19c 100644 --- a/elephant/spike_train_dissimilarity.py +++ b/elephant/spike_train_dissimilarity.py @@ -28,14 +28,12 @@ import elephant.kernels as kernels -__all__ = [ - "victor_purpura_distance", - "van_rossum_distance" -] +__all__ = ["victor_purpura_distance", "van_rossum_distance"] def _create_matrix_from_indexed_function( - shape, func, symmetric_2d=False, **func_params): + shape, func, symmetric_2d=False, **func_params +): mat = np.empty(shape) if symmetric_2d: for i in range(shape[0]): @@ -47,8 +45,9 @@ def _create_matrix_from_indexed_function( return mat -def victor_purpura_distance(spiketrains, cost_factor=1.0 * pq.Hz, kernel=None, - sort=True, algorithm='fast'): +def victor_purpura_distance( + spiketrains, cost_factor=1.0 * pq.Hz, kernel=None, sort=True, algorithm="fast" +): """ Calculates the Victor-Purpura's (VP) distance. It is often denoted as :math:`D^{\\text{spike}}[q]`. @@ -119,14 +118,18 @@ def victor_purpura_distance(spiketrains, cost_factor=1.0 * pq.Hz, kernel=None, ... algorithm='intuitive')[0, 1] """ for train in spiketrains: - if not (isinstance(train, (pq.quantity.Quantity, SpikeTrain)) and - train.dimensionality.simplified == - pq.Quantity(1, "s").dimensionality.simplified): + if not ( + isinstance(train, (pq.quantity.Quantity, SpikeTrain)) + and train.dimensionality.simplified + == pq.Quantity(1, "s").dimensionality.simplified + ): raise TypeError("Spike trains must have a time unit.") - if not (isinstance(cost_factor, pq.quantity.Quantity) and - cost_factor.dimensionality.simplified == - pq.Quantity(1, "Hz").dimensionality.simplified): + if not ( + isinstance(cost_factor, pq.quantity.Quantity) + and cost_factor.dimensionality.simplified + == pq.Quantity(1, "Hz").dimensionality.simplified + ): raise TypeError("cost_factor must be a rate quantity.") if kernel is None: @@ -136,26 +139,27 @@ def victor_purpura_distance(spiketrains, cost_factor=1.0 * pq.Hz, kernel=None, if cost_factor == np.inf: num_spikes = np.atleast_2d([st.size for st in spiketrains]) return num_spikes.T + num_spikes - kernel = kernels.TriangularKernel( - sigma=2.0 / (np.sqrt(6.0) * cost_factor)) + kernel = kernels.TriangularKernel(sigma=2.0 / (np.sqrt(6.0) * cost_factor)) if sort: - spiketrains = [np.sort(st.view(type=pq.Quantity)) - for st in spiketrains] + spiketrains = [np.sort(st.view(type=pq.Quantity)) for st in spiketrains] def compute(i, j): if i == j: return 0.0 - if algorithm == 'fast': + if algorithm == "fast": return _victor_purpura_dist_for_st_pair_fast( - spiketrains[i], spiketrains[j], kernel) - if algorithm == 'intuitive': + spiketrains[i], spiketrains[j], kernel + ) + if algorithm == "intuitive": return _victor_purpura_dist_for_st_pair_intuitive( - spiketrains[i], spiketrains[j], cost_factor) + spiketrains[i], spiketrains[j], cost_factor + ) raise NameError("The algorithm must be either 'fast' or 'intuitive'.") return _create_matrix_from_indexed_function( - (len(spiketrains), len(spiketrains)), compute, kernel.is_symmetric()) + (len(spiketrains), len(spiketrains)), compute, kernel.is_symmetric() + ) def _victor_purpura_dist_for_st_pair_fast(spiketrain_a, spiketrain_b, kernel): @@ -216,32 +220,38 @@ def _victor_purpura_dist_for_st_pair_fast(spiketrain_a, spiketrain_b, kernel): min_dim, max_dim = spiketrain_b.size, spiketrain_a.size + 1 cost = np.asfortranarray(np.tile(np.arange(float(max_dim)), (2, 1))) decreasing_sequence = np.asfortranarray(cost[:, ::-1]) - kern = kernel((np.atleast_2d(spiketrain_a).T.view(type=pq.Quantity) - - spiketrain_b.view(type=pq.Quantity))) - as_fortran = np.asfortranarray( - ((np.sqrt(6.0) * kernel.sigma) * kern).simplified) + kern = kernel( + ( + np.atleast_2d(spiketrain_a).T.view(type=pq.Quantity) + - spiketrain_b.view(type=pq.Quantity) + ) + ) + as_fortran = np.asfortranarray(((np.sqrt(6.0) * kernel.sigma) * kern).simplified) k = 1 - 2 * as_fortran for i in range(min_dim): # determine G[i, i] == accumulated_min[:, 0] - accumulated_min = cost[:, :-i - 1] + k[i:, i] - accumulated_min[1, :spiketrain_b.size - i] = \ - cost[1, :spiketrain_b.size - i] + k[i, i:] + accumulated_min = cost[:, : -i - 1] + k[i:, i] + accumulated_min[1, : spiketrain_b.size - i] = ( + cost[1, : spiketrain_b.size - i] + k[i, i:] + ) accumulated_min = np.minimum( accumulated_min, # shift - cost[:, 1:max_dim - i]) # insert + cost[:, 1 : max_dim - i], + ) # insert acc_dim = accumulated_min.shape[1] # delete vs min(insert, shift) accumulated_min[:, 0] = min(cost[1, 1], accumulated_min[0, 0]) # determine G[i, :] and G[:, i] by propagating minima. - accumulated_min += decreasing_sequence[:, -acc_dim - 1:-1] + accumulated_min += decreasing_sequence[:, -acc_dim - 1 : -1] accumulated_min = np.minimum.accumulate(accumulated_min, axis=1) cost[:, :acc_dim] = accumulated_min - decreasing_sequence[:, -acc_dim:] return cost[0, -min_dim - 1] -def _victor_purpura_dist_for_st_pair_intuitive(spiketrain_a, spiketrain_b, - cost_factor=1.0 * pq.Hz): +def _victor_purpura_dist_for_st_pair_intuitive( + spiketrain_a, spiketrain_b, cost_factor=1.0 * pq.Hz +): """ Function to calculate the Victor-Purpura distance between two spike trains described in *J. D. Victor and K. P. Purpura, Nature and precision of @@ -277,19 +287,23 @@ def _victor_purpura_dist_for_st_pair_intuitive(spiketrain_a, spiketrain_b, """ nspk_a = len(spiketrain_a) nspk_b = len(spiketrain_b) - scr = np.zeros((nspk_a+1, nspk_b+1)) - scr[:, 0] = range(0, nspk_a+1) - scr[0, :] = range(0, nspk_b+1) + scr = np.zeros((nspk_a + 1, nspk_b + 1)) + scr[:, 0] = range(0, nspk_a + 1) + scr[0, :] = range(0, nspk_b + 1) if nspk_a > 0 and nspk_b > 0: - for i in range(1, nspk_a+1): - for j in range(1, nspk_b+1): - scr[i, j] = min(scr[i-1, j]+1, scr[i, j-1]+1) - scr[i, j] = min(scr[i, j], scr[i-1, j-1] + - np.float64(( - cost_factor * abs( - spiketrain_a[i - 1] - - spiketrain_b[j - 1])).simplified)) + for i in range(1, nspk_a + 1): + for j in range(1, nspk_b + 1): + scr[i, j] = min(scr[i - 1, j] + 1, scr[i, j - 1] + 1) + scr[i, j] = min( + scr[i, j], + scr[i - 1, j - 1] + + np.float64( + ( + cost_factor * abs(spiketrain_a[i - 1] - spiketrain_b[j - 1]) + ).simplified + ), + ) return scr[nspk_a, nspk_b] @@ -339,14 +353,18 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True): >>> vr = van_rossum_distance([st_a, st_b], tau)[0, 1] """ for train in spiketrains: - if not (isinstance(train, (pq.quantity.Quantity, SpikeTrain)) and - train.dimensionality.simplified == - pq.Quantity(1, "s").dimensionality.simplified): + if not ( + isinstance(train, (pq.quantity.Quantity, SpikeTrain)) + and train.dimensionality.simplified + == pq.Quantity(1, "s").dimensionality.simplified + ): raise TypeError("Spike trains must have a time unit.") - if not (isinstance(time_constant, pq.quantity.Quantity) and - time_constant.dimensionality.simplified == - pq.Quantity(1, "s").dimensionality.simplified): + if not ( + isinstance(time_constant, pq.quantity.Quantity) + and time_constant.dimensionality.simplified + == pq.Quantity(1, "s").dimensionality.simplified + ): raise TypeError("tau must be a time quantity.") if time_constant == 0: @@ -357,12 +375,11 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True): return np.absolute(spike_counts - np.atleast_2d(spike_counts).T) k_dist = _summed_dist_matrix( - [st.view(type=pq.Quantity) - for st in spiketrains], time_constant, not sort) + [st.view(type=pq.Quantity) for st in spiketrains], time_constant, not sort + ) vr_dist = np.empty_like(k_dist) for i, j in np.ndindex(k_dist.shape): - vr_dist[i, j] = ( - k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i]) + vr_dist[i, j] = k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i] return np.sqrt(vr_dist) @@ -389,8 +406,7 @@ def _summed_dist_matrix(spiketrains, tau, presorted=False): values.fill(np.nan) for i, v in enumerate(spiketrains): if v.size > 0: - values[i, :v.size] = \ - (v / tau * pq.dimensionless).simplified + values[i, : v.size] = (v / tau * pq.dimensionless).simplified exp_diffs = np.exp(values[:, :-1] - values[:, 1:]) markage = np.zeros(values.shape) @@ -405,18 +421,20 @@ def _summed_dist_matrix(spiketrains, tau, presorted=False): # Cross spiketrain terms for u in range(D.shape[0]): - all_ks = np.searchsorted(values[u], values, 'left') - 1 + all_ks = np.searchsorted(values[u], values, "left") - 1 for v in range(u): - js = np.searchsorted(values[v], values[u], 'right') - 1 + js = np.searchsorted(values[v], values[u], "right") - 1 ks = all_ks[v] - slice_j = np.s_[np.searchsorted(js, 0):sizes[u]] - slice_k = np.s_[np.searchsorted(ks, 0):sizes[v]] + slice_j = np.s_[np.searchsorted(js, 0) : sizes[u]] + slice_k = np.s_[np.searchsorted(ks, 0) : sizes[v]] D[u, v] = np.sum( - np.exp(values[v][js[slice_j]] - values[u][slice_j]) * - (1.0 + markage[v][js[slice_j]])) + np.exp(values[v][js[slice_j]] - values[u][slice_j]) + * (1.0 + markage[v][js[slice_j]]) + ) D[u, v] += np.sum( - np.exp(values[u][ks[slice_k]] - values[v][slice_k]) * - (1.0 + markage[u][ks[slice_k]])) + np.exp(values[u][ks[slice_k]] - values[v][slice_k]) + * (1.0 + markage[u][ks[slice_k]]) + ) D[v, u] = D[u, v] return D diff --git a/elephant/spike_train_generation.py b/elephant/spike_train_generation.py index 1c279c61a..82a46d5e7 100644 --- a/elephant/spike_train_generation.py +++ b/elephant/spike_train_generation.py @@ -79,12 +79,17 @@ "homogeneous_gamma_process", "inhomogeneous_gamma_process", "single_interaction_process", - "compound_poisson_process" + "compound_poisson_process", ] -def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', - time_stamps=None, interval=(-2 * pq.ms, 4 * pq.ms)): +def spike_extraction( + signal, + threshold=0.0 * pq.mV, + sign="above", + time_stamps=None, + interval=(-2 * pq.ms, 4 * pq.ms), +): """ Return the peak times for all events that cross threshold and the waveforms. Usually used for extracting spikes from a membrane @@ -125,17 +130,23 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', # Get spike time_stamps if time_stamps is None: time_stamps = peak_detection(signal, threshold, sign=sign) - elif hasattr(time_stamps, 'times'): + elif hasattr(time_stamps, "times"): time_stamps = time_stamps.times elif isinstance(time_stamps, pq.Quantity): - raise TypeError("time_stamps must be None, a pq.Quantity array or" + - " expose the.times interface") + raise TypeError( + "time_stamps must be None, a pq.Quantity array or" + + " expose the.times interface" + ) if len(time_stamps) == 0: - return neo.SpikeTrain(time_stamps, units=signal.times.units, - t_start=signal.t_start, t_stop=signal.t_stop, - waveforms=np.array([]), - sampling_rate=signal.sampling_rate) + return neo.SpikeTrain( + time_stamps, + units=signal.times.units, + t_start=signal.t_start, + t_stop=signal.t_stop, + waveforms=np.array([]), + sampling_rate=signal.sampling_rate, + ) # Unpack the extraction interval from tuple or array extr_left, extr_right = interval @@ -149,8 +160,9 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', data_right = (extr_right * signal.sampling_rate).simplified.magnitude - data_stamps = (((time_stamps - signal.t_start) * - signal.sampling_rate).simplified).magnitude + data_stamps = ( + ((time_stamps - signal.t_start) * signal.sampling_rate).simplified + ).magnitude data_stamps = data_stamps.astype(int) @@ -160,32 +172,40 @@ def spike_extraction(signal, threshold=0.0 * pq.mV, sign='above', borders = np.dstack((borders_left, borders_right)).flatten() - waveforms = np.array( - np.split(np.array(signal), borders.astype(int))[1::2]) * signal.units + waveforms = ( + np.array(np.split(np.array(signal), borders.astype(int))[1::2]) * signal.units + ) # len(np.shape(waveforms)) == 1 if waveforms do not have the same width. # this can occur when extraction interval indexes beyond the signal. # Workaround: delete spikes shorter than the maximum length with if len(np.shape(waveforms)) == 1: max_len = max(len(waveform) for waveform in waveforms) - to_delete = np.array([idx for idx, x in enumerate(waveforms) - if len(x) < max_len]) + to_delete = np.array( + [idx for idx, x in enumerate(waveforms) if len(x) < max_len] + ) waveforms = np.delete(waveforms, to_delete, axis=0) - warnings.warn("Waveforms " + - ("{:d}, " * len(to_delete)).format(*to_delete) + - "exceeded signal and had to be deleted. " + - "Change 'interval' to keep.") + warnings.warn( + "Waveforms " + + ("{:d}, " * len(to_delete)).format(*to_delete) + + "exceeded signal and had to be deleted. " + + "Change 'interval' to keep." + ) waveforms = waveforms[:, np.newaxis, :] - return neo.SpikeTrain(time_stamps, units=signal.times.units, - t_start=signal.t_start, t_stop=signal.t_stop, - sampling_rate=signal.sampling_rate, - waveforms=waveforms, - left_sweep=extr_left) + return neo.SpikeTrain( + time_stamps, + units=signal.times.units, + t_start=signal.t_start, + t_stop=signal.t_stop, + sampling_rate=signal.sampling_rate, + waveforms=waveforms, + left_sweep=extr_left, + ) -def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'): +def threshold_detection(signal, threshold=0.0 * pq.mV, sign="above"): """ Returns the times when the analog signal crosses a threshold. Usually used for extracting spike times from a membrane potential. @@ -210,12 +230,12 @@ def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'): """ if not isinstance(threshold, pq.Quantity): - raise ValueError('threshold must be a pq.Quantity') + raise ValueError("threshold must be a pq.Quantity") - if sign not in ('above', 'below'): + if sign not in ("above", "below"): raise ValueError("sign should be 'above' or 'below'") - if sign == 'above': + if sign == "above": cutout = np.where(signal > threshold)[0] else: # sign == 'below' @@ -234,16 +254,18 @@ def threshold_detection(signal, threshold=0.0 * pq.mV, sign='above'): if events_base is None: # This occurs in some Python 3 builds due to some # bug in quantities. - events_base = np.array( - [event.magnitude for event in events]) # Workaround - - result_st = neo.SpikeTrain(events_base, units=signal.times.units, - t_start=signal.t_start, t_stop=signal.t_stop) + events_base = np.array([event.magnitude for event in events]) # Workaround + + result_st = neo.SpikeTrain( + events_base, + units=signal.times.units, + t_start=signal.t_start, + t_stop=signal.t_stop, + ) return result_st -def peak_detection(signal, threshold=0.0 * pq.mV, sign='above', - as_array=False): +def peak_detection(signal, threshold=0.0 * pq.mV, sign="above", as_array=False): """ Return the peak times for all events that cross threshold. Usually used for extracting spike times from a membrane potential. @@ -274,10 +296,10 @@ def peak_detection(signal, threshold=0.0 * pq.mV, sign='above', if not isinstance(threshold, pq.Quantity): raise ValueError("threshold must be a pq.Quantity") - if sign not in ('above', 'below'): + if sign not in ("above", "below"): raise ValueError("sign should be 'above' or 'below'") - if sign == 'above': + if sign == "above": cutout = np.where(signal > threshold)[0] peak_func = np.argmax else: @@ -300,8 +322,7 @@ def peak_detection(signal, threshold=0.0 * pq.mV, sign='above', # Workaround for bug that occurs when signal goes below thr for 1 dtp, # Workaround eliminates empty slices from np. split backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0 - forward_mask = np.absolute(np.ediff1d(true_borders[::-1], - to_begin=1)[::-1]) > 0 + forward_mask = np.absolute(np.ediff1d(true_borders[::-1], to_begin=1)[::-1]) > 0 true_borders = true_borders[backward_mask * forward_mask] split_signal = np.split(np.array(signal), true_borders)[1::2] @@ -315,12 +336,14 @@ def peak_detection(signal, threshold=0.0 * pq.mV, sign='above', if events_base is None: # This occurs in some Python 3 builds due to some # bug in quantities. - events_base = np.array( - [event.magnitude for event in events]) # Workaround - - result_st = neo.SpikeTrain(events_base, units=signal.times.units, - t_start=signal.t_start, - t_stop=signal.t_stop) + events_base = np.array([event.magnitude for event in events]) # Workaround + + result_st = neo.SpikeTrain( + events_base, + units=signal.times.units, + t_start=signal.t_start, + t_stop=signal.t_stop, + ) if as_array: result_st = result_st.magnitude @@ -340,16 +363,14 @@ class AbstractPointProcess: The end of the spike train. Default: 1.*pq.s """ + def __init__( - self, - t_stop: pq.Quantity = 1.*pq.s, - t_start: pq.Quantity = 0.*pq.s + self, t_stop: pq.Quantity = 1.0 * pq.s, t_start: pq.Quantity = 0.0 * pq.s ): - if not (isinstance(t_start, pq.Quantity) and - isinstance(t_stop, pq.Quantity)): + if not (isinstance(t_start, pq.Quantity) and isinstance(t_stop, pq.Quantity)): raise ValueError("t_start and t_stop must be of type pq.Quantity") if t_stop <= t_start: - raise ValueError('t_start must be smaller than t_stop.') + raise ValueError("t_start must be smaller than t_stop.") self.units = t_stop.units self._t_stop = t_stop.item() @@ -373,7 +394,7 @@ def _generate_spiketrain_as_array(self) -> np.ndarray: raise NotImplementedError def generate_spiketrain( - self, as_array: bool = False + self, as_array: bool = False ) -> Union[neo.SpikeTrain, np.ndarray]: """ Generates a single spike train. @@ -395,13 +416,11 @@ def generate_spiketrain( return spikes # else: return neo.SpikeTrain( - spikes, - t_start=self.t_start, t_stop=self.t_stop, units=self.units) + spikes, t_start=self.t_start, t_stop=self.t_stop, units=self.units + ) def generate_n_spiketrains( - self, - n_spiketrains: int, - as_array: bool = False + self, n_spiketrains: int, as_array: bool = False ) -> Union[List[neo.SpikeTrain], List[np.ndarray]]: """ Generates a list of spike trains. @@ -420,8 +439,9 @@ def generate_n_spiketrains( list_of_spiketrain : list of neo.SpikeTrain or list of np.ndarray A list generated spike trains in the specified format. """ - return [self.generate_spiketrain(as_array=as_array) - for _ in range(n_spiketrains)] + return [ + self.generate_spiketrain(as_array=as_array) for _ in range(n_spiketrains) + ] class RenewalProcess(AbstractPointProcess): @@ -442,30 +462,33 @@ class RenewalProcess(AbstractPointProcess): Generate an equilibrium or an ordinary renewal process. Default: True """ + isi_generator: stats.rv_continuous def __init__( - self, - rate: pq.Quantity, - t_start: pq.Quantity = 0.*pq.s, - t_stop: pq.Quantity = 1.*pq.s, - equilibrium: bool = True + self, + rate: pq.Quantity, + t_start: pq.Quantity = 0.0 * pq.s, + t_stop: pq.Quantity = 1.0 * pq.s, + equilibrium: bool = True, ): super().__init__(t_start=t_start, t_stop=t_stop) if not isinstance(rate, pq.Quantity): raise ValueError("rate must be of type pq.Quantity") - self.rate = rate.rescale(1./self.units).item() + self.rate = rate.rescale(1.0 / self.units).item() self.equilibrium = equilibrium - self.n_expected_spikes = int(np.ceil( - ((self._t_stop - self._t_start) * self.rate))) + self.n_expected_spikes = int( + np.ceil(((self._t_stop - self._t_start) * self.rate)) + ) if self.n_expected_spikes < 0: raise ValueError( f"Expected no. of spikes: {self.n_expected_spikes} < 0. " f"The firing rate ({self.rate/self.units}) " - f"cannot be negative.") + f"cannot be negative." + ) def _cdf_first_spike_equilibrium(self, time): """ @@ -475,7 +498,7 @@ def _cdf_first_spike_equilibrium(self, time): The parameter time is a magnitude of a time value given in seconds. """ - return self.rate * integrate.quad(self.isi_generator.sf, 0., time)[0] + return self.rate * integrate.quad(self.isi_generator.sf, 0.0, time)[0] def _get_first_spike_equilibrium(self): """ @@ -502,9 +525,9 @@ def derivative_of_function_to_solve(time): return self.rate * self.isi_generator.sf(time) # Initial guess is solution for Poisson process - initial_guess = -np.log(1.-random_uniform)/self.rate - duration = self._t_stop-self._t_start - limits_for_first_spike = (0., duration) + initial_guess = -np.log(1.0 - random_uniform) / self.rate + duration = self._t_stop - self._t_start + limits_for_first_spike = (0.0, duration) # test if solution for first spike is inside the boundaries. If not # return t_stop of the spike train. @@ -512,11 +535,11 @@ def derivative_of_function_to_solve(time): return self._t_stop non_shifted_position_of_first_spike = equation_solver( - function_to_solve, - x0=initial_guess, - bracket=limits_for_first_spike, - fprime=derivative_of_function_to_solve - ).root + function_to_solve, + x0=initial_guess, + bracket=limits_for_first_spike, + fprime=derivative_of_function_to_solve, + ).root return non_shifted_position_of_first_spike + self._t_start @@ -534,8 +557,9 @@ def _generate_spiketrain_as_array(self) -> np.ndarray: spikes = np.array([first_spike]) # 3 STDs corresponds to 99.7% - n_spikes_three_stds = int(np.ceil( - self.n_expected_spikes + 3 * np.sqrt(self.n_expected_spikes))) + n_spikes_three_stds = int( + np.ceil(self.n_expected_spikes + 3 * np.sqrt(self.n_expected_spikes)) + ) # Continue until whole time range is covered while spikes[-1] < self._t_stop: @@ -554,7 +578,7 @@ def expected_cv(self): """ The expected coefficient of variation given the ISI distribution. """ - return self.isi_generator.std()/self.isi_generator.mean() + return self.isi_generator.std() / self.isi_generator.mean() class StationaryPoissonProcess(RenewalProcess): @@ -601,40 +625,41 @@ class StationaryPoissonProcess(RenewalProcess): >>> spiketrain_array = StationaryPoissonProcess(rate=20*pq.Hz,t_stop=10000*pq.ms,t_start=5000*pq.ms).generate_spiketrain(as_array=True) >>> spiketrain = StationaryPoissonProcess(rate=50*pq.Hz,t_stop=1000*pq.ms,t_start=0*pq.ms,refractory_period=3*pq.ms).generate_spiketrain() """ + def __init__( - self, - rate: pq.Quantity, - t_start: pq.Quantity = 0.0 * pq.ms, - t_stop: pq.Quantity = 1000.0*pq.ms, - refractory_period: Optional[pq.Quantity] = None, - equilibrium: bool = True + self, + rate: pq.Quantity, + t_start: pq.Quantity = 0.0 * pq.ms, + t_stop: pq.Quantity = 1000.0 * pq.ms, + refractory_period: Optional[pq.Quantity] = None, + equilibrium: bool = True, ): super().__init__( - rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium) + rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium + ) if refractory_period is not None: if not isinstance(refractory_period, pq.Quantity): - raise ValueError( - "refractory_period must be of type pq.Quantity") - self.refractory_period = refractory_period.rescale( - self.units).item() + raise ValueError("refractory_period must be of type pq.Quantity") + self.refractory_period = refractory_period.rescale(self.units).item() - if self.rate * self.refractory_period >= 1.: + if self.rate * self.refractory_period >= 1.0: raise ValueError( "Period between two successive spikes must be larger " "than the refractory period. Decrease either the " - "firing rate or the refractory period.") + "firing rate or the refractory period." + ) else: self.refractory_period = refractory_period if self.n_expected_spikes > 0 and refractory_period is None: - self.isi_generator = stats.expon(scale=1./self.rate) + self.isi_generator = stats.expon(scale=1.0 / self.rate) elif self.n_expected_spikes > 0 and refractory_period is not None: - self.effective_rate = self.rate / \ - (1. - self.rate * self.refractory_period) + self.effective_rate = self.rate / (1.0 - self.rate * self.refractory_period) self.isi_generator = stats.expon( - scale=1. / self.effective_rate, loc=self.refractory_period) + scale=1.0 / self.effective_rate, loc=self.refractory_period + ) def _get_first_spike_equilibrium(self): if self.refractory_period is None: @@ -645,9 +670,10 @@ def _get_first_spike_equilibrium(self): if random_uniform <= self.rate * self.refractory_period: return random_uniform / self.rate + self._t_start # random_uniform > self.rate * self.refractory_period - return (np.log(1. - self.rate * self.refractory_period) - - np.log(1. - random_uniform) - ) / self.effective_rate + self.refractory_period + return ( + np.log(1.0 - self.rate * self.refractory_period) + - np.log(1.0 - random_uniform) + ) / self.effective_rate + self.refractory_period @property def expected_cv(self): @@ -655,10 +681,10 @@ def expected_cv(self): The expected coefficient of variation given the ISI distribution. """ if self.refractory_period is None: - return 1. + return 1.0 # the case with dead time - return 1. - self.rate * self.refractory_period + return 1.0 - self.rate * self.refractory_period class StationaryGammaProcess(RenewalProcess): @@ -698,39 +724,40 @@ class StationaryGammaProcess(RenewalProcess): ... rate=20*pq.Hz, shape_factor=5.0, t_start=5000*pq.ms, ... t_stop=10000*pq.ms).generate_spiketrain(as_array=True) """ + def __init__( - self, - rate: pq.Quantity, - shape_factor: float, - t_start: pq.Quantity = 0.*pq.s, - t_stop: pq.Quantity = 1.*pq.s, - equilibrium: bool = True + self, + rate: pq.Quantity, + shape_factor: float, + t_start: pq.Quantity = 0.0 * pq.s, + t_stop: pq.Quantity = 1.0 * pq.s, + equilibrium: bool = True, ): super().__init__( - rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium) + rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium + ) if self.n_expected_spikes > 0: self.shape_factor = shape_factor self.isi_generator = stats.gamma( - a=shape_factor, scale=1./(shape_factor * self.rate)) + a=shape_factor, scale=1.0 / (shape_factor * self.rate) + ) def _cdf_first_spike_equilibrium(self, time): """ The parameter time is a magnitude of a time value given in seconds. """ - if time < 0.: - return 0. - return self.rate * time * \ - gammaincc(self.shape_factor, - self.shape_factor*self.rate*time)\ - + gammainc(self.shape_factor+1., - self.shape_factor*self.rate*time) + if time < 0.0: + return 0.0 + return self.rate * time * gammaincc( + self.shape_factor, self.shape_factor * self.rate * time + ) + gammainc(self.shape_factor + 1.0, self.shape_factor * self.rate * time) @property def expected_cv(self): """ The expected coefficient of variation given the ISI distribution. """ - return 1./np.sqrt(self.shape_factor) + return 1.0 / np.sqrt(self.shape_factor) class StationaryLogNormalProcess(RenewalProcess): @@ -770,27 +797,28 @@ class StationaryLogNormalProcess(RenewalProcess): ... rate=20*pq.Hz, sigma=5.0, t_start=5000*pq.ms, ... t_stop=10000*pq.ms).generate_spiketrain(as_array=True) """ + def __init__( - self, - rate: pq.Quantity, - sigma: float, - t_start: pq.Quantity = 0.*pq.s, - t_stop: pq.Quantity = 1.*pq.s, - equilibrium: bool = True + self, + rate: pq.Quantity, + sigma: float, + t_start: pq.Quantity = 0.0 * pq.s, + t_stop: pq.Quantity = 1.0 * pq.s, + equilibrium: bool = True, ): super().__init__( - rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium) + rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium + ) self.sigma = sigma if self.n_expected_spikes > 0: - self.isi_generator = stats.lognorm( - s=self.sigma, scale=np.exp(self.mu)) + self.isi_generator = stats.lognorm(s=self.sigma, scale=np.exp(self.mu)) @property def mu(self): """ The parameter mu of the log-normal distribution. """ - return -np.log(self.rate) - self.sigma**2/2 + return -np.log(self.rate) - self.sigma**2 / 2 @property def expected_cv(self): @@ -837,20 +865,23 @@ class StationaryInverseGaussianProcess(RenewalProcess): ... rate=20*pq.Hz, cv=5.0, t_start=5000*pq.ms, ... t_stop=10000*pq.ms).generate_spiketrain(as_array=True) """ + def __init__( - self, - rate: pq.Quantity, - cv: float, - t_start: pq.Quantity = 0.*pq.s, - t_stop: pq.Quantity = 1.*pq.s, - equilibrium: bool = True + self, + rate: pq.Quantity, + cv: float, + t_start: pq.Quantity = 0.0 * pq.s, + t_stop: pq.Quantity = 1.0 * pq.s, + equilibrium: bool = True, ): super().__init__( - rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium) + rate=rate, t_start=t_start, t_stop=t_stop, equilibrium=equilibrium + ) self._cv = cv if self.n_expected_spikes > 0: self.isi_generator = stats.invgauss( - mu=cv**2, scale=1./(self.rate*cv**2)) + mu=cv**2, scale=1.0 / (self.rate * cv**2) + ) @property def expected_cv(self): @@ -878,55 +909,54 @@ class RateModulatedProcess(AbstractPointProcess): If `rate_signal` contains a negative value. If `rate_signal` is empty. """ + process_operational_time: RenewalProcess def __init__(self, rate_signal: neo.AnalogSignal): - if not isinstance(rate_signal, neo.AnalogSignal): raise ValueError( - f'rate_signal should be of type neo.AnalogSignal.' - f' Currently it is of type: {type(rate_signal)}') + f"rate_signal should be of type neo.AnalogSignal." + f" Currently it is of type: {type(rate_signal)}" + ) if len(rate_signal) == 0: - raise ValueError('rate_signal can not be empty.') + raise ValueError("rate_signal can not be empty.") if any(rate_signal < 0): - raise ValueError( - 'All elements of rate_signal should be positive.') + raise ValueError("All elements of rate_signal should be positive.") - super().__init__( - t_start=rate_signal.t_start, t_stop=rate_signal.t_stop) + super().__init__(t_start=rate_signal.t_start, t_stop=rate_signal.t_stop) self.rate_signal = rate_signal - self.mean_rate = np.mean(rate_signal.rescale(1./self.units).magnitude) + self.mean_rate = np.mean(rate_signal.rescale(1.0 / self.units).magnitude) - if self.mean_rate == 0.: + if self.mean_rate == 0.0: # if the firing rate is zero, the init functions stops here, since # the other parameters are then not needed. return None - self.sampling_period = \ - self.rate_signal.sampling_period.rescale(self.units).magnitude + self.sampling_period = self.rate_signal.sampling_period.rescale( + self.units + ).magnitude # Operational time corresponds to the integral of the firing rate # over time, here normalized by the average firing rate - operational_time = np.cumsum( - rate_signal.rescale(1./self.units).magnitude) - operational_time *= (self.sampling_period / self.mean_rate) - operational_time = np.hstack((0., operational_time)) + operational_time = np.cumsum(rate_signal.rescale(1.0 / self.units).magnitude) + operational_time *= self.sampling_period / self.mean_rate + operational_time = np.hstack((0.0, operational_time)) self.operational_time = operational_time + self._t_start # The time points at which the firing rates are given self.real_time = np.hstack( - (rate_signal.times.rescale(self.units).magnitude, - self._t_stop)) + (rate_signal.times.rescale(self.units).magnitude, self._t_stop) + ) def _generate_spiketrain_as_array(self) -> np.ndarray: - spiketrain_operational_time = \ + spiketrain_operational_time = ( self.process_operational_time._generate_spiketrain_as_array() + ) if len(spiketrain_operational_time) == 0: return spiketrain_operational_time # indices where between which points in operational time the spikes lie - indices = np.searchsorted(self.operational_time, - spiketrain_operational_time) + indices = np.searchsorted(self.operational_time, spiketrain_operational_time) # In real time the spikes are first aligned # to the left border of the bin. @@ -934,11 +964,9 @@ def _generate_spiketrain_as_array(self) -> np.ndarray: # padded with zeros. spiketrain = self.real_time[indices - 1] # the relative position of the spikes in the operational time bins - positions_in_bins = \ - (spiketrain_operational_time - - self.operational_time[indices - 1]) / \ - (self.operational_time[indices] - - self.operational_time[indices - 1]) + positions_in_bins = ( + spiketrain_operational_time - self.operational_time[indices - 1] + ) / (self.operational_time[indices] - self.operational_time[indices - 1]) # add the positions in the bin times the sampling period in real time spiketrain += self.sampling_period * positions_in_bins @@ -970,26 +998,30 @@ class NonStationaryPoissonProcess(RateModulatedProcess): If `rate_signal` is empty. If `refractory_period` is not of type `pq.Quantity` nor None. """ - def __init__(self, rate_signal: neo.AnalogSignal, - refractory_period: Optional[pq.Quantity] = None): + def __init__( + self, + rate_signal: neo.AnalogSignal, + refractory_period: Optional[pq.Quantity] = None, + ): if refractory_period is not None: if not isinstance(refractory_period, pq.Quantity): - raise ValueError( - "refractory_period must be of type pq.Quantity") - rate_signal = \ - rate_signal / (1. - rate_signal.simplified.magnitude - * refractory_period.simplified.item()) + raise ValueError("refractory_period must be of type pq.Quantity") + rate_signal = rate_signal / ( + 1.0 + - rate_signal.simplified.magnitude * refractory_period.simplified.item() + ) super().__init__(rate_signal=rate_signal) self.process_operational_time = StationaryPoissonProcess( - rate=self.mean_rate * 1. / self.units, t_stop=self.t_stop, - t_start=self.t_start) + rate=self.mean_rate * 1.0 / self.units, + t_stop=self.t_stop, + t_start=self.t_start, + ) self.refractory_period = refractory_period if self.refractory_period is not None: - self.refractory_period = self.refractory_period.rescale( - self.units).item() + self.refractory_period = self.refractory_period.rescale(self.units).item() def _generate_spiketrain_as_array(self) -> np.ndarray: if self.refractory_period is None: @@ -1028,18 +1060,24 @@ class NonStationaryGammaProcess(RateModulatedProcess): If `rate_signal` contains a negative value. If `rate_signal` is empty. """ + def __init__(self, rate_signal: neo.AnalogSignal, shape_factor: float): super().__init__(rate_signal=rate_signal) self.process_operational_time = StationaryGammaProcess( - rate=self.mean_rate * 1./self.units, + rate=self.mean_rate * 1.0 / self.units, shape_factor=shape_factor, t_start=self.t_start, - t_stop=self.t_stop) + t_stop=self.t_stop, + ) -def homogeneous_poisson_process(rate, t_start=0.0 * pq.ms, - t_stop=1000.0 * pq.ms, as_array=False, - refractory_period=None): +def homogeneous_poisson_process( + rate, + t_start=0.0 * pq.ms, + t_stop=1000.0 * pq.ms, + as_array=False, + refractory_period=None, +): """ Returns a spike train whose spikes are a realization of a Poisson process with the given rate, starting at time `t_start` and stopping time `t_stop`. @@ -1096,16 +1134,20 @@ def homogeneous_poisson_process(rate, t_start=0.0 * pq.ms, """ warnings.warn( "'homogeneous_poisson_process' is deprecated;" - " use 'StationaryPoissonProcess'.", DeprecationWarning) - process = StationaryPoissonProcess(rate=rate, t_stop=t_stop, - t_start=t_start, - refractory_period=refractory_period, - equilibrium=False) + " use 'StationaryPoissonProcess'.", + DeprecationWarning, + ) + process = StationaryPoissonProcess( + rate=rate, + t_stop=t_stop, + t_start=t_start, + refractory_period=refractory_period, + equilibrium=False, + ) return process.generate_spiketrain(as_array=as_array) -def inhomogeneous_poisson_process(rate, as_array=False, - refractory_period=None): +def inhomogeneous_poisson_process(rate, as_array=False, refractory_period=None): """ Returns a spike train whose spikes are a realization of an inhomogeneous Poisson process with the given rate profile. @@ -1145,16 +1187,17 @@ def inhomogeneous_poisson_process(rate, as_array=False, warnings.warn( "'inhomogeneous_poisson_process' is deprecated;" " use 'NonStationaryPoissonProcess'.", - DeprecationWarning) + DeprecationWarning, + ) process = NonStationaryPoissonProcess( - rate_signal=rate, - refractory_period=refractory_period) - return process.generate_spiketrain( - as_array=as_array) + rate_signal=rate, refractory_period=refractory_period + ) + return process.generate_spiketrain(as_array=as_array) -def homogeneous_gamma_process(a, b, t_start=0.0 * pq.ms, t_stop=1000.0 * pq.ms, - as_array=False): +def homogeneous_gamma_process( + a, b, t_start=0.0 * pq.ms, t_stop=1000.0 * pq.ms, as_array=False +): """ Returns a spike train whose spikes are a realization of a gamma process with the given parameters, starting at time `t_start` and stopping time @@ -1199,12 +1242,12 @@ def homogeneous_gamma_process(a, b, t_start=0.0 * pq.ms, t_stop=1000.0 * pq.ms, """ warnings.warn( - "'homogeneous_gamma_process' is deprecated;" - " use 'StationaryGammaProcess'.", - DeprecationWarning) + "'homogeneous_gamma_process' is deprecated;" " use 'StationaryGammaProcess'.", + DeprecationWarning, + ) process = StationaryGammaProcess( - rate=b / a, shape_factor=a, t_stop=t_stop, t_start=t_start, - equilibrium=False) + rate=b / a, shape_factor=a, t_stop=t_stop, t_start=t_start, equilibrium=False + ) return process.generate_spiketrain(as_array=as_array) @@ -1243,9 +1286,9 @@ def inhomogeneous_gamma_process(rate, shape_factor, as_array=False): warnings.warn( "'inhomogeneous_gamma_process' is deprecated;" " use 'nonStationaryGammaProcess'.", - DeprecationWarning) - process = NonStationaryGammaProcess( - rate_signal=rate, shape_factor=shape_factor) + DeprecationWarning, + ) + process = NonStationaryGammaProcess(rate_signal=rate, shape_factor=shape_factor) return process.generate_spiketrain(as_array=as_array) @@ -1284,30 +1327,38 @@ def _n_poisson(rate, t_stop, t_start=0.0 * pq.ms, n_spiketrains=1): """ # Check that the provided input is Hertz if not isinstance(rate, pq.Quantity): - raise ValueError('rate must be a pq.Quantity') + raise ValueError("rate must be a pq.Quantity") # Set number n of output spike trains (specified or set to len(rate)) if not (isinstance(n_spiketrains, int) and n_spiketrains > 0): - raise ValueError( - f'n_spiketrains (={n_spiketrains}) must be a positive integer') + raise ValueError(f"n_spiketrains (={n_spiketrains}) must be a positive integer") # one rate for all spike trains if rate.ndim == 0: return StationaryPoissonProcess( - rate=rate, - t_stop=t_stop, - t_start=t_start).generate_n_spiketrains(n_spiketrains) + rate=rate, t_stop=t_stop, t_start=t_start + ).generate_n_spiketrains(n_spiketrains) # different rate for each spike train - return [StationaryPoissonProcess(rate=single_rate, t_stop=t_stop, - t_start=t_start).generate_spiketrain() - for single_rate in rate] + return [ + StationaryPoissonProcess( + rate=single_rate, t_stop=t_stop, t_start=t_start + ).generate_spiketrain() + for single_rate in rate + ] def single_interaction_process( - rate, coincidence_rate, t_stop, n_spiketrains=2, jitter=0 * pq.ms, - coincidences='deterministic', t_start=0 * pq.ms, min_delay=0 * pq.ms, - return_coincidences=False): + rate, + coincidence_rate, + t_stop, + n_spiketrains=2, + jitter=0 * pq.ms, + coincidences="deterministic", + t_start=0 * pq.ms, + min_delay=0 * pq.ms, + return_coincidences=False, +): """ Generates a multidimensional Poisson SIP (single interaction process) plus independent Poisson processes :cite:`generation-Kuhn2003_67`. @@ -1389,11 +1440,9 @@ def single_interaction_process( # Check if n is a positive integer if not (isinstance(n_spiketrains, int) and n_spiketrains > 0): - raise ValueError( - f'n_spiketrains (={n_spiketrains}) must be a positive integer') - if coincidences not in ('deterministic', 'stochastic'): - raise ValueError( - "coincidences must be 'deterministic' or 'stochastic'") + raise ValueError(f"n_spiketrains (={n_spiketrains}) must be a positive integer") + if coincidences not in ("deterministic", "stochastic"): + raise ValueError("coincidences must be 'deterministic' or 'stochastic'") # Assign time unit to jitter, or check that its existing unit is a time # unit @@ -1403,54 +1452,51 @@ def single_interaction_process( # matches with n if rate.ndim == 0: if rate < 0 * pq.Hz: - raise ValueError( - f'rate (={rate}) must be non-negative.') + raise ValueError(f"rate (={rate}) must be non-negative.") rates_b = np.repeat(rate, n_spiketrains) else: rates_b = rate.flatten() - if not all(rates_b >= 0.*pq.Hz): - raise ValueError('*rate* must have non-negative elements') + if not all(rates_b >= 0.0 * pq.Hz): + raise ValueError("*rate* must have non-negative elements") # Check: rate>=rate_coincidence if np.any(rates_b < coincidence_rate): - raise ValueError( - 'all elements of *rate* must be >= *rate_coincidence*') + raise ValueError("all elements of *rate* must be >= *rate_coincidence*") # Check min_delay < 1./rate_coincidence - if not (coincidence_rate == 0 * pq.Hz - or min_delay < 1. / coincidence_rate): + if not (coincidence_rate == 0 * pq.Hz or min_delay < 1.0 / coincidence_rate): raise ValueError( - "'*min_delay* (%s) must be lower than 1/*rate_coincidence* (%s)." % - (str(min_delay), str((1. / coincidence_rate).rescale( - min_delay.units)))) + "'*min_delay* (%s) must be lower than 1/*rate_coincidence* (%s)." + % (str(min_delay), str((1.0 / coincidence_rate).rescale(min_delay.units))) + ) # Generate the n Poisson processes there are the basis for the SIP # (coincidences still lacking) embedded_poisson_trains = _n_poisson( - rate=rates_b - coincidence_rate, t_stop=t_stop, t_start=t_start) + rate=rates_b - coincidence_rate, t_stop=t_stop, t_start=t_start + ) # Convert the trains from neo SpikeTrain objects to simpler pq.Quantity # objects - embedded_poisson_trains = [ - emb.view(pq.Quantity) for emb in embedded_poisson_trains] + embedded_poisson_trains = [emb.view(pq.Quantity) for emb in embedded_poisson_trains] # Generate the array of times for coincident events in SIP, not closer than # min_delay. The array is generated as a pq.Quantity. - if coincidences == 'deterministic': + if coincidences == "deterministic": # P. Bouss: we want the closest approximation to the average # coincidence count. n_coincidences = (t_stop - t_start) * coincidence_rate # Conversion to integer necessary for python 2 n_coincidences = int(round(n_coincidences.simplified.item())) while True: - coinc_times = t_start + \ - np.sort(np.random.random(n_coincidences)) * ( - t_stop - t_start) + coinc_times = t_start + np.sort(np.random.random(n_coincidences)) * ( + t_stop - t_start + ) if len(coinc_times) < 2 or min(np.diff(coinc_times)) >= min_delay: break else: # coincidences == 'stochastic' - poisson_process = StationaryPoissonProcess(rate=coincidence_rate, - t_stop=t_stop, - t_start=t_start) + poisson_process = StationaryPoissonProcess( + rate=coincidence_rate, t_stop=t_stop, t_start=t_start + ) while True: coinc_times = poisson_process.generate_spiketrain() if len(coinc_times) < 2 or min(np.diff(coinc_times)) >= min_delay: @@ -1460,35 +1506,49 @@ def single_interaction_process( # Set the coincidence times to T-jitter if larger. This ensures that # the last jittered spike time is t_stop) + embedded_coinc = ( + coinc_times + + np.random.random((len(rates_b), len(coinc_times))) * 2 * jitter + - jitter + ) + embedded_coinc = ( + embedded_coinc + + (t_start - embedded_coinc) * (embedded_coinc < t_start) + - (t_stop - embedded_coinc) * (embedded_coinc > t_stop) + ) # Inject coincident events into the n SIP processes generated above, and # merge with the n independent processes sip_process = [ - np.sort(np.concatenate(( - embedded_poisson_trains[m].rescale(t_stop.units), - embedded_coinc[m].rescale(t_stop.units))) * t_stop.units) - for m in range(len(rates_b))] + np.sort( + np.concatenate( + ( + embedded_poisson_trains[m].rescale(t_stop.units), + embedded_coinc[m].rescale(t_stop.units), + ) + ) + * t_stop.units + ) + for m in range(len(rates_b)) + ] # Convert back sip_process and coinc_times from pq.Quantity objects to # neo.SpikeTrain objects sip_process = [ neo.SpikeTrain(t, t_start=t_start, t_stop=t_stop).rescale(t_stop.units) - for t in sip_process] + for t in sip_process + ] coinc_times = [ neo.SpikeTrain(t, t_start=t_start, t_stop=t_stop).rescale(t_stop.units) - for t in embedded_coinc] + for t in embedded_coinc + ] # Return the processes in the specified output_format if not return_coincidences: @@ -1499,7 +1559,7 @@ def single_interaction_process( return output -def _pool_two_spiketrains(spiketrain_1, spiketrain_2, extremes='inner'): +def _pool_two_spiketrains(spiketrain_1, spiketrain_2, extremes="inner"): """ Pool the spikes of two spike trains a and b into a unique spike train. @@ -1528,20 +1588,18 @@ def _pool_two_spiketrains(spiketrain_1, spiketrain_2, extremes='inner'): times_2_dimless = spiketrain_2.rescale(unit).magnitude times = np.sort(np.concatenate((times_1_dimless, times_2_dimless))) - if extremes == 'outer': + if extremes == "outer": t_start = min(spiketrain_1.t_start, spiketrain_2.t_start) t_stop = max(spiketrain_1.t_stop, spiketrain_2.t_stop) - elif extremes == 'inner': + elif extremes == "inner": t_start = max(spiketrain_1.t_start, spiketrain_2.t_start) t_stop = min(spiketrain_1.t_stop, spiketrain_2.t_stop) times = times[times > t_start.magnitude] times = times[times < t_stop.magnitude] else: - raise ValueError( - 'extremes (%s) can only be "inner" or "outer"' % extremes) + raise ValueError('extremes (%s) can only be "inner" or "outer"' % extremes) - return neo.SpikeTrain(times=times, units=unit, t_start=t_start, - t_stop=t_stop) + return neo.SpikeTrain(times=times, units=unit, t_start=t_start, t_stop=t_stop) def _sample_int_from_pdf(probability_density, n_samples): @@ -1567,14 +1625,15 @@ def _sample_int_from_pdf(probability_density, n_samples): cumulative_distribution = np.cumsum(probability_density) random_uniforms = np.random.uniform(0, 1, size=n_samples) - random_uniforms = np.repeat(np.expand_dims(random_uniforms, axis=1), - repeats=len(probability_density), - axis=1) + random_uniforms = np.repeat( + np.expand_dims(random_uniforms, axis=1), + repeats=len(probability_density), + axis=1, + ) return (cumulative_distribution < random_uniforms).sum(axis=1) -def _mother_proc_cpp_stat( - amplitude_distribution, t_stop, rate, t_start=0 * pq.ms): +def _mother_proc_cpp_stat(amplitude_distribution, t_stop, rate, t_start=0 * pq.ms): """ Generate the hidden ("mother") Poisson process for a Compound Poisson Process (CPP). @@ -1602,12 +1661,12 @@ def _mother_proc_cpp_stat( """ n_spiketrains = len(amplitude_distribution) - 1 # expected amplitude - exp_amplitude = np.dot( - amplitude_distribution, np.arange(n_spiketrains + 1)) + exp_amplitude = np.dot(amplitude_distribution, np.arange(n_spiketrains + 1)) # expected rate of the mother process exp_mother_rate = (n_spiketrains * rate) / exp_amplitude - return StationaryPoissonProcess(rate=exp_mother_rate, t_stop=t_stop, - t_start=t_start).generate_spiketrain() + return StationaryPoissonProcess( + rate=exp_mother_rate, t_stop=t_stop, t_start=t_start + ).generate_spiketrain() def _cpp_hom_stat(amplitude_distribution, t_stop, rate, t_start=0 * pq.ms): @@ -1640,7 +1699,10 @@ def _cpp_hom_stat(amplitude_distribution, t_stop, rate, t_start=0 * pq.ms): # Generate mother process and associated spike labels mother = _mother_proc_cpp_stat( amplitude_distribution=amplitude_distribution, - t_stop=t_stop, rate=rate, t_start=t_start) + t_stop=t_stop, + rate=rate, + t_start=t_start, + ) labels = _sample_int_from_pdf(amplitude_distribution, len(mother)) n_spiketrains = len(amplitude_distribution) - 1 # Number of trains in output @@ -1661,17 +1723,19 @@ def _cpp_hom_stat(amplitude_distribution, t_stop, rate, t_start=0 * pq.ms): spiketrains[train_id] = mother[row].view(pq.Quantity) except MemoryError: # Slower (~2x) but less memory-consuming approach - print('memory case') + print("memory case") for mother_spiketrain, label in zip(mother, labels): train_ids = np.random.choice(n_spiketrains, label) for train_id in train_ids: spiketrains[train_id].append(mother_spiketrain) - return [neo.SpikeTrain(times=spiketrain, t_start=t_start, t_stop=t_stop) - for spiketrain in spiketrains] + return [ + neo.SpikeTrain(times=spiketrain, t_start=t_start, t_stop=t_stop) + for spiketrain in spiketrains + ] -def _cpp_het_stat(amplitude_distribution, t_stop, rates, t_start=0.*pq.ms): +def _cpp_het_stat(amplitude_distribution, t_stop, rates, t_start=0.0 * pq.ms): """ Generate a Compound Poisson Process (CPP) with amplitude distribution A and heterogeneous firing rates r=r[0], r[1], ..., r[-1]. @@ -1702,8 +1766,7 @@ def _cpp_het_stat(amplitude_distribution, t_stop, rates, t_start=0.*pq.ms): # (uncorrelated with heterog. rates + correlated with homog. rates) n_spiketrains = len(rates) # number of output spike trains # amplitude expectation - expected_amplitude = np.dot( - amplitude_distribution, np.arange(n_spiketrains + 1)) + expected_amplitude = np.dot(amplitude_distribution, np.arange(n_spiketrains + 1)) r_sum = np.sum(rates) # sum of all output firing rates r_min = np.min(rates) # minimum of the firing rates @@ -1715,33 +1778,43 @@ def _cpp_het_stat(amplitude_distribution, t_stop, rates, t_start=0.*pq.ms): r_mother = r_uncorrelated + r_correlated # Check the analytical constraint for the amplitude distribution - if amplitude_distribution[1] < (r_uncorrelated / r_mother).rescale( - pq.dimensionless).magnitude: - raise ValueError('A[1] too small / A[i], i>1 too high') + if ( + amplitude_distribution[1] + < (r_uncorrelated / r_mother).rescale(pq.dimensionless).magnitude + ): + raise ValueError("A[1] too small / A[i], i>1 too high") # Compute the amplitude distribution of the correlated CPP, and generate it - amplitude_distribution = \ + amplitude_distribution = ( amplitude_distribution * (r_mother / r_correlated).magnitude - amplitude_distribution[1] = \ + ) + amplitude_distribution[1] = ( amplitude_distribution[1] - r_uncorrelated / r_correlated + ) compound_poisson_spiketrains = _cpp_hom_stat( - amplitude_distribution, t_stop, r_min, t_start) + amplitude_distribution, t_stop, r_min, t_start + ) # Generate the independent heterogeneous Poisson processes - poisson_spiketrains = \ - [StationaryPoissonProcess(rate=rate - r_min, t_stop=t_stop, - t_start=t_start).generate_spiketrain() - for rate in rates] + poisson_spiketrains = [ + StationaryPoissonProcess( + rate=rate - r_min, t_stop=t_stop, t_start=t_start + ).generate_spiketrain() + for rate in rates + ] # Pool the correlated CPP and the corresponding Poisson processes - return [_pool_two_spiketrains(compound_poisson_spiketrain, - poisson_spiketrain) - for compound_poisson_spiketrain, poisson_spiketrain - in zip(compound_poisson_spiketrains, poisson_spiketrains)] + return [ + _pool_two_spiketrains(compound_poisson_spiketrain, poisson_spiketrain) + for compound_poisson_spiketrain, poisson_spiketrain in zip( + compound_poisson_spiketrains, poisson_spiketrains + ) + ] def compound_poisson_process( - rate, amplitude_distribution, t_stop, shift=None, t_start=0 * pq.ms): + rate, amplitude_distribution, t_stop, shift=None, t_start=0 * pq.ms +): """ Generate a Compound Poisson Process (CPP; see :cite:`generation-Staude2010_327`) with a given `amplitude_distribution` @@ -1793,41 +1866,48 @@ def compound_poisson_process( if not isinstance(amplitude_distribution, np.ndarray): amplitude_distribution = np.array(amplitude_distribution) # Check A is a probability distribution (it sums to 1 and is positive) - if abs(sum(amplitude_distribution) - 1) > np.finfo('float').eps: + if abs(sum(amplitude_distribution) - 1) > np.finfo("float").eps: raise ValueError( f"'amplitude_distribution' must be a probability vector: " - f"sum(A) = {sum(amplitude_distribution)} != 1") + f"sum(A) = {sum(amplitude_distribution)} != 1" + ) if np.any(amplitude_distribution < 0): - raise ValueError("'amplitude_distribution' must be a probability " - "vector with positive entries") + raise ValueError( + "'amplitude_distribution' must be a probability " + "vector with positive entries" + ) # Check that the rate is not an empty pq.Quantity if rate.ndim == 1 and len(rate) == 0: - raise ValueError('Rate is an empty pq.Quantity array') + raise ValueError("Rate is an empty pq.Quantity array") # Return empty spike trains for specific parameters if amplitude_distribution[0] == 1 or np.sum(np.abs(rate.magnitude)) == 0: - return [neo.SpikeTrain([] * t_stop.units, - t_stop=t_stop, - t_start=t_start)] * ( - len(amplitude_distribution) - 1) + return [neo.SpikeTrain([] * t_stop.units, t_stop=t_stop, t_start=t_start)] * ( + len(amplitude_distribution) - 1 + ) # Homogeneous rates if rate.ndim == 0: compound_poisson_spiketrains = _cpp_hom_stat( amplitude_distribution=amplitude_distribution, - t_stop=t_stop, rate=rate, - t_start=t_start) + t_stop=t_stop, + rate=rate, + t_start=t_start, + ) # Heterogeneous rates else: compound_poisson_spiketrains = _cpp_het_stat( amplitude_distribution=amplitude_distribution, - t_stop=t_stop, rates=rate, - t_start=t_start) + t_stop=t_stop, + rates=rate, + t_start=t_start, + ) if shift is not None: # Dither the output spiketrains - compound_poisson_spiketrains = \ - [dither_spike_train(spiketrain, shift=shift, edges=True)[0] - for spiketrain in compound_poisson_spiketrains] + compound_poisson_spiketrains = [ + dither_spike_train(spiketrain, shift=shift, edges=True)[0] + for spiketrain in compound_poisson_spiketrains + ] return compound_poisson_spiketrains diff --git a/elephant/spike_train_surrogates.py b/elephant/spike_train_surrogates.py index 5d6cd4300..db155efd6 100644 --- a/elephant/spike_train_surrogates.py +++ b/elephant/spike_train_surrogates.py @@ -57,28 +57,36 @@ "bin_shuffling", "JointISI", "trial_shifting", - "surrogates" + "surrogates", ] # List of all available surrogate methods -SURR_METHODS = ('dither_spike_train', 'dither_spikes', 'jitter_spikes', - 'randomise_spikes', 'shuffle_isis', 'joint_isi_dithering', - 'dither_spikes_with_refractory_period', 'trial_shifting', - 'bin_shuffling', 'isi_dithering') +SURR_METHODS = ( + "dither_spike_train", + "dither_spikes", + "jitter_spikes", + "randomise_spikes", + "shuffle_isis", + "joint_isi_dithering", + "dither_spikes_with_refractory_period", + "trial_shifting", + "bin_shuffling", + "isi_dithering", +) -def _dither_spikes_with_refractory_period(spiketrain: neo.SpikeTrain, - dither: float, - n_surrogates: int, - refractory_period: float - ) -> np.array: +def _dither_spikes_with_refractory_period( + spiketrain: neo.SpikeTrain, + dither: float, + n_surrogates: int, + refractory_period: float, +) -> np.array: units = spiketrain.units t_start = spiketrain.t_start.rescale(units).magnitude t_stop = spiketrain.t_stop.rescale(units).magnitude # The initially guesses refractory period is compared to the minimal ISI. # The smaller value is taken as the refractory to calculate with. - refractory_period = np.min(np.diff(spiketrain.magnitude), - initial=refractory_period) + refractory_period = np.min(np.diff(spiketrain.magnitude), initial=refractory_period) dithered_spiketrains = [] for _ in range(n_surrogates): @@ -88,14 +96,18 @@ def _dither_spikes_with_refractory_period(spiketrain: neo.SpikeTrain, for random_id in random_ordered_ids: spike = dithered_st[random_id] - prev_spike = dithered_st[random_id - 1] \ - if random_id > 0 \ + prev_spike = ( + dithered_st[random_id - 1] + if random_id > 0 else t_start - refractory_period + ) # subtract refractory period so that the first spike can move up # to t_start - next_spike = dithered_st[random_id + 1] \ - if random_id < len(spiketrain) - 1 \ + next_spike = ( + dithered_st[random_id + 1] + if random_id < len(spiketrain) - 1 else t_stop + refractory_period + ) # add refractory period so that the last spike can move up # to t_stop @@ -114,40 +126,45 @@ def _dither_spikes_with_refractory_period(spiketrain: neo.SpikeTrain, return dithered_spiketrains -def _dither_spikes(spiketrain: neo.SpikeTrain, dither: float, - n_surrogates: int, edges: bool) -> np.array: +def _dither_spikes( + spiketrain: neo.SpikeTrain, dither: float, n_surrogates: int, edges: bool +) -> np.array: units = spiketrain.units t_start = spiketrain.t_start.rescale(units).magnitude.item() t_stop = spiketrain.t_stop.rescale(units).magnitude.item() # Main: generate the surrogates - dithered_spiketrains = \ - spiketrain.magnitude.reshape((1, len(spiketrain))) \ - + 2 * dither * np.random.random_sample( - (n_surrogates, len(spiketrain))) - dither + dithered_spiketrains = ( + spiketrain.magnitude.reshape((1, len(spiketrain))) + + 2 * dither * np.random.random_sample((n_surrogates, len(spiketrain))) + - dither + ) dithered_spiketrains.sort(axis=1) if edges: # Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop] dithered_spiketrains = [ train[np.all([t_start < train, train < t_stop], axis=0)] - for train in dithered_spiketrains] + for train in dithered_spiketrains + ] else: # Move all spikes outside # [spiketrain.t_start, spiketrain.t_stop] to the range's ends dithered_spiketrains = np.minimum( - np.maximum(dithered_spiketrains, t_start), t_stop) + np.maximum(dithered_spiketrains, t_start), t_stop + ) return dithered_spiketrains -@deprecated_alias(n='n_surrogates') -def dither_spikes(spiketrain: neo.SpikeTrain, - dither: pq.Quantity, - n_surrogates: Optional[int] = 1, - decimals: Optional[int] = None, - edges: Optional[bool] = True, - refractory_period: Optional[Union[pq.Quantity, None]] = None - ) -> List[neo.SpikeTrain]: +@deprecated_alias(n="n_surrogates") +def dither_spikes( + spiketrain: neo.SpikeTrain, + dither: pq.Quantity, + n_surrogates: Optional[int] = 1, + decimals: Optional[int] = None, + edges: Optional[bool] = True, + refractory_period: Optional[Union[pq.Quantity, None]] = None, +) -> List[neo.SpikeTrain]: """ Generates surrogates of a spike train by spike dithering. @@ -228,33 +245,41 @@ def dither_spikes(spiketrain: neo.SpikeTrain, dither = dither.rescale(units).magnitude.item() if not refractory_period: - dithered_spiketrains = _dither_spikes( - spiketrain, dither, n_surrogates, edges) + dithered_spiketrains = _dither_spikes(spiketrain, dither, n_surrogates, edges) elif isinstance(refractory_period, pq.Quantity): refractory_period = refractory_period.rescale(units).magnitude.item() dithered_spiketrains = _dither_spikes_with_refractory_period( - spiketrain, dither, n_surrogates, refractory_period) + spiketrain, dither, n_surrogates, refractory_period + ) else: raise ValueError("refractory_period must be of type pq.Quantity") # Round the surrogate data to decimal position, if requested if decimals: - return [neo.SpikeTrain( + return [ + neo.SpikeTrain( (train * units).rescale(pq.ms).round(decimals).rescale(units), - t_start=spiketrain.t_start, t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate) - for train in dithered_spiketrains] + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ) + for train in dithered_spiketrains + ] else: # Return the surrogates as list of neo.SpikeTrain - return [neo.SpikeTrain( + return [ + neo.SpikeTrain( train * units, - t_start=spiketrain.t_start, t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate) - for train in dithered_spiketrains] + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ) + for train in dithered_spiketrains + ] -@deprecated_alias(n='n_surrogates') +@deprecated_alias(n="n_surrogates") def randomise_spikes(spiketrain, n_surrogates=1, decimals=None): """ Generates surrogates of a spike train by spike time randomization. @@ -304,22 +329,29 @@ def randomise_spikes(spiketrain, n_surrogates=1, decimals=None): """ # Create surrogate spike trains as rows of a Quantity array - sts = ((spiketrain.t_stop - spiketrain.t_start) * - np.random.random(size=(n_surrogates, len(spiketrain))) + - spiketrain.t_start).rescale(spiketrain.units) + sts = ( + (spiketrain.t_stop - spiketrain.t_start) + * np.random.random(size=(n_surrogates, len(spiketrain))) + + spiketrain.t_start + ).rescale(spiketrain.units) # Round the surrogate data to decimal position, if requested if decimals is not None: sts = sts.round(decimals) # Convert the Quantity array to a list of SpikeTrains, and return them - return [neo.SpikeTrain(np.sort(st), t_start=spiketrain.t_start, - t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate) - for st in sts] + return [ + neo.SpikeTrain( + np.sort(st), + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ) + for st in sts + ] -@deprecated_alias(n='n_surrogates') +@deprecated_alias(n="n_surrogates") def shuffle_isis(spiketrain, n_surrogates=1, decimals=None): """ Generates surrogates of a spike train by inter-spike-interval (ISI) @@ -366,11 +398,15 @@ def shuffle_isis(spiketrain, n_surrogates=1, decimals=None): """ if len(spiketrain) == 0: - return [neo.SpikeTrain([] * spiketrain.units, - t_start=spiketrain.t_start, - t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate) - for _ in range(n_surrogates)] + return [ + neo.SpikeTrain( + [] * spiketrain.units, + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ) + for _ in range(n_surrogates) + ] # A correct sorting is necessary, to calculate the ISIs spiketrain = spiketrain.copy() @@ -385,18 +421,23 @@ def shuffle_isis(spiketrain, n_surrogates=1, decimals=None): # Create list of surrogate spike trains by random ISI permutation sts = [] for surrogate_id in range(n_surrogates): - surr_times = np.cumsum(np.random.permutation(isis)) * \ - spiketrain.units + spiketrain.t_start - sts.append(neo.SpikeTrain( - surr_times, t_start=spiketrain.t_start, - t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate)) + surr_times = ( + np.cumsum(np.random.permutation(isis)) * spiketrain.units + + spiketrain.t_start + ) + sts.append( + neo.SpikeTrain( + surr_times, + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ) + ) return sts -@deprecated_alias(n='n_surrogates') -def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None, - edges=True): +@deprecated_alias(n="n_surrogates") +def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None, edges=True): """ Generates surrogates of a spike train by spike train shifting. @@ -460,9 +501,11 @@ def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None, data = spiketrain.view(pq.Quantity) # Main: generate the surrogates by spike train shifting - surr = data.reshape( - (1, len(data))) + 2 * shift * np.random.random_sample( - (n_surrogates, 1)) - shift + surr = ( + data.reshape((1, len(data))) + + 2 * shift * np.random.random_sample((n_surrogates, 1)) + - shift + ) # Round the surrogate data to decimal position, if requested if decimals is not None: @@ -471,25 +514,39 @@ def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None, if edges is False: # Move all spikes outside [spiketrain.t_start, spiketrain.t_stop] to # the range's ends - surr = np.minimum(np.maximum(surr.simplified.magnitude, - spiketrain.t_start.simplified.magnitude), - spiketrain.t_stop.simplified.magnitude) * pq.s + surr = ( + np.minimum( + np.maximum( + surr.simplified.magnitude, spiketrain.t_start.simplified.magnitude + ), + spiketrain.t_stop.simplified.magnitude, + ) + * pq.s + ) else: # Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop] - tstart, tstop = spiketrain.t_start.simplified.magnitude, \ - spiketrain.t_stop.simplified.magnitude - surr = [np.sort(s[np.all([s >= tstart, s < tstop], axis=0)]) * pq.s - for s in surr.simplified.magnitude] + tstart, tstop = ( + spiketrain.t_start.simplified.magnitude, + spiketrain.t_stop.simplified.magnitude, + ) + surr = [ + np.sort(s[np.all([s >= tstart, s < tstop], axis=0)]) * pq.s + for s in surr.simplified.magnitude + ] # Return the surrogates as SpikeTrains - return [neo.SpikeTrain(s, t_start=spiketrain.t_start, - t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate - ).rescale(spiketrain.units) - for s in surr] + return [ + neo.SpikeTrain( + s, + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ).rescale(spiketrain.units) + for s in surr + ] -@deprecated_alias(binsize='bin_size', n='n_surrogates') +@deprecated_alias(binsize="bin_size", n="n_surrogates") def jitter_spikes(spiketrain, bin_size, n_surrogates=1): """ Generates surrogates of a spike train by spike jittering. @@ -559,8 +616,9 @@ def jitter_spikes(spiketrain, bin_size, n_surrogates=1): # Compute the bin id of each spike bin_ids = np.array( - (spiketrain.view(pq.Quantity) / - bin_size).rescale(pq.dimensionless).magnitude, dtype=int) + (spiketrain.view(pq.Quantity) / bin_size).rescale(pq.dimensionless).magnitude, + dtype=int, + ) # Compute the size of each time bin (as a numpy array) bin_sizes_dl = np.diff(bin_edges) @@ -574,16 +632,20 @@ def jitter_spikes(spiketrain, bin_size, n_surrogates=1): # poisson 0-1 spike trains to dilat * s + offset. Attach time unit again surr = np.sort(surr_poiss01 * dilats + offsets, axis=1) * std_unit - return [neo.SpikeTrain(s, t_start=spiketrain.t_start, - t_stop=spiketrain.t_stop, - sampling_rate=spiketrain.sampling_rate - ).rescale(spiketrain.units) - for s in surr] + return [ + neo.SpikeTrain( + s, + t_start=spiketrain.t_start, + t_stop=spiketrain.t_stop, + sampling_rate=spiketrain.sampling_rate, + ).rescale(spiketrain.units) + for s in surr + ] def bin_shuffling( - spiketrain, max_displacement, bin_size=None, n_surrogates=1, - sliding=False): + spiketrain, max_displacement, bin_size=None, n_surrogates=1, sliding=False +): """ Bin shuffling surrogate generation. @@ -618,16 +680,22 @@ def bin_shuffling( if isinstance(spiketrain, neo.SpikeTrain): if bin_size is None: raise ValueError( - 'If you want to create surrogates from neo.SpikeTrain objects,' - 'you need to specify the bin_size') + "If you want to create surrogates from neo.SpikeTrain objects," + "you need to specify the bin_size" + ) if sliding: warnings.warn( - 'The sliding option is not implemented yet for bin shuffling' - ' on continuos time spike trains. Results are given for' - ' sliding=False.', UserWarning) + "The sliding option is not implemented yet for bin shuffling" + " on continuos time spike trains. Results are given for" + " sliding=False.", + UserWarning, + ) return _continuous_time_bin_shuffling( - spiketrain, max_displacement=max_displacement, bin_size=bin_size, - n_surrogates=n_surrogates) + spiketrain, + max_displacement=max_displacement, + bin_size=bin_size, + n_surrogates=n_surrogates, + ) displacement_window = 2 * max_displacement @@ -642,8 +710,9 @@ def bin_shuffling( # shuffling the binned spike train within the window np.random.shuffle( surrogate_spiketrain[ - window_position: - window_position + displacement_window]) + window_position : window_position + displacement_window + ] + ) else: windows = st_length // displacement_window windows_remainder = st_length % displacement_window @@ -651,11 +720,12 @@ def bin_shuffling( # shuffling the binned spike train within the window np.random.shuffle( surrogate_spiketrain[ - window_position * displacement_window: - (window_position + 1) * displacement_window]) + window_position * displacement_window : (window_position + 1) + * displacement_window + ] + ) if windows_remainder != 0: - np.random.shuffle( - surrogate_spiketrain[windows * displacement_window:]) + np.random.shuffle(surrogate_spiketrain[windows * displacement_window :]) surrogate_spiketrain = surrogate_spiketrain.reshape((1, st_length)) surrogate_spiketrains.append( @@ -664,12 +734,15 @@ def bin_shuffling( bin_size=spiketrain.bin_size, t_start=spiketrain.t_start, t_stop=spiketrain.t_stop, - tolerance=None)) + tolerance=None, + ) + ) return surrogate_spiketrains -def _continuous_time_bin_shuffling(spiketrain, max_displacement, bin_size, - n_surrogates=1): +def _continuous_time_bin_shuffling( + spiketrain, max_displacement, bin_size, n_surrogates=1 +): """ Parameters @@ -698,36 +771,36 @@ def _continuous_time_bin_shuffling(spiketrain, max_displacement, bin_size, split_indices = np.searchsorted( bin_indices, - np.arange(displacement_window, binned_duration, displacement_window)) + np.arange(displacement_window, binned_duration, displacement_window), + ) - bin_indices = np.split( - bin_indices, - split_indices) + bin_indices = np.split(bin_indices, split_indices) surrogate_spiketrains = [] for surrogate_id in range(n_surrogates): - surrogate_bin_indices = np.empty(shape=len(bin_indices), - dtype=np.ndarray) + surrogate_bin_indices = np.empty(shape=len(bin_indices), dtype=np.ndarray) for i, bin_indices_slice in enumerate(bin_indices): - window_start = i*displacement_window + window_start = i * displacement_window random_indices = np.random.permutation(displacement_window) - surrogate_bin_indices[i] = \ - random_indices[bin_indices_slice - window_start] \ - + window_start + surrogate_bin_indices[i] = ( + random_indices[bin_indices_slice - window_start] + window_start + ) surrogate_bin_indices = np.concatenate(surrogate_bin_indices) bin_remainders = bin_size * np.random.random(len(spiketrain)) - surrogate_spiketrain = \ + surrogate_spiketrain = ( surrogate_bin_indices * bin_size + bin_remainders + t_start + ) # ensure last and first spike being inside the boundaries surrogate_spiketrain = surrogate_spiketrain[ - np.all((surrogate_spiketrain > t_start, - surrogate_spiketrain < t_stop), - axis=0)] + np.all( + (surrogate_spiketrain > t_start, surrogate_spiketrain < t_stop), axis=0 + ) + ] surrogate_spiketrain.sort() @@ -812,22 +885,23 @@ class JointISI(object): # Otherwise, the original spiketrain is copied N times. MIN_SPIKES = 3 - @deprecated_alias(num_bins='n_bins', refr_period='refractory_period') - def __init__(self, - spiketrain, - dither=15. * pq.ms, - truncation_limit=100. * pq.ms, - n_bins=100, - sigma=2. * pq.ms, - alternate=True, - use_sqrt=False, - method='window', - cutoff=True, - refractory_period=4. * pq.ms, - isi_dithering=False): - + @deprecated_alias(num_bins="n_bins", refr_period="refractory_period") + def __init__( + self, + spiketrain, + dither=15.0 * pq.ms, + truncation_limit=100.0 * pq.ms, + n_bins=100, + sigma=2.0 * pq.ms, + alternate=True, + use_sqrt=False, + method="window", + cutoff=True, + refractory_period=4.0 * pq.ms, + isi_dithering=False, + ): if not isinstance(spiketrain, neo.SpikeTrain): - raise TypeError('spiketrain must be of type neo.SpikeTrain') + raise TypeError("spiketrain must be of type neo.SpikeTrain") # A correct sorting is necessary to calculate the ISIs spiketrain = spiketrain.copy() @@ -841,9 +915,12 @@ def __init__(self, self.sigma = self._get_magnitude(sigma) self.alternate = alternate - if method not in ('fast', 'window'): - raise ValueError("The method can either be 'fast' or 'window', " - "but not '{}'".format(method)) + if method not in ("fast", "window"): + raise ValueError( + "The method can either be 'fast' or 'window', " "but not '{}'".format( + method + ) + ) self.method = method refractory_period = self._get_magnitude(refractory_period) @@ -863,14 +940,14 @@ def __init__(self, @property def refr_period(self): - warnings.warn("'.refr_period' is deprecated; use '.refractory_period'", - DeprecationWarning) + warnings.warn( + "'.refr_period' is deprecated; use '.refractory_period'", DeprecationWarning + ) return self.refractory_period @property def num_bins(self): - warnings.warn("'.num_bins' is deprecated; use '.n_bins'", - DeprecationWarning) + warnings.warn("'.num_bins' is deprecated; use '.n_bins'", DeprecationWarning) return self.n_bins def _get_magnitude(self, quantity): @@ -982,15 +1059,15 @@ def joint_isi_histogram(self): isis = self.isi if not self.isi_dithering: joint_isi_histogram = np.histogram2d( - isis[:-1], isis[1:], + isis[:-1], + isis[1:], bins=[self.n_bins, self.n_bins], - range=[[0., self.truncation_limit], - [0., self.truncation_limit]])[0] + range=[[0.0, self.truncation_limit], [0.0, self.truncation_limit]], + )[0] else: isi_histogram = np.histogram( - isis, - bins=self.n_bins, - range=[0., self.truncation_limit])[0] + isis, bins=self.n_bins, range=[0.0, self.truncation_limit] + )[0] joint_isi_histogram = np.outer(isi_histogram, isi_histogram) if self.use_sqrt: @@ -999,15 +1076,16 @@ def joint_isi_histogram(self): if self.sigma: if self.cutoff: start_index = self._isi_to_index(self.refractory_period) - joint_isi_histogram[ - start_index:, start_index:] = gaussian_filter( - joint_isi_histogram[start_index:, start_index:], - sigma=self.sigma / self.bin_width) + joint_isi_histogram[start_index:, start_index:] = gaussian_filter( + joint_isi_histogram[start_index:, start_index:], + sigma=self.sigma / self.bin_width, + ) joint_isi_histogram[:start_index, :] = 0 joint_isi_histogram[:, :start_index] = 0 else: joint_isi_histogram = gaussian_filter( - joint_isi_histogram, sigma=self.sigma / self.bin_width) + joint_isi_histogram, sigma=self.sigma / self.bin_width + ) return joint_isi_histogram @staticmethod @@ -1030,7 +1108,7 @@ def _normalize_cumulative_distribution(array): If `array` does not contain all equal elements, a only-zeros array is returned. """ - if array[-1] - array[0] > 0.: + if array[-1] - array[0] > 0.0: return (array - array[0]) / (array[-1] - array[0]) return np.zeros_like(array) @@ -1064,8 +1142,9 @@ def dithering(self, n_surrogates=1): for _ in range(n_surrogates): dithered_isi = self._get_dithered_isi(isi_to_dither) - dithered_st = self.spiketrain[0].magnitude + \ - np.r_[0., np.cumsum(dithered_isi)] + dithered_st = ( + self.spiketrain[0].magnitude + np.r_[0.0, np.cumsum(dithered_isi)] + ) sampling_rate = self.spiketrain.sampling_rate # Due to rounding errors, the last spike may be above t_stop. @@ -1073,44 +1152,47 @@ def dithering(self, n_surrogates=1): if dithered_st[-1] > self.spiketrain.t_stop: dithered_st[-1] = self.spiketrain.t_stop - dithered_st = neo.SpikeTrain(dithered_st * self._unit, - t_start=self.spiketrain.t_start, - t_stop=self.spiketrain.t_stop, - sampling_rate=sampling_rate) + dithered_st = neo.SpikeTrain( + dithered_st * self._unit, + t_start=self.spiketrain.t_start, + t_stop=self.spiketrain.t_stop, + sampling_rate=sampling_rate, + ) dithered_sts.append(dithered_st) return dithered_sts def _determine_cumulative_functions(self): rotated_jisih = np.rot90(self.joint_isi_histogram()) - if self.method == 'fast': + if self.method == "fast": self._jisih_cumulatives = [] for double_index in range(self.n_bins): # Taking anti-diagonals of the original joint-ISI histogram diagonal = np.diagonal( - rotated_jisih, offset=-self.n_bins + double_index + 1) + rotated_jisih, offset=-self.n_bins + double_index + 1 + ) jisih_cum = self._normalize_cumulative_distribution( - np.r_[0., np.cumsum(diagonal)]) + np.r_[0.0, np.cumsum(diagonal)] + ) self._jisih_cumulatives.append(jisih_cum) - self._jisih_cumulatives = np.array( - self._jisih_cumulatives, dtype=object) + self._jisih_cumulatives = np.array(self._jisih_cumulatives, dtype=object) else: self._jisih_cumulatives = self._window_cumulatives(rotated_jisih) def _window_cumulatives(self, rotated_jisih): jisih_diag_cums = self._window_diagonal_cumulatives(rotated_jisih) jisih_cumulatives = np.zeros( - (self.n_bins, self.n_bins, - 2 * self._max_change_index + 1)) + (self.n_bins, self.n_bins, 2 * self._max_change_index + 1) + ) for curr_isi_id in range(self.n_bins): for next_isi_id in range(self.n_bins - curr_isi_id): double_index = next_isi_id + curr_isi_id cum_slice = jisih_diag_cums[ double_index, - curr_isi_id: curr_isi_id + 2 * self._max_change_index + 1] + curr_isi_id : curr_isi_id + 2 * self._max_change_index + 1, + ] - normalized_cum = self._normalize_cumulative_distribution( - cum_slice) + normalized_cum = self._normalize_cumulative_distribution(cum_slice) jisih_cumulatives[curr_isi_id][next_isi_id] = normalized_cum return jisih_cumulatives @@ -1118,26 +1200,27 @@ def _window_diagonal_cumulatives(self, rotated_jisih): # An element of the first axis is defined as the sum of indices # for previous and subsequent ISI. - jisih_diag_cums = np.zeros((self.n_bins, - self.n_bins - + 2 * self._max_change_index)) + jisih_diag_cums = np.zeros( + (self.n_bins, self.n_bins + 2 * self._max_change_index) + ) # double_index corresponds to the sum of the indices for the previous # and the subsequent ISI. for double_index in range(self.n_bins): - anti_diagonal = np.diagonal( - rotated_jisih, - self.n_bins + double_index + 1) + anti_diagonal = np.diagonal(rotated_jisih, -self.n_bins + double_index + 1) - right_padding = jisih_diag_cums.shape[1] - \ - len(anti_diagonal) - self._max_change_index + right_padding = ( + jisih_diag_cums.shape[1] - len(anti_diagonal) - self._max_change_index + ) cumulated_diagonal = np.cumsum(anti_diagonal) padded_cumulated_diagonal = np.pad( cumulated_diagonal, pad_width=(self._max_change_index, right_padding), - mode='constant', - constant_values=(0., cumulated_diagonal[-1])) + mode="constant", + constant_values=(0.0, cumulated_diagonal[-1]), + ) jisih_diag_cums[double_index] = padded_cumulated_diagonal @@ -1155,49 +1238,37 @@ def _get_dithered_isi(self, isi_to_dither): for start in range(sampling_rhythm): dithered_isi_indices = self._isi_to_index(dithered_isi) - for i in range(start, number_of_isis - 1, - sampling_rhythm): - step = self._get_dithering_step( - dithered_isi, - dithered_isi_indices, - i) + for i in range(start, number_of_isis - 1, sampling_rhythm): + step = self._get_dithering_step(dithered_isi, dithered_isi_indices, i) dithered_isi[i] += step dithered_isi[i + 1] -= step return dithered_isi - def _get_dithering_step(self, - dithered_isi, - dithered_isi_indices, - i): + def _get_dithering_step(self, dithered_isi, dithered_isi_indices, i): curr_isi_id = dithered_isi_indices[i] next_isi_id = dithered_isi_indices[i + 1] double_index = curr_isi_id + next_isi_id if double_index < self.n_bins: - if self.method == 'fast': - cum_dist_func = self._jisih_cumulatives[ - double_index] + if self.method == "fast": + cum_dist_func = self._jisih_cumulatives[double_index] compare_isi = self._index_to_isi(curr_isi_id + 1) else: - cum_dist_func = self._jisih_cumulatives[ - curr_isi_id][next_isi_id] + cum_dist_func = self._jisih_cumulatives[curr_isi_id][next_isi_id] compare_isi = self._max_change_isi - if cum_dist_func[-1] > 0.: + if cum_dist_func[-1] > 0.0: # when the method is 'fast', new_isi_id is where the current # ISI id should go to. new_isi_id = np.searchsorted(cum_dist_func, random.random()) - step = self._index_to_isi(new_isi_id)\ - - compare_isi + step = self._index_to_isi(new_isi_id) - compare_isi return step return self._uniform_dither_not_jisi_movable_spikes( - dithered_isi[i], - dithered_isi[i + 1]) + dithered_isi[i], dithered_isi[i + 1] + ) - def _uniform_dither_not_jisi_movable_spikes(self, - curr_isi, - next_isi): + def _uniform_dither_not_jisi_movable_spikes(self, curr_isi, next_isi): left_dither = min(curr_isi - self.refractory_period, self.dither) right_dither = min(next_isi - self.refractory_period, self.dither) step = random.random() * (right_dither + left_dither) - left_dither @@ -1234,28 +1305,34 @@ def trial_shifting(spiketrains, dither, n_surrogates=1): dither = dither.simplified.magnitude units = spiketrains[0].units - t_starts = [single_trial_st.t_start.simplified.magnitude - for single_trial_st in spiketrains] - t_stops = [single_trial_st.t_stop.simplified.magnitude - for single_trial_st in spiketrains] - sampling_rates = [single_trial_st.sampling_rate - for single_trial_st in spiketrains] - spiketrains = [single_trial_st.simplified.magnitude - for single_trial_st in spiketrains] - - surrogate_spiketrains = \ - _trial_shifting(spiketrains, dither, t_starts, t_stops, - n_surrogates) - - surrogate_spiketrains = \ - [[neo.SpikeTrain( - surrogate_spiketrain[trial_id] * pq.s, - t_start=t_starts[trial_id] * pq.s, - t_stop=t_stops[trial_id] * pq.s, - units=units, - sampling_rate=sampling_rates[trial_id]) - for trial_id in range(len(surrogate_spiketrain))] - for surrogate_spiketrain in surrogate_spiketrains] + t_starts = [ + single_trial_st.t_start.simplified.magnitude for single_trial_st in spiketrains + ] + t_stops = [ + single_trial_st.t_stop.simplified.magnitude for single_trial_st in spiketrains + ] + sampling_rates = [single_trial_st.sampling_rate for single_trial_st in spiketrains] + spiketrains = [ + single_trial_st.simplified.magnitude for single_trial_st in spiketrains + ] + + surrogate_spiketrains = _trial_shifting( + spiketrains, dither, t_starts, t_stops, n_surrogates + ) + + surrogate_spiketrains = [ + [ + neo.SpikeTrain( + surrogate_spiketrain[trial_id] * pq.s, + t_start=t_starts[trial_id] * pq.s, + t_stop=t_stops[trial_id] * pq.s, + units=units, + sampling_rate=sampling_rates[trial_id], + ) + for trial_id in range(len(surrogate_spiketrain)) + ] + for surrogate_spiketrain in surrogate_spiketrains + ] return surrogate_spiketrains @@ -1271,10 +1348,13 @@ def _trial_shifting(spiketrains, dither, t_starts, t_stops, n_surrogates): # looping over all trials for trial_id, single_trial_st in enumerate(copied_spiketrain): single_trial_st += dither * (2 * random.random() - 1) - single_trial_st = np.remainder( - single_trial_st - t_starts[trial_id], - t_stops[trial_id] - t_starts[trial_id] - ) + t_starts[trial_id] + single_trial_st = ( + np.remainder( + single_trial_st - t_starts[trial_id], + t_stops[trial_id] - t_starts[trial_id], + ) + + t_starts[trial_id] + ) single_trial_st.sort() surrogate_spiketrain.append(single_trial_st) @@ -1284,7 +1364,8 @@ def _trial_shifting(spiketrains, dither, t_starts, t_stops, n_surrogates): def _trial_shifting_of_concatenated_spiketrain( - spiketrain, dither, trial_length, trial_separation, n_surrogates=1): + spiketrain, dither, trial_length, trial_separation, n_surrogates=1 +): """ Generates surrogates of a spike train by trial shifting. @@ -1322,30 +1403,37 @@ def _trial_shifting_of_concatenated_spiketrain( trial_separation = trial_separation.simplified.magnitude dither = dither.simplified.magnitude n_trials = int((t_stop - t_start) // (trial_length + trial_separation)) - t_starts = t_start + \ - np.arange(n_trials) * (trial_length + trial_separation) + t_starts = t_start + np.arange(n_trials) * (trial_length + trial_separation) t_stops = t_starts + trial_length spiketrains = spiketrain.simplified.magnitude - spiketrains = [spiketrains[(spiketrains >= t_starts[trial_id]) & - (spiketrains <= t_stops[trial_id])] - for trial_id in range(n_trials)] + spiketrains = [ + spiketrains[ + (spiketrains >= t_starts[trial_id]) & (spiketrains <= t_stops[trial_id]) + ] + for trial_id in range(n_trials) + ] surrogate_spiketrains = _trial_shifting( - spiketrains, dither, t_starts, t_stops, n_surrogates) - - surrogate_spiketrains = [neo.SpikeTrain( - np.hstack(surrogate_spiketrain) * pq.s, - t_start=t_start * pq.s, - t_stop=t_stop * pq.s, - units=units, - sampling_rate=spiketrain.sampling_rate) - for surrogate_spiketrain in surrogate_spiketrains] + spiketrains, dither, t_starts, t_stops, n_surrogates + ) + + surrogate_spiketrains = [ + neo.SpikeTrain( + np.hstack(surrogate_spiketrain) * pq.s, + t_start=t_start * pq.s, + t_stop=t_stop * pq.s, + units=units, + sampling_rate=spiketrain.sampling_rate, + ) + for surrogate_spiketrain in surrogate_spiketrains + ] return surrogate_spiketrains -@deprecated_alias(n='n_surrogates', surr_method='method') -def surrogates(spiketrain, n_surrogates=1, method='dither_spike_train', - dt=None, **kwargs): +@deprecated_alias(n="n_surrogates", surr_method="method") +def surrogates( + spiketrain, n_surrogates=1, method="dither_spike_train", dt=None, **kwargs +): """ Generates surrogates of a `spiketrain` by a desired generation method. @@ -1418,59 +1506,71 @@ def surrogates(spiketrain, n_surrogates=1, method='dither_spike_train', if isinstance(spiketrain, list): if not isinstance(spiketrain[0], neo.SpikeTrain): - raise TypeError('spiketrain must be an instance neo.SpikeTrain or' - ' a list of neo.SpikeTrain') + raise TypeError( + "spiketrain must be an instance neo.SpikeTrain or" + " a list of neo.SpikeTrain" + ) elif not isinstance(spiketrain, neo.SpikeTrain): - raise TypeError('spiketrain must be an instance neo.SpikeTrain or' - ' a list of neo.SpikeTrain') + raise TypeError( + "spiketrain must be an instance neo.SpikeTrain or" + " a list of neo.SpikeTrain" + ) if method == "dither_spikes_with_refractory_period": - warnings.warn("'dither_spikes_with_refractory_period' is deprecated " - "in favor of 'dither_spikes'", DeprecationWarning) + warnings.warn( + "'dither_spikes_with_refractory_period' is deprecated " + "in favor of 'dither_spikes'", + DeprecationWarning, + ) # Define the surrogate function to use, depending on the specified method surrogate_types = { - 'dither_spike_train': dither_spike_train, - 'dither_spikes': dither_spikes, - 'dither_spikes_with_refractory_period': dither_spikes, - 'jitter_spikes': jitter_spikes, - 'randomise_spikes': randomise_spikes, - 'shuffle_isis': shuffle_isis, - 'bin_shuffling': bin_shuffling, - 'trial_shifting': trial_shifting, - 'joint_isi_dithering': lambda n: JointISI( - spiketrain, **kwargs).dithering(n), - 'isi_dithering': lambda n: JointISI( - spiketrain, isi_dithering=True, **kwargs).dithering(n) + "dither_spike_train": dither_spike_train, + "dither_spikes": dither_spikes, + "dither_spikes_with_refractory_period": dither_spikes, + "jitter_spikes": jitter_spikes, + "randomise_spikes": randomise_spikes, + "shuffle_isis": shuffle_isis, + "bin_shuffling": bin_shuffling, + "trial_shifting": trial_shifting, + "joint_isi_dithering": lambda n: JointISI(spiketrain, **kwargs).dithering(n), + "isi_dithering": lambda n: JointISI( + spiketrain, isi_dithering=True, **kwargs + ).dithering(n), } if method not in surrogate_types.keys(): - raise ValueError("Specified surrogate method ('{}') " - "is not valid".format(method)) + raise ValueError( + "Specified surrogate method ('{}') " "is not valid".format(method) + ) method = surrogate_types[method] if dt is None and method not in (randomise_spikes, shuffle_isis): - raise ValueError(f"'{method.__name__}' method requires 'dt' parameter " - f"to be set") + raise ValueError( + f"'{method.__name__}' method requires 'dt' parameter " f"to be set" + ) if method in (dither_spike_train, dither_spikes): - return method( - spiketrain, dt, n_surrogates=n_surrogates, **kwargs) + return method(spiketrain, dt, n_surrogates=n_surrogates, **kwargs) if method in (randomise_spikes, shuffle_isis): return method(spiketrain, n_surrogates=n_surrogates, **kwargs) if method is jitter_spikes: return method(spiketrain, dt, n_surrogates=n_surrogates) if method is trial_shifting: if isinstance(spiketrain, list): - return method( - spiketrain, dither=dt, n_surrogates=n_surrogates) + return method(spiketrain, dither=dt, n_surrogates=n_surrogates) return _trial_shifting_of_concatenated_spiketrain( - spiketrain, dither=dt, n_surrogates=n_surrogates, **kwargs) + spiketrain, dither=dt, n_surrogates=n_surrogates, **kwargs + ) if method is bin_shuffling: max_displacement = int( - dt.simplified.magnitude / kwargs['bin_size'].simplified.magnitude) + dt.simplified.magnitude / kwargs["bin_size"].simplified.magnitude + ) return method( - spiketrain, max_displacement=max_displacement, - bin_size=kwargs['bin_size'], n_surrogates=n_surrogates) + spiketrain, + max_displacement=max_displacement, + bin_size=kwargs["bin_size"], + n_surrogates=n_surrogates, + ) # surr_method is 'joint_isi_dithering' or isi_dithering: return method(n_surrogates) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 946a24ae2..05894adb8 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -16,6 +16,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + from __future__ import division, print_function, unicode_literals import warnings @@ -29,15 +30,12 @@ from elephant.statistics import Complexity from elephant.utils import is_time_quantity, check_same_units -SpikeContrastTrace = namedtuple("SpikeContrastTrace", ( - "contrast", "active_spiketrains", "synchrony", "bin_size")) +SpikeContrastTrace = namedtuple( + "SpikeContrastTrace", ("contrast", "active_spiketrains", "synchrony", "bin_size") +) -__all__ = [ - "SpikeContrastTrace", - "spike_contrast", - "Synchrotool" -] +__all__ = ["SpikeContrastTrace", "spike_contrast", "Synchrotool"] def _get_theta_and_n_per_bin(spiketrains, t_start, t_stop, bin_size): @@ -48,10 +46,9 @@ def _get_theta_and_n_per_bin(spiketrains, t_start, t_stop, bin_size): bin_step = bin_size / 2 edges = np.arange(t_start, t_stop + bin_step, bin_step) # Calculate histogram for every spike train - histogram = np.vstack([ - _binning_half_overlap(st, edges=edges) - for st in spiketrains - ]) + histogram = np.vstack( + [_binning_half_overlap(st, edges=edges) for st in spiketrains] + ) # Amount of spikes per bin theta = histogram.sum(axis=0) # Amount of active spike trains per bin @@ -69,9 +66,14 @@ def _binning_half_overlap(spiketrain, edges): return histogram -def spike_contrast(spiketrains, t_start=None, t_stop=None, - min_bin=10 * pq.ms, bin_shrink_factor=0.9, - return_trace=False): +def spike_contrast( + spiketrains, + t_start=None, + t_stop=None, + min_bin=10 * pq.ms, + bin_shrink_factor=0.9, + return_trace=False, +): """ Calculates the synchrony of spike trains, according to :cite:`synchrony-Ciba18_136`. The spike trains can have different lengths. @@ -156,12 +158,14 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None, 0.419 """ - if not 0. < bin_shrink_factor < 1.: - raise ValueError(f"'bin_shrink_factor' ({bin_shrink_factor}) must be " - "in range (0, 1).") + if not 0.0 < bin_shrink_factor < 1.0: + raise ValueError( + f"'bin_shrink_factor' ({bin_shrink_factor}) must be " "in range (0, 1)." + ) if not len(spiketrains) > 1: - raise ValueError("Spike contrast measure requires more than 1 input " - "spiketrain.") + raise ValueError( + "Spike contrast measure requires more than 1 input " "spiketrain." + ) check_same_units(spiketrains, object_type=neo.SpikeTrain) if not is_time_quantity(t_start, t_stop, allow_none=True): raise TypeError("'t_start' and 't_stop' must be time quantities.") @@ -180,8 +184,9 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None, t_stop = t_stop.rescale(units).item() min_bin = min_bin.rescale(units).item() - spiketrains = [times[(times >= t_start) & (times <= t_stop)] - for times in spiketrains] + spiketrains = [ + times[(times >= t_start) & (times <= t_stop)] for times in spiketrains + ] n_spiketrains = len(spiketrains) n_spikes_total = sum(map(len, spiketrains)) @@ -208,14 +213,12 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None, while bin_size >= bin_min: bin_sizes.append(bin_size) # Calculate Theta and n - theta_k, n_k = _get_theta_and_n_per_bin(spiketrains, - t_start=t_start, - t_stop=t_stop, - bin_size=bin_size) + theta_k, n_k = _get_theta_and_n_per_bin( + spiketrains, t_start=t_start, t_stop=t_stop, bin_size=bin_size + ) # calculate synchrony_curve = contrast * active_st - active_st = (np.sum(n_k * theta_k) / np.sum(theta_k) - 1) / ( - n_spiketrains - 1) + active_st = (np.sum(n_k * theta_k) / np.sum(theta_k) - 1) / (n_spiketrains - 1) contrast = np.sum(np.abs(np.diff(theta_k))) / (2 * n_spikes_total) # Contrast: sum(|derivation|) / (2*#Spikes) synchrony = contrast * active_st @@ -261,23 +264,27 @@ class Synchrotool(Complexity): """ - def __init__(self, spiketrains, - sampling_rate, - bin_size=None, - binary=True, - spread=0, - tolerance=1e-8): - + def __init__( + self, + spiketrains, + sampling_rate, + bin_size=None, + binary=True, + spread=0, + tolerance=1e-8, + ): self.annotated = False - super(Synchrotool, self).__init__(spiketrains=spiketrains, - bin_size=bin_size, - sampling_rate=sampling_rate, - binary=binary, - spread=spread, - tolerance=tolerance) + super(Synchrotool, self).__init__( + spiketrains=spiketrains, + bin_size=bin_size, + sampling_rate=sampling_rate, + binary=binary, + spread=spread, + tolerance=tolerance, + ) - def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): + def delete_synchrofacts(self, threshold, in_place=False, mode="delete"): """ Delete or extract synchronous spiking events. @@ -331,13 +338,16 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): if not self.annotated: self.annotate_synchrofacts() - if mode not in ['delete', 'extract']: - raise ValueError(f"Invalid mode '{mode}'. Valid modes are: " - f"'delete', 'extract'") + if mode not in ["delete", "extract"]: + raise ValueError( + f"Invalid mode '{mode}'. Valid modes are: " f"'delete', 'extract'" + ) if threshold <= 1: - raise ValueError('A deletion threshold <= 1 would result ' - 'in the deletion of all spikes.') + raise ValueError( + "A deletion threshold <= 1 would result " + "in the deletion of all spikes." + ) if in_place: spiketrain_list = self.input_spiketrains @@ -345,8 +355,8 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): spiketrain_list = deepcopy(self.input_spiketrains) for idx, st in enumerate(spiketrain_list): - mask = st.array_annotations['complexity'] < threshold - if mode == 'extract': + mask = st.array_annotations["complexity"] < threshold + if mode == "extract": mask = np.invert(mask) new_st = st[mask] if in_place and st.segment is not None: @@ -354,29 +364,28 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): try: # replace link to spiketrain in segment - new_index = self._get_spiketrain_index( - segment.spiketrains, st) + new_index = self._get_spiketrain_index(segment.spiketrains, st) segment.spiketrains[new_index] = new_st except ValueError: # st is not in this segment even though it points to it - warnings.warn(f"The SpikeTrain at index {idx} of the " - "input list spiketrains has a " - "unidirectional uplink to a segment in " - "whose segment.spiketrains list it does not " - "appear. Only the spiketrains in the input " - "list will be replaced. You can suppress " - "this warning by setting " - "spiketrain.segment=None for the input " - "spiketrains.") + warnings.warn( + f"The SpikeTrain at index {idx} of the " + "input list spiketrains has a " + "unidirectional uplink to a segment in " + "whose segment.spiketrains list it does not " + "appear. Only the spiketrains in the input " + "list will be replaced. You can suppress " + "this warning by setting " + "spiketrain.segment=None for the input " + "spiketrains." + ) block = segment.block if block is not None: # replace link to spiketrain in groups for group in block.groups: try: - idx = self._get_spiketrain_index( - group.spiketrains, - st) + idx = self._get_spiketrain_index(group.spiketrains, st) except ValueError: # st is not in this group, move to next group continue @@ -392,21 +401,20 @@ def annotate_synchrofacts(self): Annotate the complexity of each spike in the ``self.epoch.array_annotations`` *in-place*. """ - epoch_complexities = self.epoch.array_annotations['complexity'] + epoch_complexities = self.epoch.array_annotations["complexity"] right_edges = ( self.epoch.times.magnitude.flatten() - + self.epoch.durations.rescale( - self.epoch.times.units).magnitude.flatten() + + self.epoch.durations.rescale(self.epoch.times.units).magnitude.flatten() ) for idx, st in enumerate(self.input_spiketrains): - # all indices of spikes that are within the half-open intervals # defined by the boundaries # note that every second entry in boundaries is an upper boundary spike_to_epoch_idx = np.searchsorted( right_edges, - st.times.rescale(self.epoch.times.units).magnitude.flatten()) + st.times.rescale(self.epoch.times.units).magnitude.flatten(), + ) complexity_per_spike = epoch_complexities[spike_to_epoch_idx] st.array_annotate(complexity=complexity_per_spike) diff --git a/elephant/sta.py b/elephant/sta.py index aa2657059..8b602a6dc 100644 --- a/elephant/sta.py +++ b/elephant/sta.py @@ -24,10 +24,7 @@ from .conversion import BinnedSpikeTrain -__all__ = [ - "spike_triggered_average", - "spike_field_coherence" -] +__all__ = ["spike_triggered_average", "spike_field_coherence"] def spike_triggered_average(signal, spiketrains, window): @@ -81,32 +78,40 @@ def spike_triggered_average(signal, spiketrains, window): # window_stoptime: time to specify the stop time of the averaging # interval relative to a spike window_starttime, window_stoptime = window - if not (isinstance(window_starttime, pq.quantity.Quantity) and - window_starttime.dimensionality.simplified == - pq.Quantity(1, "s").dimensionality): - raise TypeError("The start time of the window (window[0]) " - "must be a time quantity.") - if not (isinstance(window_stoptime, pq.quantity.Quantity) and - window_stoptime.dimensionality.simplified == - pq.Quantity(1, "s").dimensionality): - raise TypeError("The stop time of the window (window[1]) " - "must be a time quantity.") + if not ( + isinstance(window_starttime, pq.quantity.Quantity) + and window_starttime.dimensionality.simplified + == pq.Quantity(1, "s").dimensionality + ): + raise TypeError( + "The start time of the window (window[0]) " "must be a time quantity." + ) + if not ( + isinstance(window_stoptime, pq.quantity.Quantity) + and window_stoptime.dimensionality.simplified + == pq.Quantity(1, "s").dimensionality + ): + raise TypeError( + "The stop time of the window (window[1]) " "must be a time quantity." + ) if window_stoptime <= window_starttime: - raise ValueError("The start time of the window (window[0]) must be " - "earlier than the stop time of the window (window[1]).") + raise ValueError( + "The start time of the window (window[0]) must be " + "earlier than the stop time of the window (window[1])." + ) # checks on signal if not isinstance(signal, AnalogSignal): - raise TypeError( - "Signal must be an AnalogSignal, not %s." % type(signal)) + raise TypeError("Signal must be an AnalogSignal, not %s." % type(signal)) if len(signal.shape) > 1: # num_signals: number of analog signals num_signals = signal.shape[1] else: raise ValueError("Empty analog signal, hence no averaging possible.") if window_stoptime - window_starttime > signal.t_stop - signal.t_start: - raise ValueError("The chosen time window is larger than the " - "time duration of the signal.") + raise ValueError( + "The chosen time window is larger than the " "time duration of the signal." + ) # spiketrains type check if isinstance(spiketrains, (np.ndarray, SpikeTrain)): @@ -116,11 +121,13 @@ def spike_triggered_average(signal, spiketrains, window): if not isinstance(st, (np.ndarray, SpikeTrain)): raise TypeError( "spiketrains must be a SpikeTrain, a numpy ndarray, or a " - "list of one of those, not %s." % type(spiketrains)) + "list of one of those, not %s." % type(spiketrains) + ) else: raise TypeError( "spiketrains must be a SpikeTrain, a numpy ndarray, or a list of " - "one of those, not %s." % type(spiketrains)) + "one of those, not %s." % type(spiketrains) + ) # multiplying spiketrain in case only a single spiketrain is given if len(spiketrains) == 1 and num_signals != 1: @@ -131,28 +138,35 @@ def spike_triggered_average(signal, spiketrains, window): # checking for matching numbers of signals and spiketrains if num_signals != len(spiketrains): - raise ValueError( - "The number of signals and spiketrains has to be the same.") + raise ValueError("The number of signals and spiketrains has to be the same.") # checking the times of signal and spiketrains for i in range(num_signals): if spiketrains[i].t_start < signal.t_start: raise ValueError( "The spiketrain indexed by %i starts earlier than " - "the analog signal." % i) + "the analog signal." % i + ) if spiketrains[i].t_stop > signal.t_stop: raise ValueError( "The spiketrain indexed by %i stops later than " - "the analog signal." % i) + "the analog signal." % i + ) # *** Main algorithm: *** # window_bins: number of bins of the chosen averaging interval - window_bins = int(np.ceil(((window_stoptime - window_starttime) * - signal.sampling_rate).simplified)) + window_bins = int( + np.ceil( + ((window_stoptime - window_starttime) * signal.sampling_rate).simplified + ) + ) # result_sta: array containing finally the spike-triggered averaged signal - result_sta = AnalogSignal(np.zeros((window_bins, num_signals)), - sampling_rate=signal.sampling_rate, units=signal.units) + result_sta = AnalogSignal( + np.zeros((window_bins, num_signals)), + sampling_rate=signal.sampling_rate, + units=signal.units, + ) # setting of correct times of the spike-triggered average # relative to the spike result_sta.t_start = window_starttime @@ -164,15 +178,22 @@ def spike_triggered_average(signal, spiketrains, window): # summing over all respective signal intervals around spiketimes for spiketime in spiketrains[i]: # checks for sufficient signal data around spiketime - if (spiketime + window_starttime >= signal.t_start and - spiketime + window_stoptime <= signal.t_stop): + if ( + spiketime + window_starttime >= signal.t_start + and spiketime + window_stoptime <= signal.t_stop + ): # calculating the startbin in the analog signal of the # averaging window for spike - startbin = int(np.floor(((spiketime + window_starttime - - signal.t_start) * signal.sampling_rate).simplified)) + startbin = int( + np.floor( + ( + (spiketime + window_starttime - signal.t_start) + * signal.sampling_rate + ).simplified + ) + ) # adds the signal in selected interval relative to the spike - result_sta[:, i] += signal[ - startbin: startbin + window_bins, i] + result_sta[:, i] += signal[startbin : startbin + window_bins, i] # counting of the used spikes used_spikes[i] += 1 else: @@ -185,8 +206,7 @@ def spike_triggered_average(signal, spiketrains, window): total_used_spikes += used_spikes[i] if total_used_spikes == 0: - warnings.warn( - "No spike at all was either found or used for averaging") + warnings.warn("No spike at all was either found or used for averaging") result_sta.annotate(used_spikes=used_spikes, unused_spikes=unused_spikes) return result_sta @@ -268,23 +288,25 @@ def spike_field_coherence(signal, spiketrain, **kwargs): """ - if not hasattr(scipy.signal, 'coherence'): - raise AttributeError('scipy.signal.coherence is not available. The sfc ' - 'function uses scipy.signal.coherence for ' - 'the coherence calculation. This function is ' - 'available for scipy version 0.16 or newer. ' - 'Please update you scipy version.') + if not hasattr(scipy.signal, "coherence"): + raise AttributeError( + "scipy.signal.coherence is not available. The sfc " + "function uses scipy.signal.coherence for " + "the coherence calculation. This function is " + "available for scipy version 0.16 or newer. " + "Please update you scipy version." + ) # spiketrains type check if not isinstance(spiketrain, (SpikeTrain, BinnedSpikeTrain)): raise TypeError( "spiketrain must be of type SpikeTrain or BinnedSpikeTrain, " - "not %s." % type(spiketrain)) + "not %s." % type(spiketrain) + ) # checks on analogsignal if not isinstance(signal, AnalogSignal): - raise TypeError( - "Signal must be an AnalogSignal, not %s." % type(signal)) + raise TypeError("Signal must be an AnalogSignal, not %s." % type(signal)) if len(signal.shape) > 1: # num_signals: number of individual traces in the analog signal num_signals = signal.shape[1] @@ -296,22 +318,20 @@ def spike_field_coherence(signal, spiketrain, **kwargs): # bin spiketrain if necessary if isinstance(spiketrain, SpikeTrain): - spiketrain = BinnedSpikeTrain( - spiketrain, bin_size=signal.sampling_period) + spiketrain = BinnedSpikeTrain(spiketrain, bin_size=signal.sampling_period) # check the start and stop times of signal and spike trains if spiketrain.t_start < signal.t_start: - raise ValueError( - "The spiketrain starts earlier than the analog signal.") + raise ValueError("The spiketrain starts earlier than the analog signal.") if spiketrain.t_stop > signal.t_stop: - raise ValueError( - "The spiketrain stops later than the analog signal.") + raise ValueError("The spiketrain stops later than the analog signal.") # check equal time resolution for both signals if spiketrain.bin_size != signal.sampling_period: raise ValueError( "The spiketrain and signal must have a " - "common sampling frequency / bin_size") + "common sampling frequency / bin_size" + ) # calculate how many bins to add on the left of the binned spike train delta_t = spiketrain.t_start - signal.t_start @@ -324,13 +344,20 @@ def spike_field_coherence(signal, spiketrain, **kwargs): # duplicate spike trains spiketrain_array = np.zeros((1, len_signals)) spiketrain_array[0, left_edge:right_edge] = spiketrain.to_array() - spiketrains_array = np.repeat(spiketrain_array, repeats=num_signals, axis=0).transpose() + spiketrains_array = np.repeat( + spiketrain_array, repeats=num_signals, axis=0 + ).transpose() # calculate coherence frequencies, sfc = scipy.signal.coherence( - spiketrains_array, signal.magnitude, - fs=signal.sampling_rate.rescale('Hz').magnitude, - axis=0, **kwargs) - - return (pq.Quantity(sfc, units=pq.dimensionless), - pq.Quantity(frequencies, units=pq.Hz)) + spiketrains_array, + signal.magnitude, + fs=signal.sampling_rate.rescale("Hz").magnitude, + axis=0, + **kwargs, + ) + + return ( + pq.Quantity(sfc, units=pq.dimensionless), + pq.Quantity(frequencies, units=pq.Hz), + ) diff --git a/elephant/statistics.py b/elephant/statistics.py index 0ab389572..b251c650a 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -80,8 +80,12 @@ import elephant.kernels as kernels import elephant.trials from elephant.conversion import BinnedSpikeTrain -from elephant.utils import deprecated_alias, check_neo_consistency, \ - is_time_quantity, round_binning_errors +from elephant.utils import ( + deprecated_alias, + check_neo_consistency, + is_time_quantity, + round_binning_errors, +) # do not import unicode_literals # (quantities rescale does not work with unicodes) @@ -99,7 +103,7 @@ "complexity_pdf", "Complexity", "fftkernel", - "optimal_kernel_bandwidth" + "optimal_kernel_bandwidth", ] cv = scipy.stats.variation @@ -150,8 +154,9 @@ def isi(spiketrain, axis=-1): else: intervals = np.diff(spiketrain, axis=axis) if (intervals < 0).any(): - warnings.warn("ISI evaluated to negative values. " - "Please sort the input array.") + warnings.warn( + "ISI evaluated to negative values. " "Please sort the input array." + ) return intervals @@ -217,8 +222,12 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): 0.4301075268817204 """ - if isinstance(spiketrain, neo.SpikeTrain) and t_start is None \ - and t_stop is None and axis is None: + if ( + isinstance(spiketrain, neo.SpikeTrain) + and t_start is None + and t_stop is None + and axis is None + ): # a faster approach for a typical use case n_spikes = len(spiketrain) time_interval = spiketrain.t_stop - spiketrain.t_start @@ -235,22 +244,24 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): units = spiketrain.units if t_start is None: - t_start = getattr(spiketrain, 't_start', 0 * units) + t_start = getattr(spiketrain, "t_start", 0 * units) t_start = t_start.rescale(units).magnitude if t_stop is None: - t_stop = getattr(spiketrain, 't_stop', - np.max(spiketrain, axis=axis)) + t_stop = getattr(spiketrain, "t_stop", np.max(spiketrain, axis=axis)) t_stop = t_stop.rescale(units).magnitude # calculate as a numpy array - rates = mean_firing_rate(spiketrain.magnitude, t_start=t_start, - t_stop=t_stop, axis=axis) + rates = mean_firing_rate( + spiketrain.magnitude, t_start=t_start, t_stop=t_stop, axis=axis + ) - rates = pq.Quantity(rates, units=1. / units) + rates = pq.Quantity(rates, units=1.0 / units) elif isinstance(spiketrain, (np.ndarray, list, tuple)): if isinstance(t_start, pq.Quantity) or isinstance(t_stop, pq.Quantity): - raise TypeError("'t_start' and 't_stop' cannot be quantities if " - "'spiketrain' is not a Quantity.") + raise TypeError( + "'t_start' and 't_stop' cannot be quantities if " + "'spiketrain' is not a Quantity." + ) spiketrain = np.asarray(spiketrain) if len(spiketrain) == 0: raise ValueError("Empty input spiketrain.") @@ -261,12 +272,15 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): time_interval = t_stop - t_start if axis and isinstance(t_stop, np.ndarray): t_stop = np.expand_dims(t_stop, axis) - rates = np.sum((spiketrain >= t_start) & (spiketrain <= t_stop), - axis=axis) / time_interval + rates = ( + np.sum((spiketrain >= t_start) & (spiketrain <= t_stop), axis=axis) + / time_interval + ) else: - raise TypeError("Invalid input spiketrain type: '{}'. Allowed: " - "neo.SpikeTrain, Quantity, ndarray". - format(type(spiketrain))) + raise TypeError( + "Invalid input spiketrain type: '{}'. Allowed: " + "neo.SpikeTrain, Quantity, ndarray".format(type(spiketrain)) + ) return rates @@ -339,15 +353,15 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): if not is_time_quantity(warn_tolerance): raise TypeError("'warn_tolerance' must be a time quantity.") - durations = [(st.t_stop - st.t_start).simplified.item() - for st in spiketrains] + durations = [(st.t_stop - st.t_start).simplified.item() for st in spiketrains] durations_min = min(durations) durations_max = max(durations) if durations_max - durations_min > warn_tolerance.simplified.item(): - warnings.warn("Fano factor calculated for spike trains of " - "different duration (minimum: {_min}s, maximum " - "{_max}s).".format(_min=durations_min, - _max=durations_max)) + warnings.warn( + "Fano factor calculated for spike trains of " + "different duration (minimum: {_min}s, maximum " + "{_max}s).".format(_min=durations_min, _max=durations_max) + ) fano = spike_counts.var() / spike_counts.mean() return fano @@ -356,24 +370,29 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): def __variation_check(v, with_nan): # ensure the input ia a vector if v.ndim != 1: - raise ValueError("The input must be a vector, not a {}-dim matrix.". - format(v.ndim)) + raise ValueError( + "The input must be a vector, not a {}-dim matrix.".format(v.ndim) + ) # ensure we have enough entries if v.size < 2: if with_nan: - warnings.warn("The input size is too small. Please provide" - "an input with more than 1 entry. Returning `NaN`" - "since the argument `with_nan` is `True`") + warnings.warn( + "The input size is too small. Please provide" + "an input with more than 1 entry. Returning `NaN`" + "since the argument `with_nan` is `True`" + ) return np.NaN - raise ValueError("Input size is too small. Please provide " - "an input with more than 1 entry. Set 'with_nan' " - "to True to replace the error by a warning.") + raise ValueError( + "Input size is too small. Please provide " + "an input with more than 1 entry. Set 'with_nan' " + "to True to replace the error by a warning." + ) return None -@deprecated_alias(v='time_intervals') +@deprecated_alias(v="time_intervals") def cv2(time_intervals, with_nan=False): r""" Calculate the measure of Cv2 for a sequence of time intervals between @@ -438,10 +457,10 @@ def cv2(time_intervals, with_nan=False): # calculate Cv2 and return result cv_i = np.diff(time_intervals) / (time_intervals[:-1] + time_intervals[1:]) - return 2. * np.mean(np.abs(cv_i)) + return 2.0 * np.mean(np.abs(cv_i)) -@deprecated_alias(v='time_intervals') +@deprecated_alias(v="time_intervals") def lv(time_intervals, with_nan=False): r""" Calculate the measure of local variation Lv for a sequence of time @@ -505,10 +524,10 @@ def lv(time_intervals, with_nan=False): return np_nan cv_i = np.diff(time_intervals) / (time_intervals[:-1] + time_intervals[1:]) - return 3. * np.mean(np.power(cv_i, 2)) + return 3.0 * np.mean(np.power(cv_i, 2)) -def lvr(time_intervals, R=5*pq.ms, with_nan=False): +def lvr(time_intervals, R=5 * pq.ms, with_nan=False): r""" Calculate the measure of revised local variation LvR for a sequence of time intervals between events :cite:`statistics-Shinomoto2009_e1000433`. @@ -572,19 +591,20 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False): 0.833907445980624 """ if isinstance(R, pq.Quantity): - R = R.rescale('ms').magnitude + R = R.rescale("ms").magnitude else: - warnings.warn('No units specified for R, assuming milliseconds (ms)') + warnings.warn("No units specified for R, assuming milliseconds (ms)") if R < 0: - raise ValueError('R must be >= 0') + raise ValueError("R must be >= 0") # check units of intervals if available if isinstance(time_intervals, pq.Quantity): - time_intervals = time_intervals.rescale('ms').magnitude + time_intervals = time_intervals.rescale("ms").magnitude else: - warnings.warn('No units specified for time_intervals,' - ' assuming milliseconds (ms)') + warnings.warn( + "No units specified for time_intervals," " assuming milliseconds (ms)" + ) # convert to array, cast to float time_intervals = np.asarray(time_intervals) @@ -596,15 +616,24 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False): t = time_intervals[:-1] + time_intervals[1:] frac1 = 4 * time_intervals[:-1] * time_intervals[1:] / t**2 frac2 = 4 * R / t - lvr = (3 / (N-1)) * np.sum((1-frac1) * (1+frac2)) + lvr = (3 / (N - 1)) * np.sum((1 - frac1) * (1 + frac2)) return lvr -@deprecated_alias(spiketrain='spiketrains') -def instantaneous_rate(spiketrains, sampling_period, kernel='auto', - cutoff=5.0, t_start=None, t_stop=None, trim=False, - center_kernel=True, border_correction=False, - pool_trials=False, pool_spike_trains=False): +@deprecated_alias(spiketrain="spiketrains") +def instantaneous_rate( + spiketrains, + sampling_period, + kernel="auto", + cutoff=5.0, + t_start=None, + t_stop=None, + trim=False, + center_kernel=True, + border_correction=False, + pool_trials=False, + pool_spike_trains=False, +): r""" Estimates instantaneous firing rate by kernel convolution. @@ -811,48 +840,53 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto', """ if isinstance(spiketrains, elephant.trials.Trials): kwargs = { - 'kernel': kernel, - 'cutoff': cutoff, - 't_start': t_start, - 't_stop': t_stop, - 'trim': trim, - 'center_kernel': center_kernel, - 'border_correction': border_correction, - 'pool_trials': False, - 'pool_spike_trains': False, + "kernel": kernel, + "cutoff": cutoff, + "t_start": t_start, + "t_stop": t_stop, + "trim": trim, + "center_kernel": center_kernel, + "border_correction": border_correction, + "pool_trials": False, + "pool_spike_trains": False, } if pool_trials: list_of_lists_of_spiketrains = [ - spiketrains.get_spiketrains_from_trial_as_list( - trial_id=trial_no) - for trial_no in range(spiketrains.n_trials)] + spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) + for trial_no in range(spiketrains.n_trials) + ] spiketrains_cross_trials = ( - [list_of_lists_of_spiketrains[trial_no][spiketrain_idx] - for trial_no in range(spiketrains.n_trials)] - for spiketrain_idx, spiketrain in - enumerate(list_of_lists_of_spiketrains[0])) - - rates_cross_trials = [instantaneous_rate(spiketrain, - sampling_period, - **kwargs) - for spiketrain in spiketrains_cross_trials] + [ + list_of_lists_of_spiketrains[trial_no][spiketrain_idx] + for trial_no in range(spiketrains.n_trials) + ] + for spiketrain_idx, spiketrain in enumerate( + list_of_lists_of_spiketrains[0] + ) + ) + + rates_cross_trials = [ + instantaneous_rate(spiketrain, sampling_period, **kwargs) + for spiketrain in spiketrains_cross_trials + ] average_rate_cross_trials = ( - np.mean(rates, axis=1) for rates in rates_cross_trials) + np.mean(rates, axis=1) for rates in rates_cross_trials + ) if pool_spike_trains: average_rate = np.mean(list(average_rate_cross_trials), axis=0) analog_signal = rates_cross_trials[0] - return (neo.AnalogSignal( - signal=average_rate, - sampling_period=analog_signal.sampling_period, - units=analog_signal.units, - t_start=analog_signal.t_start, - t_stop=analog_signal.t_stop, - kernel=analog_signal.annotations) - ) + return neo.AnalogSignal( + signal=average_rate, + sampling_period=analog_signal.sampling_period, + units=analog_signal.units, + t_start=analog_signal.t_start, + t_stop=analog_signal.t_stop, + kernel=analog_signal.annotations, + ) list_of_average_rates_cross_trial = neo.AnalogSignal( signal=np.array(list(average_rate_cross_trials)).transpose(), @@ -860,32 +894,44 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto', units=rates_cross_trials[0].units, t_start=rates_cross_trials[0].t_start, t_stop=rates_cross_trials[0].t_stop, - kernel=rates_cross_trials[0].annotations) + kernel=rates_cross_trials[0].annotations, + ) return list_of_average_rates_cross_trial if not pool_trials and not pool_spike_trains: - return [instantaneous_rate( - spiketrains.get_spiketrains_from_trial_as_list( - trial_id=trial_no), sampling_period, **kwargs) - for trial_no in range(spiketrains.n_trials)] + return [ + instantaneous_rate( + spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no), + sampling_period, + **kwargs, + ) + for trial_no in range(spiketrains.n_trials) + ] if not pool_trials and pool_spike_trains: - rates = [instantaneous_rate( - spiketrains.get_spiketrains_from_trial_as_list( - trial_id=trial_no), sampling_period, **kwargs) - for trial_no in range(spiketrains.n_trials)] + rates = [ + instantaneous_rate( + spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no), + sampling_period, + **kwargs, + ) + for trial_no in range(spiketrains.n_trials) + ] average_rates = (np.mean(rate, axis=1) for rate in rates) list_of_average_rates_over_spiketrains = [ - neo.AnalogSignal(signal=average_rate, - sampling_period=analog_signal.sampling_period, - units=analog_signal.units, - t_start=analog_signal.t_start, - t_stop=analog_signal.t_stop, - kernel=analog_signal.annotations) - for average_rate, analog_signal in zip(average_rates, rates)] + neo.AnalogSignal( + signal=average_rate, + sampling_period=analog_signal.sampling_period, + units=analog_signal.units, + t_start=analog_signal.t_start, + t_stop=analog_signal.t_stop, + kernel=analog_signal.annotations, + ) + for average_rate, analog_signal in zip(average_rates, rates) + ] return list_of_average_rates_over_spiketrains @@ -893,39 +939,50 @@ def optimal_kernel(st): width_sigma = None if len(st) > 0: width_sigma = optimal_kernel_bandwidth( - st.magnitude, times=None, bootstrap=False)['optw'] + st.magnitude, times=None, bootstrap=False + )["optw"] if width_sigma is None: - raise ValueError("Unable to calculate optimal kernel width for " - "instantaneous rate from input data.") + raise ValueError( + "Unable to calculate optimal kernel width for " + "instantaneous rate from input data." + ) return kernels.GaussianKernel(width_sigma * st.units) - if border_correction and not \ - (kernel == 'auto' or isinstance(kernel, kernels.GaussianKernel)): + if border_correction and not ( + kernel == "auto" or isinstance(kernel, kernels.GaussianKernel) + ): raise ValueError( - 'The border correction is only implemented' - ' for Gaussian kernels.') + "The border correction is only implemented" " for Gaussian kernels." + ) if isinstance(spiketrains, neo.SpikeTrain): - if kernel == 'auto': + if kernel == "auto": kernel = optimal_kernel(spiketrains) spiketrains = [spiketrains] if not all([isinstance(elem, neo.SpikeTrain) for elem in spiketrains]): - raise TypeError(f"'spiketrains' must be a list of neo.SpikeTrain's or " - f"a single neo.SpikeTrain. Found: {type(spiketrains)}") + raise TypeError( + f"'spiketrains' must be a list of neo.SpikeTrain's or " + f"a single neo.SpikeTrain. Found: {type(spiketrains)}" + ) if not is_time_quantity(sampling_period): - raise TypeError(f"The 'sampling_period' must be a time Quantity." - f"Found: {type(sampling_period)}") + raise TypeError( + f"The 'sampling_period' must be a time Quantity." + f"Found: {type(sampling_period)}" + ) if sampling_period.magnitude < 0: - raise ValueError(f"The 'sampling_period' ({sampling_period}) " - f"must be non-negative.") + raise ValueError( + f"The 'sampling_period' ({sampling_period}) " f"must be non-negative." + ) - if not (isinstance(kernel, kernels.Kernel) or kernel == 'auto'): - raise TypeError(f"'kernel' must be instance of class " - f"elephant.kernels.Kernel or string 'auto'. Found: " - f"{type(kernel)}, value {str(kernel)}") + if not (isinstance(kernel, kernels.Kernel) or kernel == "auto"): + raise TypeError( + f"'kernel' must be instance of class " + f"elephant.kernels.Kernel or string 'auto'. Found: " + f"{type(kernel)}, value {str(kernel)}" + ) if not isinstance(cutoff, (float, int)): raise TypeError("'cutoff' must be float or integer") @@ -938,17 +995,19 @@ def optimal_kernel(st): if not isinstance(trim, bool): raise TypeError("'trim' must be bool") - check_neo_consistency(spiketrains, - object_type=neo.SpikeTrain, - t_start=t_start, t_stop=t_stop) + check_neo_consistency( + spiketrains, object_type=neo.SpikeTrain, t_start=t_start, t_stop=t_stop + ) - if kernel == 'auto': + if kernel == "auto": if len(spiketrains) == 1: kernel = optimal_kernel(spiketrains[0]) else: - raise ValueError("Cannot estimate a kernel for a list of spike " - "trains. Please provide a kernel explicitly " - "rather than 'auto'.") + raise ValueError( + "Cannot estimate a kernel for a list of spike " + "trains. Please provide a kernel explicitly " + "rather than 'auto'." + ) if t_start is None: t_start = spiketrains[0].t_start @@ -961,24 +1020,23 @@ def optimal_kernel(st): # Calculate parameters for np.histogram n_bins = int(((t_stop - t_start) / sampling_period).simplified) - hist_range_end = t_start + n_bins * \ - sampling_period.rescale(spiketrains[0].units) + hist_range_end = t_start + n_bins * sampling_period.rescale(spiketrains[0].units) hist_range = (t_start.item(), hist_range_end.item()) # Preallocation histogram_arr = np.zeros((len(spiketrains), n_bins), dtype=np.float64) for i, st in enumerate(spiketrains): - histogram_arr[i], _ = np.histogram(st.magnitude, bins=n_bins, - range=hist_range) + histogram_arr[i], _ = np.histogram(st.magnitude, bins=n_bins, range=hist_range) histogram_arr = histogram_arr.T # make it (time, units) # Kernel if cutoff < kernel.min_cutoff: cutoff = kernel.min_cutoff - warnings.warn("The width of the kernel was adjusted to a minimally " - "allowed width.") + warnings.warn( + "The width of the kernel was adjusted to a minimally " "allowed width." + ) scaling_unit = pq.CompoundUnit(f"{sampling_period.rescale('s').item()}*s") cutoff_sigma = cutoff * kernel.sigma.rescale(scaling_unit).magnitude @@ -995,59 +1053,66 @@ def optimal_kernel(st): # `num=2 * t_arr_kernel_half + 1` is always odd. # (See Issue #360, https://github.com/NeuralEnsemble/elephant/issues/360) t_arr_kernel_half = math.ceil( - cutoff * (kernel.sigma / sampling_period).simplified.item()) + cutoff * (kernel.sigma / sampling_period).simplified.item() + ) t_arr_kernel_length = 2 * t_arr_kernel_half + 1 # Shift kernel using the calculated median - t_arr_kernel = np.linspace(start=-cutoff_sigma + median, - stop=cutoff_sigma + median, - num=t_arr_kernel_length, - endpoint=True) * scaling_unit + t_arr_kernel = ( + np.linspace( + start=-cutoff_sigma + median, + stop=cutoff_sigma + median, + num=t_arr_kernel_length, + endpoint=True, + ) + * scaling_unit + ) # Calculate the kernel values with t_arr - kernel_arr = np.expand_dims( - kernel(t_arr_kernel).rescale(pq.Hz).magnitude, axis=1) + kernel_arr = np.expand_dims(kernel(t_arr_kernel).rescale(pq.Hz).magnitude, axis=1) # Define mode for scipy.signal.fftconvolve if trim: - fft_mode = 'valid' + fft_mode = "valid" else: - fft_mode = 'same' + fft_mode = "same" - rate = scipy.signal.fftconvolve(histogram_arr, - kernel_arr, - mode=fft_mode) + rate = scipy.signal.fftconvolve(histogram_arr, kernel_arr, mode=fft_mode) # The convolution of non-negative vectors is non-negative rate = np.clip(rate, a_min=0, a_max=None, out=rate) # Adjust t_start and t_stop - if fft_mode == 'valid': + if fft_mode == "valid": median_id = kernel.median_index(t_arr_kernel) kernel_array_size = len(kernel_arr) t_start = t_start + median_id * scaling_unit t_stop = t_stop - (kernel_array_size - median_id) * scaling_unit - kernel_annotation = dict(type=type(kernel).__name__, - sigma=str(kernel.sigma), - invert=kernel.invert) + kernel_annotation = dict( + type=type(kernel).__name__, sigma=str(kernel.sigma), invert=kernel.invert + ) if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and ( - pool_spike_trains): + pool_spike_trains + ): rate = np.mean(rate, axis=1) - rate = neo.AnalogSignal(signal=rate, - sampling_period=sampling_period, - units=pq.Hz, t_start=t_start, t_stop=t_stop, - kernel=kernel_annotation) + rate = neo.AnalogSignal( + signal=rate, + sampling_period=sampling_period, + units=pq.Hz, + t_start=t_start, + t_stop=t_stop, + kernel=kernel_annotation, + ) if border_correction: sigma = kernel.sigma.simplified.magnitude times = rate.times.simplified.magnitude correction_factor = 2 / ( - erf((t_stop.simplified.magnitude - times) / ( - np.sqrt(2.) * sigma)) - - erf((t_start.simplified.magnitude - times) / ( - np.sqrt(2.) * sigma))) + erf((t_stop.simplified.magnitude - times) / (np.sqrt(2.0) * sigma)) + - erf((t_start.simplified.magnitude - times) / (np.sqrt(2.0) * sigma)) + ) rate *= correction_factor[:, None] @@ -1055,15 +1120,17 @@ def optimal_kernel(st): # ensure integral over firing rate yield the exact number of spikes for i, spiketrain in enumerate(spiketrains): if len(spiketrain) > 0: - rate[:, i] *= len(spiketrain) /\ - (np.mean(rate[:, i]).magnitude * duration) + rate[:, i] *= len(spiketrain) / ( + np.mean(rate[:, i]).magnitude * duration + ) return rate -@deprecated_alias(binsize='bin_size') -def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, - output='counts', binary=False): +@deprecated_alias(binsize="bin_size") +def time_histogram( + spiketrains, bin_size, t_start=None, t_stop=None, output="counts", binary=False +): """ Time Histogram of a list of `neo.SpikeTrain` objects. @@ -1160,15 +1227,13 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, """ # Bin the spike trains and sum across columns if binary: - binned_spiketrain = BinnedSpikeTrain(spiketrains, - t_start=t_start, - t_stop=t_stop, bin_size=bin_size - ).binarize(copy=False) + binned_spiketrain = BinnedSpikeTrain( + spiketrains, t_start=t_start, t_stop=t_stop, bin_size=bin_size + ).binarize(copy=False) else: - binned_spiketrain = BinnedSpikeTrain(spiketrains, - t_start=t_start, - t_stop=t_stop, bin_size=bin_size - ) + binned_spiketrain = BinnedSpikeTrain( + spiketrains, t_start=t_start, t_stop=t_stop, bin_size=bin_size + ) bin_hist: Union[int, ndarray] = binned_spiketrain.get_num_of_spikes(axis=0) # Flatten array @@ -1182,8 +1247,9 @@ def _counts() -> pq.Quantity: def _mean() -> pq.Quantity: # 'mean': mean spike counts per spike train. - return pq.Quantity(bin_hist / len(spiketrains), - units=pq.dimensionless, copy=False) + return pq.Quantity( + bin_hist / len(spiketrains), units=pq.dimensionless, copy=False + ) def _rate() -> pq.Quantity: # 'rate': mean spike rate per spike train. Like 'mean', but the @@ -1195,16 +1261,19 @@ def _rate() -> pq.Quantity: normalise_func = output_mapping.get(output) normalised_bin_hist = normalise_func() except TypeError: - raise ValueError(f'Parameter output ({output}) is not valid.') + raise ValueError(f"Parameter output ({output}) is not valid.") - return neo.AnalogSignal(signal=np.expand_dims(normalised_bin_hist, axis=1), - sampling_period=bin_size, - units=normalised_bin_hist.units, - t_start=binned_spiketrain.t_start, - normalization=output, copy=False) + return neo.AnalogSignal( + signal=np.expand_dims(normalised_bin_hist, axis=1), + sampling_period=bin_size, + units=normalised_bin_hist.units, + t_start=binned_spiketrain.t_start, + normalization=output, + copy=False, + ) -@deprecated_alias(binsize='bin_size') +@deprecated_alias(binsize="bin_size") def complexity_pdf(spiketrains, bin_size): """ Complexity Distribution of a list of `neo.SpikeTrain` objects @@ -1237,8 +1306,11 @@ def complexity_pdf(spiketrains, bin_size): -------- elephant.conversion.BinnedSpikeTrain """ - warnings.warn("'complexity_pdf' is deprecated in favor of the Complexity " - "class which has a 'pdf' method", DeprecationWarning) + warnings.warn( + "'complexity_pdf' is deprecated in favor of the Complexity " + "class which has a 'pdf' method", + DeprecationWarning, + ) complexity = Complexity(spiketrains, bin_size=bin_size) @@ -1418,20 +1490,22 @@ class Complexity(object): """ - def __init__(self, spiketrains, - sampling_rate=None, - bin_size=None, - binary=True, - spread=0, - tolerance=1e-8): - + def __init__( + self, + spiketrains, + sampling_rate=None, + bin_size=None, + binary=True, + spread=0, + tolerance=1e-8, + ): check_neo_consistency(spiketrains, object_type=neo.SpikeTrain) if bin_size is None and sampling_rate is None: - raise ValueError('No bin_size or sampling_rate was specified!') + raise ValueError("No bin_size or sampling_rate was specified!") if spread < 0: - raise ValueError('Spread must be >=0') + raise ValueError("Spread must be >=0") self.input_spiketrains = spiketrains self.t_start = spiketrains[0].t_start @@ -1446,13 +1520,13 @@ def __init__(self, spiketrains, self.bin_size = 1 / self.sampling_rate if spread == 0: - self.time_histogram, self.complexity_histogram = \ - self._histogram_no_spread() + self.time_histogram, self.complexity_histogram = self._histogram_no_spread() self.epoch = self._epoch_no_spread() else: self.epoch = self._epoch_with_spread() - self.time_histogram, self.complexity_histogram = \ + self.time_histogram, self.complexity_histogram = ( self._histogram_with_spread() + ) def pdf(self): """ @@ -1471,7 +1545,8 @@ def pdf(self): np.expand_dims(norm_hist, axis=1), units=pq.dimensionless, t_start=0 * pq.dimensionless, - sampling_period=1 * pq.dimensionless) + sampling_period=1 * pq.dimensionless, + ) return pdf def _histogram_no_spread(self): @@ -1480,9 +1555,9 @@ def _histogram_no_spread(self): """ # Computing the population histogram with parameter binary=True to # clip the spike trains before summing - time_hist = time_histogram(self.input_spiketrains, - self.bin_size, - binary=self.binary) + time_hist = time_histogram( + self.input_spiketrains, self.bin_size, binary=self.binary + ) time_hist_magnitude = time_hist.magnitude.flatten() @@ -1495,21 +1570,22 @@ def _histogram_with_spread(self): """ Calculate the complexity histogram and time histogram for `spread` > 0 """ - complexity_hist = np.bincount( - self.epoch.array_annotations['complexity']) + complexity_hist = np.bincount(self.epoch.array_annotations["complexity"]) num_bins = (self.t_stop - self.t_start).rescale( - self.bin_size.units).item() / self.bin_size.item() + self.bin_size.units + ).item() / self.bin_size.item() num_bins = round_binning_errors(num_bins, tolerance=self.tolerance) time_hist = np.zeros(num_bins, dtype=int) start_bins = (self.epoch.times - self.t_start).rescale( - self.bin_size.units).magnitude / self.bin_size.item() - stop_bins = (self.epoch.times + self.epoch.durations - self.t_start - ).rescale(self.bin_size.units - ).magnitude / self.bin_size.item() + self.bin_size.units + ).magnitude / self.bin_size.item() + stop_bins = (self.epoch.times + self.epoch.durations - self.t_start).rescale( + self.bin_size.units + ).magnitude / self.bin_size.item() if self.sampling_rate is not None: - shift = (.5 / self.sampling_rate / self.bin_size).simplified.item() + shift = (0.5 / self.sampling_rate / self.bin_size).simplified.item() # account for the first bin not being shifted in the epoch creation # if the shift would move it past t_start if self.epoch.times[0] == self.t_start: @@ -1522,17 +1598,19 @@ def _histogram_with_spread(self): stop_bins = round_binning_errors(stop_bins, tolerance=self.tolerance) for idx, (start, stop) in enumerate(zip(start_bins, stop_bins)): - time_hist[start:stop] = \ - self.epoch.array_annotations['complexity'][idx] + time_hist[start:stop] = self.epoch.array_annotations["complexity"][idx] time_hist = neo.AnalogSignal( signal=np.expand_dims(time_hist, axis=1), - sampling_period=self.bin_size, units=pq.dimensionless, - t_start=self.t_start) + sampling_period=self.bin_size, + units=pq.dimensionless, + t_start=self.t_start, + ) - empty_bins = (self.t_stop - self.t_start - self.epoch.durations.sum()) - empty_bins = empty_bins.rescale(self.bin_size.units - ).magnitude / self.bin_size.item() + empty_bins = self.t_stop - self.t_start - self.epoch.durations.sum() + empty_bins = ( + empty_bins.rescale(self.bin_size.units).magnitude / self.bin_size.item() + ) empty_bins = round_binning_errors(empty_bins, tolerance=self.tolerance) complexity_hist[0] = empty_bins @@ -1547,7 +1625,7 @@ def _epoch_no_spread(self): if self.sampling_rate: # ensure that spikes are not on the bin edges - bin_shift = .5 / self.sampling_rate + bin_shift = 0.5 / self.sampling_rate left_edges -= bin_shift # Ensure that an epoch does not start before the minimum t_start. @@ -1556,83 +1634,93 @@ def _epoch_no_spread(self): left_edges[0] = self.t_start durations[0] -= bin_shift else: - warnings.warn('No sampling rate specified. ' - 'Note that using the complexity epoch to get ' - 'precise spike times can lead to rounding errors.') + warnings.warn( + "No sampling rate specified. " + "Note that using the complexity epoch to get " + "precise spike times can lead to rounding errors." + ) complexity = self.time_histogram.magnitude.flatten() complexity = complexity.astype(np.uint16) - epoch = neo.Epoch(left_edges, - durations=durations, - array_annotations={'complexity': complexity}) + epoch = neo.Epoch( + left_edges, + durations=durations, + array_annotations={"complexity": complexity}, + ) return epoch def _epoch_with_spread(self): """ Get an epoch object of the complexity distribution with `spread` > 0 """ - bst = conv.BinnedSpikeTrain(self.input_spiketrains, - bin_size=self.bin_size, - tolerance=self.tolerance) + bst = conv.BinnedSpikeTrain( + self.input_spiketrains, bin_size=self.bin_size, tolerance=self.tolerance + ) if self.binary: bst = bst.binarize(copy=False) bincount = bst.get_num_of_spikes(axis=0) nonzero_indices = np.nonzero(bincount)[0] - left_diff = np.diff(nonzero_indices, - prepend=-self.spread - 1) - right_diff = np.diff(nonzero_indices, - append=len(bincount) + self.spread + 1) + left_diff = np.diff(nonzero_indices, prepend=-self.spread - 1) + right_diff = np.diff(nonzero_indices, append=len(bincount) + self.spread + 1) # standalone bins (no merging required) - single_bin_indices = np.logical_and(left_diff > self.spread, - right_diff > self.spread) + single_bin_indices = np.logical_and( + left_diff > self.spread, right_diff > self.spread + ) single_bins = nonzero_indices[single_bin_indices] # bins separated by fewer than spread bins form clusters # that have to be merged - cluster_start_indices = np.logical_and(left_diff > self.spread, - right_diff <= self.spread) + cluster_start_indices = np.logical_and( + left_diff > self.spread, right_diff <= self.spread + ) cluster_starts = nonzero_indices[cluster_start_indices] - cluster_stop_indices = np.logical_and(left_diff <= self.spread, - right_diff > self.spread) + cluster_stop_indices = np.logical_and( + left_diff <= self.spread, right_diff > self.spread + ) cluster_stops = nonzero_indices[cluster_stop_indices] + 1 single_bin_complexities = bincount[single_bins] - cluster_complexities = [bincount[start:stop].sum() - for start, stop in zip(cluster_starts, - cluster_stops)] + cluster_complexities = [ + bincount[start:stop].sum() + for start, stop in zip(cluster_starts, cluster_stops) + ] # merge standalone bins and clusters and sort them combined_starts = np.concatenate((single_bins, cluster_starts)) combined_stops = np.concatenate((single_bins + 1, cluster_stops)) - combined_complexities = np.concatenate((single_bin_complexities, - cluster_complexities)) - sorting = np.argsort(combined_starts, kind='mergesort') + combined_complexities = np.concatenate( + (single_bin_complexities, cluster_complexities) + ) + sorting = np.argsort(combined_starts, kind="mergesort") left_edges = bst.bin_edges[combined_starts[sorting]] right_edges = bst.bin_edges[combined_stops[sorting]] complexities = combined_complexities[sorting].astype(np.uint16) if self.sampling_rate: # ensure that spikes are not on the bin edges - bin_shift = .5 / self.sampling_rate + bin_shift = 0.5 / self.sampling_rate left_edges -= bin_shift right_edges -= bin_shift else: - warnings.warn('No sampling rate specified. ' - 'Note that using the complexity epoch to get ' - 'precise spike times can lead to rounding errors.') + warnings.warn( + "No sampling rate specified. " + "Note that using the complexity epoch to get " + "precise spike times can lead to rounding errors." + ) # Ensure that an epoch does not start before the minimum t_start. # Note: all spike trains share the same t_start and t_stop. left_edges[0] = max(self.t_start, left_edges[0]) - complexity_epoch = neo.Epoch(times=left_edges, - durations=right_edges - left_edges, - array_annotations={'complexity': - complexities}) + complexity_epoch = neo.Epoch( + times=left_edges, + durations=right_edges - left_edges, + array_annotations={"complexity": complexities}, + ) return complexity_epoch @@ -1642,7 +1730,7 @@ def nextpow2(x): Return the smallest integral power of 2 that is equal or larger than `x`. """ log2_n = math.ceil(math.log2(x)) - n = 2 ** log2_n + n = 2**log2_n return n @@ -1676,7 +1764,7 @@ def fftkernel(x, w): n = nextpow2(Lmax) X = np.fft.fft(x, n) f = np.arange(0, n, 1.0) / n - f = np.concatenate((-f[:int(n / 2)], f[int(n / 2):0:-1])) + f = np.concatenate((-f[: int(n / 2)], f[int(n / 2) : 0 : -1])) K = np.exp(-0.5 * (w * 2 * np.pi * f) ** 2) y = np.fft.ifft(X * K, n) y = y[:L].copy() @@ -1708,17 +1796,15 @@ def cost_function(x, N, w, dt): """ yh = np.abs(fftkernel(x, w / dt)) # density # formula for density - C = np.sum(yh ** 2) * dt - 2 * np.sum(yh * x) * \ - dt + 2 / np.sqrt(2 * np.pi) / w / N + C = np.sum(yh**2) * dt - 2 * np.sum(yh * x) * dt + 2 / np.sqrt(2 * np.pi) / w / N C = C * N * N # formula for rate # C = dt*sum( yh.^2 - 2*yh.*y_hist + 2/sqrt(2*pi)/w*y_hist ) return C, yh -@deprecated_alias(tin='times', w='bandwidth') -def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, - bootstrap=False): +@deprecated_alias(tin="times", w="bandwidth") +def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, bootstrap=False): """ Calculates optimal fixed kernel bandwidth :cite:`statistics-Shimazaki2010_171`, given as the standard deviation @@ -1776,21 +1862,22 @@ def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, isi = np.diff(spiketimes) isi = isi[isi > 0].copy() dt = np.min(isi) - times = np.linspace(np.min(spiketimes), - np.max(spiketimes), - min(int(time / dt + 0.5), - 1000)) # The 1000 seems somewhat arbitrary + times = np.linspace( + np.min(spiketimes), np.max(spiketimes), min(int(time / dt + 0.5), 1000) + ) # The 1000 seems somewhat arbitrary t = times else: time = np.max(times) - np.min(times) - spiketimes = spiketimes[(spiketimes >= np.min(times)) & - (spiketimes <= np.max(times))].copy() + spiketimes = spiketimes[ + (spiketimes >= np.min(times)) & (spiketimes <= np.max(times)) + ].copy() isi = np.diff(spiketimes) isi = isi[isi > 0].copy() dt = np.min(isi) if dt > np.min(np.diff(times)): - t = np.linspace(np.min(times), np.max(times), - min(int(time / dt + 0.5), 1000)) + t = np.linspace( + np.min(times), np.max(times), min(int(time / dt + 0.5), 1000) + ) else: t = times dt = np.min(np.diff(times)) @@ -1824,8 +1911,7 @@ def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, f1, y1 = cost_function(yhist, N, logexp(c1), dt) f2, y2 = cost_function(yhist, N, logexp(c2), dt) k = 0 - while (np.abs(b - a) > (tolerance * (np.abs(c1) + np.abs(c2)))) \ - and (k < imax): + while (np.abs(b - a) > (tolerance * (np.abs(c1) + np.abs(c2)))) and (k < imax): if f1 < f2: b = c2 c2 = c1 @@ -1857,8 +1943,7 @@ def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, for ii in range(nbs): idx = np.floor(np.random.rand(N) * N).astype(int) xb = spiketimes[idx] - y_histb, bins = np.histogram( - xb, np.r_[t - dt / 2, t[-1] + dt / 2]) / dt / N + y_histb, bins = np.histogram(xb, np.r_[t - dt / 2, t[-1] + dt / 2]) / dt / N yb_buf = fftkernel(y_histb, optw / dt).real yb_buf = yb_buf / np.sum(yb_buf * dt) yb[ii, :] = np.interp(times, t, yb_buf) @@ -1869,16 +1954,20 @@ def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, # Only perform interpolation if y could be calculated if y is not None: y = np.interp(times, t, y) - return {'y': y, - 't': times, - 'optw': optw, - 'w': bandwidth, - 'C': C, - 'confb95': confb95, - 'yb': yb} + return { + "y": y, + "t": times, + "optw": optw, + "w": bandwidth, + "C": C, + "confb95": confb95, + "yb": yb, + } def sskernel(*args, **kwargs): - warnings.warn("'sskernel' function is deprecated; " - "use 'optimal_kernel_bandwidth'", DeprecationWarning) + warnings.warn( + "'sskernel' function is deprecated; " "use 'optimal_kernel_bandwidth'", + DeprecationWarning, + ) return optimal_kernel_bandwidth(*args, **kwargs) diff --git a/elephant/test/make_spike_extraction_test_data.py b/elephant/test/make_spike_extraction_test_data.py index 4a72983d6..f23807faf 100644 --- a/elephant/test/make_spike_extraction_test_data.py +++ b/elephant/test/make_spike_extraction_test_data.py @@ -25,10 +25,10 @@ def main(): # pragma: no cover """ # Setup and run simulation. - G = NeuronGroup(1, eqs, threshold='v>30*mvolt', reset='v = -70*mvolt') + G = NeuronGroup(1, eqs, threshold="v>30*mvolt", reset="v = -70*mvolt") G.v = -65 * mvolt G.u = b * G.v - M = StateMonitor(G, 'v', record=True) + M = StateMonitor(G, "v", record=True) run(300 * ms) # Store results in neo format. @@ -37,11 +37,11 @@ def main(): # pragma: no cover # Plot results. plt.figure() plt.plot(vm.times * 1000, vm * 1000) # Plot mV and ms instead of V and s. - plt.xlabel('Time (ms)') - plt.ylabel('mv') + plt.xlabel("Time (ms)") + plt.ylabel("mv") # Save results. - iom = neo.io.PyNNNumpyIO('spike_extraction_test_data') + iom = neo.io.PyNNNumpyIO("spike_extraction_test_data") block = neo.core.Block() segment = neo.core.Segment() segment.analogsignals.append(vm) @@ -49,7 +49,7 @@ def main(): # pragma: no cover iom.write(block) # Load results. - iom2 = neo.io.PyNNNumpyIO('spike_extraction_test_data.npz') + iom2 = neo.io.PyNNNumpyIO("spike_extraction_test_data.npz") data = iom2.read() vm = data[0].segments[0].analogsignals[0] @@ -57,9 +57,9 @@ def main(): # pragma: no cover # The two figures should match. plt.figure() plt.plot(vm.times * 1000, vm * 1000) # Plot mV and ms instead of V and s. - plt.xlabel('Time (ms)') - plt.ylabel('mv') + plt.xlabel("Time (ms)") + plt.ylabel("mv") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/elephant/test/test_asset.py b/elephant/test/test_asset.py index e9309e4c2..3544da490 100644 --- a/elephant/test/test_asset.py +++ b/elephant/test/test_asset.py @@ -35,93 +35,89 @@ try: import pyopencl + HAVE_PYOPENCL = asset.get_opencl_capability() except ImportError: HAVE_PYOPENCL = False try: import pycuda + HAVE_CUDA = asset.get_cuda_capability_major() > 0 except ImportError: HAVE_CUDA = False class AssetBinningTestCase(unittest.TestCase): - def setUp(self): spiketrain_1 = neo.SpikeTrain( - [1.3, 2.1, 3.9999999999, 4.9999], units=pq.ms, t_stop=6*pq.ms) + [1.3, 2.1, 3.9999999999, 4.9999], units=pq.ms, t_stop=6 * pq.ms + ) spiketrain_2 = neo.SpikeTrain( - [0.9999999999, 1.9999, 4, 5], units=pq.ms, t_stop=6*pq.ms) + [0.9999999999, 1.9999, 4, 5], units=pq.ms, t_stop=6 * pq.ms + ) self.spiketrains_i = [spiketrain_1, spiketrain_2] self.spiketrains_j = [spiketrain_2, spiketrain_1] def test_bin_tolerance_default(self): - asset_obj = asset.ASSET(spiketrains_i=self.spiketrains_i, - spiketrains_j=self.spiketrains_j, - bin_size=1*pq.ms) + asset_obj = asset.ASSET( + spiketrains_i=self.spiketrains_i, + spiketrains_j=self.spiketrains_j, + bin_size=1 * pq.ms, + ) bins_i = asset_obj.spiketrains_binned_i.to_array() bins_j = asset_obj.spiketrains_binned_j.to_array() # Should shift spikes closer than 1e-8 to the right bin edge. # This is the current default tolerance for `BinnedSpikeTrain`. - expected_bins_i = np.array( - [[0, 1, 1, 0, 2, 0], - [0, 2, 0, 0, 1, 1]]) - expected_bins_j = np.array( - [[0, 2, 0, 0, 1, 1], - [0, 1, 1, 0, 2, 0]]) + expected_bins_i = np.array([[0, 1, 1, 0, 2, 0], [0, 2, 0, 0, 1, 1]]) + expected_bins_j = np.array([[0, 2, 0, 0, 1, 1], [0, 1, 1, 0, 2, 0]]) self.assertTrue(np.array_equal(bins_i, expected_bins_i)) self.assertTrue(np.array_equal(bins_j, expected_bins_j)) def test_bin_tolerance_none(self): - asset_obj = asset.ASSET(spiketrains_i=self.spiketrains_i, - spiketrains_j=self.spiketrains_j, - bin_size=1*pq.ms, - bin_tolerance=None) + asset_obj = asset.ASSET( + spiketrains_i=self.spiketrains_i, + spiketrains_j=self.spiketrains_j, + bin_size=1 * pq.ms, + bin_tolerance=None, + ) bins_i = asset_obj.spiketrains_binned_i.to_array() bins_j = asset_obj.spiketrains_binned_j.to_array() # Should not shift any spikes. Bin should be the same as the integer # part of the time. - expected_bins_i = np.array( - [[0, 1, 1, 1, 1, 0], - [1, 1, 0, 0, 1, 1]]) - expected_bins_j = np.array( - [[1, 1, 0, 0, 1, 1], - [0, 1, 1, 1, 1, 0]]) + expected_bins_i = np.array([[0, 1, 1, 1, 1, 0], [1, 1, 0, 0, 1, 1]]) + expected_bins_j = np.array([[1, 1, 0, 0, 1, 1], [0, 1, 1, 1, 1, 0]]) self.assertTrue(np.array_equal(bins_i, expected_bins_i)) self.assertTrue(np.array_equal(bins_j, expected_bins_j)) def test_bin_tolerance_float(self): - asset_obj = asset.ASSET(spiketrains_i=self.spiketrains_i, - spiketrains_j=self.spiketrains_j, - bin_size=1*pq.ms, - bin_tolerance=1e-3) + asset_obj = asset.ASSET( + spiketrains_i=self.spiketrains_i, + spiketrains_j=self.spiketrains_j, + bin_size=1 * pq.ms, + bin_tolerance=1e-3, + ) bins_i = asset_obj.spiketrains_binned_i.to_array() bins_j = asset_obj.spiketrains_binned_j.to_array() # Should shift spikes closer than 1e-3 to the right bin edge. - expected_bins_i = np.array( - [[0, 1, 1, 0, 1, 1], - [0, 1, 1, 0, 1, 1]]) - expected_bins_j = np.array( - [[0, 1, 1, 0, 1, 1], - [0, 1, 1, 0, 1, 1]]) + expected_bins_i = np.array([[0, 1, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]]) + expected_bins_j = np.array([[0, 1, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]]) self.assertTrue(np.array_equal(bins_i, expected_bins_i)) self.assertTrue(np.array_equal(bins_j, expected_bins_j)) -@unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn') +@unittest.skipUnless(HAVE_SKLEARN, "requires sklearn") class AssetTestCase(unittest.TestCase): - def setUp(self): - os.environ['ELEPHANT_USE_OPENCL'] = '0' + os.environ["ELEPHANT_USE_OPENCL"] = "0" def test_stretched_metric_2d_size(self): nr_points = 4 @@ -153,7 +149,7 @@ def test_stretched_metric_2d_symmetric(self): def test_stretched_metric_2d_equals_euclidean_if_stretch_1(self): x = np.arange(10) - y = y = x ** 2 - 2 * x - 4 + y = y = x**2 - 2 * x - 4 # compute stretched distance matrix stretch = 1 D = stretchedmetric2d(x, y, stretch=stretch, ref_angle=45) @@ -164,17 +160,13 @@ def test_stretched_metric_2d_equals_euclidean_if_stretch_1(self): assert_array_almost_equal(D, E, decimal=5) def test_get_sse_start_and_end_time_bins(self): - sse = {(1, 2): set([1, 2, 3]), - (3, 4): set([5, 6]), - (6, 7): set([0, 1])} + sse = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])} start, end = asset.get_sse_start_and_end_time_bins(sse) self.assertListEqual(start, [1, 2]) self.assertListEqual(end, [6, 7]) def test_get_neurons_in_sse(self): - sse = {(1, 2): set([1, 2, 3]), - (3, 4): set([5, 6]), - (6, 7): set([0, 1])} + sse = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])} neurons = asset.get_neurons_in_sse(sse) self.assertListEqual(neurons, [0, 1, 2, 3, 5, 6]) @@ -186,17 +178,17 @@ def test_sse_difference(self): diff_ab_linkwise = {(1, 2): set([3]), (3, 4): set([5, 6])} diff_ba_linkwise = {(1, 2): set([5]), (5, 6): set([0, 2])} self.assertEqual( - asset.synchronous_events_difference(a, b, 'pixelwise'), - diff_ab_pixelwise) + asset.synchronous_events_difference(a, b, "pixelwise"), diff_ab_pixelwise + ) self.assertEqual( - asset.synchronous_events_difference(b, a, 'pixelwise'), - diff_ba_pixelwise) + asset.synchronous_events_difference(b, a, "pixelwise"), diff_ba_pixelwise + ) self.assertEqual( - asset.synchronous_events_difference(a, b, 'linkwise'), - diff_ab_linkwise) + asset.synchronous_events_difference(a, b, "linkwise"), diff_ab_linkwise + ) self.assertEqual( - asset.synchronous_events_difference(b, a, 'linkwise'), - diff_ba_linkwise) + asset.synchronous_events_difference(b, a, "linkwise"), diff_ba_linkwise + ) def test_sse_intersection(self): a = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])} @@ -206,17 +198,19 @@ def test_sse_intersection(self): inters_ab_linkwise = {(1, 2): set([1, 2]), (6, 7): set([0, 1])} inters_ba_linkwise = {(1, 2): set([1, 2]), (6, 7): set([0, 1])} self.assertEqual( - asset.synchronous_events_intersection(a, b, 'pixelwise'), - inters_ab_pixelwise) + asset.synchronous_events_intersection(a, b, "pixelwise"), + inters_ab_pixelwise, + ) self.assertEqual( - asset.synchronous_events_intersection(b, a, 'pixelwise'), - inters_ba_pixelwise) + asset.synchronous_events_intersection(b, a, "pixelwise"), + inters_ba_pixelwise, + ) self.assertEqual( - asset.synchronous_events_intersection(a, b, 'linkwise'), - inters_ab_linkwise) + asset.synchronous_events_intersection(a, b, "linkwise"), inters_ab_linkwise + ) self.assertEqual( - asset.synchronous_events_intersection(b, a, 'linkwise'), - inters_ba_linkwise) + asset.synchronous_events_intersection(b, a, "linkwise"), inters_ba_linkwise + ) def test_sse_relations(self): a = {(1, 2): set([1, 2, 3]), (3, 4): set([5, 6]), (6, 7): set([0, 1])} @@ -250,53 +244,38 @@ def test_mask_matrix(self): self.assertIsInstance(mask_1_2[0, 0], np.bool_) self.assertRaises(ValueError, asset.ASSET.mask_matrices, [], []) - self.assertRaises(ValueError, asset.ASSET.mask_matrices, - [np.arange(5)], []) + self.assertRaises(ValueError, asset.ASSET.mask_matrices, [np.arange(5)], []) def test_cluster_matrix_entries(self): # test with symmetric matrix - mat = np.array([[0, 0, 1, 0], - [0, 0, 0, 1], - [1, 0, 0, 0], - [0, 1, 0, 0]]) + mat = np.array([[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]) clustered = asset.ASSET.cluster_matrix_entries( - mat, max_distance=1.5, min_neighbors=2, stretch=1) - correct = np.array([[0, 0, 1, 0], - [0, 0, 0, 1], - [2, 0, 0, 0], - [0, 2, 0, 0]]) + mat, max_distance=1.5, min_neighbors=2, stretch=1 + ) + correct = np.array([[0, 0, 1, 0], [0, 0, 0, 1], [2, 0, 0, 0], [0, 2, 0, 0]]) assert_array_equal(clustered, correct) # test with non-symmetric matrix - mat = np.array([[0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 1], - [0, 1, 0, 0]]) + mat = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 1], [0, 1, 0, 0]]) clustered = asset.ASSET.cluster_matrix_entries( - mat, max_distance=1.5, min_neighbors=3, stretch=1) - correct = np.array([[0, 1, 0, 0], - [0, 0, 1, 0], - [-1, 0, 0, 1], - [0, -1, 0, 0]]) + mat, max_distance=1.5, min_neighbors=3, stretch=1 + ) + correct = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [-1, 0, 0, 1], [0, -1, 0, 0]]) assert_array_equal(clustered, correct) # test with lowered min_neighbors - mat = np.array([[0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 1], - [0, 1, 0, 0]]) + mat = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 1], [0, 1, 0, 0]]) clustered = asset.ASSET.cluster_matrix_entries( - mat, max_distance=1.5, min_neighbors=2, stretch=1) - correct = np.array([[0, 1, 0, 0], - [0, 0, 1, 0], - [2, 0, 0, 1], - [0, 2, 0, 0]]) + mat, max_distance=1.5, min_neighbors=2, stretch=1 + ) + correct = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [2, 0, 0, 1], [0, 2, 0, 0]]) assert_array_equal(clustered, correct) mat = np.zeros((4, 4)) clustered = asset.ASSET.cluster_matrix_entries( - mat, max_distance=1.5, min_neighbors=2, stretch=1) + mat, max_distance=1.5, min_neighbors=2, stretch=1 + ) correct = mat assert_array_equal(clustered, correct) @@ -307,12 +286,19 @@ def test_cluster_matrix_entries_chunked(self): min_neighbors = 4 stretch = 2 cmat_true = asset.ASSET.cluster_matrix_entries( - mmat, max_distance=max_distance, min_neighbors=min_neighbors, - stretch=stretch) + mmat, + max_distance=max_distance, + min_neighbors=min_neighbors, + stretch=stretch, + ) for working_memory in [1, 10, 100, 1000]: cmat = asset.ASSET.cluster_matrix_entries( - mmat, max_distance=max_distance, min_neighbors=min_neighbors, - stretch=stretch, working_memory=working_memory) + mmat, + max_distance=max_distance, + min_neighbors=min_neighbors, + stretch=stretch, + working_memory=working_memory, + ) assert_array_equal(cmat, cmat_true) def test_cluster_matrix_entries_chunked_array_file(self): @@ -322,16 +308,22 @@ def test_cluster_matrix_entries_chunked_array_file(self): min_neighbors = 2 stretch = 2 cmat_true = asset.ASSET.cluster_matrix_entries( - mmat, max_distance=max_distance, min_neighbors=min_neighbors, - stretch=stretch) + mmat, + max_distance=max_distance, + min_neighbors=min_neighbors, + stretch=stretch, + ) for working_memory in [1, 10, 100, 1000]: with tempfile.TemporaryDirectory() as tmpdir: cmat = asset.ASSET.cluster_matrix_entries( - mmat, max_distance=max_distance, - min_neighbors=min_neighbors, stretch=stretch, + mmat, + max_distance=max_distance, + min_neighbors=min_neighbors, + stretch=stretch, working_memory=working_memory, - array_file=Path(tmpdir) / f"test_dist_{working_memory}") + array_file=Path(tmpdir) / f"test_dist_{working_memory}", + ) assert_array_equal(cmat, cmat_true) def test_pmat_neighbors_gpu(self): @@ -346,9 +338,10 @@ def test_pmat_neighbors_gpu(self): filter_shape = (filter_size, 3) with warnings.catch_warnings(): # ignore even filter sizes - warnings.simplefilter('ignore', UserWarning) + warnings.simplefilter("ignore", UserWarning) pmat_neigh = asset._PMatNeighbors( - filter_shape=filter_shape, n_largest=n_largest) + filter_shape=filter_shape, n_largest=n_largest + ) lmat_true = pmat_neigh.cpu(pmat) if HAVE_PYOPENCL: lmat_opencl = pmat_neigh.pyopencl(pmat) @@ -366,8 +359,9 @@ def test_pmat_neighbors_gpu_chunked(self): pmat2 = np.random.random_sample((70, 27)).astype(np.float32) pmat3 = np.random.random_sample((41, 80)).astype(np.float32) for pmat in (pmat1, pmat2, pmat3): - pmat_neigh = asset._PMatNeighbors(filter_shape=filter_shape, - n_largest=n_largest) + pmat_neigh = asset._PMatNeighbors( + filter_shape=filter_shape, n_largest=n_largest + ) lmat_true = pmat_neigh.cpu(pmat) for max_chunk_size in (17, 20, 29): pmat_neigh.max_chunk_size = max_chunk_size @@ -384,8 +378,9 @@ def test_pmat_neighbors_gpu_overlapped_chunks(self): # Two last chunks overlap. np.random.seed(12) pmat = np.random.random_sample((50, 50)).astype(np.float32) - pmat_neigh = asset._PMatNeighbors(filter_shape=(11, 5), n_largest=3, - max_chunk_size=12) + pmat_neigh = asset._PMatNeighbors( + filter_shape=(11, 5), n_largest=3, max_chunk_size=12 + ) lmat_true = pmat_neigh.cpu(pmat) if HAVE_PYOPENCL: lmat_opencl = pmat_neigh.pyopencl(pmat) @@ -396,8 +391,9 @@ def test_pmat_neighbors_gpu_overlapped_chunks(self): def test_pmat_neighbors_deprecation_warning(self): with self.assertWarns(DeprecationWarning): - asset._PMatNeighbors(filter_shape=(11, 5), n_largest=3, - max_chunk_size=12, verbose=True) + asset._PMatNeighbors( + filter_shape=(11, 5), n_largest=3, max_chunk_size=12, verbose=True + ) def test_pmat_neighbors_invalid_input(self): np.random.seed(12) @@ -410,8 +406,9 @@ def test_pmat_neighbors_invalid_input(self): pmat_neigh = asset._PMatNeighbors(filter_shape=(21, 3), n_largest=3) np.fill_diagonal(pmat, 0.0) self.assertRaises(ValueError, pmat_neigh.compute, pmat) - pmat_neigh = asset._PMatNeighbors(filter_shape=(11, 3), n_largest=3, - max_chunk_size=10) + pmat_neigh = asset._PMatNeighbors( + filter_shape=(11, 3), n_largest=3, max_chunk_size=10 + ) if HAVE_PYOPENCL: # max_chunk_size > filter_shape self.assertRaises(ValueError, pmat_neigh.pyopencl, pmat) @@ -420,102 +417,137 @@ def test_pmat_neighbors_invalid_input(self): self.assertRaises(ValueError, pmat_neigh.pycuda, pmat) # Too small filter_shape - self.assertRaises(ValueError, asset._PMatNeighbors, - filter_shape=(11, 3), n_largest=100) + self.assertRaises( + ValueError, asset._PMatNeighbors, filter_shape=(11, 3), n_largest=100 + ) # w >= l - self.assertRaises(ValueError, asset._PMatNeighbors, - filter_shape=(9, 9), n_largest=3) + self.assertRaises( + ValueError, asset._PMatNeighbors, filter_shape=(9, 9), n_largest=3 + ) # not centered - self.assertWarns(UserWarning, asset._PMatNeighbors, - filter_shape=(10, 6), n_largest=3) + self.assertWarns( + UserWarning, asset._PMatNeighbors, filter_shape=(10, 6), n_largest=3 + ) def test_intersection_matrix(self): st1 = neo.SpikeTrain([1, 2, 4] * pq.ms, t_stop=6 * pq.ms) st2 = neo.SpikeTrain([1, 3, 4] * pq.ms, t_stop=6 * pq.ms) - st3 = neo.SpikeTrain([2, 5] * pq.ms, t_start=1 * pq.ms, - t_stop=6 * pq.ms) + st3 = neo.SpikeTrain([2, 5] * pq.ms, t_start=1 * pq.ms, t_stop=6 * pq.ms) bin_size = 1 * pq.ms asset_obj_same_t_start_stop = asset.ASSET( - [st1, st2], bin_size=bin_size, t_stop_i=5 * pq.ms, - t_stop_j=5 * pq.ms) + [st1, st2], bin_size=bin_size, t_stop_i=5 * pq.ms, t_stop_j=5 * pq.ms + ) # Check that the routine works for correct input... # ...same t_start, t_stop on both time axes imat_1_2 = asset_obj_same_t_start_stop.intersection_matrix() - trueimat_1_2 = np.array([[0., 0., 0., 0., 0.], - [0., 2., 1., 1., 2.], - [0., 1., 1., 0., 1.], - [0., 1., 0., 1., 1.], - [0., 2., 1., 1., 2.]]) - assert_array_equal(asset_obj_same_t_start_stop.x_edges, - np.arange(6) * pq.ms) # correct bins - assert_array_equal(asset_obj_same_t_start_stop.y_edges, - np.arange(6) * pq.ms) # correct bins + trueimat_1_2 = np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 1.0, 1.0, 2.0], + [0.0, 1.0, 1.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0, 1.0], + [0.0, 2.0, 1.0, 1.0, 2.0], + ] + ) + assert_array_equal( + asset_obj_same_t_start_stop.x_edges, np.arange(6) * pq.ms + ) # correct bins + assert_array_equal( + asset_obj_same_t_start_stop.y_edges, np.arange(6) * pq.ms + ) # correct bins assert_array_equal(imat_1_2, trueimat_1_2) # correct matrix # ...different t_start, t_stop on the two time axes asset_obj_different_t_start_stop = asset.ASSET( - [st1, st2], spiketrains_j=[st + 6 * pq.ms for st in [st1, st2]], - bin_size=bin_size, t_start_j=6 * pq.ms, t_stop_i=5 * pq.ms, - t_stop_j=11 * pq.ms) + [st1, st2], + spiketrains_j=[st + 6 * pq.ms for st in [st1, st2]], + bin_size=bin_size, + t_start_j=6 * pq.ms, + t_stop_i=5 * pq.ms, + t_stop_j=11 * pq.ms, + ) imat_1_2 = asset_obj_different_t_start_stop.intersection_matrix() - assert_array_equal(asset_obj_different_t_start_stop.x_edges, - np.arange(6) * pq.ms) # correct bins - assert_array_equal(asset_obj_different_t_start_stop.y_edges, - np.arange(6, 12) * pq.ms) + assert_array_equal( + asset_obj_different_t_start_stop.x_edges, np.arange(6) * pq.ms + ) # correct bins + assert_array_equal( + asset_obj_different_t_start_stop.y_edges, np.arange(6, 12) * pq.ms + ) self.assertTrue(np.all(imat_1_2 == trueimat_1_2)) # correct matrix # test with norm=1 imat_1_2 = asset_obj_same_t_start_stop.intersection_matrix( - normalization='intersection') - trueimat_1_2 = np.array([[0., 0., 0., 0., 0.], - [0., 1., 1., 1., 1.], - [0., 1., 1., 0., 1.], - [0., 1., 0., 1., 1.], - [0., 1., 1., 1., 1.]]) + normalization="intersection" + ) + trueimat_1_2 = np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0], + ] + ) assert_array_equal(imat_1_2, trueimat_1_2) # test with norm=2 - imat_1_2 = asset_obj_same_t_start_stop.intersection_matrix( - normalization='mean') - sq = np.sqrt(2) / 2. - trueimat_1_2 = np.array([[0., 0., 0., 0., 0.], - [0., 1., sq, sq, 1.], - [0., sq, 1., 0., sq], - [0., sq, 0., 1., sq], - [0., 1., sq, sq, 1.]]) + imat_1_2 = asset_obj_same_t_start_stop.intersection_matrix(normalization="mean") + sq = np.sqrt(2) / 2.0 + trueimat_1_2 = np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, sq, sq, 1.0], + [0.0, sq, 1.0, 0.0, sq], + [0.0, sq, 0.0, 1.0, sq], + [0.0, 1.0, sq, sq, 1.0], + ] + ) assert_array_almost_equal(imat_1_2, trueimat_1_2) # test with norm=3 imat_1_2 = asset_obj_same_t_start_stop.intersection_matrix( - normalization='union') - trueimat_1_2 = np.array([[0., 0., 0., 0., 0.], - [0., 1., .5, .5, 1.], - [0., .5, 1., 0., .5], - [0., .5, 0., 1., .5], - [0., 1., .5, .5, 1.]]) + normalization="union" + ) + trueimat_1_2 = np.array( + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.5, 0.5, 1.0], + [0.0, 0.5, 1.0, 0.0, 0.5], + [0.0, 0.5, 0.0, 1.0, 0.5], + [0.0, 1.0, 0.5, 0.5, 1.0], + ] + ) assert_array_almost_equal(imat_1_2, trueimat_1_2) # Check that errors are raised correctly... # ...for partially overlapping time intervals - self.assertRaises(ValueError, asset.ASSET, - spiketrains_i=[st1, st2], bin_size=bin_size, - t_start_j=1 * pq.ms) + self.assertRaises( + ValueError, + asset.ASSET, + spiketrains_i=[st1, st2], + bin_size=bin_size, + t_start_j=1 * pq.ms, + ) # ...for different SpikeTrain's t_starts - self.assertRaises(ValueError, asset.ASSET, - spiketrains_i=[st1, st3], bin_size=bin_size) + self.assertRaises( + ValueError, asset.ASSET, spiketrains_i=[st1, st3], bin_size=bin_size + ) # ...for different SpikeTrain's t_stops - self.assertRaises(ValueError, asset.ASSET, - spiketrains_i=[st1, st2], bin_size=bin_size, - t_stop_j=5 * pq.ms) + self.assertRaises( + ValueError, + asset.ASSET, + spiketrains_i=[st1, st2], + bin_size=bin_size, + t_stop_j=5 * pq.ms, + ) # regression test Issue #481 # see: https://github.com/NeuralEnsemble/elephant/issues/481 def test_asset_choose_backend_opencl(self): class TestClassBackend(asset._GPUBackend): - def __init__(self): super().__init__() self.backend = self._choose_backend() @@ -531,31 +563,35 @@ def pyopencl(self): # check which backend is chosen if environment variable for opencl # is not set - os.environ.pop('ELEPHANT_USE_OPENCL', None) + os.environ.pop("ELEPHANT_USE_OPENCL", None) # create object of TestClass backend_obj = TestClassBackend() if HAVE_PYOPENCL: - self.assertEqual(backend_obj.backend(), 'opencl') + self.assertEqual(backend_obj.backend(), "opencl") else: # if environment variable is not set and no module pyopencl or # device is found: choose cpu backend - self.assertEqual(backend_obj.backend(), 'cpu') + self.assertEqual(backend_obj.backend(), "cpu") - os.environ['ELEPHANT_USE_OPENCL'] = '0' + os.environ["ELEPHANT_USE_OPENCL"] = "0" def test_asset_deprecation_warning(self): st1 = neo.SpikeTrain([1, 2, 4] * pq.ms, t_stop=6 * pq.ms) st2 = neo.SpikeTrain([1, 3, 4] * pq.ms, t_stop=6 * pq.ms) bin_size = 1 * pq.ms with self.assertWarns(DeprecationWarning): - asset.ASSET([st1, st2], bin_size=bin_size, t_stop_i=5 * pq.ms, - t_stop_j=5 * pq.ms, verbose=True) + asset.ASSET( + [st1, st2], + bin_size=bin_size, + t_stop_i=5 * pq.ms, + t_stop_j=5 * pq.ms, + verbose=True, + ) -@unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn') +@unittest.skipUnless(HAVE_SKLEARN, "requires sklearn") class TestJSFUniformOrderStat3D(unittest.TestCase): - def test_combinations_with_replacement(self): # Test that _combinations_with_replacement yields the same tuples # as in the original implementation with itertools.product(*lists) @@ -573,16 +609,14 @@ def _wrong_order(a): for d in range(1, min(6, n + 1)): jsf = asset._JSFUniformOrderStat3D(n=n, d=d) lists = [range(j, n + 1) for j in range(d, 0, -1)] - matrix_entries = list( - jsf._combinations_with_replacement() - ) + matrix_entries = list(jsf._combinations_with_replacement()) matrix_entries_correct = [ - indices for indices in itertools.product(*lists) + indices + for indices in itertools.product(*lists) if not _wrong_order(indices) ] self.assertEqual(matrix_entries, matrix_entries_correct) - self.assertEqual(jsf.num_iterations, - len(matrix_entries_correct)) + self.assertEqual(jsf.num_iterations, len(matrix_entries_correct)) def test_next_sequence_sorted(self): # This test shows the main idea of CUDA ASSET parallelization that @@ -607,7 +641,8 @@ def next_sequence_sorted(jsf, iteration): for d in range(1, min(6, n + 1)): jsf = asset._JSFUniformOrderStat3D(n=n, d=d) for iter_id, seq_sorted_true in enumerate( - jsf._combinations_with_replacement()): + jsf._combinations_with_replacement() + ): seq_sorted = next_sequence_sorted(jsf, iteration=iter_id) self.assertEqual(seq_sorted, seq_sorted_true) @@ -628,11 +663,11 @@ def test_JSFUniformOrderStat3D_deprecation_warning(self): def test_point_mass_output(self): # When N >> D, the expected output is [1, 0] L, N, D = 2, 50, 2 - jsf = asset._JSFUniformOrderStat3D(n=N, d=D, precision='double') + jsf = asset._JSFUniformOrderStat3D(n=N, d=D, precision="double") u = np.arange(L * D, dtype=np.float32).reshape((-1, D)) u /= np.max(u) p_out = jsf.compute(u) - assert_array_almost_equal(p_out, [1., 0.], decimal=5) + assert_array_almost_equal(p_out, [1.0, 0.0], decimal=5) def test_precision(self): L = 2 @@ -640,29 +675,26 @@ def test_precision(self): for d in range(1, min(6, n + 1)): u = np.arange(L * d, dtype=np.float32).reshape((-1, d)) u /= np.max(u) - jsf_double = asset._JSFUniformOrderStat3D(n=n, d=d, - precision='double') - jsf_float = asset._JSFUniformOrderStat3D(n=n, d=d, - precision='float') + jsf_double = asset._JSFUniformOrderStat3D(n=n, d=d, precision="double") + jsf_float = asset._JSFUniformOrderStat3D(n=n, d=d, precision="float") P_total_double = jsf_double.compute(u) P_total_float = jsf_float.compute(u) # decimal 5 is used because the number of iterations is small # in practice, the results deviate starting at decimal 3 or 2 - assert_array_almost_equal(P_total_double, P_total_float, - decimal=5) + assert_array_almost_equal(P_total_double, P_total_float, decimal=5) def test_gpu(self): np.random.seed(12) - for precision, L, n in itertools.product(['float', 'double'], - [1, 23, 100], [1, 3, 10]): + for precision, L, n in itertools.product( + ["float", "double"], [1, 23, 100], [1, 3, 10] + ): for d in range(1, min(4, n + 1)): u = np.random.rand(L, d).cumsum(axis=1) u /= np.max(u) - jsf = asset._JSFUniformOrderStat3D(n=n, d=d, - precision=precision) + jsf = asset._JSFUniformOrderStat3D(n=n, d=d, precision=precision) du = np.diff(u, prepend=0, append=1, axis=1) with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) + warnings.simplefilter("ignore", RuntimeWarning) log_du = np.log(du, dtype=np.float32) P_total_cpu = jsf.cpu(log_du) for max_chunk_size in [None, 22]: @@ -683,7 +715,7 @@ def test_gpu_threads_and_cwr_loops(self): u /= np.max(u) du = np.diff(u, prepend=0, append=1, axis=1) with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) + warnings.simplefilter("ignore", RuntimeWarning) log_du = np.log(du, dtype=np.float32) def run_test(jsf, jsf_backend): @@ -696,10 +728,10 @@ def run_test(jsf, jsf_backend): assert_array_almost_equal(P_total, P_expected) if HAVE_PYOPENCL: - jsf = asset._JSFUniformOrderStat3D(n=N, d=D, precision='float') + jsf = asset._JSFUniformOrderStat3D(n=N, d=D, precision="float") run_test(jsf, jsf.pyopencl) if HAVE_CUDA: - jsf = asset._JSFUniformOrderStat3D(n=N, d=D, precision='float') + jsf = asset._JSFUniformOrderStat3D(n=N, d=D, precision="float") run_test(jsf, jsf.pycuda) def test_gpu_chunked(self): @@ -708,7 +740,7 @@ def test_gpu_chunked(self): u /= np.max(u) du = np.diff(u, prepend=0, append=1, axis=1) with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) + warnings.simplefilter("ignore", RuntimeWarning) log_du = np.log(du, dtype=np.float32) jsf = asset._JSFUniformOrderStat3D(n=N, d=D) P_true = jsf.cpu(log_du) @@ -730,7 +762,7 @@ def test_watchdog(self): self.assertWarns(UserWarning, jsf.compute, u) -@unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn') +@unittest.skipUnless(HAVE_SKLEARN, "requires sklearn") class AssetTestIntegration(unittest.TestCase): def setUp(self): # common for all tests @@ -749,27 +781,26 @@ def test_probability_matrix_symmetric(self): spiketrains_copy.append(st.copy()) asset_obj = asset.ASSET(spiketrains, bin_size=self.bin_size) - asset_obj_symmetric = asset.ASSET(spiketrains, - spiketrains_j=spiketrains_copy, - bin_size=self.bin_size) + asset_obj_symmetric = asset.ASSET( + spiketrains, spiketrains_j=spiketrains_copy, bin_size=self.bin_size + ) imat = asset_obj.intersection_matrix() - pmat = asset_obj.probability_matrix_analytical( - kernel_width=kernel_width) + pmat = asset_obj.probability_matrix_analytical(kernel_width=kernel_width) imat_symm = asset_obj_symmetric.intersection_matrix() pmat_symm = asset_obj_symmetric.probability_matrix_analytical( - kernel_width=kernel_width) + kernel_width=kernel_width + ) assert_array_almost_equal(pmat, pmat_symm) assert_array_almost_equal(imat, imat_symm) - assert_array_almost_equal(asset_obj.x_edges, - asset_obj_symmetric.x_edges) - assert_array_almost_equal(asset_obj.y_edges, - asset_obj_symmetric.y_edges) + assert_array_almost_equal(asset_obj.x_edges, asset_obj_symmetric.x_edges) + assert_array_almost_equal(asset_obj.y_edges, asset_obj_symmetric.y_edges) - def _test_integration_subtest(self, spiketrains, spiketrains_y, - indices_pmat, index_proba, expected_sses): + def _test_integration_subtest( + self, spiketrains, spiketrains_y, indices_pmat, index_proba, expected_sses + ): # define parameters random.seed(1) kernel_width = 9 * pq.ms @@ -783,31 +814,30 @@ def _test_integration_subtest(self, spiketrains, spiketrains_y, n_surr = 20 def _get_rates(_spiketrains): - kernel_sigma = kernel_width / 2. / np.sqrt(3.) + kernel_sigma = kernel_width / 2.0 / np.sqrt(3.0) kernel = kernels.RectangularKernel(sigma=kernel_sigma) - rates = [statistics.instantaneous_rate( - st, - kernel=kernel, - sampling_period=1 * pq.ms) - for st in _spiketrains] + rates = [ + statistics.instantaneous_rate( + st, kernel=kernel, sampling_period=1 * pq.ms + ) + for st in _spiketrains + ] return rates - asset_obj = asset.ASSET(spiketrains, spiketrains_y, - bin_size=self.bin_size) + asset_obj = asset.ASSET(spiketrains, spiketrains_y, bin_size=self.bin_size) # calculate the intersection matrix imat = asset_obj.intersection_matrix() # calculate probability matrix analytical - pmat = asset_obj.probability_matrix_analytical( - imat, - kernel_width=kernel_width) + pmat = asset_obj.probability_matrix_analytical(imat, kernel_width=kernel_width) # check if pmat is the same when rates are provided pmat_as_rates = asset_obj.probability_matrix_analytical( imat, firing_rates_x=_get_rates(spiketrains), - firing_rates_y=_get_rates(spiketrains_y)) + firing_rates_y=_get_rates(spiketrains_y), + ) assert_array_almost_equal(pmat, pmat_as_rates) # calculate probability matrix montecarlo @@ -815,24 +845,23 @@ def _get_rates(_spiketrains): n_surrogates=n_surr, imat=imat, surrogate_dt=surrogate_dt, - surrogate_method='dither_spikes') + surrogate_method="dither_spikes", + ) # test probability matrices assert_array_equal(np.where(pmat > alpha), indices_pmat) - assert_array_equal(np.where(pmat_montecarlo > alpha), - indices_pmat) + assert_array_equal(np.where(pmat_montecarlo > alpha), indices_pmat) # calculate joint probability matrix jmat = asset_obj.joint_probability_matrix( - pmat, - filter_shape=filter_shape, - n_largest=nr_largest) + pmat, filter_shape=filter_shape, n_largest=nr_largest + ) # test joint probability matrix - assert_array_equal(np.where(jmat > 0.98), index_proba['high']) - assert_array_equal(np.where(jmat > 0.9), index_proba['medium']) - assert_array_equal(np.where(jmat > 0.8), index_proba['low']) + assert_array_equal(np.where(jmat > 0.98), index_proba["high"]) + assert_array_equal(np.where(jmat > 0.9), index_proba["medium"]) + assert_array_equal(np.where(jmat > 0.8), index_proba["low"]) # test if all other entries are zeros mask_zeros = np.ones(jmat.shape, bool) - mask_zeros[index_proba['low']] = False + mask_zeros[index_proba["low"]] = False self.assertTrue(np.all(jmat[mask_zeros] == 0)) # calculate mask matrix and cluster matrix @@ -841,7 +870,8 @@ def _get_rates(_spiketrains): mmat, max_distance=max_distance, min_neighbors=min_neighbors, - stretch=stretch) + stretch=stretch, + ) # extract sses and test them sses = asset_obj.extract_synchronous_events(cmat) @@ -864,40 +894,52 @@ def test_integration(self): # ground truth for pmats starting_bin_1 = int((delay / self.bin_size).magnitude.item()) starting_bin_2 = int( - (2 * delay / self.bin_size + time_between_sses / self.bin_size - ).magnitude.item()) + ( + 2 * delay / self.bin_size + time_between_sses / self.bin_size + ).magnitude.item() + ) indices_pmat_1 = np.arange(starting_bin_1, starting_bin_1 + size_sse) - indices_pmat_2 = np.arange(starting_bin_2, - starting_bin_2 + size_sse) - indices_pmat = (np.concatenate((indices_pmat_1, indices_pmat_2)), - np.concatenate((indices_pmat_2, indices_pmat_1))) + indices_pmat_2 = np.arange(starting_bin_2, starting_bin_2 + size_sse) + indices_pmat = ( + np.concatenate((indices_pmat_1, indices_pmat_2)), + np.concatenate((indices_pmat_2, indices_pmat_1)), + ) # generate spike trains - spiketrains = [neo.SpikeTrain([index_spiketrain, - index_spiketrain + - size_sse + - bins_between_sses] * self.bin_size - + delay + 1 * pq.ms, - t_stop=T) - for index_group in range(size_group) - for index_spiketrain in range(size_sse)] + spiketrains = [ + neo.SpikeTrain( + [index_spiketrain, index_spiketrain + size_sse + bins_between_sses] + * self.bin_size + + delay + + 1 * pq.ms, + t_stop=T, + ) + for index_group in range(size_group) + for index_spiketrain in range(size_sse) + ] index_proba = { - "high": (np.array([9, 9, 10, 10, 10, 11, 11]), - np.array([3, 4, 3, 4, 5, 4, 5])), - "medium": (np.array([8, 8, 9, 9, 9, 10, 10, - 10, 11, 11, 11, 12, 12]), - np.array([2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6])), - "low": (np.array([7, 8, 8, 9, 9, 9, 10, 10, 10, - 11, 11, 11, 12, 12, 12, 13, 13]), - np.array([2, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, - 6, 5, 6, 7, 6, 7])) + "high": ( + np.array([9, 9, 10, 10, 10, 11, 11]), + np.array([3, 4, 3, 4, 5, 4, 5]), + ), + "medium": ( + np.array([8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12]), + np.array([2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6]), + ), + "low": ( + np.array( + [7, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13] + ), + np.array([2, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7]), + ), } - expected_sses = {1: {(9, 3): {0, 3, 6}, (10, 4): {1, 4, 7}, - (11, 5): {2, 5, 8}}} - self._test_integration_subtest(spiketrains, - spiketrains_y=spiketrains, - indices_pmat=indices_pmat, - index_proba=index_proba, - expected_sses=expected_sses) + expected_sses = {1: {(9, 3): {0, 3, 6}, (10, 4): {1, 4, 7}, (11, 5): {2, 5, 8}}} + self._test_integration_subtest( + spiketrains, + spiketrains_y=spiketrains, + indices_pmat=indices_pmat, + index_proba=index_proba, + expected_sses=expected_sses, + ) def test_integration_nonsymmetric(self): # define parameters @@ -914,35 +956,45 @@ def test_integration_nonsymmetric(self): indices_pmat = (indices_pmat_1, indices_pmat_1) # generate spike trains spiketrains = [ - neo.SpikeTrain([index_spiketrain] * self.bin_size + delay, - t_start=0 * pq.ms, - t_stop=2 * delay + size_sse * self.bin_size) + neo.SpikeTrain( + [index_spiketrain] * self.bin_size + delay, + t_start=0 * pq.ms, + t_stop=2 * delay + size_sse * self.bin_size, + ) for index_group in range(size_group) - for index_spiketrain in range(size_sse)] + for index_spiketrain in range(size_sse) + ] spiketrains_y = [ - neo.SpikeTrain([index_spiketrain] * self.bin_size + delay + - time_between_sses + size_sse * self.bin_size, - t_start=size_sse * self.bin_size + 2 * delay, - t_stop=T) + neo.SpikeTrain( + [index_spiketrain] * self.bin_size + + delay + + time_between_sses + + size_sse * self.bin_size, + t_start=size_sse * self.bin_size + 2 * delay, + t_stop=T, + ) for index_group in range(size_group) - for index_spiketrain in range(size_sse)] + for index_spiketrain in range(size_sse) + ] index_proba = { - "high": ([6, 6, 7, 7, 7, 8, 8], - [6, 7, 6, 7, 8, 7, 8]), - "medium": ([5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9], - [5, 6, 5, 6, 7, 6, 7, 8, 7, 8, 9, 8, 9]), - "low": ([4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, - 8, 8, 8, 9, 9, 9, 10, 10], - [4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8, - 7, 8, 9, 8, 9, 10, 9, 10]) + "high": ([6, 6, 7, 7, 7, 8, 8], [6, 7, 6, 7, 8, 7, 8]), + "medium": ( + [5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9], + [5, 6, 5, 6, 7, 6, 7, 8, 7, 8, 9, 8, 9], + ), + "low": ( + [4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10], + [4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8, 7, 8, 9, 8, 9, 10, 9, 10], + ), } - expected_sses = {1: {(6, 6): {0, 3, 6}, (7, 7): {1, 4, 7}, - (8, 8): {2, 5, 8}}} - self._test_integration_subtest(spiketrains, - spiketrains_y=spiketrains_y, - indices_pmat=indices_pmat, - index_proba=index_proba, - expected_sses=expected_sses) + expected_sses = {1: {(6, 6): {0, 3, 6}, (7, 7): {1, 4, 7}, (8, 8): {2, 5, 8}}} + self._test_integration_subtest( + spiketrains, + spiketrains_y=spiketrains_y, + indices_pmat=indices_pmat, + index_proba=index_proba, + expected_sses=expected_sses, + ) if __name__ == "__main__": diff --git a/elephant/test/test_causality.py b/elephant/test/test_causality.py index 552fe2f05..5c15231e4 100644 --- a/elephant/test/test_causality.py +++ b/elephant/test/test_causality.py @@ -5,6 +5,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + from __future__ import division, print_function import unittest @@ -20,7 +21,6 @@ class PairwiseGrangerTestCase(unittest.TestCase): - @classmethod def setUpClass(cls): np.random.seed(1) @@ -36,15 +36,13 @@ def _generate_ground_truth(length_2d=30000): weights = np.stack((weights_1, weights_2)) - noise_covariance = np.array([[1., 0.0], [0.0, 1.]]) + noise_covariance = np.array([[1.0, 0.0], [0.0, 1.0]]) for i in range(length_2d): for lag in range(order): - signal[:, i + order] += np.dot(weights[lag], - signal[:, i + 1 - lag]) - rnd_var = np.random.multivariate_normal([0, 0], - noise_covariance) - signal[:, i+order] += rnd_var + signal[:, i + order] += np.dot(weights[lag], signal[:, i + 1 - lag]) + rnd_var = np.random.multivariate_normal([0, 0], noise_covariance) + signal[:, i + order] += rnd_var signal = signal[:, 2:] @@ -60,40 +58,53 @@ def setUp(self): # Estimate Granger causality self.causality = elephant.causality.granger.pairwise_granger( - self.signal, max_order=10, - information_criterion='bic') + self.signal, max_order=10, information_criterion="bic" + ) def test_analog_signal_input(self): """ Check if analog signal input result matches an otherwise identical 2D numpy array input result. """ - analog_signal = AnalogSignal(self.signal, units='V', - sampling_rate=1*pq.Hz) - analog_signal_causality = \ - elephant.causality.granger.pairwise_granger( - analog_signal, max_order=10, - information_criterion='bic') - self.assertEqual(analog_signal_causality.directional_causality_x_y, - self.causality.directional_causality_x_y) - self.assertEqual(analog_signal_causality.directional_causality_y_x, - self.causality.directional_causality_y_x) - self.assertEqual(analog_signal_causality.instantaneous_causality, - self.causality.instantaneous_causality) - self.assertEqual(analog_signal_causality.total_interdependence, - self.causality.total_interdependence) + analog_signal = AnalogSignal(self.signal, units="V", sampling_rate=1 * pq.Hz) + analog_signal_causality = elephant.causality.granger.pairwise_granger( + analog_signal, max_order=10, information_criterion="bic" + ) + self.assertEqual( + analog_signal_causality.directional_causality_x_y, + self.causality.directional_causality_x_y, + ) + self.assertEqual( + analog_signal_causality.directional_causality_y_x, + self.causality.directional_causality_y_x, + ) + self.assertEqual( + analog_signal_causality.instantaneous_causality, + self.causality.instantaneous_causality, + ) + self.assertEqual( + analog_signal_causality.total_interdependence, + self.causality.total_interdependence, + ) def test_aic(self): identity_matrix = np.eye(2, 2) - self.assertEqual(elephant.causality.granger._aic( - identity_matrix, order=2, dimension=2, length=2 - ), 8.0) + self.assertEqual( + elephant.causality.granger._aic( + identity_matrix, order=2, dimension=2, length=2 + ), + 8.0, + ) def test_bic(self): identity_matrix = np.eye(2, 2) - assert_array_almost_equal(elephant.causality.granger._bic( - identity_matrix, order=2, dimension=2, length=2 - ), 5.54517744, decimal=8) + assert_array_almost_equal( + elephant.causality.granger._bic( + identity_matrix, order=2, dimension=2, length=2 + ), + 5.54517744, + decimal=8, + ) def test_lag_covariances_error(self): """ @@ -101,28 +112,42 @@ def test_lag_covariances_error(self): ValueError is raised. """ short_signals = np.array([[1, 2], [3, 4]]) - self.assertRaises(ValueError, - elephant.causality.granger._lag_covariances, - short_signals, dimension=2, max_lag=3) + self.assertRaises( + ValueError, + elephant.causality.granger._lag_covariances, + short_signals, + dimension=2, + max_lag=3, + ) def test_pairwise_granger_error_null_signals(self): null_signals = np.array([[0, 0], [0, 0]]) - self.assertRaises(ValueError, - elephant.causality.granger.pairwise_granger, - null_signals, max_order=2) + self.assertRaises( + ValueError, + elephant.causality.granger.pairwise_granger, + null_signals, + max_order=2, + ) def test_pairwise_granger_identical_signal(self): - same_signal = np.hstack([self.signal[:, 0, np.newaxis], - self.signal[:, 0, np.newaxis]]) - self.assertRaises(ValueError, - elephant.causality.granger.pairwise_granger, - signals=same_signal, max_order=2) + same_signal = np.hstack( + [self.signal[:, 0, np.newaxis], self.signal[:, 0, np.newaxis]] + ) + self.assertRaises( + ValueError, + elephant.causality.granger.pairwise_granger, + signals=same_signal, + max_order=2, + ) def test_pairwise_granger_error_1d_array(self): array_1d = np.ones(10, dtype=np.float32) - self.assertRaises(ValueError, - elephant.causality.granger.pairwise_granger, - array_1d, max_order=2) + self.assertRaises( + ValueError, + elephant.causality.granger.pairwise_granger, + array_1d, + max_order=2, + ) def test_result_namedtuple(self): """ @@ -154,19 +179,19 @@ def test_total_channel_interdependence_equals_sum_of_other_three(self): measures. It should be equal. In this test, however, almost equality is asserted due to a loss of significance with larger datasets. """ - causality_sum = self.causality.directional_causality_x_y \ - + self.causality.directional_causality_y_x \ + causality_sum = ( + self.causality.directional_causality_x_y + + self.causality.directional_causality_y_x + self.causality.instantaneous_causality - assert_array_almost_equal(self.causality.total_interdependence, - causality_sum, decimal=2) + ) + assert_array_almost_equal( + self.causality.total_interdependence, causality_sum, decimal=2 + ) def test_all_four_result_values_are_floats(self): - self.assertIsInstance(self.causality.directional_causality_x_y, - float) - self.assertIsInstance(self.causality.directional_causality_y_x, - float) - self.assertIsInstance(self.causality.instantaneous_causality, - float) + self.assertIsInstance(self.causality.directional_causality_x_y, float) + self.assertIsInstance(self.causality.directional_causality_y_x, float) + self.assertIsInstance(self.causality.instantaneous_causality, float) self.assertIsInstance(self.causality.total_interdependence, float) def test_ground_truth_vector_autoregressive_model(self): @@ -187,30 +212,33 @@ def test_ground_truth_vector_autoregressive_model(self): second_y2_l2 = -5.012730e-01 coefficients, _, _ = elephant.causality.granger._optimal_vector_arm( - self.ground_truth.T, dimension=2, max_order=10, - information_criterion='aic') + self.ground_truth.T, dimension=2, max_order=10, information_criterion="aic" + ) # Arrange the ground truth values in the same shape as coefficients ground_truth_coefficients = np.asarray( - [[[first_y1_l1, first_y2_l1], - [second_y1_l1, second_y2_l1]], - [[first_y1_l2, first_y2_l2], - [second_y1_l2, second_y2_l2]]] + [ + [[first_y1_l1, first_y2_l1], [second_y1_l1, second_y2_l1]], + [[first_y1_l2, first_y2_l2], [second_y1_l2, second_y2_l2]], + ] ) - assert_array_almost_equal(coefficients, ground_truth_coefficients, - decimal=4) + assert_array_almost_equal(coefficients, ground_truth_coefficients, decimal=4) def test_wrong_kwarg_optimal_vector_arm(self): - wrong_ic_criterion = 'cic' - - self.assertRaises(ValueError, - elephant.causality.granger._optimal_vector_arm, - self.ground_truth.T, 2, 10, wrong_ic_criterion) + wrong_ic_criterion = "cic" + + self.assertRaises( + ValueError, + elephant.causality.granger._optimal_vector_arm, + self.ground_truth.T, + 2, + 10, + wrong_ic_criterion, + ) class ConditionalGrangerTestCase(unittest.TestCase): - @classmethod def setUpClass(cls): np.random.seed(1) @@ -233,32 +261,23 @@ def _generate_ground_truth(length_2d=30000, causality_type="indirect"): elif causality_type == "both": y_t_lag_2 = 0.2 else: - raise ValueError("causality_type should be either 'indirect' or " - "'both'") + raise ValueError("causality_type should be either 'indirect' or " "'both'") order = 2 signal = np.zeros((3, length_2d + order)) - weights_1 = np.array([[0.8, 0, 0.4], - [0, 0.9, 0], - [0., 0.5, 0.5]]) + weights_1 = np.array([[0.8, 0, 0.4], [0, 0.9, 0], [0.0, 0.5, 0.5]]) - weights_2 = np.array([[-0.5, y_t_lag_2, 0.], - [0., -0.8, 0], - [0, 0, -0.2]]) + weights_2 = np.array([[-0.5, y_t_lag_2, 0.0], [0.0, -0.8, 0], [0, 0, -0.2]]) weights = np.stack((weights_1, weights_2)) - noise_covariance = np.array([[0.3, 0.0, 0.0], - [0.0, 1., 0.0], - [0.0, 0.0, 0.2]]) + noise_covariance = np.array([[0.3, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.2]]) for i in range(length_2d): for lag in range(order): - signal[:, i + order] += np.dot(weights[lag], - signal[:, i + 1 - lag]) - rnd_var = np.random.multivariate_normal([0, 0, 0], - noise_covariance) + signal[:, i + order] += np.dot(weights[lag], signal[:, i + 1 - lag]) + rnd_var = np.random.multivariate_normal([0, 0, 0], noise_covariance) signal[:, i + order] += rnd_var signal = signal[:, 2:] @@ -276,35 +295,48 @@ def setUp(self): # Generate a small dataset for containing both direct and indirect # causality. self.non_zero_signal = self._generate_ground_truth( - length_2d=1000, causality_type="both") + length_2d=1000, causality_type="both" + ) # Estimate Granger causality - self.conditional_causality = elephant.causality.granger.\ - conditional_granger(self.signal, max_order=10, - information_criterion='bic') + self.conditional_causality = elephant.causality.granger.conditional_granger( + self.signal, max_order=10, information_criterion="bic" + ) def test_result_is_float(self): self.assertIsInstance(self.conditional_causality, float) def test_ground_truth_zero_value_conditional_causality(self): - self.assertEqual(elephant.causality.granger.conditional_granger( - self.ground_truth, 10, 'bic'), 0.0) + self.assertEqual( + elephant.causality.granger.conditional_granger( + self.ground_truth, 10, "bic" + ), + 0.0, + ) def test_ground_truth_zero_value_conditional_causality_anasig(self): - signals = AnalogSignal(self.ground_truth, sampling_rate=1*pq.Hz, - units='V') - self.assertEqual(elephant.causality.granger.conditional_granger( - signals, 10, 'bic'), 0.0) + signals = AnalogSignal(self.ground_truth, sampling_rate=1 * pq.Hz, units="V") + self.assertEqual( + elephant.causality.granger.conditional_granger(signals, 10, "bic"), 0.0 + ) def test_non_zero_conditional_causality(self): - self.assertGreater(elephant.causality.granger.conditional_granger( - self.non_zero_signal, 10, 'bic'), 0.0) + self.assertGreater( + elephant.causality.granger.conditional_granger( + self.non_zero_signal, 10, "bic" + ), + 0.0, + ) def test_conditional_causality_wrong_input_shape(self): signals = np.random.normal(0, 1, (4, 10, 1)) - self.assertRaises(ValueError, - elephant.causality.granger.conditional_granger, - signals, 10, 'bic') + self.assertRaises( + ValueError, + elephant.causality.granger.conditional_granger, + signals, + 10, + "bic", + ) class PairwiseSpectralGrangerTestCase(unittest.TestCase): @@ -319,18 +351,18 @@ def test_bracket_operator_one_signal(self): # (ii). spectrum_causal = np.fft.ifft(spectrum, axis=0) - spectrum_causal[(n + 1) // 2:] = 0 + spectrum_causal[(n + 1) // 2 :] = 0 spectrum_causal[0] /= 2 spectrum_causal_ground_truth = np.fft.fft(spectrum_causal, axis=0) spectrum_causal_est = elephant.causality.granger._bracket_operator( - spectrum=spectrum, - num_freqs=n, - num_signals=1) + spectrum=spectrum, num_freqs=n, num_signals=1 + ) - np.testing.assert_array_almost_equal(spectrum_causal_est, - spectrum_causal_ground_truth) + np.testing.assert_array_almost_equal( + spectrum_causal_est, spectrum_causal_ground_truth + ) def test_bracket_operator_mult_signal(self): # Generate a spectrum from random dataset and test bracket operator @@ -344,7 +376,7 @@ def test_bracket_operator_mult_signal(self): # (ii). spectrum_causal = np.fft.ifft(spectrum, axis=0) - spectrum_causal[(n + 1) // 2:] = 0 + spectrum_causal[(n + 1) // 2 :] = 0 spectrum_causal[0] /= 2 spectrum_causal_ground_truth = np.fft.fft(spectrum_causal, axis=0) @@ -355,12 +387,12 @@ def test_bracket_operator_mult_signal(self): spectrum_causal_ground_truth[0, 2, 1] = 0 spectrum_causal_est = elephant.causality.granger._bracket_operator( - spectrum=spectrum, - num_freqs=n, - num_signals=num_signals) + spectrum=spectrum, num_freqs=n, num_signals=num_signals + ) - np.testing.assert_array_almost_equal(spectrum_causal_est, - spectrum_causal_ground_truth) + np.testing.assert_array_almost_equal( + spectrum_causal_est, spectrum_causal_ground_truth + ) def test_spectral_factorization(self): np.random.seed(11) @@ -368,18 +400,20 @@ def test_spectral_factorization(self): num_signals = 2 signals = np.random.normal(0, 1, (num_signals, n)) - _, cross_spec = multitaper_cross_spectrum(signals, - return_onesided=True) + _, cross_spec = multitaper_cross_spectrum(signals, return_onesided=True) cross_spec = np.transpose(cross_spec, (2, 0, 1)) - cov_matrix, transfer_function = \ + cov_matrix, transfer_function = ( elephant.causality.granger._spectral_factorization( - cross_spec, num_iterations=100) + cross_spec, num_iterations=100 + ) + ) - cross_spec_est = np.matmul(np.matmul(transfer_function, cov_matrix), - elephant.causality.granger._dagger( - transfer_function)) + cross_spec_est = np.matmul( + np.matmul(transfer_function, cov_matrix), + elephant.causality.granger._dagger(transfer_function), + ) np.testing.assert_array_almost_equal(cross_spec, cross_spec_est) @@ -389,21 +423,26 @@ def test_spectral_factorization_non_conv_exception(self): num_signals = 2 signals = np.random.normal(0, 1, (num_signals, n)) - _, cross_spec = multitaper_cross_spectrum(signals, - return_onesided=True) + _, cross_spec = multitaper_cross_spectrum(signals, return_onesided=True) cross_spec = np.transpose(cross_spec, (2, 0, 1)) - self.assertRaises(Exception, - elephant.causality.granger._spectral_factorization, - cross_spec, num_iterations=1) + self.assertRaises( + Exception, + elephant.causality.granger._spectral_factorization, + cross_spec, + num_iterations=1, + ) def test_spectral_factorization_initial_cond(self): # Cross spectrum at zero frequency must always be symmetric wrong_cross_spec = np.array([[[1, 2], [-1, 1]], [[1, 1], [1, 1]]]) - self.assertRaises(ValueError, - elephant.causality.granger._spectral_factorization, - wrong_cross_spec, num_iterations=10) + self.assertRaises( + ValueError, + elephant.causality.granger._spectral_factorization, + wrong_cross_spec, + num_iterations=10, + ) def test_dagger_2d(self): matrix_array = np.array([[1j, 0], [2, 3]], dtype=complex) @@ -420,20 +459,21 @@ def test_total_channel_interdependence_equals_transformed_coherence(self): num_signals = 2 signals = np.random.normal(0, 1, (num_signals, n)) - freqs, coh, phase_lag = multitaper_coherence(signals[0], signals[1], - len_segment=2**7, - num_tapers=2) - f, spectral_causality = \ - elephant.causality.granger.pairwise_spectral_granger( - signals[0], signals[1], len_segment=2**7, num_tapers=2) + freqs, coh, phase_lag = multitaper_coherence( + signals[0], signals[1], len_segment=2**7, num_tapers=2 + ) + f, spectral_causality = elephant.causality.granger.pairwise_spectral_granger( + signals[0], signals[1], len_segment=2**7, num_tapers=2 + ) total_interdependence = spectral_causality[3] # Cut last frequency due to length of segment being even and # multitaper_coherence using the real FFT in contrast to # pairwise_spectral_granger which has to use the full FFT. true_total_interdependence = -np.log(1 - coh)[:-1] - np.testing.assert_allclose(total_interdependence, - true_total_interdependence, atol=1e-5) + np.testing.assert_allclose( + total_interdependence, true_total_interdependence, atol=1e-5 + ) def test_pairwise_spectral_granger_against_ground_truth(self): """ @@ -444,76 +484,81 @@ def test_pairwise_spectral_granger_against_ground_truth(self): """ - repo_path = \ - r"unittest/causality/granger/pairwise_spectral_granger/data" + repo_path = r"unittest/causality/granger/pairwise_spectral_granger/data" files_to_download = [ ("time_series.npy", "54e0b3fbd904ccb48c75228c070a1a2a"), ("weights.npy", "eb1fc5590da5507293c63b25b1e3a7fc"), - ("noise_covariance.npy", "6f80ccff2b2aa9485dc9c01d81570bf5") + ("noise_covariance.npy", "6f80ccff2b2aa9485dc9c01d81570bf5"), ] for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) - signals = np.load(ELEPHANT_TMP_DIR / 'time_series.npy') - weights = np.load(ELEPHANT_TMP_DIR / 'weights.npy') - cov = np.load(ELEPHANT_TMP_DIR / 'noise_covariance.npy') + download_datasets(repo_path=f"{repo_path}/{filename}", checksum=checksum) + signals = np.load(ELEPHANT_TMP_DIR / "time_series.npy") + weights = np.load(ELEPHANT_TMP_DIR / "weights.npy") + cov = np.load(ELEPHANT_TMP_DIR / "noise_covariance.npy") # Estimate spectral Granger Causality - f, spectral_causality = \ - elephant.causality.granger.pairwise_spectral_granger( - signals[0], signals[1], len_segment=2**7, num_tapers=3) + f, spectral_causality = elephant.causality.granger.pairwise_spectral_granger( + signals[0], signals[1], len_segment=2**7, num_tapers=3 + ) # Calculate ground truth spectral Granger Causality # Formulae taken from Ding et al., Granger Causality: Basic Theory and # Application to Neuroscience, 2006 fn = np.linspace(0, np.pi, len(f)) freqs_for_theo = np.array([1, 2])[:, np.newaxis] * fn - A_theo = (np.identity(2)[np.newaxis] - - weights[0] * np.exp( - - 1j * freqs_for_theo[0][:, np.newaxis, np.newaxis])) + A_theo = np.identity(2)[np.newaxis] - weights[0] * np.exp( + -1j * freqs_for_theo[0][:, np.newaxis, np.newaxis] + ) A_theo -= weights[1] * np.exp( - - 1j * freqs_for_theo[1][:, np.newaxis, np.newaxis]) + -1j * freqs_for_theo[1][:, np.newaxis, np.newaxis] + ) - H_theo = np.array([[A_theo[:, 1, 1], -A_theo[:, 0, 1]], - [-A_theo[:, 1, 0], A_theo[:, 0, 0]]]) + H_theo = np.array( + [[A_theo[:, 1, 1], -A_theo[:, 0, 1]], [-A_theo[:, 1, 0], A_theo[:, 0, 0]]] + ) H_theo /= np.linalg.det(A_theo) H_theo = np.moveaxis(H_theo, 2, 0) - S_theo = np.matmul(np.matmul(H_theo, cov), - elephant.causality.granger._dagger(H_theo)) + S_theo = np.matmul( + np.matmul(H_theo, cov), elephant.causality.granger._dagger(H_theo) + ) - H_tilde_xx = (H_theo[:, 0, 0] + (cov[0, 1] / - cov[0, 0] * H_theo[:, 0, 1])) - H_tilde_yy = (H_theo[:, 1, 1] + (cov[0, 1] / - cov[1, 1] * H_theo[:, 1, 0])) + H_tilde_xx = H_theo[:, 0, 0] + (cov[0, 1] / cov[0, 0] * H_theo[:, 0, 1]) + H_tilde_yy = H_theo[:, 1, 1] + (cov[0, 1] / cov[1, 1] * H_theo[:, 1, 0]) - true_directional_causality_y_x = np.log(S_theo[:, 0, 0].real / - (H_tilde_xx - * cov[0, 0] - * H_tilde_xx.conj()).real) + true_directional_causality_y_x = np.log( + S_theo[:, 0, 0].real / (H_tilde_xx * cov[0, 0] * H_tilde_xx.conj()).real + ) - true_directional_causality_x_y = np.log(S_theo[:, 1, 1].real / - (H_tilde_yy * cov[1, 1] * - H_tilde_yy.conj()).real) + true_directional_causality_x_y = np.log( + S_theo[:, 1, 1].real / (H_tilde_yy * cov[1, 1] * H_tilde_yy.conj()).real + ) true_instantaneous_causality = np.log( (H_tilde_xx * cov[0, 0] * H_tilde_xx.conj()).real - * (H_tilde_yy * cov[1, 1] * H_tilde_yy.conj()).real) + * (H_tilde_yy * cov[1, 1] * H_tilde_yy.conj()).real + ) true_instantaneous_causality -= np.linalg.slogdet(S_theo)[1] np.testing.assert_allclose( spectral_causality.directional_causality_x_y, - true_directional_causality_x_y, atol=0.06) + true_directional_causality_x_y, + atol=0.06, + ) np.testing.assert_allclose( spectral_causality.directional_causality_y_x, - true_directional_causality_y_x, atol=0.06) + true_directional_causality_y_x, + atol=0.06, + ) np.testing.assert_allclose( spectral_causality.instantaneous_causality, - true_instantaneous_causality, atol=0.06) + true_instantaneous_causality, + atol=0.06, + ) def test_pairwise_spectral_granger_against_r_grangers(self): """ @@ -524,32 +569,36 @@ def test_pairwise_spectral_granger_against_r_grangers(self): """ - repo_path = \ - r"unittest/causality/granger/pairwise_spectral_granger/data" + repo_path = r"unittest/causality/granger/pairwise_spectral_granger/data" files_to_download = [ ("time_series_small.npy", "b33dc12d4291db7c2087dd8429f15ab4"), - ("gc_matrix.npy", "c57262145e74a178588ff0a1004879e2") + ("gc_matrix.npy", "c57262145e74a178588ff0a1004879e2"), ] for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) - signal = np.load(ELEPHANT_TMP_DIR / 'time_series_small.npy') - gc_matrix = np.load(ELEPHANT_TMP_DIR / 'gc_matrix.npy') + download_datasets(repo_path=f"{repo_path}/{filename}", checksum=checksum) + signal = np.load(ELEPHANT_TMP_DIR / "time_series_small.npy") + gc_matrix = np.load(ELEPHANT_TMP_DIR / "gc_matrix.npy") denom = 20 - f, spectral_causality = \ - elephant.causality.granger.pairwise_spectral_granger( - signal[0], signal[1], len_segment=int(len(signal[0]) / denom), - num_tapers=15, fs=1, num_iterations=50) + f, spectral_causality = elephant.causality.granger.pairwise_spectral_granger( + signal[0], + signal[1], + len_segment=int(len(signal[0]) / denom), + num_tapers=15, + fs=1, + num_iterations=50, + ) np.testing.assert_allclose(gc_matrix[::denom, 0], f, atol=4e-5) - np.testing.assert_allclose(gc_matrix[::denom, 1], - spectral_causality[0], atol=0.085) - np.testing.assert_allclose(gc_matrix[::denom, 2], - spectral_causality[1], atol=0.035) + np.testing.assert_allclose( + gc_matrix[::denom, 1], spectral_causality[0], atol=0.085 + ) + np.testing.assert_allclose( + gc_matrix[::denom, 2], spectral_causality[1], atol=0.035 + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_cell_assembly_detection.py b/elephant/test/test_cell_assembly_detection.py index 152eae7d6..c4da62ca1 100644 --- a/elephant/test/test_cell_assembly_detection.py +++ b/elephant/test/test_cell_assembly_detection.py @@ -1,4 +1,3 @@ - """ Unit test for cell_assembly_detection """ @@ -13,9 +12,7 @@ class CadTestCase(unittest.TestCase): - def setUp(self): - # Parameters self.bin_size = 1 * pq.ms self.alpha = 0.05 @@ -53,28 +50,48 @@ def setUp(self): np.random.seed(1) self.patt1_times = neo.SpikeTrain( np.random.uniform(0, 1 - max(self.lags1), self.n_occ1) * pq.s, - t_start=0 * pq.s, t_stop=1 * pq.s) + t_start=0 * pq.s, + t_stop=1 * pq.s, + ) self.patt2_times = neo.SpikeTrain( np.random.uniform(0, 1 - max(self.lags2), self.n_occ2) * pq.s, - t_start=0 * pq.s, t_stop=1 * pq.s) + t_start=0 * pq.s, + t_stop=1 * pq.s, + ) self.patt3_times = neo.SpikeTrain( np.random.uniform(0, 1 - max(self.lags3), self.n_occ3) * pq.s, - t_start=0 * pq.s, t_stop=1 * pq.s) + t_start=0 * pq.s, + t_stop=1 * pq.s, + ) # Patterns - self.patt1 = [self.patt1_times] + [neo.SpikeTrain( - self.patt1_times + l * pq.s, t_start=self.t_start * pq.s, - t_stop=self.t_stop * pq.s) for l in self.lags1] - self.patt2 = [self.patt2_times] + [neo.SpikeTrain( - self.patt2_times + l * pq.s, t_start=self.t_start * pq.s, - t_stop=self.t_stop * pq.s) for l in self.lags2] - self.patt3 = [self.patt3_times] + [neo.SpikeTrain( - self.patt3_times + l * pq.s, t_start=self.t_start * pq.s, - t_stop=self.t_stop * pq.s) for l in self.lags3] + self.patt1 = [self.patt1_times] + [ + neo.SpikeTrain( + self.patt1_times + l * pq.s, + t_start=self.t_start * pq.s, + t_stop=self.t_stop * pq.s, + ) + for l in self.lags1 + ] + self.patt2 = [self.patt2_times] + [ + neo.SpikeTrain( + self.patt2_times + l * pq.s, + t_start=self.t_start * pq.s, + t_stop=self.t_stop * pq.s, + ) + for l in self.lags2 + ] + self.patt3 = [self.patt3_times] + [ + neo.SpikeTrain( + self.patt3_times + l * pq.s, + t_start=self.t_start * pq.s, + t_stop=self.t_stop * pq.s, + ) + for l in self.lags3 + ] # Binning spiketrains - self.bin_patt1 = conv.BinnedSpikeTrain(self.patt1, - bin_size=self.bin_size) + self.bin_patt1 = conv.BinnedSpikeTrain(self.patt1, bin_size=self.bin_size) # Data self.msip = self.patt1 + self.patt2 + self.patt3 @@ -88,110 +105,141 @@ def setUp(self): self.elements2 = range(self.n_spk2) self.elements3 = range(self.n_spk3) self.elements_msip = [ - self.elements1, range(self.n_spk1, self.n_spk1 + self.n_spk2), - range(self.n_spk1 + self.n_spk2, - self.n_spk1 + self.n_spk2 + self.n_spk3)] - self.occ1 = np.unique(conv.BinnedSpikeTrain( - self.patt1_times, self.bin_size).spike_indices[0]) - self.occ2 = np.unique(conv.BinnedSpikeTrain( - self.patt2_times, self.bin_size).spike_indices[0]) - self.occ3 = np.unique(conv.BinnedSpikeTrain( - self.patt3_times, self.bin_size).spike_indices[0]) + self.elements1, + range(self.n_spk1, self.n_spk1 + self.n_spk2), + range(self.n_spk1 + self.n_spk2, self.n_spk1 + self.n_spk2 + self.n_spk3), + ] + self.occ1 = np.unique( + conv.BinnedSpikeTrain(self.patt1_times, self.bin_size).spike_indices[0] + ) + self.occ2 = np.unique( + conv.BinnedSpikeTrain(self.patt2_times, self.bin_size).spike_indices[0] + ) + self.occ3 = np.unique( + conv.BinnedSpikeTrain(self.patt3_times, self.bin_size).spike_indices[0] + ) self.occ_msip = [list(self.occ1), list(self.occ2), list(self.occ3)] - self.lags_msip = [self.output_lags1, - self.output_lags2, - self.output_lags3] + self.lags_msip = [self.output_lags1, self.output_lags2, self.output_lags3] # test for single pattern injection input def test_cad_single_sip(self): # collecting cad output output_single = cad.cell_assembly_detection( - binned_spiketrain=self.bin_patt1, max_lag=self.max_lag) + binned_spiketrain=self.bin_patt1, max_lag=self.max_lag + ) # check neurons in the pattern - assert_array_equal(sorted(output_single[0]['neurons']), - self.elements1) + assert_array_equal(sorted(output_single[0]["neurons"]), self.elements1) # check the occurrences time of the pattern - assert_array_equal(output_single[0]['times'], - self.occ1 * self.bin_size) + assert_array_equal(output_single[0]["times"], self.occ1 * self.bin_size) # check the lags - assert_array_equal(sorted(output_single[0]['lags']) * pq.s, - self.output_lags1 * self.bin_size) + assert_array_equal( + sorted(output_single[0]["lags"]) * pq.s, self.output_lags1 * self.bin_size + ) # test with multiple (3) patterns injected in the data def test_cad_msip(self): # collecting cad output output_msip = cad.cell_assembly_detection( - binned_spiketrain=self.msip, max_lag=self.max_lag) + binned_spiketrain=self.msip, max_lag=self.max_lag + ) for i, out in enumerate(output_msip): - assert_array_equal(out['times'], self.occ_msip[i] * self.bin_size) - assert_array_equal(sorted(out['lags']) * pq.s, - self.lags_msip[i] * self.bin_size) - assert_array_equal(sorted(out['neurons']), self.elements_msip[i]) + assert_array_equal(out["times"], self.occ_msip[i] * self.bin_size) + assert_array_equal( + sorted(out["lags"]) * pq.s, self.lags_msip[i] * self.bin_size + ) + assert_array_equal(sorted(out["neurons"]), self.elements_msip[i]) # test the errors raised def test_cad_raise_error(self): # test error data input format - self.assertRaises(TypeError, cad.cell_assembly_detection, - binned_spiketrain=[[1, 2, 3], [3, 4, 5]], - max_lag=self.max_lag) + self.assertRaises( + TypeError, + cad.cell_assembly_detection, + binned_spiketrain=[[1, 2, 3], [3, 4, 5]], + max_lag=self.max_lag, + ) # test error significance level - self.assertRaises(ValueError, cad.cell_assembly_detection, - binned_spiketrain=conv.BinnedSpikeTrain( - [neo.SpikeTrain([1, 2, 3] * pq.s, - t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, - t_stop=5 * pq.s)], - bin_size=self.bin_size), - max_lag=self.max_lag, - alpha=-3) + self.assertRaises( + ValueError, + cad.cell_assembly_detection, + binned_spiketrain=conv.BinnedSpikeTrain( + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + bin_size=self.bin_size, + ), + max_lag=self.max_lag, + alpha=-3, + ) # test error minimum number of occurrences - self.assertRaises(ValueError, cad.cell_assembly_detection, - binned_spiketrain=conv.BinnedSpikeTrain( - [neo.SpikeTrain([1, 2, 3] * pq.s, - t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, - t_stop=5 * pq.s)], - bin_size=self.bin_size), - max_lag=self.max_lag, - min_occurrences=-1) + self.assertRaises( + ValueError, + cad.cell_assembly_detection, + binned_spiketrain=conv.BinnedSpikeTrain( + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + bin_size=self.bin_size, + ), + max_lag=self.max_lag, + min_occurrences=-1, + ) # test error minimum number of spikes in a pattern - self.assertRaises(ValueError, cad.cell_assembly_detection, - binned_spiketrain=conv.BinnedSpikeTrain( - [neo.SpikeTrain([1, 2, 3] * pq.s, - t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, - t_stop=5 * pq.s)], - bin_size=self.bin_size), - max_lag=self.max_lag, - max_spikes=1) + self.assertRaises( + ValueError, + cad.cell_assembly_detection, + binned_spiketrain=conv.BinnedSpikeTrain( + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + bin_size=self.bin_size, + ), + max_lag=self.max_lag, + max_spikes=1, + ) # test error chunk size for variance computation - self.assertRaises(ValueError, cad.cell_assembly_detection, - binned_spiketrain=conv.BinnedSpikeTrain( - [neo.SpikeTrain([1, 2, 3] * pq.s, - t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, - t_stop=5 * pq.s)], - bin_size=self.bin_size), - max_lag=self.max_lag, - size_chunks=1) + self.assertRaises( + ValueError, + cad.cell_assembly_detection, + binned_spiketrain=conv.BinnedSpikeTrain( + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + bin_size=self.bin_size, + ), + max_lag=self.max_lag, + size_chunks=1, + ) # test error maximum lag - self.assertRaises(ValueError, cad.cell_assembly_detection, - binned_spiketrain=conv.BinnedSpikeTrain( - [neo.SpikeTrain([1, 2, 3] * pq.s, - t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, - t_stop=5 * pq.s)], - bin_size=self.bin_size), - max_lag=1) + self.assertRaises( + ValueError, + cad.cell_assembly_detection, + binned_spiketrain=conv.BinnedSpikeTrain( + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + bin_size=self.bin_size, + ), + max_lag=1, + ) # test error minimum length spike train - self.assertRaises(ValueError, cad.cell_assembly_detection, - binned_spiketrain=conv.BinnedSpikeTrain( - [neo.SpikeTrain([1, 2, 3] * pq.ms, - t_stop=6 * pq.ms), - neo.SpikeTrain([3, 4, 5] * pq.ms, - t_stop=6 * pq.ms)], - bin_size=1 * pq.ms), - max_lag=self.max_lag) + self.assertRaises( + ValueError, + cad.cell_assembly_detection, + binned_spiketrain=conv.BinnedSpikeTrain( + [ + neo.SpikeTrain([1, 2, 3] * pq.ms, t_stop=6 * pq.ms), + neo.SpikeTrain([3, 4, 5] * pq.ms, t_stop=6 * pq.ms), + ], + bin_size=1 * pq.ms, + ), + max_lag=self.max_lag, + ) if __name__ == "__main__": diff --git a/elephant/test/test_change_point_detection.py b/elephant/test/test_change_point_detection.py index 6fce25cac..38faaf347 100644 --- a/elephant/test/test_change_point_detection.py +++ b/elephant/test/test_change_point_detection.py @@ -17,32 +17,33 @@ def setUp(self): mu_ri = (0.25 + 0.05) / 2 mu_le = (0.1 + 0.15 + 0.05) / 3 sigma_ri = ((0.25 - 0.15) ** 2 + (0.05 - 0.15) ** 2) / 2 - sigma_le = ((0.1 - 0.1) ** 2 + (0.15 - 0.1) ** 2 + ( - 0.05 - 0.1) ** 2) / 3 + sigma_le = ((0.1 - 0.1) ** 2 + (0.15 - 0.1) ** 2 + (0.05 - 0.1) ** 2) / 3 self.targ_t08_h025 = 0 self.targ_t08_h05 = (3 - 4) / np.sqrt( - (sigma_ri / mu_ri ** 3) * 0.5 + (sigma_le / mu_le ** 3) * 0.5) + (sigma_ri / mu_ri**3) * 0.5 + (sigma_le / mu_le**3) * 0.5 + ) # Window Large # def test_filter_with_spiketrain_h05(self): - st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.0) + st = neo.SpikeTrain(self.test_array, units="s", t_stop=2.0) target = self.targ_t08_h05 res = mft._filter(0.8 * pq.s, 0.5 * pq.s, st) assert_array_almost_equal(res, target, decimal=9) self.assertRaises(ValueError, mft._filter, 0.8, 0.5 * pq.s, st) self.assertRaises(ValueError, mft._filter, 0.8 * pq.s, 0.5, st) - self.assertRaises(ValueError, mft._filter, 0.8 * pq.s, 0.5 * pq.s, - self.test_array) + self.assertRaises( + ValueError, mft._filter, 0.8 * pq.s, 0.5 * pq.s, self.test_array + ) # Window Small # def test_filter_with_spiketrain_h025(self): - st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.0) + st = neo.SpikeTrain(self.test_array, units="s", t_stop=2.0) target = self.targ_t08_h025 res = mft._filter(0.8 * pq.s, 0.25 * pq.s, st) assert_array_almost_equal(res, target, decimal=9) def test_filter_with_quantities_h025(self): - st = pq.Quantity(self.test_array, units='s') + st = pq.Quantity(self.test_array, units="s") target = self.targ_t08_h025 res = mft._filter(0.8 * pq.s, 0.25 * pq.s, st) assert_array_almost_equal(res, target, decimal=9) @@ -54,7 +55,7 @@ def test_filter_with_plain_array_h025(self): assert_array_almost_equal(res, target, decimal=9) def test_isi_with_quantities_h05(self): - st = pq.Quantity(self.test_array, units='s') + st = pq.Quantity(self.test_array, units="s") target = self.targ_t08_h05 res = mft._filter(0.8 * pq.s, 0.5 * pq.s, st) assert_array_almost_equal(res, target, decimal=9) @@ -69,39 +70,70 @@ def test_isi_with_plain_array_h05(self): class FilterProcessTestCase(unittest.TestCase): def setUp(self): self.test_array = [1.1, 1.2, 1.4, 1.6, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95] - x = (7 - 3) / np.sqrt( - (0.0025 / 0.15 ** 3) * 0.5 + (0.0003472 / 0.05833 ** 3) * 0.5) - self.targ_h05 = [[0.5, 1, 1.5], - [(0 - 1.7) / np.sqrt(0.4), (0 - 1.7) / np.sqrt(0.4), - (x - 1.7) / np.sqrt(0.4)]] + x = (7 - 3) / np.sqrt((0.0025 / 0.15**3) * 0.5 + (0.0003472 / 0.05833**3) * 0.5) + self.targ_h05 = [ + [0.5, 1, 1.5], + [ + (0 - 1.7) / np.sqrt(0.4), + (0 - 1.7) / np.sqrt(0.4), + (x - 1.7) / np.sqrt(0.4), + ], + ] def test_filter_process_with_spiketrain_h05(self): - st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.1) + st = neo.SpikeTrain(self.test_array, units="s", t_stop=2.1) target = self.targ_h05 - res = mft._filter_process(0.5 * pq.s, 0.5 * pq.s, st, 2.01 * pq.s, - np.array([[0.5], [1.7], [0.4]])) + res = mft._filter_process( + 0.5 * pq.s, 0.5 * pq.s, st, 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]]) + ) assert_array_almost_equal(res[1], target[1], decimal=3) - self.assertRaises(ValueError, mft._filter_process, 0.5, 0.5 * pq.s, - st, 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]])) - self.assertRaises(ValueError, mft._filter_process, 0.5 * pq.s, 0.5, - st, 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]])) - self.assertRaises(ValueError, mft._filter_process, 0.5 * pq.s, - 0.5 * pq.s, self.test_array, 2.01 * pq.s, - np.array([[0.5], [1.7], [0.4]])) + self.assertRaises( + ValueError, + mft._filter_process, + 0.5, + 0.5 * pq.s, + st, + 2.01 * pq.s, + np.array([[0.5], [1.7], [0.4]]), + ) + self.assertRaises( + ValueError, + mft._filter_process, + 0.5 * pq.s, + 0.5, + st, + 2.01 * pq.s, + np.array([[0.5], [1.7], [0.4]]), + ) + self.assertRaises( + ValueError, + mft._filter_process, + 0.5 * pq.s, + 0.5 * pq.s, + self.test_array, + 2.01 * pq.s, + np.array([[0.5], [1.7], [0.4]]), + ) def test_filter_proces_with_quantities_h05(self): - st = pq.Quantity(self.test_array, units='s') + st = pq.Quantity(self.test_array, units="s") target = self.targ_h05 - res = mft._filter_process(0.5 * pq.s, 0.5 * pq.s, st, 2.01 * pq.s, - np.array([[0.5], [1.7], [0.4]])) + res = mft._filter_process( + 0.5 * pq.s, 0.5 * pq.s, st, 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]]) + ) assert_array_almost_equal(res[0], target[0], decimal=3) def test_filter_proces_with_plain_array_h05(self): st = self.test_array target = self.targ_h05 - res = mft._filter_process(0.5 * pq.s, 0.5 * pq.s, st * pq.s, - 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]])) + res = mft._filter_process( + 0.5 * pq.s, + 0.5 * pq.s, + st * pq.s, + 2.01 * pq.s, + np.array([[0.5], [1.7], [0.4]]), + ) self.assertNotIsInstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=3) @@ -118,46 +150,33 @@ def setUp(self): # the user should do the same, if the metohd has to be applied to # several spike trains of the same length `T` and with the same set of # window. - self.test_param = np.array([[10., - 25., - 50., - 75., - 100., - 125., - 150.], - [3.167, - 2.955, - 2.721, - 2.548, - 2.412, - 2.293, - 2.180], - [0.150, - 0.185, - 0.224, - 0.249, - 0.269, - 0.288, - 0.301]]) + self.test_param = np.array( + [ + [10.0, 25.0, 50.0, 75.0, 100.0, 125.0, 150.0], + [3.167, 2.955, 2.721, 2.548, 2.412, 2.293, 2.180], + [0.150, 0.185, 0.224, 0.249, 0.269, 0.288, 0.301], + ] + ) self.test_quantile = 2.75 def test_MultipleFilterAlgorithm_with_spiketrain_h05(self): - st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.1) + st = neo.SpikeTrain(self.test_array, units="s", t_stop=2.1) target = [self.targ_h05_dt05] - res = mft.multiple_filter_test([0.5] * pq.s, st, 2.1 * pq.s, 5, 100, - time_step=0.1 * pq.s) + res = mft.multiple_filter_test( + [0.5] * pq.s, st, 2.1 * pq.s, 5, 100, time_step=0.1 * pq.s + ) assert_array_almost_equal(res, target, decimal=9) def test_MultipleFilterAlgorithm_with_quantities_h05(self): - st = pq.Quantity(self.test_array, units='s') + st = pq.Quantity(self.test_array, units="s") target = [self.targ_h05_dt05] - res = mft.multiple_filter_test([0.5] * pq.s, st, 2.1 * pq.s, 5, 100, - time_step=0.5 * pq.s) + res = mft.multiple_filter_test( + [0.5] * pq.s, st, 2.1 * pq.s, 5, 100, time_step=0.5 * pq.s + ) self.assertNotIsInstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) def test_MultipleFilterAlgorithm_with_longdata(self): - def gamma_train(k, teta, tmax): x = np.random.gamma(k, teta, int(tmax * (k * teta) ** (-1) * 3)) s = np.cumsum(x) @@ -165,17 +184,18 @@ def gamma_train(k, teta, tmax): s = s[idx] # gamma process return s - def alternative_hypothesis(k1, teta1, c1, k2, teta2, c2, k3, teta3, c3, - k4, teta4, T): + def alternative_hypothesis( + k1, teta1, c1, k2, teta2, c2, k3, teta3, c3, k4, teta4, T + ): s1 = gamma_train(k1, teta1, c1) s2 = gamma_train(k2, teta2, c2) + c1 s3 = gamma_train(k3, teta3, c3) + c1 + c2 s4 = gamma_train(k4, teta4, T) + c1 + c2 + c3 return np.concatenate((s1, s2, s3, s4)), [s1[-1], s2[-1], s3[-1]] - st = self.h1 = alternative_hypothesis(1, 1 / 4., 150, 2, 1 / 26., 30, - 1, 1 / 36., 320, - 2, 1 / 33., 200)[0] + st = self.h1 = alternative_hypothesis( + 1, 1 / 4.0, 150, 2, 1 / 26.0, 30, 1, 1 / 36.0, 320, 2, 1 / 33.0, 200 + )[0] window_size = [10, 25, 50, 75, 100, 125, 150] * pq.s self.target_points = [150, 180, 500] @@ -189,18 +209,21 @@ def alternative_hypothesis(k1, teta1, c1, k2, teta2, c2, k3, teta3, c3, 10000, test_quantile=self.test_quantile, test_param=self.test_param, - time_step=1 * pq.s) + time_step=1 * pq.s, + ) self.assertNotIsInstance(result, pq.Quantity) result_concatenated = [] for i in result: result_concatenated = np.hstack([result_concatenated, i]) result_concatenated = np.sort(result_concatenated) - assert_allclose(result_concatenated[:3], target[:3], rtol=0, - atol=5) - print('detected {0} cps: {1}'.format(len(result_concatenated), - result_concatenated)) + assert_allclose(result_concatenated[:3], target[:3], rtol=0, atol=5) + print( + "detected {0} cps: {1}".format( + len(result_concatenated), result_concatenated + ) + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_conversion.py b/elephant/test/test_conversion.py index 21e8a0b29..c86f8a5f2 100644 --- a/elephant/test/test_conversion.py +++ b/elephant/test/test_conversion.py @@ -12,7 +12,7 @@ from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq -from numpy.testing import (assert_array_almost_equal, assert_array_equal) +from numpy.testing import assert_array_almost_equal, assert_array_equal import elephant.conversion as cv from elephant.utils import get_common_start_stop_times @@ -28,83 +28,87 @@ def setUp(self): self.test_array_1d = np.array([1.23, 0.3, 0.87, 0.56]) def test_binarize_with_spiketrain_exact(self): - st = neo.SpikeTrain(self.test_array_1d, units='ms', - t_stop=10.0, sampling_rate=100) - times = np.arange(0, 10. + .01, .01) - target = np.zeros_like(times).astype('bool') + st = neo.SpikeTrain( + self.test_array_1d, units="ms", t_stop=10.0, sampling_rate=100 + ) + times = np.arange(0, 10.0 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") for time in self.test_array_1d: target[get_nearest(times, time)] = True - times = pq.Quantity(times, units='ms') + times = pq.Quantity(times, units="ms") res, tres = cv.binarize(st, return_times=True) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_with_spiketrain_exact_set_ends(self): - st = neo.SpikeTrain(self.test_array_1d, units='ms', - t_stop=10.0, sampling_rate=100) - times = np.arange(5., 10. + .01, .01) - target = np.zeros_like(times).astype('bool') - times = pq.Quantity(times, units='ms') - - res, tres = cv.binarize(st, return_times=True, t_start=5., t_stop=10.) + st = neo.SpikeTrain( + self.test_array_1d, units="ms", t_stop=10.0, sampling_rate=100 + ) + times = np.arange(5.0, 10.0 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") + times = pq.Quantity(times, units="ms") + + res, tres = cv.binarize(st, return_times=True, t_start=5.0, t_stop=10.0) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_with_spiketrain_round(self): - st = neo.SpikeTrain(self.test_array_1d, units='ms', - t_stop=10.0, sampling_rate=10.0) - times = np.arange(0, 10. + .1, .1) - target = np.zeros_like(times).astype('bool') + st = neo.SpikeTrain( + self.test_array_1d, units="ms", t_stop=10.0, sampling_rate=10.0 + ) + times = np.arange(0, 10.0 + 0.1, 0.1) + target = np.zeros_like(times).astype("bool") for time in np.round(self.test_array_1d, 1): target[get_nearest(times, time)] = True - times = pq.Quantity(times, units='ms') + times = pq.Quantity(times, units="ms") res, tres = cv.binarize(st, return_times=True) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_with_quantities_exact(self): - st = pq.Quantity(self.test_array_1d, units='ms') - times = np.arange(0, 1.23 + .01, .01) - target = np.zeros_like(times).astype('bool') + st = pq.Quantity(self.test_array_1d, units="ms") + times = np.arange(0, 1.23 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") for time in self.test_array_1d: target[get_nearest(times, time)] = True - times = pq.Quantity(times, units='ms') + times = pq.Quantity(times, units="ms") - res, tres = cv.binarize(st, return_times=True, - sampling_rate=100. * pq.kHz) + res, tres = cv.binarize(st, return_times=True, sampling_rate=100.0 * pq.kHz) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_with_quantities_exact_set_ends(self): - st = pq.Quantity(self.test_array_1d, units='ms') - times = np.arange(0, 10. + .01, .01) - target = np.zeros_like(times).astype('bool') + st = pq.Quantity(self.test_array_1d, units="ms") + times = np.arange(0, 10.0 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") for time in self.test_array_1d: target[get_nearest(times, time)] = True - times = pq.Quantity(times, units='ms') + times = pq.Quantity(times, units="ms") - res, tres = cv.binarize(st, return_times=True, t_stop=10., - sampling_rate=100. * pq.kHz) + res, tres = cv.binarize( + st, return_times=True, t_stop=10.0, sampling_rate=100.0 * pq.kHz + ) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_with_quantities_round_set_ends(self): - st = pq.Quantity(self.test_array_1d, units='ms') - times = np.arange(5., 10. + .1, .1) - target = np.zeros_like(times).astype('bool') - times = pq.Quantity(times, units='ms') - - res, tres = cv.binarize(st, return_times=True, t_start=5., t_stop=10., - sampling_rate=10. * pq.kHz) + st = pq.Quantity(self.test_array_1d, units="ms") + times = np.arange(5.0, 10.0 + 0.1, 0.1) + target = np.zeros_like(times).astype("bool") + times = pq.Quantity(times, units="ms") + + res, tres = cv.binarize( + st, return_times=True, t_start=5.0, t_stop=10.0, sampling_rate=10.0 * pq.kHz + ) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_with_plain_array_exact(self): st = self.test_array_1d - times = np.arange(0, 1.23 + .01, .01) - target = np.zeros_like(times).astype('bool') + times = np.arange(0, 1.23 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") for time in self.test_array_1d: target[get_nearest(times, time)] = True @@ -114,20 +118,19 @@ def test_binarize_with_plain_array_exact(self): def test_binarize_with_plain_array_exact_set_ends(self): st = self.test_array_1d - times = np.arange(0, 10. + .01, .01) - target = np.zeros_like(times).astype('bool') + times = np.arange(0, 10.0 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") for time in self.test_array_1d: target[get_nearest(times, time)] = True - res, tres = cv.binarize(st, return_times=True, t_stop=10., - sampling_rate=100.) + res, tres = cv.binarize(st, return_times=True, t_stop=10.0, sampling_rate=100.0) assert_array_almost_equal(res, target, decimal=9) assert_array_almost_equal(tres, times, decimal=9) def test_binarize_no_time(self): st = self.test_array_1d - times = np.arange(0, 1.23 + .01, .01) - target = np.zeros_like(times).astype('bool') + times = np.arange(0, 1.23 + 0.01, 0.01) + target = np.zeros_like(times).astype("bool") for time in self.test_array_1d: target[get_nearest(times, time)] = True @@ -139,50 +142,62 @@ def test_binarize_no_time(self): def test_binariz_rate_with_plain_array_and_units_typeerror(self): st = self.test_array_1d - self.assertRaises(TypeError, cv.binarize, st, - t_start=pq.Quantity(0, 'ms'), - sampling_rate=10.) - self.assertRaises(TypeError, cv.binarize, st, - t_stop=pq.Quantity(10, 'ms'), - sampling_rate=10.) - self.assertRaises(TypeError, cv.binarize, st, - t_start=pq.Quantity(0, 'ms'), - t_stop=pq.Quantity(10, 'ms'), - sampling_rate=10.) - self.assertRaises(TypeError, cv.binarize, st, - t_start=pq.Quantity(0, 'ms'), - t_stop=10., - sampling_rate=10.) - self.assertRaises(TypeError, cv.binarize, st, - t_start=0., - t_stop=pq.Quantity(10, 'ms'), - sampling_rate=10.) - self.assertRaises(TypeError, cv.binarize, st, - sampling_rate=10. * pq.Hz) + self.assertRaises( + TypeError, cv.binarize, st, t_start=pq.Quantity(0, "ms"), sampling_rate=10.0 + ) + self.assertRaises( + TypeError, cv.binarize, st, t_stop=pq.Quantity(10, "ms"), sampling_rate=10.0 + ) + self.assertRaises( + TypeError, + cv.binarize, + st, + t_start=pq.Quantity(0, "ms"), + t_stop=pq.Quantity(10, "ms"), + sampling_rate=10.0, + ) + self.assertRaises( + TypeError, + cv.binarize, + st, + t_start=pq.Quantity(0, "ms"), + t_stop=10.0, + sampling_rate=10.0, + ) + self.assertRaises( + TypeError, + cv.binarize, + st, + t_start=0.0, + t_stop=pq.Quantity(10, "ms"), + sampling_rate=10.0, + ) + self.assertRaises(TypeError, cv.binarize, st, sampling_rate=10.0 * pq.Hz) def test_binariz_without_sampling_rate_valueerror(self): st0 = self.test_array_1d - st1 = pq.Quantity(st0, 'ms') + st1 = pq.Quantity(st0, "ms") self.assertRaises(ValueError, cv.binarize, st0) - self.assertRaises(ValueError, cv.binarize, st0, - t_start=0) - self.assertRaises(ValueError, cv.binarize, st0, - t_stop=10) - self.assertRaises(ValueError, cv.binarize, st0, - t_start=0, t_stop=10) - self.assertRaises(ValueError, cv.binarize, st1, - t_start=pq.Quantity(0, 'ms'), t_stop=10.) - self.assertRaises(ValueError, cv.binarize, st1, - t_start=0., t_stop=pq.Quantity(10, 'ms')) + self.assertRaises(ValueError, cv.binarize, st0, t_start=0) + self.assertRaises(ValueError, cv.binarize, st0, t_stop=10) + self.assertRaises(ValueError, cv.binarize, st0, t_start=0, t_stop=10) + self.assertRaises( + ValueError, cv.binarize, st1, t_start=pq.Quantity(0, "ms"), t_stop=10.0 + ) + self.assertRaises( + ValueError, cv.binarize, st1, t_start=0.0, t_stop=pq.Quantity(10, "ms") + ) self.assertRaises(ValueError, cv.binarize, st1) def test_bin_edges_empty_binned_spiketrain(self): - st = neo.SpikeTrain(times=np.array([2.5]) * pq.s, t_start=0 * pq.s, - t_stop=3 * pq.s) + st = neo.SpikeTrain( + times=np.array([2.5]) * pq.s, t_start=0 * pq.s, t_stop=3 * pq.s + ) with self.assertWarns(UserWarning): - bst = cv.BinnedSpikeTrain(st, bin_size=2 * pq.s, t_start=0 * pq.s, - t_stop=3 * pq.s) - assert_array_equal(bst.bin_edges, [0., 2.] * pq.s) + bst = cv.BinnedSpikeTrain( + st, bin_size=2 * pq.s, t_start=0 * pq.s, t_stop=3 * pq.s + ) + assert_array_equal(bst.bin_edges, [0.0, 2.0] * pq.s) assert_array_equal(bst.spike_indices, [[]]) # no binned spikes self.assertEqual(bst.get_num_of_spikes(), 0) @@ -193,38 +208,46 @@ def test_regression_431(self): correctly handled by the constructor """ st1 = neo.SpikeTrain( - times=np.array([1, 2, 3]) * pq.ms, - t_start=0 * pq.ms, t_stop=10 * pq.ms) + times=np.array([1, 2, 3]) * pq.ms, t_start=0 * pq.ms, t_stop=10 * pq.ms + ) st2 = neo.SpikeTrain( - times=np.array([4, 5, 6]) * pq.ms, - t_start=0 * pq.ms, t_stop=10 * pq.ms) + times=np.array([4, 5, 6]) * pq.ms, t_start=0 * pq.ms, t_stop=10 * pq.ms + ) real_list = [st1, st2] spiketrainlist = SpikeTrainList([st1, st2]) - real_list_binary = cv.BinnedSpikeTrain(real_list, bin_size=1*pq.ms) - spiketrainlist_binary = cv.BinnedSpikeTrain( - spiketrainlist, bin_size=1 * pq.ms) + real_list_binary = cv.BinnedSpikeTrain(real_list, bin_size=1 * pq.ms) + spiketrainlist_binary = cv.BinnedSpikeTrain(spiketrainlist, bin_size=1 * pq.ms) assert_array_equal( - real_list_binary.to_array(), spiketrainlist_binary.to_array()) + real_list_binary.to_array(), spiketrainlist_binary.to_array() + ) class BinnedSpikeTrainTestCase(unittest.TestCase): def setUp(self): self.spiketrain_a = neo.SpikeTrain( - [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) + [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s + ) self.spiketrain_b = neo.SpikeTrain( - [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s + ) self.bin_size = 1 * pq.s self.tolerance = 1e-8 def test_binarize(self): - spiketrains = [self.spiketrain_a, self.spiketrain_b, - self.spiketrain_a, self.spiketrain_b] + spiketrains = [ + self.spiketrain_a, + self.spiketrain_b, + self.spiketrain_a, + self.spiketrain_b, + ] for sparse_format in ("csr", "csc"): - bst = cv.BinnedSpikeTrain(spiketrains=spiketrains, - bin_size=self.bin_size, - sparse_format=sparse_format) + bst = cv.BinnedSpikeTrain( + spiketrains=spiketrains, + bin_size=self.bin_size, + sparse_format=sparse_format, + ) bst_bin = bst.binarize(copy=True) bst_copy = bst.copy() assert_array_equal(bst_bin.to_array(), bst.to_bool_array()) @@ -232,28 +255,39 @@ def test_binarize(self): self.assertEqual(bst_bin, bst_copy) def test_slice(self): - spiketrains = [self.spiketrain_a, self.spiketrain_b, - self.spiketrain_a, self.spiketrain_b] - bst = cv.BinnedSpikeTrain(spiketrains=spiketrains, - bin_size=self.bin_size) + spiketrains = [ + self.spiketrain_a, + self.spiketrain_b, + self.spiketrain_a, + self.spiketrain_b, + ] + bst = cv.BinnedSpikeTrain(spiketrains=spiketrains, bin_size=self.bin_size) self.assertEqual(bst[:, :], bst) - self.assertEqual(bst[1:], cv.BinnedSpikeTrain(spiketrains[1:], - bin_size=self.bin_size)) + self.assertEqual( + bst[1:], cv.BinnedSpikeTrain(spiketrains[1:], bin_size=self.bin_size) + ) self.assertEqual(bst[:, :4], bst.time_slice(t_stop=4 * pq.s)) - self.assertEqual(bst[:, 1:-1], cv.BinnedSpikeTrain( - spiketrains, bin_size=self.bin_size, - t_start=1 * pq.s, t_stop=9 * pq.s - )) - self.assertEqual(bst[0, 0], cv.BinnedSpikeTrain( - neo.SpikeTrain([0.5, 0.7], t_stop=1, units='s'), - bin_size=self.bin_size - )) + self.assertEqual( + bst[:, 1:-1], + cv.BinnedSpikeTrain( + spiketrains, bin_size=self.bin_size, t_start=1 * pq.s, t_stop=9 * pq.s + ), + ) + self.assertEqual( + bst[0, 0], + cv.BinnedSpikeTrain( + neo.SpikeTrain([0.5, 0.7], t_stop=1, units="s"), bin_size=self.bin_size + ), + ) # 2-seconds stride: leave [0..1, 2..3, 4..5, 6..7] interval - self.assertEqual(bst[0, ::2], cv.BinnedSpikeTrain( - neo.SpikeTrain([0.5, 0.7, 4.3, 6.7], t_stop=10, units='s'), - bin_size=2 * self.bin_size - )) + self.assertEqual( + bst[0, ::2], + cv.BinnedSpikeTrain( + neo.SpikeTrain([0.5, 0.7, 4.3, 6.7], t_stop=10, units="s"), + bin_size=2 * self.bin_size, + ), + ) bst_copy = bst.copy() bst_copy[:] = 1 @@ -261,10 +295,10 @@ def test_slice(self): def test_time_slice(self): spiketrains = [self.spiketrain_a, self.spiketrain_b] - bst = cv.BinnedSpikeTrain(spiketrains=spiketrains, - bin_size=self.bin_size) - bst_equal = bst.time_slice(t_start=bst.t_start - 5 * pq.s, - t_stop=bst.t_stop + 5 * pq.s) + bst = cv.BinnedSpikeTrain(spiketrains=spiketrains, bin_size=self.bin_size) + bst_equal = bst.time_slice( + t_start=bst.t_start - 5 * pq.s, t_stop=bst.t_stop + 5 * pq.s + ) self.assertEqual(bst_equal, bst) bst_same = bst.time_slice(t_start=None, t_stop=None) self.assertIs(bst_same, bst) @@ -275,13 +309,14 @@ def test_time_slice(self): self.assertEqual(bst_empty.n_bins, 0) t_range = np.arange(0, 10, self.bin_size.item()) * pq.s for i, t_start in enumerate(t_range[:-1]): - for t_stop in t_range[i + 1:]: + for t_stop in t_range[i + 1 :]: bst_ij = bst.time_slice(t_start=t_start, t_stop=t_stop) bst_ij2 = bst_ij.time_slice(t_start=t_start, t_stop=t_stop) self.assertEqual(bst_ij2, bst_ij) self.assertEqual(bst_ij2.tolerance, bst.tolerance) - sts = [st.time_slice(t_start=t_start, t_stop=t_stop) - for st in spiketrains] + sts = [ + st.time_slice(t_start=t_start, t_stop=t_stop) for st in spiketrains + ] bst_ref = cv.BinnedSpikeTrain(sts, bin_size=self.bin_size) self.assertEqual(bst_ij, bst_ref) @@ -290,75 +325,80 @@ def test_time_slice(self): def test_to_spike_trains(self): np.random.seed(1) - spiketrains = [homogeneous_poisson_process(rate=10 * pq.Hz, - t_start=-1 * pq.s, - t_stop=10 * pq.s)] + spiketrains = [ + homogeneous_poisson_process( + rate=10 * pq.Hz, t_start=-1 * pq.s, t_stop=10 * pq.s + ) + ] for sparse_format in ("csr", "csc"): bst1 = cv.BinnedSpikeTrain( spiketrains=[self.spiketrain_a, self.spiketrain_b], - bin_size=self.bin_size, sparse_format=sparse_format + bin_size=self.bin_size, + sparse_format=sparse_format, + ) + bst2 = cv.BinnedSpikeTrain( + spiketrains=spiketrains, + bin_size=300 * pq.ms, + sparse_format=sparse_format, ) - bst2 = cv.BinnedSpikeTrain(spiketrains=spiketrains, - bin_size=300 * pq.ms, - sparse_format=sparse_format) for bst in (bst1, bst2): for spikes in ("random", "left", "center"): - spiketrains_gen = bst.to_spike_trains(spikes=spikes, - annotate_bins=True) + spiketrains_gen = bst.to_spike_trains( + spikes=spikes, annotate_bins=True + ) for st, indices in zip(spiketrains_gen, bst.spike_indices): # check sorted self.assertTrue((np.diff(st.magnitude) > 0).all()) - assert_array_equal(st.array_annotations['bins'], - indices) - self.assertEqual(st.annotations['bin_size'], - bst.bin_size) + assert_array_equal(st.array_annotations["bins"], indices) + self.assertEqual(st.annotations["bin_size"], bst.bin_size) self.assertEqual(st.t_start, bst.t_start) self.assertEqual(st.t_stop, bst.t_stop) - bst_same = cv.BinnedSpikeTrain(spiketrains_gen, - bin_size=bst.bin_size, - sparse_format=sparse_format) + bst_same = cv.BinnedSpikeTrain( + spiketrains_gen, + bin_size=bst.bin_size, + sparse_format=sparse_format, + ) self.assertEqual(bst_same, bst) # invalid mode - self.assertRaises(ValueError, bst.to_spike_trains, - spikes='right') + self.assertRaises(ValueError, bst.to_spike_trains, spikes="right") def test_get_num_of_spikes(self): spiketrains = [self.spiketrain_a, self.spiketrain_b] for spiketrain in spiketrains: - binned = cv.BinnedSpikeTrain(spiketrain, n_bins=10, - bin_size=1 * pq.s, t_start=0 * pq.s) - self.assertEqual(binned.get_num_of_spikes(), - len(binned.spike_indices[0])) + binned = cv.BinnedSpikeTrain( + spiketrain, n_bins=10, bin_size=1 * pq.s, t_start=0 * pq.s + ) + self.assertEqual(binned.get_num_of_spikes(), len(binned.spike_indices[0])) for sparse_format in ("csr", "csc"): - binned_matrix = cv.BinnedSpikeTrain(spiketrains, n_bins=10, - bin_size=1 * pq.s, - sparse_format=sparse_format) + binned_matrix = cv.BinnedSpikeTrain( + spiketrains, n_bins=10, bin_size=1 * pq.s, sparse_format=sparse_format + ) n_spikes_per_row = binned_matrix.get_num_of_spikes(axis=1) - n_spikes_per_row_from_indices = list( - map(len, binned_matrix.spike_indices)) + n_spikes_per_row_from_indices = list(map(len, binned_matrix.spike_indices)) assert_array_equal(n_spikes_per_row, n_spikes_per_row_from_indices) - self.assertEqual(binned_matrix.get_num_of_spikes(), - sum(n_spikes_per_row_from_indices)) + self.assertEqual( + binned_matrix.get_num_of_spikes(), sum(n_spikes_per_row_from_indices) + ) def test_binned_spiketrain_sparse(self): a = neo.SpikeTrain([1.7, 1.8, 4.3] * pq.s, t_stop=10.0 * pq.s) b = neo.SpikeTrain([1.7, 1.8, 4.3] * pq.s, t_stop=10.0 * pq.s) bin_size = 1 * pq.s nbins = 10 - x = cv.BinnedSpikeTrain([a, b], n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s) + x = cv.BinnedSpikeTrain( + [a, b], n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s + ) x_sparse = [2, 1, 2, 1] assert_array_equal(x.sparse_matrix.data, x_sparse) assert_array_equal(x.spike_indices, [[1, 1, 4], [1, 1, 4]]) def test_binned_spiketrain_shape(self): a = self.spiketrain_a - x = cv.BinnedSpikeTrain(a, n_bins=10, - bin_size=self.bin_size, - t_start=0 * pq.s) - x_bool = cv.BinnedSpikeTrain(a, n_bins=10, bin_size=self.bin_size, - t_start=0 * pq.s) + x = cv.BinnedSpikeTrain(a, n_bins=10, bin_size=self.bin_size, t_start=0 * pq.s) + x_bool = cv.BinnedSpikeTrain( + a, n_bins=10, bin_size=self.bin_size, t_start=0 * pq.s + ) self.assertEqual(x.to_array().shape, (1, 10)) self.assertEqual(x_bool.to_bool_array().shape, (1, 10)) @@ -368,44 +408,52 @@ def test_binned_spiketrain_shape_list(self): b = self.spiketrain_b c = [a, b] nbins = 5 - x = cv.BinnedSpikeTrain(c, n_bins=nbins, t_start=0 * pq.s, - t_stop=10.0 * pq.s) - x_bool = cv.BinnedSpikeTrain(c, n_bins=nbins, t_start=0 * pq.s, - t_stop=10.0 * pq.s) + x = cv.BinnedSpikeTrain(c, n_bins=nbins, t_start=0 * pq.s, t_stop=10.0 * pq.s) + x_bool = cv.BinnedSpikeTrain( + c, n_bins=nbins, t_start=0 * pq.s, t_stop=10.0 * pq.s + ) self.assertEqual(x.to_array().shape, (2, 5)) self.assertEqual(x_bool.to_bool_array().shape, (2, 5)) def test_binned_spiketrain_neg_times(self): a = neo.SpikeTrain( [-6.5, 0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, - t_start=-6.5 * pq.s, t_stop=10.0 * pq.s) + t_start=-6.5 * pq.s, + t_stop=10.0 * pq.s, + ) bin_size = self.bin_size nbins = 16 - x = cv.BinnedSpikeTrain(a, n_bins=nbins, bin_size=bin_size, - t_start=-6.5 * pq.s) + x = cv.BinnedSpikeTrain(a, n_bins=nbins, bin_size=bin_size, t_start=-6.5 * pq.s) y = [[1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0]] assert_array_equal(x.to_bool_array(), y) def test_binned_spiketrain_neg_times_list(self): a = neo.SpikeTrain( [-6.5, 0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, - t_start=-7 * pq.s, t_stop=7 * pq.s) + t_start=-7 * pq.s, + t_stop=7 * pq.s, + ) b = neo.SpikeTrain( [-0.1, -0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, - t_start=-1 * pq.s, t_stop=8 * pq.s) + t_start=-1 * pq.s, + t_stop=8 * pq.s, + ) spiketrains = [a, b] # not the same t_start and t_stop - self.assertRaises(ValueError, cv.BinnedSpikeTrain, - spiketrains=spiketrains, - bin_size=self.bin_size) + self.assertRaises( + ValueError, + cv.BinnedSpikeTrain, + spiketrains=spiketrains, + bin_size=self.bin_size, + ) t_start, t_stop = get_common_start_stop_times(spiketrains) self.assertEqual(t_start, -1 * pq.s) self.assertEqual(t_stop, 7 * pq.s) - x_bool = cv.BinnedSpikeTrain(spiketrains, bin_size=self.bin_size, - t_start=t_start, t_stop=t_stop) - y_bool = [[0, 1, 1, 0, 1, 1, 1, 1], - [1, 0, 1, 1, 0, 1, 1, 0]] + x_bool = cv.BinnedSpikeTrain( + spiketrains, bin_size=self.bin_size, t_start=t_start, t_stop=t_stop + ) + y_bool = [[0, 1, 1, 0, 1, 1, 1, 1], [1, 0, 1, 1, 0, 1, 1, 0]] assert_array_equal(x_bool.to_bool_array(), y_bool) @@ -414,16 +462,15 @@ def test_binned_spiketrain_indices(self): a = self.spiketrain_a bin_size = self.bin_size nbins = 10 - x = cv.BinnedSpikeTrain(a, n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s) - x_bool = cv.BinnedSpikeTrain(a, n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s) - y_matrix = [[2., 1., 0., 1., 1., 1., 1., 0., 0., 0.]] - y_bool_matrix = [[1., 1., 0., 1., 1., 1., 1., 0., 0., 0.]] + x = cv.BinnedSpikeTrain(a, n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s) + x_bool = cv.BinnedSpikeTrain( + a, n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s + ) + y_matrix = [[2.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]] + y_bool_matrix = [[1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]] assert_array_equal(x.to_array(), y_matrix) assert_array_equal(x_bool.to_bool_array(), y_bool_matrix) - s = x_bool.to_sparse_bool_array()[ - x_bool.to_sparse_bool_array().nonzero()] + s = x_bool.to_sparse_bool_array()[x_bool.to_sparse_bool_array().nonzero()] assert_array_equal(s, [[True] * 6]) def test_binned_spiketrain_list(self): @@ -433,14 +480,12 @@ def test_binned_spiketrain_list(self): bin_size = self.bin_size nbins = 10 c = [a, b] - x = cv.BinnedSpikeTrain(c, n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s) - x_bool = cv.BinnedSpikeTrain(c, n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s) - y_matrix = [[2, 1, 0, 1, 1, 1, 1, 0, 0, 0], - [2, 1, 1, 0, 1, 1, 0, 0, 1, 0]] - y_matrix_bool = [[1, 1, 0, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 1, 1, 0, 0, 1, 0]] + x = cv.BinnedSpikeTrain(c, n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s) + x_bool = cv.BinnedSpikeTrain( + c, n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s + ) + y_matrix = [[2, 1, 0, 1, 1, 1, 1, 0, 0, 0], [2, 1, 1, 0, 1, 1, 0, 0, 1, 0]] + y_matrix_bool = [[1, 1, 0, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 1, 1, 0, 0, 1, 0]] assert_array_equal(x.to_array(), y_matrix) assert_array_equal(x_bool.to_bool_array(), y_matrix_bool) @@ -451,11 +496,12 @@ def test_binned_spiketrain_list_t_stop(self): c = [a, b] bin_size = self.bin_size nbins = 10 - x = cv.BinnedSpikeTrain(c, n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s, - t_stop=None) - x_bool = cv.BinnedSpikeTrain(c, n_bins=nbins, bin_size=bin_size, - t_start=0 * pq.s) + x = cv.BinnedSpikeTrain( + c, n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s, t_stop=None + ) + x_bool = cv.BinnedSpikeTrain( + c, n_bins=nbins, bin_size=bin_size, t_start=0 * pq.s + ) self.assertEqual(x.t_stop, 10 * pq.s) self.assertEqual(x_bool.t_stop, 10 * pq.s) @@ -465,10 +511,12 @@ def test_binned_spiketrain_list_numbins(self): b = self.spiketrain_b c = [a, b] bin_size = 1 * pq.s - x = cv.BinnedSpikeTrain(c, bin_size=bin_size, t_start=0 * pq.s, - t_stop=10. * pq.s) - x_bool = cv.BinnedSpikeTrain(c, bin_size=bin_size, t_start=0 * pq.s, - t_stop=10. * pq.s) + x = cv.BinnedSpikeTrain( + c, bin_size=bin_size, t_start=0 * pq.s, t_stop=10.0 * pq.s + ) + x_bool = cv.BinnedSpikeTrain( + c, bin_size=bin_size, t_start=0 * pq.s, t_stop=10.0 * pq.s + ) self.assertEqual(x.n_bins, 10) self.assertEqual(x_bool.n_bins, 10) @@ -476,10 +524,12 @@ def test_binned_spiketrain_matrix(self): # Init a = self.spiketrain_a b = self.spiketrain_b - x_bool_a = cv.BinnedSpikeTrain(a, bin_size=pq.s, t_start=0 * pq.s, - t_stop=10. * pq.s) - x_bool_b = cv.BinnedSpikeTrain(b, bin_size=pq.s, t_start=0 * pq.s, - t_stop=10. * pq.s) + x_bool_a = cv.BinnedSpikeTrain( + a, bin_size=pq.s, t_start=0 * pq.s, t_stop=10.0 * pq.s + ) + x_bool_b = cv.BinnedSpikeTrain( + b, bin_size=pq.s, t_start=0 * pq.s, t_stop=10.0 * pq.s + ) # Assumed results y_matrix_a = [[2, 1, 0, 1, 1, 1, 1, 0, 0, 0]] @@ -493,23 +543,31 @@ def test_binned_spiketrain_matrix(self): # Test if t_start is calculated correctly def test_binned_spiketrain_parameter_calc_tstart(self): - x = cv.BinnedSpikeTrain(self.spiketrain_a, bin_size=1 * pq.s, - n_bins=10, t_stop=10. * pq.s) - self.assertEqual(x.t_start, 0. * pq.s) - self.assertEqual(x.t_stop, 10. * pq.s) + x = cv.BinnedSpikeTrain( + self.spiketrain_a, bin_size=1 * pq.s, n_bins=10, t_stop=10.0 * pq.s + ) + self.assertEqual(x.t_start, 0.0 * pq.s) + self.assertEqual(x.t_stop, 10.0 * pq.s) self.assertEqual(x.bin_size, 1 * pq.s) self.assertEqual(x.n_bins, 10) # Test if error raises when type of n_bins is not an integer def test_binned_spiketrain_n_bins_not_int(self): a = self.spiketrain_a - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, bin_size=pq.s, - n_bins=1.4, t_start=0 * pq.s, - t_stop=10. * pq.s) + self.assertRaises( + ValueError, + cv.BinnedSpikeTrain, + a, + bin_size=pq.s, + n_bins=1.4, + t_start=0 * pq.s, + t_stop=10.0 * pq.s, + ) def test_to_array(self): - x = cv.BinnedSpikeTrain(self.spiketrain_a, bin_size=1 * pq.s, - n_bins=10, t_stop=10. * pq.s) + x = cv.BinnedSpikeTrain( + self.spiketrain_a, bin_size=1 * pq.s, n_bins=10, t_stop=10.0 * pq.s + ) arr_float = x.to_array(dtype=float) assert_array_equal(arr_float, x.to_array().astype(float)) @@ -524,19 +582,23 @@ def test_binned_spiketrain_insufficient_arguments(self): a, bin_size=1 * pq.s, t_start=0 * pq.s, - t_stop=0 * pq.s) + t_stop=0 * pq.s, + ) def test_different_input_types(self): a = self.spiketrain_a q = [1, 2, 3] * pq.s - self.assertRaises(ValueError, cv.BinnedSpikeTrain, - spiketrains=[a, q], bin_size=pq.s) + self.assertRaises( + ValueError, cv.BinnedSpikeTrain, spiketrains=[a, q], bin_size=pq.s + ) def test_get_start_stop(self): a = self.spiketrain_a b = neo.SpikeTrain( [-0.1, -0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, - t_start=-1 * pq.s, t_stop=8 * pq.s) + t_start=-1 * pq.s, + t_stop=8 * pq.s, + ) start, stop = get_common_start_stop_times(a) self.assertEqual(start, a.t_start) self.assertEqual(stop, a.t_stop) @@ -546,33 +608,51 @@ def test_get_start_stop(self): def test_consistency_errors(self): a = self.spiketrain_a - b = neo.SpikeTrain([-2, -1] * pq.s, t_start=-2 * pq.s, - t_stop=-1 * pq.s) - self.assertRaises(TypeError, cv.BinnedSpikeTrain, [a, b], t_start=5, - t_stop=0, bin_size=pq.s, n_bins=10) - - b = neo.SpikeTrain([-7, -8, -9] * pq.s, t_start=-9 * pq.s, - t_stop=-7 * pq.s) - self.assertRaises(TypeError, cv.BinnedSpikeTrain, b, t_start=None, - t_stop=10, bin_size=pq.s, n_bins=10) - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, t_start=0 * pq.s, - t_stop=10 * pq.s, bin_size=3 * pq.s, n_bins=10) - - b = neo.SpikeTrain([-4, -2, 0, 1] * pq.s, t_start=-4 * pq.s, - t_stop=1 * pq.s) + b = neo.SpikeTrain([-2, -1] * pq.s, t_start=-2 * pq.s, t_stop=-1 * pq.s) + self.assertRaises( + TypeError, + cv.BinnedSpikeTrain, + [a, b], + t_start=5, + t_stop=0, + bin_size=pq.s, + n_bins=10, + ) + + b = neo.SpikeTrain([-7, -8, -9] * pq.s, t_start=-9 * pq.s, t_stop=-7 * pq.s) + self.assertRaises( + TypeError, + cv.BinnedSpikeTrain, + b, + t_start=None, + t_stop=10, + bin_size=pq.s, + n_bins=10, + ) + self.assertRaises( + ValueError, + cv.BinnedSpikeTrain, + a, + t_start=0 * pq.s, + t_stop=10 * pq.s, + bin_size=3 * pq.s, + n_bins=10, + ) + + b = neo.SpikeTrain([-4, -2, 0, 1] * pq.s, t_start=-4 * pq.s, t_stop=1 * pq.s) self.assertRaises( TypeError, cv.BinnedSpikeTrain, b, bin_size=-2 * pq.s, t_start=-4 * pq.s, - t_stop=0 * pq.s) + t_stop=0 * pq.s, + ) # Test edges def test_binned_spiketrain_bin_edges(self): a = self.spiketrain_a - x = cv.BinnedSpikeTrain(a, bin_size=1 * pq.s, n_bins=10, - t_stop=10. * pq.s) + x = cv.BinnedSpikeTrain(a, bin_size=1 * pq.s, n_bins=10, t_stop=10.0 * pq.s) # Test all edges assert_array_equal(x.bin_edges, [float(i) for i in range(11)]) @@ -588,8 +668,7 @@ def test_binned_spiketrain_different_units(self): xb = cv.BinnedSpikeTrain(b, bin_size=bin_size.rescale(pq.ms)) assert_array_equal(xa.to_array(), xb.to_array()) assert_array_equal(xa.to_bool_array(), xb.to_bool_array()) - assert_array_equal(xa.sparse_matrix.data, - xb.sparse_matrix.data) + assert_array_equal(xa.sparse_matrix.data, xb.sparse_matrix.data) assert_array_equal(xa.bin_edges, xb.bin_edges) def test_binary_to_binned_matrix(self): @@ -610,14 +689,32 @@ def test_binary_to_binned_matrix(self): self.assertEqual(x.t_start, 1 * pq.s) # Raise error - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, - t_start=5 * pq.s, t_stop=0 * pq.s, bin_size=pq.s, - n_bins=10) - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, t_start=0 * pq.s, - t_stop=10 * pq.s, bin_size=3 * pq.s, n_bins=10) - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, - bin_size=-2 * pq.s, t_start=-4 * pq.s, - t_stop=0 * pq.s) + self.assertRaises( + ValueError, + cv.BinnedSpikeTrain, + a, + t_start=5 * pq.s, + t_stop=0 * pq.s, + bin_size=pq.s, + n_bins=10, + ) + self.assertRaises( + ValueError, + cv.BinnedSpikeTrain, + a, + t_start=0 * pq.s, + t_stop=10 * pq.s, + bin_size=3 * pq.s, + n_bins=10, + ) + self.assertRaises( + ValueError, + cv.BinnedSpikeTrain, + a, + bin_size=-2 * pq.s, + t_start=-4 * pq.s, + t_stop=0 * pq.s, + ) # Check binary property self.assertTrue(x.is_binary) @@ -629,15 +726,15 @@ def test_binned_to_binned(self): assert_array_equal(y.to_array(), x) # test with a list - x = cv.BinnedSpikeTrain([[0, 1, 2, 3]], bin_size=1 * pq.s, - t_stop=3 * pq.s).to_array() + x = cv.BinnedSpikeTrain( + [[0, 1, 2, 3]], bin_size=1 * pq.s, t_stop=3 * pq.s + ).to_array() y = cv.BinnedSpikeTrain(x, bin_size=1 * pq.s, t_start=0 * pq.s) assert_array_equal(y.to_array(), x) # test with a numpy array a = np.array([[0, 1, 2, 3], [1, 2, 2.5, 3]]) - x = cv.BinnedSpikeTrain(a, bin_size=1 * pq.s, - t_stop=3 * pq.s).to_array() + x = cv.BinnedSpikeTrain(a, bin_size=1 * pq.s, t_stop=3 * pq.s).to_array() y = cv.BinnedSpikeTrain(x, bin_size=1 * pq.s, t_start=0 * pq.s) assert_array_equal(y.to_array(), x) @@ -647,66 +744,96 @@ def test_binned_to_binned(self): # Raise Errors # give a strangely shaped matrix as input (not MxN) a = np.array([[0, 1, 2, 3], [1, 2, 3]], dtype=object) - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, t_start=0 * pq.s, - bin_size=1 * pq.s) + self.assertRaises( + ValueError, cv.BinnedSpikeTrain, a, t_start=0 * pq.s, bin_size=1 * pq.s + ) # Give no t_start or t_stop a = np.array([[0, 1, 2, 3], [1, 2, 3, 4]]) - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, - bin_size=1 * pq.s) + self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, bin_size=1 * pq.s) # Input format not supported a = np.array(([0, 1, 2], [0, 1, 2, 3, 4]), dtype=object) - self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, - bin_size=1 * pq.s) + self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, bin_size=1 * pq.s) def test_binnend_spiketrain_different_input_units(self): - train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s, - t_start=1 * pq.s, t_stop=1.01 * pq.s) - bst = cv.BinnedSpikeTrain(train, - t_start=1 * pq.s, t_stop=1.01 * pq.s, - bin_size=1 * pq.ms) + train = neo.SpikeTrain( + times=np.array([1.001, 1.002, 1.005]) * pq.s, + t_start=1 * pq.s, + t_stop=1.01 * pq.s, + ) + bst = cv.BinnedSpikeTrain( + train, t_start=1 * pq.s, t_stop=1.01 * pq.s, bin_size=1 * pq.ms + ) self.assertEqual(bst.units, pq.s) - target_edges = np.array([1000, 1001, 1002, 1003, 1004, 1005, 1006, - 1007, 1008, 1009, 1010], dtype=float - ) * pq.ms - target_centers = np.array( - [1000.5, 1001.5, 1002.5, 1003.5, 1004.5, 1005.5, 1006.5, 1007.5, - 1008.5, 1009.5], dtype=float) * pq.ms + target_edges = ( + np.array( + [1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010], + dtype=float, + ) + * pq.ms + ) + target_centers = ( + np.array( + [ + 1000.5, + 1001.5, + 1002.5, + 1003.5, + 1004.5, + 1005.5, + 1006.5, + 1007.5, + 1008.5, + 1009.5, + ], + dtype=float, + ) + * pq.ms + ) assert_array_almost_equal(bst.bin_edges, target_edges) assert_array_almost_equal(bst.bin_centers, target_centers) - bst = cv.BinnedSpikeTrain(train, - t_start=1 * pq.s, t_stop=1010 * pq.ms, - bin_size=1 * pq.ms) + bst = cv.BinnedSpikeTrain( + train, t_start=1 * pq.s, t_stop=1010 * pq.ms, bin_size=1 * pq.ms + ) self.assertEqual(bst.units, pq.s) assert_array_almost_equal(bst.bin_edges, target_edges) assert_array_almost_equal(bst.bin_centers, target_centers) def test_rescale(self): - train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s, - t_start=1 * pq.s, t_stop=1.01 * pq.s) - bst = cv.BinnedSpikeTrain(train, t_start=1 * pq.s, - t_stop=1.01 * pq.s, - bin_size=1 * pq.ms) + train = neo.SpikeTrain( + times=np.array([1.001, 1.002, 1.005]) * pq.s, + t_start=1 * pq.s, + t_stop=1.01 * pq.s, + ) + bst = cv.BinnedSpikeTrain( + train, t_start=1 * pq.s, t_stop=1.01 * pq.s, bin_size=1 * pq.ms + ) self.assertEqual(bst.units, pq.s) self.assertEqual(bst._t_start, 1) # 1 s self.assertEqual(bst._t_stop, 1.01) # 1.01 s self.assertEqual(bst._bin_size, 0.001) # 0.001 s - bst.rescale(units='ms') + bst.rescale(units="ms") self.assertEqual(bst.units, pq.ms) self.assertEqual(bst._t_start, 1000) # 1 s self.assertEqual(bst._t_stop, 1010) # 1.01 s self.assertEqual(bst._bin_size, 1) # 0.001 s def test_repr(self): - train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s, - t_start=1 * pq.s, t_stop=1.01 * pq.s) - bst = cv.BinnedSpikeTrain(train, t_start=1 * pq.s, - t_stop=1.01 * pq.s, - bin_size=1 * pq.ms) - self.assertEqual(repr(bst), "BinnedSpikeTrain(t_start=1.0 s, " - "t_stop=1.01 s, bin_size=0.001 s; " - "shape=(1, 10), format=csr_matrix)") + train = neo.SpikeTrain( + times=np.array([1.001, 1.002, 1.005]) * pq.s, + t_start=1 * pq.s, + t_stop=1.01 * pq.s, + ) + bst = cv.BinnedSpikeTrain( + train, t_start=1 * pq.s, t_stop=1.01 * pq.s, bin_size=1 * pq.ms + ) + self.assertEqual( + repr(bst), + "BinnedSpikeTrain(t_start=1.0 s, " + "t_stop=1.01 s, bin_size=0.001 s; " + "shape=(1, 10), format=csr_matrix)", + ) def test_binned_sparsity(self): train = neo.SpikeTrain(np.arange(10), t_stop=10 * pq.s, units=pq.s) @@ -715,52 +842,51 @@ def test_binned_sparsity(self): # Test fix for rounding errors def test_binned_spiketrain_rounding(self): - train = neo.SpikeTrain(times=np.arange(120000) / 30000. * pq.s, - t_start=0 * pq.s, t_stop=4 * pq.s) - bst = cv.BinnedSpikeTrain(train, - t_start=0 * pq.s, t_stop=4 * pq.s, - bin_size=1. / 30000. * pq.s) - assert_array_equal(bst.to_array().nonzero()[1], - np.arange(120000)) + train = neo.SpikeTrain( + times=np.arange(120000) / 30000.0 * pq.s, t_start=0 * pq.s, t_stop=4 * pq.s + ) + bst = cv.BinnedSpikeTrain( + train, t_start=0 * pq.s, t_stop=4 * pq.s, bin_size=1.0 / 30000.0 * pq.s + ) + assert_array_equal(bst.to_array().nonzero()[1], np.arange(120000)) class DiscretiseSpiketrainsTestCase(unittest.TestCase): def setUp(self): times = (np.arange(10) + np.random.uniform(size=10)) * pq.ms - self.spiketrains = [neo.SpikeTrain(times, t_stop=10*pq.ms)] * 5 + self.spiketrains = [neo.SpikeTrain(times, t_stop=10 * pq.ms)] * 5 def test_list_of_spiketrains(self): - discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains, - 1 / pq.ms) + discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains, 1 / pq.ms) for idx in range(len(self.spiketrains)): - np.testing.assert_array_equal(discretised_spiketrains[idx].times, - np.arange(10) * pq.ms) + np.testing.assert_array_equal( + discretised_spiketrains[idx].times, np.arange(10) * pq.ms + ) def test_single_spiketrain(self): - discretised_spiketrain = cv.discretise_spiketimes(self.spiketrains[0], - 1 / pq.ms) - np.testing.assert_array_equal(discretised_spiketrain.times, - np.arange(10) * pq.ms) + discretised_spiketrain = cv.discretise_spiketimes( + self.spiketrains[0], 1 / pq.ms + ) + np.testing.assert_array_equal( + discretised_spiketrain.times, np.arange(10) * pq.ms + ) def test_preserve_t_start(self): - spiketrain = neo.SpikeTrain([0.7, 5.1]*pq.ms, - t_start=0.5*pq.ms, t_stop=10*pq.ms) + spiketrain = neo.SpikeTrain( + [0.7, 5.1] * pq.ms, t_start=0.5 * pq.ms, t_stop=10 * pq.ms + ) with self.assertWarns(UserWarning): - discretised_spiketrain = cv.discretise_spiketimes(spiketrain, - 1 / pq.ms) - np.testing.assert_array_equal(discretised_spiketrain.times, - [0.5, 5] * pq.ms) + discretised_spiketrain = cv.discretise_spiketimes(spiketrain, 1 / pq.ms) + np.testing.assert_array_equal(discretised_spiketrain.times, [0.5, 5] * pq.ms) def test_binning_consistency(self): - discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains, - 1 / pq.ms) - bsts = cv.BinnedSpikeTrain(self.spiketrains, - bin_size=1 * pq.ms) - bsts_discretised = cv.BinnedSpikeTrain(discretised_spiketrains, - bin_size=1 * pq.ms) - np.testing.assert_array_equal(bsts.to_array(), - bsts_discretised.to_array()) + discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains, 1 / pq.ms) + bsts = cv.BinnedSpikeTrain(self.spiketrains, bin_size=1 * pq.ms) + bsts_discretised = cv.BinnedSpikeTrain( + discretised_spiketrains, bin_size=1 * pq.ms + ) + np.testing.assert_array_equal(bsts.to_array(), bsts_discretised.to_array()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_cubic.py b/elephant/test/test_cubic.py index fd6e44d7d..48eac646b 100644 --- a/elephant/test/test_cubic.py +++ b/elephant/test/test_cubic.py @@ -32,17 +32,17 @@ def setUp(self): n0 = 100000 - n2 self.xi = 10 self.data_signal = neo.AnalogSignal( - numpy.array([self.xi] * n2 + [0] * n0).reshape(n0 + n2, 1) * - pq.dimensionless, sampling_period=1 * pq.s) + numpy.array([self.xi] * n2 + [0] * n0).reshape(n0 + n2, 1) + * pq.dimensionless, + sampling_period=1 * pq.s, + ) self.data_array = numpy.array([self.xi] * n2 + [0] * n0) self.alpha = 0.05 self.max_iterations = 10 def test_cubic(self): - # Computing the output of CuBIC for the test data AnalogSignal - xi, p_vals, k, test_aborted = cubic.cubic( - self.data_signal, alpha=self.alpha) + xi, p_vals, k, test_aborted = cubic.cubic(self.data_signal, alpha=self.alpha) # Check the types of the outputs self.assertIsInstance(xi, int) @@ -71,8 +71,7 @@ def test_cubic(self): self.assertEqual(xi, self.xi) # Computing the output of CuBIC for the test data Array - xi, p_vals, k, test_aborted = cubic.cubic( - self.data_array, alpha=self.alpha) + xi, p_vals, k, test_aborted = cubic.cubic(self.data_array, alpha=self.alpha) # Check the types of the outputs self.assertIsInstance(xi, int) @@ -106,41 +105,52 @@ def test_cubic(self): def test_cubic_max_iterations(self): # Test exceeding max_iterations with self.assertWarns(UserWarning): - xi_max_iterations, p_vals_max_iterations, k_max_iterations, \ - test_aborted = cubic.cubic(self.data_signal, alpha=1, - max_iterations=self.max_iterations) + xi_max_iterations, p_vals_max_iterations, k_max_iterations, test_aborted = ( + cubic.cubic( + self.data_signal, alpha=1, max_iterations=self.max_iterations + ) + ) self.assertEqual(test_aborted, True) self.assertEqual(xi_max_iterations - 1, self.max_iterations) def test_cubic_errors(self): - # Check error ouputs for mis-settings of the parameters # Empty signal self.assertRaises( - ValueError, cubic.cubic, neo.AnalogSignal( - [] * pq.dimensionless, sampling_period=10 * pq.ms)) + ValueError, + cubic.cubic, + neo.AnalogSignal([] * pq.dimensionless, sampling_period=10 * pq.ms), + ) dummy_data = numpy.tile([1, 2, 3], reps=3) # Multidimensional array - self.assertRaises(ValueError, cubic.cubic, neo.AnalogSignal( - dummy_data * pq.dimensionless, - sampling_period=10 * pq.ms)) + self.assertRaises( + ValueError, + cubic.cubic, + neo.AnalogSignal(dummy_data * pq.dimensionless, sampling_period=10 * pq.ms), + ) self.assertRaises(ValueError, cubic.cubic, dummy_data.copy()) # Negative alpha self.assertRaises(ValueError, cubic.cubic, self.data_array, alpha=-0.1) # Negative number of max_iterations - self.assertRaises(ValueError, cubic.cubic, self.data_array, - max_iterations=-100) + self.assertRaises(ValueError, cubic.cubic, self.data_array, max_iterations=-100) # Checking case in which the second cumulant of the signal is smaller # than the first cumulant (analytical constrain of the method) - self.assertRaises(ValueError, cubic.cubic, neo.AnalogSignal( - numpy.array([1] * 1000).reshape(1000, 1), units=pq.dimensionless, - sampling_period=10 * pq.ms), alpha=self.alpha) + self.assertRaises( + ValueError, + cubic.cubic, + neo.AnalogSignal( + numpy.array([1] * 1000).reshape(1000, 1), + units=pq.dimensionless, + sampling_period=10 * pq.ms, + ), + alpha=self.alpha, + ) if __name__ == "__main__": diff --git a/elephant/test/test_current_source_density.py b/elephant/test/test_current_source_density.py index b48b2b361..fa04e17b3 100644 --- a/elephant/test/test_current_source_density.py +++ b/elephant/test/test_current_source_density.py @@ -17,12 +17,12 @@ import elephant.current_source_density_src.utility_functions as utils from elephant.current_source_density import generate_lfp -available_1d = ['StandardCSD', 'DeltaiCSD', 'StepiCSD', 'SplineiCSD', 'KCSD1D'] -available_2d = ['KCSD2D', 'MoIKCSD'] -available_3d = ['KCSD3D'] -kernel_methods = ['KCSD1D', 'KCSD2D', 'KCSD3D', 'MoIKCSD'] -icsd_methods = ['DeltaiCSD', 'StepiCSD', 'SplineiCSD'] -py_iCSD_toolbox = ['StandardCSD', 'DeltaiCSD', 'StepiCSD', 'SplineiCSD'] +available_1d = ["StandardCSD", "DeltaiCSD", "StepiCSD", "SplineiCSD", "KCSD1D"] +available_2d = ["KCSD2D", "MoIKCSD"] +available_3d = ["KCSD3D"] +kernel_methods = ["KCSD1D", "KCSD2D", "KCSD3D", "MoIKCSD"] +icsd_methods = ["DeltaiCSD", "StepiCSD", "SplineiCSD"] +py_iCSD_toolbox = ["StandardCSD", "DeltaiCSD", "StepiCSD", "SplineiCSD"] class LFP_TestCase(unittest.TestCase): @@ -54,69 +54,78 @@ def setUp(self): self.csd_method = csd.estimate_csd # Input dictionaries for each method - self.params = {'DeltaiCSD': {'sigma_top': 0. * pq.S / pq.m, - 'diam': 500E-6 * pq.m}, - 'StepiCSD': {'sigma_top': 0. * pq.S / pq.m, - 'tol': 1E-12, - 'diam': 500E-6 * pq.m}, - 'SplineiCSD': {'sigma_top': 0. * pq.S / pq.m, - 'num_steps': 201, 'tol': 1E-12, - 'diam': 500E-6 * pq.m}, - 'StandardCSD': {}, 'KCSD1D': {'h': 50., - 'Rs': np.array( - (0.1, 0.25, 0.5))}} + self.params = { + "DeltaiCSD": {"sigma_top": 0.0 * pq.S / pq.m, "diam": 500e-6 * pq.m}, + "StepiCSD": { + "sigma_top": 0.0 * pq.S / pq.m, + "tol": 1e-12, + "diam": 500e-6 * pq.m, + }, + "SplineiCSD": { + "sigma_top": 0.0 * pq.S / pq.m, + "num_steps": 201, + "tol": 1e-12, + "diam": 500e-6 * pq.m, + }, + "StandardCSD": {}, + "KCSD1D": {"h": 50.0, "Rs": np.array((0.1, 0.25, 0.5))}, + } def test_validate_inputs(self): self.assertRaises(TypeError, self.csd_method, lfp=[[1], [2], [3]]) - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - coordinates=self.ele_pos * pq.mm) + self.assertRaises( + ValueError, self.csd_method, lfp=self.lfp, coordinates=self.ele_pos * pq.mm + ) # inconsistent number of electrodes - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - coordinates=[1, 2, 3, 4] * pq.mm, - method='StandardCSD') + self.assertRaises( + ValueError, + self.csd_method, + lfp=self.lfp, + coordinates=[1, 2, 3, 4] * pq.mm, + method="StandardCSD", + ) # bad method name - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - method='InvalidMethodName') - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - method='KCSD2D') - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - method='KCSD3D') + self.assertRaises( + ValueError, self.csd_method, lfp=self.lfp, method="InvalidMethodName" + ) + self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, method="KCSD2D") + self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, method="KCSD3D") def test_inputs_standardcsd(self): - method = 'StandardCSD' + method = "StandardCSD" result = self.csd_method(self.lfp, method=method) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) self.assertEqual(result.shape[0], 1) def test_inputs_deltasplineicsd(self): - methods = ['DeltaiCSD', 'SplineiCSD'] + methods = ["DeltaiCSD", "SplineiCSD"] for method in methods: - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - method=method) - result = self.csd_method(self.lfp, method=method, - **self.params[method]) + self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, method=method) + result = self.csd_method(self.lfp, method=method, **self.params[method]) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) self.assertEqual(result.times.shape[0], 1) def test_inputs_stepicsd(self): - method = 'StepiCSD' - self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, - method=method) - self.assertRaises(AssertionError, self.csd_method, lfp=self.lfp, - method=method, **self.params[method]) - self.params['StepiCSD'].update({'h': np.ones(5) * 100E-6 * pq.m}) - result = self.csd_method(self.lfp, method=method, - **self.params[method]) + method = "StepiCSD" + self.assertRaises(ValueError, self.csd_method, lfp=self.lfp, method=method) + self.assertRaises( + AssertionError, + self.csd_method, + lfp=self.lfp, + method=method, + **self.params[method], + ) + self.params["StepiCSD"].update({"h": np.ones(5) * 100e-6 * pq.m}) + result = self.csd_method(self.lfp, method=method, **self.params[method]) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) self.assertEqual(result.times.shape[0], 1) def test_inuts_kcsd(self): - method = 'KCSD1D' - result = self.csd_method(self.lfp, method=method, - **self.params[method]) + method = "KCSD1D" + result = self.csd_method(self.lfp, method=method, **self.params[method]) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) self.assertEqual(len(result.times), 1) @@ -126,13 +135,13 @@ class CSD2D_TestCase(unittest.TestCase): def setUp(self): xx_ele, yy_ele = utils.generate_electrodes(dim=2) self.lfp = csd.generate_lfp(utils.large_source_2D, xx_ele, yy_ele) - self.params = {'KCSD2D': {'sigma': 1., 'Rs': np.array( - (0.1, 0.25, 0.5))}} # Input dictionaries for each method + self.params = { + "KCSD2D": {"sigma": 1.0, "Rs": np.array((0.1, 0.25, 0.5))} + } # Input dictionaries for each method def test_kcsd2d_init(self): - method = 'KCSD2D' - result = csd.estimate_csd(lfp=self.lfp, method=method, - **self.params[method]) + method = "KCSD2D" + result = csd.estimate_csd(lfp=self.lfp, method=method, **self.params[method]) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) self.assertEqual(len(result.times), 1) @@ -141,16 +150,20 @@ def test_kcsd2d_init(self): class CSD3D_TestCase(unittest.TestCase): def setUp(self): xx_ele, yy_ele, zz_ele = utils.generate_electrodes(dim=3) - self.lfp = csd.generate_lfp(utils.gauss_3d_dipole, - xx_ele, yy_ele, zz_ele) - self.params = {'KCSD3D': {'gdx': 0.1, 'gdy': 0.1, 'gdz': 0.1, - 'src_type': 'step', - 'Rs': np.array((0.1, 0.25, 0.5))}} + self.lfp = csd.generate_lfp(utils.gauss_3d_dipole, xx_ele, yy_ele, zz_ele) + self.params = { + "KCSD3D": { + "gdx": 0.1, + "gdy": 0.1, + "gdz": 0.1, + "src_type": "step", + "Rs": np.array((0.1, 0.25, 0.5)), + } + } def test_kcsd2d_init(self): - method = 'KCSD3D' - result = csd.estimate_csd(lfp=self.lfp, method=method, - **self.params[method]) + method = "KCSD3D" + result = csd.estimate_csd(lfp=self.lfp, method=method, **self.params[method]) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) self.assertEqual(len(result.times), 1) @@ -160,8 +173,7 @@ class GenerateLfpTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.one_dimensional = np.linspace(0, 10, 2304) - cls.two_dimensional = np.linspace(0, 10, 2304 - ).reshape(2304, 1) + cls.two_dimensional = np.linspace(0, 10, 2304).reshape(2304, 1) def test_generate_lfp_one_dimensional_array(self): """ @@ -180,5 +192,5 @@ def test_generate_lfp_two_dimensional_array(self): generate_lfp(utils.gauss_1d_dipole, self.two_dimensional) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_gpfa.py b/elephant/test/test_gpfa.py index fcb231017..1cd4676a4 100644 --- a/elephant/test/test_gpfa.py +++ b/elephant/test/test_gpfa.py @@ -15,8 +15,10 @@ from elephant.spike_train_generation import StationaryPoissonProcess from elephant.trials import TrialsFromLists + try: import sklearn + HAVE_SKLEARN = True except ModuleNotFoundError: HAVE_SKLEARN = False @@ -27,7 +29,7 @@ from sklearn.model_selection import cross_val_score -@unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn') +@unittest.skipUnless(HAVE_SKLEARN, "requires sklearn") class GPFATestCase(unittest.TestCase): @classmethod def setUpClass(cls): @@ -39,9 +41,9 @@ def gen_gamma_spike_train(k, theta, t_max): return s[s < t_max] def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): - s = gen_gamma_spike_train(shapes[0], 1. / rates[0], durs[0]) + s = gen_gamma_spike_train(shapes[0], 1.0 / rates[0], durs[0]) for i in range(1, 4): - s_i = gen_gamma_spike_train(shapes[i], 1. / rates[i], durs[i]) + s_i = gen_gamma_spike_train(shapes[i], 1.0 / rates[i], durs[i]) s = np.concatenate([s, s_i + np.sum(durs[:i])]) return s @@ -56,22 +58,54 @@ def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): n_trials = 100 cls.data0 = [] for trial in range(n_trials): - n1 = neo.SpikeTrain(gen_test_data(rates_a, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n2 = neo.SpikeTrain(gen_test_data(rates_a, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n3 = neo.SpikeTrain(gen_test_data(rates_b, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n4 = neo.SpikeTrain(gen_test_data(rates_b, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n5 = neo.SpikeTrain(gen_test_data(rates_a, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n6 = neo.SpikeTrain(gen_test_data(rates_a, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n7 = neo.SpikeTrain(gen_test_data(rates_b, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) - n8 = neo.SpikeTrain(gen_test_data(rates_b, durs), units=1 * pq.s, - t_start=0 * pq.s, t_stop=10 * pq.s) + n1 = neo.SpikeTrain( + gen_test_data(rates_a, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n2 = neo.SpikeTrain( + gen_test_data(rates_a, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n3 = neo.SpikeTrain( + gen_test_data(rates_b, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n4 = neo.SpikeTrain( + gen_test_data(rates_b, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n5 = neo.SpikeTrain( + gen_test_data(rates_a, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n6 = neo.SpikeTrain( + gen_test_data(rates_a, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n7 = neo.SpikeTrain( + gen_test_data(rates_b, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) + n8 = neo.SpikeTrain( + gen_test_data(rates_b, durs), + units=1 * pq.s, + t_start=0 * pq.s, + t_stop=10 * pq.s, + ) cls.data0.append([n1, n2, n3, n4, n5, n6, n7, n8]) cls.x_dim = 4 @@ -84,27 +118,29 @@ def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): n_channels = 20 for trial in range(n_trials): rates = np.random.randint(low=1, high=100, size=n_channels) - spike_times = [StationaryPoissonProcess(rate=rate * pq.Hz, - t_stop=1000.0 * pq.ms).generate_spiketrain() - for rate in rates] + spike_times = [ + StationaryPoissonProcess( + rate=rate * pq.Hz, t_stop=1000.0 * pq.ms + ).generate_spiketrain() + for rate in rates + ] cls.data2.append(spike_times) def test_data1(self): gpfa = GPFA(x_dim=self.x_dim, em_max_iters=self.n_iters) gpfa.fit(self.data1) latent_variable_orth = gpfa.transform(self.data1) - self.assertAlmostEqual(gpfa.transform_info['log_likelihood'], - -8172.004695554373, places=5) + self.assertAlmostEqual( + gpfa.transform_info["log_likelihood"], -8172.004695554373, places=5 + ) # Since data1 is inherently 2 dimensional, only the first two # dimensions of latent_variable_orth should have finite power. for i in [0, 1]: self.assertNotEqual(latent_variable_orth[0][i].mean(), 0) self.assertNotEqual(latent_variable_orth[0][i].var(), 0) for i in [2, 3]: - self.assertAlmostEqual(latent_variable_orth[0][i].mean(), 0, - places=2) - self.assertAlmostEqual(latent_variable_orth[0][i].var(), 0, - places=2) + self.assertAlmostEqual(latent_variable_orth[0][i].mean(), 0, places=2) + self.assertAlmostEqual(latent_variable_orth[0][i].var(), 0, places=2) def test_transform_testing_data(self): # check if the num. of neurons in the test data matches the @@ -145,13 +181,19 @@ def test_data2(self): returned_data = gpfa.valid_data_names seqs = gpfa.transform(self.data2, returned_data=returned_data) for key, data in seqs.items(): - self.assertEqual(len(data), n_trials, - msg="Failed ndarray field {0}".format(key)) + self.assertEqual( + len(data), n_trials, msg="Failed ndarray field {0}".format(key) + ) t_start = self.data2[0][0].t_stop t_stop = self.data2[0][0].t_start n_bins = int(((t_start - t_stop) / self.bin_size).magnitude) - assert_array_equal(gpfa.transform_info['num_bins'], - [n_bins, ] * n_trials) + assert_array_equal( + gpfa.transform_info["num_bins"], + [ + n_bins, + ] + * n_trials, + ) def test_returned_data(self): gpfa = GPFA(bin_size=self.bin_size, x_dim=8, em_max_iters=self.n_iters) @@ -164,28 +206,29 @@ def test_returned_data(self): self.assertTrue(len(returned_data) == len(seqs)) self.assertTrue(isinstance(seqs, dict)) with self.assertRaises(ValueError): - seqs = gpfa.transform(self.data2, returned_data=['invalid_name']) + seqs = gpfa.transform(self.data2, returned_data=["invalid_name"]) def test_fit_transform(self): - gpfa1 = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters) + gpfa1 = GPFA( + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ) gpfa1.fit(self.data1) latent_variable_orth1 = gpfa1.transform(self.data1) latent_variable_orth2 = GPFA( - bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters).fit_transform(self.data1) + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ).fit_transform(self.data1) for i in range(len(self.data1)): for j in range(self.x_dim): - assert_array_almost_equal(latent_variable_orth1[i][j], - latent_variable_orth2[i][j]) + assert_array_almost_equal( + latent_variable_orth1[i][j], latent_variable_orth2[i][j] + ) def test_get_seq_sqrt(self): data = [self.data2[0]] seqs = gpfa_util.get_seqs(data, bin_size=self.bin_size) - seqs_not_sqrt = gpfa_util.get_seqs(data, bin_size=self.bin_size, - use_sqrt=False) - self.assertEqual(seqs['T'], seqs_not_sqrt['T']) - self.assertEqual(seqs['y'].shape, seqs_not_sqrt['y'].shape) + seqs_not_sqrt = gpfa_util.get_seqs(data, bin_size=self.bin_size, use_sqrt=False) + self.assertEqual(seqs["T"], seqs_not_sqrt["T"]) + self.assertEqual(seqs["y"].shape, seqs_not_sqrt["y"].shape) def test_cut_trials_inf(self): same_data = gpfa_util.cut_trials(self.data2, seg_length=np.Inf) @@ -199,14 +242,14 @@ def test_cut_trials_zero_length(self): def test_cut_trials_same_length(self): data = [self.data2[0]] seqs = gpfa_util.get_seqs(data, bin_size=self.bin_size) - seg_length = seqs[0]['T'] + seg_length = seqs[0]["T"] seqs_cut = gpfa_util.cut_trials(seqs, seg_length=seg_length) - assert_array_almost_equal(seqs[0]['y'], seqs_cut[0]['y']) + assert_array_almost_equal(seqs[0]["y"], seqs_cut[0]["y"]) def test_cut_trials_larger_length(self): data = [self.data2[0]] seqs = gpfa_util.get_seqs(data, bin_size=self.bin_size) - seg_length = seqs[0]['T'] + 1 + seg_length = seqs[0]["T"] + 1 with self.assertWarns(UserWarning): gpfa_util.cut_trials(seqs, seg_length=seg_length) @@ -220,23 +263,29 @@ def test_logdet(self): assert_array_almost_equal(logdet_fast, logdet_ground_truth) def test_trial_object_gpfa_fit(self): - gpfa_trial_object = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters) - gpfa_list_of_lists = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters) + gpfa_trial_object = GPFA( + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ) + gpfa_list_of_lists = GPFA( + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ) trials = TrialsFromLists(self.data1) gpfa_trial_object.fit(trials) gpfa_list_of_lists.fit(self.data1) - assert_array_almost_equal(gpfa_trial_object.params_estimated['gamma'], - gpfa_list_of_lists.params_estimated['gamma']) + assert_array_almost_equal( + gpfa_trial_object.params_estimated["gamma"], + gpfa_list_of_lists.params_estimated["gamma"], + ) def test_trial_object_gpfa_transform(self): - gpfa_trial_object = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters) - gpfa_list_of_lists = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters) + gpfa_trial_object = GPFA( + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ) + gpfa_list_of_lists = GPFA( + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ) trials = TrialsFromLists(self.data1) gpfa_trial_object.fit(trials) @@ -244,12 +293,15 @@ def test_trial_object_gpfa_transform(self): gpfa_list_of_lists.fit(self.data1) gpfa_list_of_lists.transform(self.data1) - assert_array_almost_equal(gpfa_trial_object.transform_info['Corth'], - gpfa_list_of_lists.transform_info['Corth']) + assert_array_almost_equal( + gpfa_trial_object.transform_info["Corth"], + gpfa_list_of_lists.transform_info["Corth"], + ) def test_trial_object_zero_trials(self): - gpfa_trial_object = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, - em_max_iters=self.n_iters) + gpfa_trial_object = GPFA( + bin_size=self.bin_size, x_dim=self.x_dim, em_max_iters=self.n_iters + ) trials = TrialsFromLists([]) with self.assertRaises(ValueError): diff --git a/elephant/test/test_icsd.py b/elephant/test/test_icsd.py index ebe243a1d..daec34b90 100644 --- a/elephant/test/test_icsd.py +++ b/elephant/test/test_icsd.py @@ -13,24 +13,27 @@ # patch quantities with the SI unit Siemens if it does not exist for symbol, prefix, definition, u_symbol in zip( - ['siemens', 'S', 'mS', 'uS', 'nS', 'pS'], - ['', '', 'milli', 'micro', 'nano', 'pico'], - [pq.A / pq.V, pq.A / pq.V, 'S', 'mS', 'uS', 'nS'], - [None, None, None, None, u'µS', None]): + ["siemens", "S", "mS", "uS", "nS", "pS"], + ["", "", "milli", "micro", "nano", "pico"], + [pq.A / pq.V, pq.A / pq.V, "S", "mS", "uS", "nS"], + [None, None, None, None, "µS", None], +): if isinstance(definition, str): definition = lastdefinition / 1000 if not hasattr(pq, symbol): - setattr(pq, symbol, pq.UnitQuantity( - prefix + 'siemens', - definition, - symbol=symbol, - u_symbol=u_symbol)) + setattr( + pq, + symbol, + pq.UnitQuantity( + prefix + "siemens", definition, symbol=symbol, u_symbol=u_symbol + ), + ) lastdefinition = definition -def potential_of_plane(z_j, z_i=0. * pq.m, - C_i=1 * pq.A / pq.m**2, - sigma=0.3 * pq.S / pq.m): +def potential_of_plane( + z_j, z_i=0.0 * pq.m, C_i=1 * pq.A / pq.m**2, sigma=0.3 * pq.S / pq.m +): """ Return potential of infinite horizontal plane with constant current source density at a vertical offset z_j. @@ -53,20 +56,21 @@ def potential_of_plane(z_j, z_i=0. * pq.m, """ try: - assert(z_j.units == z_i.units) + assert z_j.units == z_i.units except AssertionError as ae: - print('units of z_j ({}) and z_i ({}) not equal'.format(z_j.units, - z_i.units)) + print("units of z_j ({}) and z_i ({}) not equal".format(z_j.units, z_i.units)) raise ae return -C_i / (2 * sigma) * abs(z_j - z_i).simplified -def potential_of_disk(z_j, - z_i=0. * pq.m, - C_i=1 * pq.A / pq.m**2, - R_i=1E-3 * pq.m, - sigma=0.3 * pq.S / pq.m): +def potential_of_disk( + z_j, + z_i=0.0 * pq.m, + C_i=1 * pq.A / pq.m**2, + R_i=1e-3 * pq.m, + sigma=0.3 * pq.S / pq.m, +): """ Return potential of circular disk in horizontal plane with constant current source density at a vertical offset z_j. @@ -85,23 +89,30 @@ def potential_of_disk(z_j, conductivity of medium in units of S/m """ try: - assert(z_j.units == z_i.units == R_i.units) + assert z_j.units == z_i.units == R_i.units except AssertionError as ae: - print('units of z_j ({}), z_i ({}) and R_i ({}) not equal'.format( - z_j.units, z_i.units, R_i.units)) + print( + "units of z_j ({}), z_i ({}) and R_i ({}) not equal".format( + z_j.units, z_i.units, R_i.units + ) + ) raise ae - return C_i / (2 * sigma) * ( - np.sqrt((z_j - z_i) ** 2 + R_i**2) - abs(z_j - z_i)).simplified - - -def potential_of_cylinder(z_j, - z_i=0. * pq.m, - C_i=1 * pq.A / pq.m**3, - R_i=1E-3 * pq.m, - h_i=0.1 * pq.m, - sigma=0.3 * pq.S / pq.m, - ): + return ( + C_i + / (2 * sigma) + * (np.sqrt((z_j - z_i) ** 2 + R_i**2) - abs(z_j - z_i)).simplified + ) + + +def potential_of_cylinder( + z_j, + z_i=0.0 * pq.m, + C_i=1 * pq.A / pq.m**3, + R_i=1e-3 * pq.m, + h_i=0.1 * pq.m, + sigma=0.3 * pq.S / pq.m, +): """ Return potential of cylinder in horizontal plane with constant homogeneous current source density at a vertical offset z_j. @@ -123,10 +134,13 @@ def potential_of_cylinder(z_j, conductivity of medium in units of S/m """ try: - assert(z_j.units == z_i.units == R_i.units == h_i.units) + assert z_j.units == z_i.units == R_i.units == h_i.units except AssertionError as ae: - print('units of z_j ({}), z_i ({}), R_i ({}) and h ({}) not equal' - .format(z_j.units, z_i.units, R_i.units, h_i.units)) + print( + "units of z_j ({}), z_i ({}), R_i ({}) and h ({}) not equal".format( + z_j.units, z_i.units, R_i.units, h_i.units + ) + ) raise ae # speed up tests by stripping units @@ -137,19 +151,20 @@ def potential_of_cylinder(z_j, # evaluate integrand using quad def integrand(z): - return 1 / (2 * _sigma) * \ - (np.sqrt((z - _z_j)**2 + _R_i**2) - abs(z - _z_j)) + return 1 / (2 * _sigma) * (np.sqrt((z - _z_j) ** 2 + _R_i**2) - abs(z - _z_j)) phi_j, abserr = C_i * si.quad(integrand, z_i - h_i / 2, z_i + h_i / 2) - return (phi_j * z_i.units**2 / sigma.units) + return phi_j * z_i.units**2 / sigma.units -def get_lfp_of_planes(z_j=np.arange(21) * 1E-4 * pq.m, - z_i=np.array([8E-4, 10E-4, 12E-4]) * pq.m, - C_i=np.array([-.5, 1., -.5]) * pq.A / pq.m**2, - sigma=0.3 * pq.S / pq.m, - plot=True): +def get_lfp_of_planes( + z_j=np.arange(21) * 1e-4 * pq.m, + z_i=np.array([8e-4, 10e-4, 12e-4]) * pq.m, + C_i=np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2, + sigma=0.3 * pq.S / pq.m, + plot=True, +): """ Compute the lfp of spatially separated planes with given current source density @@ -162,33 +177,36 @@ def get_lfp_of_planes(z_j=np.arange(21) * 1E-4 * pq.m, # test plot if plot: import matplotlib.pyplot as plt + plt.figure() plt.subplot(121) ax = plt.gca() - ax.plot(np.zeros(z_j.size), z_j, 'r-o') + ax.plot(np.zeros(z_j.size), z_j, "r-o") for i, C in enumerate(C_i): - ax.plot((0, C), (z_i[i], z_i[i]), 'r-o') + ax.plot((0, C), (z_i[i], z_i[i]), "r-o") ax.set_ylim(z_j.min(), z_j.max()) - ax.set_ylabel('z_j ({})'.format(z_j.units)) - ax.set_xlabel('C_i ({})'.format(C_i.units)) - ax.set_title('planar CSD') + ax.set_ylabel("z_j ({})".format(z_j.units)) + ax.set_xlabel("C_i ({})".format(C_i.units)) + ax.set_title("planar CSD") plt.subplot(122) ax = plt.gca() - ax.plot(phi_j, z_j, 'r-o') + ax.plot(phi_j, z_j, "r-o") ax.set_ylim(z_j.min(), z_j.max()) - ax.set_xlabel('phi_j ({})'.format(phi_j.units)) - ax.set_title('LFP') + ax.set_xlabel("phi_j ({})".format(phi_j.units)) + ax.set_title("LFP") return phi_j, C_i -def get_lfp_of_disks(z_j=np.arange(21) * 1E-4 * pq.m, - z_i=np.array([8E-4, 10E-4, 12E-4]) * pq.m, - C_i=np.array([-.5, 1., -.5]) * pq.A / pq.m**2, - R_i=np.array([1, 1, 1]) * 1E-3 * pq.m, - sigma=0.3 * pq.S / pq.m, - plot=True): +def get_lfp_of_disks( + z_j=np.arange(21) * 1e-4 * pq.m, + z_i=np.array([8e-4, 10e-4, 12e-4]) * pq.m, + C_i=np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2, + R_i=np.array([1, 1, 1]) * 1e-3 * pq.m, + sigma=0.3 * pq.S / pq.m, + plot=True, +): """ Compute the lfp of spatially separated disks with a given current source density @@ -201,34 +219,37 @@ def get_lfp_of_disks(z_j=np.arange(21) * 1E-4 * pq.m, # test plot if plot: import matplotlib.pyplot as plt + plt.figure() plt.subplot(121) ax = plt.gca() - ax.plot(np.zeros(z_j.size), z_j, 'r-o') + ax.plot(np.zeros(z_j.size), z_j, "r-o") for i, C in enumerate(C_i): - ax.plot((0, C), (z_i[i], z_i[i]), 'r-o') + ax.plot((0, C), (z_i[i], z_i[i]), "r-o") ax.set_ylim(z_j.min(), z_j.max()) - ax.set_ylabel('z_j ({})'.format(z_j.units)) - ax.set_xlabel('C_i ({})'.format(C_i.units)) - ax.set_title('disk CSD\nR={}'.format(R_i)) + ax.set_ylabel("z_j ({})".format(z_j.units)) + ax.set_xlabel("C_i ({})".format(C_i.units)) + ax.set_title("disk CSD\nR={}".format(R_i)) plt.subplot(122) ax = plt.gca() - ax.plot(phi_j, z_j, 'r-o') + ax.plot(phi_j, z_j, "r-o") ax.set_ylim(z_j.min(), z_j.max()) - ax.set_xlabel('phi_j ({})'.format(phi_j.units)) - ax.set_title('LFP') + ax.set_xlabel("phi_j ({})".format(phi_j.units)) + ax.set_title("LFP") return phi_j, C_i -def get_lfp_of_cylinders(z_j=np.arange(21) * 1E-4 * pq.m, - z_i=np.array([8E-4, 10E-4, 12E-4]) * pq.m, - C_i=np.array([-.5, 1., -.5]) * pq.A / pq.m**3, - R_i=np.array([1, 1, 1]) * 1E-3 * pq.m, - h_i=np.array([1, 1, 1]) * 1E-4 * pq.m, - sigma=0.3 * pq.S / pq.m, - plot=True): +def get_lfp_of_cylinders( + z_j=np.arange(21) * 1e-4 * pq.m, + z_i=np.array([8e-4, 10e-4, 12e-4]) * pq.m, + C_i=np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3, + R_i=np.array([1, 1, 1]) * 1e-3 * pq.m, + h_i=np.array([1, 1, 1]) * 1e-4 * pq.m, + sigma=0.3 * pq.S / pq.m, + plot=True, +): """ Compute the lfp of spatially separated disks with a given current source density @@ -241,24 +262,23 @@ def get_lfp_of_cylinders(z_j=np.arange(21) * 1E-4 * pq.m, # test plot if plot: import matplotlib.pyplot as plt + plt.figure() plt.subplot(121) ax = plt.gca() - ax.plot(np.zeros(z_j.size), z_j, 'r-o') - ax.barh(np.asarray(z_i - h_i / 2), - np.asarray(C_i), - np.asarray(h_i), color='r') + ax.plot(np.zeros(z_j.size), z_j, "r-o") + ax.barh(np.asarray(z_i - h_i / 2), np.asarray(C_i), np.asarray(h_i), color="r") ax.set_ylim(z_j.min(), z_j.max()) - ax.set_ylabel('z_j ({})'.format(z_j.units)) - ax.set_xlabel('C_i ({})'.format(C_i.units)) - ax.set_title('cylinder CSD\nR={}'.format(R_i)) + ax.set_ylabel("z_j ({})".format(z_j.units)) + ax.set_xlabel("C_i ({})".format(C_i.units)) + ax.set_title("cylinder CSD\nR={}".format(R_i)) plt.subplot(122) ax = plt.gca() - ax.plot(phi_j, z_j, 'r-o') + ax.plot(phi_j, z_j, "r-o") ax.set_ylim(z_j.min(), z_j.max()) - ax.set_xlabel('phi_j ({})'.format(phi_j.units)) - ax.set_title('LFP') + ax.set_xlabel("phi_j ({})".format(phi_j.units)) + ax.set_title("LFP") return phi_j, C_i @@ -274,14 +294,14 @@ def test_StandardCSD_00(self): # set some parameters for ground truth csd and csd estimates. # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # uniform conductivity sigma = 0.3 * pq.S / pq.m @@ -292,11 +312,11 @@ def test_StandardCSD_00(self): # get LFP and CSD at contacts phi_j, C_i = get_lfp_of_planes(z_j, z_i, C_i, sigma, plot) std_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'sigma': sigma, - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "sigma": sigma, + "f_type": "gaussian", + "f_order": (3, 1), } std_csd = icsd.StandardCSD(**std_input) csd = std_csd.get_csd() @@ -309,14 +329,14 @@ def test_StandardCSD_01(self): # set some parameters for ground truth csd and csd estimates. # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * 1E3 * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * 1e3 * pq.A / pq.m**2 # uniform conductivity sigma = 0.3 * pq.S / pq.m @@ -327,11 +347,11 @@ def test_StandardCSD_01(self): # get LFP and CSD at contacts phi_j, C_i = get_lfp_of_planes(z_j, z_i, C_i, sigma, plot) std_input = { - 'lfp': phi_j * 1E3 * pq.mV / pq.V, - 'coord_electrode': z_j, - 'sigma': sigma, - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j * 1e3 * pq.mV / pq.V, + "coord_electrode": z_j, + "sigma": sigma, + "f_type": "gaussian", + "f_order": (3, 1), } std_csd = icsd.StandardCSD(**std_input) csd = std_csd.get_csd() @@ -344,14 +364,14 @@ def test_StandardCSD_02(self): # set some parameters for ground truth csd and csd estimates. # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # uniform conductivity sigma = 0.3 * pq.S / pq.m @@ -362,11 +382,11 @@ def test_StandardCSD_02(self): # get LFP and CSD at contacts phi_j, C_i = get_lfp_of_planes(z_j, z_i, C_i, sigma, plot) std_input = { - 'lfp': phi_j, - 'coord_electrode': z_j * 1E3 * pq.mm / pq.m, - 'sigma': sigma, - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j * 1e3 * pq.mm / pq.m, + "sigma": sigma, + "f_type": "gaussian", + "f_order": (3, 1), } std_csd = icsd.StandardCSD(**std_input) csd = std_csd.get_csd() @@ -379,14 +399,14 @@ def test_StandardCSD_03(self): # set some parameters for ground truth csd and csd estimates. # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # uniform conductivity sigma = 0.3 * pq.mS / pq.m @@ -397,11 +417,11 @@ def test_StandardCSD_03(self): # get LFP and CSD at contacts phi_j, C_i = get_lfp_of_planes(z_j, z_i, C_i, sigma, plot) std_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'sigma': sigma * 1E3 * pq.mS / pq.S, - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "sigma": sigma * 1e3 * pq.mS / pq.S, + "f_type": "gaussian", + "f_order": (3, 1), } std_csd = icsd.StandardCSD(**std_input) csd = std_csd.get_csd() @@ -415,17 +435,17 @@ def test_DeltaiCSD_00(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -435,16 +455,15 @@ def test_DeltaiCSD_00(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, - plot) + phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, plot) delta_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i.mean() * 2, # source diameter - 'sigma': sigma, # extracellular conductivity - 'sigma_top': sigma_top, # conductivity on top of cortex - 'f_type': 'gaussian', # gaussian filter - 'f_order': (3, 1), # 3-point filter, sigma = 1. + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i.mean() * 2, # source diameter + "sigma": sigma, # extracellular conductivity + "sigma_top": sigma_top, # conductivity on top of cortex + "f_type": "gaussian", # gaussian filter + "f_order": (3, 1), # 3-point filter, sigma = 1. } delta_icsd = icsd.DeltaiCSD(**delta_input) @@ -459,17 +478,17 @@ def test_DeltaiCSD_01(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -479,16 +498,15 @@ def test_DeltaiCSD_01(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, - plot) + phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, plot) delta_input = { - 'lfp': phi_j * 1E3 * pq.mV / pq.V, - 'coord_electrode': z_j, - 'diam': R_i.mean() * 2, # source diameter - 'sigma': sigma, # extracellular conductivity - 'sigma_top': sigma_top, # conductivity on top of cortex - 'f_type': 'gaussian', # gaussian filter - 'f_order': (3, 1), # 3-point filter, sigma = 1. + "lfp": phi_j * 1e3 * pq.mV / pq.V, + "coord_electrode": z_j, + "diam": R_i.mean() * 2, # source diameter + "sigma": sigma, # extracellular conductivity + "sigma_top": sigma_top, # conductivity on top of cortex + "f_type": "gaussian", # gaussian filter + "f_order": (3, 1), # 3-point filter, sigma = 1. } delta_icsd = icsd.DeltaiCSD(**delta_input) @@ -503,17 +521,17 @@ def test_DeltaiCSD_02(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -523,16 +541,15 @@ def test_DeltaiCSD_02(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, - plot) + phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, plot) delta_input = { - 'lfp': phi_j, - 'coord_electrode': z_j * 1E3 * pq.mm / pq.m, - 'diam': R_i.mean() * 2 * 1E3 * pq.mm / pq.m, # source diameter - 'sigma': sigma, # extracellular conductivity - 'sigma_top': sigma_top, # conductivity on top of cortex - 'f_type': 'gaussian', # gaussian filter - 'f_order': (3, 1), # 3-point filter, sigma = 1. + "lfp": phi_j, + "coord_electrode": z_j * 1e3 * pq.mm / pq.m, + "diam": R_i.mean() * 2 * 1e3 * pq.mm / pq.m, # source diameter + "sigma": sigma, # extracellular conductivity + "sigma_top": sigma_top, # conductivity on top of cortex + "f_type": "gaussian", # gaussian filter + "f_order": (3, 1), # 3-point filter, sigma = 1. } delta_icsd = icsd.DeltaiCSD(**delta_input) @@ -547,17 +564,17 @@ def test_DeltaiCSD_03(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -567,17 +584,16 @@ def test_DeltaiCSD_03(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, - plot) + phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, plot) delta_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i.mean() * 2, # source diameter - 'sigma': sigma * 1E3 * pq.mS / pq.S, # extracellular conductivity - 'sigma_top': sigma_top * 1E3 * pq.mS / pq.S, # conductivity on - # top of cortex - 'f_type': 'gaussian', # gaussian filter - 'f_order': (3, 1), # 3-point filter, sigma = 1. + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i.mean() * 2, # source diameter + "sigma": sigma * 1e3 * pq.mS / pq.S, # extracellular conductivity + "sigma_top": sigma_top * 1e3 * pq.mS / pq.S, # conductivity on + # top of cortex + "f_type": "gaussian", # gaussian filter + "f_order": (3, 1), # 3-point filter, sigma = 1. } delta_icsd = icsd.DeltaiCSD(**delta_input) @@ -592,17 +608,17 @@ def test_DeltaiCSD_04(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**2 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**2 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**2 # source radius (delta, step) - R_i = np.ones(z_j.size) * 1E-3 * pq.m + R_i = np.ones(z_j.size) * 1e-3 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -612,17 +628,16 @@ def test_DeltaiCSD_04(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, - plot) + phi_j, C_i = get_lfp_of_disks(z_j, z_i, C_i, R_i, sigma, plot) inds = np.delete(np.arange(21), 5) delta_input = { - 'lfp': phi_j[inds], - 'coord_electrode': z_j[inds], - 'diam': R_i[inds] * 2, # source diameter - 'sigma': sigma, # extracellular conductivity - 'sigma_top': sigma_top, # conductivity on top of cortex - 'f_type': 'gaussian', # gaussian filter - 'f_order': (3, 1), # 3-point filter, sigma = 1. + "lfp": phi_j[inds], + "coord_electrode": z_j[inds], + "diam": R_i[inds] * 2, # source diameter + "sigma": sigma, # extracellular conductivity + "sigma_top": sigma_top, # conductivity on top of cortex + "f_type": "gaussian", # gaussian filter + "f_order": (3, 1), # 3-point filter, sigma = 1. } delta_icsd = icsd.DeltaiCSD(**delta_input) @@ -637,20 +652,20 @@ def test_StepiCSD_units_00(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -660,19 +675,18 @@ def test_StepiCSD_units_00(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, sigma, plot) step_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i.mean() * 2, - 'sigma': sigma, - 'sigma_top': sigma, - 'h': h_i, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i.mean() * 2, + "sigma": sigma, + "sigma_top": sigma, + "h": h_i, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } step_icsd = icsd.StepiCSD(**step_input) csd = step_icsd.get_csd() @@ -686,20 +700,20 @@ def test_StepiCSD_01(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -709,19 +723,18 @@ def test_StepiCSD_01(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, sigma, plot) step_input = { - 'lfp': phi_j * 1E3 * pq.mV / pq.V, - 'coord_electrode': z_j, - 'diam': R_i.mean() * 2, - 'sigma': sigma, - 'sigma_top': sigma, - 'h': h_i, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j * 1e3 * pq.mV / pq.V, + "coord_electrode": z_j, + "diam": R_i.mean() * 2, + "sigma": sigma, + "sigma_top": sigma, + "h": h_i, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } step_icsd = icsd.StepiCSD(**step_input) csd = step_icsd.get_csd() @@ -735,20 +748,20 @@ def test_StepiCSD_02(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -758,19 +771,18 @@ def test_StepiCSD_02(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, sigma, plot) step_input = { - 'lfp': phi_j, - 'coord_electrode': z_j * 1E3 * pq.mm / pq.m, - 'diam': R_i.mean() * 2 * 1E3 * pq.mm / pq.m, - 'sigma': sigma, - 'sigma_top': sigma, - 'h': h_i * 1E3 * pq.mm / pq.m, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j * 1e3 * pq.mm / pq.m, + "diam": R_i.mean() * 2 * 1e3 * pq.mm / pq.m, + "sigma": sigma, + "sigma_top": sigma, + "h": h_i * 1e3 * pq.mm / pq.m, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } step_icsd = icsd.StepiCSD(**step_input) csd = step_icsd.get_csd() @@ -784,20 +796,20 @@ def test_StepiCSD_03(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -807,19 +819,18 @@ def test_StepiCSD_03(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, sigma, plot) step_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i.mean() * 2, - 'sigma': sigma * 1E3 * pq.mS / pq.S, - 'sigma_top': sigma * 1E3 * pq.mS / pq.S, - 'h': h_i, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i.mean() * 2, + "sigma": sigma * 1e3 * pq.mS / pq.S, + "sigma_top": sigma * 1e3 * pq.mS / pq.S, + "h": h_i, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } step_icsd = icsd.StepiCSD(**step_input) csd = step_icsd.get_csd() @@ -833,20 +844,20 @@ def test_StepiCSD_units_04(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -856,19 +867,18 @@ def test_StepiCSD_units_04(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i, C_i, R_i, h_i, sigma, plot) inds = np.delete(np.arange(21), 5) step_input = { - 'lfp': phi_j[inds], - 'coord_electrode': z_j[inds], - 'diam': R_i[inds] * 2, - 'sigma': sigma, - 'sigma_top': sigma, - 'h': h_i[inds], - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j[inds], + "coord_electrode": z_j[inds], + "diam": R_i[inds] * 2, + "sigma": sigma, + "sigma_top": sigma, + "h": h_i[inds], + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } step_icsd = icsd.StepiCSD(**step_input) csd = step_icsd.get_csd() @@ -882,20 +892,20 @@ def test_SplineiCSD_00(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -904,11 +914,10 @@ def test_SplineiCSD_00(self): # construct interpolators, spline method assume underlying source # pattern generating LFPs that are cubic spline interpolates between # contacts so we generate CSD data relying on the same assumption - f_C = interp1d(z_i, C_i, kind='cubic') + f_C = interp1d(z_i, C_i, kind="cubic") f_R = interp1d(z_i, R_i) num_steps = 201 - z_i_i = np.linspace(float(z_i[0]), float( - z_i[-1]), num_steps) * z_i.units + z_i_i = np.linspace(float(z_i[0]), float(z_i[-1]), num_steps) * z_i.units C_i_i = f_C(np.asarray(z_i_i)) * C_i.units R_i_i = f_R(z_i_i) * R_i.units @@ -918,19 +927,18 @@ def test_SplineiCSD_00(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, sigma, plot) spline_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i * 2, - 'sigma': sigma, - 'sigma_top': sigma, - 'num_steps': num_steps, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i * 2, + "sigma": sigma, + "sigma_top": sigma, + "num_steps": num_steps, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } spline_icsd = icsd.SplineiCSD(**spline_input) csd = spline_icsd.get_csd() @@ -944,20 +952,20 @@ def test_SplineiCSD_01(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(10, 31) * 1E-4 * pq.m + z_j = np.arange(10, 31) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -966,11 +974,10 @@ def test_SplineiCSD_01(self): # construct interpolators, spline method assume underlying source # pattern generating LFPs that are cubic spline interpolates between # contacts so we generate CSD data relying on the same assumption - f_C = interp1d(z_i, C_i, kind='cubic') + f_C = interp1d(z_i, C_i, kind="cubic") f_R = interp1d(z_i, R_i) num_steps = 201 - z_i_i = np.linspace(float(z_i[0]), float( - z_i[-1]), num_steps) * z_i.units + z_i_i = np.linspace(float(z_i[0]), float(z_i[-1]), num_steps) * z_i.units C_i_i = f_C(np.asarray(z_i_i)) * C_i.units R_i_i = f_R(z_i_i) * R_i.units @@ -980,19 +987,18 @@ def test_SplineiCSD_01(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, sigma, plot) spline_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i * 2, - 'sigma': sigma, - 'sigma_top': sigma, - 'num_steps': num_steps, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i * 2, + "sigma": sigma, + "sigma_top": sigma, + "num_steps": num_steps, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } spline_icsd = icsd.SplineiCSD(**spline_input) csd = spline_icsd.get_csd() @@ -1006,20 +1012,20 @@ def test_SplineiCSD_02(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -1028,11 +1034,10 @@ def test_SplineiCSD_02(self): # construct interpolators, spline method assume underlying source # pattern generating LFPs that are cubic spline interpolates between # contacts so we generate CSD data relying on the same assumption - f_C = interp1d(z_i, C_i, kind='cubic') + f_C = interp1d(z_i, C_i, kind="cubic") f_R = interp1d(z_i, R_i) num_steps = 201 - z_i_i = np.linspace(float(z_i[0]), float( - z_i[-1]), num_steps) * z_i.units + z_i_i = np.linspace(float(z_i[0]), float(z_i[-1]), num_steps) * z_i.units C_i_i = f_C(np.asarray(z_i_i)) * C_i.units R_i_i = f_R(z_i_i) * R_i.units @@ -1042,19 +1047,18 @@ def test_SplineiCSD_02(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, sigma, plot) spline_input = { - 'lfp': phi_j * 1E3 * pq.mV / pq.V, - 'coord_electrode': z_j, - 'diam': R_i * 2, - 'sigma': sigma, - 'sigma_top': sigma, - 'num_steps': num_steps, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j * 1e3 * pq.mV / pq.V, + "coord_electrode": z_j, + "diam": R_i * 2, + "sigma": sigma, + "sigma_top": sigma, + "num_steps": num_steps, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } spline_icsd = icsd.SplineiCSD(**spline_input) csd = spline_icsd.get_csd() @@ -1068,20 +1072,20 @@ def test_SplineiCSD_03(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -1090,11 +1094,10 @@ def test_SplineiCSD_03(self): # construct interpolators, spline method assume underlying source # pattern generating LFPs that are cubic spline interpolates between # contacts so we generate CSD data relying on the same assumption - f_C = interp1d(z_i, C_i, kind='cubic') + f_C = interp1d(z_i, C_i, kind="cubic") f_R = interp1d(z_i, R_i) num_steps = 201 - z_i_i = np.linspace(float(z_i[0]), float( - z_i[-1]), num_steps) * z_i.units + z_i_i = np.linspace(float(z_i[0]), float(z_i[-1]), num_steps) * z_i.units C_i_i = f_C(np.asarray(z_i_i)) * C_i.units R_i_i = f_R(z_i_i) * R_i.units @@ -1104,19 +1107,18 @@ def test_SplineiCSD_03(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, sigma, plot) spline_input = { - 'lfp': phi_j, - 'coord_electrode': z_j * 1E3 * pq.mm / pq.m, - 'diam': R_i * 2 * 1E3 * pq.mm / pq.m, - 'sigma': sigma, - 'sigma_top': sigma, - 'num_steps': num_steps, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j * 1e3 * pq.mm / pq.m, + "diam": R_i * 2 * 1e3 * pq.mm / pq.m, + "sigma": sigma, + "sigma_top": sigma, + "num_steps": num_steps, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } spline_icsd = icsd.SplineiCSD(**spline_input) csd = spline_icsd.get_csd() @@ -1130,20 +1132,20 @@ def test_SplineiCSD_04(self): # we will use same source diameter as in ground truth # contact point coordinates - z_j = np.arange(21) * 1E-4 * pq.m + z_j = np.arange(21) * 1e-4 * pq.m # source coordinates z_i = z_j # current source density magnitude C_i = np.zeros(z_i.size) * pq.A / pq.m**3 - C_i[7:12:2] += np.array([-.5, 1., -.5]) * pq.A / pq.m**3 + C_i[7:12:2] += np.array([-0.5, 1.0, -0.5]) * pq.A / pq.m**3 # source radius (delta, step) - R_i = np.ones(z_i.size) * 1E-3 * pq.m + R_i = np.ones(z_i.size) * 1e-3 * pq.m # source height (cylinder) - h_i = np.ones(z_i.size) * 1E-4 * pq.m + h_i = np.ones(z_i.size) * 1e-4 * pq.m # conductivity, use same conductivity for top layer (z_j < 0) sigma = 0.3 * pq.S / pq.m @@ -1152,11 +1154,10 @@ def test_SplineiCSD_04(self): # construct interpolators, spline method assume underlying source # pattern generating LFPs that are cubic spline interpolates between # contacts so we generate CSD data relying on the same assumption - f_C = interp1d(z_i, C_i, kind='cubic') + f_C = interp1d(z_i, C_i, kind="cubic") f_R = interp1d(z_i, R_i) num_steps = 201 - z_i_i = np.linspace(float(z_i[0]), float( - z_i[-1]), num_steps) * z_i.units + z_i_i = np.linspace(float(z_i[0]), float(z_i[-1]), num_steps) * z_i.units C_i_i = f_C(np.asarray(z_i_i)) * C_i.units R_i_i = f_R(z_i_i) * R_i.units @@ -1166,19 +1167,18 @@ def test_SplineiCSD_04(self): plot = False # get LFP and CSD at contacts - phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, - sigma, plot) + phi_j, C_i = get_lfp_of_cylinders(z_j, z_i_i, C_i_i, R_i_i, h_i_i, sigma, plot) spline_input = { - 'lfp': phi_j, - 'coord_electrode': z_j, - 'diam': R_i * 2, - 'sigma': sigma * 1E3 * pq.mS / pq.S, - 'sigma_top': sigma * 1E3 * pq.mS / pq.S, - 'num_steps': num_steps, - 'tol': 1E-12, # Tolerance in numerical integration - 'f_type': 'gaussian', - 'f_order': (3, 1), + "lfp": phi_j, + "coord_electrode": z_j, + "diam": R_i * 2, + "sigma": sigma * 1e3 * pq.mS / pq.S, + "sigma_top": sigma * 1e3 * pq.mS / pq.S, + "num_steps": num_steps, + "tol": 1e-12, # Tolerance in numerical integration + "f_type": "gaussian", + "f_order": (3, 1), } spline_icsd = icsd.SplineiCSD(**spline_input) csd = spline_icsd.get_csd() diff --git a/elephant/test/test_kcsd.py b/elephant/test/test_kcsd.py index ee98b96c1..13ca99984 100644 --- a/elephant/test/test_kcsd.py +++ b/elephant/test/test_kcsd.py @@ -24,152 +24,175 @@ def setUp(self): self.csd_profile = utils.gauss_1d_dipole pots = CSD.generate_lfp(self.csd_profile, self.ele_pos) self.pots = np.reshape(pots, (-1, 1)) - self.test_method = 'KCSD1D' - self.test_params = {'h': 50.} + self.test_method = "KCSD1D" + self.test_params = {"h": 50.0} temp_signals = [] for ii in range(len(self.pots)): temp_signals.append(self.pots[ii]) - self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV, - sampling_rate=1000 * pq.Hz) + self.an_sigs = neo.AnalogSignal( + np.array(temp_signals).T * pq.mV, sampling_rate=1000 * pq.Hz + ) self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm) def test_kcsd1d_estimate(self, cv_params={}): self.test_params.update(cv_params) - result = CSD.estimate_csd(self.an_sigs, method=self.test_method, - **self.test_params) + result = CSD.estimate_csd( + self.an_sigs, method=self.test_method, **self.test_params + ) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) - self.assertEqual(result.times, [0.] * pq.s) + self.assertEqual(result.times, [0.0] * pq.s) self.assertEqual(len(result.annotations.keys()), 1) - true_csd = self.csd_profile(result.annotations['x_coords']) + true_csd = self.csd_profile(result.annotations["x_coords"]) rms = np.linalg.norm(np.array(result[0, :]) - true_csd) rms /= np.linalg.norm(true_csd) - self.assertLess(rms, 0.5, msg='RMS between trueCSD and estimate > 0.5') + self.assertLess(rms, 0.5, msg="RMS between trueCSD and estimate > 0.5") def test_valid_inputs(self): - self.test_method = 'InvalidMethodName' + self.test_method = "InvalidMethodName" self.assertRaises(ValueError, self.test_kcsd1d_estimate) - self.test_method = 'KCSD1D' - self.test_params = {'src_type': 22} + self.test_method = "KCSD1D" + self.test_params = {"src_type": 22} self.assertRaises(KeyError, self.test_kcsd1d_estimate) - self.test_method = 'KCSD1D' - self.test_params = {'InvalidKwarg': 21} + self.test_method = "KCSD1D" + self.test_params = {"InvalidKwarg": 21} self.assertRaises(TypeError, self.test_kcsd1d_estimate) - cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))} + cv_params = {"InvalidCVArg": np.array((0.1, 0.25, 0.5))} self.assertRaises(TypeError, self.test_kcsd1d_estimate, cv_params) class KCSD2D_TestCase(unittest.TestCase): def setUp(self): - xx_ele, yy_ele = utils.generate_electrodes(dim=2, res=9, - xlims=[0.05, 0.95], - ylims=[0.05, 0.95]) + xx_ele, yy_ele = utils.generate_electrodes( + dim=2, res=9, xlims=[0.05, 0.95], ylims=[0.05, 0.95] + ) self.ele_pos = np.vstack((xx_ele, yy_ele)).T self.csd_profile = utils.large_source_2D - pots = CSD.generate_lfp( - self.csd_profile, - xx_ele, - yy_ele, - resolution=100) + pots = CSD.generate_lfp(self.csd_profile, xx_ele, yy_ele, resolution=100) self.pots = np.reshape(pots, (-1, 1)) - self.test_method = 'KCSD2D' - self.test_params = {'gdx': 0.25, 'gdy': 0.25, 'R_init': 0.08, - 'h': 50., 'xmin': 0., 'xmax': 1., - 'ymin': 0., 'ymax': 1.} + self.test_method = "KCSD2D" + self.test_params = { + "gdx": 0.25, + "gdy": 0.25, + "R_init": 0.08, + "h": 50.0, + "xmin": 0.0, + "xmax": 1.0, + "ymin": 0.0, + "ymax": 1.0, + } temp_signals = [] for ii in range(len(self.pots)): temp_signals.append(self.pots[ii]) - self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV, - sampling_rate=1000 * pq.Hz) + self.an_sigs = neo.AnalogSignal( + np.array(temp_signals).T * pq.mV, sampling_rate=1000 * pq.Hz + ) self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm) def test_kcsd2d_estimate(self, cv_params={}): self.test_params.update(cv_params) - result = CSD.estimate_csd(self.an_sigs, method=self.test_method, - **self.test_params) + result = CSD.estimate_csd( + self.an_sigs, method=self.test_method, **self.test_params + ) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) - self.assertEqual(result.times, [0.] * pq.s) + self.assertEqual(result.times, [0.0] * pq.s) self.assertEqual(len(result.annotations.keys()), 2) - true_csd = self.csd_profile(result.annotations['x_coords'], - result.annotations['y_coords']) + true_csd = self.csd_profile( + result.annotations["x_coords"], result.annotations["y_coords"] + ) rms = np.linalg.norm(np.array(result[0, :]) - true_csd) rms /= np.linalg.norm(true_csd) - self.assertLess(rms, 0.5, msg='RMS ' + str(rms) + - 'between trueCSD and estimate > 0.5') + self.assertLess( + rms, 0.5, msg="RMS " + str(rms) + "between trueCSD and estimate > 0.5" + ) def test_moi_estimate(self): - result = CSD.estimate_csd(self.an_sigs, method='MoIKCSD', - MoI_iters=10, lambd=0.0, - gdx=0.2, gdy=0.2) + result = CSD.estimate_csd( + self.an_sigs, method="MoIKCSD", MoI_iters=10, lambd=0.0, gdx=0.2, gdy=0.2 + ) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) - self.assertEqual(result.times, [0.] * pq.s) + self.assertEqual(result.times, [0.0] * pq.s) self.assertEqual(len(result.annotations.keys()), 2) def test_valid_inputs(self): - self.test_method = 'InvalidMethodName' + self.test_method = "InvalidMethodName" self.assertRaises(ValueError, self.test_kcsd2d_estimate) - self.test_method = 'KCSD2D' - self.test_params = {'src_type': 22} + self.test_method = "KCSD2D" + self.test_params = {"src_type": 22} self.assertRaises(KeyError, self.test_kcsd2d_estimate) - self.test_params = {'InvalidKwarg': 21} + self.test_params = {"InvalidKwarg": 21} self.assertRaises(TypeError, self.test_kcsd2d_estimate) - cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))} + cv_params = {"InvalidCVArg": np.array((0.1, 0.25, 0.5))} self.assertRaises(TypeError, self.test_kcsd2d_estimate, cv_params) class KCSD3D_TestCase(unittest.TestCase): def setUp(self): - xx_ele, yy_ele, zz_ele = utils.generate_electrodes(dim=3, res=5, - xlims=[0.15, 0.85], - ylims=[0.15, 0.85], - zlims=[0.15, 0.85]) + xx_ele, yy_ele, zz_ele = utils.generate_electrodes( + dim=3, res=5, xlims=[0.15, 0.85], ylims=[0.15, 0.85], zlims=[0.15, 0.85] + ) self.ele_pos = np.vstack((xx_ele, yy_ele, zz_ele)).T self.csd_profile = utils.gauss_3d_dipole pots = CSD.generate_lfp(self.csd_profile, xx_ele, yy_ele, zz_ele) self.pots = np.reshape(pots, (-1, 1)) - self.test_method = 'KCSD3D' - self.test_params = {'gdx': 0.05, 'gdy': 0.05, 'gdz': 0.05, - 'lambd': 5.10896977451e-19, 'src_type': 'step', - 'R_init': 0.31, 'xmin': 0., 'xmax': 1., 'ymin': 0., - 'ymax': 1., 'zmin': 0., 'zmax': 1.} + self.test_method = "KCSD3D" + self.test_params = { + "gdx": 0.05, + "gdy": 0.05, + "gdz": 0.05, + "lambd": 5.10896977451e-19, + "src_type": "step", + "R_init": 0.31, + "xmin": 0.0, + "xmax": 1.0, + "ymin": 0.0, + "ymax": 1.0, + "zmin": 0.0, + "zmax": 1.0, + } temp_signals = [] for ii in range(len(self.pots)): temp_signals.append(self.pots[ii]) - self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV, - sampling_rate=1000 * pq.Hz) + self.an_sigs = neo.AnalogSignal( + np.array(temp_signals).T * pq.mV, sampling_rate=1000 * pq.Hz + ) self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm) def test_kcsd3d_estimate(self, cv_params={}): self.test_params.update(cv_params) - result = CSD.estimate_csd(self.an_sigs, method=self.test_method, - **self.test_params) + result = CSD.estimate_csd( + self.an_sigs, method=self.test_method, **self.test_params + ) self.assertEqual(result.t_start, 0.0 * pq.s) self.assertEqual(result.sampling_rate, 1000 * pq.Hz) - self.assertEqual(result.times, [0.] * pq.s) + self.assertEqual(result.times, [0.0] * pq.s) self.assertEqual(len(result.annotations.keys()), 3) - true_csd = self.csd_profile(result.annotations['x_coords'], - result.annotations['y_coords'], - result.annotations['z_coords']) + true_csd = self.csd_profile( + result.annotations["x_coords"], + result.annotations["y_coords"], + result.annotations["z_coords"], + ) rms = np.linalg.norm(np.array(result[0, :]) - true_csd) rms /= np.linalg.norm(true_csd) - self.assertLess(rms, 0.5, msg='RMS ' + str(rms) + - ' between trueCSD and estimate > 0.5') + self.assertLess( + rms, 0.5, msg="RMS " + str(rms) + " between trueCSD and estimate > 0.5" + ) def test_valid_inputs(self): - self.test_method = 'InvalidMethodName' + self.test_method = "InvalidMethodName" self.assertRaises(ValueError, self.test_kcsd3d_estimate) - self.test_method = 'KCSD3D' - self.test_params = {'src_type': 22} + self.test_method = "KCSD3D" + self.test_params = {"src_type": 22} self.assertRaises(KeyError, self.test_kcsd3d_estimate) - self.test_params = {'InvalidKwarg': 21} + self.test_params = {"InvalidKwarg": 21} self.assertRaises(TypeError, self.test_kcsd3d_estimate) - cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))} + cv_params = {"InvalidCVArg": np.array((0.1, 0.25, 0.5))} self.assertRaises(TypeError, self.test_kcsd3d_estimate, cv_params) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_kernels.py b/elephant/test/test_kernels.py index c31f87600..927c5c669 100644 --- a/elephant/test/test_kernels.py +++ b/elephant/test/test_kernels.py @@ -21,43 +21,37 @@ class kernel_TestCase(unittest.TestCase): def setUp(self): self.kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.Kernel) and - kern_cls is not kernels.Kernel and - kern_cls is not kernels.SymmetricKernel) + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.Kernel) + and kern_cls is not kernels.Kernel + and kern_cls is not kernels.SymmetricKernel + ) def test_error_kernels(self): """ Test of various error cases in the kernels module. """ # pass multidimensional sigma - self.assertRaises(TypeError, kernels.RectangularKernel, - sigma=[2.0, 2.3] * pq.s) + self.assertRaises(TypeError, kernels.RectangularKernel, sigma=[2.0, 2.3] * pq.s) # pass one-dimensional sigma, this should be handled - self.assertEqual( - kernels.RectangularKernel(sigma=[2.0] * pq.s).sigma.ndim, 0) + self.assertEqual(kernels.RectangularKernel(sigma=[2.0] * pq.s).sigma.ndim, 0) self.assertRaises(TypeError, kernels.RectangularKernel, sigma=2.0) - self.assertRaises( - ValueError, kernels.RectangularKernel, sigma=-0.03 * pq.s) - self.assertRaises( - ValueError, kernels.AlphaKernel, sigma=2.0 * pq.ms, - invert=2) + self.assertRaises(ValueError, kernels.RectangularKernel, sigma=-0.03 * pq.s) + self.assertRaises(ValueError, kernels.AlphaKernel, sigma=2.0 * pq.ms, invert=2) rec_kernel = kernels.RectangularKernel(sigma=0.3 * pq.ms) - self.assertRaises( - TypeError, rec_kernel, [1, 2, 3]) - self.assertRaises( - TypeError, rec_kernel, [1, 2, 3] * pq.V) + self.assertRaises(TypeError, rec_kernel, [1, 2, 3]) + self.assertRaises(TypeError, rec_kernel, [1, 2, 3] * pq.V) kernel = kernels.Kernel(sigma=0.3 * pq.ms) + self.assertRaises(NotImplementedError, kernel._evaluate, [1, 2, 3] * pq.V) self.assertRaises( - NotImplementedError, kernel._evaluate, [1, 2, 3] * pq.V) + NotImplementedError, kernel.boundary_enclosing_area_fraction, fraction=0.9 + ) self.assertRaises( - NotImplementedError, kernel.boundary_enclosing_area_fraction, - fraction=0.9) - self.assertRaises(TypeError, - rec_kernel.boundary_enclosing_area_fraction, [1, 2]) - self.assertRaises(ValueError, - rec_kernel.boundary_enclosing_area_fraction, -10) + TypeError, rec_kernel.boundary_enclosing_area_fraction, [1, 2] + ) + self.assertRaises(ValueError, rec_kernel.boundary_enclosing_area_fraction, -10) self.assertEqual(kernel.is_symmetric(), False) self.assertEqual(rec_kernel.is_symmetric(), True) @@ -73,16 +67,17 @@ def test_kernels_normalization(self): sigma = 0.1 * pq.mV fraction = 0.9999 kernel_resolution = sigma / 100.0 - kernel_list = [kernel_type(sigma, invert=False) for - kernel_type in self.kernel_types] + kernel_list = [ + kernel_type(sigma, invert=False) for kernel_type in self.kernel_types + ] for kernel in kernel_list: b = kernel.boundary_enclosing_area_fraction(fraction).magnitude n_points = int(2 * b / kernel_resolution.magnitude) - restric_defdomain = np.linspace( - -b, b, num=n_points) * sigma.units + restric_defdomain = np.linspace(-b, b, num=n_points) * sigma.units kern = kernel(restric_defdomain) - norm = spint.cumulative_trapezoid(y=kern.magnitude, - x=restric_defdomain.magnitude)[-1] + norm = spint.cumulative_trapezoid( + y=kern.magnitude, x=restric_defdomain.magnitude + )[-1] self.assertAlmostEqual(norm, 1, delta=0.003) def test_kernels_stddev(self): @@ -94,23 +89,28 @@ def test_kernels_stddev(self): fraction = 0.9999 kernel_resolution = sigma / 50.0 for invert in (False, True): - kernel_list = [kernel_type(sigma, invert) for - kernel_type in self.kernel_types] + kernel_list = [ + kernel_type(sigma, invert) for kernel_type in self.kernel_types + ] for kernel in kernel_list: - b = kernel.boundary_enclosing_area_fraction( - fraction).magnitude + b = kernel.boundary_enclosing_area_fraction(fraction).magnitude n_points = int(2 * b / kernel_resolution.magnitude) - restric_defdomain = np.linspace( - -b, b, num=n_points) * sigma.units + restric_defdomain = np.linspace(-b, b, num=n_points) * sigma.units kern = kernel(restric_defdomain) av_integr = kern * restric_defdomain - average = spint.cumulative_trapezoid( - y=av_integr.magnitude, - x=restric_defdomain.magnitude)[-1] * sigma.units + average = ( + spint.cumulative_trapezoid( + y=av_integr.magnitude, x=restric_defdomain.magnitude + )[-1] + * sigma.units + ) var_integr = (restric_defdomain - average) ** 2 * kern - variance = spint.cumulative_trapezoid( - y=var_integr.magnitude, - x=restric_defdomain.magnitude)[-1] * sigma.units ** 2 + variance = ( + spint.cumulative_trapezoid( + y=var_integr.magnitude, x=restric_defdomain.magnitude + )[-1] + * sigma.units**2 + ) stddev = np.sqrt(variance) self.assertAlmostEqual(stddev, sigma, delta=0.01 * sigma) @@ -123,17 +123,18 @@ def test_kernel_boundary_enclosing(self): """ sigma = 0.5 * pq.s kernel_resolution = sigma / 500.0 - kernel_list = [kernel_type(sigma, invert=False) for - kernel_type in self.kernel_types] + kernel_list = [ + kernel_type(sigma, invert=False) for kernel_type in self.kernel_types + ] for fraction in np.arange(0.15, 1.0, 0.4): for kernel in kernel_list: b = kernel.boundary_enclosing_area_fraction(fraction).magnitude n_points = int(2 * b / kernel_resolution.magnitude) - restric_defdomain = np.linspace( - -b, b, num=n_points) * sigma.units + restric_defdomain = np.linspace(-b, b, num=n_points) * sigma.units kern = kernel(restric_defdomain) - frac = spint.cumulative_trapezoid(y=kern.magnitude, - x=restric_defdomain.magnitude)[-1] + frac = spint.cumulative_trapezoid( + y=kern.magnitude, x=restric_defdomain.magnitude + )[-1] self.assertAlmostEqual(frac, fraction, delta=0.002) def test_kernel_output_same_size(self): @@ -177,7 +178,7 @@ def test_boundary_enclosing_area_fraction(self): for fraction in fractions_test: self.assertAlmostEqual( kernel.boundary_enclosing_area_fraction(fraction), - kernel_inverted.boundary_enclosing_area_fraction(fraction) + kernel_inverted.boundary_enclosing_area_fraction(fraction), ) def test_icdf(self): @@ -190,7 +191,7 @@ def test_icdf(self): # ICDF(0) for several kernels produces -inf # of fsolve complains about stuck at local optima with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) + warnings.simplefilter("ignore", RuntimeWarning) icdf = kernel.icdf(fraction) icdf_inverted = kernel_inverted.icdf(fraction) if kernel.is_symmetric(): @@ -210,9 +211,10 @@ def test_cdf_icdf(self): # ICDF(0) for several kernels produces -inf # of fsolve complains about stuck at local optima with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) + warnings.simplefilter("ignore", RuntimeWarning) self.assertAlmostEqual( - kernel.cdf(kernel.icdf(fraction)), fraction) + kernel.cdf(kernel.icdf(fraction)), fraction + ) def test_icdf_cdf(self): sigma = 1 * pq.s @@ -222,19 +224,19 @@ def test_icdf_cdf(self): kernel = kern_cls(sigma=sigma, invert=invert) for t in times: cdf = kernel.cdf(t) - self.assertGreaterEqual(cdf, 0.) - self.assertLessEqual(cdf, 1.) + self.assertGreaterEqual(cdf, 0.0) + self.assertLessEqual(cdf, 1.0) if 0 < cdf < 1: - self.assertAlmostEqual( - kernel.icdf(cdf), t, places=2) + self.assertAlmostEqual(kernel.icdf(cdf), t, places=2) def test_icdf_at_1(self): sigma = 1 * pq.s for kern_cls in self.kernel_types: for invert in (False, True): kernel = kern_cls(sigma=sigma, invert=invert) - if isinstance(kernel, (kernels.RectangularKernel, - kernels.TriangularKernel)): + if isinstance( + kernel, (kernels.RectangularKernel, kernels.TriangularKernel) + ): icdf = kernel.icdf(1.0) # check finite self.assertLess(np.abs(icdf.magnitude), np.inf) @@ -245,8 +247,10 @@ def test_cdf_symmetric(self): sigma = 1 * pq.s cutoff = 1e2 * sigma # a large value times = np.linspace(-cutoff, cutoff, num=10) - kern_symmetric = filter(lambda kern_type: issubclass( - kern_type, kernels.SymmetricKernel), self.kernel_types) + kern_symmetric = filter( + lambda kern_type: issubclass(kern_type, kernels.SymmetricKernel), + self.kernel_types, + ) for kern_cls in kern_symmetric: kernel = kern_cls(sigma=sigma, invert=False) kernel_inverted = kern_cls(sigma=sigma, invert=True) @@ -257,11 +261,13 @@ def test_cdf_symmetric(self): class KernelOldImplementation(unittest.TestCase): def setUp(self): self.kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.Kernel) and - kern_cls is not kernels.Kernel and - kern_cls is not kernels.SymmetricKernel) + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.Kernel) + and kern_cls is not kernels.Kernel + and kern_cls is not kernels.SymmetricKernel + ) self.sigma = 1 * pq.s self.time_input = np.linspace(-10, 10, num=100) * self.sigma.units @@ -276,8 +282,9 @@ def evaluate_old(t): for invert in (False, True): kernel = kernels.TriangularKernel(self.sigma, invert=invert) - assert_array_almost_equal(kernel(self.time_input), - evaluate_old(self.time_input)) + assert_array_almost_equal( + kernel(self.time_input), evaluate_old(self.time_input) + ) def test_gaussian(self): def evaluate_old(t): @@ -285,14 +292,16 @@ def evaluate_old(t): t = t.magnitude sigma = kernel.sigma.rescale(t_units).magnitude kernel_pdf = (1.0 / (math.sqrt(2.0 * math.pi) * sigma)) * np.exp( - -0.5 * (t / sigma) ** 2) + -0.5 * (t / sigma) ** 2 + ) kernel_pdf = pq.Quantity(kernel_pdf, units=1 / t_units) return kernel_pdf for invert in (False, True): kernel = kernels.GaussianKernel(self.sigma, invert=invert) - assert_array_almost_equal(kernel(self.time_input), - evaluate_old(self.time_input)) + assert_array_almost_equal( + kernel(self.time_input), evaluate_old(self.time_input) + ) def test_laplacian(self): def evaluate_old(t): @@ -305,8 +314,9 @@ def evaluate_old(t): for invert in (False, True): kernel = kernels.LaplacianKernel(self.sigma, invert=invert) - assert_array_almost_equal(kernel(self.time_input), - evaluate_old(self.time_input)) + assert_array_almost_equal( + kernel(self.time_input), evaluate_old(self.time_input) + ) def test_exponential(self): def evaluate_old(t): @@ -322,18 +332,21 @@ def evaluate_old(t): for invert in (False, True): kernel = kernels.ExponentialKernel(self.sigma, invert=invert) - assert_array_almost_equal(kernel(self.time_input), - evaluate_old(self.time_input)) + assert_array_almost_equal( + kernel(self.time_input), evaluate_old(self.time_input) + ) class KernelMedianIndex(unittest.TestCase): def setUp(self): kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.Kernel) and - kern_cls is not kernels.Kernel and - kern_cls is not kernels.SymmetricKernel) + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.Kernel) + and kern_cls is not kernels.Kernel + and kern_cls is not kernels.SymmetricKernel + ) self.sigma = 1 * pq.s self.time_input = np.linspace(-10, 10, num=100) * self.sigma.units self.kernels = [] @@ -358,13 +371,12 @@ def test_not_sorted(self): def test_non_support(self): time_negative = np.linspace(-100, -20) * pq.s for kernel in self.kernels: - if isinstance(kernel, (kernels.GaussianKernel, - kernels.LaplacianKernel)): + if isinstance(kernel, (kernels.GaussianKernel, kernels.LaplacianKernel)): continue kernel.invert = False median_id = kernel.median_index(time_negative) self.assertEqual(median_id, len(time_negative) // 2) - self.assertAlmostEqual(kernel.cdf(time_negative[median_id]), 0.) + self.assertAlmostEqual(kernel.cdf(time_negative[median_id]), 0.0) def test_old_implementation(self): def median_index(t): @@ -382,5 +394,5 @@ def median_index(t): self.assertLessEqual(abs(median_id - median_id_old), 1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_neo_tools.py b/elephant/test/test_neo_tools.py index 3c8597e6e..10589729d 100644 --- a/elephant/test/test_neo_tools.py +++ b/elephant/test/test_neo_tools.py @@ -5,20 +5,26 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + import random from itertools import chain import copy import unittest import neo.core + # TODO: In Neo 0.10.0, SpikeTrainList ist not exposed in __init__.py of # neo.core. Remove the following line if SpikeTrainList is accessible via # neo.core from neo.core.spiketrainlist import SpikeTrainList -from neo.test.generate_datasets import generate_one_simple_block, \ - generate_one_simple_segment, \ - random_event, random_epoch, random_spiketrain +from neo.test.generate_datasets import ( + generate_one_simple_block, + generate_one_simple_segment, + random_event, + random_epoch, + random_spiketrain, +) from neo.test.tools import assert_same_sub_schema from numpy.testing import assert_array_equal @@ -26,16 +32,17 @@ import elephant.neo_tools as nt # A list of neo object attributes that contain arrays. -ARRAY_ATTRS = ['waveforms', - 'times', - 'durations', - 'labels', - # 'index', - 'channel_names', - 'channel_ids', - 'coordinates', - 'array_annotations' - ] +ARRAY_ATTRS = [ + "waveforms", + "times", + "durations", + "labels", + # 'index', + "channel_names", + "channel_ids", + "coordinates", + "array_annotations", +] def strip_iter_values(targ, array_attrs=ARRAY_ATTRS): @@ -80,41 +87,40 @@ def strip_iter_values(targ, array_attrs=ARRAY_ATTRS): class GetAllObjsTestCase(unittest.TestCase): def setUp(self): random.seed(4245) - self.spiketrain = random_spiketrain( - 'Single SpikeTrain', seed=random.random()) + self.spiketrain = random_spiketrain("Single SpikeTrain", seed=random.random()) self.spiketrain_list = [ - random_spiketrain('SpikeTrain', seed=random.random()), - random_spiketrain('SpikeTrain', seed=random.random())] + random_spiketrain("SpikeTrain", seed=random.random()), + random_spiketrain("SpikeTrain", seed=random.random()), + ] self.spiketrain_dict = { - 'a': random_spiketrain('SpikeTrain', seed=random.random()), - 123: random_spiketrain('SpikeTrain', seed=random.random())} + "a": random_spiketrain("SpikeTrain", seed=random.random()), + 123: random_spiketrain("SpikeTrain", seed=random.random()), + } self.epoch = random_epoch() - self.epoch_list = [ - random_epoch(), random_epoch()] - self.epoch_dict = { - 'a': random_epoch(), 123: random_epoch()} + self.epoch_list = [random_epoch(), random_epoch()] + self.epoch_dict = {"a": random_epoch(), 123: random_epoch()} def test__get_all_objs__float_valueerror(self): - value = 5. + value = 5.0 with self.assertRaises(ValueError): - nt._get_all_objs(value, 'Block') + nt._get_all_objs(value, "Block") def test__get_all_objs__list_float_valueerror(self): - value = [5.] + value = [5.0] with self.assertRaises(ValueError): - nt._get_all_objs(value, 'Block') + nt._get_all_objs(value, "Block") def test__get_all_objs__epoch_for_event_valueerror(self): value = self.epoch with self.assertRaises(ValueError): - nt._get_all_objs(value, 'Event') + nt._get_all_objs(value, "Event") def test__get_all_objs__empty_list(self): targ = [] value = [] - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) @@ -122,7 +128,7 @@ def test__get_all_objs__empty_nested_list(self): targ = [] value = [[], [[], [[]]]] - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) @@ -130,15 +136,15 @@ def test__get_all_objs__empty_dict(self): targ = [] value = {} - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) def test__get_all_objs__empty_nested_dict(self): targ = [] - value = {'a': {}, 'b': {'c': {}, 'd': {'e': {}}}} + value = {"a": {}, "b": {"c": {}, "d": {"e": {}}}} - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) @@ -146,7 +152,7 @@ def test__get_all_objs__empty_itert(self): targ = [] value = iter([]) - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) @@ -154,15 +160,15 @@ def test__get_all_objs__empty_nested_iter(self): targ = [] value = iter([iter([]), iter([iter([]), iter([iter([])])])]) - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) def test__get_all_objs__empty_nested_many(self): targ = [] - value = iter([[], {'c': [], 'd': (iter([]),)}]) + value = iter([[], {"c": [], "d": (iter([]),)}]) - res = nt._get_all_objs(value, 'Block') + res = nt._get_all_objs(value, "Block") self.assertEqual(targ, res) @@ -170,7 +176,7 @@ def test__get_all_objs__spiketrain(self): value = self.spiketrain targ = [self.spiketrain] - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") assert_same_sub_schema(targ, res) @@ -178,7 +184,7 @@ def test__get_all_objs__list_spiketrain(self): value = self.spiketrain_list targ = self.spiketrain_list - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") assert_same_sub_schema(targ, res) @@ -186,16 +192,15 @@ def test__get_all_objs__nested_list_epoch(self): targ = self.epoch_list value = [self.epoch_list] - res = nt._get_all_objs(value, 'Epoch') + res = nt._get_all_objs(value, "Epoch") assert_same_sub_schema(targ, res) def test__get_all_objs__iter_spiketrain(self): targ = self.spiketrain_list - value = iter([self.spiketrain_list[0], - self.spiketrain_list[1]]) + value = iter([self.spiketrain_list[0], self.spiketrain_list[1]]) - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") assert_same_sub_schema(targ, res) @@ -203,15 +208,15 @@ def test__get_all_objs__nested_iter_epoch(self): targ = self.epoch_list value = iter([iter(self.epoch_list)]) - res = nt._get_all_objs(value, 'Epoch') + res = nt._get_all_objs(value, "Epoch") assert_same_sub_schema(targ, res) def test__get_all_objs__dict_spiketrain(self): - targ = [self.spiketrain_dict['a'], self.spiketrain_dict[123]] + targ = [self.spiketrain_dict["a"], self.spiketrain_dict[123]] value = self.spiketrain_dict - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") self.assertEqual(len(targ), len(res)) for t, r in zip(targ, res): @@ -219,54 +224,50 @@ def test__get_all_objs__dict_spiketrain(self): def test__get_all_objs__nested_dict_spiketrain(self): targ = self.spiketrain_list - value = {'a': self.spiketrain_list[0], - 'b': {'c': self.spiketrain_list[1]}} + value = {"a": self.spiketrain_list[0], "b": {"c": self.spiketrain_list[1]}} - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") self.assertEqual(len(targ), len(res)) for i, itarg in enumerate(targ): for ires in res: - if itarg.annotations['seed'] == ires.annotations['seed']: + if itarg.annotations["seed"] == ires.annotations["seed"]: assert_same_sub_schema(itarg, ires) break else: - raise ValueError('Target %s not in result' % i) + raise ValueError("Target %s not in result" % i) def test__get_all_objs__nested_many_spiketrain(self): targ = self.spiketrain_list - value = {'a': [self.spiketrain_list[0]], - 'b': iter([self.spiketrain_list[1]])} + value = {"a": [self.spiketrain_list[0]], "b": iter([self.spiketrain_list[1]])} - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") self.assertEqual(len(targ), len(res)) for i, itarg in enumerate(targ): for ires in res: - if itarg.annotations['seed'] == ires.annotations['seed']: + if itarg.annotations["seed"] == ires.annotations["seed"]: assert_same_sub_schema(itarg, ires) break else: - raise ValueError('Target %s not in result' % i) + raise ValueError("Target %s not in result" % i) def test__get_all_objs__unit_spiketrain(self): - value = neo.core.Group( - self.spiketrain_list, - name='Unit') + value = neo.core.Group(self.spiketrain_list, name="Unit") targ = self.spiketrain_list for train in value.spiketrains: - train.annotations.pop('i', None) - train.annotations.pop('j', None) + train.annotations.pop("i", None) + train.annotations.pop("j", None) - res = nt._get_all_objs(value, 'SpikeTrain') + res = nt._get_all_objs(value, "SpikeTrain") assert_same_sub_schema(targ, res) def test__get_all_objs__block_epoch(self): - value = generate_one_simple_block('Block', n=3, seed=0) - targ = [train for train in value.list_children_by_class('Epoch')] - res = nt._get_all_objs(value, 'Epoch') + value = generate_one_simple_block("Block", n=3, seed=0) + targ = [train for train in value.list_children_by_class("Epoch")] + res = nt._get_all_objs(value, "Epoch") assert_same_sub_schema(targ, res) @@ -277,9 +278,13 @@ def setUp(self): self.block = generate_one_simple_block( nb_segment=3, supported_objects=[ - neo.core.Block, neo.core.Segment, + neo.core.Block, + neo.core.Segment, neo.core.SpikeTrain, - neo.core.Event, neo.core.Epoch]) + neo.core.Event, + neo.core.Epoch, + ], + ) def assert_dicts_equal(self, d1, d2): """Assert that two dictionaries are equal, taking into account arrays. @@ -306,18 +311,19 @@ def assert_dicts_equal(self, d1, d2): try: self.assertEqual(d1, d2) except ValueError: - for (key1, value1), (key2, value2) in zip(sorted(d1.items()), - sorted(d2.items())): + for (key1, value1), (key2, value2) in zip( + sorted(d1.items()), sorted(d2.items()) + ): self.assertEqual(key1, key2) try: - if hasattr(value1, 'keys') and hasattr(value2, 'keys'): + if hasattr(value1, "keys") and hasattr(value2, "keys"): self.assert_dicts_equal(value1, value2) - elif hasattr(value1, 'dtype') and hasattr(value2, 'dtype'): + elif hasattr(value1, "dtype") and hasattr(value2, "dtype"): assert_array_equal(value1, value2) else: self.assertEqual(value1, value2) except BaseException as exc: - exc.args += ('key: %s' % key1,) + exc.args += ("key: %s" % key1,) raise def test__extract_neo_attrs__spiketrain_noarray(self): @@ -325,21 +331,25 @@ def test__extract_neo_attrs__spiketrain_noarray(self): targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=True) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) res01 = nt.extract_neo_attributes(obj, parents=True, skip_array=True) - res11 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res21 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res11 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res21 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) self.assertEqual(targ, res00) self.assertEqual(targ, res10) @@ -353,8 +363,8 @@ def test__extract_neo_attrs__spiketrain_noarray_skip_none(self): targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) @@ -362,18 +372,24 @@ def test__extract_neo_attrs__spiketrain_noarray_skip_none(self): if value is None: del targ[key] - res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - skip_none=True) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True, skip_none=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False, skip_none=True) - res01 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - skip_none=True) - res11 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True, skip_none=True) - res21 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False, skip_none=True) + res00 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, skip_none=True + ) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True, skip_none=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False, skip_none=True + ) + res01 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, skip_none=True + ) + res11 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True, skip_none=True + ) + res21 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False, skip_none=True + ) self.assertEqual(targ, res00) self.assertEqual(targ, res10) @@ -386,21 +402,25 @@ def test__extract_neo_attrs__epoch_noarray(self): obj = random_epoch() targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=True) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) res01 = nt.extract_neo_attributes(obj, parents=True, skip_array=True) - res11 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res21 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res11 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res21 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) self.assertEqual(targ, res00) self.assertEqual(targ, res10) @@ -413,21 +433,25 @@ def test__extract_neo_attrs__event_noarray(self): obj = random_event() targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=True) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) res01 = nt.extract_neo_attributes(obj, parents=True, skip_array=True) - res11 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res21 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res11 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res21 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) self.assertEqual(targ, res00) self.assertEqual(targ, res10) @@ -440,31 +464,31 @@ def test__extract_neo_attrs__spiketrain_parents_empty_array(self): obj = random_spiketrain() targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) - del targ['times'] + del targ["times"] res000 = nt.extract_neo_attributes(obj, parents=False) - res100 = nt.extract_neo_attributes( - obj, parents=False, child_first=True) - res200 = nt.extract_neo_attributes( - obj, parents=False, child_first=False) - res010 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False) + res100 = nt.extract_neo_attributes(obj, parents=False, child_first=True) + res200 = nt.extract_neo_attributes(obj, parents=False, child_first=False) + res010 = nt.extract_neo_attributes(obj, parents=False, skip_array=False) res110 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False, child_first=True) + obj, parents=False, skip_array=False, child_first=True + ) res210 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False, child_first=False) + obj, parents=False, skip_array=False, child_first=False + ) res001 = nt.extract_neo_attributes(obj, parents=True) res101 = nt.extract_neo_attributes(obj, parents=True, child_first=True) - res201 = nt.extract_neo_attributes( - obj, parents=True, child_first=False) + res201 = nt.extract_neo_attributes(obj, parents=True, child_first=False) res011 = nt.extract_neo_attributes(obj, parents=True, skip_array=False) - res111 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=True) - res211 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=False) + res111 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=True + ) + res211 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=False + ) self.assert_dicts_equal(targ, res000) self.assert_dicts_equal(targ, res100) @@ -483,41 +507,41 @@ def test__extract_neo_attrs__spiketrain_parents_empty_array(self): def _fix_neo_issue_749(obj, targ): # TODO: remove once fixed # https://github.com/NeuralEnsemble/python-neo/issues/749 - num_times = len(targ['times']) + num_times = len(targ["times"]) obj = obj[:num_times] - del targ['array_annotations'] + del targ["array_annotations"] return obj def test__extract_neo_attrs__epoch_parents_empty_array(self): obj = random_epoch() targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) - del targ['times'] + del targ["times"] res000 = nt.extract_neo_attributes(obj, parents=False) - res100 = nt.extract_neo_attributes( - obj, parents=False, child_first=True) - res200 = nt.extract_neo_attributes( - obj, parents=False, child_first=False) - res010 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False) + res100 = nt.extract_neo_attributes(obj, parents=False, child_first=True) + res200 = nt.extract_neo_attributes(obj, parents=False, child_first=False) + res010 = nt.extract_neo_attributes(obj, parents=False, skip_array=False) res110 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False, child_first=True) + obj, parents=False, skip_array=False, child_first=True + ) res210 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False, child_first=False) + obj, parents=False, skip_array=False, child_first=False + ) res001 = nt.extract_neo_attributes(obj, parents=True) res101 = nt.extract_neo_attributes(obj, parents=True, child_first=True) - res201 = nt.extract_neo_attributes( - obj, parents=True, child_first=False) + res201 = nt.extract_neo_attributes(obj, parents=True, child_first=False) res011 = nt.extract_neo_attributes(obj, parents=True, skip_array=False) - res111 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=True) - res211 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=False) + res111 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=True + ) + res211 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=False + ) self.assert_dicts_equal(targ, res000) self.assert_dicts_equal(targ, res100) @@ -536,31 +560,31 @@ def test__extract_neo_attrs__event_parents_empty_array(self): obj = random_event() targ = copy.deepcopy(obj.annotations) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) - del targ['times'] + del targ["times"] res000 = nt.extract_neo_attributes(obj, parents=False) - res100 = nt.extract_neo_attributes( - obj, parents=False, child_first=True) - res200 = nt.extract_neo_attributes( - obj, parents=False, child_first=False) - res010 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False) + res100 = nt.extract_neo_attributes(obj, parents=False, child_first=True) + res200 = nt.extract_neo_attributes(obj, parents=False, child_first=False) + res010 = nt.extract_neo_attributes(obj, parents=False, skip_array=False) res110 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False, child_first=True) + obj, parents=False, skip_array=False, child_first=True + ) res210 = nt.extract_neo_attributes( - obj, parents=False, skip_array=False, child_first=False) + obj, parents=False, skip_array=False, child_first=False + ) res001 = nt.extract_neo_attributes(obj, parents=True) res101 = nt.extract_neo_attributes(obj, parents=True, child_first=True) - res201 = nt.extract_neo_attributes( - obj, parents=True, child_first=False) + res201 = nt.extract_neo_attributes(obj, parents=True, child_first=False) res011 = nt.extract_neo_attributes(obj, parents=True, skip_array=False) - res111 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=True) - res211 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=False) + res111 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=True + ) + res211 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=False + ) self.assert_dicts_equal(targ, res000) self.assert_dicts_equal(targ, res100) @@ -576,91 +600,98 @@ def test__extract_neo_attrs__event_parents_empty_array(self): self.assert_dicts_equal(targ, res211) def test__extract_neo_attrs__spiketrain_noparents_noarray(self): - obj = self.block.list_children_by_class('SpikeTrain')[0] + obj = self.block.list_children_by_class("SpikeTrain")[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res0 = nt.extract_neo_attributes(obj, parents=False, skip_array=True) - res1 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res2 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res1 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res2 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) self.assertEqual(targ, res0) self.assertEqual(targ, res1) self.assertEqual(targ, res2) def test__extract_neo_attrs__epoch_noparents_noarray(self): - obj = self.block.list_children_by_class('Epoch')[0] + obj = self.block.list_children_by_class("Epoch")[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) # 'times' is not in obj._necessary_attrs + obj._recommended_attrs targ = strip_iter_values(targ) res0 = nt.extract_neo_attributes(obj, parents=False, skip_array=True) - res1 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res2 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res1 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res2 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) self.assertEqual(targ, res0) self.assertEqual(targ, res1) self.assertEqual(targ, res2) def test__extract_neo_attrs__event_noparents_noarray(self): - obj = self.block.list_children_by_class('Event')[0] + obj = self.block.list_children_by_class("Event")[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res0 = nt.extract_neo_attributes(obj, parents=False, skip_array=True) - res1 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=True) - res2 = nt.extract_neo_attributes(obj, parents=False, skip_array=True, - child_first=False) + res1 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=True + ) + res2 = nt.extract_neo_attributes( + obj, parents=False, skip_array=True, child_first=False + ) self.assertEqual(targ, res0) self.assertEqual(targ, res1) self.assertEqual(targ, res2) def test__extract_neo_attrs__spiketrain_noparents_array(self): - obj = self.block.list_children_by_class('SpikeTrain')[0] + obj = self.block.list_children_by_class("SpikeTrain")[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) # 'times' is not in obj._necessary_attrs + obj._recommended_attrs - del targ['times'] + del targ["times"] res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=False) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=False, - child_first=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=False, - child_first=False) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=False, child_first=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=False, child_first=False + ) res01 = nt.extract_neo_attributes(obj, parents=False) res11 = nt.extract_neo_attributes(obj, parents=False, child_first=True) - res21 = nt.extract_neo_attributes( - obj, parents=False, child_first=False) + res21 = nt.extract_neo_attributes(obj, parents=False, child_first=False) self.assert_dicts_equal(targ, res00) self.assert_dicts_equal(targ, res10) @@ -670,27 +701,28 @@ def test__extract_neo_attrs__spiketrain_noparents_array(self): self.assert_dicts_equal(targ, res21) def test__extract_neo_attrs__epoch_noparents_array(self): - obj = self.block.list_children_by_class('Epoch')[0] + obj = self.block.list_children_by_class("Epoch")[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) # 'times' is not in obj._necessary_attrs + obj._recommended_attrs - del targ['times'] + del targ["times"] res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=False) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=False, - child_first=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=False, - child_first=False) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=False, child_first=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=False, child_first=False + ) res01 = nt.extract_neo_attributes(obj, parents=False) res11 = nt.extract_neo_attributes(obj, parents=False, child_first=True) - res21 = nt.extract_neo_attributes( - obj, parents=False, child_first=False) + res21 = nt.extract_neo_attributes(obj, parents=False, child_first=False) self.assert_dicts_equal(targ, res00) self.assert_dicts_equal(targ, res10) @@ -700,27 +732,28 @@ def test__extract_neo_attrs__epoch_noparents_array(self): self.assert_dicts_equal(targ, res21) def test__extract_neo_attrs__event_noparents_array(self): - obj = self.block.list_children_by_class('Event')[0] + obj = self.block.list_children_by_class("Event")[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) # 'times' is not in obj._necessary_attrs + obj._recommended_attrs - del targ['times'] + del targ["times"] res00 = nt.extract_neo_attributes(obj, parents=False, skip_array=False) - res10 = nt.extract_neo_attributes(obj, parents=False, skip_array=False, - child_first=True) - res20 = nt.extract_neo_attributes(obj, parents=False, skip_array=False, - child_first=False) + res10 = nt.extract_neo_attributes( + obj, parents=False, skip_array=False, child_first=True + ) + res20 = nt.extract_neo_attributes( + obj, parents=False, skip_array=False, child_first=False + ) res01 = nt.extract_neo_attributes(obj, parents=False) res11 = nt.extract_neo_attributes(obj, parents=False, child_first=True) - res21 = nt.extract_neo_attributes( - obj, parents=False, child_first=False) + res21 = nt.extract_neo_attributes(obj, parents=False, child_first=False) self.assert_dicts_equal(targ, res00) self.assert_dicts_equal(targ, res10) @@ -730,223 +763,230 @@ def test__extract_neo_attrs__event_noparents_array(self): self.assert_dicts_equal(targ, res21) def test__extract_neo_attrs__spiketrain_parents_childfirst_noarray(self): - obj = self.block.list_children_by_class('SpikeTrain')[0] + obj = self.block.list_children_by_class("SpikeTrain")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(blk.annotations) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(obj.annotations)) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=True) - res1 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - child_first=True) + res1 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, child_first=True + ) self.assertEqual(targ, res0) self.assertEqual(targ, res1) def test__extract_neo_attrs__epoch_parents_childfirst_noarray(self): - obj = self.block.list_children_by_class('Epoch')[0] + obj = self.block.list_children_by_class("Epoch")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(blk.annotations) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(obj.annotations)) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=True) - res1 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - child_first=True) + res1 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, child_first=True + ) self.assertEqual(targ, res0) self.assertEqual(targ, res1) def test__extract_neo_attrs__event_parents_childfirst_noarray(self): - obj = self.block.list_children_by_class('Event')[0] + obj = self.block.list_children_by_class("Event")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(blk.annotations) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(obj.annotations)) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ = strip_iter_values(targ) res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=True) - res1 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - child_first=True) + res1 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, child_first=True + ) self.assertEqual(targ, res0) self.assertEqual(targ, res1) def test__extract_neo_attrs__spiketrain_parents_parentfirst_noarray(self): - obj = self.block.list_children_by_class('SpikeTrain')[0] + obj = self.block.list_children_by_class("SpikeTrain")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(blk.annotations)) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ = strip_iter_values(targ) - res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - child_first=False) + res0 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, child_first=False + ) self.assertEqual(targ, res0) def test__extract_neo_attrs__epoch_parents_parentfirst_noarray(self): - obj = self.block.list_children_by_class('Epoch')[0] + obj = self.block.list_children_by_class("Epoch")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(blk.annotations)) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ = strip_iter_values(targ) - res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - child_first=False) + res0 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, child_first=False + ) self.assertEqual(targ, res0) def test__extract_neo_attrs__event_parents_parentfirst_noarray(self): - obj = self.block.list_children_by_class('Event')[0] + obj = self.block.list_children_by_class("Event")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(blk.annotations)) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ = strip_iter_values(targ) - res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=True, - child_first=False) + res0 = nt.extract_neo_attributes( + obj, parents=True, skip_array=True, child_first=False + ) self.assertEqual(targ, res0) def test__extract_neo_attrs__spiketrain_parents_childfirst_array(self): - obj = self.block.list_children_by_class('SpikeTrain')[0] + obj = self.block.list_children_by_class("SpikeTrain")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(blk.annotations) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(obj.annotations)) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) - del targ['times'] + del targ["times"] res00 = nt.extract_neo_attributes(obj, parents=True, skip_array=False) - res10 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=True) + res10 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=True + ) res01 = nt.extract_neo_attributes(obj, parents=True) res11 = nt.extract_neo_attributes(obj, parents=True, child_first=True) @@ -956,34 +996,35 @@ def test__extract_neo_attrs__spiketrain_parents_childfirst_array(self): self.assert_dicts_equal(targ, res11) def test__extract_neo_attrs__epoch_parents_childfirst_array(self): - obj = self.block.list_children_by_class('Epoch')[0] + obj = self.block.list_children_by_class("Epoch")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(blk.annotations) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(obj.annotations)) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) - del targ['times'] + del targ["times"] res00 = nt.extract_neo_attributes(obj, parents=True, skip_array=False) - res10 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=True) + res10 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=True + ) res01 = nt.extract_neo_attributes(obj, parents=True) res11 = nt.extract_neo_attributes(obj, parents=True, child_first=True) @@ -993,34 +1034,35 @@ def test__extract_neo_attrs__epoch_parents_childfirst_array(self): self.assert_dicts_equal(targ, res11) def test__extract_neo_attrs__event_parents_childfirst_array(self): - obj = self.block.list_children_by_class('Event')[0] + obj = self.block.list_children_by_class("Event")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(blk.annotations) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(obj.annotations)) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) - del targ['times'] + del targ["times"] res00 = nt.extract_neo_attributes(obj, parents=True, skip_array=False) - res10 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=True) + res10 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=True + ) res01 = nt.extract_neo_attributes(obj, parents=True) res11 = nt.extract_neo_attributes(obj, parents=True, child_first=True) @@ -1030,100 +1072,102 @@ def test__extract_neo_attrs__event_parents_childfirst_array(self): self.assert_dicts_equal(targ, res11) def test__extract_neo_attrs__spiketrain_parents_parentfirst_array(self): - obj = self.block.list_children_by_class('SpikeTrain')[0] + obj = self.block.list_children_by_class("SpikeTrain")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.SpikeTrain._necessary_attrs + - neo.SpikeTrain._recommended_attrs): + neo.SpikeTrain._necessary_attrs + neo.SpikeTrain._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(blk.annotations)) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) - del targ['times'] + del targ["times"] - res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=False) + res0 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=False + ) res1 = nt.extract_neo_attributes(obj, parents=True, child_first=False) self.assert_dicts_equal(targ, res0) self.assert_dicts_equal(targ, res1) def test__extract_neo_attrs__epoch_parents_parentfirst_array(self): - obj = self.block.list_children_by_class('Epoch')[0] + obj = self.block.list_children_by_class("Epoch")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(obj.annotations) targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Epoch._necessary_attrs + - neo.Epoch._recommended_attrs): + neo.Epoch._necessary_attrs + neo.Epoch._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(blk.annotations)) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) - del targ['times'] + del targ["times"] - res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=False) + res0 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=False + ) res1 = nt.extract_neo_attributes(obj, parents=True, child_first=False) self.assert_dicts_equal(targ, res0) self.assert_dicts_equal(targ, res1) def test__extract_neo_attrs__event_parents_parentfirst_array(self): - obj = self.block.list_children_by_class('Event')[0] + obj = self.block.list_children_by_class("Event")[0] blk = self.block seg = self.block.segments[0] targ = copy.deepcopy(obj.annotations) - targ["array_annotations"] = copy.deepcopy( - dict(obj.array_annotations)) + targ["array_annotations"] = copy.deepcopy(dict(obj.array_annotations)) for i, attr in enumerate( - neo.Event._necessary_attrs + - neo.Event._recommended_attrs): + neo.Event._necessary_attrs + neo.Event._recommended_attrs + ): targ[attr[0]] = getattr(obj, attr[0]) targ.update(copy.deepcopy(seg.annotations)) for i, attr in enumerate( - neo.Segment._necessary_attrs + - neo.Segment._recommended_attrs): + neo.Segment._necessary_attrs + neo.Segment._recommended_attrs + ): targ[attr[0]] = getattr(seg, attr[0]) targ.update(copy.deepcopy(blk.annotations)) for i, attr in enumerate( - neo.Block._necessary_attrs + - neo.Block._recommended_attrs): + neo.Block._necessary_attrs + neo.Block._recommended_attrs + ): targ[attr[0]] = getattr(blk, attr[0]) - del targ['times'] + del targ["times"] - res0 = nt.extract_neo_attributes(obj, parents=True, skip_array=False, - child_first=False) + res0 = nt.extract_neo_attributes( + obj, parents=True, skip_array=False, child_first=False + ) res1 = nt.extract_neo_attributes(obj, parents=True, child_first=False) self.assert_dicts_equal(targ, res0) @@ -1166,8 +1210,7 @@ def test__get_all_spiketrains__segment(self): # Generate a simple segment object containing one spike train, # supporting objects of type Segment and SpikeTrain. obj = generate_one_simple_segment( - nb_spiketrain=1, - supported_objects=[neo.core.Segment, neo.core.SpikeTrain] + nb_spiketrain=1, supported_objects=[neo.core.Segment, neo.core.SpikeTrain] ) # Append a deep copy of the first spike train in the segment's # spike train list to itself. @@ -1183,9 +1226,7 @@ def test__get_all_spiketrains__block(self): # Generate a simple block with 3 segments obj = generate_one_simple_block( nb_segment=3, - supported_objects=[neo.core.Block, - neo.core.Segment, - neo.core.SpikeTrain] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.SpikeTrain], ) # Deep copy the generated block for comparison @@ -1200,15 +1241,17 @@ def test__get_all_spiketrains__block(self): res0 = nt.get_all_spiketrains(obj) # Convert the target deep copy to a SpikeTrainList - targ = SpikeTrainList(targ.list_children_by_class('SpikeTrain')) + targ = SpikeTrainList(targ.list_children_by_class("SpikeTrain")) # Perform assertions to validate the results self.assertTrue( - len(res0) > 0, - "The result of get_all_spiketrains should not be empty.") + len(res0) > 0, "The result of get_all_spiketrains should not be empty." + ) self.assertEqual( - len(targ), len(res0), - "The lengths of the SpikeTrainList and result should be equal.") + len(targ), + len(res0), + "The lengths of the SpikeTrainList and result should be equal.", + ) assert_same_sub_schema(targ, res0) def test__get_all_spiketrains__list(self): @@ -1216,15 +1259,20 @@ def test__get_all_spiketrains__list(self): generate_one_simple_block( nb_segment=3, supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) - for _ in range(3)] + neo.core.Block, + neo.core.Segment, + neo.core.SpikeTrain, + ], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) obj.append(obj[-1]) res0 = nt.get_all_spiketrains(obj) - targ = [iobj.list_children_by_class('SpikeTrain') for iobj in targ] + targ = [iobj.list_children_by_class("SpikeTrain") for iobj in targ] targ = SpikeTrainList(list(chain.from_iterable(targ))) self.assertTrue(len(res0) > 0) @@ -1238,8 +1286,13 @@ def test__get_all_spiketrains__tuple(self): generate_one_simple_block( nb_segment=3, supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) - for _ in range(3)] + neo.core.Block, + neo.core.Segment, + neo.core.SpikeTrain, + ], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].spiketrains[1] @@ -1247,7 +1300,7 @@ def test__get_all_spiketrains__tuple(self): obj.append(obj[-1]) res0 = nt.get_all_spiketrains(tuple(obj)) - targ = [iobj.list_children_by_class('SpikeTrain') for iobj in targ] + targ = [iobj.list_children_by_class("SpikeTrain") for iobj in targ] targ = SpikeTrainList(list(chain.from_iterable(targ))) self.assertTrue(len(res0) > 0) @@ -1261,8 +1314,13 @@ def test__get_all_spiketrains__iter(self): generate_one_simple_block( nb_segment=3, supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) - for _ in range(3)] + neo.core.Block, + neo.core.Segment, + neo.core.SpikeTrain, + ], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) @@ -1270,7 +1328,7 @@ def test__get_all_spiketrains__iter(self): res0 = nt.get_all_spiketrains(obj) res0 = nt.get_all_spiketrains(iter(obj)) - targ = [iobj.list_children_by_class('SpikeTrain') for iobj in targ] + targ = [iobj.list_children_by_class("SpikeTrain") for iobj in targ] targ = SpikeTrainList(list(chain.from_iterable(targ))) self.assertTrue(len(res0) > 0) @@ -1284,8 +1342,13 @@ def test__get_all_spiketrains__dict(self): generate_one_simple_block( nb_segment=3, supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.SpikeTrain]) - for _ in range(3)] + neo.core.Block, + neo.core.Segment, + neo.core.SpikeTrain, + ], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) iobj2 = obj[1].segments[2].spiketrains[1] obj[2].segments[1].spiketrains.append(iobj2) @@ -1294,7 +1357,7 @@ def test__get_all_spiketrains__dict(self): obj = dict((i, iobj) for i, iobj in enumerate(obj)) res0 = nt.get_all_spiketrains(obj) - targ = [iobj.list_children_by_class('SpikeTrain') for iobj in targ] + targ = [iobj.list_children_by_class("SpikeTrain") for iobj in targ] targ = SpikeTrainList(list(chain.from_iterable(targ))) self.assertTrue(len(res0) > 0) @@ -1317,7 +1380,8 @@ def test__get_all_events__event(self): def test__get_all_events__segment(self): obj = generate_one_simple_segment( - supported_objects=[neo.core.Segment, neo.core.Event]) + supported_objects=[neo.core.Segment, neo.core.Event] + ) targ = copy.deepcopy(obj) res0 = nt.get_all_events(obj) @@ -1333,15 +1397,15 @@ def test__get_all_events__segment(self): def test__get_all_events__block(self): obj = generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Event]) + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Event], + ) targ = copy.deepcopy(obj) iobj2 = obj.segments[0].events[1] obj.segments[1].events.append(iobj2) res0 = nt.get_all_events(obj) - targ = targ.list_children_by_class('Event') + targ = targ.list_children_by_class("Event") self.assertTrue(len(res0) > 0) @@ -1353,9 +1417,10 @@ def test__get_all_events__list(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Event]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Event], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].events[1] @@ -1363,7 +1428,7 @@ def test__get_all_events__list(self): obj.append(obj[-1]) res0 = nt.get_all_events(obj) - targ = [iobj.list_children_by_class('Event') for iobj in targ] + targ = [iobj.list_children_by_class("Event") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1376,9 +1441,10 @@ def test__get_all_events__tuple(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Event]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Event], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].events[1] @@ -1386,7 +1452,7 @@ def test__get_all_events__tuple(self): obj.append(obj[0]) res0 = nt.get_all_events(tuple(obj)) - targ = [iobj.list_children_by_class('Event') for iobj in targ] + targ = [iobj.list_children_by_class("Event") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1399,9 +1465,10 @@ def test__get_all_events__iter(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Event]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Event], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].events[1] @@ -1409,7 +1476,7 @@ def test__get_all_events__iter(self): obj.append(obj[0]) res0 = nt.get_all_events(iter(obj)) - targ = [iobj.list_children_by_class('Event') for iobj in targ] + targ = [iobj.list_children_by_class("Event") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1422,9 +1489,10 @@ def test__get_all_events__dict(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Event]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Event], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].events[1] @@ -1433,7 +1501,7 @@ def test__get_all_events__dict(self): obj = dict((i, iobj) for i, iobj in enumerate(obj)) res0 = nt.get_all_events(obj) - targ = [iobj.list_children_by_class('Event') for iobj in targ] + targ = [iobj.list_children_by_class("Event") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1456,7 +1524,8 @@ def test__get_all_epochs__epoch(self): def test__get_all_epochs__segment(self): obj = generate_one_simple_segment( - supported_objects=[neo.core.Segment, neo.core.Epoch]) + supported_objects=[neo.core.Segment, neo.core.Epoch] + ) targ = copy.deepcopy(obj) res0 = nt.get_all_epochs(obj) @@ -1472,13 +1541,13 @@ def test__get_all_epochs__segment(self): def test__get_all_epochs__block(self): obj = generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Epoch]) + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Epoch], + ) targ = copy.deepcopy(obj) res0 = nt.get_all_epochs(obj) - targ = targ.list_children_by_class('Epoch') + targ = targ.list_children_by_class("Epoch") self.assertTrue(len(res0) > 0) @@ -1490,9 +1559,10 @@ def test__get_all_epochs__list(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Epoch]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Epoch], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].epochs[1] @@ -1500,7 +1570,7 @@ def test__get_all_epochs__list(self): obj.append(obj[-1]) res0 = nt.get_all_epochs(obj) - targ = [iobj.list_children_by_class('Epoch') for iobj in targ] + targ = [iobj.list_children_by_class("Epoch") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1513,9 +1583,10 @@ def test__get_all_epochs__tuple(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Epoch]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Epoch], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].epochs[1] @@ -1523,7 +1594,7 @@ def test__get_all_epochs__tuple(self): obj.append(obj[0]) res0 = nt.get_all_epochs(tuple(obj)) - targ = [iobj.list_children_by_class('Epoch') for iobj in targ] + targ = [iobj.list_children_by_class("Epoch") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1536,9 +1607,10 @@ def test__get_all_epochs__iter(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Epoch]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Epoch], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].epochs[1] @@ -1546,7 +1618,7 @@ def test__get_all_epochs__iter(self): obj.append(obj[0]) res0 = nt.get_all_epochs(iter(obj)) - targ = [iobj.list_children_by_class('Epoch') for iobj in targ] + targ = [iobj.list_children_by_class("Epoch") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1559,9 +1631,10 @@ def test__get_all_epochs__dict(self): obj = [ generate_one_simple_block( nb_segment=3, - supported_objects=[ - neo.core.Block, neo.core.Segment, neo.core.Epoch]) - for _ in range(3)] + supported_objects=[neo.core.Block, neo.core.Segment, neo.core.Epoch], + ) + for _ in range(3) + ] targ = copy.deepcopy(obj) obj.append(obj[-1]) iobj2 = obj[1].segments[2].epochs[1] @@ -1570,7 +1643,7 @@ def test__get_all_epochs__dict(self): obj = dict((i, iobj) for i, iobj in enumerate(obj)) res0 = nt.get_all_epochs(obj) - targ = [iobj.list_children_by_class('Epoch') for iobj in targ] + targ = [iobj.list_children_by_class("Epoch") for iobj in targ] targ = list(chain.from_iterable(targ)) self.assertTrue(len(res0) > 0) @@ -1580,5 +1653,5 @@ def test__get_all_epochs__dict(self): assert_same_sub_schema(targ, res0) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_parallel.py b/elephant/test/test_parallel.py index f8952e304..f0d7f5470 100644 --- a/elephant/test/test_parallel.py +++ b/elephant/test/test_parallel.py @@ -15,6 +15,7 @@ def setUpClass(cls): cls.executors_cls = [SingleProcess, ProcessPoolExecutor] try: from elephant.parallel.mpi import MPIPoolExecutor, MPICommExecutor + cls.executors_cls.extend([MPIPoolExecutor, MPICommExecutor]) except ImportError: # mpi4py is not installed @@ -24,7 +25,8 @@ def setUpClass(cls): n_spiketrains = 10 cls.spiketrains = tuple( StationaryPoissonProcess( - rate=10 * pq.Hz, t_stop=10 * pq.s).generate_spiketrain() + rate=10 * pq.Hz, t_stop=10 * pq.s + ).generate_spiketrain() for _ in range(n_spiketrains) ) cls.mean_fr = tuple(map(mean_firing_rate, cls.spiketrains)) @@ -33,10 +35,11 @@ def test_mean_firing_rate(self): for executor_cls in self.executors_cls: with self.subTest(executor_cls=executor_cls): executor = executor_cls() - mean_fr = executor.execute(handler=mean_firing_rate, - args_iterate=self.spiketrains) + mean_fr = executor.execute( + handler=mean_firing_rate, args_iterate=self.spiketrains + ) assert_array_almost_equal(mean_fr, self.mean_fr) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_phase_analysis.py b/elephant/test/test_phase_analysis.py index 6b8218c99..8df790f2b 100644 --- a/elephant/test/test_phase_analysis.py +++ b/elephant/test/test_phase_analysis.py @@ -5,6 +5,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + from __future__ import division, print_function import unittest @@ -20,31 +21,37 @@ class SpikeTriggeredPhaseTestCase(unittest.TestCase): - def setUp(self): tlen0 = 100 * pq.s - f0 = 20. * pq.Hz + f0 = 20.0 * pq.Hz fs0 = 1 * pq.ms - t0 = np.arange( - 0, tlen0.rescale(pq.s).magnitude, - fs0.rescale(pq.s).magnitude) * pq.s + t0 = ( + np.arange(0, tlen0.rescale(pq.s).magnitude, fs0.rescale(pq.s).magnitude) + * pq.s + ) self.anasig0 = AnalogSignal( np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), - units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) + units=pq.mV, + t_start=0 * pq.ms, + sampling_period=fs0, + ) self.st0 = SpikeTrain( np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms, - t_start=0 * pq.ms, t_stop=tlen0) + t_start=0 * pq.ms, + t_stop=tlen0, + ) self.st1 = SpikeTrain( - [100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms, - t_start=0 * pq.ms, t_stop=tlen0) + [100.0, 100.1, 100.2, 100.3, 100.9, 101.0] * pq.ms, + t_start=0 * pq.ms, + t_stop=tlen0, + ) def test_perfect_locking_one_spiketrain_one_signal(self): phases, amps, times = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - self.st0, - interpolate=True) + elephant.signal_processing.hilbert(self.anasig0), self.st0, interpolate=True + ) - assert_allclose(phases[0], - np.pi / 2.) + assert_allclose(phases[0], -np.pi / 2.0) assert_allclose(amps[0], 1, atol=0.1) assert_allclose(times[0].magnitude, self.st0.magnitude) self.assertEqual(len(phases[0]), len(self.st0)) @@ -55,11 +62,13 @@ def test_perfect_locking_many_spiketrains_many_signals(self): phases, amps, times = elephant.phase_analysis.spike_triggered_phase( [ elephant.signal_processing.hilbert(self.anasig0), - elephant.signal_processing.hilbert(self.anasig0)], + elephant.signal_processing.hilbert(self.anasig0), + ], [self.st0, self.st0], - interpolate=True) + interpolate=True, + ) - assert_allclose(phases[0], -np.pi / 2.) + assert_allclose(phases[0], -np.pi / 2.0) assert_allclose(amps[0], 1, atol=0.1) assert_allclose(times[0].magnitude, self.st0.magnitude) self.assertEqual(len(phases[0]), len(self.st0)) @@ -70,11 +79,13 @@ def test_perfect_locking_one_spiketrains_many_signals(self): phases, amps, times = elephant.phase_analysis.spike_triggered_phase( [ elephant.signal_processing.hilbert(self.anasig0), - elephant.signal_processing.hilbert(self.anasig0)], + elephant.signal_processing.hilbert(self.anasig0), + ], [self.st0], - interpolate=True) + interpolate=True, + ) - assert_allclose(phases[0], -np.pi / 2.) + assert_allclose(phases[0], -np.pi / 2.0) assert_allclose(amps[0], 1, atol=0.1) assert_allclose(times[0].magnitude, self.st0.magnitude) self.assertEqual(len(phases[0]), len(self.st0)) @@ -85,9 +96,10 @@ def test_perfect_locking_many_spiketrains_one_signal(self): phases, amps, times = elephant.phase_analysis.spike_triggered_phase( elephant.signal_processing.hilbert(self.anasig0), [self.st0, self.st0], - interpolate=True) + interpolate=True, + ) - assert_allclose(phases[0], -np.pi / 2.) + assert_allclose(phases[0], -np.pi / 2.0) assert_allclose(amps[0], 1, atol=0.1) assert_allclose(times[0].magnitude, self.st0.magnitude) self.assertEqual(len(phases[0]), len(self.st0)) @@ -96,9 +108,8 @@ def test_perfect_locking_many_spiketrains_one_signal(self): def test_interpolate(self): phases_int, _, _ = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - self.st1, - interpolate=True) + elephant.signal_processing.hilbert(self.anasig0), self.st1, interpolate=True + ) self.assertLess(phases_int[0][0], phases_int[0][1]) self.assertLess(phases_int[0][1], phases_int[0][2]) @@ -109,7 +120,8 @@ def test_interpolate(self): phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( elephant.signal_processing.hilbert(self.anasig0), self.st1, - interpolate=False) + interpolate=False, + ) self.assertEqual(phases_noint[0][0], phases_noint[0][1]) self.assertEqual(phases_noint[0][1], phases_noint[0][2]) @@ -125,65 +137,61 @@ def test_interpolate(self): def test_inconsistent_numbers_spiketrains_hilbert(self): self.assertRaises( - ValueError, elephant.phase_analysis.spike_triggered_phase, + ValueError, + elephant.phase_analysis.spike_triggered_phase, [ elephant.signal_processing.hilbert(self.anasig0), - elephant.signal_processing.hilbert(self.anasig0)], - [self.st0, self.st0, self.st0], False) + elephant.signal_processing.hilbert(self.anasig0), + ], + [self.st0, self.st0, self.st0], + False, + ) self.assertRaises( - ValueError, elephant.phase_analysis.spike_triggered_phase, + ValueError, + elephant.phase_analysis.spike_triggered_phase, [ elephant.signal_processing.hilbert(self.anasig0), - elephant.signal_processing.hilbert(self.anasig0)], - [self.st0, self.st0, self.st0], False) + elephant.signal_processing.hilbert(self.anasig0), + ], + [self.st0, self.st0, self.st0], + False, + ) def test_spike_earlier_than_hilbert(self): # This is a spike clearly outside the bounds - st = SpikeTrain( - [-50, 50], - units='s', t_start=-100 * pq.s, t_stop=100 * pq.s) + st = SpikeTrain([-50, 50], units="s", t_start=-100 * pq.s, t_stop=100 * pq.s) phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - st, - interpolate=False) + elephant.signal_processing.hilbert(self.anasig0), st, interpolate=False + ) self.assertEqual(len(phases_noint[0]), 1) # This is a spike right on the border (start of the signal is at 0s, # spike sits at t=0s). By definition of intervals in # Elephant (left borders inclusive, right borders exclusive), this # spike is to be considered. - st = SpikeTrain( - [0, 50], - units='s', t_start=-100 * pq.s, t_stop=100 * pq.s) + st = SpikeTrain([0, 50], units="s", t_start=-100 * pq.s, t_stop=100 * pq.s) phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - st, - interpolate=False) + elephant.signal_processing.hilbert(self.anasig0), st, interpolate=False + ) self.assertEqual(len(phases_noint[0]), 2) def test_spike_later_than_hilbert(self): # This is a spike clearly outside the bounds - st = SpikeTrain( - [1, 250], - units='s', t_start=-1 * pq.s, t_stop=300 * pq.s) + st = SpikeTrain([1, 250], units="s", t_start=-1 * pq.s, t_stop=300 * pq.s) phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - st, - interpolate=False) + elephant.signal_processing.hilbert(self.anasig0), st, interpolate=False + ) self.assertEqual(len(phases_noint[0]), 1) # This is a spike right on the border (length of the signal is 100s, # spike sits at t=100s). However, by definition of intervals in # Elephant (left borders inclusive, right borders exclusive), this # spike is not to be considered. - st = SpikeTrain( - [1, 100], - units='s', t_start=-1 * pq.s, t_stop=200 * pq.s) + st = SpikeTrain([1, 100], units="s", t_start=-1 * pq.s, t_stop=200 * pq.s) phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - st, - interpolate=False) + elephant.signal_processing.hilbert(self.anasig0), st, interpolate=False + ) self.assertEqual(len(phases_noint[0]), 1) # This test handles the correct dealing with input signals that have @@ -193,13 +201,14 @@ def test_regression_269(self): # before the end of the signal cu = pq.CompoundUnit("1/30000.*s") st = SpikeTrain( - [30000., (self.anasig0.t_stop - 1 * pq.s).rescale(cu).magnitude], + [30000.0, (self.anasig0.t_stop - 1 * pq.s).rescale(cu).magnitude], units=pq.CompoundUnit("1/30000.*s"), - t_start=-1 * pq.s, t_stop=300 * pq.s) + t_start=-1 * pq.s, + t_stop=300 * pq.s, + ) phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( - elephant.signal_processing.hilbert(self.anasig0), - st, - interpolate=False) + elephant.signal_processing.hilbert(self.anasig0), st, interpolate=False + ) self.assertEqual(len(phases_noint[0]), 2) @@ -221,11 +230,9 @@ def testMeanVector_direction_is_phi_and_length_is_1(self): for a sample with all phases equal to phi on the unit circle. """ - theta_bar_1, r_1 = elephant.phase_analysis.mean_phase_vector( - self.dataset1) + theta_bar_1, r_1 = elephant.phase_analysis.mean_phase_vector(self.dataset1) # mean direction must be phi - self.assertAlmostEqual(theta_bar_1, self.lock_value_phi, - delta=self.tolerance) + self.assertAlmostEqual(theta_bar_1, self.lock_value_phi, delta=self.tolerance) # mean vector length must be almost equal 1 self.assertAlmostEqual(r_1, 1, delta=self.tolerance) @@ -234,8 +241,7 @@ def testMeanVector_length_is_0(self): Test if the mean vector length is 0 for a evenly spaced distribution on the unit circle. """ - theta_bar_2, r_2 = elephant.phase_analysis.mean_phase_vector( - self.dataset2) + theta_bar_2, r_2 = elephant.phase_analysis.mean_phase_vector(self.dataset2) # mean vector length must be almost equal 0 self.assertAlmostEqual(r_2, 0, delta=self.tolerance) @@ -245,8 +251,7 @@ def testMeanVector_ranges_of_direction_and_length(self): and is within (-pi, pi]. Test if the range of the mean vector length is within [0, 1]. """ - theta_bar_3, r_3 = elephant.phase_analysis.mean_phase_vector( - self.dataset3) + theta_bar_3, r_3 = elephant.phase_analysis.mean_phase_vector(self.dataset3) # mean vector direction self.assertTrue(-np.pi < theta_bar_3 <= np.pi) # mean vector length @@ -267,8 +272,7 @@ def testPhaseDifference_in_range_minus_pi_to_pi(self): beta = np.random.uniform(-np.pi, np.pi, self.n_samples) phase_diff = elephant.phase_analysis.phase_difference(alpha, beta) - self.assertTrue((-np.pi <= phase_diff).all() - and (phase_diff <= np.pi).all()) + self.assertTrue((-np.pi <= phase_diff).all() and (phase_diff <= np.pi).all()) def testPhaseDifference_is_delta(self): """ @@ -292,18 +296,18 @@ def setUp(self): self.num_trials = 100 # create two random uniform distributions (all trials are identical) - self.signal_x = \ - np.full([self.num_trials, self.num_time_points], - np.random.uniform(-np.pi, np.pi, self.num_time_points)) - self.signal_y = \ - np.full([self.num_trials, self.num_time_points], - np.random.uniform(-np.pi, np.pi, self.num_time_points)) + self.signal_x = np.full( + [self.num_trials, self.num_time_points], + np.random.uniform(-np.pi, np.pi, self.num_time_points), + ) + self.signal_y = np.full( + [self.num_trials, self.num_time_points], + np.random.uniform(-np.pi, np.pi, self.num_time_points), + ) # create two random uniform distributions, where all trails are random - self.random_x = np.random.uniform( - -np.pi, np.pi, (1000, self.num_time_points)) - self.random_y = np.random.uniform( - -np.pi, np.pi, (1000, self.num_time_points)) + self.random_x = np.random.uniform(-np.pi, np.pi, (1000, self.num_time_points)) + self.random_y = np.random.uniform(-np.pi, np.pi, (1000, self.num_time_points)) # simple samples of different shapes to assert ErrorRaising self.simple_x = np.array([[0, -np.pi, np.pi], [0, -np.pi, np.pi]]) @@ -316,12 +320,11 @@ def testPhaseLockingValue_identical_signals_both_identical_trials(self): trials are passed. PLV's needed to be 1, due to the constant phase difference of 0 across trials at each time-point. """ - list1_plv_t = \ - elephant.phase_analysis.phase_locking_value(self.signal_x, - self.signal_x) + list1_plv_t = elephant.phase_analysis.phase_locking_value( + self.signal_x, self.signal_x + ) target_plv_r_is_one = np.ones_like(list1_plv_t) - np.testing.assert_allclose(list1_plv_t, target_plv_r_is_one, - self.tolerance) + np.testing.assert_allclose(list1_plv_t, target_plv_r_is_one, self.tolerance) def testPhaseLockingValue_different_signals_both_identical_trials(self): """ @@ -331,10 +334,10 @@ def testPhaseLockingValue_different_signals_both_identical_trials(self): different time-points. """ list2_plv_t = elephant.phase_analysis.phase_locking_value( - self.signal_x, self.signal_y) + self.signal_x, self.signal_y + ) target_plv_r_is_one = np.ones_like(list2_plv_t) - np.testing.assert_allclose(list2_plv_t, target_plv_r_is_one, - atol=3e-15) + np.testing.assert_allclose(list2_plv_t, target_plv_r_is_one, atol=3e-15) def testPhaseLockingValue_different_signals_both_different_trials(self): """ @@ -344,11 +347,13 @@ def testPhaseLockingValue_different_signals_both_different_trials(self): phase difference across trials for each time-point. """ list3_plv_t = elephant.phase_analysis.phase_locking_value( - self.random_x, self.random_y) + self.random_x, self.random_y + ) target_plv_is_zero = np.zeros_like(list3_plv_t) # use default value from np.allclose() for atol=1e-8 to prevent failure - np.testing.assert_allclose(list3_plv_t, target_plv_is_zero, - rtol=1e-2, atol=1.1e-1) + np.testing.assert_allclose( + list3_plv_t, target_plv_is_zero, rtol=1e-2, atol=1.1e-1 + ) def testPhaseLockingValue_raise_Error_if_trial_number_is_different(self): """ @@ -357,8 +362,11 @@ def testPhaseLockingValue_raise_Error_if_trial_number_is_different(self): """ # different numbers of trails np.testing.assert_raises( - ValueError, elephant.phase_analysis.phase_locking_value, - self.simple_x, self.simple_y) + ValueError, + elephant.phase_analysis.phase_locking_value, + self.simple_x, + self.simple_y, + ) def testPhaseLockingValue_raise_Error_if_trial_lengths_are_different(self): """ @@ -367,8 +375,11 @@ def testPhaseLockingValue_raise_Error_if_trial_lengths_are_different(self): """ # different lengths in a trail pair np.testing.assert_raises( - ValueError, elephant.phase_analysis.phase_locking_value, - self.simple_y, self.simple_z) + ValueError, + elephant.phase_analysis.phase_locking_value, + self.simple_y, + self.simple_z, + ) class WeightedPhaseLagIndexTestCase(unittest.TestCase): @@ -385,24 +396,34 @@ def setUpClass(cls): # and then loaded/ read for each test function individually. # REAL DATA - real_data_path = "unittest/phase_analysis/weighted_phase_lag_index/" \ - "data/wpli_real_data" + real_data_path = ( + "unittest/phase_analysis/weighted_phase_lag_index/" "data/wpli_real_data" + ) cls.files_to_download_real = ( - ("i140703-001_ch01_slice_TS_ON_to_GO_ON_correct_trials.mat", - "0e76454c58208cab710e672d04de5168"), - ("i140703-001_ch02_slice_TS_ON_to_GO_ON_correct_trials.mat", - "b06059e5222e91eb640caad0aba15b7f"), - ("i140703-001_cross_spectrum_of_channel_1_and_2_of_slice_" - "TS_ON_to_GO_ON_corect_trials.mat", - "2687ef63a4a456971a5dcc621b02e9a9") + ( + "i140703-001_ch01_slice_TS_ON_to_GO_ON_correct_trials.mat", + "0e76454c58208cab710e672d04de5168", + ), + ( + "i140703-001_ch02_slice_TS_ON_to_GO_ON_correct_trials.mat", + "b06059e5222e91eb640caad0aba15b7f", + ), + ( + "i140703-001_cross_spectrum_of_channel_1_and_2_of_slice_" + "TS_ON_to_GO_ON_corect_trials.mat", + "2687ef63a4a456971a5dcc621b02e9a9", + ), ) for filename, checksum in cls.files_to_download_real: # files will be downloaded to ELEPHANT_TMP_DIR cls.tmp_path = download_datasets( - f"{real_data_path}/{filename}", checksum=checksum) + f"{real_data_path}/{filename}", checksum=checksum + ) # ARTIFICIAL DATA - artificial_data_path = "unittest/phase_analysis/" \ + artificial_data_path = ( + "unittest/phase_analysis/" "weighted_phase_lag_index/data/wpli_specific_artificial_dataset" + ) cls.files_to_download_artificial = ( ("artificial_LFPs_1.mat", "4b99b15f89c0b9a0eb6fc14e9009436f"), ("artificial_LFPs_2.mat", "7144976b5f871fa62f4a831f530deee4"), @@ -410,20 +431,28 @@ def setUpClass(cls): for filename, checksum in cls.files_to_download_artificial: # files will be downloaded to ELEPHANT_TMP_DIR cls.tmp_path = download_datasets( - f"{artificial_data_path}/{filename}", checksum=checksum) + f"{artificial_data_path}/{filename}", checksum=checksum + ) # GROUND TRUTH DATA - ground_truth_data_path = "unittest/phase_analysis/" \ - "weighted_phase_lag_index/data/wpli_ground_truth" + ground_truth_data_path = ( + "unittest/phase_analysis/" "weighted_phase_lag_index/data/wpli_ground_truth" + ) cls.files_to_download_ground_truth = ( - ("ground_truth_WPLI_from_ft_connectivity_wpli_" - "with_real_LFPs_R2G.csv", "4d9a7b7afab7d107023956077ab11fef"), - ("ground_truth_WPLI_from_ft_connectivity_wpli_" - "with_artificial_LFPs.csv", "92988f475333d7badbe06b3f23abe494"), + ( + "ground_truth_WPLI_from_ft_connectivity_wpli_" "with_real_LFPs_R2G.csv", + "4d9a7b7afab7d107023956077ab11fef", + ), + ( + "ground_truth_WPLI_from_ft_connectivity_wpli_" + "with_artificial_LFPs.csv", + "92988f475333d7badbe06b3f23abe494", + ), ) for filename, checksum in cls.files_to_download_ground_truth: # files will be downloaded into ELEPHANT_TMP_DIR cls.tmp_path = download_datasets( - f"{ground_truth_data_path}/{filename}", checksum=checksum) + f"{ground_truth_data_path}/{filename}", checksum=checksum + ) def setUp(self): self.tolerance = 1e-15 @@ -432,50 +461,60 @@ def setUp(self): # real LFP-dataset dataset1_real = scipy.io.loadmat( f"{self.tmp_path.parent}/{self.files_to_download_real[0][0]}", - squeeze_me=True) + squeeze_me=True, + ) dataset2_real = scipy.io.loadmat( f"{self.tmp_path.parent}/{self.files_to_download_real[1][0]}", - squeeze_me=True) + squeeze_me=True, + ) # get relevant values - self.lfps1_real = dataset1_real['lfp_matrix'] * pq.uV - self.sf1_real = dataset1_real['sf'] * pq.Hz - self.lfps2_real = dataset2_real['lfp_matrix'] * pq.uV - self.sf2_real = dataset2_real['sf'] * pq.Hz + self.lfps1_real = dataset1_real["lfp_matrix"] * pq.uV + self.sf1_real = dataset1_real["sf"] * pq.Hz + self.lfps2_real = dataset2_real["lfp_matrix"] * pq.uV + self.sf2_real = dataset2_real["sf"] * pq.Hz # create AnalogSignals from the real dataset self.lfps1_real_AnalogSignal = AnalogSignal( - signal=self.lfps1_real, sampling_rate=self.sf1_real) + signal=self.lfps1_real, sampling_rate=self.sf1_real + ) self.lfps2_real_AnalogSignal = AnalogSignal( - signal=self.lfps2_real, sampling_rate=self.sf2_real) + signal=self.lfps2_real, sampling_rate=self.sf2_real + ) # artificial LFP-dataset dataset1_artificial = scipy.io.loadmat( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_artificial[0][0]}", squeeze_me=True) + f"{self.tmp_path.parent}/" f"{self.files_to_download_artificial[0][0]}", + squeeze_me=True, + ) dataset2_artificial = scipy.io.loadmat( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_artificial[1][0]}", squeeze_me=True) + f"{self.tmp_path.parent}/" f"{self.files_to_download_artificial[1][0]}", + squeeze_me=True, + ) # get relevant values - self.lfps1_artificial = dataset1_artificial['lfp_matrix'] * pq.uV - self.sf1_artificial = dataset1_artificial['sf'] * pq.Hz - self.lfps2_artificial = dataset2_artificial['lfp_matrix'] * pq.uV - self.sf2_artificial = dataset2_artificial['sf'] * pq.Hz + self.lfps1_artificial = dataset1_artificial["lfp_matrix"] * pq.uV + self.sf1_artificial = dataset1_artificial["sf"] * pq.Hz + self.lfps2_artificial = dataset2_artificial["lfp_matrix"] * pq.uV + self.sf2_artificial = dataset2_artificial["sf"] * pq.Hz # create AnalogSignals from the artificial dataset self.lfps1_artificial_AnalogSignal = AnalogSignal( - signal=self.lfps1_artificial, sampling_rate=self.sf1_artificial) + signal=self.lfps1_artificial, sampling_rate=self.sf1_artificial + ) self.lfps2_artificial_AnalogSignal = AnalogSignal( - signal=self.lfps2_artificial, sampling_rate=self.sf2_artificial) + signal=self.lfps2_artificial, sampling_rate=self.sf2_artificial + ) # load ground-truth reference calculated by: # Matlab package 'FieldTrip': ft_connectivity_wpli() self.wpli_ground_truth_ft_connectivity_wpli_real = np.loadtxt( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_ground_truth[0][0]}", - delimiter=',', dtype=np.float64) + f"{self.tmp_path.parent}/" f"{self.files_to_download_ground_truth[0][0]}", + delimiter=",", + dtype=np.float64, + ) self.wpli_ground_truth_ft_connectivity_artificial = np.loadtxt( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_ground_truth[1][0]}", - delimiter=',', dtype=np.float64) + f"{self.tmp_path.parent}/" f"{self.files_to_download_ground_truth[1][0]}", + delimiter=",", + dtype=np.float64, + ) def test_WPLI_ground_truth_consistency_real_LFP_dataset(self): """ @@ -493,25 +532,27 @@ def test_WPLI_ground_truth_consistency_real_LFP_dataset(self): # Quantity-input with self.subTest(msg="Quantity input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_real, self.lfps2_real, self.sf1_real) + self.lfps1_real, self.lfps2_real, self.sf1_real + ) np.testing.assert_allclose( - wpli, self.wpli_ground_truth_ft_connectivity_wpli_real, - equal_nan=True) + wpli, self.wpli_ground_truth_ft_connectivity_wpli_real, equal_nan=True + ) # np.array-input with self.subTest(msg="np.array input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_real.magnitude, self.lfps2_real.magnitude, - self.sf1_real) + self.lfps1_real.magnitude, self.lfps2_real.magnitude, self.sf1_real + ) np.testing.assert_allclose( - wpli, self.wpli_ground_truth_ft_connectivity_wpli_real, - equal_nan=True) + wpli, self.wpli_ground_truth_ft_connectivity_wpli_real, equal_nan=True + ) # neo.AnalogSignal-input with self.subTest(msg="neo.AnalogSignal input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_real_AnalogSignal, self.lfps2_real_AnalogSignal) + self.lfps1_real_AnalogSignal, self.lfps2_real_AnalogSignal + ) np.testing.assert_allclose( - wpli, self.wpli_ground_truth_ft_connectivity_wpli_real, - equal_nan=True) + wpli, self.wpli_ground_truth_ft_connectivity_wpli_real, equal_nan=True + ) def test_WPLI_ground_truth_consistency_artificial_LFP_dataset(self): """ @@ -526,28 +567,47 @@ def test_WPLI_ground_truth_consistency_artificial_LFP_dataset(self): # Quantity-input with self.subTest(msg="Quantity input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_artificial, self.lfps2_artificial, - self.sf1_artificial, absolute_value=False) + self.lfps1_artificial, + self.lfps2_artificial, + self.sf1_artificial, + absolute_value=False, + ) np.testing.assert_allclose( - wpli, self.wpli_ground_truth_ft_connectivity_artificial, - atol=1e-14, rtol=1e-12, equal_nan=True) + wpli, + self.wpli_ground_truth_ft_connectivity_artificial, + atol=1e-14, + rtol=1e-12, + equal_nan=True, + ) # np.array-input with self.subTest(msg="np.array input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial.magnitude, - self.lfps2_artificial.magnitude, self.sf1_artificial, - absolute_value=False) + self.lfps2_artificial.magnitude, + self.sf1_artificial, + absolute_value=False, + ) np.testing.assert_allclose( - wpli, self.wpli_ground_truth_ft_connectivity_artificial, - atol=1e-14, rtol=1e-12, equal_nan=True) + wpli, + self.wpli_ground_truth_ft_connectivity_artificial, + atol=1e-14, + rtol=1e-12, + equal_nan=True, + ) # neo.AnalogSignal-input with self.subTest(msg="neo.AnalogSignal input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial_AnalogSignal, - self.lfps2_artificial_AnalogSignal, absolute_value=False) + self.lfps2_artificial_AnalogSignal, + absolute_value=False, + ) np.testing.assert_allclose( - wpli, self.wpli_ground_truth_ft_connectivity_artificial, - atol=1e-14, rtol=1e-12, equal_nan=True) + wpli, + self.wpli_ground_truth_ft_connectivity_artificial, + atol=1e-14, + rtol=1e-12, + equal_nan=True, + ) def test_WPLI_is_zero(self): """ @@ -557,25 +617,35 @@ def test_WPLI_is_zero(self): # Quantity-input with self.subTest(msg="Quantity input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_artificial, self.lfps2_artificial, - self.sf1_artificial, absolute_value=False) + self.lfps1_artificial, + self.lfps2_artificial, + self.sf1_artificial, + absolute_value=False, + ) np.testing.assert_allclose( - wpli[freq == 70], 0, atol=0.004, rtol=self.tolerance) + wpli[freq == 70], 0, atol=0.004, rtol=self.tolerance + ) # np.array-input with self.subTest(msg="np.array input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial.magnitude, - self.lfps2_artificial.magnitude, self.sf1_artificial, - absolute_value=False) + self.lfps2_artificial.magnitude, + self.sf1_artificial, + absolute_value=False, + ) np.testing.assert_allclose( - wpli[freq == 70], 0, atol=0.004, rtol=self.tolerance) + wpli[freq == 70], 0, atol=0.004, rtol=self.tolerance + ) # neo.AnalogSignal-input with self.subTest(msg="neo.AnalogSignal input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial_AnalogSignal, - self.lfps2_artificial_AnalogSignal, absolute_value=False) + self.lfps2_artificial_AnalogSignal, + absolute_value=False, + ) np.testing.assert_allclose( - wpli[freq == 70], 0, atol=0.004, rtol=self.tolerance) + wpli[freq == 70], 0, atol=0.004, rtol=self.tolerance + ) def test_WPLI_is_one(self): """ @@ -585,28 +655,38 @@ def test_WPLI_is_one(self): # Quantity-input with self.subTest(msg="Quantity input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_artificial, self.lfps2_artificial, - self.sf1_artificial, absolute_value=False) - mask = ((freq == 16) | (freq == 36)) + self.lfps1_artificial, + self.lfps2_artificial, + self.sf1_artificial, + absolute_value=False, + ) + mask = (freq == 16) | (freq == 36) np.testing.assert_allclose( - wpli[mask], 1, atol=self.tolerance, rtol=self.tolerance) + wpli[mask], 1, atol=self.tolerance, rtol=self.tolerance + ) # np.array-input with self.subTest(msg="np.array input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial.magnitude, - self.lfps2_artificial.magnitude, self.sf1_artificial, - absolute_value=False) - mask = ((freq == 16) | (freq == 36)) + self.lfps2_artificial.magnitude, + self.sf1_artificial, + absolute_value=False, + ) + mask = (freq == 16) | (freq == 36) np.testing.assert_allclose( - wpli[mask], 1, atol=self.tolerance, rtol=self.tolerance) + wpli[mask], 1, atol=self.tolerance, rtol=self.tolerance + ) # neo.AnalogSignal-input with self.subTest(msg="neo.AnalogSignal input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial_AnalogSignal, - self.lfps2_artificial_AnalogSignal, absolute_value=False) - mask = ((freq == 16) | (freq == 36)) + self.lfps2_artificial_AnalogSignal, + absolute_value=False, + ) + mask = (freq == 16) | (freq == 36) np.testing.assert_allclose( - wpli[mask], 1, atol=self.tolerance, rtol=self.tolerance) + wpli[mask], 1, atol=self.tolerance, rtol=self.tolerance + ) def test_WPLI_is_minus_one(self): """ @@ -616,26 +696,36 @@ def test_WPLI_is_minus_one(self): # Quantity-input with self.subTest(msg="Quantity input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( - self.lfps1_artificial, self.lfps2_artificial, - self.sf1_artificial, absolute_value=False) - mask = ((freq == 52) | (freq == 100)) + self.lfps1_artificial, + self.lfps2_artificial, + self.sf1_artificial, + absolute_value=False, + ) + mask = (freq == 52) | (freq == 100) np.testing.assert_allclose( - wpli[mask], -1, atol=self.tolerance, rtol=self.tolerance) + wpli[mask], -1, atol=self.tolerance, rtol=self.tolerance + ) # np.array-input with self.subTest(msg="np.array input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial.magnitude, - self.lfps2_artificial.magnitude, self.sf1_artificial, - absolute_value=False) + self.lfps2_artificial.magnitude, + self.sf1_artificial, + absolute_value=False, + ) np.testing.assert_allclose( - wpli[mask], -1, atol=self.tolerance, rtol=self.tolerance) + wpli[mask], -1, atol=self.tolerance, rtol=self.tolerance + ) # neo.AnalogSignal-input with self.subTest(msg="neo.AnalogSignal input"): freq, wpli = elephant.phase_analysis.weighted_phase_lag_index( self.lfps1_artificial_AnalogSignal, - self.lfps2_artificial_AnalogSignal, absolute_value=False) + self.lfps2_artificial_AnalogSignal, + absolute_value=False, + ) np.testing.assert_allclose( - wpli[mask], -1, atol=self.tolerance, rtol=self.tolerance) + wpli[mask], -1, atol=self.tolerance, rtol=self.tolerance + ) def test_WPLI_raises_error_if_signals_have_different_shapes(self): """ @@ -648,40 +738,63 @@ def test_WPLI_raises_error_if_signals_have_different_shapes(self): trials1_length4 = np.array([[0, 1, 1 / 2, -1]]) * pq.uV sampling_frequency = 250 * pq.Hz trials2_length3_analogsignal = AnalogSignal( - signal=trials2_length3, sampling_rate=sampling_frequency) + signal=trials2_length3, sampling_rate=sampling_frequency + ) trials1_length3_analogsignal = AnalogSignal( - signal=trials1_length3, sampling_rate=sampling_frequency) + signal=trials1_length3, sampling_rate=sampling_frequency + ) trials1_length4_analogsignal = AnalogSignal( - signal=trials1_length4, sampling_rate=sampling_frequency) + signal=trials1_length4, sampling_rate=sampling_frequency + ) # different numbers of trails with self.subTest(msg="diff. trial numbers & Quantity input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - trials2_length3, trials1_length3, sampling_frequency) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + trials2_length3, + trials1_length3, + sampling_frequency, + ) with self.subTest(msg="diff. trial numbers & np.array input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - trials2_length3.magnitude, trials1_length3.magnitude, - sampling_frequency) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + trials2_length3.magnitude, + trials1_length3.magnitude, + sampling_frequency, + ) with self.subTest(msg="diff. trial numbers & neo.AnalogSignal input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - trials2_length3_analogsignal, trials1_length3_analogsignal) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + trials2_length3_analogsignal, + trials1_length3_analogsignal, + ) # different lengths in a trail pair with self.subTest(msg="diff. trial lengths & Quantity input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - trials1_length3, trials1_length4, sampling_frequency) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + trials1_length3, + trials1_length4, + sampling_frequency, + ) with self.subTest(msg="diff. trial lengths & np.array input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - trials1_length3.magnitude, trials1_length4.magnitude, - sampling_frequency) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + trials1_length3.magnitude, + trials1_length4.magnitude, + sampling_frequency, + ) with self.subTest(msg="diff. trial lengths & neo.AnalogSignal input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - trials1_length3_analogsignal, trials1_length4_analogsignal) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + trials1_length3_analogsignal, + trials1_length4_analogsignal, + ) @staticmethod def test_WPLI_raises_error_if_AnalogSignals_have_diff_sampling_rate(): @@ -689,13 +802,20 @@ def test_WPLI_raises_error_if_AnalogSignals_have_diff_sampling_rate(): Test if WPLI raises a ValueError, when the AnalogSignals have different sampling rates. """ - signal_x_250_hz = AnalogSignal(signal=np.random.random([40, 2100]), - units=pq.mV, sampling_rate=0.25*pq.kHz) - signal_y_1000_hz = AnalogSignal(signal=np.random.random([40, 2100]), - units=pq.mV, sampling_rate=1000*pq.Hz) + signal_x_250_hz = AnalogSignal( + signal=np.random.random([40, 2100]), + units=pq.mV, + sampling_rate=0.25 * pq.kHz, + ) + signal_y_1000_hz = AnalogSignal( + signal=np.random.random([40, 2100]), units=pq.mV, sampling_rate=1000 * pq.Hz + ) np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - signal_x_250_hz, signal_y_1000_hz) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + signal_x_250_hz, + signal_y_1000_hz, + ) def test_WPLI_raises_error_if_sampling_rate_not_given(self): """ @@ -706,13 +826,19 @@ def test_WPLI_raises_error_if_sampling_rate_not_given(self): signal_y = np.random.random([40, 2100]) * pq.mV with self.subTest(msg="Quantity-input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - signal_x, signal_y) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + signal_x, + signal_y, + ) with self.subTest(msg="np.array-input"): np.testing.assert_raises( - ValueError, elephant.phase_analysis.weighted_phase_lag_index, - signal_x.magnitude, signal_y.magnitude) + ValueError, + elephant.phase_analysis.weighted_phase_lag_index, + signal_x.magnitude, + signal_y.magnitude, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_signal_processing.py b/elephant/test/test_signal_processing.py index 7633a7604..0b9c760b7 100644 --- a/elephant/test/test_signal_processing.py +++ b/elephant/test/test_signal_processing.py @@ -5,6 +5,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + from __future__ import division, print_function import unittest @@ -23,10 +24,10 @@ class PairwiseCrossCorrelationTest(unittest.TestCase): # Set parameters sampling_period = 0.02 * pq.s - sampling_rate = 1. / sampling_period + sampling_rate = 1.0 / sampling_period n_samples = 2018 times = np.arange(n_samples) * sampling_period - freq = 1. * pq.Hz + freq = 1.0 * pq.Hz def test_cross_correlation_freqs(self): """ @@ -37,20 +38,24 @@ def test_cross_correlation_freqs(self): freq_arr = np.linspace(0.5, 15, 8) * pq.Hz signal = np.zeros((self.n_samples, 3)) for freq in freq_arr: - signal[:, 0] = np.sin(2. * np.pi * freq * self.times) - signal[:, 1] = np.cos(2. * np.pi * freq * self.times) - signal[:, 2] = np.cos(2. * np.pi * freq * self.times + 0.2) + signal[:, 0] = np.sin(2.0 * np.pi * freq * self.times) + signal[:, 1] = np.cos(2.0 * np.pi * freq * self.times) + signal[:, 2] = np.cos(2.0 * np.pi * freq * self.times + 0.2) # Convert signal to neo.AnalogSignal - signal_neo = neo.AnalogSignal(signal, units='mV', - t_start=0. * pq.ms, - sampling_rate=self.sampling_rate, - dtype=float) + signal_neo = neo.AnalogSignal( + signal, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=self.sampling_rate, + dtype=float, + ) rho = elephant.signal_processing.cross_correlation_function( - signal_neo, [[0, 1], [0, 2]]) + signal_neo, [[0, 1], [0, 2]] + ) # Cross-correlation of sine and cosine should be sine assert_array_almost_equal( - rho.magnitude[:, 0], np.sin(2. * np.pi * freq * rho.times), - decimal=2) + rho.magnitude[:, 0], np.sin(2.0 * np.pi * freq * rho.times), decimal=2 + ) self.assertEqual(rho.shape, (signal.shape[0], 2)) # 2 pairs def test_cross_correlation_nlags(self): @@ -59,14 +64,19 @@ def test_cross_correlation_nlags(self): """ nlags = 30 signal = np.zeros((self.n_samples, 2)) - signal[:, 0] = 0.2 * np.sin(2. * np.pi * self.freq * self.times) - signal[:, 1] = 5.3 * np.cos(2. * np.pi * self.freq * self.times) + signal[:, 0] = 0.2 * np.sin(2.0 * np.pi * self.freq * self.times) + signal[:, 1] = 5.3 * np.cos(2.0 * np.pi * self.freq * self.times) # Convert signal to neo.AnalogSignal - signal = neo.AnalogSignal(signal, units='mV', t_start=0. * pq.ms, - sampling_rate=self.sampling_rate, - dtype=float) + signal = neo.AnalogSignal( + signal, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=self.sampling_rate, + dtype=float, + ) rho = elephant.signal_processing.cross_correlation_function( - signal, [0, 1], n_lags=nlags) + signal, [0, 1], n_lags=nlags + ) # Test if vector of lags tau has correct length assert len(rho.times) == 2 * int(nlags) + 1 @@ -74,19 +84,25 @@ def test_cross_correlation_phi(self): """ Sine with phase shift phi vs cosine """ - phi = np.pi / 6. + phi = np.pi / 6.0 signal = np.zeros((self.n_samples, 2)) - signal[:, 0] = 0.2 * np.sin(2. * np.pi * self.freq * self.times + phi) - signal[:, 1] = 5.3 * np.cos(2. * np.pi * self.freq * self.times) + signal[:, 0] = 0.2 * np.sin(2.0 * np.pi * self.freq * self.times + phi) + signal[:, 1] = 5.3 * np.cos(2.0 * np.pi * self.freq * self.times) # Convert signal to neo.AnalogSignal - signal = neo.AnalogSignal(signal, units='mV', t_start=0. * pq.ms, - sampling_rate=self.sampling_rate, - dtype=float) - rho = elephant.signal_processing.cross_correlation_function( - signal, [0, 1]) + signal = neo.AnalogSignal( + signal, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=self.sampling_rate, + dtype=float, + ) + rho = elephant.signal_processing.cross_correlation_function(signal, [0, 1]) # Cross-correlation of sine and cosine should be sine + phi - assert_array_almost_equal(rho.magnitude[:, 0], np.sin( - 2. * np.pi * self.freq * rho.times + phi), decimal=2) + assert_array_almost_equal( + rho.magnitude[:, 0], + np.sin(2.0 * np.pi * self.freq * rho.times + phi), + decimal=2, + ) def test_cross_correlation_envelope(self): """ @@ -95,82 +111,281 @@ def test_cross_correlation_envelope(self): # Sine with phase shift phi vs cosine for different frequencies nlags = 800 # nlags need to be smaller than N/2 b/c border effects signal = np.zeros((self.n_samples, 2)) - signal[:, 0] = 0.2 * np.sin(2. * np.pi * self.freq * self.times) - signal[:, 1] = 5.3 * np.cos(2. * np.pi * self.freq * self.times) + signal[:, 0] = 0.2 * np.sin(2.0 * np.pi * self.freq * self.times) + signal[:, 1] = 5.3 * np.cos(2.0 * np.pi * self.freq * self.times) # Convert signal to neo.AnalogSignal - signal = neo.AnalogSignal(signal, units='mV', t_start=0. * pq.ms, - sampling_rate=self.sampling_rate, - dtype=float) + signal = neo.AnalogSignal( + signal, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=self.sampling_rate, + dtype=float, + ) envelope = elephant.signal_processing.cross_correlation_function( - signal, [0, 1], n_lags=nlags, hilbert_envelope=True) + signal, [0, 1], n_lags=nlags, hilbert_envelope=True + ) # Envelope should be one for sinusoidal function assert_array_almost_equal(envelope, np.ones_like(envelope), decimal=2) def test_cross_correlation_biased(self): - signal = np.c_[np.sin(2. * np.pi * self.freq * self.times), - np.cos(2. * np.pi * self.freq * self.times)] * pq.mV - signal = neo.AnalogSignal(signal, t_start=0. * pq.ms, - sampling_rate=self.sampling_rate) + signal = ( + np.c_[ + np.sin(2.0 * np.pi * self.freq * self.times), + np.cos(2.0 * np.pi * self.freq * self.times), + ] + * pq.mV + ) + signal = neo.AnalogSignal( + signal, t_start=0.0 * pq.ms, sampling_rate=self.sampling_rate + ) raw = elephant.signal_processing.cross_correlation_function( - signal, [0, 1], scaleopt='none' + signal, [0, 1], scaleopt="none" ) biased = elephant.signal_processing.cross_correlation_function( - signal, [0, 1], scaleopt='biased' + signal, [0, 1], scaleopt="biased" ) assert_array_almost_equal(biased, raw / biased.shape[0]) def test_cross_correlation_coeff(self): - signal = np.c_[np.sin(2. * np.pi * self.freq * self.times), - np.cos(2. * np.pi * self.freq * self.times)] * pq.mV - signal = neo.AnalogSignal(signal, t_start=0. * pq.ms, - sampling_rate=self.sampling_rate) + signal = ( + np.c_[ + np.sin(2.0 * np.pi * self.freq * self.times), + np.cos(2.0 * np.pi * self.freq * self.times), + ] + * pq.mV + ) + signal = neo.AnalogSignal( + signal, t_start=0.0 * pq.ms, sampling_rate=self.sampling_rate + ) normalized = elephant.signal_processing.cross_correlation_function( - signal, [0, 1], scaleopt='coeff' + signal, [0, 1], scaleopt="coeff" ) sig1, sig2 = signal.magnitude.T target_numpy = np.correlate(sig1, sig2, mode="same") - target_numpy /= np.sqrt((sig1 ** 2).sum() * (sig2 ** 2).sum()) + target_numpy /= np.sqrt((sig1**2).sum() * (sig2**2).sum()) target_numpy = np.expand_dims(target_numpy, axis=1) - assert_array_almost_equal(normalized.magnitude, - target_numpy, - decimal=3) + assert_array_almost_equal(normalized.magnitude, target_numpy, decimal=3) def test_cross_correlation_coeff_autocorr(self): # Numpy/Matlab equivalent - signal = np.sin(2. * np.pi * self.freq * self.times) + signal = np.sin(2.0 * np.pi * self.freq * self.times) signal = signal[:, np.newaxis] * pq.mV - signal = neo.AnalogSignal(signal, t_start=0. * pq.ms, - sampling_rate=self.sampling_rate) + signal = neo.AnalogSignal( + signal, t_start=0.0 * pq.ms, sampling_rate=self.sampling_rate + ) normalized = elephant.signal_processing.cross_correlation_function( - signal, [0, 0], scaleopt='coeff' + signal, [0, 0], scaleopt="coeff" ) # auto-correlation at zero lag should equal 1 self.assertAlmostEqual(normalized[normalized.shape[0] // 2], 1) class ZscoreTestCase(unittest.TestCase): - def setUp(self): - self.test_seq1 = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12, - 4, 12, 59, 2, 4, 18, 33, 25, 2, 34, - 4, 1, 1, 14, 8, 1, 10, 1, 8, 20, - 5, 1, 6, 5, 12, 2, 8, 8, 2, 8, - 2, 10, 2, 1, 1, 2, 15, 3, 20, 6, - 11, 6, 18, 2, 5, 17, 4, 3, 13, 6, - 1, 18, 1, 16, 12, 2, 52, 2, 5, 7, - 6, 25, 6, 5, 3, 15, 4, 3, 16, 3, - 6, 5, 24, 21, 3, 3, 4, 8, 4, 11, - 5, 7, 5, 6, 8, 11, 33, 10, 7, 4] - self.test_seq2 = [6, 3, 0, 0, 18, 4, 14, 98, 3, 56, - 7, 4, 6, 9, 11, 16, 13, 3, 2, 15, - 24, 1, 0, 7, 4, 4, 9, 24, 12, 11, - 9, 7, 9, 8, 5, 2, 7, 12, 15, 17, - 3, 7, 2, 1, 0, 17, 2, 6, 3, 32, - 22, 19, 11, 8, 5, 4, 3, 2, 7, 21, - 24, 2, 5, 10, 11, 14, 6, 8, 4, 12, - 6, 5, 2, 22, 25, 19, 16, 22, 13, 2, - 19, 20, 17, 19, 2, 4, 1, 3, 5, 23, - 20, 15, 4, 7, 10, 14, 15, 15, 20, 1] + self.test_seq1 = [ + 1, + 28, + 4, + 47, + 5, + 16, + 2, + 5, + 21, + 12, + 4, + 12, + 59, + 2, + 4, + 18, + 33, + 25, + 2, + 34, + 4, + 1, + 1, + 14, + 8, + 1, + 10, + 1, + 8, + 20, + 5, + 1, + 6, + 5, + 12, + 2, + 8, + 8, + 2, + 8, + 2, + 10, + 2, + 1, + 1, + 2, + 15, + 3, + 20, + 6, + 11, + 6, + 18, + 2, + 5, + 17, + 4, + 3, + 13, + 6, + 1, + 18, + 1, + 16, + 12, + 2, + 52, + 2, + 5, + 7, + 6, + 25, + 6, + 5, + 3, + 15, + 4, + 3, + 16, + 3, + 6, + 5, + 24, + 21, + 3, + 3, + 4, + 8, + 4, + 11, + 5, + 7, + 5, + 6, + 8, + 11, + 33, + 10, + 7, + 4, + ] + self.test_seq2 = [ + 6, + 3, + 0, + 0, + 18, + 4, + 14, + 98, + 3, + 56, + 7, + 4, + 6, + 9, + 11, + 16, + 13, + 3, + 2, + 15, + 24, + 1, + 0, + 7, + 4, + 4, + 9, + 24, + 12, + 11, + 9, + 7, + 9, + 8, + 5, + 2, + 7, + 12, + 15, + 17, + 3, + 7, + 2, + 1, + 0, + 17, + 2, + 6, + 3, + 32, + 22, + 19, + 11, + 8, + 5, + 4, + 3, + 2, + 7, + 21, + 24, + 2, + 5, + 10, + 11, + 14, + 6, + 8, + 4, + 12, + 6, + 5, + 2, + 22, + 25, + 19, + 16, + 22, + 13, + 2, + 19, + 20, + 17, + 19, + 2, + 4, + 1, + 3, + 5, + 23, + 20, + 15, + 4, + 7, + 10, + 14, + 15, + 15, + 20, + 1, + ] def test_zscore_single_dup(self): """ @@ -178,8 +393,12 @@ def test_zscore_single_dup(self): duplicate. """ signal = neo.AnalogSignal( - self.test_seq1, units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + self.test_seq1, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) m = np.mean(self.test_seq1) s = np.std(self.test_seq1) @@ -187,10 +406,9 @@ def test_zscore_single_dup(self): assert_array_equal(target, scipy.stats.zscore(self.test_seq1)) result = elephant.signal_processing.zscore(signal, inplace=False) - assert_array_almost_equal( - result.magnitude, target.reshape(-1, 1), decimal=9) + assert_array_almost_equal(result.magnitude, target.reshape(-1, 1), decimal=9) - self.assertEqual(result.units, pq.Quantity(1. * pq.dimensionless)) + self.assertEqual(result.units, pq.Quantity(1.0 * pq.dimensionless)) # Assert original signal is untouched self.assertEqual(signal[0].magnitude, self.test_seq1[0]) @@ -204,8 +422,12 @@ def test_zscore_single_inplace(self): operation. """ signal = neo.AnalogSignal( - self.test_seq1, units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + self.test_seq1, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) m = np.mean(self.test_seq1) s = np.std(self.test_seq1) @@ -213,10 +435,9 @@ def test_zscore_single_inplace(self): result = elephant.signal_processing.zscore(signal, inplace=True) - assert_array_almost_equal( - result.magnitude, target.reshape(-1, 1), decimal=9) + assert_array_almost_equal(result.magnitude, target.reshape(-1, 1), decimal=9) - self.assertEqual(result.units, pq.Quantity(1. * pq.dimensionless)) + self.assertEqual(result.units, pq.Quantity(1.0 * pq.dimensionless)) # Assert original signal is overwritten self.assertEqual(signal[0].magnitude, target[0]) @@ -230,9 +451,12 @@ def test_zscore_single_multidim_dup(self): to return a duplicate. """ signal = neo.AnalogSignal( - np.transpose( - np.vstack([self.test_seq1, self.test_seq2])), units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + np.transpose(np.vstack([self.test_seq1, self.test_seq2])), + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) m = np.mean(signal.magnitude, axis=0, keepdims=True) s = np.std(signal.magnitude, axis=0, keepdims=True) @@ -249,12 +473,14 @@ def test_zscore_single_multidim_dup(self): def test_zscore_array_annotations(self): signal = neo.AnalogSignal( - self.test_seq1, units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, - array_annotations=dict(valid=True, my_list=[0])) + self.test_seq1, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + array_annotations=dict(valid=True, my_list=[0]), + ) zscored = elephant.signal_processing.zscore(signal, inplace=False) - self.assertDictEqual(signal.array_annotations, - zscored.array_annotations) + self.assertDictEqual(signal.array_annotations, zscored.array_annotations) def test_zscore_single_multidim_inplace(self): """ @@ -262,14 +488,18 @@ def test_zscore_single_multidim_inplace(self): for an inplace operation. """ signal = neo.AnalogSignal( - np.vstack([self.test_seq1, self.test_seq2]), units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + np.vstack([self.test_seq1, self.test_seq2]), + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) m = np.mean(signal.magnitude, axis=0, keepdims=True) s = np.std(signal.magnitude, axis=0, keepdims=True) - ground_truth = np.divide(signal.magnitude - m, s, - out=np.zeros_like(signal.magnitude), - where=s != 0) + ground_truth = np.divide( + signal.magnitude - m, s, out=np.zeros_like(signal.magnitude), where=s != 0 + ) result = elephant.signal_processing.zscore(signal, inplace=True) assert_array_almost_equal(result.magnitude, ground_truth, decimal=8) @@ -287,16 +517,19 @@ def test_zscore_single_dup_int(self): be of type float). """ signal = neo.AnalogSignal( - self.test_seq1, units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=int) + self.test_seq1, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=int, + ) m = np.mean(self.test_seq1) s = np.std(self.test_seq1) target = (self.test_seq1 - m) / s result = elephant.signal_processing.zscore(signal, inplace=False) - assert_array_almost_equal(result.magnitude, target.reshape(-1, 1), - decimal=9) + assert_array_almost_equal(result.magnitude, target.reshape(-1, 1), decimal=9) # Assert original signal is untouched self.assertEqual(signal.magnitude[0], self.test_seq1[0]) @@ -311,8 +544,12 @@ def test_zscore_single_inplace_int(self): """ signal = neo.AnalogSignal( - self.test_seq1, units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=int) + self.test_seq1, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=int, + ) with self.assertRaises(ValueError): elephant.signal_processing.zscore(signal, inplace=True) @@ -324,12 +561,18 @@ def test_zscore_list_dup(self): """ signal1 = neo.AnalogSignal( np.transpose(np.vstack([self.test_seq1, self.test_seq1])), - units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) signal2 = neo.AnalogSignal( np.transpose(np.vstack([self.test_seq1, self.test_seq2])), - units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) signal_list = [signal1, signal2] m = np.mean(np.hstack([self.test_seq1, self.test_seq1])) @@ -346,10 +589,14 @@ def test_zscore_list_dup(self): assert_array_almost_equal( result[0].magnitude, - np.transpose(np.vstack([target11, target12])), decimal=9) + np.transpose(np.vstack([target11, target12])), + decimal=9, + ) assert_array_almost_equal( result[1].magnitude, - np.transpose(np.vstack([target21, target22])), decimal=9) + np.transpose(np.vstack([target21, target22])), + decimal=9, + ) # Assert original signal is untouched self.assertEqual(signal1.magnitude[0, 0], self.test_seq1[0]) @@ -366,12 +613,18 @@ def test_zscore_list_inplace(self): """ signal1 = neo.AnalogSignal( np.transpose(np.vstack([self.test_seq1, self.test_seq1])), - units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) signal2 = neo.AnalogSignal( np.transpose(np.vstack([self.test_seq1, self.test_seq2])), - units='mV', - t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float) + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=float, + ) signal_list = [signal1, signal2] m = np.mean(np.hstack([self.test_seq1, self.test_seq1])) @@ -388,10 +641,14 @@ def test_zscore_list_inplace(self): assert_array_almost_equal( result[0].magnitude, - np.transpose(np.vstack([target11, target12])), decimal=9) + np.transpose(np.vstack([target11, target12])), + decimal=9, + ) assert_array_almost_equal( result[1].magnitude, - np.transpose(np.vstack([target21, target22])), decimal=9) + np.transpose(np.vstack([target21, target22])), + decimal=9, + ) # Assert original signal is overwritten self.assertEqual(signal1[0, 0].magnitude, target11[0]) @@ -403,13 +660,15 @@ def test_zscore_list_inplace(self): def test_z_score_wrong_input(self): # wrong type - self.assertRaises(TypeError, elephant.signal_processing.zscore, - signal=[1, 2] * pq.uV) + self.assertRaises( + TypeError, elephant.signal_processing.zscore, signal=[1, 2] * pq.uV + ) # units mismatch asig1 = neo.AnalogSignal([0, 1], units=pq.uV, sampling_rate=1 * pq.ms) asig2 = neo.AnalogSignal([0, 1], units=pq.V, sampling_rate=1 * pq.ms) - self.assertRaises(ValueError, elephant.signal_processing.zscore, - signal=[asig1, asig2]) + self.assertRaises( + ValueError, elephant.signal_processing.zscore, signal=[asig1, asig2] + ) def test_z_score_np_float32_64(self): """ @@ -421,16 +680,18 @@ def test_z_score_np_float32_64(self): test_types = (np.float32, np.float64) for test_type in test_types: with self.subTest(test_type): - signal = neo.AnalogSignal(self.test_seq1, units='mV', - t_start=0. * pq.ms, - sampling_rate=1000. * pq.Hz, - dtype=test_type) + signal = neo.AnalogSignal( + self.test_seq1, + units="mV", + t_start=0.0 * pq.ms, + sampling_rate=1000.0 * pq.Hz, + dtype=test_type, + ) # This should not raise a ValueError elephant.signal_processing.zscore(signal, inplace=True) class ButterTestCase(unittest.TestCase): - def test_butter_filter_type(self): """ Test if correct type of filtering is performed according to how cut-off @@ -438,40 +699,42 @@ def test_butter_filter_type(self): """ # generate white noise AnalogSignal noise = neo.AnalogSignal( - np.random.normal(size=5000), - sampling_rate=1000 * pq.Hz, units='mV') + np.random.normal(size=5000), sampling_rate=1000 * pq.Hz, units="mV" + ) # test high-pass filtering: power at the lowest frequency # should be almost zero # Note: the default detrend function of scipy.signal.welch() seems to # cause artificial finite power at the lowest frequencies. Here I avoid # this by using an identity function for detrending - filtered_noise = elephant.signal_processing.butter( - noise, 250.0 * pq.Hz, None) - _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0, - detrend=lambda x: x) + filtered_noise = elephant.signal_processing.butter(noise, 250.0 * pq.Hz, None) + _, psd = spsig.welch( + filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x + ) self.assertAlmostEqual(psd[0, 0], 0) # test low-pass filtering: power at the highest frequency # should be almost zero - filtered_noise = elephant.signal_processing.butter( - noise, None, 250.0 * pq.Hz) + filtered_noise = elephant.signal_processing.butter(noise, None, 250.0 * pq.Hz) _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0) self.assertAlmostEqual(psd[0, -1], 0) # test band-pass filtering: power at the lowest and highest frequencies # should be almost zero filtered_noise = elephant.signal_processing.butter( - noise, 200.0 * pq.Hz, 300.0 * pq.Hz) - _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0, - detrend=lambda x: x) + noise, 200.0 * pq.Hz, 300.0 * pq.Hz + ) + _, psd = spsig.welch( + filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x + ) self.assertAlmostEqual(psd[0, 0], 0) self.assertAlmostEqual(psd[0, -1], 0) # test band-stop filtering: power at the intermediate frequency # should be almost zero filtered_noise = elephant.signal_processing.butter( - noise, 400.0 * pq.Hz, 100.0 * pq.Hz) + noise, 400.0 * pq.Hz, 100.0 * pq.Hz + ) _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0) self.assertAlmostEqual(psd[0, 256], 0) @@ -486,51 +749,65 @@ def test_butter_filter_function(self): # generate white noise AnalogSignal noise = neo.AnalogSignal( np.random.normal(size=5000), - sampling_rate=1000 * pq.Hz, units='mV', - array_annotations=dict(valid=True, my_list=[0])) + sampling_rate=1000 * pq.Hz, + units="mV", + array_annotations=dict(valid=True, my_list=[0]), + ) - kwds = {'signal': noise, 'highpass_frequency': 250.0 * pq.Hz, - 'lowpass_frequency': None, 'filter_function': 'filtfilt'} + kwds = { + "signal": noise, + "highpass_frequency": 250.0 * pq.Hz, + "lowpass_frequency": None, + "filter_function": "filtfilt", + } filtered_noise = elephant.signal_processing.butter(**kwds) _, psd_filtfilt = spsig.welch( - filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x) + filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x + ) - kwds['filter_function'] = 'lfilter' + kwds["filter_function"] = "lfilter" filtered_noise = elephant.signal_processing.butter(**kwds) _, psd_lfilter = spsig.welch( - filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x) + filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x + ) - kwds['filter_function'] = 'sosfiltfilt' + kwds["filter_function"] = "sosfiltfilt" filtered_noise = elephant.signal_processing.butter(**kwds) _, psd_sosfiltfilt = spsig.welch( - filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x) + filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x + ) self.assertAlmostEqual(psd_filtfilt[0, 0], psd_lfilter[0, 0]) self.assertAlmostEqual(psd_filtfilt[0, 0], psd_sosfiltfilt[0, 0]) # Test if array_annotations are preserved - self.assertDictEqual(noise.array_annotations, - filtered_noise.array_annotations) + self.assertDictEqual(noise.array_annotations, filtered_noise.array_annotations) def test_butter_invalid_filter_function(self): # generate a dummy AnalogSignal anasig_dummy = neo.AnalogSignal( - np.zeros(5000), sampling_rate=1000 * pq.Hz, units='mV') + np.zeros(5000), sampling_rate=1000 * pq.Hz, units="mV" + ) # test exception upon invalid filtfunc string - kwds = {'signal': anasig_dummy, 'highpass_frequency': 250.0 * pq.Hz, - 'filter_function': 'invalid_filter'} - self.assertRaises( - ValueError, elephant.signal_processing.butter, **kwds) + kwds = { + "signal": anasig_dummy, + "highpass_frequency": 250.0 * pq.Hz, + "filter_function": "invalid_filter", + } + self.assertRaises(ValueError, elephant.signal_processing.butter, **kwds) def test_butter_missing_cutoff_freqs(self): # generate a dummy AnalogSignal anasig_dummy = neo.AnalogSignal( - np.zeros(5000), sampling_rate=1000 * pq.Hz, units='mV') + np.zeros(5000), sampling_rate=1000 * pq.Hz, units="mV" + ) # test a case where no cut-off frequencies are given - kwds = {'signal': anasig_dummy, 'highpass_frequency': None, - 'lowpass_frequency': None} - self.assertRaises( - ValueError, elephant.signal_processing.butter, **kwds) + kwds = { + "signal": anasig_dummy, + "highpass_frequency": None, + "lowpass_frequency": None, + } + self.assertRaises(ValueError, elephant.signal_processing.butter, **kwds) def test_butter_input_types(self): # generate white noise data of different types @@ -540,7 +817,8 @@ def test_butter_input_types(self): # check input as NumPy ndarray filtered_noise_np = elephant.signal_processing.butter( - noise_np, 400.0, 100.0, sampling_frequency=1000.0) + noise_np, 400.0, 100.0, sampling_frequency=1000.0 + ) self.assertTrue(isinstance(filtered_noise_np, np.ndarray)) self.assertFalse(isinstance(filtered_noise_np, pq.quantity.Quantity)) self.assertFalse(isinstance(filtered_noise_np, neo.AnalogSignal)) @@ -548,108 +826,121 @@ def test_butter_input_types(self): # check input as Quantity array filtered_noise_pq = elephant.signal_processing.butter( - noise_pq, 400.0 * pq.Hz, 100.0 * pq.Hz, sampling_frequency=1000.0) + noise_pq, 400.0 * pq.Hz, 100.0 * pq.Hz, sampling_frequency=1000.0 + ) self.assertTrue(isinstance(filtered_noise_pq, pq.quantity.Quantity)) self.assertFalse(isinstance(filtered_noise_pq, neo.AnalogSignal)) self.assertEqual(filtered_noise_pq.shape, noise_pq.shape) # check input as neo AnalogSignal - filtered_noise = elephant.signal_processing.butter(noise, - 400.0 * pq.Hz, - 100.0 * pq.Hz) + filtered_noise = elephant.signal_processing.butter( + noise, 400.0 * pq.Hz, 100.0 * pq.Hz + ) self.assertTrue(isinstance(filtered_noise, neo.AnalogSignal)) self.assertEqual(filtered_noise.shape, noise.shape) # check if the results from different input types are identical - self.assertTrue(np.all( - filtered_noise_pq.magnitude == filtered_noise_np)) - self.assertTrue(np.all( - filtered_noise.magnitude[:, 0] == filtered_noise_np)) + self.assertTrue(np.all(filtered_noise_pq.magnitude == filtered_noise_np)) + self.assertTrue(np.all(filtered_noise.magnitude[:, 0] == filtered_noise_np)) def test_butter_axis(self): noise = np.random.normal(size=(4, 5000)) filtered_noise = elephant.signal_processing.butter( - noise, 250.0, sampling_frequency=1000.0) + noise, 250.0, sampling_frequency=1000.0 + ) filtered_noise_transposed = elephant.signal_processing.butter( - noise.T, 250.0, sampling_frequency=1000.0, axis=0) + noise.T, 250.0, sampling_frequency=1000.0, axis=0 + ) self.assertTrue(np.all(filtered_noise == filtered_noise_transposed.T)) def test_butter_multidim_input(self): noise_pq = np.random.normal(size=(4, 5000)) * pq.mV - noise_neo = neo.AnalogSignal( - noise_pq.T, sampling_rate=1000.0 * pq.Hz) - noise_neo1d = neo.AnalogSignal( - noise_pq[0], sampling_rate=1000.0 * pq.Hz) + noise_neo = neo.AnalogSignal(noise_pq.T, sampling_rate=1000.0 * pq.Hz) + noise_neo1d = neo.AnalogSignal(noise_pq[0], sampling_rate=1000.0 * pq.Hz) filtered_noise_pq = elephant.signal_processing.butter( - noise_pq, 250.0, sampling_frequency=1000.0) - filtered_noise_neo = elephant.signal_processing.butter( - noise_neo, 250.0) - filtered_noise_neo1d = elephant.signal_processing.butter( - noise_neo1d, 250.0) - self.assertTrue(np.all( - filtered_noise_pq.magnitude == filtered_noise_neo.T.magnitude)) - self.assertTrue(np.all( - filtered_noise_neo1d.magnitude[:, 0] == - filtered_noise_neo.magnitude[:, 0])) + noise_pq, 250.0, sampling_frequency=1000.0 + ) + filtered_noise_neo = elephant.signal_processing.butter(noise_neo, 250.0) + filtered_noise_neo1d = elephant.signal_processing.butter(noise_neo1d, 250.0) + self.assertTrue( + np.all(filtered_noise_pq.magnitude == filtered_noise_neo.T.magnitude) + ) + self.assertTrue( + np.all( + filtered_noise_neo1d.magnitude[:, 0] + == filtered_noise_neo.magnitude[:, 0] + ) + ) class HilbertTestCase(unittest.TestCase): - def setUp(self): # Generate test data of a harmonic function over a long time time = np.arange(0, 1000, 0.1) * pq.ms freq = 10 * pq.Hz - self.amplitude = np.array([ - np.linspace(1, 10, len(time)), - np.linspace(1, 10, len(time)), - np.ones((len(time))), - np.ones((len(time))) * 10.]).T - self.phase = np.array([ - (time * freq).simplified.magnitude * 2. * np.pi, - (time * freq).simplified.magnitude * 2. * np.pi + np.pi / 2, - (time * freq).simplified.magnitude * 2. * np.pi + np.pi, - (time * freq).simplified.magnitude * 2. * 2. * np.pi]).T - - self.phase = np.mod(self.phase + np.pi, 2. * np.pi) - np.pi + self.amplitude = np.array( + [ + np.linspace(1, 10, len(time)), + np.linspace(1, 10, len(time)), + np.ones((len(time))), + np.ones((len(time))) * 10.0, + ] + ).T + self.phase = np.array( + [ + (time * freq).simplified.magnitude * 2.0 * np.pi, + (time * freq).simplified.magnitude * 2.0 * np.pi + np.pi / 2, + (time * freq).simplified.magnitude * 2.0 * np.pi + np.pi, + (time * freq).simplified.magnitude * 2.0 * 2.0 * np.pi, + ] + ).T + + self.phase = np.mod(self.phase + np.pi, 2.0 * np.pi) - np.pi # rising amplitude cosine, random ampl. sine, flat inverse cosine, # flat cosine at double frequency - sigs = np.vstack([ - self.amplitude[:, 0] * np.cos(self.phase[:, 0]), - self.amplitude[:, 1] * np.cos(self.phase[:, 1]), - self.amplitude[:, 2] * np.cos(self.phase[:, 2]), - self.amplitude[:, 3] * np.cos(self.phase[:, 3])]) + sigs = np.vstack( + [ + self.amplitude[:, 0] * np.cos(self.phase[:, 0]), + self.amplitude[:, 1] * np.cos(self.phase[:, 1]), + self.amplitude[:, 2] * np.cos(self.phase[:, 2]), + self.amplitude[:, 3] * np.cos(self.phase[:, 3]), + ] + ) array_annotations = dict(my_list=np.arange(sigs.shape[0])) self.long_signals = neo.AnalogSignal( - sigs.T, units='mV', - t_start=0. * pq.ms, + sigs.T, + units="mV", + t_start=0.0 * pq.ms, sampling_rate=(len(time) / (time[-1] - time[0])).rescale(pq.Hz), dtype=float, - array_annotations=array_annotations) + array_annotations=array_annotations, + ) # Generate test data covering a single oscillation cycle in 1s only phases = np.arange(0, 2 * np.pi, np.pi / 256) - sigs = np.vstack([ - np.sin(phases), - np.cos(phases), - np.sin(2 * phases), - np.cos(2 * phases)]) + sigs = np.vstack( + [np.sin(phases), np.cos(phases), np.sin(2 * phases), np.cos(2 * phases)] + ) self.one_period = neo.AnalogSignal( - sigs.T, units=pq.mV, - sampling_rate=len(phases) * pq.Hz) + sigs.T, units=pq.mV, sampling_rate=len(phases) * pq.Hz + ) def test_hilbert_pad_type_error(self): """ Tests if incorrect pad_type raises ValueError. """ - padding = 'wrong_type' + padding = "wrong_type" self.assertRaises( - ValueError, elephant.signal_processing.hilbert, - self.long_signals, padding=padding) + ValueError, + elephant.signal_processing.hilbert, + self.long_signals, + padding=padding, + ) def test_hilbert_output_shape(self): """ @@ -658,21 +949,24 @@ def test_hilbert_output_shape(self): """ true_shape = np.shape(self.long_signals) output = elephant.signal_processing.hilbert( - self.long_signals, padding='nextpow') + self.long_signals, padding="nextpow" + ) self.assertEqual(np.shape(output), true_shape) self.assertEqual(output.units, pq.dimensionless) - output = elephant.signal_processing.hilbert( - self.long_signals, padding=16384) + output = elephant.signal_processing.hilbert(self.long_signals, padding=16384) self.assertEqual(np.shape(output), true_shape) self.assertEqual(output.units, pq.dimensionless) def test_hilbert_array_annotations(self): - output = elephant.signal_processing.hilbert(self.long_signals, - padding='nextpow') + output = elephant.signal_processing.hilbert( + self.long_signals, padding="nextpow" + ) # Test if array_annotations are preserved self.assertSetEqual(set(output.array_annotations.keys()), {"my_list"}) - assert_array_equal(output.array_annotations['my_list'], - self.long_signals.array_annotations['my_list']) + assert_array_equal( + output.array_annotations["my_list"], + self.long_signals.array_annotations["my_list"], + ) def test_hilbert_theoretical_long_signals(self): """ @@ -680,9 +974,8 @@ def test_hilbert_theoretical_long_signals(self): phase of long test signals """ # Performing test using all pad types - for padding in ['nextpow', 'none', 16384]: - h = elephant.signal_processing.hilbert( - self.long_signals, padding=padding) + for padding in ["nextpow", "none", 16384]: + h = elephant.signal_processing.hilbert(self.long_signals, padding=padding) phase = np.angle(h.magnitude) amplitude = np.abs(h.magnitude) @@ -690,9 +983,8 @@ def test_hilbert_theoretical_long_signals(self): # The real part should be equal to the original long_signals assert_array_almost_equal( - real_value, - self.long_signals.magnitude, - decimal=14) + real_value, self.long_signals.magnitude, decimal=14 + ) # Test only in the middle half of the array (border effects) ind1 = int(len(h.times) / 4) @@ -701,13 +993,11 @@ def test_hilbert_theoretical_long_signals(self): # Calculate difference in phase between signal and original phase # and use smaller of any two phase differences phasediff = np.abs(phase[ind1:ind2, :] - self.phase[ind1:ind2, :]) - phasediff[phasediff >= np.pi] = \ - 2 * np.pi - phasediff[phasediff >= np.pi] + phasediff[phasediff >= np.pi] = 2 * np.pi - phasediff[phasediff >= np.pi] # Calculate difference in amplitude between signal and original # amplitude - amplitudediff = \ - amplitude[ind1:ind2, :] - self.amplitude[ind1:ind2, :] + amplitudediff = amplitude[ind1:ind2, :] - self.amplitude[ind1:ind2, :] # assert_allclose(phasediff, 0, atol=0.1) assert_allclose(amplitudediff, 0, atol=0.5) @@ -725,9 +1015,8 @@ def test_hilbert_theoretical_one_period(self): decimal = 14 # Performing test using both pad types - for padding in ['nextpow', 'none', 512]: - h = elephant.signal_processing.hilbert( - self.one_period, padding=padding) + for padding in ["nextpow", "none", 512]: + h = elephant.signal_processing.hilbert(self.one_period, padding=padding) amplitude = np.abs(h.magnitude) phase = np.angle(h.magnitude) @@ -735,40 +1024,38 @@ def test_hilbert_theoretical_one_period(self): # The real part should be equal to the original long_signals: assert_array_almost_equal( - real_value, - self.one_period.magnitude, - decimal=decimal) + real_value, self.one_period.magnitude, decimal=decimal + ) # The absolute value should be 1 everywhere, for this input: assert_array_almost_equal( - amplitude, - np.ones(self.one_period.magnitude.shape), - decimal=decimal) + amplitude, np.ones(self.one_period.magnitude.shape), decimal=decimal + ) # For the 'slow' sine - the phase should go from -pi/2 to pi/2 in # the first 256 bins: assert_array_almost_equal( phase[:256, 0], np.arange(-np.pi / 2, np.pi / 2, np.pi / 256), - decimal=decimal) + decimal=decimal, + ) # For the 'slow' cosine - the phase should go from 0 to pi in the # same interval: assert_array_almost_equal( - phase[:256, 1], - np.arange(0, np.pi, np.pi / 256), - decimal=decimal) + phase[:256, 1], np.arange(0, np.pi, np.pi / 256), decimal=decimal + ) # The 'fast' sine should make this phase transition in half the # time: assert_array_almost_equal( phase[:128, 2], np.arange(-np.pi / 2, np.pi / 2, np.pi / 128), - decimal=decimal) + decimal=decimal, + ) # The 'fast' cosine should make this phase transition in half the # time: assert_array_almost_equal( - phase[:128, 3], - np.arange(0, np.pi, np.pi / 128), - decimal=decimal) + phase[:128, 3], np.arange(0, np.pi, np.pi / 128), decimal=decimal + ) class WaveletTestCase(unittest.TestCase): @@ -782,24 +1069,17 @@ def setUp(self): self.test_data2 = np.sin(2 * np.pi * self.test_freq2 * self.times) self.test_data_arr = np.vstack([self.test_data1, self.test_data2]) self.test_data = neo.AnalogSignal( - self.test_data_arr.T * pq.mV, t_start=self.times[0] * pq.s, - t_stop=self.times[-1] * pq.s, sampling_period=(1 / self.fs) * pq.s) + self.test_data_arr.T * pq.mV, + t_start=self.times[0] * pq.s, + t_stop=self.times[-1] * pq.s, + sampling_period=(1 / self.fs) * pq.s, + ) self.true_phase1 = np.angle( - self.test_data1 + - 1j * - np.sin( - 2 * - np.pi * - self.test_freq1 * - self.times)) + self.test_data1 + 1j * np.sin(2 * np.pi * self.test_freq1 * self.times) + ) self.true_phase2 = np.angle( - self.test_data2 - - 1j * - np.cos( - 2 * - np.pi * - self.test_freq2 * - self.times)) + self.test_data2 - 1j * np.cos(2 * np.pi * self.test_freq2 * self.times) + ) self.wt_freqs = [10, 20, 30] def test_wavelet_errors(self): @@ -807,32 +1087,38 @@ def test_wavelet_errors(self): Tests if errors are raised as expected. """ # too high center frequency - kwds = {'signal': self.test_data, 'frequency': self.fs / 2} + kwds = {"signal": self.test_data, "frequency": self.fs / 2} self.assertRaises( - ValueError, elephant.signal_processing.wavelet_transform, **kwds) + ValueError, elephant.signal_processing.wavelet_transform, **kwds + ) kwds = { - 'signal': self.test_data_arr, - 'frequency': self.fs / 2, - 'sampling_frequency': self.fs} + "signal": self.test_data_arr, + "frequency": self.fs / 2, + "sampling_frequency": self.fs, + } self.assertRaises( - ValueError, elephant.signal_processing.wavelet_transform, **kwds) + ValueError, elephant.signal_processing.wavelet_transform, **kwds + ) # too high center frequency in a list - kwds = {'signal': self.test_data, - 'frequency': [self.fs / 10, self.fs / 2]} + kwds = {"signal": self.test_data, "frequency": [self.fs / 10, self.fs / 2]} self.assertRaises( - ValueError, elephant.signal_processing.wavelet_transform, **kwds) - kwds = {'signal': self.test_data_arr, - 'frequency': [self.fs / 10, self.fs / 2], - 'sampling_frequency': self.fs} + ValueError, elephant.signal_processing.wavelet_transform, **kwds + ) + kwds = { + "signal": self.test_data_arr, + "frequency": [self.fs / 10, self.fs / 2], + "sampling_frequency": self.fs, + } self.assertRaises( - ValueError, elephant.signal_processing.wavelet_transform, **kwds) + ValueError, elephant.signal_processing.wavelet_transform, **kwds + ) # nco is not positive - kwds = {'signal': self.test_data, 'frequency': self.fs / 10, - 'n_cycles': 0} + kwds = {"signal": self.test_data, "frequency": self.fs / 10, "n_cycles": 0} self.assertRaises( - ValueError, elephant.signal_processing.wavelet_transform, **kwds) + ValueError, elephant.signal_processing.wavelet_transform, **kwds + ) def test_wavelet_io(self): """ @@ -842,14 +1128,14 @@ def test_wavelet_io(self): """ # check the shape of the result array # --- case of single center frequency - wt = elephant.signal_processing.wavelet_transform(self.test_data, - self.fs / 10) + wt = elephant.signal_processing.wavelet_transform(self.test_data, self.fs / 10) self.assertTrue(wt.ndim == self.test_data.ndim) self.assertTrue(wt.shape[0] == self.test_data.shape[0]) # time axis self.assertTrue(wt.shape[1] == self.test_data.shape[1]) # channel axis wt_arr = elephant.signal_processing.wavelet_transform( - self.test_data_arr, self.fs / 10, sampling_frequency=self.fs) + self.test_data_arr, self.fs / 10, sampling_frequency=self.fs + ) self.assertTrue(wt_arr.ndim == self.test_data.ndim) # channel axis self.assertTrue(wt_arr.shape[0] == self.test_data_arr.shape[0]) @@ -857,21 +1143,22 @@ def test_wavelet_io(self): self.assertTrue(wt_arr.shape[1] == self.test_data_arr.shape[1]) wt_arr1d = elephant.signal_processing.wavelet_transform( - self.test_data1, self.fs / 10, sampling_frequency=self.fs) + self.test_data1, self.fs / 10, sampling_frequency=self.fs + ) self.assertTrue(wt_arr1d.ndim == self.test_data1.ndim) # time axis self.assertTrue(wt_arr1d.shape[0] == self.test_data1.shape[0]) # --- case of multiple center frequencies - wt = elephant.signal_processing.wavelet_transform( - self.test_data, self.wt_freqs) + wt = elephant.signal_processing.wavelet_transform(self.test_data, self.wt_freqs) self.assertTrue(wt.ndim == self.test_data.ndim + 1) self.assertTrue(wt.shape[0] == self.test_data.shape[0]) # time axis self.assertTrue(wt.shape[1] == self.test_data.shape[1]) # channel axis self.assertTrue(wt.shape[2] == len(self.wt_freqs)) # frequency axis wt_arr = elephant.signal_processing.wavelet_transform( - self.test_data_arr, self.wt_freqs, sampling_frequency=self.fs) + self.test_data_arr, self.wt_freqs, sampling_frequency=self.fs + ) self.assertTrue(wt_arr.ndim == self.test_data_arr.ndim + 1) # channel axis self.assertTrue(wt_arr.shape[0] == self.test_data_arr.shape[0]) @@ -881,7 +1168,8 @@ def test_wavelet_io(self): self.assertTrue(wt_arr.shape[2] == self.test_data_arr.shape[1]) wt_arr1d = elephant.signal_processing.wavelet_transform( - self.test_data1, self.wt_freqs, sampling_frequency=self.fs) + self.test_data1, self.wt_freqs, sampling_frequency=self.fs + ) self.assertTrue(wt_arr1d.ndim == self.test_data1.ndim + 1) # frequency axis self.assertTrue(wt_arr1d.shape[0] == len(self.wt_freqs)) @@ -899,37 +1187,41 @@ def test_wavelet_io(self): # wavelet transform (in NumPy 1.13.1, they become identical). Here we # only check that they are almost equal. wt_1freq = elephant.signal_processing.wavelet_transform( - self.test_data, self.wt_freqs[0]) + self.test_data, self.wt_freqs[0] + ) wt_3freqs = elephant.signal_processing.wavelet_transform( - self.test_data, self.wt_freqs) - assert_array_almost_equal(wt_1freq[:, 0], wt_3freqs[:, 0, 0], - decimal=12) + self.test_data, self.wt_freqs + ) + assert_array_almost_equal(wt_1freq[:, 0], wt_3freqs[:, 0, 0], decimal=12) def test_wavelet_amplitude(self): """ Tests amplitude properties of the obtained wavelet transform """ # check that the amplitude of WT of a sinusoid is (almost) constant - wt = elephant.signal_processing.wavelet_transform(self.test_data, - self.test_freq1) + wt = elephant.signal_processing.wavelet_transform( + self.test_data, self.test_freq1 + ) # take a middle segment in order to avoid edge effects - amp = np.abs(wt[int(len(wt) / 3):int(len(wt) // 3 * 2), 0]) + amp = np.abs(wt[int(len(wt) / 3) : int(len(wt) // 3 * 2), 0]) mean_amp = amp.mean() - assert_array_almost_equal((amp - mean_amp) / mean_amp, - np.zeros_like(amp), decimal=6) + assert_array_almost_equal( + (amp - mean_amp) / mean_amp, np.zeros_like(amp), decimal=6 + ) # check that the amplitude of WT is (almost) zero when center frequency # is considerably different from signal frequency wt_low = elephant.signal_processing.wavelet_transform( - self.test_data, self.test_freq1 / 10) - amp_low = np.abs(wt_low[int(len(wt) / 3):int(len(wt) // 3 * 2), 0]) + self.test_data, self.test_freq1 / 10 + ) + amp_low = np.abs(wt_low[int(len(wt) / 3) : int(len(wt) // 3 * 2), 0]) assert_array_almost_equal(amp_low, np.zeros_like(amp), decimal=6) # check that zero padding hardly affect the result wt_padded = elephant.signal_processing.wavelet_transform( - self.test_data, self.test_freq1, zero_padding=False) - amp_padded = np.abs( - wt_padded[int(len(wt) / 3):int(len(wt) // 3 * 2), 0]) + self.test_data, self.test_freq1, zero_padding=False + ) + amp_padded = np.abs(wt_padded[int(len(wt) / 3) : int(len(wt) // 3 * 2), 0]) assert_array_almost_equal(amp_padded, amp, decimal=9) def test_wavelet_phase(self): @@ -938,28 +1230,26 @@ def test_wavelet_phase(self): """ # check that the phase of WT is (almost) same as that of the original # sinusoid - wt = elephant.signal_processing.wavelet_transform(self.test_data, - self.test_freq1) - phase = np.angle(wt[int(len(wt) / 3):int(len(wt) // 3 * 2), 0]) - true_phase = self.true_phase1[int(len(wt) / 3):int(len(wt) // 3 * 2)] - assert_array_almost_equal(np.exp(1j * phase), np.exp(1j * true_phase), - decimal=6) + wt = elephant.signal_processing.wavelet_transform( + self.test_data, self.test_freq1 + ) + phase = np.angle(wt[int(len(wt) / 3) : int(len(wt) // 3 * 2), 0]) + true_phase = self.true_phase1[int(len(wt) / 3) : int(len(wt) // 3 * 2)] + assert_array_almost_equal( + np.exp(1j * phase), np.exp(1j * true_phase), decimal=6 + ) # check that zero padding hardly affect the result wt_padded = elephant.signal_processing.wavelet_transform( - self.test_data, self.test_freq1, zero_padding=False) - phase_padded = np.angle( - wt_padded[int(len(wt) / 3):int(len(wt) // 3 * 2), 0]) + self.test_data, self.test_freq1, zero_padding=False + ) + phase_padded = np.angle(wt_padded[int(len(wt) / 3) : int(len(wt) // 3 * 2), 0]) assert_array_almost_equal( - np.exp( - 1j * phase_padded), - np.exp( - 1j * phase), - decimal=9) + np.exp(1j * phase_padded), np.exp(1j * phase), decimal=9 + ) class DerivativeTestCase(unittest.TestCase): - def setUp(self): self.fs = 1000.0 self.tmin = 0.0 @@ -967,73 +1257,81 @@ def setUp(self): self.times = np.arange(self.tmin, self.tmax, 1 / self.fs) self.test_data1 = np.cos(2 * np.pi * self.times) self.test_data2 = np.vstack( - [np.cos(2 * np.pi * self.times), np.sin(2 * np.pi * self.times)]).T + [np.cos(2 * np.pi * self.times), np.sin(2 * np.pi * self.times)] + ).T self.test_signal1 = neo.AnalogSignal( - self.test_data1 * pq.mV, t_start=self.times[0] * pq.s, - t_stop=self.times[-1] * pq.s, sampling_period=(1 / self.fs) * pq.s) + self.test_data1 * pq.mV, + t_start=self.times[0] * pq.s, + t_stop=self.times[-1] * pq.s, + sampling_period=(1 / self.fs) * pq.s, + ) self.test_signal2 = neo.AnalogSignal( - self.test_data2 * pq.mV, t_start=self.times[0] * pq.s, - t_stop=self.times[-1] * pq.s, sampling_period=(1 / self.fs) * pq.s) + self.test_data2 * pq.mV, + t_start=self.times[0] * pq.s, + t_stop=self.times[-1] * pq.s, + sampling_period=(1 / self.fs) * pq.s, + ) def test_derivative_invalid_signal(self): """Test derivative on non-AnalogSignal""" - kwds = {'signal': np.arange(5)} - self.assertRaises( - TypeError, elephant.signal_processing.derivative, **kwds) + kwds = {"signal": np.arange(5)} + self.assertRaises(TypeError, elephant.signal_processing.derivative, **kwds) def test_derivative_units(self): """Test derivative returns AnalogSignal with correct units""" - derivative = elephant.signal_processing.derivative( - self.test_signal1) + derivative = elephant.signal_processing.derivative(self.test_signal1) self.assertTrue(isinstance(derivative, neo.AnalogSignal)) self.assertEqual( - derivative.units, - self.test_signal1.units / self.test_signal1.times.units) + derivative.units, self.test_signal1.units / self.test_signal1.times.units + ) def test_derivative_times(self): """Test derivative returns AnalogSignal with correct times""" - derivative = elephant.signal_processing.derivative( - self.test_signal1) + derivative = elephant.signal_processing.derivative(self.test_signal1) self.assertTrue(isinstance(derivative, neo.AnalogSignal)) # test that sampling period is correct self.assertEqual( - derivative.sampling_period, - 1 / self.fs * self.test_signal1.times.units) + derivative.sampling_period, 1 / self.fs * self.test_signal1.times.units + ) # test that all times are correct - target_times = self.times[:-1] * self.test_signal1.times.units \ + target_times = ( + self.times[:-1] * self.test_signal1.times.units + derivative.sampling_period / 2 + ) assert_array_almost_equal(derivative.times, target_times) # test that t_start and t_stop are correct self.assertEqual(derivative.t_start, target_times[0]) assert_array_almost_equal( - derivative.t_stop, - target_times[-1] + derivative.sampling_period) + derivative.t_stop, target_times[-1] + derivative.sampling_period + ) def test_derivative_values(self): """Test derivative returns AnalogSignal with correct values""" - derivative1 = elephant.signal_processing.derivative( - self.test_signal1) - derivative2 = elephant.signal_processing.derivative( - self.test_signal2) + derivative1 = elephant.signal_processing.derivative(self.test_signal1) + derivative2 = elephant.signal_processing.derivative(self.test_signal2) self.assertTrue(isinstance(derivative1, neo.AnalogSignal)) self.assertTrue(isinstance(derivative2, neo.AnalogSignal)) # single channel assert_array_almost_equal( derivative1.magnitude, - np.vstack([np.diff(self.test_data1)]).T / (1 / self.fs)) + np.vstack([np.diff(self.test_data1)]).T / (1 / self.fs), + ) # multi channel - assert_array_almost_equal(derivative2.magnitude, np.vstack([ - np.diff(self.test_data2[:, 0]), - np.diff(self.test_data2[:, 1])]).T / (1 / self.fs)) + assert_array_almost_equal( + derivative2.magnitude, + np.vstack( + [np.diff(self.test_data2[:, 0]), np.diff(self.test_data2[:, 1])] + ).T + / (1 / self.fs), + ) class RAUCTestCase(unittest.TestCase): - def setUp(self): self.fs = 1000.0 self.tmin = 0.0 @@ -1041,231 +1339,225 @@ def setUp(self): self.times = np.arange(self.tmin, self.tmax, 1 / self.fs) self.test_data1 = np.cos(2 * np.pi * self.times) self.test_data2 = np.vstack( - [np.cos(2 * np.pi * self.times), np.sin(2 * np.pi * self.times)]).T + [np.cos(2 * np.pi * self.times), np.sin(2 * np.pi * self.times)] + ).T self.test_signal1 = neo.AnalogSignal( - self.test_data1 * pq.mV, t_start=self.times[0] * pq.s, - t_stop=self.times[-1] * pq.s, sampling_period=(1 / self.fs) * pq.s) + self.test_data1 * pq.mV, + t_start=self.times[0] * pq.s, + t_stop=self.times[-1] * pq.s, + sampling_period=(1 / self.fs) * pq.s, + ) self.test_signal2 = neo.AnalogSignal( - self.test_data2 * pq.mV, t_start=self.times[0] * pq.s, - t_stop=self.times[-1] * pq.s, sampling_period=(1 / self.fs) * pq.s) + self.test_data2 * pq.mV, + t_start=self.times[0] * pq.s, + t_stop=self.times[-1] * pq.s, + sampling_period=(1 / self.fs) * pq.s, + ) def test_rauc_invalid_signal(self): """Test rauc on non-AnalogSignal""" - kwds = {'signal': np.arange(5)} - self.assertRaises( - ValueError, elephant.signal_processing.rauc, **kwds) + kwds = {"signal": np.arange(5)} + self.assertRaises(ValueError, elephant.signal_processing.rauc, **kwds) def test_rauc_invalid_bin_duration(self): """Test rauc on bad bin duration""" - kwds = {'signal': self.test_signal1, 'bin_duration': 'bad'} - self.assertRaises( - ValueError, elephant.signal_processing.rauc, **kwds) + kwds = {"signal": self.test_signal1, "bin_duration": "bad"} + self.assertRaises(ValueError, elephant.signal_processing.rauc, **kwds) def test_rauc_invalid_baseline(self): """Test rauc on bad baseline""" - kwds = {'signal': self.test_signal1, 'baseline': 'bad'} - self.assertRaises( - ValueError, elephant.signal_processing.rauc, **kwds) + kwds = {"signal": self.test_signal1, "baseline": "bad"} + self.assertRaises(ValueError, elephant.signal_processing.rauc, **kwds) def test_rauc_units(self): """Test rauc returns Quantity or AnalogSignal with correct units""" # test that single-bin result is Quantity with correct units - rauc = elephant.signal_processing.rauc( - self.test_signal1) + rauc = elephant.signal_processing.rauc(self.test_signal1) self.assertTrue(isinstance(rauc, pq.Quantity)) self.assertEqual( - rauc.units, - self.test_signal1.units * self.test_signal1.times.units) + rauc.units, self.test_signal1.units * self.test_signal1.times.units + ) # test that multi-bin result is AnalogSignal with correct units rauc_arr = elephant.signal_processing.rauc( - self.test_signal1, bin_duration=1 * pq.s) + self.test_signal1, bin_duration=1 * pq.s + ) self.assertTrue(isinstance(rauc_arr, neo.AnalogSignal)) self.assertEqual( - rauc_arr.units, - self.test_signal1.units * self.test_signal1.times.units) + rauc_arr.units, self.test_signal1.units * self.test_signal1.times.units + ) def test_rauc_times_without_overextending_bin(self): """Test rauc returns correct times when signal is binned evenly""" bin_duration = 1 * pq.s # results in all bin centers < original t_stop rauc_arr = elephant.signal_processing.rauc( - self.test_signal1, bin_duration=bin_duration) + self.test_signal1, bin_duration=bin_duration + ) self.assertTrue(isinstance(rauc_arr, neo.AnalogSignal)) # test that sampling period is correct self.assertEqual(rauc_arr.sampling_period, bin_duration) # test that all times are correct - target_times = np.arange(self.tmin, - self.tmax, - bin_duration.magnitude) \ - * bin_duration.units + bin_duration / 2 + target_times = ( + np.arange(self.tmin, self.tmax, bin_duration.magnitude) * bin_duration.units + + bin_duration / 2 + ) assert_array_almost_equal(rauc_arr.times, target_times) # test that t_start and t_stop are correct self.assertEqual(rauc_arr.t_start, target_times[0]) - assert_array_almost_equal( - rauc_arr.t_stop, - target_times[-1] + bin_duration) + assert_array_almost_equal(rauc_arr.t_stop, target_times[-1] + bin_duration) def test_rauc_times_with_overextending_bin(self): """Test rauc returns correct times when signal is NOT binned evenly""" # results in one bin center > original t_stop bin_duration = 0.99 * pq.s rauc_arr = elephant.signal_processing.rauc( - self.test_signal1, bin_duration=bin_duration) + self.test_signal1, bin_duration=bin_duration + ) self.assertTrue(isinstance(rauc_arr, neo.AnalogSignal)) # test that sampling period is correct self.assertEqual(rauc_arr.sampling_period, bin_duration) # test that all times are correct - target_times = np.arange(self.tmin, - self.tmax, - bin_duration.magnitude) \ - * bin_duration.units + bin_duration / 2 + target_times = ( + np.arange(self.tmin, self.tmax, bin_duration.magnitude) * bin_duration.units + + bin_duration / 2 + ) assert_array_almost_equal(rauc_arr.times, target_times) # test that t_start and t_stop are correct self.assertEqual(rauc_arr.t_start, target_times[0]) - assert_array_almost_equal( - rauc_arr.t_stop, - target_times[-1] + bin_duration) + assert_array_almost_equal(rauc_arr.t_stop, target_times[-1] + bin_duration) def test_rauc_values_one_bin(self): """Test rauc returns correct values when there is just one bin""" - rauc1 = elephant.signal_processing.rauc( - self.test_signal1) - rauc2 = elephant.signal_processing.rauc( - self.test_signal2) + rauc1 = elephant.signal_processing.rauc(self.test_signal1) + rauc2 = elephant.signal_processing.rauc(self.test_signal2) self.assertTrue(isinstance(rauc1, pq.Quantity)) self.assertTrue(isinstance(rauc2, pq.Quantity)) # single channel - assert_array_almost_equal( - rauc1.magnitude, - np.array([6.36517679])) + assert_array_almost_equal(rauc1.magnitude, np.array([6.36517679])) # multi channel - assert_array_almost_equal( - rauc2.magnitude, - np.array([6.36517679, 6.36617364])) + assert_array_almost_equal(rauc2.magnitude, np.array([6.36517679, 6.36617364])) def test_rauc_values_multi_bin(self): """Test rauc returns correct values when there are multiple bins""" rauc_arr1 = elephant.signal_processing.rauc( - self.test_signal1, bin_duration=0.99 * pq.s) + self.test_signal1, bin_duration=0.99 * pq.s + ) rauc_arr2 = elephant.signal_processing.rauc( - self.test_signal2, bin_duration=0.99 * pq.s) + self.test_signal2, bin_duration=0.99 * pq.s + ) self.assertTrue(isinstance(rauc_arr1, neo.AnalogSignal)) self.assertTrue(isinstance(rauc_arr2, neo.AnalogSignal)) # single channel - assert_array_almost_equal(rauc_arr1.magnitude, np.array([ - [0.62562647], - [0.62567202], - [0.62576076], - [0.62589236], - [0.62606628], - [0.62628184], - [0.62653819], - [0.62683432], - [0.62716907], - [0.62754110], - [0.09304862]])) + assert_array_almost_equal( + rauc_arr1.magnitude, + np.array( + [ + [0.62562647], + [0.62567202], + [0.62576076], + [0.62589236], + [0.62606628], + [0.62628184], + [0.62653819], + [0.62683432], + [0.62716907], + [0.62754110], + [0.09304862], + ] + ), + ) # multi channel - assert_array_almost_equal(rauc_arr2.magnitude, np.array([ - [0.62562647, 0.63623770], - [0.62567202, 0.63554830], - [0.62576076, 0.63486313], - [0.62589236, 0.63418488], - [0.62606628, 0.63351623], - [0.62628184, 0.63285983], - [0.62653819, 0.63221825], - [0.62683432, 0.63159403], - [0.62716907, 0.63098964], - [0.62754110, 0.63040747], - [0.09304862, 0.03039579]])) + assert_array_almost_equal( + rauc_arr2.magnitude, + np.array( + [ + [0.62562647, 0.63623770], + [0.62567202, 0.63554830], + [0.62576076, 0.63486313], + [0.62589236, 0.63418488], + [0.62606628, 0.63351623], + [0.62628184, 0.63285983], + [0.62653819, 0.63221825], + [0.62683432, 0.63159403], + [0.62716907, 0.63098964], + [0.62754110, 0.63040747], + [0.09304862, 0.03039579], + ] + ), + ) def test_rauc_mean_baseline(self): """Test rauc returns correct values when baseline='mean' is given""" - rauc1 = elephant.signal_processing.rauc( - self.test_signal1, baseline='mean') - rauc2 = elephant.signal_processing.rauc( - self.test_signal2, baseline='mean') + rauc1 = elephant.signal_processing.rauc(self.test_signal1, baseline="mean") + rauc2 = elephant.signal_processing.rauc(self.test_signal2, baseline="mean") self.assertTrue(isinstance(rauc1, pq.Quantity)) self.assertTrue(isinstance(rauc2, pq.Quantity)) # single channel - assert_array_almost_equal( - rauc1.magnitude, - np.array([6.36517679])) + assert_array_almost_equal(rauc1.magnitude, np.array([6.36517679])) # multi channel - assert_array_almost_equal( - rauc2.magnitude, - np.array([6.36517679, 6.36617364])) + assert_array_almost_equal(rauc2.magnitude, np.array([6.36517679, 6.36617364])) def test_rauc_median_baseline(self): """Test rauc returns correct values when baseline='median' is given""" - rauc1 = elephant.signal_processing.rauc( - self.test_signal1, baseline='median') - rauc2 = elephant.signal_processing.rauc( - self.test_signal2, baseline='median') + rauc1 = elephant.signal_processing.rauc(self.test_signal1, baseline="median") + rauc2 = elephant.signal_processing.rauc(self.test_signal2, baseline="median") self.assertTrue(isinstance(rauc1, pq.Quantity)) self.assertTrue(isinstance(rauc2, pq.Quantity)) # single channel - assert_array_almost_equal( - rauc1.magnitude, - np.array([6.36517679])) + assert_array_almost_equal(rauc1.magnitude, np.array([6.36517679])) # multi channel - assert_array_almost_equal( - rauc2.magnitude, - np.array([6.36517679, 6.36617364])) + assert_array_almost_equal(rauc2.magnitude, np.array([6.36517679, 6.36617364])) def test_rauc_arbitrary_baseline(self): """Test rauc returns correct values when arbitrary baseline is given""" rauc1 = elephant.signal_processing.rauc( - self.test_signal1, baseline=0.123 * pq.mV) + self.test_signal1, baseline=0.123 * pq.mV + ) rauc2 = elephant.signal_processing.rauc( - self.test_signal2, baseline=0.123 * pq.mV) + self.test_signal2, baseline=0.123 * pq.mV + ) self.assertTrue(isinstance(rauc1, pq.Quantity)) self.assertTrue(isinstance(rauc2, pq.Quantity)) # single channel - assert_array_almost_equal( - rauc1.magnitude, - np.array([6.41354725])) + assert_array_almost_equal(rauc1.magnitude, np.array([6.41354725])) # multi channel - assert_array_almost_equal( - rauc2.magnitude, - np.array([6.41354725, 6.41429810])) + assert_array_almost_equal(rauc2.magnitude, np.array([6.41354725, 6.41429810])) def test_rauc_time_slice(self): """Test rauc returns correct values when t_start, t_stop are given""" rauc1 = elephant.signal_processing.rauc( - self.test_signal1, t_start=0.123 * pq.s, t_stop=0.456 * pq.s) + self.test_signal1, t_start=0.123 * pq.s, t_stop=0.456 * pq.s + ) rauc2 = elephant.signal_processing.rauc( - self.test_signal2, t_start=0.123 * pq.s, t_stop=0.456 * pq.s) + self.test_signal2, t_start=0.123 * pq.s, t_stop=0.456 * pq.s + ) self.assertTrue(isinstance(rauc1, pq.Quantity)) self.assertTrue(isinstance(rauc2, pq.Quantity)) # single channel - assert_array_almost_equal( - rauc1.magnitude, - np.array([0.16279006])) + assert_array_almost_equal(rauc1.magnitude, np.array([0.16279006])) # multi channel - assert_array_almost_equal( - rauc2.magnitude, - np.array([0.16279006, 0.26677944])) + assert_array_almost_equal(rauc2.magnitude, np.array([0.16279006, 0.26677944])) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_spade.py b/elephant/test/test_spade.py index 4a9579393..5ac3143ab 100644 --- a/elephant/test/test_spade.py +++ b/elephant/test/test_spade.py @@ -4,6 +4,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + import unittest import neo @@ -14,8 +15,10 @@ import elephant.conversion as conv import elephant.spade as spade -from elephant.spike_train_generation import StationaryPoissonProcess, \ - compound_poisson_process +from elephant.spike_train_generation import ( + StationaryPoissonProcess, + compound_poisson_process, +) try: import statsmodels @@ -48,9 +51,8 @@ def setUp(self): self.n_neu = 100 self.amplitude = [0] * self.n_neu + [1] self.cpp = compound_poisson_process( - rate=3 * pq.Hz, - amplitude_distribution=self.amplitude, - t_stop=5 * pq.s) + rate=3 * pq.Hz, amplitude_distribution=self.amplitude, t_stop=5 * pq.s + ) # Number of patterns' occurrences self.n_occ1 = 10 self.n_occ2 = 12 @@ -63,27 +65,38 @@ def setUp(self): self.t_stop = 3000 # Patterns times self.patt1_times = neo.SpikeTrain( - np.arange( - 0, 1000, 1000 // self.n_occ1) * - pq.ms, t_stop=self.t_stop * pq.ms) + np.arange(0, 1000, 1000 // self.n_occ1) * pq.ms, t_stop=self.t_stop * pq.ms + ) self.patt2_times = neo.SpikeTrain( - np.arange( - 1000, 2000, 1000 // self.n_occ2)[:-1] * - pq.ms, t_stop=self.t_stop * pq.ms) + np.arange(1000, 2000, 1000 // self.n_occ2)[:-1] * pq.ms, + t_stop=self.t_stop * pq.ms, + ) self.patt3_times = neo.SpikeTrain( - np.arange( - 2000, 3000, 1000 // self.n_occ3)[:-1] * - pq.ms, t_stop=self.t_stop * pq.ms) + np.arange(2000, 3000, 1000 // self.n_occ3)[:-1] * pq.ms, + t_stop=self.t_stop * pq.ms, + ) # Patterns - self.patt1 = [self.patt1_times] + [neo.SpikeTrain( - self.patt1_times.view(pq.Quantity) + lag * pq.ms, - t_stop=self.t_stop * pq.ms) for lag in self.lags1] - self.patt2 = [self.patt2_times] + [neo.SpikeTrain( - self.patt2_times.view(pq.Quantity) + lag * pq.ms, - t_stop=self.t_stop * pq.ms) for lag in self.lags2] - self.patt3 = [self.patt3_times] + [neo.SpikeTrain( - self.patt3_times.view(pq.Quantity) + lag * pq.ms, - t_stop=self.t_stop * pq.ms) for lag in self.lags3] + self.patt1 = [self.patt1_times] + [ + neo.SpikeTrain( + self.patt1_times.view(pq.Quantity) + lag * pq.ms, + t_stop=self.t_stop * pq.ms, + ) + for lag in self.lags1 + ] + self.patt2 = [self.patt2_times] + [ + neo.SpikeTrain( + self.patt2_times.view(pq.Quantity) + lag * pq.ms, + t_stop=self.t_stop * pq.ms, + ) + for lag in self.lags2 + ] + self.patt3 = [self.patt3_times] + [ + neo.SpikeTrain( + self.patt3_times.view(pq.Quantity) + lag * pq.ms, + t_stop=self.t_stop * pq.ms, + ) + for lag in self.lags3 + ] # Data self.msip = self.patt1 + self.patt2 + self.patt3 # Expected results @@ -95,46 +108,48 @@ def setUp(self): self.elements3 = list(range(self.n_spk3)) self.elements_msip = [ self.elements1, + list(range(self.n_spk1, self.n_spk1 + self.n_spk2)), list( range( - self.n_spk1, - self.n_spk1 + - self.n_spk2)), - list( - range( - self.n_spk1 + - self.n_spk2, - self.n_spk1 + - self.n_spk2 + - self.n_spk3))] - self.occ1 = np.unique(conv.BinnedSpikeTrain( - self.patt1_times, self.bin_size).spike_indices[0]) - self.occ2 = np.unique(conv.BinnedSpikeTrain( - self.patt2_times, self.bin_size).spike_indices[0]) - self.occ3 = np.unique(conv.BinnedSpikeTrain( - self.patt3_times, self.bin_size).spike_indices[0]) - self.occ_msip = [ - list(self.occ1), list(self.occ2), list(self.occ3)] + self.n_spk1 + self.n_spk2, self.n_spk1 + self.n_spk2 + self.n_spk3 + ) + ), + ] + self.occ1 = np.unique( + conv.BinnedSpikeTrain(self.patt1_times, self.bin_size).spike_indices[0] + ) + self.occ2 = np.unique( + conv.BinnedSpikeTrain(self.patt2_times, self.bin_size).spike_indices[0] + ) + self.occ3 = np.unique( + conv.BinnedSpikeTrain(self.patt3_times, self.bin_size).spike_indices[0] + ) + self.occ_msip = [list(self.occ1), list(self.occ2), list(self.occ3)] self.lags_msip = [self.lags1, self.lags2, self.lags3] self.patt_psr = self.patt3 + [self.patt3[-1][:3]] # Testing cpp @unittest.skipUnless(HAVE_FIM, "Time consuming with pythonic FIM") def test_spade_cpp(self): - output_cpp = spade.spade(self.cpp, self.bin_size, 1, - approx_stab_pars=dict( - n_subsets=self.n_subset, - stability_thresh=self.stability_thresh), - n_surr=self.n_surr, alpha=self.alpha, - psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + output_cpp = spade.spade( + self.cpp, + self.bin_size, + 1, + approx_stab_pars=dict( + n_subsets=self.n_subset, stability_thresh=self.stability_thresh + ), + n_surr=self.n_surr, + alpha=self.alpha, + psr_param=self.psr_param, + stat_corr="no", + output_format="patterns", + )["patterns"] elements_cpp = [] lags_cpp = [] # collecting spade output for out in output_cpp: - elements_cpp.append(sorted(out['neurons'])) - lags_cpp.append(list(out['lags'].magnitude)) + elements_cpp.append(sorted(out["neurons"])) + lags_cpp.append(list(out["lags"].magnitude)) # check neurons in the patterns assert_array_equal(elements_cpp, [range(self.n_neu)]) # check the lags @@ -143,33 +158,46 @@ def test_spade_cpp(self): # Testing spectrum cpp def test_spade_spectrum_cpp(self): # Computing Spectrum - spectrum_cpp = spade.concepts_mining(self.cpp, self.bin_size, - 1, report='#')[0] + spectrum_cpp = spade.concepts_mining(self.cpp, self.bin_size, 1, report="#")[0] # Check spectrum assert_array_equal( spectrum_cpp, - [(len(self.cpp), - np.sum(conv.BinnedSpikeTrain( - self.cpp[0], self.bin_size).to_bool_array()), 1)]) + [ + ( + len(self.cpp), + np.sum( + conv.BinnedSpikeTrain( + self.cpp[0], self.bin_size + ).to_bool_array() + ), + 1, + ) + ], + ) # Testing with multiple patterns input def test_spade_msip(self): - output_msip = spade.spade(self.msip, self.bin_size, self.winlen, - approx_stab_pars=dict( - n_subsets=self.n_subset, - stability_thresh=self.stability_thresh), - n_surr=self.n_surr, alpha=self.alpha, - psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + output_msip = spade.spade( + self.msip, + self.bin_size, + self.winlen, + approx_stab_pars=dict( + n_subsets=self.n_subset, stability_thresh=self.stability_thresh + ), + n_surr=self.n_surr, + alpha=self.alpha, + psr_param=self.psr_param, + stat_corr="no", + output_format="patterns", + )["patterns"] elements_msip = [] occ_msip = [] lags_msip = [] # collecting spade output for out in output_msip: - elements_msip.append(out['neurons']) - occ_msip.append(list(out['times'].magnitude)) - lags_msip.append(list(out['lags'].magnitude)) + elements_msip.append(out["neurons"]) + occ_msip.append(list(out["times"].magnitude)) + lags_msip.append(list(out["lags"].magnitude)) elements_msip = sorted(elements_msip, key=len) occ_msip = sorted(occ_msip, key=len) lags_msip = sorted(lags_msip, key=len) @@ -183,22 +211,26 @@ def test_spade_msip(self): # Testing with multiple patterns input def test_spade_msip_spiketrainlist(self): output_msip = spade.spade( - SpikeTrainList(self.msip), self.bin_size, self.winlen, - approx_stab_pars=dict(n_subsets=self.n_subset, - stability_thresh=self.stability_thresh), + SpikeTrainList(self.msip), + self.bin_size, + self.winlen, + approx_stab_pars=dict( + n_subsets=self.n_subset, stability_thresh=self.stability_thresh + ), n_surr=self.n_surr, alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + stat_corr="no", + output_format="patterns", + )["patterns"] elements_msip = [] occ_msip = [] lags_msip = [] # collecting spade output for out in output_msip: - elements_msip.append(out['neurons']) - occ_msip.append(list(out['times'].magnitude)) - lags_msip.append(list(out['lags'].magnitude)) + elements_msip.append(out["neurons"]) + occ_msip.append(list(out["times"].magnitude)) + lags_msip.append(list(out["lags"].magnitude)) elements_msip = sorted(elements_msip, key=len) occ_msip = sorted(occ_msip, key=len) lags_msip = sorted(lags_msip, key=len) @@ -225,27 +257,33 @@ def test_parameters(self): n_surr=0, alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + stat_corr="no", + output_format="patterns", + )["patterns"] # collecting spade output elements_msip_min_spikes = [] for out in output_msip_min_spikes: - elements_msip_min_spikes.append(out['neurons']) - elements_msip_min_spikes = sorted( - elements_msip_min_spikes, key=len) + elements_msip_min_spikes.append(out["neurons"]) + elements_msip_min_spikes = sorted(elements_msip_min_spikes, key=len) lags_msip_min_spikes = [] for out in output_msip_min_spikes: - lags_msip_min_spikes.append(list(out['lags'].magnitude)) - pvalue = out['pvalue'] - lags_msip_min_spikes = sorted( - lags_msip_min_spikes, key=len) + lags_msip_min_spikes.append(list(out["lags"].magnitude)) + pvalue = out["pvalue"] + lags_msip_min_spikes = sorted(lags_msip_min_spikes, key=len) # check the lags - assert_array_equal(lags_msip_min_spikes, [ - lag for lag in self.lags_msip if len(lag) + 1 >= self.min_spikes]) + assert_array_equal( + lags_msip_min_spikes, + [lag for lag in self.lags_msip if len(lag) + 1 >= self.min_spikes], + ) # check the neurons in the patterns - assert_array_equal(elements_msip_min_spikes, [ - el for el in self.elements_msip if len(el) >= self.min_neu and len( - el) >= self.min_spikes]) + assert_array_equal( + elements_msip_min_spikes, + [ + el + for el in self.elements_msip + if len(el) >= self.min_neu and len(el) >= self.min_spikes + ], + ) # check that the p-values assigned are equal to -1 (n_surr=0) assert_array_equal(-1, pvalue) @@ -255,21 +293,22 @@ def test_parameters(self): self.bin_size, self.winlen, min_occ=self.min_occ, - approx_stab_pars=dict( - n_subsets=self.n_subset), + approx_stab_pars=dict(n_subsets=self.n_subset), n_surr=self.n_surr, alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + stat_corr="no", + output_format="patterns", + )["patterns"] # collect spade output occ_msip_min_occ = [] for out in output_msip_min_occ: - occ_msip_min_occ.append(list(out['times'].magnitude)) + occ_msip_min_occ.append(list(out["times"].magnitude)) occ_msip_min_occ = sorted(occ_msip_min_occ, key=len) # test occurrences time - assert_equal(occ_msip_min_occ, [ - occ for occ in self.occ_msip if len(occ) >= self.min_occ]) + assert_equal( + occ_msip_min_occ, [occ for occ in self.occ_msip if len(occ) >= self.min_occ] + ) # test max_spikes parameter output_msip_max_spikes = spade.spade( @@ -277,27 +316,26 @@ def test_parameters(self): self.bin_size, self.winlen, max_spikes=self.max_spikes, - approx_stab_pars=dict( - n_subsets=self.n_subset), + approx_stab_pars=dict(n_subsets=self.n_subset), n_surr=self.n_surr, alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + stat_corr="no", + output_format="patterns", + )["patterns"] # collecting spade output elements_msip_max_spikes = [] for out in output_msip_max_spikes: - elements_msip_max_spikes.append(out['neurons']) + elements_msip_max_spikes.append(out["neurons"]) lags_msip_max_spikes = [] for out in output_msip_max_spikes: - lags_msip_max_spikes.append(list(out['lags'].magnitude)) - lags_msip_max_spikes = sorted( - lags_msip_max_spikes, key=len) + lags_msip_max_spikes.append(list(out["lags"].magnitude)) + lags_msip_max_spikes = sorted(lags_msip_max_spikes, key=len) # check the lags assert_array_equal( - [len(lags) < self.max_spikes - for lags in lags_msip_max_spikes], - [True] * len(lags_msip_max_spikes)) + [len(lags) < self.max_spikes for lags in lags_msip_max_spikes], + [True] * len(lags_msip_max_spikes), + ) # TODO: does not work with new FIM module # test max_occ parameter @@ -324,37 +362,39 @@ def test_parameters(self): # test to compare the python and the C implementation of FIM # skip this test if C code not available - @unittest.skipIf(not HAVE_FIM, 'Requires fim.so') + @unittest.skipIf(not HAVE_FIM, "Requires fim.so") def test_fpgrowth_fca(self): print("fim.so is found.") - binary_matrix = conv.BinnedSpikeTrain( - self.patt1, self.bin_size).to_sparse_bool_array().tocoo() + binary_matrix = ( + conv.BinnedSpikeTrain(self.patt1, self.bin_size) + .to_sparse_bool_array() + .tocoo() + ) context, transactions, rel_matrix = spade._build_context( - binary_matrix, self.winlen) + binary_matrix, self.winlen + ) # mining the data with python fast_fca - mining_results_fpg = spade._fpgrowth( - transactions, - rel_matrix=rel_matrix) - print( - '################################################################') - print('mining results fpg', mining_results_fpg) + mining_results_fpg = spade._fpgrowth(transactions, rel_matrix=rel_matrix) + print("################################################################") + print("mining results fpg", mining_results_fpg) # mining the data with C fim mining_results_ffca = spade._fast_fca(context) # testing that the outputs are identical - assert_array_equal(sorted(mining_results_ffca[0][0]), sorted( - mining_results_fpg[0][0])) - assert_array_equal(sorted(mining_results_ffca[0][1]), sorted( - mining_results_fpg[0][1])) + assert_array_equal( + sorted(mining_results_ffca[0][0]), sorted(mining_results_fpg[0][0]) + ) + assert_array_equal( + sorted(mining_results_ffca[0][1]), sorted(mining_results_fpg[0][1]) + ) def test_spade_output_for_specific_pattern(self): - np.random.seed(0) n_spiketrains = 3 spiketrains = StationaryPoissonProcess( - rate=20 * pq.Hz, t_stop=5 * pq.s).generate_n_spiketrains( - n_spiketrains=n_spiketrains) + rate=20 * pq.Hz, t_stop=5 * pq.s + ).generate_n_spiketrains(n_spiketrains=n_spiketrains) bin_size = 5 * pq.ms spade_output = spade.spade( @@ -364,35 +404,42 @@ def test_spade_output_for_specific_pattern(self): min_spikes=3, max_spikes=3, min_neu=n_spiketrains, - spectrum='3d#') + spectrum="3d#", + ) - patterns = spade_output['patterns'] - self.assertEqual (len(patterns), 1) + patterns = spade_output["patterns"] + self.assertEqual(len(patterns), 1) pattern = patterns[0] - self.assertEqual(len(np.unique(pattern['neurons'])), n_spiketrains) - assert_array_equal(sorted(pattern['itemset']), [1, 2, 5]) - assert_array_equal(sorted(pattern['windows_ids']), [36, 134, 856]) - assert_array_equal(pattern['signature'], [3, 3, 1]) + self.assertEqual(len(np.unique(pattern["neurons"])), n_spiketrains) + assert_array_equal(sorted(pattern["itemset"]), [1, 2, 5]) + assert_array_equal(sorted(pattern["windows_ids"]), [36, 134, 856]) + assert_array_equal(pattern["signature"], [3, 3, 1]) # Tests 3d spectrum # Testing with multiple patterns input def test_spade_msip_3d(self): - output_msip = spade.spade(self.msip, self.bin_size, self.winlen, - approx_stab_pars=dict( - n_subsets=self.n_subset, - stability_thresh=self.stability_thresh), - n_surr=self.n_surr, spectrum='3d#', - alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + output_msip = spade.spade( + self.msip, + self.bin_size, + self.winlen, + approx_stab_pars=dict( + n_subsets=self.n_subset, stability_thresh=self.stability_thresh + ), + n_surr=self.n_surr, + spectrum="3d#", + alpha=self.alpha, + psr_param=self.psr_param, + stat_corr="no", + output_format="patterns", + )["patterns"] elements_msip = [] occ_msip = [] lags_msip = [] # collecting spade output for out in output_msip: - elements_msip.append(out['neurons']) - occ_msip.append(list(out['times'].magnitude)) - lags_msip.append(list(out['lags'].magnitude)) + elements_msip.append(out["neurons"]) + occ_msip.append(list(out["times"].magnitude)) + lags_msip.append(list(out["lags"].magnitude)) elements_msip = sorted(elements_msip, key=len) occ_msip = sorted(occ_msip, key=len) lags_msip = sorted(lags_msip, key=len) @@ -411,32 +458,37 @@ def test_parameters_3d(self): self.bin_size, self.winlen, min_spikes=self.min_spikes, - approx_stab_pars=dict( - n_subsets=self.n_subset), + approx_stab_pars=dict(n_subsets=self.n_subset), n_surr=self.n_surr, - spectrum='3d#', + spectrum="3d#", alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + stat_corr="no", + output_format="patterns", + )["patterns"] # collecting spade output elements_msip_min_spikes = [] for out in output_msip_min_spikes: - elements_msip_min_spikes.append(out['neurons']) - elements_msip_min_spikes = sorted( - elements_msip_min_spikes, key=len) + elements_msip_min_spikes.append(out["neurons"]) + elements_msip_min_spikes = sorted(elements_msip_min_spikes, key=len) lags_msip_min_spikes = [] for out in output_msip_min_spikes: - lags_msip_min_spikes.append(list(out['lags'].magnitude)) - lags_msip_min_spikes = sorted( - lags_msip_min_spikes, key=len) + lags_msip_min_spikes.append(list(out["lags"].magnitude)) + lags_msip_min_spikes = sorted(lags_msip_min_spikes, key=len) # check the lags - assert_array_equal(lags_msip_min_spikes, [ - lag for lag in self.lags_msip if len(lag) + 1 >= self.min_spikes]) + assert_array_equal( + lags_msip_min_spikes, + [lag for lag in self.lags_msip if len(lag) + 1 >= self.min_spikes], + ) # check the neurons in the patterns - assert_array_equal(elements_msip_min_spikes, [ - el for el in self.elements_msip if len(el) >= self.min_neu and len( - el) >= self.min_spikes]) + assert_array_equal( + elements_msip_min_spikes, + [ + el + for el in self.elements_msip + if len(el) >= self.min_neu and len(el) >= self.min_spikes + ], + ) # test min_occ parameter output_msip_min_occ = spade.spade( @@ -444,197 +496,372 @@ def test_parameters_3d(self): self.bin_size, self.winlen, min_occ=self.min_occ, - approx_stab_pars=dict( - n_subsets=self.n_subset), + approx_stab_pars=dict(n_subsets=self.n_subset), n_surr=self.n_surr, - spectrum='3d#', + spectrum="3d#", alpha=self.alpha, psr_param=self.psr_param, - stat_corr='no', - output_format='patterns')['patterns'] + stat_corr="no", + output_format="patterns", + )["patterns"] # collect spade output occ_msip_min_occ = [] for out in output_msip_min_occ: - occ_msip_min_occ.append(list(out['times'].magnitude)) + occ_msip_min_occ.append(list(out["times"].magnitude)) occ_msip_min_occ = sorted(occ_msip_min_occ, key=len) # test occurrences time - assert_equal(occ_msip_min_occ, [ - occ for occ in self.occ_msip if len(occ) >= self.min_occ]) + assert_equal( + occ_msip_min_occ, [occ for occ in self.occ_msip if len(occ) >= self.min_occ] + ) # Test computation spectrum def test_spectrum(self): # test 2d spectrum - spectrum = spade.concepts_mining(self.patt1, self.bin_size, - self.winlen, report='#')[0] + spectrum = spade.concepts_mining( + self.patt1, self.bin_size, self.winlen, report="#" + )[0] # test 3d spectrum assert_array_equal(spectrum, [[len(self.lags1) + 1, self.n_occ1, 1]]) - spectrum_3d = spade.concepts_mining(self.patt1, self.bin_size, - self.winlen, report='3d#')[0] - assert_array_equal(spectrum_3d, [ - [len(self.lags1) + 1, self.n_occ1, max(self.lags1), 1]]) + spectrum_3d = spade.concepts_mining( + self.patt1, self.bin_size, self.winlen, report="3d#" + )[0] + assert_array_equal( + spectrum_3d, [[len(self.lags1) + 1, self.n_occ1, max(self.lags1), 1]] + ) def test_spade_raise_error(self): # Test list not using neo.Spiketrain - self.assertRaises(TypeError, spade.spade, [ - [1, 2, 3], [3, 4, 5]], 1 * pq.ms, 4, stat_corr='no') - self.assertRaises(TypeError, spade.concepts_mining, [ - [1, 2, 3], [3, 4, 5]], 1 * pq.ms, 4) + self.assertRaises( + TypeError, spade.spade, [[1, 2, 3], [3, 4, 5]], 1 * pq.ms, 4, stat_corr="no" + ) + self.assertRaises( + TypeError, spade.concepts_mining, [[1, 2, 3], [3, 4, 5]], 1 * pq.ms, 4 + ) # Test neo.Spiketrain with different t_stop self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - 1 * pq.ms, 4, stat_corr='no') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + 1 * pq.ms, + 4, + stat_corr="no", + ) # Test bin_size not pq.Quantity self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1., winlen=4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0, + winlen=4, + stat_corr="no", + ) # Test winlen not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4.1, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4.1, + stat_corr="no", + ) # Test min_spikes not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, min_spikes=3.4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + min_spikes=3.4, + stat_corr="no", + ) # Test min_occ not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, min_occ=3.4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + min_occ=3.4, + stat_corr="no", + ) # Test max_spikes not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, max_spikes=3.4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + max_spikes=3.4, + stat_corr="no", + ) # Test max_occ not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, max_occ=3.4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + max_occ=3.4, + stat_corr="no", + ) # Test min_neu not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, min_neu=3.4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + min_neu=3.4, + stat_corr="no", + ) # Test wrong stability params self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, approx_stab_pars={'wrong key': 0}, - stat_corr='no') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + approx_stab_pars={"wrong key": 0}, + stat_corr="no", + ) # Test n_surr not int self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=3.4, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=3.4, + stat_corr="no", + ) # Test dither not pq.Quantity self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=100, alpha=0.05, - dither=15., stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=100, + alpha=0.05, + dither=15.0, + stat_corr="no", + ) # Test wrong alpha self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=100, alpha='5 %', - dither=15. * pq.ms, stat_corr='no') + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=100, + alpha="5 %", + dither=15.0 * pq.ms, + stat_corr="no", + ) # Test wrong statistical correction self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=100, alpha=0.05, - dither=15. * pq.ms, stat_corr='wrong correction') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=100, + alpha=0.05, + dither=15.0 * pq.ms, + stat_corr="wrong correction", + ) # Test wrong psr_params self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=100, alpha=0.05, - dither=15. * pq.ms, stat_corr='no', psr_param=(2.5, 3.4, 2.1)) + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=100, + alpha=0.05, + dither=15.0 * pq.ms, + stat_corr="no", + psr_param=(2.5, 3.4, 2.1), + ) # Test wrong psr_params self.assertRaises( - TypeError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=100, alpha=0.05, - dither=15. * pq.ms, stat_corr='no', psr_param=3.1) + TypeError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=100, + alpha=0.05, + dither=15.0 * pq.ms, + stat_corr="no", + psr_param=3.1, + ) # Test output format self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - bin_size=1. * pq.ms, winlen=4, n_surr=100, alpha=0.05, - dither=15. * pq.ms, stat_corr='no', output_format='wrong_output') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + bin_size=1.0 * pq.ms, + winlen=4, + n_surr=100, + alpha=0.05, + dither=15.0 * pq.ms, + stat_corr="no", + output_format="wrong_output", + ) # Test wrong spectrum parameter self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - 1 * pq.ms, 4, n_surr=1, stat_corr='no', - spectrum='invalid_key') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + 1 * pq.ms, + 4, + n_surr=1, + stat_corr="no", + spectrum="invalid_key", + ) self.assertRaises( - ValueError, spade.concepts_mining, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - 1 * pq.ms, 4, report='invalid_key') + ValueError, + spade.concepts_mining, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + 1 * pq.ms, + 4, + report="invalid_key", + ) self.assertRaises( - ValueError, spade.pvalue_spectrum, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s)], - 1 * pq.ms, 4, dither=10 * pq.ms, n_surr=1, - spectrum='invalid_key') + ValueError, + spade.pvalue_spectrum, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=6 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=6 * pq.s), + ], + 1 * pq.ms, + 4, + dither=10 * pq.ms, + n_surr=1, + spectrum="invalid_key", + ) # Test negative minimum number of spikes self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s)], - 1 * pq.ms, 4, min_neu=-3, stat_corr='no') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + 1 * pq.ms, + 4, + min_neu=-3, + stat_corr="no", + ) # Test wrong dither method self.assertRaises( - ValueError, spade.spade, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s)], - 1 * pq.ms, 4, surr_method='invalid_key', stat_corr='no') + ValueError, + spade.spade, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + 1 * pq.ms, + 4, + surr_method="invalid_key", + stat_corr="no", + ) # Test negative number of surrogates self.assertRaises( - ValueError, spade.pvalue_spectrum, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s)], - 1 * pq.ms, 4, dither=10 * pq.ms, n_surr=100, - surr_method='invalid_key') + ValueError, + spade.pvalue_spectrum, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + 1 * pq.ms, + 4, + dither=10 * pq.ms, + n_surr=100, + surr_method="invalid_key", + ) # Test negative number of surrogates self.assertRaises( - ValueError, spade.pvalue_spectrum, - [neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), - neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s)], - 1 * pq.ms, 4, 3 * pq.ms, n_surr=-3) + ValueError, + spade.pvalue_spectrum, + [ + neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), + neo.SpikeTrain([3, 4, 5] * pq.s, t_stop=5 * pq.s), + ], + 1 * pq.ms, + 4, + 3 * pq.ms, + n_surr=-3, + ) # Test wrong correction parameter - self.assertRaises(ValueError, spade.test_signature_significance, - pv_spec=((2, 3, 0.2), (2, 4, 0.1)), - concepts=([[(2, 3), (1, 2, 3)]]), - alpha=0.01, - winlen=1, - corr='invalid_key') + self.assertRaises( + ValueError, + spade.test_signature_significance, + pv_spec=((2, 3, 0.2), (2, 4, 0.1)), + concepts=([[(2, 3), (1, 2, 3)]]), + alpha=0.01, + winlen=1, + corr="invalid_key", + ) # Test negative number of subset for stability - self.assertRaises(ValueError, spade.approximate_stability, (), - np.ndarray([]), n_subsets=-3) + self.assertRaises( + ValueError, spade.approximate_stability, (), np.ndarray([]), n_subsets=-3 + ) def test_pattern_set_reduction(self): winlen = 6 @@ -652,106 +879,142 @@ def test_pattern_set_reduction(self): # reject concept2 using min_occ # make sure to keep concept1 by setting k_superset_filtering = 1 - concepts = spade.pattern_set_reduction([concept1, concept2], - ns_signatures=[], - winlen=winlen, spectrum='#', - h_subset_filtering=0, min_occ=2, - k_superset_filtering=1) + concepts = spade.pattern_set_reduction( + [concept1, concept2], + ns_signatures=[], + winlen=winlen, + spectrum="#", + h_subset_filtering=0, + min_occ=2, + k_superset_filtering=1, + ) self.assertEqual(concepts, [concept1]) # keep concept2 by increasing h_subset_filtering - concepts = spade.pattern_set_reduction([concept1, concept2], - ns_signatures=[], - winlen=winlen, spectrum='#', - h_subset_filtering=2, min_occ=2, - k_superset_filtering=1) + concepts = spade.pattern_set_reduction( + [concept1, concept2], + ns_signatures=[], + winlen=winlen, + spectrum="#", + h_subset_filtering=2, + min_occ=2, + k_superset_filtering=1, + ) self.assertEqual(concepts, [concept1, concept2]) # reject concept1 using min_spikes - concepts = spade.pattern_set_reduction([concept1, concept2], - ns_signatures=[], - winlen=winlen, spectrum='#', - h_subset_filtering=2, - min_spikes=2, - k_superset_filtering=0) + concepts = spade.pattern_set_reduction( + [concept1, concept2], + ns_signatures=[], + winlen=winlen, + spectrum="#", + h_subset_filtering=2, + min_spikes=2, + k_superset_filtering=0, + ) self.assertEqual(concepts, [concept2]) # reject concept2 using ns_signatures - concepts = spade.pattern_set_reduction([concept1, concept2], - ns_signatures=[(2, 2)], - winlen=winlen, spectrum='#', - h_subset_filtering=1, min_occ=2, - k_superset_filtering=1) + concepts = spade.pattern_set_reduction( + [concept1, concept2], + ns_signatures=[(2, 2)], + winlen=winlen, + spectrum="#", + h_subset_filtering=1, + min_occ=2, + k_superset_filtering=1, + ) self.assertEqual(concepts, [concept1]) # reject concept1 using ns_signatures # make sure to keep concept2 by increasing h_subset_filtering - concepts = spade.pattern_set_reduction([concept1, concept2], - ns_signatures=[(2, 3)], - winlen=winlen, spectrum='#', - h_subset_filtering=3, - min_spikes=2, - min_occ=2, - k_superset_filtering=1) + concepts = spade.pattern_set_reduction( + [concept1, concept2], + ns_signatures=[(2, 3)], + winlen=winlen, + spectrum="#", + h_subset_filtering=3, + min_spikes=2, + min_occ=2, + k_superset_filtering=1, + ) self.assertEqual(concepts, [concept2]) # reject concept2 using the covered spikes criterion - concepts = spade.pattern_set_reduction([concept1, concept2], - ns_signatures=[(2, 2)], - winlen=winlen, spectrum='#', - h_subset_filtering=0, - min_occ=2, - k_superset_filtering=0, - l_covered_spikes=0) + concepts = spade.pattern_set_reduction( + [concept1, concept2], + ns_signatures=[(2, 2)], + winlen=winlen, + spectrum="#", + h_subset_filtering=0, + min_occ=2, + k_superset_filtering=0, + l_covered_spikes=0, + ) self.assertEqual(concepts, [concept1]) # reject concept1 using superset filtering # (case with non-empty intersection but no superset) - concepts = spade.pattern_set_reduction([concept1, concept3], - ns_signatures=[], min_spikes=2, - winlen=winlen, spectrum='#', - k_superset_filtering=0) + concepts = spade.pattern_set_reduction( + [concept1, concept3], + ns_signatures=[], + min_spikes=2, + winlen=winlen, + spectrum="#", + k_superset_filtering=0, + ) self.assertEqual(concepts, [concept3]) # keep concept1 by increasing k_superset_filtering - concepts = spade.pattern_set_reduction([concept1, concept3], - ns_signatures=[], min_spikes=2, - winlen=winlen, spectrum='#', - k_superset_filtering=1) + concepts = spade.pattern_set_reduction( + [concept1, concept3], + ns_signatures=[], + min_spikes=2, + winlen=winlen, + spectrum="#", + k_superset_filtering=1, + ) self.assertEqual(concepts, [concept1, concept3]) # reject concept3 using ns_signatures - concepts = spade.pattern_set_reduction([concept1, concept3], - ns_signatures=[(3, 2)], - min_spikes=2, - winlen=winlen, spectrum='#', - k_superset_filtering=1) + concepts = spade.pattern_set_reduction( + [concept1, concept3], + ns_signatures=[(3, 2)], + min_spikes=2, + winlen=winlen, + spectrum="#", + k_superset_filtering=1, + ) self.assertEqual(concepts, [concept1]) # reject concept3 using the covered spikes criterion - concepts = spade.pattern_set_reduction([concept1, concept3], - ns_signatures=[(3, 2), (2, 3)], - min_spikes=2, - winlen=winlen, spectrum='#', - k_superset_filtering=1, - l_covered_spikes=0) + concepts = spade.pattern_set_reduction( + [concept1, concept3], + ns_signatures=[(3, 2), (2, 3)], + min_spikes=2, + winlen=winlen, + spectrum="#", + k_superset_filtering=1, + l_covered_spikes=0, + ) self.assertEqual(concepts, [concept1]) # check that two concepts with disjoint intents are both kept - concepts = spade.pattern_set_reduction([concept3, concept4], - ns_signatures=[], - winlen=winlen, spectrum='#') + concepts = spade.pattern_set_reduction( + [concept3, concept4], ns_signatures=[], winlen=winlen, spectrum="#" + ) self.assertEqual(concepts, [concept3, concept4]) - @unittest.skipUnless(HAVE_STATSMODELS, - "'fdr_bh' stat corr requires statsmodels") + @unittest.skipUnless(HAVE_STATSMODELS, "'fdr_bh' stat corr requires statsmodels") def test_signature_significance_fdr_bh_corr(self): """ A typical corr='fdr_bh' scenario, that requires statsmodels. """ sig_spectrum = spade.test_signature_significance( pv_spec=((2, 3, 0.2), (2, 4, 0.05)), - concepts=([[(2, 3), (1, 2, 3)], - [(2, 4), (1, 2, 3, 4)]]), - alpha=0.15, winlen=1, corr='fdr_bh') - self.assertEqual(sig_spectrum, [(2., 3., False), (2., 4., True)]) + concepts=([[(2, 3), (1, 2, 3)], [(2, 4), (1, 2, 3, 4)]]), + alpha=0.15, + winlen=1, + corr="fdr_bh", + ) + self.assertEqual(sig_spectrum, [(2.0, 3.0, False), (2.0, 4.0, True)]) diff --git a/elephant/test/test_spectral.py b/elephant/test/test_spectral.py index 41244fb66..1520f5335 100644 --- a/elephant/test/test_spectral.py +++ b/elephant/test/test_spectral.py @@ -26,39 +26,40 @@ class WelchPSDTestCase(unittest.TestCase): def test_welch_psd_errors(self): # generate a dummy data - data = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, - units='mV') + data = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, units="mV") # check for invalid parameter values # - length of segments - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - len_segment=0) - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - len_segment=data.shape[0] * 2) + self.assertRaises(ValueError, elephant.spectral.welch_psd, data, len_segment=0) + self.assertRaises( + ValueError, elephant.spectral.welch_psd, data, len_segment=data.shape[0] * 2 + ) # - number of segments - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - n_segments=0) - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - n_segments=data.shape[0] * 2) + self.assertRaises(ValueError, elephant.spectral.welch_psd, data, n_segments=0) + self.assertRaises( + ValueError, elephant.spectral.welch_psd, data, n_segments=data.shape[0] * 2 + ) # - frequency resolution - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - frequency_resolution=-1) - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - frequency_resolution=data.sampling_rate / - (data.shape[0] + 1)) + self.assertRaises( + ValueError, elephant.spectral.welch_psd, data, frequency_resolution=-1 + ) + self.assertRaises( + ValueError, + elephant.spectral.welch_psd, + data, + frequency_resolution=data.sampling_rate / (data.shape[0] + 1), + ) # - overlap - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - overlap=-1.0) - self.assertRaises(ValueError, elephant.spectral.welch_psd, data, - overlap=1.1) + self.assertRaises(ValueError, elephant.spectral.welch_psd, data, overlap=-1.0) + self.assertRaises(ValueError, elephant.spectral.welch_psd, data, overlap=1.1) def test_welch_psd_warnings(self): # generate a dummy data - data = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, - units='mV') + data = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, units="mV") # Test deprecation warning for 'hanning' window - self.assertWarns(DeprecationWarning, elephant.spectral.welch_psd, - data, window='hanning') + self.assertWarns( + DeprecationWarning, elephant.spectral.welch_psd, data, window="hanning" + ) def test_welch_psd_behavior(self): # generate data by adding white noise and a sinusoid @@ -66,43 +67,53 @@ def test_welch_psd_behavior(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=data_length) - signal = [np.sin(2 * np.pi * signal_freq * t) - for t in np.arange(0, data_length * sampling_period, - sampling_period)] - data = AnalogSignal(np.array(signal + noise), - sampling_period=sampling_period * pq.s, - units='mV') + signal = [ + np.sin(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, sampling_period) + ] + data = AnalogSignal( + np.array(signal + noise), sampling_period=sampling_period * pq.s, units="mV" + ) # consistency between different ways of specifying segment length freqs1, psd1 = elephant.spectral.welch_psd( - data, len_segment=data_length // 5, overlap=0) - freqs2, psd2 = elephant.spectral.welch_psd( - data, n_segments=5, overlap=0) + data, len_segment=data_length // 5, overlap=0 + ) + freqs2, psd2 = elephant.spectral.welch_psd(data, n_segments=5, overlap=0) self.assertTrue((psd1 == psd2).all() and (freqs1 == freqs2).all()) # frequency resolution and consistency with data freq_res = 1.0 * pq.Hz - freqs, psd = elephant.spectral.welch_psd( - data, frequency_resolution=freq_res) + freqs, psd = elephant.spectral.welch_psd(data, frequency_resolution=freq_res) self.assertAlmostEqual(freq_res, freqs[1] - freqs[0]) self.assertEqual(freqs[psd.argmax()], signal_freq) freqs_np, psd_np = elephant.spectral.welch_psd( - data.magnitude.flatten(), fs=1 / sampling_period, - frequency_resolution=freq_res) + data.magnitude.flatten(), + fs=1 / sampling_period, + frequency_resolution=freq_res, + ) self.assertTrue((freqs == freqs_np).all() and (psd == psd_np).all()) # check of scipy.signal.welch() parameters - params = {'window': 'hamming', 'nfft': 1024, 'detrend': 'linear', - 'return_onesided': False, 'scaling': 'spectrum'} + params = { + "window": "hamming", + "nfft": 1024, + "detrend": "linear", + "return_onesided": False, + "scaling": "spectrum", + } for key, val in params.items(): freqs, psd = elephant.spectral.welch_psd( - data, len_segment=1000, overlap=0, **{key: val}) - freqs_spsig, psd_spsig = spsig.welch(np.rollaxis(data, 0, len( - data.shape)), fs=1 / sampling_period, nperseg=1000, - noverlap=0, **{key: val}) - self.assertTrue( - (freqs == freqs_spsig).all() and ( - psd == psd_spsig).all()) + data, len_segment=1000, overlap=0, **{key: val} + ) + freqs_spsig, psd_spsig = spsig.welch( + np.rollaxis(data, 0, len(data.shape)), + fs=1 / sampling_period, + nperseg=1000, + noverlap=0, + **{key: val}, + ) + self.assertTrue((freqs == freqs_spsig).all() and (psd == psd_spsig).all()) # - generate multidimensional data for check of parameter `axis` num_channel = 4 @@ -116,9 +127,11 @@ def test_welch_psd_behavior(self): def test_welch_psd_input_types(self): # generate a test data sampling_period = 0.001 - data = AnalogSignal(np.array(np.random.normal(size=5000)), - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.array(np.random.normal(size=5000)), + sampling_period=sampling_period * pq.s, + units="mV", + ) # outputs from AnalogSignal input are of Quantity type (standard usage) freqs_neo, psd_neo = elephant.spectral.welch_psd(data) @@ -127,23 +140,21 @@ def test_welch_psd_input_types(self): # outputs from Quantity array input are of Quantity type freqs_pq, psd_pq = elephant.spectral.welch_psd( - data.magnitude.flatten() * data.units, fs=1 / sampling_period) + data.magnitude.flatten() * data.units, fs=1 / sampling_period + ) self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity)) self.assertTrue(isinstance(psd_pq, pq.quantity.Quantity)) # outputs from Numpy ndarray input are NOT of Quantity type freqs_np, psd_np = elephant.spectral.welch_psd( - data.magnitude.flatten(), fs=1 / sampling_period) + data.magnitude.flatten(), fs=1 / sampling_period + ) self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity)) self.assertFalse(isinstance(psd_np, pq.quantity.Quantity)) # check if the results from different input types are identical - self.assertTrue( - (freqs_neo == freqs_pq).all() and ( - psd_neo == psd_pq).all()) - self.assertTrue( - (freqs_neo == freqs_np).all() and ( - psd_neo == psd_np).all()) + self.assertTrue((freqs_neo == freqs_pq).all() and (psd_neo == psd_pq).all()) + self.assertTrue((freqs_neo == freqs_np).all() and (psd_neo == psd_np).all()) def test_welch_psd_multidim_input(self): # generate multidimensional data @@ -155,19 +166,17 @@ def test_welch_psd_multidim_input(self): # Since row-column order in AnalogSignal is different from the # conventional one, `data_np` needs to be transposed when it's used to # define an AnalogSignal - data_neo = AnalogSignal(data_np.T, - sampling_period=sampling_period * pq.s, - units='mV') - data_neo_1dim = AnalogSignal(data_np[0], - sampling_period=sampling_period * pq.s, - units='mV') + data_neo = AnalogSignal( + data_np.T, sampling_period=sampling_period * pq.s, units="mV" + ) + data_neo_1dim = AnalogSignal( + data_np[0], sampling_period=sampling_period * pq.s, units="mV" + ) # check if the results from different input types are identical - freqs_np, psd_np = elephant.spectral.welch_psd(data_np, - fs=1 / sampling_period) + freqs_np, psd_np = elephant.spectral.welch_psd(data_np, fs=1 / sampling_period) freqs_neo, psd_neo = elephant.spectral.welch_psd(data_neo) - freqs_neo_1dim, psd_neo_1dim = elephant.spectral.welch_psd( - data_neo_1dim) + freqs_neo_1dim, psd_neo_1dim = elephant.spectral.welch_psd(data_neo_1dim) self.assertTrue(np.all(freqs_np == freqs_neo)) self.assertTrue(np.all(psd_np == psd_neo)) self.assertTrue(np.all(psd_neo_1dim == psd_neo[0])) @@ -177,54 +186,54 @@ class MultitaperPSDTestCase(unittest.TestCase): def test_multitaper_psd_errors(self): # generate dummy data data_length = 5000 - signal = AnalogSignal(np.zeros(data_length), - sampling_period=0.001 * pq.s, - units='mV') + signal = AnalogSignal( + np.zeros(data_length), sampling_period=0.001 * pq.s, units="mV" + ) fs = signal.sampling_rate self.assertIsInstance(fs, pq.Quantity) # check for invalid parameter values # - number of tapers - self.assertRaises(ValueError, elephant.spectral.multitaper_psd, signal, - num_tapers=-5) - self.assertRaises(TypeError, elephant.spectral.multitaper_psd, signal, - num_tapers=-5.0) + self.assertRaises( + ValueError, elephant.spectral.multitaper_psd, signal, num_tapers=-5 + ) + self.assertRaises( + TypeError, elephant.spectral.multitaper_psd, signal, num_tapers=-5.0 + ) # - peak resolution - self.assertRaises(ValueError, elephant.spectral.multitaper_psd, signal, - peak_resolution=-1) + self.assertRaises( + ValueError, elephant.spectral.multitaper_psd, signal, peak_resolution=-1 + ) def test_multitaper_psd_behavior(self): # generate data (frequency domain to time domain) r = np.ones(2501) * 0.2 r[0], r[500] = 0, 10 # Zero DC, peak at 100 Hz phi = np.random.uniform(-np.pi, np.pi, len(r)) - fake_coeffs = r*np.exp(1j * phi) + fake_coeffs = r * np.exp(1j * phi) fake_ts = scipy.fft.irfft(fake_coeffs) sampling_period = 0.001 freqs = scipy.fft.rfftfreq(len(fake_ts), d=sampling_period) signal_freq = freqs[r.argmax()] - data = AnalogSignal(fake_ts, sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal(fake_ts, sampling_period=sampling_period * pq.s, units="mV") # consistency between different ways of specifying number of tapers - freqs1, psd1 = elephant.spectral.multitaper_psd(data, - fs=data.sampling_rate, - nw=3.5) - freqs2, psd2 = elephant.spectral.multitaper_psd(data, - fs=data.sampling_rate, - nw=3.5, - num_tapers=6) + freqs1, psd1 = elephant.spectral.multitaper_psd( + data, fs=data.sampling_rate, nw=3.5 + ) + freqs2, psd2 = elephant.spectral.multitaper_psd( + data, fs=data.sampling_rate, nw=3.5, num_tapers=6 + ) self.assertTrue((psd1 == psd2).all() and (freqs1 == freqs2).all()) # peak resolution and consistency with data peak_res = 1.0 * pq.Hz - freqs, psd = elephant.spectral.multitaper_psd( - data, peak_resolution=peak_res) + freqs, psd = elephant.spectral.multitaper_psd(data, peak_resolution=peak_res) self.assertEqual(freqs[psd.argmax()], signal_freq) freqs_np, psd_np = elephant.spectral.multitaper_psd( - data.magnitude.flatten(), fs=1 / sampling_period, - peak_resolution=peak_res) + data.magnitude.flatten(), fs=1 / sampling_period, peak_resolution=peak_res + ) self.assertTrue((freqs == freqs_np).all() and (psd == psd_np).all()) def test_multitaper_psd_parameter_hierarchy(self): @@ -233,33 +242,30 @@ def test_multitaper_psd_parameter_hierarchy(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=data_length) - signal = [np.sin(2 * np.pi * signal_freq * t) - for t in np.arange(0, data_length * sampling_period, - sampling_period)] - data = AnalogSignal(np.array(signal + noise), - sampling_period=sampling_period * pq.s, - units='mV') + signal = [ + np.sin(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, sampling_period) + ] + data = AnalogSignal( + np.array(signal + noise), sampling_period=sampling_period * pq.s, units="mV" + ) # Test num_tapers vs nw - freqs1, psd1 = elephant.spectral.multitaper_psd(data, - fs=data.sampling_rate, - nw=3, - num_tapers=9) - freqs2, psd2 = elephant.spectral.multitaper_psd(data, - fs=data.sampling_rate, - nw=3) + freqs1, psd1 = elephant.spectral.multitaper_psd( + data, fs=data.sampling_rate, nw=3, num_tapers=9 + ) + freqs2, psd2 = elephant.spectral.multitaper_psd( + data, fs=data.sampling_rate, nw=3 + ) self.assertTrue((freqs1 == freqs2).all() and (psd1 != psd2).all()) # Test peak_resolution vs nw - freqs1, psd1 = elephant.spectral.multitaper_psd(data, - fs=data.sampling_rate, - nw=3, - num_tapers=9, - peak_resolution=1) - freqs2, psd2 = elephant.spectral.multitaper_psd(data, - fs=data.sampling_rate, - nw=3, - num_tapers=9) + freqs1, psd1 = elephant.spectral.multitaper_psd( + data, fs=data.sampling_rate, nw=3, num_tapers=9, peak_resolution=1 + ) + freqs2, psd2 = elephant.spectral.multitaper_psd( + data, fs=data.sampling_rate, nw=3, num_tapers=9 + ) self.assertTrue((freqs1 == freqs2).all() and (psd1 != psd2).all()) def test_multitaper_psd_against_nitime(self): @@ -275,28 +281,31 @@ def test_multitaper_psd_against_nitime(self): files_to_download = [ ("time_series.npy", "ff43797e2ac94613f510b20a31e2e80e"), - ("psd_nitime.npy", "89d1f53957e66c786049ea425b53c0e8") + ("psd_nitime.npy", "89d1f53957e66c786049ea425b53c0e8"), ] for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + download_datasets(repo_path=f"{repo_path}/{filename}", checksum=checksum) - time_series = np.load(ELEPHANT_TMP_DIR / 'time_series.npy') - psd_nitime = np.load(ELEPHANT_TMP_DIR / 'psd_nitime.npy') + time_series = np.load(ELEPHANT_TMP_DIR / "time_series.npy") + psd_nitime = np.load(ELEPHANT_TMP_DIR / "psd_nitime.npy") freqs, psd_multitaper = elephant.spectral.multitaper_psd( - signal=time_series, fs=0.1, nw=4, num_tapers=8) + signal=time_series, fs=0.1, nw=4, num_tapers=8 + ) - np.testing.assert_allclose(np.squeeze(psd_multitaper), psd_nitime, - rtol=0.3, atol=0.1) + np.testing.assert_allclose( + np.squeeze(psd_multitaper), psd_nitime, rtol=0.3, atol=0.1 + ) def test_multitaper_psd_input_types(self): # generate a test data sampling_period = 0.001 - data = AnalogSignal(np.array(np.random.normal(size=5000)), - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.array(np.random.normal(size=5000)), + sampling_period=sampling_period * pq.s, + units="mV", + ) # outputs from AnalogSignal input are of Quantity type (standard usage) freqs_neo, psd_neo = elephant.spectral.multitaper_psd(data) @@ -305,13 +314,15 @@ def test_multitaper_psd_input_types(self): # outputs from Quantity array input are of Quantity type freqs_pq, psd_pq = elephant.spectral.multitaper_psd( - data.magnitude.flatten() * data.units, fs=1 / sampling_period) + data.magnitude.flatten() * data.units, fs=1 / sampling_period + ) self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity)) self.assertTrue(isinstance(psd_pq, pq.quantity.Quantity)) # outputs from Numpy ndarray input are NOT of Quantity type freqs_np, psd_np = elephant.spectral.multitaper_psd( - data.magnitude.flatten(), fs=1 / sampling_period) + data.magnitude.flatten(), fs=1 / sampling_period + ) self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity)) self.assertFalse(isinstance(psd_np, pq.quantity.Quantity)) @@ -319,20 +330,18 @@ def test_multitaper_psd_input_types(self): fs_hz = 1 * pq.Hz fs_int = 1 freqs_fs_hz, psd_fs_hz = elephant.spectral.multitaper_psd( - data.magnitude.T, fs=fs_hz) + data.magnitude.T, fs=fs_hz + ) freqs_fs_int, psd_fs_int = elephant.spectral.multitaper_psd( - data.magnitude.T, fs=fs_int) + data.magnitude.T, fs=fs_int + ) np.testing.assert_array_equal(freqs_fs_hz, freqs_fs_int) np.testing.assert_array_equal(psd_fs_hz, psd_fs_int) # check if the results from different input types are identical - self.assertTrue( - (freqs_neo == freqs_pq).all() and ( - psd_neo == psd_pq).all()) - self.assertTrue( - (freqs_neo == freqs_np).all() and ( - psd_neo == psd_np).all()) + self.assertTrue((freqs_neo == freqs_pq).all() and (psd_neo == psd_pq).all()) + self.assertTrue((freqs_neo == freqs_np).all() and (psd_neo == psd_np).all()) class SegmentedMultitaperPSDTestCase(unittest.TestCase): @@ -343,57 +352,74 @@ class SegmentedMultitaperPSDTestCase(unittest.TestCase): def test_segmented_multitaper_psd_errors(self): # generate dummy data data_length = 5000 - signal = AnalogSignal(np.zeros(data_length), - sampling_period=0.001 * pq.s, - units='mV') + signal = AnalogSignal( + np.zeros(data_length), sampling_period=0.001 * pq.s, units="mV" + ) fs = signal.sampling_rate # check for invalid parameter values # - frequency resolution - self.assertRaises(ValueError, - elephant.spectral.segmented_multitaper_psd, signal, - frequency_resolution=-10) + self.assertRaises( + ValueError, + elephant.spectral.segmented_multitaper_psd, + signal, + frequency_resolution=-10, + ) # - n per segment # n_per_seg = int(fs / dF), where dF is the frequency_resolution - broken_freq_resolution = fs / (data_length+1) - self.assertRaises(ValueError, - elephant.spectral.segmented_multitaper_psd, signal, - frequency_resolution=broken_freq_resolution) + broken_freq_resolution = fs / (data_length + 1) + self.assertRaises( + ValueError, + elephant.spectral.segmented_multitaper_psd, + signal, + frequency_resolution=broken_freq_resolution, + ) # - length of segment (negative) - self.assertRaises(ValueError, - elephant.spectral.segmented_multitaper_psd, signal, - len_segment=-10) + self.assertRaises( + ValueError, + elephant.spectral.segmented_multitaper_psd, + signal, + len_segment=-10, + ) # - length of segment (larger than data length) - self.assertRaises(ValueError, - elephant.spectral.segmented_multitaper_psd, signal, - len_segment=data_length+1) + self.assertRaises( + ValueError, + elephant.spectral.segmented_multitaper_psd, + signal, + len_segment=data_length + 1, + ) # - number of segments (negative) - self.assertRaises(ValueError, - elephant.spectral.segmented_multitaper_psd, signal, - n_segments=-10) + self.assertRaises( + ValueError, + elephant.spectral.segmented_multitaper_psd, + signal, + n_segments=-10, + ) # - number of segments (larger than data length) - self.assertRaises(ValueError, - elephant.spectral.segmented_multitaper_psd, signal, - n_segments=data_length+1) + self.assertRaises( + ValueError, + elephant.spectral.segmented_multitaper_psd, + signal, + n_segments=data_length + 1, + ) def test_segmented_multitaper_psd_behavior(self): # generate data (frequency domain to time domain) r = np.ones(2501) * 0.2 r[0], r[500] = 0, 10 # Zero DC, peak at 100 Hz phi = np.random.uniform(-np.pi, np.pi, len(r)) - fake_coeffs = r*np.exp(1j * phi) + fake_coeffs = r * np.exp(1j * phi) fake_ts = scipy.fft.irfft(fake_coeffs) sampling_period = 0.001 freqs = scipy.fft.rfftfreq(len(fake_ts), d=sampling_period) signal_freq = freqs[r.argmax()] - data = AnalogSignal(fake_ts, sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal(fake_ts, sampling_period=sampling_period * pq.s, units="mV") # consistency between different ways of specifying n_per_seg # n_per_seg = int(fs/dF) and n_per_seg = len_segment @@ -401,10 +427,12 @@ def test_segmented_multitaper_psd_behavior(self): len_segment = int(data.sampling_rate / frequency_resolution) freqs_fr, psd_fr = elephant.spectral.segmented_multitaper_psd( - data, frequency_resolution=frequency_resolution) + data, frequency_resolution=frequency_resolution + ) freqs_ls, psd_ls = elephant.spectral.segmented_multitaper_psd( - data, len_segment=len_segment) + data, len_segment=len_segment + ) np.testing.assert_array_equal(freqs_fr, freqs_ls) np.testing.assert_array_equal(psd_fr, psd_ls) @@ -415,30 +443,33 @@ def test_segmented_multitaper_psd_parameter_hierarchy(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=data_length) - signal = [np.sin(2 * np.pi * signal_freq * t) - for t in np.arange(0, data_length * sampling_period, - sampling_period)] - data = AnalogSignal(np.array(signal + noise), - sampling_period=sampling_period * pq.s, - units='mV') + signal = [ + np.sin(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, sampling_period) + ] + data = AnalogSignal( + np.array(signal + noise), sampling_period=sampling_period * pq.s, units="mV" + ) # test frequency_resolution vs len_segment vs n_segments n_segments = 5 len_segment = 2000 frequency_resolution = 0.25 * pq.Hz - freqs_ns, psd_ns = \ - elephant.spectral.segmented_multitaper_psd( - data, n_segments=n_segments) + freqs_ns, psd_ns = elephant.spectral.segmented_multitaper_psd( + data, n_segments=n_segments + ) - freqs_ls, psd_ls = \ - elephant.spectral.segmented_multitaper_psd( - data, n_segments=n_segments, len_segment=len_segment) + freqs_ls, psd_ls = elephant.spectral.segmented_multitaper_psd( + data, n_segments=n_segments, len_segment=len_segment + ) - freqs_fr, psd_fr = \ - elephant.spectral.segmented_multitaper_psd( - data, n_segments=n_segments, len_segment=len_segment, - frequency_resolution=frequency_resolution) + freqs_fr, psd_fr = elephant.spectral.segmented_multitaper_psd( + data, + n_segments=n_segments, + len_segment=len_segment, + frequency_resolution=frequency_resolution, + ) self.assertTrue(freqs_ns.shape < freqs_ls.shape < freqs_fr.shape) self.assertTrue(psd_ns.shape < psd_ls.shape < psd_fr.shape) @@ -446,9 +477,11 @@ def test_segmented_multitaper_psd_parameter_hierarchy(self): def test_segmented_multitaper_psd_input_types(self): # generate a test data sampling_period = 0.001 - data = AnalogSignal(np.array(np.random.normal(size=5000)), - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.array(np.random.normal(size=5000)), + sampling_period=sampling_period * pq.s, + units="mV", + ) # outputs from AnalogSignal input are of Quantity type (standard usage) freqs_neo, psd_neo = elephant.spectral.segmented_multitaper_psd(data) @@ -457,13 +490,15 @@ def test_segmented_multitaper_psd_input_types(self): # outputs from Quantity array input are of Quantity type freqs_pq, psd_pq = elephant.spectral.segmented_multitaper_psd( - data.magnitude.flatten() * data.units, fs=1 / sampling_period) + data.magnitude.flatten() * data.units, fs=1 / sampling_period + ) self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity)) self.assertTrue(isinstance(psd_pq, pq.quantity.Quantity)) # outputs from Numpy ndarray input are NOT of Quantity type freqs_np, psd_np = elephant.spectral.segmented_multitaper_psd( - data.magnitude.flatten(), fs=1 / sampling_period) + data.magnitude.flatten(), fs=1 / sampling_period + ) self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity)) self.assertFalse(isinstance(psd_np, pq.quantity.Quantity)) @@ -472,10 +507,12 @@ def test_segmented_multitaper_psd_input_types(self): freq_res_int = 1 freqs_int, psd_int = elephant.spectral.segmented_multitaper_psd( - data, frequency_resolution=freq_res_int) + data, frequency_resolution=freq_res_int + ) freqs_hz, psd_hz = elephant.spectral.segmented_multitaper_psd( - data, frequency_resolution=freq_res_hz) + data, frequency_resolution=freq_res_hz + ) np.testing.assert_array_equal(freqs_int, freqs_hz) np.testing.assert_array_equal(psd_int, psd_hz) @@ -484,44 +521,54 @@ def test_segmented_multitaper_psd_input_types(self): fs_hz = 1 * pq.Hz fs_int = 1 freqs_fs_hz, psd_fs_hz = elephant.spectral.multitaper_psd( - data.magnitude.T, fs=fs_hz) + data.magnitude.T, fs=fs_hz + ) freqs_fs_int, psd_fs_int = elephant.spectral.multitaper_psd( - data.magnitude.T, fs=fs_int) + data.magnitude.T, fs=fs_int + ) np.testing.assert_array_equal(freqs_fs_hz, freqs_fs_int) np.testing.assert_array_equal(psd_fs_hz, psd_fs_int) # check if the results from different input types are identical - self.assertTrue( - (freqs_neo == freqs_pq).all() and ( - psd_neo == psd_pq).all()) - self.assertTrue( - (freqs_neo == freqs_np).all() and ( - psd_neo == psd_np).all()) + self.assertTrue((freqs_neo == freqs_pq).all() and (psd_neo == psd_pq).all()) + self.assertTrue((freqs_neo == freqs_np).all() and (psd_neo == psd_np).all()) class MultitaperCrossSpectrumTestCase(unittest.TestCase): def test_multitaper_cross_spectrum_errors(self): # generate dummy data data_length = 5000 - signal = AnalogSignal(np.zeros(data_length), - sampling_period=0.001 * pq.s, - units='mV') + signal = AnalogSignal( + np.zeros(data_length), sampling_period=0.001 * pq.s, units="mV" + ) fs = signal.sampling_rate # check for invalid parameter values # - number of tapers - self.assertRaises(ValueError, - elephant.spectral.multitaper_cross_spectrum, signal, - fs=fs, num_tapers=-5) - self.assertRaises(TypeError, - elephant.spectral.multitaper_cross_spectrum, signal, - fs=fs, num_tapers=-5.0) + self.assertRaises( + ValueError, + elephant.spectral.multitaper_cross_spectrum, + signal, + fs=fs, + num_tapers=-5, + ) + self.assertRaises( + TypeError, + elephant.spectral.multitaper_cross_spectrum, + signal, + fs=fs, + num_tapers=-5.0, + ) # - peak resolution - self.assertRaises(ValueError, - elephant.spectral.multitaper_cross_spectrum, signal, - fs=fs, peak_resolution=-1) + self.assertRaises( + ValueError, + elephant.spectral.multitaper_cross_spectrum, + signal, + fs=fs, + peak_resolution=-1, + ) def test_multitaper_cross_spectrum_behavior(self): # generate data (frequency domain to time domain) @@ -529,70 +576,72 @@ def test_multitaper_cross_spectrum_behavior(self): r[0], r[500] = 0, 10 # Zero DC, peak at 100 Hz phi_x = np.random.uniform(-np.pi, np.pi, len(r)) phi_y = np.random.uniform(-np.pi, np.pi, len(r)) - fake_coeffs_x = r*np.exp(1j * phi_x) - fake_coeffs_y = r*np.exp(1j * phi_y) + fake_coeffs_x = r * np.exp(1j * phi_x) + fake_coeffs_y = r * np.exp(1j * phi_y) signal_x = scipy.fft.irfft(fake_coeffs_x) signal_y = scipy.fft.irfft(fake_coeffs_y) sampling_period = 0.001 freqs = scipy.fft.rfftfreq(len(signal_x), d=sampling_period) signal_freq = freqs[r.argmax()] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) # consistency between different ways of specifying number of tapers - freqs1, cross_spec1 = \ - elephant.spectral.multitaper_cross_spectrum( - data, - fs=data.sampling_rate, - nw=3.5) - freqs2, cross_spec2 = \ - elephant.spectral.multitaper_cross_spectrum( - data, - fs=data.sampling_rate, - nw=3.5, - num_tapers=6) - self.assertTrue((cross_spec1 == cross_spec2).all() - and (freqs1 == freqs2).all()) + freqs1, cross_spec1 = elephant.spectral.multitaper_cross_spectrum( + data, fs=data.sampling_rate, nw=3.5 + ) + freqs2, cross_spec2 = elephant.spectral.multitaper_cross_spectrum( + data, fs=data.sampling_rate, nw=3.5, num_tapers=6 + ) + self.assertTrue((cross_spec1 == cross_spec2).all() and (freqs1 == freqs2).all()) # peak resolution and consistency with data peak_res = 1.0 * pq.Hz - freqs, cross_spec = \ - elephant.spectral.multitaper_cross_spectrum( - data, peak_resolution=peak_res) + freqs, cross_spec = elephant.spectral.multitaper_cross_spectrum( + data, peak_resolution=peak_res + ) self.assertEqual(freqs[cross_spec[0, 0].argmax()], signal_freq) - freqs_np, cross_spec_np = \ - elephant.spectral.multitaper_cross_spectrum( - data.magnitude.T, fs=1 / sampling_period, - peak_resolution=peak_res) - self.assertTrue((freqs == freqs_np).all() - and (cross_spec == cross_spec_np).all()) + freqs_np, cross_spec_np = elephant.spectral.multitaper_cross_spectrum( + data.magnitude.T, fs=1 / sampling_period, peak_resolution=peak_res + ) + self.assertTrue( + (freqs == freqs_np).all() and (cross_spec == cross_spec_np).all() + ) # one-sided vs two-sided spectrum - freqs_os, cross_spec_os = \ - elephant.spectral.multitaper_cross_spectrum( - data, return_onesided=True) + freqs_os, cross_spec_os = elephant.spectral.multitaper_cross_spectrum( + data, return_onesided=True + ) - freqs_ts, cross_spec_ts = \ - elephant.spectral.multitaper_cross_spectrum( - data, return_onesided=False) + freqs_ts, cross_spec_ts = elephant.spectral.multitaper_cross_spectrum( + data, return_onesided=False + ) # Nyquist frequency is negative when using onesided=False (fftfreq) # See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.rfftfreq.html#scipy.fft.rfftfreq # noqa nonnegative_freqs_indices = np.nonzero(freqs_ts >= 0)[0] nyquist_freq_idx = np.abs(freqs_ts).argmax() - ts_freq_indices = np.append(nonnegative_freqs_indices, - nyquist_freq_idx) - ts_overlap_freqs = np.append( - freqs_ts[nonnegative_freqs_indices].rescale('Hz').magnitude, - np.abs(freqs_ts[nyquist_freq_idx].rescale('Hz').magnitude)) * pq.Hz + ts_freq_indices = np.append(nonnegative_freqs_indices, nyquist_freq_idx) + ts_overlap_freqs = ( + np.append( + freqs_ts[nonnegative_freqs_indices].rescale("Hz").magnitude, + np.abs(freqs_ts[nyquist_freq_idx].rescale("Hz").magnitude), + ) + * pq.Hz + ) np.testing.assert_array_equal(freqs_os, ts_overlap_freqs) np.testing.assert_allclose( cross_spec_os.magnitude, - cross_spec_ts[:, :, ts_freq_indices].magnitude, rtol=1e-12, atol=0) + cross_spec_ts[:, :, ts_freq_indices].magnitude, + rtol=1e-12, + atol=0, + ) def test_multitaper_cross_spectrum_parameter_hierarchy(self): # generate data by adding white noise and a sinusoid @@ -600,31 +649,34 @@ def test_multitaper_cross_spectrum_parameter_hierarchy(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=(2, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) signal_x = np.sin(2 * np.pi * signal_freq * time_points) + noise[0] signal_y = np.cos(2 * np.pi * signal_freq * time_points) + noise[1] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) # Test num_tapers vs nw freqs1, cross_spec1 = elephant.spectral.multitaper_cross_spectrum( - data, fs=data.sampling_rate, nw=3, num_tapers=9) + data, fs=data.sampling_rate, nw=3, num_tapers=9 + ) freqs2, cross_spec2 = elephant.spectral.multitaper_cross_spectrum( - data, fs=data.sampling_rate, nw=3) + data, fs=data.sampling_rate, nw=3 + ) - self.assertTrue((freqs1 == freqs2).all() - and (cross_spec1 != cross_spec2).all()) + self.assertTrue((freqs1 == freqs2).all() and (cross_spec1 != cross_spec2).all()) # Test peak_resolution vs nw freqs1, cross_spec1 = elephant.spectral.multitaper_cross_spectrum( - data, fs=data.sampling_rate, nw=3, num_tapers=9, peak_resolution=1) + data, fs=data.sampling_rate, nw=3, num_tapers=9, peak_resolution=1 + ) freqs2, cross_spec2 = elephant.spectral.multitaper_cross_spectrum( - data, fs=data.sampling_rate, nw=3, num_tapers=9) + data, fs=data.sampling_rate, nw=3, num_tapers=9 + ) - self.assertTrue((freqs1 == freqs2).all() - and (cross_spec1 != cross_spec2).all()) + self.assertTrue((freqs1 == freqs2).all() and (cross_spec1 != cross_spec2).all()) def test_multitaper_cross_spectrum_input_types(self): # generate a test data @@ -632,85 +684,107 @@ def test_multitaper_cross_spectrum_input_types(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=(2, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) signal_x = np.sin(2 * np.pi * signal_freq * time_points) + noise[0] signal_y = np.cos(2 * np.pi * signal_freq * time_points) + noise[1] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) # outputs from AnalogSignal input are of Quantity type (standard usage) - freqs_neo, cross_spec_neo \ - = elephant.spectral.multitaper_cross_spectrum(data) + freqs_neo, cross_spec_neo = elephant.spectral.multitaper_cross_spectrum(data) self.assertTrue(isinstance(freqs_neo, pq.quantity.Quantity)) self.assertTrue(isinstance(cross_spec_neo, pq.quantity.Quantity)) # outputs from Quantity array input are of Quantity type - freqs_pq, cross_spec_pq \ - = elephant.spectral.multitaper_cross_spectrum( - data.magnitude.T * data.units, - fs=1 / (sampling_period * pq.s)) + freqs_pq, cross_spec_pq = elephant.spectral.multitaper_cross_spectrum( + data.magnitude.T * data.units, fs=1 / (sampling_period * pq.s) + ) self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity)) self.assertTrue(isinstance(cross_spec_pq, pq.quantity.Quantity)) # outputs from Numpy ndarray input are NOT of Quantity type - freqs_np, cross_spec_np \ - = elephant.spectral.multitaper_cross_spectrum( - data.magnitude.T, - fs=1 / (sampling_period * pq.s)) + freqs_np, cross_spec_np = elephant.spectral.multitaper_cross_spectrum( + data.magnitude.T, fs=1 / (sampling_period * pq.s) + ) self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity)) self.assertFalse(isinstance(cross_spec_np, pq.quantity.Quantity)) # check if the results from different input types are identical self.assertTrue( - (freqs_neo == freqs_pq).all() and - (cross_spec_neo == cross_spec_pq).all()) + (freqs_neo == freqs_pq).all() and (cross_spec_neo == cross_spec_pq).all() + ) self.assertTrue( - (freqs_neo == freqs_np).all() and - (cross_spec_neo == cross_spec_np).all()) + (freqs_neo == freqs_np).all() and (cross_spec_neo == cross_spec_np).all() + ) class SegmentedMultitaperCrossSpectrumTestCase(unittest.TestCase): def test_segmented_multitaper_cross_spectrum_errors(self): # generate dummy data data_length = 5000 - signal = AnalogSignal(np.zeros(data_length), - sampling_period=0.001 * pq.s, - units='mV') + signal = AnalogSignal( + np.zeros(data_length), sampling_period=0.001 * pq.s, units="mV" + ) fs = signal.sampling_rate # - frequency resolution self.assertRaises( - ValueError, elephant.spectral.segmented_multitaper_cross_spectrum, - signal, fs=fs, frequency_resolution=-10) + ValueError, + elephant.spectral.segmented_multitaper_cross_spectrum, + signal, + fs=fs, + frequency_resolution=-10, + ) # - n per segment # n_per_seg = int(fs / dF), where dF is the frequency_resolution - broken_freq_resolution = fs / (data_length+1) + broken_freq_resolution = fs / (data_length + 1) self.assertRaises( - ValueError, elephant.spectral.segmented_multitaper_cross_spectrum, - signal, fs=fs, frequency_resolution=broken_freq_resolution) + ValueError, + elephant.spectral.segmented_multitaper_cross_spectrum, + signal, + fs=fs, + frequency_resolution=broken_freq_resolution, + ) # - length of segment (negative) self.assertRaises( - ValueError, elephant.spectral.segmented_multitaper_cross_spectrum, - signal, fs=fs, len_segment=-10) + ValueError, + elephant.spectral.segmented_multitaper_cross_spectrum, + signal, + fs=fs, + len_segment=-10, + ) # - length of segment (larger than data length) self.assertRaises( - ValueError, elephant.spectral.segmented_multitaper_cross_spectrum, - signal, fs=fs, len_segment=data_length+1) + ValueError, + elephant.spectral.segmented_multitaper_cross_spectrum, + signal, + fs=fs, + len_segment=data_length + 1, + ) # - number of segments (negative) self.assertRaises( - ValueError, elephant.spectral.segmented_multitaper_cross_spectrum, - signal, fs=fs, n_segments=-10) + ValueError, + elephant.spectral.segmented_multitaper_cross_spectrum, + signal, + fs=fs, + n_segments=-10, + ) # - number of segments (larger than data length) self.assertRaises( - ValueError, elephant.spectral.segmented_multitaper_cross_spectrum, - signal, fs=fs, n_segments=data_length+1) + ValueError, + elephant.spectral.segmented_multitaper_cross_spectrum, + signal, + fs=fs, + n_segments=data_length + 1, + ) def test_segmented_multitaper_cross_spectrum_behavior(self): # generate data (frequency domain to time domain) @@ -718,41 +792,43 @@ def test_segmented_multitaper_cross_spectrum_behavior(self): r[0], r[500] = 0, 10 # Zero DC, peak at 100 Hz phi_x = np.random.uniform(-np.pi, np.pi, len(r)) phi_y = np.random.uniform(-np.pi, np.pi, len(r)) - fake_coeffs_x = r*np.exp(1j * phi_x) - fake_coeffs_y = r*np.exp(1j * phi_y) + fake_coeffs_x = r * np.exp(1j * phi_x) + fake_coeffs_y = r * np.exp(1j * phi_y) signal_x = scipy.fft.irfft(fake_coeffs_x) signal_y = scipy.fft.irfft(fake_coeffs_y) sampling_period = 0.001 freqs = scipy.fft.rfftfreq(len(signal_x), d=sampling_period) signal_freq = freqs[r.argmax()] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) # consistency between different ways of specifying n_per_seg # n_per_seg = int(fs/dF) and n_per_seg = len_segment frequency_resolution = 1 * pq.Hz len_segment = int(data.sampling_rate / frequency_resolution) - freqs_fr, cross_spec_fr = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, frequency_resolution=frequency_resolution) + freqs_fr, cross_spec_fr = elephant.spectral.segmented_multitaper_cross_spectrum( + data, frequency_resolution=frequency_resolution + ) - freqs_ls, cross_spec_ls = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, len_segment=len_segment) + freqs_ls, cross_spec_ls = elephant.spectral.segmented_multitaper_cross_spectrum( + data, len_segment=len_segment + ) np.testing.assert_array_equal(freqs_fr, freqs_ls) np.testing.assert_array_equal(cross_spec_fr, cross_spec_ls) # one-sided vs two-sided spectrum - freqs_os, cross_spec_os = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, return_onesided=True) + freqs_os, cross_spec_os = elephant.spectral.segmented_multitaper_cross_spectrum( + data, return_onesided=True + ) - freqs_ts, cross_spec_ts = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, return_onesided=False) + freqs_ts, cross_spec_ts = elephant.spectral.segmented_multitaper_cross_spectrum( + data, return_onesided=False + ) # test overlap parameter no_overlap = 0 @@ -760,37 +836,42 @@ def test_segmented_multitaper_cross_spectrum_behavior(self): large_overlap = 0.99 n_segments = 10 - freqs_no, cross_spec_no = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, n_segments=n_segments, overlap=no_overlap) + freqs_no, cross_spec_no = elephant.spectral.segmented_multitaper_cross_spectrum( + data, n_segments=n_segments, overlap=no_overlap + ) - freqs_ho, cross_spec_ho = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, n_segments=n_segments, overlap=half_overlap) + freqs_ho, cross_spec_ho = elephant.spectral.segmented_multitaper_cross_spectrum( + data, n_segments=n_segments, overlap=half_overlap + ) - freqs_lo, cross_spec_lo = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, n_segments=n_segments, overlap=large_overlap) + freqs_lo, cross_spec_lo = elephant.spectral.segmented_multitaper_cross_spectrum( + data, n_segments=n_segments, overlap=large_overlap + ) self.assertTrue(freqs_no.shape < freqs_ho.shape < freqs_lo.shape) - self.assertTrue( - cross_spec_no.shape < cross_spec_ho.shape < cross_spec_lo.shape) + self.assertTrue(cross_spec_no.shape < cross_spec_ho.shape < cross_spec_lo.shape) # Nyquist frequency is negative when using onesided=False (fftfreq) # See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.rfftfreq.html#scipy.fft.rfftfreq # noqa nonnegative_freqs_indices = np.nonzero(freqs_ts >= 0)[0] nyquist_freq_idx = np.abs(freqs_ts).argmax() - ts_freq_indices = np.append(nonnegative_freqs_indices, - nyquist_freq_idx) - ts_overlap_freqs = np.append( - freqs_ts[nonnegative_freqs_indices].rescale('Hz').magnitude, - np.abs(freqs_ts[nyquist_freq_idx].rescale('Hz').magnitude)) * pq.Hz + ts_freq_indices = np.append(nonnegative_freqs_indices, nyquist_freq_idx) + ts_overlap_freqs = ( + np.append( + freqs_ts[nonnegative_freqs_indices].rescale("Hz").magnitude, + np.abs(freqs_ts[nyquist_freq_idx].rescale("Hz").magnitude), + ) + * pq.Hz + ) np.testing.assert_array_equal(freqs_os, ts_overlap_freqs) np.testing.assert_allclose( cross_spec_os.magnitude, - cross_spec_ts[:, :, ts_freq_indices].magnitude, rtol=1e-12, atol=0) + cross_spec_ts[:, :, ts_freq_indices].magnitude, + rtol=1e-12, + atol=0, + ) def test_segmented_multitaper_cross_spectrum_parameter_hierarchy(self): # test frequency_resolution vs len_segment vs n_segments @@ -799,33 +880,37 @@ def test_segmented_multitaper_cross_spectrum_parameter_hierarchy(self): r[0], r[500] = 0, 10 # Zero DC, peak at 100 Hz phi_x = np.random.uniform(-np.pi, np.pi, len(r)) phi_y = np.random.uniform(-np.pi, np.pi, len(r)) - fake_coeffs_x = r*np.exp(1j * phi_x) - fake_coeffs_y = r*np.exp(1j * phi_y) + fake_coeffs_x = r * np.exp(1j * phi_x) + fake_coeffs_y = r * np.exp(1j * phi_y) signal_x = scipy.fft.irfft(fake_coeffs_x) signal_y = scipy.fft.irfft(fake_coeffs_y) sampling_period = 0.001 freqs = scipy.fft.rfftfreq(len(signal_x), d=sampling_period) signal_freq = freqs[r.argmax()] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) n_segments = 5 len_segment = 2000 frequency_resolution = 1 * pq.Hz - freqs_ns, cross_spec_ns = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, n_segments=n_segments) + freqs_ns, cross_spec_ns = elephant.spectral.segmented_multitaper_cross_spectrum( + data, n_segments=n_segments + ) - freqs_ls, cross_spec_ls = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, n_segments=n_segments, len_segment=len_segment) + freqs_ls, cross_spec_ls = elephant.spectral.segmented_multitaper_cross_spectrum( + data, n_segments=n_segments, len_segment=len_segment + ) - freqs_fr, cross_spec_fr = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, n_segments=n_segments, len_segment=len_segment, - frequency_resolution=frequency_resolution) + freqs_fr, cross_spec_fr = elephant.spectral.segmented_multitaper_cross_spectrum( + data, + n_segments=n_segments, + len_segment=len_segment, + frequency_resolution=frequency_resolution, + ) self.assertNotEqual(freqs_ns.shape, freqs_ls.shape) self.assertNotEqual(freqs_ls.shape, freqs_fr.shape) @@ -839,33 +924,33 @@ def test_segmented_multitaper_cross_spectrum_against_multitaper_psd(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=(2, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) signal_x = np.sin(2 * np.pi * signal_freq * time_points) + noise[0] signal_y = np.cos(2 * np.pi * signal_freq * time_points) + noise[1] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) freqs1, psd_multitaper = elephant.spectral.multitaper_psd( - signal=data, fs=data.sampling_rate, nw=4, num_tapers=8) + signal=data, fs=data.sampling_rate, nw=4, num_tapers=8 + ) psd_multitaper[:, 1:] /= 2 # since comparing rfft and fft results - freqs2, cross_spec = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, - fs=data.sampling_rate, - nw=4, - num_tapers=8, - return_onesided=True) + freqs2, cross_spec = elephant.spectral.segmented_multitaper_cross_spectrum( + data, fs=data.sampling_rate, nw=4, num_tapers=8, return_onesided=True + ) self.assertTrue((freqs1 == freqs2).all()) - np.testing.assert_allclose(psd_multitaper.magnitude, - np.diagonal(cross_spec).T.real.magnitude, - rtol=0.01, - atol=0.01) + np.testing.assert_allclose( + psd_multitaper.magnitude, + np.diagonal(cross_spec).T.real.magnitude, + rtol=0.01, + atol=0.01, + ) def test_segmented_multitaper_cross_spectrum_input_types(self): # generate a test data @@ -873,25 +958,28 @@ def test_segmented_multitaper_cross_spectrum_input_types(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=(2, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) signal_x = np.sin(2 * np.pi * signal_freq * time_points) + noise[0] signal_y = np.cos(2 * np.pi * signal_freq * time_points) + noise[1] - data = AnalogSignal(np.vstack([signal_x, signal_y]).T, - sampling_period=sampling_period * pq.s, - units='mV') + data = AnalogSignal( + np.vstack([signal_x, signal_y]).T, + sampling_period=sampling_period * pq.s, + units="mV", + ) # frequency resolution as an integer freq_res_int = 1 freq_res_hz = 1 * pq.Hz - freqs_int, cross_spec_int = \ + freqs_int, cross_spec_int = ( elephant.spectral.segmented_multitaper_cross_spectrum( - data, frequency_resolution=freq_res_int) + data, frequency_resolution=freq_res_int + ) + ) - freqs_hz, cross_spec_hz = \ - elephant.spectral.segmented_multitaper_cross_spectrum( - data, frequency_resolution=freq_res_hz) + freqs_hz, cross_spec_hz = elephant.spectral.segmented_multitaper_cross_spectrum( + data, frequency_resolution=freq_res_hz + ) np.testing.assert_array_equal(freqs_int, freqs_hz) np.testing.assert_array_equal(cross_spec_int, cross_spec_hz) @@ -905,25 +993,25 @@ def test_multitaper_coherence_input_types(self): signal_freq = 100.0 np.random.seed(123) noise = np.random.normal(size=(2, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) # Signals are designed to have coherence peak at `signal_freq` arr_signal_i = np.sin(2 * np.pi * signal_freq * time_points) + noise[0] arr_signal_j = np.cos(2 * np.pi * signal_freq * time_points) + noise[1] fs = 1000 * pq.Hz - anasig_signal_i = neo.core.AnalogSignal(arr_signal_i, - sampling_rate=fs, - units=pq.mV) - anasig_signal_j = neo.core.AnalogSignal(arr_signal_j, - sampling_rate=fs, - units=pq.mV) + anasig_signal_i = neo.core.AnalogSignal( + arr_signal_i, sampling_rate=fs, units=pq.mV + ) + anasig_signal_j = neo.core.AnalogSignal( + arr_signal_j, sampling_rate=fs, units=pq.mV + ) arr_f, arr_coh, arr_phi = elephant.spectral.multitaper_coherence( - arr_signal_i, arr_signal_j, fs=fs) - anasig_f, anasig_coh, anasig_phi = \ - elephant.spectral.multitaper_coherence(anasig_signal_i, - anasig_signal_j) + arr_signal_i, arr_signal_j, fs=fs + ) + anasig_f, anasig_coh, anasig_phi = elephant.spectral.multitaper_coherence( + anasig_signal_i, anasig_signal_j + ) np.testing.assert_array_equal(arr_f, anasig_f) np.testing.assert_allclose(arr_coh, anasig_coh, atol=1e-6) @@ -935,47 +1023,43 @@ def test_multitaper_cohere_peak(self): sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=(2, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) # Signals are designed to have coherence peak at `signal_freq` signal_i = np.sin(2 * np.pi * signal_freq * time_points) + noise[0] signal_j = np.cos(2 * np.pi * signal_freq * time_points) + noise[1] # Estimate coherence and phase lag with the multitaper method freq1, coh1, phase_lag1 = elephant.spectral.multitaper_coherence( - signal_i, - signal_j, - fs=1/sampling_period, - n_segments=16) + signal_i, signal_j, fs=1 / sampling_period, n_segments=16 + ) indices, vals = scipy.signal.find_peaks(coh1, height=0.8, distance=10) peak_freqs = freq1[indices] - np.testing.assert_allclose(peak_freqs, - signal_freq*np.ones(len(peak_freqs)), - rtol=0.05) - - @pytest.mark.skipif(version.parse(np.__version__)>version.parse("1.25.0"), - reason="This test will fail with numpy version" - "1.25.0 - 1.25.2, see issue #24000" - "https://github.com/numpy/numpy/issues/24000 ") + np.testing.assert_allclose( + peak_freqs, signal_freq * np.ones(len(peak_freqs)), rtol=0.05 + ) + + @pytest.mark.skipif( + version.parse(np.__version__) > version.parse("1.25.0"), + reason="This test will fail with numpy version" + "1.25.0 - 1.25.2, see issue #24000" + "https://github.com/numpy/numpy/issues/24000 ", + ) def test_multitaper_cohere_perfect_cohere(self): # Generate dummy data data_length = 10000 sampling_period = 0.001 signal_freq = 100.0 noise = np.random.normal(size=(1, data_length)) - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) signal = np.cos(2 * np.pi * signal_freq * time_points) + noise # Estimate coherence and phase lag with the multitaper method freq1, coh, phase_lag = elephant.spectral.multitaper_coherence( - signal, - signal, - fs=1/sampling_period, - n_segments=16) + signal, signal, fs=1 / sampling_period, n_segments=16 + ) np.testing.assert_array_equal(phase_lag, np.zeros(phase_lag.size)) np.testing.assert_array_equal(coh, np.ones(coh.size)) @@ -984,18 +1068,15 @@ def test_multitaper_cohere_no_cohere(self): # Generate dummy data data_length = 10000 sampling_period = 0.001 - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) signal_i = np.sin(2 * np.pi * 2.5 * time_points) signal_j = np.sin(2 * np.pi * 5 * time_points) # Estimate coherence and phase lag with the multitaper method freq, coh, phase_lag = elephant.spectral.multitaper_coherence( - signal_i, - signal_j, - fs=1/sampling_period, - n_segments=16) + signal_i, signal_j, fs=1 / sampling_period, n_segments=16 + ) np.testing.assert_allclose(coh, np.zeros(coh.size), atol=0.002) @@ -1004,8 +1085,7 @@ def test_multitaper_cohere_phase_lag(self): data_length = 10000 sampling_period = 0.001 signal_freq = 100.0 - time_points = np.arange(0, data_length * sampling_period, - sampling_period) + time_points = np.arange(0, data_length * sampling_period, sampling_period) # Signals are designed to have maximal phase lag at 100 with value pi/4 signal_i = np.sin(2 * np.pi * signal_freq * time_points + np.pi / 4) @@ -1013,68 +1093,85 @@ def test_multitaper_cohere_phase_lag(self): # Estimate coherence and phase lag with the multitaper method freq, coh, phase_lag = elephant.spectral.multitaper_coherence( - signal_i, - signal_j, - fs=1/sampling_period, - n_segments=16, - num_tapers=8) + signal_i, signal_j, fs=1 / sampling_period, n_segments=16, num_tapers=8 + ) - indices, vals = scipy.signal.find_peaks(phase_lag, - height=0.8 * np.pi / 4, - distance=10) + indices, vals = scipy.signal.find_peaks( + phase_lag, height=0.8 * np.pi / 4, distance=10 + ) # Get peak frequencies and peak heights peak_freqs = freq[indices] - peak_heights = vals['peak_heights'] + peak_heights = vals["peak_heights"] - np.testing.assert_allclose(peak_freqs, - signal_freq*np.ones(len(peak_freqs)), - rtol=0.05) - np.testing.assert_allclose(peak_heights, - np.pi / 4 * np.ones(len(peak_heights)), - rtol=0.05) + np.testing.assert_allclose( + peak_freqs, signal_freq * np.ones(len(peak_freqs)), rtol=0.05 + ) + np.testing.assert_allclose( + peak_heights, np.pi / 4 * np.ones(len(peak_heights)), rtol=0.05 + ) class WelchCohereTestCase(unittest.TestCase): def test_welch_cohere_errors(self): # generate a dummy data - x = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, - units='mV') - y = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, - units='mV') + x = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, units="mV") + y = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, units="mV") # check for invalid parameter values # - length of segments - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - len_segment=0) - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - len_segment=x.shape[0] * 2) + self.assertRaises( + ValueError, elephant.spectral.welch_coherence, x, y, len_segment=0 + ) + self.assertRaises( + ValueError, + elephant.spectral.welch_coherence, + x, + y, + len_segment=x.shape[0] * 2, + ) # - number of segments - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - n_segments=0) - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - n_segments=x.shape[0] * 2) + self.assertRaises( + ValueError, elephant.spectral.welch_coherence, x, y, n_segments=0 + ) + self.assertRaises( + ValueError, + elephant.spectral.welch_coherence, + x, + y, + n_segments=x.shape[0] * 2, + ) # - frequency resolution - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - frequency_resolution=-1) - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - frequency_resolution=x.sampling_rate / - (x.shape[0] + 1)) + self.assertRaises( + ValueError, elephant.spectral.welch_coherence, x, y, frequency_resolution=-1 + ) + self.assertRaises( + ValueError, + elephant.spectral.welch_coherence, + x, + y, + frequency_resolution=x.sampling_rate / (x.shape[0] + 1), + ) # - overlap - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - overlap=-1.0) - self.assertRaises(ValueError, elephant.spectral.welch_coherence, x, y, - overlap=1.1) + self.assertRaises( + ValueError, elephant.spectral.welch_coherence, x, y, overlap=-1.0 + ) + self.assertRaises( + ValueError, elephant.spectral.welch_coherence, x, y, overlap=1.1 + ) def test_welch_cohere_warnings(self): # generate a dummy data - x = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, - units='mV') - y = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, - units='mV') + x = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, units="mV") + y = AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, units="mV") # Test deprecation warning for 'hanning' window - self.assertWarns(DeprecationWarning, elephant.spectral.welch_coherence, - x, y, window='hanning') + self.assertWarns( + DeprecationWarning, + elephant.spectral.welch_coherence, + x, + y, + window="hanning", + ) def test_welch_cohere_behavior(self): # generate data by adding white noise and a sinusoid @@ -1083,40 +1180,52 @@ def test_welch_cohere_behavior(self): signal_freq = 100.0 noise1 = np.random.normal(size=data_length) * 0.01 noise2 = np.random.normal(size=data_length) * 0.01 - signal1 = [np.cos(2 * np.pi * signal_freq * t) - for t in np.arange(0, data_length * sampling_period, - sampling_period)] - signal2 = [np.sin(2 * np.pi * signal_freq * t) - for t in np.arange(0, data_length * sampling_period, - sampling_period)] - x = AnalogSignal(np.array(signal1 + noise1), units='mV', - sampling_period=sampling_period * pq.s) - y = AnalogSignal(np.array(signal2 + noise2), units='mV', - sampling_period=sampling_period * pq.s) + signal1 = [ + np.cos(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, sampling_period) + ] + signal2 = [ + np.sin(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, sampling_period) + ] + x = AnalogSignal( + np.array(signal1 + noise1), + units="mV", + sampling_period=sampling_period * pq.s, + ) + y = AnalogSignal( + np.array(signal2 + noise2), + units="mV", + sampling_period=sampling_period * pq.s, + ) # consistency between different ways of specifying segment length freqs1, coherency1, phase_lag1 = elephant.spectral.welch_coherence( - x, y, len_segment=data_length // 5, overlap=0) + x, y, len_segment=data_length // 5, overlap=0 + ) freqs2, coherency2, phase_lag2 = elephant.spectral.welch_coherence( - x, y, n_segments=5, overlap=0) - self.assertTrue((coherency1 == coherency2).all() and - (phase_lag1 == phase_lag2).all() and - (freqs1 == freqs2).all()) + x, y, n_segments=5, overlap=0 + ) + self.assertTrue( + (coherency1 == coherency2).all() + and (phase_lag1 == phase_lag2).all() + and (freqs1 == freqs2).all() + ) # frequency resolution and consistency with data freq_res = 1.0 * pq.Hz freqs, coherency, phase_lag = elephant.spectral.welch_coherence( - x, y, frequency_resolution=freq_res) + x, y, frequency_resolution=freq_res + ) self.assertAlmostEqual(freq_res, freqs[1] - freqs[0]) - self.assertAlmostEqual(freqs[coherency.argmax()], signal_freq, - places=2) - self.assertAlmostEqual(phase_lag[coherency.argmax()], -np.pi / 2, - places=2) - freqs_np, coherency_np, phase_lag_np = \ - elephant.spectral.welch_coherence(x.magnitude.flatten(), - y.magnitude.flatten(), - fs=1 / sampling_period, - frequency_resolution=freq_res) + self.assertAlmostEqual(freqs[coherency.argmax()], signal_freq, places=2) + self.assertAlmostEqual(phase_lag[coherency.argmax()], -np.pi / 2, places=2) + freqs_np, coherency_np, phase_lag_np = elephant.spectral.welch_coherence( + x.magnitude.flatten(), + y.magnitude.flatten(), + fs=1 / sampling_period, + frequency_resolution=freq_res, + ) assert_array_equal(freqs.simplified.magnitude, freqs_np) assert_array_equal(coherency[:, 0], coherency_np) assert_array_equal(phase_lag[:, 0], phase_lag_np) @@ -1126,10 +1235,12 @@ def test_welch_cohere_behavior(self): data_length = 5000 x_multidim = np.random.normal(size=(num_channel, data_length)) y_multidim = np.random.normal(size=(num_channel, data_length)) - freqs, coherency, phase_lag = \ - elephant.spectral.welch_coherence(x_multidim, y_multidim) + freqs, coherency, phase_lag = elephant.spectral.welch_coherence( + x_multidim, y_multidim + ) freqs_T, coherency_T, phase_lag_T = elephant.spectral.welch_coherence( - x_multidim.T, y_multidim.T, axis=0) + x_multidim.T, y_multidim.T, axis=0 + ) assert_array_equal(freqs, freqs_T) assert_array_equal(coherency, coherency_T.T) assert_array_equal(phase_lag, phase_lag_T.T) @@ -1137,43 +1248,52 @@ def test_welch_cohere_behavior(self): def test_welch_cohere_input_types(self): # generate a test data sampling_period = 0.001 - x = AnalogSignal(np.array(np.random.normal(size=5000)), - sampling_period=sampling_period * pq.s, - units='mV') - y = AnalogSignal(np.array(np.random.normal(size=5000)), - sampling_period=sampling_period * pq.s, - units='mV') + x = AnalogSignal( + np.array(np.random.normal(size=5000)), + sampling_period=sampling_period * pq.s, + units="mV", + ) + y = AnalogSignal( + np.array(np.random.normal(size=5000)), + sampling_period=sampling_period * pq.s, + units="mV", + ) # outputs from AnalogSignal input are of Quantity type # (standard usage) - freqs_neo, coherency_neo, phase_lag_neo = \ - elephant.spectral.welch_coherence(x, y) + freqs_neo, coherency_neo, phase_lag_neo = elephant.spectral.welch_coherence( + x, y + ) self.assertTrue(isinstance(freqs_neo, pq.quantity.Quantity)) self.assertTrue(isinstance(phase_lag_neo, pq.quantity.Quantity)) # outputs from Quantity array input are of Quantity type - freqs_pq, coherency_pq, phase_lag_pq = elephant.spectral \ - .welch_coherence(x.magnitude.flatten() * x.units, - y.magnitude.flatten() * y.units, - fs=1 / sampling_period) + freqs_pq, coherency_pq, phase_lag_pq = elephant.spectral.welch_coherence( + x.magnitude.flatten() * x.units, + y.magnitude.flatten() * y.units, + fs=1 / sampling_period, + ) self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity)) self.assertTrue(isinstance(phase_lag_pq, pq.quantity.Quantity)) # outputs from Numpy ndarray input are NOT of Quantity type - freqs_np, coherency_np, phase_lag_np = elephant.spectral \ - .welch_coherence(x.magnitude.flatten(), - y.magnitude.flatten(), - fs=1 / sampling_period) + freqs_np, coherency_np, phase_lag_np = elephant.spectral.welch_coherence( + x.magnitude.flatten(), y.magnitude.flatten(), fs=1 / sampling_period + ) self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity)) self.assertFalse(isinstance(phase_lag_np, pq.quantity.Quantity)) # check if the results from different input types are identical - self.assertTrue((freqs_neo == freqs_pq).all() and - (coherency_neo[:, 0] == coherency_pq).all() and - (phase_lag_neo[:, 0] == phase_lag_pq).all()) - self.assertTrue((freqs_neo == freqs_np).all() and - (coherency_neo[:, 0] == coherency_np).all() and - (phase_lag_neo[:, 0] == phase_lag_np).all()) + self.assertTrue( + (freqs_neo == freqs_pq).all() + and (coherency_neo[:, 0] == coherency_pq).all() + and (phase_lag_neo[:, 0] == phase_lag_pq).all() + ) + self.assertTrue( + (freqs_neo == freqs_np).all() + and (coherency_neo[:, 0] == coherency_np).all() + and (phase_lag_neo[:, 0] == phase_lag_np).all() + ) def test_welch_cohere_multidim_input(self): # generate multidimensional data @@ -1185,29 +1305,30 @@ def test_welch_cohere_multidim_input(self): # Since row-column order in AnalogSignal is different from the # convention in NumPy/SciPy, `data_np` needs to be transposed when it's # used to define an AnalogSignal - x_neo = AnalogSignal(x_np.T, units='mV', - sampling_period=sampling_period * pq.s) - y_neo = AnalogSignal(y_np.T, units='mV', - sampling_period=sampling_period * pq.s) - x_neo_1dim = AnalogSignal(x_np[0], units='mV', - sampling_period=sampling_period * pq.s) - y_neo_1dim = AnalogSignal(y_np[0], units='mV', - sampling_period=sampling_period * pq.s) + x_neo = AnalogSignal(x_np.T, units="mV", sampling_period=sampling_period * pq.s) + y_neo = AnalogSignal(y_np.T, units="mV", sampling_period=sampling_period * pq.s) + x_neo_1dim = AnalogSignal( + x_np[0], units="mV", sampling_period=sampling_period * pq.s + ) + y_neo_1dim = AnalogSignal( + y_np[0], units="mV", sampling_period=sampling_period * pq.s + ) # check if the results from different input types are identical - freqs_np, coherency_np, phase_lag_np = elephant.spectral \ - .welch_coherence(x_np, y_np, fs=1 / sampling_period) - freqs_neo, coherency_neo, phase_lag_neo = \ - elephant.spectral.welch_coherence(x_neo, y_neo) - freqs_neo_1dim, coherency_neo_1dim, phase_lag_neo_1dim = \ + freqs_np, coherency_np, phase_lag_np = elephant.spectral.welch_coherence( + x_np, y_np, fs=1 / sampling_period + ) + freqs_neo, coherency_neo, phase_lag_neo = elephant.spectral.welch_coherence( + x_neo, y_neo + ) + freqs_neo_1dim, coherency_neo_1dim, phase_lag_neo_1dim = ( elephant.spectral.welch_coherence(x_neo_1dim, y_neo_1dim) + ) self.assertTrue(np.all(freqs_np == freqs_neo)) self.assertTrue(np.all(coherency_np.T == coherency_neo)) self.assertTrue(np.all(phase_lag_np.T == phase_lag_neo)) - self.assertTrue( - np.all(coherency_neo_1dim[:, 0] == coherency_neo[:, 0])) - self.assertTrue( - np.all(phase_lag_neo_1dim[:, 0] == phase_lag_neo[:, 0])) + self.assertTrue(np.all(coherency_neo_1dim[:, 0] == coherency_neo[:, 0])) + self.assertTrue(np.all(phase_lag_neo_1dim[:, 0] == phase_lag_neo[:, 0])) if __name__ == "__main__": diff --git a/elephant/test/test_spike_train_correlation.py b/elephant/test/test_spike_train_correlation.py index cae01e479..91ccf2fed 100644 --- a/elephant/test/test_spike_train_correlation.py +++ b/elephant/test/test_spike_train_correlation.py @@ -17,34 +17,36 @@ import elephant.conversion as conv import elephant.spike_train_correlation as sc -from elephant.spike_train_generation import StationaryPoissonProcess, \ - StationaryGammaProcess +from elephant.spike_train_generation import ( + StationaryPoissonProcess, + StationaryGammaProcess, +) import math from elephant.datasets import download_datasets, ELEPHANT_TMP_DIR -from elephant.spike_train_generation import homogeneous_poisson_process, \ - homogeneous_gamma_process +from elephant.spike_train_generation import ( + homogeneous_poisson_process, + homogeneous_gamma_process, +) class CovarianceTestCase(unittest.TestCase): - def setUp(self): # These two arrays must be such that they do not have coincidences # spanning across two neighbor bins assuming ms bins [0,1),[1,2),... - self.test_array_1d_0 = [ - 1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] - self.test_array_1d_1 = [ - 1.02, 2.71, 18.82, 28.46, 28.79, 43.6] + self.test_array_1d_0 = [1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] + self.test_array_1d_1 = [1.02, 2.71, 18.82, 28.46, 28.79, 43.6] # Build spike trains - self.st_0 = neo.SpikeTrain( - self.test_array_1d_0, units='ms', t_stop=50.) - self.st_1 = neo.SpikeTrain( - self.test_array_1d_1, units='ms', t_stop=50.) + self.st_0 = neo.SpikeTrain(self.test_array_1d_0, units="ms", t_stop=50.0) + self.st_1 = neo.SpikeTrain(self.test_array_1d_1, units="ms", t_stop=50.0) # And binned counterparts self.binned_st = conv.BinnedSpikeTrain( - [self.st_0, self.st_1], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_0, self.st_1], + t_start=0 * pq.ms, + t_stop=50.0 * pq.ms, + bin_size=1 * pq.ms, + ) def test_covariance_binned(self): """ @@ -52,10 +54,8 @@ def test_covariance_binned(self): """ # Calculate clipped and unclipped - res_clipped = sc.covariance( - self.binned_st, binary=True, fast=False) - res_unclipped = sc.covariance( - self.binned_st, binary=False, fast=False) + res_clipped = sc.covariance(self.binned_st, binary=True, fast=False) + res_unclipped = sc.covariance(self.binned_st, binary=False, fast=False) # Check dimensions self.assertEqual(len(res_clipped), 2) @@ -66,8 +66,9 @@ def test_covariance_binned(self): mat = self.binned_st.to_array() mean_0 = np.mean(mat[0]) mean_1 = np.mean(mat[1]) - target_from_scratch = \ - np.dot(mat[0] - mean_0, mat[1] - mean_1) / (len(mat[0]) - 1) + target_from_scratch = np.dot(mat[0] - mean_0, mat[1] - mean_1) / ( + len(mat[0]) - 1 + ) # Check result unclipped against result calculated by numpy.corrcoef target_numpy = np.cov(mat) @@ -81,8 +82,9 @@ def test_covariance_binned(self): mat = self.binned_st.to_bool_array() mean_0 = np.mean(mat[0]) mean_1 = np.mean(mat[1]) - target_from_scratch = \ - np.dot(mat[0] - mean_0, mat[1] - mean_1) / (len(mat[0]) - 1) + target_from_scratch = np.dot(mat[0] - mean_0, mat[1] - mean_1) / ( + len(mat[0]) - 1 + ) # Check result unclipped against result calculated by numpy.corrcoef target_numpy = np.cov(mat) @@ -98,8 +100,11 @@ def test_covariance_binned_same_spiketrains(self): """ # Calculate correlation binned_st = conv.BinnedSpikeTrain( - [self.st_0, self.st_0], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_0, self.st_0], + t_start=0 * pq.ms, + t_stop=50.0 * pq.ms, + bin_size=1 * pq.ms, + ) result = sc.covariance(binned_st, fast=False) # Check dimensions @@ -114,8 +119,8 @@ def test_covariance_binned_short_input(self): """ # Calculate correlation binned_st = conv.BinnedSpikeTrain( - self.st_0, t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + self.st_0, t_start=0 * pq.ms, t_stop=50.0 * pq.ms, bin_size=1 * pq.ms + ) result = sc.covariance(binned_st, binary=True, fast=False) # Check result unclipped against result calculated by numpy.corrcoef @@ -125,42 +130,41 @@ def test_covariance_binned_short_input(self): # Check result and dimensionality of result self.assertEqual(result.ndim, target.ndim) assert_array_almost_equal(result, target) - assert_array_almost_equal(target, - sc.covariance(binned_st, binary=True, - fast=True)) + assert_array_almost_equal( + target, sc.covariance(binned_st, binary=True, fast=True) + ) def test_covariance_fast_mode(self): np.random.seed(27) - st = StationaryPoissonProcess(rate=10 * pq.Hz, t_stop=10 * pq.s - ).generate_spiketrain() + st = StationaryPoissonProcess( + rate=10 * pq.Hz, t_stop=10 * pq.s + ).generate_spiketrain() binned_st = conv.BinnedSpikeTrain(st, n_bins=10) - assert_array_almost_equal(sc.covariance(binned_st, fast=False), - sc.covariance(binned_st, fast=True)) + assert_array_almost_equal( + sc.covariance(binned_st, fast=False), sc.covariance(binned_st, fast=True) + ) class CorrCoefTestCase(unittest.TestCase): - def setUp(self): # These two arrays must be such that they do not have coincidences # spanning across two neighbor bins assuming ms bins [0,1),[1,2),... - self.test_array_1d_0 = [ - 1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] - self.test_array_1d_1 = [ - 1.02, 2.71, 18.82, 28.46, 28.79, 43.6] + self.test_array_1d_0 = [1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] + self.test_array_1d_1 = [1.02, 2.71, 18.82, 28.46, 28.79, 43.6] self.test_array_1d_2 = [] # Build spike trains - self.st_0 = neo.SpikeTrain( - self.test_array_1d_0, units='ms', t_stop=50.) - self.st_1 = neo.SpikeTrain( - self.test_array_1d_1, units='ms', t_stop=50.) - self.st_2 = neo.SpikeTrain( - self.test_array_1d_2, units='ms', t_stop=50.) + self.st_0 = neo.SpikeTrain(self.test_array_1d_0, units="ms", t_stop=50.0) + self.st_1 = neo.SpikeTrain(self.test_array_1d_1, units="ms", t_stop=50.0) + self.st_2 = neo.SpikeTrain(self.test_array_1d_2, units="ms", t_stop=50.0) # And binned counterparts self.binned_st = conv.BinnedSpikeTrain( - [self.st_0, self.st_1], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_0, self.st_1], + t_start=0 * pq.ms, + t_stop=50.0 * pq.ms, + bin_size=1 * pq.ms, + ) def test_corrcoef_binned(self): """ @@ -168,10 +172,8 @@ def test_corrcoef_binned(self): """ # Calculate clipped and unclipped - res_clipped = sc.correlation_coefficient( - self.binned_st, binary=True) - res_unclipped = sc.correlation_coefficient( - self.binned_st, binary=False) + res_clipped = sc.correlation_coefficient(self.binned_st, binary=True) + res_unclipped = sc.correlation_coefficient(self.binned_st, binary=False) # Check dimensions self.assertEqual(len(res_clipped), 2) @@ -182,11 +184,10 @@ def test_corrcoef_binned(self): mat = self.binned_st.to_array() mean_0 = np.mean(mat[0]) mean_1 = np.mean(mat[1]) - target_from_scratch = \ - np.dot(mat[0] - mean_0, mat[1] - mean_1) / \ - np.sqrt( - np.dot(mat[0] - mean_0, mat[0] - mean_0) * - np.dot(mat[1] - mean_1, mat[1] - mean_1)) + target_from_scratch = np.dot(mat[0] - mean_0, mat[1] - mean_1) / np.sqrt( + np.dot(mat[0] - mean_0, mat[0] - mean_0) + * np.dot(mat[1] - mean_1, mat[1] - mean_1) + ) # Check result unclipped against result calculated by numpy.corrcoef target_numpy = np.corrcoef(mat) @@ -200,11 +201,10 @@ def test_corrcoef_binned(self): mat = self.binned_st.to_bool_array() mean_0 = np.mean(mat[0]) mean_1 = np.mean(mat[1]) - target_from_scratch = \ - np.dot(mat[0] - mean_0, mat[1] - mean_1) / \ - np.sqrt( - np.dot(mat[0] - mean_0, mat[0] - mean_0) * - np.dot(mat[1] - mean_1, mat[1] - mean_1)) + target_from_scratch = np.dot(mat[0] - mean_0, mat[1] - mean_1) / np.sqrt( + np.dot(mat[0] - mean_0, mat[0] - mean_0) + * np.dot(mat[1] - mean_1, mat[1] - mean_1) + ) # Check result unclipped against result calculated by numpy.corrcoef target_numpy = np.corrcoef(mat) @@ -220,8 +220,11 @@ def test_corrcoef_binned_same_spiketrains(self): """ # Calculate correlation binned_st = conv.BinnedSpikeTrain( - [self.st_0, self.st_0], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_0, self.st_0], + t_start=0 * pq.ms, + t_stop=50.0 * pq.ms, + bin_size=1 * pq.ms, + ) result = sc.correlation_coefficient(binned_st, fast=False) target = np.ones((2, 2)) @@ -230,8 +233,8 @@ def test_corrcoef_binned_same_spiketrains(self): # Check result assert_array_almost_equal(result, target) assert_array_almost_equal( - result, sc.correlation_coefficient( - binned_st, fast=True)) + result, sc.correlation_coefficient(binned_st, fast=True) + ) def test_corrcoef_binned_short_input(self): """ @@ -239,17 +242,17 @@ def test_corrcoef_binned_short_input(self): """ # Calculate correlation binned_st = conv.BinnedSpikeTrain( - self.st_0, t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + self.st_0, t_start=0 * pq.ms, t_stop=50.0 * pq.ms, bin_size=1 * pq.ms + ) result = sc.correlation_coefficient(binned_st, fast=False) - target = np.array(1.) + target = np.array(1.0) # Check result and dimensionality of result self.assertEqual(result.ndim, 0) assert_array_almost_equal(result, target) assert_array_almost_equal( - result, sc.correlation_coefficient( - binned_st, fast=True)) + result, sc.correlation_coefficient(binned_st, fast=True) + ) def test_empty_spike_train(self): """ @@ -257,8 +260,7 @@ def test_empty_spike_train(self): Also check correctness of the output array. """ # st_2 is empty - binned_12 = conv.BinnedSpikeTrain([self.st_1, self.st_2], - bin_size=1 * pq.ms) + binned_12 = conv.BinnedSpikeTrain([self.st_1, self.st_2], bin_size=1 * pq.ms) with self.assertWarns(UserWarning): result = sc.correlation_coefficient(binned_12, fast=False) @@ -270,55 +272,57 @@ def test_empty_spike_train(self): def test_corrcoef_fast_mode(self): np.random.seed(27) - st = StationaryPoissonProcess(rate=10 * pq.Hz, t_stop=10 * pq.s - ).generate_spiketrain() + st = StationaryPoissonProcess( + rate=10 * pq.Hz, t_stop=10 * pq.s + ).generate_spiketrain() binned_st = conv.BinnedSpikeTrain(st, n_bins=10) assert_array_almost_equal( - sc.correlation_coefficient( - binned_st, fast=False), sc.correlation_coefficient( - binned_st, fast=True)) + sc.correlation_coefficient(binned_st, fast=False), + sc.correlation_coefficient(binned_st, fast=True), + ) class CrossCorrelationHistogramTest(unittest.TestCase): - def setUp(self): # These two arrays must be such that they do not have coincidences # spanning across two neighbor bins assuming ms bins [0,1),[1,2),... - self.test_array_1d_1 = [ - 1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] - self.test_array_1d_2 = [ - 1.02, 2.71, 18.82, 28.46, 28.79, 43.6] + self.test_array_1d_1 = [1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] + self.test_array_1d_2 = [1.02, 2.71, 18.82, 28.46, 28.79, 43.6] # Build spike trains - self.st_1 = neo.SpikeTrain( - self.test_array_1d_1, units='ms', t_stop=50.) - self.st_2 = neo.SpikeTrain( - self.test_array_1d_2, units='ms', t_stop=50.) + self.st_1 = neo.SpikeTrain(self.test_array_1d_1, units="ms", t_stop=50.0) + self.st_2 = neo.SpikeTrain(self.test_array_1d_2, units="ms", t_stop=50.0) # And binned counterparts self.binned_st1 = conv.BinnedSpikeTrain( - [self.st_1], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_1], t_start=0 * pq.ms, t_stop=50.0 * pq.ms, bin_size=1 * pq.ms + ) self.binned_st2 = conv.BinnedSpikeTrain( - [self.st_2], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_2], t_start=0 * pq.ms, t_stop=50.0 * pq.ms, bin_size=1 * pq.ms + ) self.binned_sts = conv.BinnedSpikeTrain( - [self.st_1, self.st_2], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_1, self.st_2], + t_start=0 * pq.ms, + t_stop=50.0 * pq.ms, + bin_size=1 * pq.ms, + ) # Binned sts to check errors raising self.st_check_bin_size = conv.BinnedSpikeTrain( - [self.st_1], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=5 * pq.ms) + [self.st_1], t_start=0 * pq.ms, t_stop=50.0 * pq.ms, bin_size=5 * pq.ms + ) self.st_check_t_start = conv.BinnedSpikeTrain( - [self.st_1], t_start=1 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_1], t_start=1 * pq.ms, t_stop=50.0 * pq.ms, bin_size=1 * pq.ms + ) self.st_check_t_stop = conv.BinnedSpikeTrain( - [self.st_1], t_start=0 * pq.ms, t_stop=40. * pq.ms, - bin_size=1 * pq.ms) + [self.st_1], t_start=0 * pq.ms, t_stop=40.0 * pq.ms, bin_size=1 * pq.ms + ) self.st_check_dimension = conv.BinnedSpikeTrain( - [self.st_1, self.st_2], t_start=0 * pq.ms, t_stop=50. * pq.ms, - bin_size=1 * pq.ms) + [self.st_1, self.st_2], + t_start=0 * pq.ms, + t_stop=50.0 * pq.ms, + bin_size=1 * pq.ms, + ) def test_cross_correlation_histogram(self): """ @@ -328,31 +332,39 @@ def test_cross_correlation_histogram(self): # Calculate CCH using Elephant (normal and binary version) with # mode equal to 'full' (whole spike trains are correlated) cch_clipped, bin_ids_clipped = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', - binary=True) + self.binned_st1, self.binned_st2, window="full", binary=True + ) cch_unclipped, bin_ids_unclipped = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', binary=False) + self.binned_st1, self.binned_st2, window="full", binary=False + ) cch_clipped_mem, bin_ids_clipped_mem = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', - binary=True, method='memory') - cch_unclipped_mem, bin_ids_unclipped_mem = \ - sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', - binary=False, method='memory') + self.binned_st1, + self.binned_st2, + window="full", + binary=True, + method="memory", + ) + cch_unclipped_mem, bin_ids_unclipped_mem = sc.cross_correlation_histogram( + self.binned_st1, + self.binned_st2, + window="full", + binary=False, + method="memory", + ) # Check consistency two methods assert_array_equal( - np.squeeze(cch_clipped.magnitude), np.squeeze( - cch_clipped_mem.magnitude)) + np.squeeze(cch_clipped.magnitude), np.squeeze(cch_clipped_mem.magnitude) + ) assert_array_equal( - np.squeeze(cch_clipped.times), np.squeeze( - cch_clipped_mem.times)) + np.squeeze(cch_clipped.times), np.squeeze(cch_clipped_mem.times) + ) assert_array_equal( - np.squeeze(cch_unclipped.magnitude), np.squeeze( - cch_unclipped_mem.magnitude)) + np.squeeze(cch_unclipped.magnitude), np.squeeze(cch_unclipped_mem.magnitude) + ) assert_array_equal( - np.squeeze(cch_unclipped.times), np.squeeze( - cch_unclipped_mem.times)) + np.squeeze(cch_unclipped.times), np.squeeze(cch_unclipped_mem.times) + ) assert_array_almost_equal(bin_ids_clipped, bin_ids_clipped_mem) assert_array_almost_equal(bin_ids_unclipped, bin_ids_unclipped_mem) @@ -361,9 +373,8 @@ def test_cross_correlation_histogram(self): # swapped compared to Elephant! mat1 = self.binned_st1.to_array()[0] mat2 = self.binned_st2.to_array()[0] - target_numpy = np.correlate(mat2, mat1, mode='full') - assert_array_equal( - target_numpy, np.squeeze(cch_unclipped.magnitude)) + target_numpy = np.correlate(mat2, mat1, mode="full") + assert_array_equal(target_numpy, np.squeeze(cch_unclipped.magnitude)) # Check cross correlation function for several displacements tau # Note: Use Elephant corrcoeff to verify result @@ -372,84 +383,106 @@ def test_cross_correlation_histogram(self): # adjust t_start, t_stop to shift by tau t0 = np.min([self.st_1.t_start + t * pq.ms, self.st_2.t_start]) t1 = np.max([self.st_1.t_stop + t * pq.ms, self.st_2.t_stop]) - st1 = neo.SpikeTrain(self.st_1.magnitude + t, units='ms', - t_start=t0 * pq.ms, t_stop=t1 * pq.ms) - st2 = neo.SpikeTrain(self.st_2.magnitude, units='ms', - t_start=t0 * pq.ms, t_stop=t1 * pq.ms) - binned_sts = conv.BinnedSpikeTrain([st1, st2], - bin_size=1 * pq.ms, - t_start=t0 * pq.ms, - t_stop=t1 * pq.ms) + st1 = neo.SpikeTrain( + self.st_1.magnitude + t, + units="ms", + t_start=t0 * pq.ms, + t_stop=t1 * pq.ms, + ) + st2 = neo.SpikeTrain( + self.st_2.magnitude, units="ms", t_start=t0 * pq.ms, t_stop=t1 * pq.ms + ) + binned_sts = conv.BinnedSpikeTrain( + [st1, st2], bin_size=1 * pq.ms, t_start=t0 * pq.ms, t_stop=t1 * pq.ms + ) # caluclate corrcoef corrcoef = sc.correlation_coefficient(binned_sts)[1, 0] # expand t_stop to have two spike trains with same length as st1, # st2 - st1 = neo.SpikeTrain(self.st_1.magnitude, units='ms', - t_start=self.st_1.t_start, - t_stop=self.st_1.t_stop + np.abs(t) * pq.ms) - st2 = neo.SpikeTrain(self.st_2.magnitude, units='ms', - t_start=self.st_2.t_start, - t_stop=self.st_2.t_stop + np.abs(t) * pq.ms) + st1 = neo.SpikeTrain( + self.st_1.magnitude, + units="ms", + t_start=self.st_1.t_start, + t_stop=self.st_1.t_stop + np.abs(t) * pq.ms, + ) + st2 = neo.SpikeTrain( + self.st_2.magnitude, + units="ms", + t_start=self.st_2.t_start, + t_stop=self.st_2.t_stop + np.abs(t) * pq.ms, + ) binned_st1 = conv.BinnedSpikeTrain( - st1, t_start=0 * pq.ms, t_stop=(50 + np.abs(t)) * pq.ms, - bin_size=1 * pq.ms) + st1, + t_start=0 * pq.ms, + t_stop=(50 + np.abs(t)) * pq.ms, + bin_size=1 * pq.ms, + ) binned_st2 = conv.BinnedSpikeTrain( - st2, t_start=0 * pq.ms, t_stop=(50 + np.abs(t)) * pq.ms, - bin_size=1 * pq.ms) + st2, + t_start=0 * pq.ms, + t_stop=(50 + np.abs(t)) * pq.ms, + bin_size=1 * pq.ms, + ) # calculate CCHcoef and take value at t=tau - CCHcoef, _ = sc.cch(binned_st1, binned_st2, - cross_correlation_coefficient=True) - left_edge = - binned_st1.n_bins + 1 + CCHcoef, _ = sc.cch( + binned_st1, binned_st2, cross_correlation_coefficient=True + ) + left_edge = -binned_st1.n_bins + 1 tau_bin = int(t / float(binned_st1.bin_size.magnitude)) - assert_array_almost_equal( - corrcoef, CCHcoef[tau_bin - left_edge].magnitude) + assert_array_almost_equal(corrcoef, CCHcoef[tau_bin - left_edge].magnitude) # Check correlation using binary spike trains mat1 = np.array(self.binned_st1.to_bool_array()[0], dtype=int) mat2 = np.array(self.binned_st2.to_bool_array()[0], dtype=int) - target_numpy = np.correlate(mat2, mat1, mode='full') - assert_array_equal( - target_numpy, np.squeeze(cch_clipped.magnitude)) + target_numpy = np.correlate(mat2, mat1, mode="full") + assert_array_equal(target_numpy, np.squeeze(cch_clipped.magnitude)) # Check the time axis and bin IDs of the resulting AnalogSignal assert_array_almost_equal( - (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, - cch_unclipped.times) + (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, cch_unclipped.times + ) assert_array_almost_equal( - (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, - cch_clipped.times) + (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, cch_clipped.times + ) # Calculate CCH using Elephant (normal and binary version) with # mode equal to 'valid' (only completely overlapping intervals of the # spike trains are correlated) cch_clipped, bin_ids_clipped = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='valid', - binary=True) + self.binned_st1, self.binned_st2, window="valid", binary=True + ) cch_unclipped, bin_ids_unclipped = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='valid', - binary=False) + self.binned_st1, self.binned_st2, window="valid", binary=False + ) cch_clipped_mem, bin_ids_clipped_mem = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='valid', - binary=True, method='memory') - cch_unclipped_mem, bin_ids_unclipped_mem = \ - sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='valid', - binary=False, method='memory') + self.binned_st1, + self.binned_st2, + window="valid", + binary=True, + method="memory", + ) + cch_unclipped_mem, bin_ids_unclipped_mem = sc.cross_correlation_histogram( + self.binned_st1, + self.binned_st2, + window="valid", + binary=False, + method="memory", + ) # Check consistency two methods assert_array_equal( - np.squeeze(cch_clipped.magnitude), np.squeeze( - cch_clipped_mem.magnitude)) + np.squeeze(cch_clipped.magnitude), np.squeeze(cch_clipped_mem.magnitude) + ) assert_array_equal( - np.squeeze(cch_clipped.times), np.squeeze( - cch_clipped_mem.times)) + np.squeeze(cch_clipped.times), np.squeeze(cch_clipped_mem.times) + ) assert_array_equal( - np.squeeze(cch_unclipped.magnitude), np.squeeze( - cch_unclipped_mem.magnitude)) + np.squeeze(cch_unclipped.magnitude), np.squeeze(cch_unclipped_mem.magnitude) + ) assert_array_equal( - np.squeeze(cch_unclipped.times), np.squeeze( - cch_unclipped_mem.times)) + np.squeeze(cch_unclipped.times), np.squeeze(cch_unclipped_mem.times) + ) assert_array_equal(bin_ids_clipped, bin_ids_clipped_mem) assert_array_equal(bin_ids_unclipped, bin_ids_unclipped_mem) @@ -458,32 +491,39 @@ def test_cross_correlation_histogram(self): # swapped compared to Elephant! mat1 = self.binned_st1.to_array()[0] mat2 = self.binned_st2.to_array()[0] - target_numpy = np.correlate(mat2, mat1, mode='valid') - assert_array_equal( - target_numpy, np.squeeze(cch_unclipped.magnitude)) + target_numpy = np.correlate(mat2, mat1, mode="valid") + assert_array_equal(target_numpy, np.squeeze(cch_unclipped.magnitude)) # Check correlation using binary spike trains mat1 = np.array(self.binned_st1.to_bool_array()[0], dtype=int) mat2 = np.array(self.binned_st2.to_bool_array()[0], dtype=int) - target_numpy = np.correlate(mat2, mat1, mode='valid') - assert_array_equal( - target_numpy, np.squeeze(cch_clipped.magnitude)) + target_numpy = np.correlate(mat2, mat1, mode="valid") + assert_array_equal(target_numpy, np.squeeze(cch_clipped.magnitude)) # Check the time axis and bin IDs of the resulting AnalogSignal assert_array_equal( - (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, - cch_unclipped.times) + (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, cch_unclipped.times + ) assert_array_equal( - (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, - cch_clipped.times) + (bin_ids_clipped - 0.5) * self.binned_st1.bin_size, cch_clipped.times + ) # Check for wrong window parameter setting self.assertRaises( - ValueError, sc.cross_correlation_histogram, self.binned_st1, - self.binned_st2, window='dsaij') + ValueError, + sc.cross_correlation_histogram, + self.binned_st1, + self.binned_st2, + window="dsaij", + ) self.assertRaises( - ValueError, sc.cross_correlation_histogram, self.binned_st1, - self.binned_st2, window='dsaij', method='memory') + ValueError, + sc.cross_correlation_histogram, + self.binned_st1, + self.binned_st2, + window="dsaij", + method="memory", + ) def test_raising_error_wrong_inputs(self): """Check that an exception is thrown if the two spike trains are not @@ -491,52 +531,59 @@ def test_raising_error_wrong_inputs(self): # Check the bin_sizes are the same self.assertRaises( ValueError, - sc.cross_correlation_histogram, self.binned_st1, - self.st_check_bin_size) + sc.cross_correlation_histogram, + self.binned_st1, + self.st_check_bin_size, + ) # Check input are one dimensional self.assertRaises( - ValueError, sc.cross_correlation_histogram, - self.st_check_dimension, self.binned_st2) + ValueError, + sc.cross_correlation_histogram, + self.st_check_dimension, + self.binned_st2, + ) self.assertRaises( - ValueError, sc.cross_correlation_histogram, - self.binned_st2, self.st_check_dimension) + ValueError, + sc.cross_correlation_histogram, + self.binned_st2, + self.st_check_dimension, + ) def test_window(self): """Test if the window parameter is correctly interpreted.""" - cch_win, bin_ids = sc.cch( - self.binned_st1, self.binned_st2, window=[-30, 30]) + cch_win, bin_ids = sc.cch(self.binned_st1, self.binned_st2, window=[-30, 30]) cch_win_mem, bin_ids_mem = sc.cch( - self.binned_st1, self.binned_st2, window=[-30, 30], - method='memory') + self.binned_st1, self.binned_st2, window=[-30, 30], method="memory" + ) self.assertEqual(len(bin_ids), cch_win.shape[0]) assert_array_equal(bin_ids, np.arange(-30, 31, 1)) - assert_array_equal( - (bin_ids - 0.5) * self.binned_st1.bin_size, cch_win.times) + assert_array_equal((bin_ids - 0.5) * self.binned_st1.bin_size, cch_win.times) assert_array_equal(bin_ids_mem, np.arange(-30, 31, 1)) assert_array_equal( - (bin_ids_mem - 0.5) * self.binned_st1.bin_size, cch_win.times) + (bin_ids_mem - 0.5) * self.binned_st1.bin_size, cch_win.times + ) assert_array_equal(cch_win, cch_win_mem) cch_unclipped, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', binary=False) + self.binned_st1, self.binned_st2, window="full", binary=False + ) assert_array_equal(cch_win, cch_unclipped[19:80]) - _, bin_ids = sc.cch( - self.binned_st1, self.binned_st2, window=[20, 30]) + _, bin_ids = sc.cch(self.binned_st1, self.binned_st2, window=[20, 30]) _, bin_ids_mem = sc.cch( - self.binned_st1, self.binned_st2, window=[20, 30], method='memory') + self.binned_st1, self.binned_st2, window=[20, 30], method="memory" + ) assert_array_equal(bin_ids, np.arange(20, 31, 1)) assert_array_equal(bin_ids_mem, np.arange(20, 31, 1)) - _, bin_ids = sc.cch( - self.binned_st1, self.binned_st2, window=[-30, -20]) + _, bin_ids = sc.cch(self.binned_st1, self.binned_st2, window=[-30, -20]) _, bin_ids_mem = sc.cch( - self.binned_st1, self.binned_st2, window=[-30, -20], - method='memory') + self.binned_st1, self.binned_st2, window=[-30, -20], method="memory" + ) assert_array_equal(bin_ids, np.arange(-30, -19, 1)) assert_array_equal(bin_ids_mem, np.arange(-30, -19, 1)) @@ -544,18 +591,34 @@ def test_window(self): # Check for wrong assignments to the window parameter # Test for window longer than the total length of the spike trains self.assertRaises( - ValueError, sc.cross_correlation_histogram, self.binned_st1, - self.binned_st2, window=[-60, 50]) + ValueError, + sc.cross_correlation_histogram, + self.binned_st1, + self.binned_st2, + window=[-60, 50], + ) self.assertRaises( - ValueError, sc.cross_correlation_histogram, self.binned_st1, - self.binned_st2, window=[-50, 60]) + ValueError, + sc.cross_correlation_histogram, + self.binned_st1, + self.binned_st2, + window=[-50, 60], + ) # Test for no integer or wrong string in input self.assertRaises( - ValueError, sc.cross_correlation_histogram, self.binned_st1, - self.binned_st2, window=[-25.5, 25.5]) + ValueError, + sc.cross_correlation_histogram, + self.binned_st1, + self.binned_st2, + window=[-25.5, 25.5], + ) self.assertRaises( - ValueError, sc.cross_correlation_histogram, self.binned_st1, - self.binned_st2, window='test') + ValueError, + sc.cross_correlation_histogram, + self.binned_st1, + self.binned_st2, + window="test", + ) def test_border_correction(self): """Test if the border correction for bins at the edges is correctly @@ -563,28 +626,35 @@ def test_border_correction(self): # check that nothing changes for valid lags cch_valid, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', - border_correction=True, binary=False, kernel=None) - valid_lags = sc._CrossCorrHist.get_valid_lags(self.binned_st1, - self.binned_st2) - left_edge, right_edge = valid_lags[(0, -1), ] - cch_builder = sc._CrossCorrHist(self.binned_st1, self.binned_st2, - window=(left_edge, right_edge)) - cch_valid = cch_builder.correlate_speed(cch_mode='valid') + self.binned_st1, + self.binned_st2, + window="full", + border_correction=True, + binary=False, + kernel=None, + ) + valid_lags = sc._CrossCorrHist.get_valid_lags(self.binned_st1, self.binned_st2) + left_edge, right_edge = valid_lags[(0, -1),] + cch_builder = sc._CrossCorrHist( + self.binned_st1, self.binned_st2, window=(left_edge, right_edge) + ) + cch_valid = cch_builder.correlate_speed(cch_mode="valid") cch_corrected = cch_builder.border_correction(cch_valid) np.testing.assert_array_equal(cch_valid, cch_corrected) # test the border correction for lags without full overlap cch_full, lags_full = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full') + self.binned_st1, self.binned_st2, window="full" + ) cch_full_corrected, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, window='full', - border_correction=True) + self.binned_st1, self.binned_st2, window="full", border_correction=True + ) - n_bins_outside_window = np.min(np.abs( - np.subtract.outer(lags_full, valid_lags)), axis=1) + n_bins_outside_window = np.min( + np.abs(np.subtract.outer(lags_full, valid_lags)), axis=1 + ) min_n_bins = min(self.binned_st1.n_bins, self.binned_st2.n_bins) @@ -595,32 +665,40 @@ def test_border_correction(self): np.testing.assert_array_almost_equal( border_correction[mask], - (float(min_n_bins) - / (min_n_bins - n_bins_outside_window))[mask]) + (float(min_n_bins) / (min_n_bins - n_bins_outside_window))[mask], + ) def test_kernel(self): """Test if the smoothing kernel is correctly defined, and wheter it is applied properly.""" smoothed_cch, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, kernel=np.ones(3)) + self.binned_st1, self.binned_st2, kernel=np.ones(3) + ) smoothed_cch_mem, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, kernel=np.ones(3), - method='memory') + self.binned_st1, self.binned_st2, kernel=np.ones(3), method="memory" + ) cch, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, kernel=None) + self.binned_st1, self.binned_st2, kernel=None + ) cch_mem, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, kernel=None, method='memory') + self.binned_st1, self.binned_st2, kernel=None, method="memory" + ) self.assertNotEqual(smoothed_cch.all, cch.all) self.assertNotEqual(smoothed_cch_mem.all, cch_mem.all) self.assertRaises( - ValueError, sc.cch, self.binned_st1, self.binned_st2, - kernel=np.ones(100)) + ValueError, sc.cch, self.binned_st1, self.binned_st2, kernel=np.ones(100) + ) self.assertRaises( - ValueError, sc.cch, self.binned_st1, self.binned_st2, - kernel=np.ones(100), method='memory') + ValueError, + sc.cch, + self.binned_st1, + self.binned_st2, + kernel=np.ones(100), + method="memory", + ) def test_exist_alias(self): """ @@ -630,125 +708,128 @@ def test_exist_alias(self): def test_annotations(self): cch, _ = sc.cross_correlation_histogram( - self.binned_st1, self.binned_st2, kernel=np.ones(3)) - target_dict = dict(window='full', border_correction=False, - binary=False, kernel=True, - normalization='counts') - self.assertIn('cch_parameters', cch.annotations) - self.assertEqual(cch.annotations['cch_parameters'], target_dict) + self.binned_st1, self.binned_st2, kernel=np.ones(3) + ) + target_dict = dict( + window="full", + border_correction=False, + binary=False, + kernel=True, + normalization="counts", + ) + self.assertIn("cch_parameters", cch.annotations) + self.assertEqual(cch.annotations["cch_parameters"], target_dict) class CrossCorrelationHistDifferentTStartTStopTest(unittest.TestCase): - def _run_sub_tests(self, st1, st2, lags_true): - for window in ('valid', 'full'): - for method in ('speed', 'memory'): + for window in ("valid", "full"): + for method in ("speed", "memory"): with self.subTest(window=window, method=method): bin_size = 1 * pq.s st1_binned = conv.BinnedSpikeTrain(st1, bin_size=bin_size) st2_binned = conv.BinnedSpikeTrain(st2, bin_size=bin_size) - left, right = lags_true[window][(0, -1), ] + left, right = lags_true[window][(0, -1),] cch_window, lags_window = sc.cross_correlation_histogram( - st1_binned, st2_binned, window=(left, right), + st1_binned, + st2_binned, + window=(left, right), method=method, ) cch, lags = sc.cross_correlation_histogram( - st1_binned, st2_binned, window=window) + st1_binned, st2_binned, window=window + ) # target cross correlation - cch_target = np.correlate(st1_binned.to_array()[0], - st2_binned.to_array()[0], - mode=window) + cch_target = np.correlate( + st1_binned.to_array()[0], st2_binned.to_array()[0], mode=window + ) self.assertEqual(len(lags_window), cch_window.shape[0]) - assert_array_almost_equal(cch.magnitude, - cch_window.magnitude) + assert_array_almost_equal(cch.magnitude, cch_window.magnitude) # the output is reversed since we cross-correlate # st2 with st1 rather than st1 with st2 (numpy behavior) - assert_array_almost_equal(np.ravel(cch.magnitude), - cch_target[::-1]) + assert_array_almost_equal(np.ravel(cch.magnitude), cch_target[::-1]) assert_array_equal(lags, lags_true[window]) assert_array_equal(lags, lags_window) def test_cross_correlation_histogram_valid_full_overlap(self): # ex. 1 in the source code - st1 = neo.SpikeTrain([3.5, 4.5, 7.5] * pq.s, t_start=3 * pq.s, - t_stop=8 * pq.s) - st2 = neo.SpikeTrain([1.5, 2.5, 4.5, 8.5, 9.5, 10.5] - * pq.s, t_start=1 * pq.s, t_stop=13 * pq.s) + st1 = neo.SpikeTrain([3.5, 4.5, 7.5] * pq.s, t_start=3 * pq.s, t_stop=8 * pq.s) + st2 = neo.SpikeTrain( + [1.5, 2.5, 4.5, 8.5, 9.5, 10.5] * pq.s, t_start=1 * pq.s, t_stop=13 * pq.s + ) lags_true = { - 'valid': np.arange(-2, 6, dtype=np.int32), - 'full': np.arange(-6, 10, dtype=np.int32) + "valid": np.arange(-2, 6, dtype=np.int32), + "full": np.arange(-6, 10, dtype=np.int32), } self._run_sub_tests(st1, st2, lags_true) def test_cross_correlation_histogram_valid_partial_overlap(self): # ex. 2 in the source code - st1 = neo.SpikeTrain([2.5, 3.5, 4.5, 6.5] * pq.s, t_start=1 * pq.s, - t_stop=7 * pq.s) - st2 = neo.SpikeTrain([3.5, 5.5, 6.5, 7.5, 8.5] * - pq.s, t_start=2 * pq.s, t_stop=9 * pq.s) + st1 = neo.SpikeTrain( + [2.5, 3.5, 4.5, 6.5] * pq.s, t_start=1 * pq.s, t_stop=7 * pq.s + ) + st2 = neo.SpikeTrain( + [3.5, 5.5, 6.5, 7.5, 8.5] * pq.s, t_start=2 * pq.s, t_stop=9 * pq.s + ) lags_true = { - 'valid': np.arange(1, 3, dtype=np.int32), - 'full': np.arange(-4, 8, dtype=np.int32) + "valid": np.arange(1, 3, dtype=np.int32), + "full": np.arange(-4, 8, dtype=np.int32), } self._run_sub_tests(st1, st2, lags_true) def test_cross_correlation_histogram_valid_no_overlap(self): - st1 = neo.SpikeTrain([2.5, 3.5, 4.5, 6.5] * pq.s, t_start=1 * pq.s, - t_stop=7 * pq.s) - st2 = neo.SpikeTrain([3.5, 5.5, 6.5, 7.5, 8.5] * pq.s + 6 * pq.s, - t_start=8 * pq.s, t_stop=15 * pq.s) + st1 = neo.SpikeTrain( + [2.5, 3.5, 4.5, 6.5] * pq.s, t_start=1 * pq.s, t_stop=7 * pq.s + ) + st2 = neo.SpikeTrain( + [3.5, 5.5, 6.5, 7.5, 8.5] * pq.s + 6 * pq.s, + t_start=8 * pq.s, + t_stop=15 * pq.s, + ) lags_true = { - 'valid': np.arange(7, 9, dtype=np.int32), - 'full': np.arange(2, 14, dtype=np.int32) + "valid": np.arange(7, 9, dtype=np.int32), + "full": np.arange(2, 14, dtype=np.int32), } self._run_sub_tests(st1, st2, lags_true) def test_invalid_time_shift(self): # time shift of 0.4 s is not multiple of bin_size=1 s - st1 = neo.SpikeTrain([2.5, 3.5] * pq.s, t_start=1 * pq.s, - t_stop=7 * pq.s) - st2 = neo.SpikeTrain([3.5, 5.5] * pq.s, t_start=1.4 * pq.s, - t_stop=7.4 * pq.s) + st1 = neo.SpikeTrain([2.5, 3.5] * pq.s, t_start=1 * pq.s, t_stop=7 * pq.s) + st2 = neo.SpikeTrain([3.5, 5.5] * pq.s, t_start=1.4 * pq.s, t_stop=7.4 * pq.s) bin_size = 1 * pq.s st1_binned = conv.BinnedSpikeTrain(st1, bin_size=bin_size) st2_binned = conv.BinnedSpikeTrain(st2, bin_size=bin_size) - self.assertRaises(ValueError, sc.cross_correlation_histogram, - st1_binned, st2_binned) + self.assertRaises( + ValueError, sc.cross_correlation_histogram, st1_binned, st2_binned + ) class SpikeTimeTilingCoefficientTestCase(unittest.TestCase): - def setUp(self): # These two arrays must be such that they do not have coincidences # spanning across two neighbor bins assuming ms bins [0,1),[1,2),... - self.test_array_1d_1 = [ - 1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] - self.test_array_1d_2 = [ - 1.02, 2.71, 18.82, 28.46, 28.79, 43.6] + self.test_array_1d_1 = [1.3, 7.56, 15.87, 28.23, 30.9, 34.2, 38.2, 43.2] + self.test_array_1d_2 = [1.02, 2.71, 18.82, 28.46, 28.79, 43.6] # Build spike trains - self.st_1 = neo.SpikeTrain( - self.test_array_1d_1, units='ms', t_stop=50.) - self.st_2 = neo.SpikeTrain( - self.test_array_1d_2, units='ms', t_stop=50.) + self.st_1 = neo.SpikeTrain(self.test_array_1d_1, units="ms", t_stop=50.0) + self.st_2 = neo.SpikeTrain(self.test_array_1d_2, units="ms", t_stop=50.0) def test_sttc_dt_smaller_zero(self): - self.assertRaises(ValueError, sc.sttc, self.st_1, self.st_2, - dt=0 * pq.s) - self.assertRaises(ValueError, sc.sttc, self.st_1, self.st_2, - dt=-1 * pq.ms) + self.assertRaises(ValueError, sc.sttc, self.st_1, self.st_2, dt=0 * pq.s) + self.assertRaises(ValueError, sc.sttc, self.st_1, self.st_2, dt=-1 * pq.ms) def test_sttc_different_t_stop(self): - st_1 = neo.SpikeTrain([1], units='ms', t_stop=10.) - st_2 = neo.SpikeTrain([5], units='ms', t_stop=10.) + st_1 = neo.SpikeTrain([1], units="ms", t_stop=10.0) + st_2 = neo.SpikeTrain([5], units="ms", t_stop=10.0) st_2.t_stop = 1 * pq.ms self.assertRaises(ValueError, sc.sttc, st_1, st_2) def test_sttc_different_t_start(self): - st_1 = neo.SpikeTrain([1], units='ms', t_stop=10.) - st_2 = neo.SpikeTrain([5], units='ms', t_stop=10.) + st_1 = neo.SpikeTrain([1], units="ms", t_stop=10.0) + st_2 = neo.SpikeTrain([5], units="ms", t_stop=10.0) st_2.t_start = 1 * pq.ms self.assertRaises(ValueError, sc.sttc, st_1, st_2) @@ -756,16 +837,14 @@ def test_sttc_different_units_dt(self): # test for result # target obtained with pencil and paper according to original paper. target = 0.495860165593 - self.assertAlmostEqual(target, sc.sttc(self.st_1, self.st_2, - 0.005 * pq.s)) + self.assertAlmostEqual(target, sc.sttc(self.st_1, self.st_2, 0.005 * pq.s)) # test for same result with dt given in ms - self.assertAlmostEqual(target, sc.sttc(self.st_1, self.st_2, - 5.0 * pq.ms)) + self.assertAlmostEqual(target, sc.sttc(self.st_1, self.st_2, 5.0 * pq.ms)) def test_sttc_different_units_spiketrains(self): - st1 = neo.SpikeTrain([1], units='ms', t_stop=10.) - st2 = neo.SpikeTrain([5], units='s', t_stop=10.) + st1 = neo.SpikeTrain([1], units="ms", t_stop=10.0) + st2 = neo.SpikeTrain([5], units="s", t_stop=10.0) self.assertRaises(ValueError, sc.sttc, st1, st2) def test_sttc_not_enough_spiketrains(self): @@ -777,8 +856,8 @@ def test_sttc_not_enough_spiketrains(self): def test_sttc_one_spike(self): # test for one spike in a spiketrain - st1 = neo.SpikeTrain([1], units='ms', t_stop=10.) - st2 = neo.SpikeTrain([5], units='ms', t_stop=10.) + st1 = neo.SpikeTrain([1], units="ms", t_stop=10.0) + st2 = neo.SpikeTrain([5], units="ms", t_stop=10.0) self.assertEqual(sc.sttc(st1, st2), 1.0) self.assertTrue(bool(sc.sttc(st1, st2, 0.1 * pq.ms) < 0)) @@ -788,9 +867,9 @@ def test_sttc_high_value_dt(self): def test_sttc_edge_cases(self): # test for TA = PB = 1 but TB /= PA /= 1 and vice versa - st2 = neo.SpikeTrain([5], units='ms', t_stop=10.) - st3 = neo.SpikeTrain([1, 5, 9], units='ms', t_stop=10.) - target2 = 1. / 3. + st2 = neo.SpikeTrain([5], units="ms", t_stop=10.0) + st3 = neo.SpikeTrain([1, 5, 9], units="ms", t_stop=10.0) + target2 = 1.0 / 3.0 self.assertAlmostEqual(target2, sc.sttc(st3, st2, 0.003 * pq.s)) self.assertAlmostEqual(target2, sc.sttc(st2, st3, 0.003 * pq.s)) @@ -799,44 +878,111 @@ def test_sttc_unsorted_spiketimes(self): # regression test for issue #563 # https://github.com/NeuralEnsemble/elephant/issues/563 spiketrain_E7 = neo.SpikeTrain( - [1678., 23786.3, 34641.8, 71520.7, 73606.9, 78383.3, - 97387.9, 144313.4, 4607.6, 19275.1, 152894.2, 44240.1], - units='ms', t_stop=300000 * pq.ms) + [ + 1678.0, + 23786.3, + 34641.8, + 71520.7, + 73606.9, + 78383.3, + 97387.9, + 144313.4, + 4607.6, + 19275.1, + 152894.2, + 44240.1, + ], + units="ms", + t_stop=300000 * pq.ms, + ) spiketrain_E3 = neo.SpikeTrain( - [1678., 23786.3, 34641.8, 71520.7, 73606.9, 78383.3, - 97387.9, 144313.4, 4607.6, 19275.1, 152894.2, 44240.1], - units='ms', t_stop=300000 * pq.ms) - sttc_unsorted_E7_E3 = sc.sttc(spiketrain_E7, - spiketrain_E3, dt=0.10 * pq.s) + [ + 1678.0, + 23786.3, + 34641.8, + 71520.7, + 73606.9, + 78383.3, + 97387.9, + 144313.4, + 4607.6, + 19275.1, + 152894.2, + 44240.1, + ], + units="ms", + t_stop=300000 * pq.ms, + ) + sttc_unsorted_E7_E3 = sc.sttc(spiketrain_E7, spiketrain_E3, dt=0.10 * pq.s) self.assertAlmostEqual(sttc_unsorted_E7_E3, 1) spiketrain_E7.sort() spiketrain_E3.sort() - sttc_sorted_E7_E3 = sc.sttc(spiketrain_E7, - spiketrain_E3, dt=0.10 * pq.s) + sttc_sorted_E7_E3 = sc.sttc(spiketrain_E7, spiketrain_E3, dt=0.10 * pq.s) self.assertAlmostEqual(sttc_unsorted_E7_E3, sttc_sorted_E7_E3) spiketrain_E8 = neo.SpikeTrain( - [20646.8, 25875.1, 26154.4, 35121., 55909.7, 79164.8, - 110849.8, 117484.1, 3731.5, 4213.9, 119995.1, 123748.1, - 171016.8, 172989., 185145.2, 12043.5, 185995.9, 186740.1, - 12629.8, 23394.3, 34993.2], units='ms', t_stop=300000 * pq.ms) + [ + 20646.8, + 25875.1, + 26154.4, + 35121.0, + 55909.7, + 79164.8, + 110849.8, + 117484.1, + 3731.5, + 4213.9, + 119995.1, + 123748.1, + 171016.8, + 172989.0, + 185145.2, + 12043.5, + 185995.9, + 186740.1, + 12629.8, + 23394.3, + 34993.2, + ], + units="ms", + t_stop=300000 * pq.ms, + ) spiketrain_B3 = neo.SpikeTrain( - [10600.7, 19699.6, 22803., 40769.3, 121385.7, 127402.9, - 130829.2, 134363.8, 1193.5, 8012.7, 142037.3, 146628.2, - 165925.3, 168489.3, 175194.3, 10339.8, 178676.4, 180807.2, - 201431.3, 22231.1, 38113.4], units='ms', t_stop=300000 * pq.ms) - - self.assertTrue( - sc.sttc(spiketrain_E8, spiketrain_B3, dt=0.10 * pq.s) < 1) - - sttc_unsorted_E8_B3 = sc.sttc(spiketrain_E8, - spiketrain_B3, dt=0.10 * pq.s) + [ + 10600.7, + 19699.6, + 22803.0, + 40769.3, + 121385.7, + 127402.9, + 130829.2, + 134363.8, + 1193.5, + 8012.7, + 142037.3, + 146628.2, + 165925.3, + 168489.3, + 175194.3, + 10339.8, + 178676.4, + 180807.2, + 201431.3, + 22231.1, + 38113.4, + ], + units="ms", + t_stop=300000 * pq.ms, + ) + + self.assertTrue(sc.sttc(spiketrain_E8, spiketrain_B3, dt=0.10 * pq.s) < 1) + + sttc_unsorted_E8_B3 = sc.sttc(spiketrain_E8, spiketrain_B3, dt=0.10 * pq.s) spiketrain_E8.sort() spiketrain_B3.sort() - sttc_sorted_E8_B3 = sc.sttc(spiketrain_E8, - spiketrain_B3, dt=0.10 * pq.s) + sttc_sorted_E8_B3 = sc.sttc(spiketrain_E8, spiketrain_B3, dt=0.10 * pq.s) self.assertAlmostEqual(sttc_unsorted_E8_B3, sttc_sorted_E8_B3) def test_sttc_validation_test(self): @@ -847,25 +993,31 @@ def test_sttc_validation_test(self): NeuralEnsemble/elephant-data/unittest/spike_train_correlation/ spike_time_tiling_coefficient""" - repo_path = r"unittest/spike_train_correlation/spike_time_tiling_coefficient/data" # noqa + repo_path = ( + r"unittest/spike_train_correlation/spike_time_tiling_coefficient/data" # noqa + ) - files_to_download = [("spike_time_tiling_coefficient_results.nix", - "e3749d79046622494660a03e89950f51")] + files_to_download = [ + ( + "spike_time_tiling_coefficient_results.nix", + "e3749d79046622494660a03e89950f51", + ) + ] for filename, checksum in files_to_download: - filepath = download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + filepath = download_datasets( + repo_path=f"{repo_path}/{filename}", checksum=checksum + ) - reader = NixIO(filepath, mode='ro') + reader = NixIO(filepath, mode="ro") test_data_block = reader.read() for segment in test_data_block[0].segments: spiketrain_i = segment.spiketrains[0] spiketrain_j = segment.spiketrains[1] - dt = segment.annotations['dt'] - sttc_result = segment.annotations['sttc_result'] - self.assertAlmostEqual(sc.sttc(spiketrain_i, spiketrain_j, dt), - sttc_result) + dt = segment.annotations["dt"] + sttc_result = segment.annotations["sttc_result"] + self.assertAlmostEqual(sc.sttc(spiketrain_i, spiketrain_j, dt), sttc_result) def test_sttc_exist_alias(self): # Test if alias cch still exists. @@ -873,7 +1025,6 @@ def test_sttc_exist_alias(self): class SpikeTrainTimescaleTestCase(unittest.TestCase): - def test_timescale_calculation(self): """ Test the timescale generation using an alpha-shaped ISI distribution, @@ -895,9 +1046,9 @@ def test_timescale_calculation(self): np.random.seed(35) for _ in range(10): - spikes = StationaryGammaProcess(rate=2 * nu / 2, shape_factor=2, - t_start=0 * pq.ms, - t_stop=T).generate_spiketrain() + spikes = StationaryGammaProcess( + rate=2 * nu / 2, shape_factor=2, t_start=0 * pq.ms, t_stop=T + ).generate_spiketrain() spikes_bin = conv.BinnedSpikeTrain(spikes, bin_size) timescale_i = sc.spike_train_timescale(spikes_bin, 10 * timescale) assert_array_almost_equal(timescale, timescale_i, decimal=3) @@ -909,13 +1060,11 @@ def test_timescale_errors(self): # Tau max with no units tau_max = 1 - self.assertRaises(ValueError, - sc.spike_train_timescale, spikes_bin, tau_max) + self.assertRaises(ValueError, sc.spike_train_timescale, spikes_bin, tau_max) # Tau max that is not a multiple of the binsize tau_max = 1.1 * pq.ms - self.assertRaises(ValueError, - sc.spike_train_timescale, spikes_bin, tau_max) + self.assertRaises(ValueError, sc.spike_train_timescale, spikes_bin, tau_max) def test_timescale_nan(self): st0 = neo.SpikeTrain([] * pq.ms, t_stop=10 * pq.ms) @@ -939,5 +1088,5 @@ def test_timescale_nan(self): self.assertFalse(math.isnan(timescale)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_spike_train_dissimilarity.py b/elephant/test/test_spike_train_dissimilarity.py index 6cab2858b..e8850c38d 100644 --- a/elephant/test/test_spike_train_dissimilarity.py +++ b/elephant/test/test_spike_train_dissimilarity.py @@ -5,6 +5,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + import unittest from neo import SpikeTrain import numpy as np @@ -20,38 +21,37 @@ class TimeScaleDependSpikeTrainDissimMeasuresTestCase(unittest.TestCase): def setUp(self): - self.st00 = SpikeTrain([], units='ms', t_stop=1000.0) - self.st01 = SpikeTrain([1], units='ms', t_stop=1000.0) - self.st02 = SpikeTrain([2], units='ms', t_stop=1000.0) - self.st03 = SpikeTrain([2.9], units='ms', t_stop=1000.0) - self.st04 = SpikeTrain([3.1], units='ms', t_stop=1000.0) - self.st05 = SpikeTrain([5], units='ms', t_stop=1000.0) - self.st06 = SpikeTrain([500], units='ms', t_stop=1000.0) - self.st07 = SpikeTrain([12, 32], units='ms', t_stop=1000.0) - self.st08 = SpikeTrain([32, 52], units='ms', t_stop=1000.0) - self.st09 = SpikeTrain([42], units='ms', t_stop=1000.0) - self.st10 = SpikeTrain([18, 60], units='ms', t_stop=1000.0) - self.st11 = SpikeTrain([10, 20, 30, 40], units='ms', t_stop=1000.0) - self.st12 = SpikeTrain([40, 30, 20, 10], units='ms', t_stop=1000.0) - self.st13 = SpikeTrain([15, 25, 35, 45], units='ms', t_stop=1000.0) - self.st14 = SpikeTrain([10, 20, 30, 40, 50], units='ms', t_stop=1000.0) - self.st15 = SpikeTrain([0.01, 0.02, 0.03, 0.04, 0.05], - units='s', t_stop=1000.0) - self.st16 = SpikeTrain([12, 16, 28, 30, 42], units='ms', t_stop=1000.0) - self.st21 = StationaryPoissonProcess(rate=50 * Hz, t_start=0 * ms, - t_stop=1000 * ms - ).generate_spiketrain() - self.st22 = StationaryPoissonProcess(rate=40 * Hz, t_start=0 * ms, - t_stop=1000 * ms - ).generate_spiketrain() - self.st23 = StationaryPoissonProcess(rate=30 * Hz, t_start=0 * ms, - t_stop=1000 * ms - ).generate_spiketrain() + self.st00 = SpikeTrain([], units="ms", t_stop=1000.0) + self.st01 = SpikeTrain([1], units="ms", t_stop=1000.0) + self.st02 = SpikeTrain([2], units="ms", t_stop=1000.0) + self.st03 = SpikeTrain([2.9], units="ms", t_stop=1000.0) + self.st04 = SpikeTrain([3.1], units="ms", t_stop=1000.0) + self.st05 = SpikeTrain([5], units="ms", t_stop=1000.0) + self.st06 = SpikeTrain([500], units="ms", t_stop=1000.0) + self.st07 = SpikeTrain([12, 32], units="ms", t_stop=1000.0) + self.st08 = SpikeTrain([32, 52], units="ms", t_stop=1000.0) + self.st09 = SpikeTrain([42], units="ms", t_stop=1000.0) + self.st10 = SpikeTrain([18, 60], units="ms", t_stop=1000.0) + self.st11 = SpikeTrain([10, 20, 30, 40], units="ms", t_stop=1000.0) + self.st12 = SpikeTrain([40, 30, 20, 10], units="ms", t_stop=1000.0) + self.st13 = SpikeTrain([15, 25, 35, 45], units="ms", t_stop=1000.0) + self.st14 = SpikeTrain([10, 20, 30, 40, 50], units="ms", t_stop=1000.0) + self.st15 = SpikeTrain([0.01, 0.02, 0.03, 0.04, 0.05], units="s", t_stop=1000.0) + self.st16 = SpikeTrain([12, 16, 28, 30, 42], units="ms", t_stop=1000.0) + self.st21 = StationaryPoissonProcess( + rate=50 * Hz, t_start=0 * ms, t_stop=1000 * ms + ).generate_spiketrain() + self.st22 = StationaryPoissonProcess( + rate=40 * Hz, t_start=0 * ms, t_stop=1000 * ms + ).generate_spiketrain() + self.st23 = StationaryPoissonProcess( + rate=30 * Hz, t_start=0 * ms, t_stop=1000 * ms + ).generate_spiketrain() self.rd_st_list = [self.st21, self.st22, self.st23] - self.st31 = SpikeTrain([12.0], units='ms', t_stop=1000.0) - self.st32 = SpikeTrain([12.0, 12.0], units='ms', t_stop=1000.0) - self.st33 = SpikeTrain([20.0], units='ms', t_stop=1000.0) - self.st34 = SpikeTrain([20.0, 20.0], units='ms', t_stop=1000.0) + self.st31 = SpikeTrain([12.0], units="ms", t_stop=1000.0) + self.st32 = SpikeTrain([12.0, 12.0], units="ms", t_stop=1000.0) + self.st33 = SpikeTrain([20.0], units="ms", t_stop=1000.0) + self.st34 = SpikeTrain([20.0, 20.0], units="ms", t_stop=1000.0) self.array1 = np.arange(1, 10) self.array2 = np.arange(1.2, 10) self.qarray1 = self.array1 * Hz @@ -75,472 +75,674 @@ def setUp(self): self.t = np.linspace(0, 200, 20000001) * ms def test_wrong_input(self): - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.array1, self.array2], self.q3) - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.qarray1, self.qarray2], self.q3) - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.qarray1, self.qarray2], 5.0 * ms) - - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.array1, self.array2], self.q3, - algorithm='intuitive') - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.qarray1, self.qarray2], self.q3, - algorithm='intuitive') - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.qarray1, self.qarray2], 5.0 * ms, - algorithm='intuitive') - - self.assertRaises(TypeError, stds.van_rossum_distance, - [self.array1, self.array2], self.tau3) - self.assertRaises(TypeError, stds.van_rossum_distance, - [self.qarray1, self.qarray2], self.tau3) - self.assertRaises(TypeError, stds.van_rossum_distance, - [self.qarray1, self.qarray2], 5.0 * Hz) - - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.st11, self.st13], self.tau2) - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.st11, self.st13], 5.0) - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.st11, self.st13], self.tau2, - algorithm='intuitive') - self.assertRaises(TypeError, stds.victor_purpura_distance, - [self.st11, self.st13], 5.0, - algorithm='intuitive') - self.assertRaises(TypeError, stds.van_rossum_distance, - [self.st11, self.st13], self.q4) - self.assertRaises(TypeError, stds.van_rossum_distance, - [self.st11, self.st13], 5.0) - - self.assertRaises(NotImplementedError, stds.victor_purpura_distance, - [self.st01, self.st02], self.q3, - kernel=kernels.Kernel(2.0 / self.q3)) - self.assertRaises(NotImplementedError, stds.victor_purpura_distance, - [self.st01, self.st02], self.q3, - kernel=kernels.SymmetricKernel(2.0 / self.q3)) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st02], self.q1, - kernel=kernels.TriangularKernel( - 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], - stds.victor_purpura_distance( - [self.st01, self.st02], self.q3, - kernel=kernels.TriangularKernel( - 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1]) - self.assertEqual(stds.victor_purpura_distance( + self.assertRaises( + TypeError, stds.victor_purpura_distance, [self.array1, self.array2], self.q3 + ) + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.qarray1, self.qarray2], + self.q3, + ) + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.qarray1, self.qarray2], + 5.0 * ms, + ) + + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.array1, self.array2], + self.q3, + algorithm="intuitive", + ) + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.qarray1, self.qarray2], + self.q3, + algorithm="intuitive", + ) + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.qarray1, self.qarray2], + 5.0 * ms, + algorithm="intuitive", + ) + + self.assertRaises( + TypeError, stds.van_rossum_distance, [self.array1, self.array2], self.tau3 + ) + self.assertRaises( + TypeError, stds.van_rossum_distance, [self.qarray1, self.qarray2], self.tau3 + ) + self.assertRaises( + TypeError, stds.van_rossum_distance, [self.qarray1, self.qarray2], 5.0 * Hz + ) + + self.assertRaises( + TypeError, stds.victor_purpura_distance, [self.st11, self.st13], self.tau2 + ) + self.assertRaises( + TypeError, stds.victor_purpura_distance, [self.st11, self.st13], 5.0 + ) + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.st11, self.st13], + self.tau2, + algorithm="intuitive", + ) + self.assertRaises( + TypeError, + stds.victor_purpura_distance, + [self.st11, self.st13], + 5.0, + algorithm="intuitive", + ) + self.assertRaises( + TypeError, stds.van_rossum_distance, [self.st11, self.st13], self.q4 + ) + self.assertRaises( + TypeError, stds.van_rossum_distance, [self.st11, self.st13], 5.0 + ) + + self.assertRaises( + NotImplementedError, + stds.victor_purpura_distance, [self.st01, self.st02], - kernel=kernels.TriangularKernel( - 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], 1.0) - self.assertNotEqual(stds.victor_purpura_distance( + self.q3, + kernel=kernels.Kernel(2.0 / self.q3), + ) + self.assertRaises( + NotImplementedError, + stds.victor_purpura_distance, [self.st01, self.st02], - kernel=kernels.AlphaKernel( - 2.0 / (np.sqrt(6.0) * self.q2)))[0, 1], 1.0) - - self.assertRaises(NameError, stds.victor_purpura_distance, - [self.st11, self.st13], self.q2, algorithm='slow') + self.q3, + kernel=kernels.SymmetricKernel(2.0 / self.q3), + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st02], + self.q1, + kernel=kernels.TriangularKernel(2.0 / (np.sqrt(6.0) * self.q2)), + )[0, 1], + stds.victor_purpura_distance( + [self.st01, self.st02], + self.q3, + kernel=kernels.TriangularKernel(2.0 / (np.sqrt(6.0) * self.q2)), + )[0, 1], + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st02], + kernel=kernels.TriangularKernel(2.0 / (np.sqrt(6.0) * self.q2)), + )[0, 1], + 1.0, + ) + self.assertNotEqual( + stds.victor_purpura_distance( + [self.st01, self.st02], + kernel=kernels.AlphaKernel(2.0 / (np.sqrt(6.0) * self.q2)), + )[0, 1], + 1.0, + ) + + self.assertRaises( + NameError, + stds.victor_purpura_distance, + [self.st11, self.st13], + self.q2, + algorithm="slow", + ) def test_victor_purpura_distance_fast(self): # Tests of distances of simplest spike trains: - self.assertEqual(stds.victor_purpura_distance( - [self.st00, self.st00], self.q2)[0, 1], 0.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st00, self.st01], self.q2)[0, 1], 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st00], self.q2)[0, 1], 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st01], self.q2)[0, 1], 0.0) + self.assertEqual( + stds.victor_purpura_distance([self.st00, self.st00], self.q2)[0, 1], 0.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st00, self.st01], self.q2)[0, 1], 1.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st01, self.st00], self.q2)[0, 1], 1.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st01, self.st01], self.q2)[0, 1], 0.0 + ) # Tests of distances under elementary spike operations - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st02], self.q2)[0, 1], 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st03], self.q2)[0, 1], 1.9) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st04], self.q2)[0, 1], 2.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st05], self.q2)[0, 1], 2.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st00, self.st07], self.q2)[0, 1], 2.0) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st07, self.st08], self.q4)[0, 1], 0.4) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st07, self.st10], self.q3)[0, 1], 0.6 + 2) - self.assertEqual(stds.victor_purpura_distance( - [self.st11, self.st14], self.q2)[0, 1], 1) + self.assertEqual( + stds.victor_purpura_distance([self.st01, self.st02], self.q2)[0, 1], 1.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st01, self.st03], self.q2)[0, 1], 1.9 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st01, self.st04], self.q2)[0, 1], 2.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st01, self.st05], self.q2)[0, 1], 2.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st00, self.st07], self.q2)[0, 1], 2.0 + ) + self.assertAlmostEqual( + stds.victor_purpura_distance([self.st07, self.st08], self.q4)[0, 1], 0.4 + ) + self.assertAlmostEqual( + stds.victor_purpura_distance([self.st07, self.st10], self.q3)[0, 1], 0.6 + 2 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st11, self.st14], self.q2)[0, 1], 1 + ) # Tests on timescales - self.assertEqual(stds.victor_purpura_distance( - [self.st11, self.st14], self.q1)[0, 1], - stds.victor_purpura_distance( - [self.st11, self.st14], self.q5)[0, 1]) - self.assertEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q0)[0, 1], 6.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q1)[0, 1], 6.0) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q5)[0, 1], 2.0, 5) - self.assertEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q6)[0, 1], 2.0) + self.assertEqual( + stds.victor_purpura_distance([self.st11, self.st14], self.q1)[0, 1], + stds.victor_purpura_distance([self.st11, self.st14], self.q5)[0, 1], + ) + self.assertEqual( + stds.victor_purpura_distance([self.st07, self.st11], self.q0)[0, 1], 6.0 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st07, self.st11], self.q1)[0, 1], 6.0 + ) + self.assertAlmostEqual( + stds.victor_purpura_distance([self.st07, self.st11], self.q5)[0, 1], 2.0, 5 + ) + self.assertEqual( + stds.victor_purpura_distance([self.st07, self.st11], self.q6)[0, 1], 2.0 + ) # Tests on unordered spiketrains - self.assertEqual(stds.victor_purpura_distance( - [self.st11, self.st13], self.q4)[0, 1], - stds.victor_purpura_distance( - [self.st12, self.st13], self.q4)[0, 1]) - self.assertNotEqual(stds.victor_purpura_distance( - [self.st11, self.st13], self.q4, - sort=False)[0, 1], - stds.victor_purpura_distance( - [self.st12, self.st13], self.q4, - sort=False)[0, 1]) + self.assertEqual( + stds.victor_purpura_distance([self.st11, self.st13], self.q4)[0, 1], + stds.victor_purpura_distance([self.st12, self.st13], self.q4)[0, 1], + ) + self.assertNotEqual( + stds.victor_purpura_distance([self.st11, self.st13], self.q4, sort=False)[ + 0, 1 + ], + stds.victor_purpura_distance([self.st12, self.st13], self.q4, sort=False)[ + 0, 1 + ], + ) # Tests on metric properties with random spiketrains # (explicit calculation of second metric axiom in particular case, # because from dist_matrix it is trivial) dist_matrix = stds.victor_purpura_distance( - [self.st21, self.st22, self.st23], self.q3) + [self.st21, self.st22, self.st23], self.q3 + ) for i in range(3): for j in range(3): self.assertGreaterEqual(dist_matrix[i, j], 0) if dist_matrix[i, j] == 0: assert_array_equal(self.rd_st_list[i], self.rd_st_list[j]) - assert_array_equal(stds.victor_purpura_distance( - [self.st21, self.st22], self.q3), - stds.victor_purpura_distance( - [self.st22, self.st21], self.q3)) - self.assertLessEqual(dist_matrix[0, 1], - dist_matrix[0, 2] + dist_matrix[1, 2]) - self.assertLessEqual(dist_matrix[0, 2], - dist_matrix[1, 2] + dist_matrix[0, 1]) - self.assertLessEqual(dist_matrix[1, 2], - dist_matrix[0, 1] + dist_matrix[0, 2]) + assert_array_equal( + stds.victor_purpura_distance([self.st21, self.st22], self.q3), + stds.victor_purpura_distance([self.st22, self.st21], self.q3), + ) + self.assertLessEqual(dist_matrix[0, 1], dist_matrix[0, 2] + dist_matrix[1, 2]) + self.assertLessEqual(dist_matrix[0, 2], dist_matrix[1, 2] + dist_matrix[0, 1]) + self.assertLessEqual(dist_matrix[1, 2], dist_matrix[0, 1] + dist_matrix[0, 2]) # Tests on proper unit conversion self.assertAlmostEqual( - stds.victor_purpura_distance([self.st14, self.st16], - self.q3)[0, 1], - stds.victor_purpura_distance([self.st15, self.st16], - self.q3)[0, 1]) + stds.victor_purpura_distance([self.st14, self.st16], self.q3)[0, 1], + stds.victor_purpura_distance([self.st15, self.st16], self.q3)[0, 1], + ) self.assertAlmostEqual( - stds.victor_purpura_distance([self.st16, self.st14], - self.q3)[0, 1], - stds.victor_purpura_distance([self.st16, self.st15], - self.q3)[0, 1]) + stds.victor_purpura_distance([self.st16, self.st14], self.q3)[0, 1], + stds.victor_purpura_distance([self.st16, self.st15], self.q3)[0, 1], + ) self.assertAlmostEqual( - stds.victor_purpura_distance([self.st01, self.st05], - self.q3)[0, 1], - stds.victor_purpura_distance([self.st01, self.st05], - self.q7)[0, 1]) + stds.victor_purpura_distance([self.st01, self.st05], self.q3)[0, 1], + stds.victor_purpura_distance([self.st01, self.st05], self.q7)[0, 1], + ) # Tests on algorithmic behaviour for equal spike times - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st31, self.st34], self.q3)[0, 1], 0.8 + 1.0) self.assertAlmostEqual( - stds.victor_purpura_distance([self.st31, self.st34], - self.q3)[0, 1], - stds.victor_purpura_distance([self.st32, self.st33], - self.q3)[0, 1]) + stds.victor_purpura_distance([self.st31, self.st34], self.q3)[0, 1], + 0.8 + 1.0, + ) self.assertAlmostEqual( - stds.victor_purpura_distance( - [self.st31, self.st33], self.q3)[0, 1] * 2.0, - stds.victor_purpura_distance( - [self.st32, self.st34], self.q3)[0, 1]) + stds.victor_purpura_distance([self.st31, self.st34], self.q3)[0, 1], + stds.victor_purpura_distance([self.st32, self.st33], self.q3)[0, 1], + ) + self.assertAlmostEqual( + stds.victor_purpura_distance([self.st31, self.st33], self.q3)[0, 1] * 2.0, + stds.victor_purpura_distance([self.st32, self.st34], self.q3)[0, 1], + ) # Tests on spike train list lengthes smaller than 2 - self.assertEqual(stds.victor_purpura_distance( - [self.st21], self.q3)[0, 0], 0) + self.assertEqual(stds.victor_purpura_distance([self.st21], self.q3)[0, 0], 0) self.assertEqual(len(stds.victor_purpura_distance([], self.q3)), 0) def test_victor_purpura_distance_intuitive(self): # Tests of distances of simplest spike trains - self.assertEqual(stds.victor_purpura_distance( - [self.st00, self.st00], self.q2, - algorithm='intuitive')[0, 1], 0.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st00, self.st01], self.q2, - algorithm='intuitive')[0, 1], 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st00], self.q2, - algorithm='intuitive')[0, 1], 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st01], self.q2, - algorithm='intuitive')[0, 1], 0.0) + self.assertEqual( + stds.victor_purpura_distance( + [self.st00, self.st00], self.q2, algorithm="intuitive" + )[0, 1], + 0.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st00, self.st01], self.q2, algorithm="intuitive" + )[0, 1], + 1.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st00], self.q2, algorithm="intuitive" + )[0, 1], + 1.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st01], self.q2, algorithm="intuitive" + )[0, 1], + 0.0, + ) # Tests of distances under elementary spike operations - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st02], self.q2, - algorithm='intuitive')[0, 1], 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st03], self.q2, - algorithm='intuitive')[0, 1], 1.9) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st04], self.q2, - algorithm='intuitive')[0, 1], 2.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st05], self.q2, - algorithm='intuitive')[0, 1], 2.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st00, self.st07], self.q2, - algorithm='intuitive')[0, 1], 2.0) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st07, self.st08], self.q4, - algorithm='intuitive')[0, 1], 0.4) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st07, self.st10], self.q3, - algorithm='intuitive')[0, 1], 2.6) - self.assertEqual(stds.victor_purpura_distance( - [self.st11, self.st14], self.q2, - algorithm='intuitive')[0, 1], 1) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st02], self.q2, algorithm="intuitive" + )[0, 1], + 1.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st03], self.q2, algorithm="intuitive" + )[0, 1], + 1.9, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st04], self.q2, algorithm="intuitive" + )[0, 1], + 2.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st05], self.q2, algorithm="intuitive" + )[0, 1], + 2.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st00, self.st07], self.q2, algorithm="intuitive" + )[0, 1], + 2.0, + ) + self.assertAlmostEqual( + stds.victor_purpura_distance( + [self.st07, self.st08], self.q4, algorithm="intuitive" + )[0, 1], + 0.4, + ) + self.assertAlmostEqual( + stds.victor_purpura_distance( + [self.st07, self.st10], self.q3, algorithm="intuitive" + )[0, 1], + 2.6, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st11, self.st14], self.q2, algorithm="intuitive" + )[0, 1], + 1, + ) # Tests on timescales - self.assertEqual(stds.victor_purpura_distance( - [self.st11, self.st14], self.q1, - algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st11, self.st14], self.q5, - algorithm='intuitive')[0, 1]) - self.assertEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q0, - algorithm='intuitive')[0, 1], 6.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q1, - algorithm='intuitive')[0, 1], 6.0) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q5, - algorithm='intuitive')[0, 1], 2.0, 5) - self.assertEqual(stds.victor_purpura_distance( - [self.st07, self.st11], self.q6, - algorithm='intuitive')[0, 1], 2.0) + self.assertEqual( + stds.victor_purpura_distance( + [self.st11, self.st14], self.q1, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st11, self.st14], self.q5, algorithm="intuitive" + )[0, 1], + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st07, self.st11], self.q0, algorithm="intuitive" + )[0, 1], + 6.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st07, self.st11], self.q1, algorithm="intuitive" + )[0, 1], + 6.0, + ) + self.assertAlmostEqual( + stds.victor_purpura_distance( + [self.st07, self.st11], self.q5, algorithm="intuitive" + )[0, 1], + 2.0, + 5, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st07, self.st11], self.q6, algorithm="intuitive" + )[0, 1], + 2.0, + ) # Tests on unordered spiketrains - self.assertEqual(stds.victor_purpura_distance( - [self.st11, self.st13], self.q4, - algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st12, self.st13], self.q4, - algorithm='intuitive')[0, 1]) - self.assertNotEqual(stds.victor_purpura_distance( - [self.st11, self.st13], self.q4, - sort=False, algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st12, self.st13], self.q4, - sort=False, algorithm='intuitive')[0, 1]) + self.assertEqual( + stds.victor_purpura_distance( + [self.st11, self.st13], self.q4, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st12, self.st13], self.q4, algorithm="intuitive" + )[0, 1], + ) + self.assertNotEqual( + stds.victor_purpura_distance( + [self.st11, self.st13], self.q4, sort=False, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st12, self.st13], self.q4, sort=False, algorithm="intuitive" + )[0, 1], + ) # Tests on metric properties with random spiketrains # (explicit calculation of second metric axiom in particular case, # because from dist_matrix it is trivial) dist_matrix = stds.victor_purpura_distance( - [self.st21, self.st22, self.st23], - self.q3, algorithm='intuitive') + [self.st21, self.st22, self.st23], self.q3, algorithm="intuitive" + ) for i in range(3): for j in range(3): self.assertGreaterEqual(dist_matrix[i, j], 0) if dist_matrix[i, j] == 0: assert_array_equal(self.rd_st_list[i], self.rd_st_list[j]) - assert_array_equal(stds.victor_purpura_distance( - [self.st21, self.st22], self.q3, - algorithm='intuitive'), - stds.victor_purpura_distance( - [self.st22, self.st21], self.q3, - algorithm='intuitive')) - self.assertLessEqual(dist_matrix[0, 1], - dist_matrix[0, 2] + dist_matrix[1, 2]) - self.assertLessEqual(dist_matrix[0, 2], - dist_matrix[1, 2] + dist_matrix[0, 1]) - self.assertLessEqual(dist_matrix[1, 2], - dist_matrix[0, 1] + dist_matrix[0, 2]) + assert_array_equal( + stds.victor_purpura_distance( + [self.st21, self.st22], self.q3, algorithm="intuitive" + ), + stds.victor_purpura_distance( + [self.st22, self.st21], self.q3, algorithm="intuitive" + ), + ) + self.assertLessEqual(dist_matrix[0, 1], dist_matrix[0, 2] + dist_matrix[1, 2]) + self.assertLessEqual(dist_matrix[0, 2], dist_matrix[1, 2] + dist_matrix[0, 1]) + self.assertLessEqual(dist_matrix[1, 2], dist_matrix[0, 1] + dist_matrix[0, 2]) # Tests on proper unit conversion - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st14, self.st16], self.q3, - algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st15, self.st16], self.q3, - algorithm='intuitive')[0, 1]) - self.assertAlmostEqual(stds.victor_purpura_distance( - [self.st16, self.st14], self.q3, - algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st16, self.st15], self.q3, - algorithm='intuitive')[0, 1]) - self.assertEqual(stds.victor_purpura_distance( - [self.st01, self.st05], self.q3, - algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st01, self.st05], self.q7, - algorithm='intuitive')[0, 1]) + self.assertAlmostEqual( + stds.victor_purpura_distance( + [self.st14, self.st16], self.q3, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st15, self.st16], self.q3, algorithm="intuitive" + )[0, 1], + ) + self.assertAlmostEqual( + stds.victor_purpura_distance( + [self.st16, self.st14], self.q3, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st16, self.st15], self.q3, algorithm="intuitive" + )[0, 1], + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st01, self.st05], self.q3, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st01, self.st05], self.q7, algorithm="intuitive" + )[0, 1], + ) # Tests on algorithmic behaviour for equal spike times - self.assertEqual(stds.victor_purpura_distance( - [self.st31, self.st34], self.q3, - algorithm='intuitive')[0, 1], - 0.8 + 1.0) - self.assertEqual(stds.victor_purpura_distance( - [self.st31, self.st34], self.q3, - algorithm='intuitive')[0, 1], - stds.victor_purpura_distance( - [self.st32, self.st33], self.q3, - algorithm='intuitive')[0, 1]) - self.assertEqual(stds.victor_purpura_distance( - [self.st31, self.st33], self.q3, - algorithm='intuitive')[0, 1] * 2.0, - stds.victor_purpura_distance( - [self.st32, self.st34], self.q3, - algorithm='intuitive')[0, 1]) + self.assertEqual( + stds.victor_purpura_distance( + [self.st31, self.st34], self.q3, algorithm="intuitive" + )[0, 1], + 0.8 + 1.0, + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st31, self.st34], self.q3, algorithm="intuitive" + )[0, 1], + stds.victor_purpura_distance( + [self.st32, self.st33], self.q3, algorithm="intuitive" + )[0, 1], + ) + self.assertEqual( + stds.victor_purpura_distance( + [self.st31, self.st33], self.q3, algorithm="intuitive" + )[0, 1] + * 2.0, + stds.victor_purpura_distance( + [self.st32, self.st34], self.q3, algorithm="intuitive" + )[0, 1], + ) # Tests on spike train list lengthes smaller than 2 - self.assertEqual(stds.victor_purpura_distance( - [self.st21], self.q3, - algorithm='intuitive')[0, 0], 0) - self.assertEqual(len(stds.victor_purpura_distance( - [], self.q3, algorithm='intuitive')), 0) + self.assertEqual( + stds.victor_purpura_distance([self.st21], self.q3, algorithm="intuitive")[ + 0, 0 + ], + 0, + ) + self.assertEqual( + len(stds.victor_purpura_distance([], self.q3, algorithm="intuitive")), 0 + ) def test_victor_purpura_algorithm_comparison(self): assert_array_almost_equal( - stds.victor_purpura_distance([self.st21, self.st22, self.st23], - self.q3), - stds.victor_purpura_distance([self.st21, self.st22, self.st23], - self.q3, algorithm='intuitive')) + stds.victor_purpura_distance([self.st21, self.st22, self.st23], self.q3), + stds.victor_purpura_distance( + [self.st21, self.st22, self.st23], self.q3, algorithm="intuitive" + ), + ) def test_victor_purpura_matlab_comparison_float(self): - - repo_path =\ - r"unittest/spike_train_dissimilarity/victor_purpura_distance/data" + repo_path = r"unittest/spike_train_dissimilarity/victor_purpura_distance/data" files_to_download = [ ("times_float.npy", "ed1ff4d2c0eeed4a2b50a456803656be"), - ("matlab_results_float.npy", "a17f049e7ad0ddf7ca812e86fdb92646")] + ("matlab_results_float.npy", "a17f049e7ad0ddf7ca812e86fdb92646"), + ] for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + download_datasets(repo_path=f"{repo_path}/{filename}", checksum=checksum) - times_float = np.load(ELEPHANT_TMP_DIR / 'times_float.npy') - mat_res_float = np.load(ELEPHANT_TMP_DIR / 'matlab_results_float.npy') + times_float = np.load(ELEPHANT_TMP_DIR / "times_float.npy") + mat_res_float = np.load(ELEPHANT_TMP_DIR / "matlab_results_float.npy") - r_float = SpikeTrain(times_float[0], units='ms', t_start=0, - t_stop=1000 * ms) - s_float = SpikeTrain(times_float[1], units='ms', t_start=0, - t_stop=1000 * ms) - t_float = SpikeTrain(times_float[2], units='ms', t_start=0, - t_stop=1000 * ms) + r_float = SpikeTrain(times_float[0], units="ms", t_start=0, t_stop=1000 * ms) + s_float = SpikeTrain(times_float[1], units="ms", t_start=0, t_stop=1000 * ms) + t_float = SpikeTrain(times_float[2], units="ms", t_start=0, t_stop=1000 * ms) vic_pur_result_float = stds.victor_purpura_distance( [r_float, s_float, t_float], - cost_factor=1.0 / ms, kernel=None, - sort=True, algorithm='intuitive') + cost_factor=1.0 / ms, + kernel=None, + sort=True, + algorithm="intuitive", + ) assert_array_almost_equal(vic_pur_result_float, mat_res_float) def test_victor_purpura_matlab_comparison_int(self): - - repo_path =\ - r"unittest/spike_train_dissimilarity/victor_purpura_distance/data" + repo_path = r"unittest/spike_train_dissimilarity/victor_purpura_distance/data" files_to_download = [ ("times_int.npy", "aa1411c04da3f58d8b8913ae2f935057"), - ("matlab_results_int.npy", "7edd32e50edde12dc1ef4aa5f57f70fb")] + ("matlab_results_int.npy", "7edd32e50edde12dc1ef4aa5f57f70fb"), + ] for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + download_datasets(repo_path=f"{repo_path}/{filename}", checksum=checksum) - times_int = np.load(ELEPHANT_TMP_DIR / 'times_int.npy') - mat_res_int = np.load(ELEPHANT_TMP_DIR / 'matlab_results_int.npy') + times_int = np.load(ELEPHANT_TMP_DIR / "times_int.npy") + mat_res_int = np.load(ELEPHANT_TMP_DIR / "matlab_results_int.npy") - r_int = SpikeTrain(times_int[0], units='ms', t_start=0, - t_stop=1000 * ms) - s_int = SpikeTrain(times_int[1], units='ms', t_start=0, - t_stop=1000 * ms) - t_int = SpikeTrain(times_int[2], units='ms', t_start=0, - t_stop=1000 * ms) + r_int = SpikeTrain(times_int[0], units="ms", t_start=0, t_stop=1000 * ms) + s_int = SpikeTrain(times_int[1], units="ms", t_start=0, t_stop=1000 * ms) + t_int = SpikeTrain(times_int[2], units="ms", t_start=0, t_stop=1000 * ms) vic_pur_result_int = stds.victor_purpura_distance( [r_int, s_int, t_int], - cost_factor=1.0 / ms, kernel=None, - sort=True, algorithm='intuitive') + cost_factor=1.0 / ms, + kernel=None, + sort=True, + algorithm="intuitive", + ) assert_array_equal(vic_pur_result_int, mat_res_int) def test_van_rossum_distance(self): # Tests of distances of simplest spike trains - self.assertEqual(stds.van_rossum_distance( - [self.st00, self.st00], self.tau2)[0, 1], 0.0) - self.assertEqual(stds.van_rossum_distance( - [self.st00, self.st01], self.tau2)[0, 1], 1.0) - self.assertEqual(stds.van_rossum_distance( - [self.st01, self.st00], self.tau2)[0, 1], 1.0) - self.assertEqual(stds.van_rossum_distance( - [self.st01, self.st01], self.tau2)[0, 1], 0.0) + self.assertEqual( + stds.van_rossum_distance([self.st00, self.st00], self.tau2)[0, 1], 0.0 + ) + self.assertEqual( + stds.van_rossum_distance([self.st00, self.st01], self.tau2)[0, 1], 1.0 + ) + self.assertEqual( + stds.van_rossum_distance([self.st01, self.st00], self.tau2)[0, 1], 1.0 + ) + self.assertEqual( + stds.van_rossum_distance([self.st01, self.st01], self.tau2)[0, 1], 0.0 + ) # Tests of distances under elementary spike operations - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st01, self.st02], self.tau2)[0, 1], - float(np.sqrt(2 * (1.0 - np.exp(-np.absolute( - ((self.st01[0] - self.st02[0]) / - self.tau2).simplified)))))) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st01, self.st05], self.tau2)[0, 1], - float(np.sqrt(2 * (1.0 - np.exp(-np.absolute( - ((self.st01[0] - self.st05[0]) / - self.tau2).simplified)))))) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st01, self.st05], self.tau2)[0, 1], - np.sqrt(2.0), 1) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st01, self.st06], self.tau2)[0, 1], - np.sqrt(2.0), 20) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st00, self.st07], self.tau1)[0, 1], - np.sqrt(0 + 2)) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st07, self.st08], self.tau4)[0, 1], - float(np.sqrt(2 * (1.0 - np.exp(-np.absolute( - ((self.st07[0] - self.st08[-1]) / - self.tau4).simplified)))))) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st01, self.st02], self.tau2)[0, 1], + float( + np.sqrt( + 2 + * ( + 1.0 + - np.exp( + -np.absolute( + ((self.st01[0] - self.st02[0]) / self.tau2).simplified + ) + ) + ) + ) + ), + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st01, self.st05], self.tau2)[0, 1], + float( + np.sqrt( + 2 + * ( + 1.0 + - np.exp( + -np.absolute( + ((self.st01[0] - self.st05[0]) / self.tau2).simplified + ) + ) + ) + ) + ), + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st01, self.st05], self.tau2)[0, 1], + np.sqrt(2.0), + 1, + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st01, self.st06], self.tau2)[0, 1], + np.sqrt(2.0), + 20, + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st00, self.st07], self.tau1)[0, 1], + np.sqrt(0 + 2), + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st07, self.st08], self.tau4)[0, 1], + float( + np.sqrt( + 2 + * ( + 1.0 + - np.exp( + -np.absolute( + ((self.st07[0] - self.st08[-1]) / self.tau4).simplified + ) + ) + ) + ) + ), + ) f_minus_g_squared = ( - (self.t > self.st08[0]) * np.exp( - -((self.t - self.st08[0]) / self.tau3).simplified) + - (self.t > self.st08[1]) * np.exp( - -((self.t - self.st08[1]) / self.tau3).simplified) - - (self.t > self.st09[0]) * np.exp( - -((self.t - self.st09[0]) / self.tau3).simplified)) ** 2 - distance = np.sqrt(2.0 * spint.cumulative_trapezoid( - y=f_minus_g_squared, x=self.t.magnitude)[-1] / - self.tau3.rescale(self.t.units).magnitude) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st08, self.st09], self.tau3)[0, 1], distance, 5) - self.assertAlmostEqual(stds.van_rossum_distance( - [self.st11, self.st14], self.tau2)[0, 1], 1) + (self.t > self.st08[0]) + * np.exp(-((self.t - self.st08[0]) / self.tau3).simplified) + + (self.t > self.st08[1]) + * np.exp(-((self.t - self.st08[1]) / self.tau3).simplified) + - (self.t > self.st09[0]) + * np.exp(-((self.t - self.st09[0]) / self.tau3).simplified) + ) ** 2 + distance = np.sqrt( + 2.0 + * spint.cumulative_trapezoid(y=f_minus_g_squared, x=self.t.magnitude)[-1] + / self.tau3.rescale(self.t.units).magnitude + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st08, self.st09], self.tau3)[0, 1], + distance, + 5, + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st11, self.st14], self.tau2)[0, 1], 1 + ) # Tests on timescales self.assertAlmostEqual( stds.van_rossum_distance([self.st11, self.st14], self.tau1)[0, 1], - stds.van_rossum_distance([self.st11, self.st14], self.tau5)[0, 1]) + stds.van_rossum_distance([self.st11, self.st14], self.tau5)[0, 1], + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st11], self.tau0)[0, 1], - np.sqrt(len(self.st07) + len(self.st11))) + np.sqrt(len(self.st07) + len(self.st11)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st14], self.tau0)[0, 1], - np.sqrt(len(self.st07) + len(self.st14))) + np.sqrt(len(self.st07) + len(self.st14)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st11], self.tau1)[0, 1], - np.sqrt(len(self.st07) + len(self.st11))) + np.sqrt(len(self.st07) + len(self.st11)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st14], self.tau1)[0, 1], - np.sqrt(len(self.st07) + len(self.st14))) + np.sqrt(len(self.st07) + len(self.st14)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st11], self.tau5)[0, 1], - np.absolute(len(self.st07) - len(self.st11))) + np.absolute(len(self.st07) - len(self.st11)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st14], self.tau5)[0, 1], - np.absolute(len(self.st07) - len(self.st14))) + np.absolute(len(self.st07) - len(self.st14)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st11], self.tau6)[0, 1], - np.absolute(len(self.st07) - len(self.st11))) + np.absolute(len(self.st07) - len(self.st11)), + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st07, self.st14], self.tau6)[0, 1], - np.absolute(len(self.st07) - len(self.st14))) + np.absolute(len(self.st07) - len(self.st14)), + ) # Tests on unordered spiketrains self.assertEqual( stds.van_rossum_distance([self.st11, self.st13], self.tau4)[0, 1], - stds.van_rossum_distance([self.st12, self.st13], self.tau4)[0, 1]) + stds.van_rossum_distance([self.st12, self.st13], self.tau4)[0, 1], + ) self.assertNotEqual( - stds.van_rossum_distance([self.st11, self.st13], - self.tau4, sort=False)[0, 1], - stds.van_rossum_distance([self.st12, self.st13], - self.tau4, sort=False)[0, 1]) + stds.van_rossum_distance([self.st11, self.st13], self.tau4, sort=False)[ + 0, 1 + ], + stds.van_rossum_distance([self.st12, self.st13], self.tau4, sort=False)[ + 0, 1 + ], + ) # Tests on metric properties with random spiketrains # (explicit calculation of second metric axiom in particular case, # because from dist_matrix it is trivial) dist_matrix = stds.van_rossum_distance( - [self.st21, self.st22, self.st23], self.tau3) + [self.st21, self.st22, self.st23], self.tau3 + ) for i in range(3): for j in range(3): self.assertGreaterEqual(dist_matrix[i, j], 0) @@ -548,50 +750,55 @@ def test_van_rossum_distance(self): assert_array_equal(self.rd_st_list[i], self.rd_st_list[j]) assert_array_equal( stds.van_rossum_distance([self.st21, self.st22], self.tau3), - stds.van_rossum_distance([self.st22, self.st21], self.tau3)) - self.assertLessEqual(dist_matrix[0, 1], - dist_matrix[0, 2] + dist_matrix[1, 2]) - self.assertLessEqual(dist_matrix[0, 2], - dist_matrix[1, 2] + dist_matrix[0, 1]) - self.assertLessEqual(dist_matrix[1, 2], - dist_matrix[0, 1] + dist_matrix[0, 2]) + stds.van_rossum_distance([self.st22, self.st21], self.tau3), + ) + self.assertLessEqual(dist_matrix[0, 1], dist_matrix[0, 2] + dist_matrix[1, 2]) + self.assertLessEqual(dist_matrix[0, 2], dist_matrix[1, 2] + dist_matrix[0, 1]) + self.assertLessEqual(dist_matrix[1, 2], dist_matrix[0, 1] + dist_matrix[0, 2]) # Tests on proper unit conversion self.assertAlmostEqual( stds.van_rossum_distance([self.st14, self.st16], self.tau3)[0, 1], - stds.van_rossum_distance([self.st15, self.st16], self.tau3)[0, 1]) + stds.van_rossum_distance([self.st15, self.st16], self.tau3)[0, 1], + ) self.assertAlmostEqual( stds.van_rossum_distance([self.st16, self.st14], self.tau3)[0, 1], - stds.van_rossum_distance([self.st16, self.st15], self.tau3)[0, 1]) + stds.van_rossum_distance([self.st16, self.st15], self.tau3)[0, 1], + ) self.assertEqual( stds.van_rossum_distance([self.st01, self.st05], self.tau3)[0, 1], - stds.van_rossum_distance([self.st01, self.st05], self.tau7)[0, 1]) + stds.van_rossum_distance([self.st01, self.st05], self.tau7)[0, 1], + ) # Tests on algorithmic behaviour for equal spike times f_minus_g_squared = ( - (self.t > self.st31[0]) * np.exp( - -((self.t - self.st31[0]) / self.tau3).simplified) - - (self.t > self.st34[0]) * np.exp( - -((self.t - self.st34[0]) / self.tau3).simplified) - - (self.t > self.st34[1]) * np.exp( - -((self.t - self.st34[1]) / self.tau3).simplified)) ** 2 - distance = np.sqrt(2.0 * spint.cumulative_trapezoid( - y=f_minus_g_squared, x=self.t.magnitude)[-1] / - self.tau3.rescale(self.t.units).magnitude) - self.assertAlmostEqual(stds.van_rossum_distance([self.st31, self.st34], - self.tau3)[0, 1], - distance, 5) - self.assertEqual(stds.van_rossum_distance([self.st31, self.st34], - self.tau3)[0, 1], - stds.van_rossum_distance([self.st32, self.st33], - self.tau3)[0, 1]) - self.assertEqual(stds.van_rossum_distance([self.st31, self.st33], - self.tau3)[0, 1] * 2.0, - stds.van_rossum_distance([self.st32, self.st34], - self.tau3)[0, 1]) + (self.t > self.st31[0]) + * np.exp(-((self.t - self.st31[0]) / self.tau3).simplified) + - (self.t > self.st34[0]) + * np.exp(-((self.t - self.st34[0]) / self.tau3).simplified) + - (self.t > self.st34[1]) + * np.exp(-((self.t - self.st34[1]) / self.tau3).simplified) + ) ** 2 + distance = np.sqrt( + 2.0 + * spint.cumulative_trapezoid(y=f_minus_g_squared, x=self.t.magnitude)[-1] + / self.tau3.rescale(self.t.units).magnitude + ) + self.assertAlmostEqual( + stds.van_rossum_distance([self.st31, self.st34], self.tau3)[0, 1], + distance, + 5, + ) + self.assertEqual( + stds.van_rossum_distance([self.st31, self.st34], self.tau3)[0, 1], + stds.van_rossum_distance([self.st32, self.st33], self.tau3)[0, 1], + ) + self.assertEqual( + stds.van_rossum_distance([self.st31, self.st33], self.tau3)[0, 1] * 2.0, + stds.van_rossum_distance([self.st32, self.st34], self.tau3)[0, 1], + ) # Tests on spike train list lengthes smaller than 2 - self.assertEqual(stds.van_rossum_distance( - [self.st21], self.tau3)[0, 0], 0) + self.assertEqual(stds.van_rossum_distance([self.st21], self.tau3)[0, 0], 0) self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_spike_train_generation.py b/elephant/test/test_spike_train_generation.py index 3ea160c35..368b74311 100644 --- a/elephant/test/test_spike_train_generation.py +++ b/elephant/test/test_spike_train_generation.py @@ -18,13 +18,25 @@ import quantities as pq from scipy.stats import expon, kstest, poisson, variation -from elephant.spike_train_generation import StationaryPoissonProcess, \ - threshold_detection, peak_detection, spike_extraction, \ - AbstractPointProcess, StationaryGammaProcess, StationaryLogNormalProcess, \ - NonStationaryPoissonProcess, NonStationaryGammaProcess, \ - StationaryInverseGaussianProcess, _n_poisson, single_interaction_process, \ - cpp, homogeneous_gamma_process, homogeneous_poisson_process, \ - inhomogeneous_poisson_process, inhomogeneous_gamma_process +from elephant.spike_train_generation import ( + StationaryPoissonProcess, + threshold_detection, + peak_detection, + spike_extraction, + AbstractPointProcess, + StationaryGammaProcess, + StationaryLogNormalProcess, + NonStationaryPoissonProcess, + NonStationaryGammaProcess, + StationaryInverseGaussianProcess, + _n_poisson, + single_interaction_process, + cpp, + homogeneous_gamma_process, + homogeneous_poisson_process, + inhomogeneous_poisson_process, + inhomogeneous_gamma_process, +) from elephant.statistics import isi, instantaneous_rate from elephant import kernels @@ -38,21 +50,25 @@ def pdiff(a, b): class AnalogSignalThresholdDetectionTestCase(unittest.TestCase): - def setUp(self): # Load membrane potential simulated using Brian2 # according to make_spike_extraction_test_data.py. curr_dir = os.path.dirname(os.path.realpath(__file__)) - raw_data_file_loc = os.path.join( - curr_dir, 'spike_extraction_test_data.txt') + raw_data_file_loc = os.path.join(curr_dir, "spike_extraction_test_data.txt") raw_data = [] - with open(raw_data_file_loc, 'r') as f: + with open(raw_data_file_loc, "r") as f: for x in f.readlines(): raw_data.append(float(x)) - self.vm = neo.AnalogSignal( - raw_data, units=pq.V, sampling_period=0.1 * pq.ms) - self.true_time_stamps = [0.0123, 0.0354, 0.0712, 0.1191, 0.1694, - 0.2200, 0.2711] * pq.s + self.vm = neo.AnalogSignal(raw_data, units=pq.V, sampling_period=0.1 * pq.ms) + self.true_time_stamps = [ + 0.0123, + 0.0354, + 0.0712, + 0.1191, + 0.1694, + 0.2200, + 0.2711, + ] * pq.s def test_threshold_detection(self): # Test whether spikes are extracted at the correct times from @@ -65,13 +81,19 @@ def test_threshold_detection(self): # spike trains being treated as unsized objects. except TypeError: warnings.warn( - ("The spike train may be an unsized object. This may be" - " related to an issue in Neo with some zero-length SpikeTrain" - " objects. Bypassing this by creating an empty SpikeTrain" - " object.")) - spike_train = neo.SpikeTrain([], t_start=spike_train.t_start, - t_stop=spike_train.t_stop, - units=spike_train.units) + ( + "The spike train may be an unsized object. This may be" + " related to an issue in Neo with some zero-length SpikeTrain" + " objects. Bypassing this by creating an empty SpikeTrain" + " object." + ) + ) + spike_train = neo.SpikeTrain( + [], + t_start=spike_train.t_start, + t_stop=spike_train.t_stop, + units=spike_train.units, + ) # Does threshold_detection gives the correct number of spikes? self.assertEqual(len(spike_train), len(self.true_time_stamps)) @@ -88,19 +110,23 @@ def test_peak_detection_threshold(self): class AnalogSignalPeakDetectionTestCase(unittest.TestCase): - def setUp(self): curr_dir = os.path.dirname(os.path.realpath(__file__)) - raw_data_file_loc = os.path.join( - curr_dir, 'spike_extraction_test_data.txt') + raw_data_file_loc = os.path.join(curr_dir, "spike_extraction_test_data.txt") raw_data = [] - with open(raw_data_file_loc, 'r') as f: + with open(raw_data_file_loc, "r") as f: for x in f.readlines(): raw_data.append(float(x)) - self.vm = neo.AnalogSignal( - raw_data, units=pq.V, sampling_period=0.1 * pq.ms) - self.true_time_stamps = [0.0124, 0.0354, 0.0713, 0.1192, 0.1695, - 0.2201, 0.2711] * pq.s + self.vm = neo.AnalogSignal(raw_data, units=pq.V, sampling_period=0.1 * pq.ms) + self.true_time_stamps = [ + 0.0124, + 0.0354, + 0.0713, + 0.1192, + 0.1695, + 0.2201, + 0.2711, + ] * pq.s def test_peak_detection_time_stamps(self): # Test with default arguments @@ -120,42 +146,63 @@ def test_peak_detection_threshold(self): class AnalogSignalSpikeExtractionTestCase(unittest.TestCase): - def setUp(self): curr_dir = os.path.dirname(os.path.realpath(__file__)) - raw_data_file_loc = os.path.join( - curr_dir, 'spike_extraction_test_data.txt') + raw_data_file_loc = os.path.join(curr_dir, "spike_extraction_test_data.txt") raw_data = [] - with open(raw_data_file_loc, 'r') as f: + with open(raw_data_file_loc, "r") as f: for x in f.readlines(): raw_data.append(float(x)) - self.vm = neo.AnalogSignal( - raw_data, units=pq.V, sampling_period=0.1 * pq.ms) - self.first_spike = np.array([-0.04084546, -0.03892033, -0.03664779, - -0.03392689, -0.03061474, -0.02650277, - -0.0212756, -0.01443531, -0.00515365, - 0.00803962, 0.02797951, -0.07, - -0.06974495, -0.06950466, -0.06927778, - -0.06906314, -0.06885969, -0.06866651, - -0.06848277, -0.06830773, -0.06814071, - -0.06798113, -0.06782843, -0.06768213, - -0.06754178, -0.06740699, -0.06727737, - -0.06715259, -0.06703235, -0.06691635]) + self.vm = neo.AnalogSignal(raw_data, units=pq.V, sampling_period=0.1 * pq.ms) + self.first_spike = np.array( + [ + -0.04084546, + -0.03892033, + -0.03664779, + -0.03392689, + -0.03061474, + -0.02650277, + -0.0212756, + -0.01443531, + -0.00515365, + 0.00803962, + 0.02797951, + -0.07, + -0.06974495, + -0.06950466, + -0.06927778, + -0.06906314, + -0.06885969, + -0.06866651, + -0.06848277, + -0.06830773, + -0.06814071, + -0.06798113, + -0.06782843, + -0.06768213, + -0.06754178, + -0.06740699, + -0.06727737, + -0.06715259, + -0.06703235, + -0.06691635, + ] + ) def test_spike_extraction_waveform(self): - spike_train = spike_extraction(self.vm.reshape(-1), - interval=(-1 * pq.ms, 2 * pq.ms)) + spike_train = spike_extraction( + self.vm.reshape(-1), interval=(-1 * pq.ms, 2 * pq.ms) + ) assert_array_almost_equal( - spike_train.waveforms[0][0].magnitude.reshape(-1), - self.first_spike) + spike_train.waveforms[0][0].magnitude.reshape(-1), self.first_spike + ) class AbstractPointProcessTestCase(unittest.TestCase): def test_not_implemented_error(self): process = AbstractPointProcess() - self.assertRaises( - NotImplementedError, process._generate_spiketrain_as_array) + self.assertRaises(NotImplementedError, process._generate_spiketrain_as_array) class StationaryPoissonProcessTestCase(unittest.TestCase): @@ -166,37 +213,40 @@ def test_statistics(self): for rate in [123.0 * pq.Hz, 0.123 * pq.kHz]: for t_stop in [2345 * pq.ms, 2.345 * pq.s]: - for refractory_period in (None, 3. * pq.ms): + for refractory_period in (None, 3.0 * pq.ms): np.random.seed(seed=123456) spiketrain_old = homogeneous_poisson_process( - rate, t_stop=t_stop, - refractory_period=refractory_period) + rate, t_stop=t_stop, refractory_period=refractory_period + ) np.random.seed(seed=123456) spiketrain = StationaryPoissonProcess( - rate, t_stop=t_stop, + rate, + t_stop=t_stop, refractory_period=refractory_period, - equilibrium=False + equilibrium=False, ).generate_spiketrain() assert_array_almost_equal( - spiketrain_old.magnitude, - spiketrain.magnitude) + spiketrain_old.magnitude, spiketrain.magnitude + ) intervals = isi(spiketrain) - expected_mean_isi = 1. / rate.simplified + expected_mean_isi = 1.0 / rate.simplified self.assertAlmostEqual( expected_mean_isi.magnitude, intervals.mean().simplified.magnitude, - places=3) + places=3, + ) expected_first_spike = 0 * pq.ms self.assertLess( - spiketrain[0] - expected_first_spike, - 7 * expected_mean_isi) + spiketrain[0] - expected_first_spike, 7 * expected_mean_isi + ) expected_last_spike = t_stop - self.assertLess(expected_last_spike - - spiketrain[-1], 7 * expected_mean_isi) + self.assertLess( + expected_last_spike - spiketrain[-1], 7 * expected_mean_isi + ) if refractory_period is None: # Kolmogorov-Smirnov test @@ -204,24 +254,31 @@ def test_statistics(self): intervals.rescale(t_stop.units).magnitude, "expon", # args are (loc, scale) - args=(0., expected_mean_isi.rescale( - t_stop.units).magnitude), - alternative='two-sided') + args=( + 0.0, + expected_mean_isi.rescale(t_stop.units).magnitude, + ), + alternative="two-sided", + ) else: refractory_period = refractory_period.rescale( - t_stop.units).item() - measured_rate = 1. / expected_mean_isi.rescale( - t_stop.units).item() + t_stop.units + ).item() + measured_rate = ( + 1.0 / expected_mean_isi.rescale(t_stop.units).item() + ) effective_rate = measured_rate / ( - 1. - measured_rate * refractory_period) + 1.0 - measured_rate * refractory_period + ) # Kolmogorov-Smirnov test D, p = kstest( intervals.rescale(t_stop.units).magnitude, "expon", # args are (loc, scale) - args=(refractory_period, 1. / effective_rate), - alternative='two-sided') + args=(refractory_period, 1.0 / effective_rate), + alternative="two-sided", + ) self.assertGreater(p, 0.001) self.assertLess(D, 0.12) @@ -230,14 +287,14 @@ def test_zero_refractory_period(self): t_stop = 20 * pq.s np.random.seed(27) - sp1 = StationaryPoissonProcess( - rate, t_stop=t_stop).generate_spiketrain(as_array=True) + sp1 = StationaryPoissonProcess(rate, t_stop=t_stop).generate_spiketrain( + as_array=True + ) np.random.seed(27) sp2 = StationaryPoissonProcess( - rate, t_stop=t_stop, refractory_period=0. * pq.ms - ).generate_spiketrain( - as_array=True) + rate, t_stop=t_stop, refractory_period=0.0 * pq.ms + ).generate_spiketrain(as_array=True) assert_array_almost_equal(sp1, sp2) @@ -247,7 +304,8 @@ def test_t_start_and_t_stop(self): t_stop = 2 * pq.s sp1 = StationaryPoissonProcess( - rate, t_start=t_start, t_stop=t_stop).generate_spiketrain() + rate, t_start=t_start, t_stop=t_stop + ).generate_spiketrain() sp2 = StationaryPoissonProcess( rate, t_start=t_start, t_stop=t_stop, refractory_period=3 * pq.ms @@ -264,8 +322,10 @@ def test_zero_rate(self): # RuntimeWarning: divide by zero encountered in true_divide # mean_interval = 1 / rate.magnitude, when rate == 0 Hz. spiketrain = StationaryPoissonProcess( - rate=0 * pq.Hz, t_stop=10 * pq.s, - refractory_period=refractory_period).generate_spiketrain() + rate=0 * pq.Hz, + t_stop=10 * pq.s, + refractory_period=refractory_period, + ).generate_spiketrain() self.assertEqual(spiketrain.size, 0) def test_nondecrease_spike_times(self): @@ -273,8 +333,8 @@ def test_nondecrease_spike_times(self): np.random.seed(27) spiketrain = StationaryPoissonProcess( - rate=10 * pq.Hz, t_stop=1000 * pq.s, - refractory_period=refractory_period).generate_spiketrain() + rate=10 * pq.Hz, t_stop=1000 * pq.s, refractory_period=refractory_period + ).generate_spiketrain() diffs = np.diff(spiketrain.times) self.assertTrue((diffs >= 0).all()) @@ -283,7 +343,8 @@ def test_compare_with_as_array(self): t_stop = 10 * pq.s for refractory_period in (None, 3 * pq.ms): process = StationaryPoissonProcess( - rate=rate, t_stop=t_stop, refractory_period=refractory_period) + rate=rate, t_stop=t_stop, refractory_period=refractory_period + ) np.random.seed(27) spiketrain = process.generate_spiketrain() self.assertIsInstance(spiketrain, neo.SpikeTrain) @@ -291,25 +352,25 @@ def test_compare_with_as_array(self): spiketrain_array = process.generate_spiketrain().as_array() # don't check with isinstance: Quantity is a subclass of np.ndarray self.assertTrue(isinstance(spiketrain_array, np.ndarray)) - assert_array_almost_equal(spiketrain.times.magnitude, - spiketrain_array) + assert_array_almost_equal(spiketrain.times.magnitude, spiketrain_array) def test_effective_rate_refractory_period(self): np.random.seed(27) rate_expected = 10 * pq.Hz refractory_period = 90 * pq.ms # 10 ms of effective ISI spiketrain = StationaryPoissonProcess( - rate_expected, t_stop=1000 * pq.s, - refractory_period=refractory_period + rate_expected, t_stop=1000 * pq.s, refractory_period=refractory_period ).generate_spiketrain() rate_obtained = len(spiketrain) / spiketrain.t_stop rate_obtained = rate_obtained.simplified - self.assertAlmostEqual(rate_expected.simplified, - rate_obtained.simplified, places=1) + self.assertAlmostEqual( + rate_expected.simplified, rate_obtained.simplified, places=1 + ) intervals = isi(spiketrain) - isi_mean_expected = 1. / rate_expected - self.assertAlmostEqual(isi_mean_expected.simplified, - intervals.mean().simplified, places=3) + isi_mean_expected = 1.0 / rate_expected + self.assertAlmostEqual( + isi_mean_expected.simplified, intervals.mean().simplified, places=3 + ) def test_invalid(self): rate = 10 * pq.Hz @@ -318,21 +379,32 @@ def test_invalid(self): hpp = StationaryPoissonProcess self.assertRaises( - ValueError, hpp, rate=rate, t_start=5 * pq.ms, - t_stop=1 * pq.ms, refractory_period=refractory_period) + ValueError, + hpp, + rate=rate, + t_start=5 * pq.ms, + t_stop=1 * pq.ms, + refractory_period=refractory_period, + ) # no units provided for rate, t_stop - self.assertRaises(ValueError, hpp, rate=10, - refractory_period=refractory_period) - self.assertRaises(ValueError, hpp, rate=rate, t_stop=5, - refractory_period=refractory_period) + self.assertRaises( + ValueError, hpp, rate=10, refractory_period=refractory_period + ) + self.assertRaises( + ValueError, + hpp, + rate=rate, + t_stop=5, + refractory_period=refractory_period, + ) # no units provided for refractory_period self.assertRaises(ValueError, hpp, rate=rate, refractory_period=2) - self.assertRaises(ValueError, StationaryPoissonProcess, - rate, refractory_period=1. * pq.s) + self.assertRaises( + ValueError, StationaryPoissonProcess, rate, refractory_period=1.0 * pq.s + ) class StationaryGammaProcessTestCase(unittest.TestCase): - def test_statistics(self): # This is a statistical test that has a non-zero chance of failure # during normal operation. Thus, we set the random seed to a value that @@ -341,12 +413,10 @@ def test_statistics(self): for b in (67.0 * pq.Hz, 0.067 * pq.kHz): for t_stop in (2345 * pq.ms, 2.345 * pq.s): np.random.seed(seed=12345) - spiketrain_old = homogeneous_gamma_process( - a, b, t_stop=t_stop) + spiketrain_old = homogeneous_gamma_process(a, b, t_stop=t_stop) np.random.seed(seed=12345) spiketrain = StationaryGammaProcess( - rate=b / a, shape_factor=a, t_stop=t_stop, - equilibrium=False + rate=b / a, shape_factor=a, t_stop=t_stop, equilibrium=False ).generate_spiketrain() assert_allclose(spiketrain_old.magnitude, spiketrain.magnitude) @@ -354,50 +424,50 @@ def test_statistics(self): expected_spike_count = int((b / a * t_stop).simplified) # should fail about 1 time in 1000 - self.assertLess( - pdiff(expected_spike_count, spiketrain.size), 0.25) + self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.25) expected_mean_isi = (a / b).rescale(pq.ms) - self.assertLess( - pdiff(expected_mean_isi, intervals.mean()), 0.3) + self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.3) expected_first_spike = 0 * pq.ms self.assertLess( - spiketrain[0] - expected_first_spike, - 4 * expected_mean_isi) + spiketrain[0] - expected_first_spike, 4 * expected_mean_isi + ) expected_last_spike = t_stop - self.assertLess(expected_last_spike - - spiketrain[-1], 4 * expected_mean_isi) + self.assertLess( + expected_last_spike - spiketrain[-1], 4 * expected_mean_isi + ) # Kolmogorov-Smirnov test - D, p = kstest(intervals.rescale(t_stop.units), - "gamma", - # args are (a, loc, scale) - args=(a, 0, (1 / b).rescale(t_stop.units)), - alternative='two-sided') + D, p = kstest( + intervals.rescale(t_stop.units), + "gamma", + # args are (a, loc, scale) + args=(a, 0, (1 / b).rescale(t_stop.units)), + alternative="two-sided", + ) self.assertGreater(p, 0.001) self.assertLess(D, 0.25) def test_compare_with_as_array(self): - a = 3. + a = 3.0 b = 10 * pq.Hz np.random.seed(27) spiketrain = StationaryGammaProcess( - rate=b / a, shape_factor=a, - equilibrium=False).generate_spiketrain() + rate=b / a, shape_factor=a, equilibrium=False + ).generate_spiketrain() self.assertIsInstance(spiketrain, neo.SpikeTrain) np.random.seed(27) spiketrain_array = StationaryGammaProcess( - rate=b / a, shape_factor=a, equilibrium=False).generate_spiketrain( - as_array=True) + rate=b / a, shape_factor=a, equilibrium=False + ).generate_spiketrain(as_array=True) # don't check with isinstance: pq.Quantity is a subclass of np.ndarray self.assertTrue(isinstance(spiketrain_array, np.ndarray)) assert_array_almost_equal(spiketrain.times.magnitude, spiketrain_array) class StationaryLogNormalProcessTestCase(unittest.TestCase): - def test_statistics(self): # This is a statistical test that has a non-zero chance of failure # during normal operation. Thus, we set the random seed to a value that @@ -407,38 +477,40 @@ def test_statistics(self): for t_stop in (2345 * pq.ms, 2.345 * pq.s): np.random.seed(seed=123456) spiketrain = StationaryLogNormalProcess( - rate=rate, sigma=sigma, t_stop=t_stop, - equilibrium=False + rate=rate, sigma=sigma, t_stop=t_stop, equilibrium=False ).generate_spiketrain() intervals = isi(spiketrain) expected_spike_count = int((rate * t_stop).simplified) # should fail about 1 time in 1000 - self.assertLess( - pdiff(expected_spike_count, spiketrain.size), 0.25) + self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.25) expected_mean_isi = (1 / rate).rescale(pq.ms) - self.assertLess( - pdiff(expected_mean_isi, intervals.mean()), 0.3) + self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.3) expected_first_spike = 0 * pq.ms self.assertLess( - spiketrain[0] - expected_first_spike, - 4 * expected_mean_isi) + spiketrain[0] - expected_first_spike, 4 * expected_mean_isi + ) expected_last_spike = t_stop - self.assertLess(expected_last_spike - - spiketrain[-1], 4 * expected_mean_isi) + self.assertLess( + expected_last_spike - spiketrain[-1], 4 * expected_mean_isi + ) # Kolmogorov-Smirnov test - D, p = kstest(intervals.rescale(t_stop.units), - "lognorm", - # args are (s, loc, scale) - args=(sigma, 0, - (1 / rate * np.exp(- sigma ** 2 / 2) - ).rescale(t_stop.units)), - alternative='two-sided') + D, p = kstest( + intervals.rescale(t_stop.units), + "lognorm", + # args are (s, loc, scale) + args=( + sigma, + 0, + (1 / rate * np.exp(-(sigma**2) / 2)).rescale(t_stop.units), + ), + alternative="two-sided", + ) self.assertGreater(p, 0.001) self.assertLess(D, 0.25) @@ -447,21 +519,19 @@ def test_compare_with_as_array(self): rate = 10 * pq.Hz np.random.seed(27) spiketrain = StationaryLogNormalProcess( - rate=rate, sigma=sigma, - equilibrium=False).generate_spiketrain() + rate=rate, sigma=sigma, equilibrium=False + ).generate_spiketrain() self.assertIsInstance(spiketrain, neo.SpikeTrain) np.random.seed(27) spiketrain_array = StationaryLogNormalProcess( rate=rate, sigma=sigma, equilibrium=False - ).generate_spiketrain( - as_array=True) + ).generate_spiketrain(as_array=True) # don't check with isinstance: pq.Quantity is a subclass of np.ndarray self.assertTrue(isinstance(spiketrain_array, np.ndarray)) assert_array_almost_equal(spiketrain.times.magnitude, spiketrain_array) class StationaryInverseGaussianProcessTestCase(unittest.TestCase): - def test_statistics(self): # This is a statistical test that has a non-zero chance of failure # during normal operation. Thus, we set the random seed to a value that @@ -479,30 +549,29 @@ def test_statistics(self): expected_spike_count = int((rate * t_stop).simplified) # should fail about 1 time in 1000 - self.assertLess( - pdiff(expected_spike_count, spiketrain.size), 0.25) + self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.25) expected_mean_isi = (1 / rate).rescale(pq.ms) - self.assertLess( - pdiff(expected_mean_isi, intervals.mean()), 0.3) + self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.3) expected_first_spike = 0 * pq.ms self.assertLess( - spiketrain[0] - expected_first_spike, - 4 * expected_mean_isi) + spiketrain[0] - expected_first_spike, 4 * expected_mean_isi + ) expected_last_spike = t_stop - self.assertLess(expected_last_spike - - spiketrain[-1], 4 * expected_mean_isi) + self.assertLess( + expected_last_spike - spiketrain[-1], 4 * expected_mean_isi + ) # Kolmogorov-Smirnov test - D, p = kstest(intervals.rescale(t_stop.units), - "invgauss", - # args are (mu, loc, scale) - args=(cv ** 2, 0, - (1 / (rate * cv ** 2)).rescale( - t_stop.units)), - alternative='two-sided') + D, p = kstest( + intervals.rescale(t_stop.units), + "invgauss", + # args are (mu, loc, scale) + args=(cv**2, 0, (1 / (rate * cv**2)).rescale(t_stop.units)), + alternative="two-sided", + ) self.assertGreater(p, 0.001) self.assertLess(D, 0.25) @@ -511,12 +580,13 @@ def test_compare_with_as_array(self): rate = 10 * pq.Hz np.random.seed(27) spiketrain = StationaryInverseGaussianProcess( - rate=rate, cv=cv, equilibrium=False).generate_spiketrain() + rate=rate, cv=cv, equilibrium=False + ).generate_spiketrain() self.assertIsInstance(spiketrain, neo.SpikeTrain) np.random.seed(27) spiketrain_array = StationaryInverseGaussianProcess( - rate=rate, cv=cv, equilibrium=False).generate_spiketrain( - as_array=True) + rate=rate, cv=cv, equilibrium=False + ).generate_spiketrain(as_array=True) # don't check with isinstance: pq.Quantity is a subclass of np.ndarray self.assertTrue(isinstance(spiketrain_array, np.ndarray)) assert_array_almost_equal(spiketrain.times.magnitude, spiketrain_array) @@ -525,14 +595,14 @@ def test_compare_with_as_array(self): class FirstSpikeCvTestCase(unittest.TestCase): def setUp(self): np.random.seed(987654321) - self.rate = 100. * pq.Hz - self.t_stop = 10. * pq.s + self.rate = 100.0 * pq.Hz + self.t_stop = 10.0 * pq.s self.n_spiketrains = 10 # can only have CV equal to 1. self.poisson_process = StationaryPoissonProcess( - rate=self.rate, - t_stop=self.t_stop) + rate=self.rate, t_stop=self.t_stop + ) # choose all further processes to have CV of 1/2 # CV = 1 - rate * refractory_period @@ -540,126 +610,126 @@ def setUp(self): rate=self.rate, refractory_period=0.5 / self.rate, t_stop=self.t_stop, - equilibrium=False) + equilibrium=False, + ) - self.poisson_refractory_period_equilibrium = \ - StationaryPoissonProcess( - rate=self.rate, - refractory_period=0.5 / self.rate, - t_stop=self.t_stop, - equilibrium=True) + self.poisson_refractory_period_equilibrium = StationaryPoissonProcess( + rate=self.rate, + refractory_period=0.5 / self.rate, + t_stop=self.t_stop, + equilibrium=True, + ) # CV = 1 / sqrt(shape_factor) self.gamma_process_ordinary = StationaryGammaProcess( - rate=self.rate, - shape_factor=4, - t_stop=self.t_stop, - equilibrium=False) + rate=self.rate, shape_factor=4, t_stop=self.t_stop, equilibrium=False + ) self.gamma_process_equilibrium = StationaryGammaProcess( - rate=self.rate, - shape_factor=4, - t_stop=self.t_stop, - equilibrium=True) + rate=self.rate, shape_factor=4, t_stop=self.t_stop, equilibrium=True + ) # CV = sqrt(exp(sigma**2) - 1) self.log_normal_process_ordinary = StationaryLogNormalProcess( rate=self.rate, - sigma=np.sqrt(np.log(5. / 4.)), + sigma=np.sqrt(np.log(5.0 / 4.0)), t_stop=self.t_stop, - equilibrium=False) + equilibrium=False, + ) self.log_normal_process_equilibrium = StationaryLogNormalProcess( rate=self.rate, - sigma=np.sqrt(np.log(5. / 4.)), + sigma=np.sqrt(np.log(5.0 / 4.0)), t_stop=self.t_stop, - equilibrium=True) + equilibrium=True, + ) - self.inverse_gaussian_process_ordinary = \ - StationaryInverseGaussianProcess( - rate=self.rate, - cv=1 / 2, - t_stop=self.t_stop, - equilibrium=False) + self.inverse_gaussian_process_ordinary = StationaryInverseGaussianProcess( + rate=self.rate, cv=1 / 2, t_stop=self.t_stop, equilibrium=False + ) - self.inverse_gaussian_process_equilibrium = \ - StationaryInverseGaussianProcess( - rate=self.rate, - cv=1 / 2, - t_stop=self.t_stop, - equilibrium=True) + self.inverse_gaussian_process_equilibrium = StationaryInverseGaussianProcess( + rate=self.rate, cv=1 / 2, t_stop=self.t_stop, equilibrium=True + ) def test_cv(self): - processes = (self.poisson_process, - self.poisson_refractory_period_ordinary, - self.gamma_process_ordinary, - self.log_normal_process_ordinary, - self.inverse_gaussian_process_ordinary) + processes = ( + self.poisson_process, + self.poisson_refractory_period_ordinary, + self.gamma_process_ordinary, + self.log_normal_process_ordinary, + self.inverse_gaussian_process_ordinary, + ) for process in processes: if process is self.poisson_process: - self.assertAlmostEqual(1., process.expected_cv) + self.assertAlmostEqual(1.0, process.expected_cv) # test the general expected-cv function - self.assertAlmostEqual( - 1., super(type(process), process).expected_cv) + self.assertAlmostEqual(1.0, super(type(process), process).expected_cv) else: self.assertAlmostEqual(0.5, process.expected_cv) # test the general expected-cv function - self.assertAlmostEqual( - 0.5, super(type(process), process).expected_cv) + self.assertAlmostEqual(0.5, super(type(process), process).expected_cv) spiketrains = process.generate_n_spiketrains( - n_spiketrains=self.n_spiketrains, - as_array=True) + n_spiketrains=self.n_spiketrains, as_array=True + ) - cvs = [variation(np.diff(spiketrain)) - for spiketrain in spiketrains] + cvs = [variation(np.diff(spiketrain)) for spiketrain in spiketrains] mean_cv = np.mean(cvs) - assert_allclose( - process.expected_cv, mean_cv, atol=0.01) + assert_allclose(process.expected_cv, mean_cv, atol=0.01) def test_first_spike(self): - ordinary_processes = (self.poisson_refractory_period_ordinary, - self.gamma_process_ordinary, - self.log_normal_process_ordinary, - self.inverse_gaussian_process_ordinary) - equilibrium_processes = (self.poisson_refractory_period_equilibrium, - self.gamma_process_equilibrium, - self.log_normal_process_equilibrium, - self.inverse_gaussian_process_equilibrium) + ordinary_processes = ( + self.poisson_refractory_period_ordinary, + self.gamma_process_ordinary, + self.log_normal_process_ordinary, + self.inverse_gaussian_process_ordinary, + ) + equilibrium_processes = ( + self.poisson_refractory_period_equilibrium, + self.gamma_process_equilibrium, + self.log_normal_process_equilibrium, + self.inverse_gaussian_process_equilibrium, + ) for ordinary_process, equilibrium_process in zip( - ordinary_processes, equilibrium_processes): + ordinary_processes, equilibrium_processes + ): ordinary_spiketrains = ordinary_process.generate_n_spiketrains( - self.n_spiketrains) - equilbrium_spiketrains = \ - equilibrium_process.generate_n_spiketrains( - self.n_spiketrains) - first_spikes_ordinary = [spiketrain[0].item() - for spiketrain in ordinary_spiketrains] - first_spikes_equilibrium = \ - [spiketrain[0].item() - for spiketrain in equilbrium_spiketrains] + self.n_spiketrains + ) + equilbrium_spiketrains = equilibrium_process.generate_n_spiketrains( + self.n_spiketrains + ) + first_spikes_ordinary = [ + spiketrain[0].item() for spiketrain in ordinary_spiketrains + ] + first_spikes_equilibrium = [ + spiketrain[0].item() for spiketrain in equilbrium_spiketrains + ] mean_first_spike_ordinary = np.mean(first_spikes_ordinary) mean_first_spike_equilibrium = np.mean(first_spikes_equilibrium) # for regular spike trains (CV=0.5 here) the first spike # in equilibrium is on average than in the ordinary case - self.assertLess(mean_first_spike_equilibrium, - mean_first_spike_ordinary) + self.assertLess(mean_first_spike_equilibrium, mean_first_spike_ordinary) class NonStationaryPoissonProcessTestCase(unittest.TestCase): def setUp(self): rate_list = [[20]] * 1000 + [[200]] * 1000 self.rate_profile = neo.AnalogSignal( - rate_list * pq.Hz, sampling_period=0.001 * pq.s) + rate_list * pq.Hz, sampling_period=0.001 * pq.s + ) rate_0 = [[0]] * 1000 self.rate_profile_0 = neo.AnalogSignal( - rate_0 * pq.Hz, sampling_period=0.001 * pq.s) + rate_0 * pq.Hz, sampling_period=0.001 * pq.s + ) rate_negative = [[-1]] * 1000 self.rate_profile_negative = neo.AnalogSignal( - rate_negative * pq.Hz, sampling_period=0.001 * pq.s) + rate_negative * pq.Hz, sampling_period=0.001 * pq.s + ) def test_statistics(self): # This is a statistical test that has a non-zero chance of failure @@ -669,26 +739,26 @@ def test_statistics(self): for refractory_period in (3 * pq.ms, None): np.random.seed(seed=12345) spiketrain_old = inhomogeneous_poisson_process( - rate, refractory_period=refractory_period) + rate, refractory_period=refractory_period + ) np.random.seed(seed=12345) process = NonStationaryPoissonProcess - spiketrain = process(rate, refractory_period=refractory_period - ).generate_spiketrain() + spiketrain = process( + rate, refractory_period=refractory_period + ).generate_spiketrain() - assert_allclose( - spiketrain_old.magnitude, spiketrain.magnitude) + assert_allclose(spiketrain_old.magnitude, spiketrain.magnitude) intervals = isi(spiketrain) # Computing expected statistics and percentiles - expected_spike_count = (np.sum( - rate) * rate.sampling_period).simplified - percentile_count = poisson.ppf(.999, expected_spike_count) - expected_min_isi = (1 / np.min(rate)) - expected_max_isi = (1 / np.max(rate)) - percentile_min_isi = expon.ppf(.999, expected_min_isi) - percentile_max_isi = expon.ppf(.999, expected_max_isi) + expected_spike_count = (np.sum(rate) * rate.sampling_period).simplified + percentile_count = poisson.ppf(0.999, expected_spike_count) + expected_min_isi = 1 / np.min(rate) + expected_max_isi = 1 / np.max(rate) + percentile_min_isi = expon.ppf(0.999, expected_min_isi) + percentile_max_isi = expon.ppf(0.999, expected_max_isi) # Check that minimal ISI is greater than the refractory_period if refractory_period is not None: @@ -704,44 +774,53 @@ def test_statistics(self): self.assertEqual(rate.t_start, spiketrain.t_start) # Testing type - spiketrain_as_array = NonStationaryPoissonProcess( - rate).generate_spiketrain(as_array=True) + spiketrain_as_array = NonStationaryPoissonProcess(rate).generate_spiketrain( + as_array=True + ) self.assertTrue(isinstance(spiketrain_as_array, np.ndarray)) self.assertTrue(isinstance(spiketrain, neo.SpikeTrain)) # Testing type for refractory period refractory_period = 3 * pq.ms spiketrain = NonStationaryPoissonProcess( - rate, refractory_period=refractory_period).generate_spiketrain() + rate, refractory_period=refractory_period + ).generate_spiketrain() spiketrain_as_array = NonStationaryPoissonProcess( - rate, refractory_period=refractory_period).generate_spiketrain( - as_array=True) + rate, refractory_period=refractory_period + ).generate_spiketrain(as_array=True) self.assertTrue(isinstance(spiketrain_as_array, np.ndarray)) self.assertTrue(isinstance(spiketrain, neo.SpikeTrain)) # Check that to high refractory period raises error self.assertRaises( - ValueError, NonStationaryPoissonProcess, + ValueError, + NonStationaryPoissonProcess, self.rate_profile, - refractory_period=1000 * pq.ms) + refractory_period=1000 * pq.ms, + ) def test_effective_rate_refractory_period(self): np.random.seed(27) rate_expected = 10 * pq.Hz refractory_period = 90 * pq.ms # 10 ms of effective ISI - rates = neo.AnalogSignal(np.repeat(rate_expected, 1000), units=pq.Hz, - t_start=0 * pq.ms, sampling_rate=1 * pq.Hz) + rates = neo.AnalogSignal( + np.repeat(rate_expected, 1000), + units=pq.Hz, + t_start=0 * pq.ms, + sampling_rate=1 * pq.Hz, + ) spiketrain = NonStationaryPoissonProcess( - rates, refractory_period=refractory_period).generate_spiketrain() + rates, refractory_period=refractory_period + ).generate_spiketrain() rate_obtained = len(spiketrain) / spiketrain.t_stop self.assertAlmostEqual( - rate_expected.simplified.item(), - rate_obtained.simplified.item(), places=1) + rate_expected.simplified.item(), rate_obtained.simplified.item(), places=1 + ) intervals_inhomo = isi(spiketrain) - isi_mean_expected = 1. / rate_expected - self.assertAlmostEqual(isi_mean_expected.simplified, - intervals_inhomo.mean().simplified, - places=3) + isi_mean_expected = 1.0 / rate_expected + self.assertAlmostEqual( + isi_mean_expected.simplified, intervals_inhomo.mean().simplified, places=3 + ) def test_zero_rate(self): for refractory_period in (3 * pq.ms, None): @@ -751,29 +830,37 @@ def test_zero_rate(self): ).generate_spiketrain() self.assertEqual(spiketrain.size, 0) self.assertRaises( - ValueError, NonStationaryPoissonProcess, - self.rate_profile, refractory_period=5) + ValueError, + NonStationaryPoissonProcess, + self.rate_profile, + refractory_period=5, + ) def test_negative_rates(self): for refractory_period in (3 * pq.ms, None): process = NonStationaryPoissonProcess self.assertRaises( - ValueError, process, + ValueError, + process, self.rate_profile_negative, - refractory_period=refractory_period) + refractory_period=refractory_period, + ) class NonStationaryGammaTestCase(unittest.TestCase): def setUp(self): rate_list = [[20]] * 1000 + [[200]] * 1000 self.rate_profile = neo.AnalogSignal( - rate_list * pq.Hz, sampling_period=0.001 * pq.s) + rate_list * pq.Hz, sampling_period=0.001 * pq.s + ) rate_0 = [[0]] * 1000 self.rate_profile_0 = neo.AnalogSignal( - rate_0 * pq.Hz, sampling_period=0.001 * pq.s) + rate_0 * pq.Hz, sampling_period=0.001 * pq.s + ) rate_negative = [[-1]] * 1000 self.rate_profile_negative = neo.AnalogSignal( - rate_negative * pq.Hz, sampling_period=0.001 * pq.s) + rate_negative * pq.Hz, sampling_period=0.001 * pq.s + ) def test_statistics(self): # This is a statistical test that has a non-zero chance of failure @@ -784,22 +871,23 @@ def test_statistics(self): for rate in [self.rate_profile, self.rate_profile.rescale(pq.kHz)]: np.random.seed(seed=12345) spiketrain_old = inhomogeneous_gamma_process( - rate, shape_factor=shape_factor) + rate, shape_factor=shape_factor + ) np.random.seed(seed=12345) spiketrain = NonStationaryGammaProcess( - rate, shape_factor=shape_factor).generate_spiketrain() + rate, shape_factor=shape_factor + ).generate_spiketrain() assert_allclose(spiketrain_old.magnitude, spiketrain.magnitude) intervals = isi(spiketrain) # Computing expected statistics and percentiles - expected_spike_count = (np.sum( - rate) * rate.sampling_period).simplified - percentile_count = poisson.ppf(.999, expected_spike_count) - expected_min_isi = (1 / np.min(rate)) - expected_max_isi = (1 / np.max(rate)) - percentile_min_isi = expon.ppf(.999, expected_min_isi) - percentile_max_isi = expon.ppf(.999, expected_max_isi) + expected_spike_count = (np.sum(rate) * rate.sampling_period).simplified + percentile_count = poisson.ppf(0.999, expected_spike_count) + expected_min_isi = 1 / np.min(rate) + expected_max_isi = 1 / np.max(rate) + percentile_min_isi = expon.ppf(0.999, expected_min_isi) + percentile_max_isi = expon.ppf(0.999, expected_max_isi) # Testing (each should fail 1 every 1000 times) self.assertLess(spiketrain.size, percentile_count) @@ -812,30 +900,36 @@ def test_statistics(self): # Testing type spiketrain_as_array = NonStationaryGammaProcess( - rate, shape_factor=shape_factor).generate_spiketrain( - as_array=True) + rate, shape_factor=shape_factor + ).generate_spiketrain(as_array=True) self.assertTrue(isinstance(spiketrain_as_array, np.ndarray)) self.assertTrue(isinstance(spiketrain, neo.SpikeTrain)) # check error if rate has wrong format self.assertRaises( - ValueError, NonStationaryGammaProcess, - rate_signal=[0.1, 2.], - shape_factor=shape_factor) + ValueError, + NonStationaryGammaProcess, + rate_signal=[0.1, 2.0], + shape_factor=shape_factor, + ) # check error if negative values in rate self.assertRaises( - ValueError, NonStationaryGammaProcess, + ValueError, + NonStationaryGammaProcess, rate_signal=neo.AnalogSignal( - [-0.1, 10.] * pq.Hz, sampling_period=0.001 * pq.s), - shape_factor=shape_factor) + [-0.1, 10.0] * pq.Hz, sampling_period=0.001 * pq.s + ), + shape_factor=shape_factor, + ) # check error if rate is empty self.assertRaises( - ValueError, NonStationaryGammaProcess, - rate_signal=neo.AnalogSignal( - [] * pq.Hz, sampling_period=0.001 * pq.s), - shape_factor=shape_factor) + ValueError, + NonStationaryGammaProcess, + rate_signal=neo.AnalogSignal([] * pq.Hz, sampling_period=0.001 * pq.s), + shape_factor=shape_factor, + ) def test_recovered_firing_rate_profile(self): np.random.seed(54) @@ -844,40 +938,47 @@ def test_recovered_firing_rate_profile(self): sampling_period = 0.001 * pq.s # an arbitrary rate profile - profile = 0.5 * (1 + np.sin(np.arange(t_start.item(), t_stop.item(), - sampling_period.item()))) + profile = 0.5 * ( + 1 + np.sin(np.arange(t_start.item(), t_stop.item(), sampling_period.item())) + ) n_trials = 100 rtol = 0.1 # 10% of deviation allowed kernel = kernels.RectangularKernel(sigma=0.25 * pq.s) for rate in (10 * pq.Hz, 50 * pq.Hz): - rate_profile = neo.AnalogSignal(rate * profile, - sampling_period=sampling_period) + rate_profile = neo.AnalogSignal( + rate * profile, sampling_period=sampling_period + ) # the recovered firing rate profile should not depend on the # shape factor; here we test float and integer values of the shape # factor: the method supports float values that is not trivial # for inhomogeneous gamma process generation - for shape_factor in (2.5, 10.): + for shape_factor in (2.5, 10.0): spiketrains = NonStationaryGammaProcess( rate_profile, shape_factor=shape_factor ).generate_n_spiketrains(n_trials) - rate_recovered = instantaneous_rate( - spiketrains, - sampling_period=sampling_period, - kernel=kernel, - t_start=t_start, - t_stop=t_stop, trim=True).sum(axis=1) / n_trials + rate_recovered = ( + instantaneous_rate( + spiketrains, + sampling_period=sampling_period, + kernel=kernel, + t_start=t_start, + t_stop=t_stop, + trim=True, + ).sum(axis=1) + / n_trials + ) rate_recovered = rate_recovered.flatten().magnitude trim = (rate_profile.shape[0] - rate_recovered.shape[0]) // 2 rate_profile_valid = rate_profile.magnitude.squeeze() - rate_profile_valid = rate_profile_valid[trim: -trim] - assert_allclose(rate_recovered, rate_profile_valid, - rtol=0, atol=rtol * rate.item()) + rate_profile_valid = rate_profile_valid[trim:-trim] + assert_allclose( + rate_recovered, rate_profile_valid, rtol=0, atol=rtol * rate.item() + ) class NPoissonTestCase(unittest.TestCase): - def setUp(self): self.n = 4 self.rate = 10 * pq.Hz @@ -886,10 +987,7 @@ def setUp(self): def test_poisson(self): # Check the output types for input rate + n number of neurons - pp = _n_poisson( - rate=self.rate, - t_stop=self.t_stop, - n_spiketrains=self.n) + pp = _n_poisson(rate=self.rate, t_stop=self.t_stop, n_spiketrains=self.n) self.assertIsInstance(pp, list) self.assertIsInstance(pp[0], neo.core.spiketrain.SpikeTrain) self.assertEqual(pp[0].simplified.units, 1000 * pq.ms) @@ -904,23 +1002,26 @@ def test_poisson(self): def test_poisson_error(self): # Dimensionless rate - self.assertRaises( - ValueError, _n_poisson, rate=5, t_stop=self.t_stop) + self.assertRaises(ValueError, _n_poisson, rate=5, t_stop=self.t_stop) # Negative rate - self.assertRaises( - ValueError, _n_poisson, rate=-5 * pq.Hz, t_stop=self.t_stop) + self.assertRaises(ValueError, _n_poisson, rate=-5 * pq.Hz, t_stop=self.t_stop) # Negative value when rate is a list self.assertRaises( - ValueError, _n_poisson, rate=[-5, 3] * pq.Hz, - t_stop=self.t_stop) + ValueError, _n_poisson, rate=[-5, 3] * pq.Hz, t_stop=self.t_stop + ) # Negative n self.assertRaises( - ValueError, _n_poisson, rate=self.rate, t_stop=self.t_stop, - n_spiketrains=-1) + ValueError, _n_poisson, rate=self.rate, t_stop=self.t_stop, n_spiketrains=-1 + ) # t_start>t_stop self.assertRaises( - ValueError, _n_poisson, rate=self.rate, t_start=4 * pq.ms, - t_stop=3 * pq.ms, n_spiketrains=3) + ValueError, + _n_poisson, + rate=self.rate, + t_start=4 * pq.ms, + t_stop=3 * pq.ms, + n_spiketrains=3, + ) class SingleInteractionProcessTestCase(unittest.TestCase): @@ -944,26 +1045,34 @@ def format_check(self, sip, coinc): def test_sip(self): # Generate an example SIP mode sip, coinc = single_interaction_process( - n_spiketrains=self.n, t_stop=self.t_stop, rate=self.rate, - coincidence_rate=self.rate_c, return_coincidences=True) + n_spiketrains=self.n, + t_stop=self.t_stop, + rate=self.rate, + coincidence_rate=self.rate_c, + return_coincidences=True, + ) # Check the output types self.format_check(sip, coinc) self.assertEqual( - len(coinc[0]), (self.rate_c * self.t_stop).simplified.magnitude) + len(coinc[0]), (self.rate_c * self.t_stop).simplified.magnitude + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") # Generate an example SIP mode giving a list of rates as imput sip, coinc = single_interaction_process( - t_stop=self.t_stop, rate=self.rates, - coincidence_rate=self.rate_c, return_coincidences=True) + t_stop=self.t_stop, + rate=self.rates, + coincidence_rate=self.rate_c, + return_coincidences=True, + ) # Check the output types self.format_check(sip, coinc) self.assertEqual( - len(coinc[0]), - (self.rate_c * self.t_stop).rescale(pq.dimensionless)) + len(coinc[0]), (self.rate_c * self.t_stop).rescale(pq.dimensionless) + ) # Generate an example SIP mode stochastic number of coincidences sip = single_interaction_process( @@ -971,8 +1080,9 @@ def test_sip(self): t_stop=self.t_stop, rate=self.rate, coincidence_rate=self.rate_c, - coincidences='stochastic', - return_coincidences=False) + coincidences="stochastic", + return_coincidences=False, + ) # Check the output types self.assertEqual(type(sip), list) @@ -982,22 +1092,40 @@ def test_sip(self): def test_sip_error(self): # Negative rate self.assertRaises( - ValueError, single_interaction_process, n_spiketrains=self.n, + ValueError, + single_interaction_process, + n_spiketrains=self.n, rate=-5 * pq.Hz, - coincidence_rate=self.rate_c, t_stop=self.t_stop) + coincidence_rate=self.rate_c, + t_stop=self.t_stop, + ) # Negative coincidence rate self.assertRaises( - ValueError, single_interaction_process, n_spiketrains=self.n, - rate=self.rate, coincidence_rate=-3 * pq.Hz, t_stop=self.t_stop) + ValueError, + single_interaction_process, + n_spiketrains=self.n, + rate=self.rate, + coincidence_rate=-3 * pq.Hz, + t_stop=self.t_stop, + ) # Negative value when rate is a list self.assertRaises( - ValueError, single_interaction_process, n_spiketrains=self.n, - rate=[-5, 3, 4, 2] * pq.Hz, coincidence_rate=self.rate_c, - t_stop=self.t_stop) + ValueError, + single_interaction_process, + n_spiketrains=self.n, + rate=[-5, 3, 4, 2] * pq.Hz, + coincidence_rate=self.rate_c, + t_stop=self.t_stop, + ) # Negative n self.assertRaises( - ValueError, single_interaction_process, n_spiketrains=-1, - rate=self.rate, coincidence_rate=self.rate_c, t_stop=self.t_stop) + ValueError, + single_interaction_process, + n_spiketrains=-1, + rate=self.rate, + coincidence_rate=self.rate_c, + t_stop=self.t_stop, + ) # Rate_c < rate self.assertRaises( ValueError, @@ -1005,21 +1133,19 @@ def test_sip_error(self): n_spiketrains=self.n, rate=self.rate, coincidence_rate=self.rate + 1 * pq.Hz, - t_stop=self.t_stop) + t_stop=self.t_stop, + ) class CppTestCase(unittest.TestCase): - def format_check(self, cpp, amplitude_distribution, t_start, t_stop): - self.assertEqual( - [type(train) for train in cpp], - [neo.SpikeTrain] * len(cpp)) + self.assertEqual([type(train) for train in cpp], [neo.SpikeTrain] * len(cpp)) self.assertEqual(cpp[0].simplified.units, 1000 * pq.ms) self.assertEqual(type(cpp), list) # testing quantities format of the output self.assertEqual( - [train.simplified.units for train in cpp], - [1000 * pq.ms] * len(cpp)) + [train.simplified.units for train in cpp], [1000 * pq.ms] * len(cpp) + ) # testing output t_start t_stop for spiketrain in cpp: self.assertEqual(spiketrain.t_stop, t_stop) @@ -1028,12 +1154,11 @@ def format_check(self, cpp, amplitude_distribution, t_start, t_stop): def test_cpp_hom(self): # testing output with generic inputs - amplitude_distribution = np.array([0, .9, .1]) + amplitude_distribution = np.array([0, 0.9, 0.1]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = 3 * pq.Hz - cpp_hom = cpp(rate, amplitude_distribution, - t_stop, t_start=t_start) + cpp_hom = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) # testing the output formats self.format_check(cpp_hom, amplitude_distribution, t_start, t_stop) @@ -1041,8 +1166,7 @@ def test_cpp_hom(self): t_stop = 10000 * pq.ms t_start = 5 * 1000 * pq.ms rate = 3 * pq.Hz - cpp_unit = cpp(rate, amplitude_distribution, - t_stop, t_start=t_start) + cpp_unit = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) self.assertEqual(cpp_unit[0].units, t_stop.units) self.assertEqual(cpp_unit[0].t_stop.units, t_stop.units) @@ -1053,33 +1177,30 @@ def test_cpp_hom(self): t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = 3 * pq.Hz - cpp_hom_empty = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_hom_empty = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) self.assertEqual( - [len(train) for train in cpp_hom_empty], [0] * len(cpp_hom_empty)) + [len(train) for train in cpp_hom_empty], [0] * len(cpp_hom_empty) + ) # testing output with rate equal to 0 - amplitude_distribution = np.array([0, .9, .1]) + amplitude_distribution = np.array([0, 0.9, 0.1]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = 0 * pq.Hz - cpp_hom_empty_r = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_hom_empty_r = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) self.assertEqual( - [len(train) for train in cpp_hom_empty_r], [0] * len( - cpp_hom_empty_r)) + [len(train) for train in cpp_hom_empty_r], [0] * len(cpp_hom_empty_r) + ) # testing output with same spike trains in output - amplitude_distribution = np.array([0., 0., 1.]) + amplitude_distribution = np.array([0.0, 0.0, 1.0]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = 3 * pq.Hz - cpp_hom_eq = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_hom_eq = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) - self.assertTrue( - np.allclose(cpp_hom_eq[0].magnitude, cpp_hom_eq[1].magnitude)) + self.assertTrue(np.allclose(cpp_hom_eq[0].magnitude, cpp_hom_eq[1].magnitude)) def test_cpp_hom_errors(self): # testing raises of ValueError (wrong inputs) @@ -1089,7 +1210,8 @@ def test_cpp_hom_errors(self): cpp, amplitude_distribution=[], t_stop=10 * 1000 * pq.ms, - rate=3 * pq.Hz) + rate=3 * pq.Hz, + ) # testing sum of amplitude>1 self.assertRaises( @@ -1097,34 +1219,42 @@ def test_cpp_hom_errors(self): cpp, amplitude_distribution=[1, 1, 1], t_stop=10 * 1000 * pq.ms, - rate=3 * pq.Hz) + rate=3 * pq.Hz, + ) # testing negative value in the amplitude self.assertRaises( - ValueError, cpp, amplitude_distribution=[-1, 1, 1], + ValueError, + cpp, + amplitude_distribution=[-1, 1, 1], t_stop=10 * 1000 * pq.ms, - rate=3 * pq.Hz) + rate=3 * pq.Hz, + ) # test negative rate with warnings.catch_warnings(): warnings.simplefilter("ignore") # Catches RuntimeWarning: invalid value encountered in sqrt # number = np.ceil(n + 3 * np.sqrt(n)), when `n` == -3 Hz. self.assertRaises( - ValueError, cpp, amplitude_distribution=[0, 1, 0], + ValueError, + cpp, + amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, - rate=-3 * pq.Hz) + rate=-3 * pq.Hz, + ) # test wrong unit for rate self.assertRaises( ValueError, cpp, amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, - rate=3 * 1000 * pq.ms) + rate=3 * 1000 * pq.ms, + ) # testing raises of AttributeError (missing input units) # Testing missing unit to t_stop self.assertRaises( - ValueError, cpp, amplitude_distribution=[0, 1, 0], t_stop=10, - rate=3 * pq.Hz) + ValueError, cpp, amplitude_distribution=[0, 1, 0], t_stop=10, rate=3 * pq.Hz + ) # Testing missing unit to t_start self.assertRaises( ValueError, @@ -1132,16 +1262,20 @@ def test_cpp_hom_errors(self): amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, rate=3 * pq.Hz, - t_start=3) + t_start=3, + ) # testing rate missing unit self.assertRaises( - AttributeError, cpp, amplitude_distribution=[0, 1, 0], + AttributeError, + cpp, + amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, - rate=3) + rate=3, + ) def test_cpp_het(self): # testing output with generic inputs - amplitude_distribution = np.array([0, .9, .1]) + amplitude_distribution = np.array([0, 0.9, 0.1]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = [3, 4] * pq.Hz @@ -1149,8 +1283,7 @@ def test_cpp_het(self): warnings.simplefilter("ignore") # Catch RuntimeWarning: divide by zero encountered in true_divide # mean_interval = 1 / rate.magnitude, when rate == 0 Hz. - cpp_het = cpp(rate, amplitude_distribution, - t_stop, t_start=t_start) + cpp_het = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) # testing the output formats self.format_check(cpp_het, amplitude_distribution, t_start, t_stop) self.assertEqual(len(cpp_het), len(rate)) @@ -1159,8 +1292,7 @@ def test_cpp_het(self): t_stop = 10000 * pq.ms t_start = 5 * 1000 * pq.ms rate = [3, 4] * pq.Hz - cpp_unit = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_unit = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) self.assertEqual(cpp_unit[0].units, t_stop.units) self.assertEqual(cpp_unit[0].t_stop.units, t_stop.units) @@ -1170,32 +1302,30 @@ def test_cpp_het(self): t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = [3, 4] * pq.Hz - cpp_het_empty = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_het_empty = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) self.assertEqual(len(cpp_het_empty[0]), 0) # testing output with rate equal to 0 - amplitude_distribution = np.array([0, .9, .1]) + amplitude_distribution = np.array([0, 0.9, 0.1]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = [0, 0] * pq.Hz - cpp_het_empty_r = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_het_empty_r = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) self.assertEqual( - [len(train) for train in cpp_het_empty_r], [0] * len( - cpp_het_empty_r)) + [len(train) for train in cpp_het_empty_r], [0] * len(cpp_het_empty_r) + ) # testing completely synchronous spike trains amplitude_distribution = np.array([0, 0, 1]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = [3, 3] * pq.Hz - cpp_het_eq = cpp( - rate, amplitude_distribution, t_stop, t_start=t_start) + cpp_het_eq = cpp(rate, amplitude_distribution, t_stop, t_start=t_start) - self.assertTrue(np.allclose( - cpp_het_eq[0].magnitude, cpp_het_eq[1].magnitude)) + self.assertTrue( + np.allclose(cpp_het_eq[0].magnitude, cpp_het_eq[1].magnitude) + ) def test_cpp_het_err(self): # testing raises of ValueError (wrong inputs) @@ -1205,50 +1335,64 @@ def test_cpp_het_err(self): cpp, amplitude_distribution=[], t_stop=10 * 1000 * pq.ms, - rate=[3, 4] * pq.Hz) + rate=[3, 4] * pq.Hz, + ) # testing sum amplitude>1 self.assertRaises( ValueError, cpp, amplitude_distribution=[1, 1, 1], t_stop=10 * 1000 * pq.ms, - rate=[3, 4] * pq.Hz) + rate=[3, 4] * pq.Hz, + ) # testing amplitude negative value self.assertRaises( - ValueError, cpp, amplitude_distribution=[-1, 1, 1], + ValueError, + cpp, + amplitude_distribution=[-1, 1, 1], t_stop=10 * 1000 * pq.ms, - rate=[3, 4] * pq.Hz) + rate=[3, 4] * pq.Hz, + ) # testing negative rate - self.assertRaises(ValueError, cpp, amplitude_distribution=[ - 0, 1, 0], t_stop=10 * 1000 * pq.ms, rate=[-3, 4] * pq.Hz) + self.assertRaises( + ValueError, + cpp, + amplitude_distribution=[0, 1, 0], + t_stop=10 * 1000 * pq.ms, + rate=[-3, 4] * pq.Hz, + ) # testing empty rate self.assertRaises( ValueError, cpp, amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, - rate=[] * pq.Hz) + rate=[] * pq.Hz, + ) # testing empty amplitude self.assertRaises( ValueError, cpp, amplitude_distribution=[], t_stop=10 * 1000 * pq.ms, - rate=[3, 4] * pq.Hz) + rate=[3, 4] * pq.Hz, + ) # testing different len(A)-1 and len(rate) self.assertRaises( ValueError, cpp, amplitude_distribution=[0, 1], t_stop=10 * 1000 * pq.ms, - rate=[3, 4] * pq.Hz) + rate=[3, 4] * pq.Hz, + ) # testing rate with different unit from Hz self.assertRaises( ValueError, cpp, amplitude_distribution=[0, 1], t_stop=10 * 1000 * pq.ms, - rate=[3, 4] * 1000 * pq.ms) + rate=[3, 4] * 1000 * pq.ms, + ) # Testing analytical constrain between amplitude and rate self.assertRaises( ValueError, @@ -1256,13 +1400,18 @@ def test_cpp_het_err(self): amplitude_distribution=[0, 0, 1], t_stop=10 * 1000 * pq.ms, rate=[3, 4] * pq.Hz, - t_start=3) + t_start=3, + ) # testing raises of AttributeError (missing input units) # Testing missing unit to t_stop self.assertRaises( - ValueError, cpp, amplitude_distribution=[0, 1, 0], t_stop=10, - rate=[3, 4] * pq.Hz) + ValueError, + cpp, + amplitude_distribution=[0, 1, 0], + t_stop=10, + rate=[3, 4] * pq.Hz, + ) # Testing missing unit to t_start self.assertRaises( ValueError, @@ -1270,28 +1419,29 @@ def test_cpp_het_err(self): amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, rate=[3, 4] * pq.Hz, - t_start=3) + t_start=3, + ) # Testing missing unit to rate self.assertRaises( - AttributeError, cpp, amplitude_distribution=[0, 1, 0], + AttributeError, + cpp, + amplitude_distribution=[0, 1, 0], t_stop=10 * 1000 * pq.ms, - rate=[3, 4]) + rate=[3, 4], + ) def test_cpp_jttered(self): # testing output with generic inputs - amplitude_distribution = np.array([0, .9, .1]) + amplitude_distribution = np.array([0, 0.9, 0.1]) t_stop = 10 * 1000 * pq.ms t_start = 5 * 1000 * pq.ms rate = 3 * pq.Hz cpp_shift = cpp( - rate, - amplitude_distribution, - t_stop, - t_start=t_start, - shift=3 * pq.ms) + rate, amplitude_distribution, t_stop, t_start=t_start, shift=3 * pq.ms + ) # testing the output formats self.format_check(cpp_shift, amplitude_distribution, t_start, t_stop) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_spike_train_surrogates.py b/elephant/test/test_spike_train_surrogates.py index a518bdc4f..36f864ba4 100644 --- a/elephant/test/test_spike_train_surrogates.py +++ b/elephant/test/test_spike_train_surrogates.py @@ -19,7 +19,6 @@ class SurrogatesTestCase(unittest.TestCase): - def setUp(self): np.random.seed(0) random.seed(0) @@ -30,11 +29,12 @@ def setUpClass(cls) -> None: cls.st1 = st1 def test_dither_spikes_output_format(self): - self.st1.t_stop = .5 * pq.s + self.st1.t_stop = 0.5 * pq.s n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr.dither_spikes( - self.st1, dither=dither, n_surrogates=n_surrogates) + self.st1, dither=dither, n_surrogates=n_surrogates + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -45,15 +45,13 @@ def test_dither_spikes_output_format(self): self.assertEqual(surrogate_train.t_start, self.st1.t_start) self.assertEqual(surrogate_train.t_stop, self.st1.t_stop) self.assertEqual(len(surrogate_train), len(self.st1)) - assert_array_less(0., np.diff(surrogate_train)) # check ordering + assert_array_less(0.0, np.diff(surrogate_train)) # check ordering def test_dither_spikes_empty_train(self): - st = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) dither = 10 * pq.ms - surrogate_train = surr.dither_spikes( - st, dither=dither, n_surrogates=1)[0] + surrogate_train = surr.dither_spikes(st, dither=dither, n_surrogates=1)[0] self.assertEqual(len(surrogate_train), 0) def test_dither_spikes_refactory_period_zero_or_none(self): @@ -63,22 +61,31 @@ def test_dither_spikes_refactory_period_zero_or_none(self): np.random.seed(42) surrogate_trains_zero = surr.dither_spikes( - self.st1, dither, decimals=decimals, n_surrogates=n_surrogates, - refractory_period=0) + self.st1, + dither, + decimals=decimals, + n_surrogates=n_surrogates, + refractory_period=0, + ) np.random.seed(42) surrogate_trains_none = surr.dither_spikes( - self.st1, dither, decimals=decimals, n_surrogates=n_surrogates, - refractory_period=None) + self.st1, + dither, + decimals=decimals, + n_surrogates=n_surrogates, + refractory_period=None, + ) np.testing.assert_array_almost_equal( - surrogate_trains_zero[0].magnitude, - surrogate_trains_none[0].magnitude) + surrogate_trains_zero[0].magnitude, surrogate_trains_none[0].magnitude + ) def test_dither_spikes_output_decimals(self): n_surrogates = 2 dither = 10 * pq.ms np.random.seed(42) surrogate_trains = surr.dither_spikes( - self.st1, dither=dither, decimals=3, n_surrogates=n_surrogates) + self.st1, dither=dither, decimals=3, n_surrogates=n_surrogates + ) np.random.seed(42) dither_values = np.random.random_sample((n_surrogates, len(self.st1))) @@ -87,8 +94,10 @@ def test_dither_spikes_output_decimals(self): observed_non_dithered = 0 for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): - if surrogate_train[i] - int(surrogate_train[i]) * \ - pq.ms == surrogate_train[i] - surrogate_train[i]: + if ( + surrogate_train[i] - int(surrogate_train[i]) * pq.ms + == surrogate_train[i] - surrogate_train[i] + ): observed_non_dithered += 1 self.assertEqual(observed_non_dithered, expected_non_dithered) @@ -97,21 +106,25 @@ def test_dither_spikes_false_edges(self): n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr.dither_spikes( - self.st1, dither=dither, n_surrogates=n_surrogates, edges=False) + self.st1, dither=dither, n_surrogates=n_surrogates, edges=False + ) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): self.assertLessEqual(surrogate_train[i], self.st1.t_stop) def test_dither_spikes_with_refractory_period_output_format(self): - - spiketrain = neo.SpikeTrain([90, 93, 97, 100, 105, - 150, 180, 350] * pq.ms, t_stop=.5 * pq.s) + spiketrain = neo.SpikeTrain( + [90, 93, 97, 100, 105, 150, 180, 350] * pq.ms, t_stop=0.5 * pq.s + ) n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr.dither_spikes( - spiketrain, dither=dither, n_surrogates=n_surrogates, - refractory_period=4 * pq.ms) + spiketrain, + dither=dither, + n_surrogates=n_surrogates, + refractory_period=4 * pq.ms, + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -123,25 +136,30 @@ def test_dither_spikes_with_refractory_period_output_format(self): self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) self.assertEqual(len(surrogate_train), len(spiketrain)) # Check that refractory period is conserved - self.assertLessEqual(np.min(np.diff(spiketrain)), - np.min(np.diff(surrogate_train))) + self.assertLessEqual( + np.min(np.diff(spiketrain)), np.min(np.diff(surrogate_train)) + ) sigma_displacement = np.std(surrogate_train - spiketrain) # Check that spikes are moved self.assertLessEqual(dither / 10, sigma_displacement) # Spikes are not moved more than dither self.assertLessEqual(sigma_displacement, dither) - self.assertRaises(ValueError, surr.dither_spikes, - spiketrain, dither=dither, refractory_period=3) + self.assertRaises( + ValueError, + surr.dither_spikes, + spiketrain, + dither=dither, + refractory_period=3, + ) def test_dither_spikes_with_refractory_period_empty_train(self): - spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) dither = 10 * pq.ms surrogate_train = surr.dither_spikes( - spiketrain, dither=dither, n_surrogates=1, - refractory_period=4 * pq.ms)[0] + spiketrain, dither=dither, n_surrogates=1, refractory_period=4 * pq.ms + )[0] self.assertEqual(len(surrogate_train), 0) def test_dither_spikes_regression_issue_586(self): @@ -158,14 +176,17 @@ def test_dither_spikes_regression_issue_586(self): # Generate one spiketrain with a spike close to t_stop t_stop = 2 * pq.s st = stg.StationaryPoissonProcess( - rate=10 * pq.Hz, t_stop=t_stop).generate_spiketrain() - st = neo.SpikeTrain(np.hstack([st.magnitude, [1.9999999]]), - units=st.units, t_stop=t_stop) + rate=10 * pq.Hz, t_stop=t_stop + ).generate_spiketrain() + st = neo.SpikeTrain( + np.hstack([st.magnitude, [1.9999999]]), units=st.units, t_stop=t_stop + ) # Dither np.random.seed(5) surrogate_trains = surr.dither_spikes( - st, dither=15 * pq.ms, n_surrogates=30, edges=True, decimals=2) + st, dither=15 * pq.ms, n_surrogates=30, edges=True, decimals=2 + ) for surrogate in surrogate_trains: with self.subTest(surrogate): self.assertLess(surrogate[-1], surrogate.t_stop) @@ -173,8 +194,7 @@ def test_dither_spikes_regression_issue_586(self): def test_randomise_spikes_output_format(self): n_surrogates = 2 - surrogate_trains = surr.randomise_spikes( - self.st1, n_surrogates=n_surrogates) + surrogate_trains = surr.randomise_spikes(self.st1, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -187,7 +207,6 @@ def test_randomise_spikes_output_format(self): self.assertEqual(len(surrogate_train), len(self.st1)) def test_randomise_spikes_empty_train(self): - spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) surrogate_train = surr.randomise_spikes(spiketrain, n_surrogates=1)[0] @@ -196,19 +215,19 @@ def test_randomise_spikes_empty_train(self): def test_randomise_spikes_output_decimals(self): n_surrogates = 2 surrogate_trains = surr.randomise_spikes( - self.st1, n_surrogates=n_surrogates, decimals=3) + self.st1, n_surrogates=n_surrogates, decimals=3 + ) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): - self.assertNotEqual(surrogate_train[i] - - int(surrogate_train[i]) * - pq.ms, surrogate_train[i] - - surrogate_train[i]) + self.assertNotEqual( + surrogate_train[i] - int(surrogate_train[i]) * pq.ms, + surrogate_train[i] - surrogate_train[i], + ) def test_shuffle_isis_output_format(self): n_surrogates = 2 - surrogate_trains = surr.shuffle_isis( - self.st1, n_surrogates=n_surrogates) + surrogate_trains = surr.shuffle_isis(self.st1, n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -221,7 +240,6 @@ def test_shuffle_isis_output_format(self): self.assertEqual(len(surrogate_train), len(self.st1)) def test_shuffle_isis_empty_train(self): - spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) surrogate_train = surr.shuffle_isis(spiketrain, n_surrogates=1)[0] @@ -242,8 +260,7 @@ def test_shuffle_isis_same_isis(self): self.assertTrue(np.all(ISIs_orig == ISIs_surr)) def test_shuffle_isis_output_decimals(self): - surrogate_train = surr.shuffle_isis( - self.st1, n_surrogates=1, decimals=95)[0] + surrogate_train = surr.shuffle_isis(self.st1, n_surrogates=1, decimals=95)[0] st_pq = self.st1.view(pq.Quantity) surr_pq = surrogate_train.view(pq.Quantity) @@ -257,28 +274,48 @@ def test_shuffle_isis_output_decimals(self): self.assertTrue(np.all(ISIs_orig == ISIs_surr)) def test_shuffle_isis_with_wrongly_ordered_spikes(self): - surr_method = 'shuffle_isis' + surr_method = "shuffle_isis" n_surr = 30 dither = 15 * pq.ms spiketrain = neo.SpikeTrain( - [39.65696411, 98.93868274, 120.2417674, 134.70971166, - 154.20788924, - 160.29077989, 179.19884034, 212.86773029, 247.59488061, - 273.04095041, - 297.56437605, 344.99204215, 418.55696486, 460.54298334, - 482.82299125, - 524.236052, 566.38966742, 597.87562722, 651.26965293, - 692.39802855, - 740.90285815, 849.45874695, 974.57724848, 8.79247605], - t_start=0. * pq.ms, t_stop=1000. * pq.ms, units=pq.ms) - surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method, - dt=dither) + [ + 39.65696411, + 98.93868274, + 120.2417674, + 134.70971166, + 154.20788924, + 160.29077989, + 179.19884034, + 212.86773029, + 247.59488061, + 273.04095041, + 297.56437605, + 344.99204215, + 418.55696486, + 460.54298334, + 482.82299125, + 524.236052, + 566.38966742, + 597.87562722, + 651.26965293, + 692.39802855, + 740.90285815, + 849.45874695, + 974.57724848, + 8.79247605, + ], + t_start=0.0 * pq.ms, + t_stop=1000.0 * pq.ms, + units=pq.ms, + ) + surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method, dt=dither) def test_dither_spike_train_output_format(self): n_surrogates = 2 shift = 10 * pq.ms surrogate_trains = surr.dither_spike_train( - self.st1, shift=shift, n_surrogates=n_surrogates) + self.st1, shift=shift, n_surrogates=n_surrogates + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -291,32 +328,34 @@ def test_dither_spike_train_output_format(self): self.assertEqual(len(surrogate_train), len(self.st1)) def test_dither_spike_train_empty_train(self): - spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) shift = 10 * pq.ms surrogate_train = surr.dither_spike_train( - spiketrain, shift=shift, n_surrogates=1)[0] + spiketrain, shift=shift, n_surrogates=1 + )[0] self.assertEqual(len(surrogate_train), 0) def test_dither_spike_train_output_decimals(self): n_surrogates = 2 shift = 10 * pq.ms surrogate_trains = surr.dither_spike_train( - self.st1, shift=shift, n_surrogates=n_surrogates, decimals=3) + self.st1, shift=shift, n_surrogates=n_surrogates, decimals=3 + ) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): - self.assertNotEqual(surrogate_train[i] - - int(surrogate_train[i]) * - pq.ms, surrogate_train[i] - - surrogate_train[i]) + self.assertNotEqual( + surrogate_train[i] - int(surrogate_train[i]) * pq.ms, + surrogate_train[i] - surrogate_train[i], + ) def test_dither_spike_train_false_edges(self): n_surrogates = 2 shift = 10 * pq.ms surrogate_trains = surr.dither_spike_train( - self.st1, shift=shift, n_surrogates=n_surrogates, edges=False) + self.st1, shift=shift, n_surrogates=n_surrogates, edges=False + ) for surrogate_train in surrogate_trains: for i in range(len(surrogate_train)): @@ -326,7 +365,8 @@ def test_jitter_spikes_output_format(self): n_surrogates = 2 bin_size = 100 * pq.ms surrogate_trains = surr.jitter_spikes( - self.st1, bin_size=bin_size, n_surrogates=n_surrogates) + self.st1, bin_size=bin_size, n_surrogates=n_surrogates + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -339,31 +379,30 @@ def test_jitter_spikes_output_format(self): self.assertEqual(len(surrogate_train), len(self.st1)) def test_jitter_spikes_empty_train(self): - spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) bin_size = 75 * pq.ms surrogate_train = surr.jitter_spikes( - spiketrain, bin_size=bin_size, n_surrogates=1)[0] + spiketrain, bin_size=bin_size, n_surrogates=1 + )[0] self.assertEqual(len(surrogate_train), 0) def test_jitter_spikes_same_bins(self): bin_size = 100 * pq.ms surrogate_train = surr.jitter_spikes( - self.st1, bin_size=bin_size, n_surrogates=1)[0] + self.st1, bin_size=bin_size, n_surrogates=1 + )[0] bin_ids_orig = np.array( - (self.st1.view( - pq.Quantity) / - bin_size).rescale( - pq.dimensionless).magnitude, - dtype=int) + (self.st1.view(pq.Quantity) / bin_size).rescale(pq.dimensionless).magnitude, + dtype=int, + ) bin_ids_surr = np.array( - (surrogate_train.view( - pq.Quantity) / - bin_size).rescale( - pq.dimensionless).magnitude, - dtype=int) + (surrogate_train.view(pq.Quantity) / bin_size) + .rescale(pq.dimensionless) + .magnitude, + dtype=int, + ) self.assertTrue(np.all(bin_ids_orig == bin_ids_surr)) # Bug encountered when the original and surrogate trains have @@ -371,54 +410,60 @@ def test_jitter_spikes_same_bins(self): self.assertEqual(len(self.st1), len(surrogate_train)) def test_jitter_spikes_unequal_bin_size(self): - - spiketrain = neo.SpikeTrain( - [90, 150, 180, 480] * pq.ms, t_stop=500 * pq.ms) + spiketrain = neo.SpikeTrain([90, 150, 180, 480] * pq.ms, t_stop=500 * pq.ms) bin_size = 75 * pq.ms surrogate_train = surr.jitter_spikes( - spiketrain, bin_size=bin_size, n_surrogates=1)[0] + spiketrain, bin_size=bin_size, n_surrogates=1 + )[0] bin_ids_orig = np.array( - (spiketrain.view( - pq.Quantity) / - bin_size).rescale( - pq.dimensionless).magnitude, - dtype=int) + (spiketrain.view(pq.Quantity) / bin_size) + .rescale(pq.dimensionless) + .magnitude, + dtype=int, + ) bin_ids_surr = np.array( - (surrogate_train.view( - pq.Quantity) / - bin_size).rescale( - pq.dimensionless).magnitude, - dtype=int) + (surrogate_train.view(pq.Quantity) / bin_size) + .rescale(pq.dimensionless) + .magnitude, + dtype=int, + ) self.assertTrue(np.all(bin_ids_orig == bin_ids_surr)) def test_surr_method(self): - - surr_methods = \ - ('dither_spike_train', 'dither_spikes', 'jitter_spikes', - 'randomise_spikes', 'shuffle_isis', 'joint_isi_dithering', - 'dither_spikes_with_refractory_period', 'trial_shifting', - 'bin_shuffling', 'isi_dithering') - - surr_method_kwargs = \ - {'dither_spikes': {}, - 'dither_spikes_with_refractory_period': {'refractory_period': - 3 * pq.ms}, - 'randomise_spikes': {}, - 'shuffle_isis': {}, - 'dither_spike_train': {}, - 'jitter_spikes': {}, - 'bin_shuffling': {'bin_size': 3 * pq.ms}, - 'joint_isi_dithering': {}, - 'isi_dithering': {}, - 'trial_shifting': {'trial_length': 200 * pq.ms, - 'trial_separation': 50 * pq.ms}} + surr_methods = ( + "dither_spike_train", + "dither_spikes", + "jitter_spikes", + "randomise_spikes", + "shuffle_isis", + "joint_isi_dithering", + "dither_spikes_with_refractory_period", + "trial_shifting", + "bin_shuffling", + "isi_dithering", + ) + + surr_method_kwargs = { + "dither_spikes": {}, + "dither_spikes_with_refractory_period": {"refractory_period": 3 * pq.ms}, + "randomise_spikes": {}, + "shuffle_isis": {}, + "dither_spike_train": {}, + "jitter_spikes": {}, + "bin_shuffling": {"bin_size": 3 * pq.ms}, + "joint_isi_dithering": {}, + "isi_dithering": {}, + "trial_shifting": { + "trial_length": 200 * pq.ms, + "trial_separation": 50 * pq.ms, + }, + } dt = 15 * pq.ms - spiketrain = neo.SpikeTrain( - [90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) + spiketrain = neo.SpikeTrain([90, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms) n_surrogates = 3 for method in surr_methods: surrogates = surr.surrogates( @@ -426,48 +471,56 @@ def test_surr_method(self): dt=dt, n_surrogates=n_surrogates, method=method, - **surr_method_kwargs[method] + **surr_method_kwargs[method], ) self.assertTrue(len(surrogates) == n_surrogates) for surrogate_train in surrogates: - self.assertTrue( - isinstance(surrogates[0], neo.SpikeTrain)) + self.assertTrue(isinstance(surrogates[0], neo.SpikeTrain)) self.assertEqual(surrogate_train.units, spiketrain.units) self.assertEqual(surrogate_train.t_start, spiketrain.t_start) self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) self.assertEqual(len(surrogate_train), len(spiketrain)) self.assertTrue(len(surrogates) == n_surrogates) - self.assertRaises(ValueError, surr.surrogates, spiketrain, - n_surrogates=1, - method='spike_shifting', - dt=None, decimals=None, edges=True) - - self.assertRaises(ValueError, surr.surrogates, spiketrain, - method='dither_spikes', dt=None) - - self.assertRaises(TypeError, surr.surrogates, spiketrain.magnitude, - method='dither_spikes', dt=10 * pq.ms) + self.assertRaises( + ValueError, + surr.surrogates, + spiketrain, + n_surrogates=1, + method="spike_shifting", + dt=None, + decimals=None, + edges=True, + ) + + self.assertRaises( + ValueError, surr.surrogates, spiketrain, method="dither_spikes", dt=None + ) + + self.assertRaises( + TypeError, + surr.surrogates, + spiketrain.magnitude, + method="dither_spikes", + dt=10 * pq.ms, + ) def test_joint_isi_dithering_format(self): - - rate = 100. * pq.Hz - t_stop = 1. * pq.s + rate = 100.0 * pq.Hz + t_stop = 1.0 * pq.s process = stg.StationaryPoissonProcess(rate, t_stop=t_stop) spiketrain = process.generate_spiketrain() n_surrogates = 2 dither = 10 * pq.ms # Test fast version - joint_isi_instance = surr.JointISI(spiketrain, dither=dither, - method='fast') - surrogate_trains = joint_isi_instance.dithering( - n_surrogates=n_surrogates) + joint_isi_instance = surr.JointISI(spiketrain, dither=dither, method="fast") + surrogate_trains = joint_isi_instance.dithering(n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) - self.assertEqual(joint_isi_instance.method, 'fast') + self.assertEqual(joint_isi_instance.method, "fast") for surrogate_train in surrogate_trains: self.assertIsInstance(surrogate_train, neo.SpikeTrain) @@ -477,16 +530,14 @@ def test_joint_isi_dithering_format(self): self.assertEqual(len(surrogate_train), len(spiketrain)) # Test window_version - joint_isi_instance = surr.JointISI(spiketrain, - method='window', - dither=2 * dither, - n_bins=50) - surrogate_trains = joint_isi_instance.dithering( - n_surrogates=n_surrogates) + joint_isi_instance = surr.JointISI( + spiketrain, method="window", dither=2 * dither, n_bins=50 + ) + surrogate_trains = joint_isi_instance.dithering(n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) - self.assertEqual(joint_isi_instance.method, 'window') + self.assertEqual(joint_isi_instance.method, "window") for surrogate_train in surrogate_trains: self.assertIsInstance(surrogate_train, neo.SpikeTrain) @@ -496,19 +547,20 @@ def test_joint_isi_dithering_format(self): self.assertEqual(len(surrogate_train), len(spiketrain)) # Test isi_dithering - joint_isi_instance = surr.JointISI(spiketrain, - method='window', - dither=2 * dither, - n_bins=50, - isi_dithering=True, - use_sqrt=True, - cutoff=False) - surrogate_trains = joint_isi_instance.dithering( - n_surrogates=n_surrogates) + joint_isi_instance = surr.JointISI( + spiketrain, + method="window", + dither=2 * dither, + n_bins=50, + isi_dithering=True, + use_sqrt=True, + cutoff=False, + ) + surrogate_trains = joint_isi_instance.dithering(n_surrogates=n_surrogates) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) - self.assertEqual(joint_isi_instance.method, 'window') + self.assertEqual(joint_isi_instance.method, "window") for surrogate_train in surrogate_trains: self.assertIsInstance(surrogate_train, neo.SpikeTrain) @@ -522,7 +574,8 @@ def test_joint_isi_dithering_format(self): spiketrain, dt=15 * pq.ms, n_surrogates=n_surrogates, - method='joint_isi_dithering') + method="joint_isi_dithering", + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -533,10 +586,9 @@ def test_joint_isi_dithering_format(self): self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) self.assertEqual(len(surrogate_train), len(spiketrain)) with self.assertRaises(ValueError): - joint_isi_instance = surr.JointISI(spiketrain, - method='wrong method', - dither=2 * dither, - n_bins=50) + joint_isi_instance = surr.JointISI( + spiketrain, method="wrong method", dither=2 * dither, n_bins=50 + ) def test_joint_isi_dithering_empty_train(self): spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) @@ -545,101 +597,174 @@ def test_joint_isi_dithering_empty_train(self): def test_joint_isi_dithering_output(self): process = stg.StationaryPoissonProcess( - rate=100. * pq.Hz, refractory_period=3 * pq.ms, t_stop=0.1 * pq.s) + rate=100.0 * pq.Hz, refractory_period=3 * pq.ms, t_stop=0.1 * pq.s + ) spiketrain = process.generate_spiketrain() surrogate_train = surr.JointISI(spiketrain).dithering()[0] - ground_truth = [0.0060744, 0.01886591, 0.02732847, 0.03683888, - 0.04569622, 0.05196334, 0.05899197, 0.07855664] + ground_truth = [ + 0.0060744, + 0.01886591, + 0.02732847, + 0.03683888, + 0.04569622, + 0.05196334, + 0.05899197, + 0.07855664, + ] assert_array_almost_equal(surrogate_train.magnitude, ground_truth) def test_joint_isi_with_wrongly_ordered_spikes(self): - surr_method = 'joint_isi_dithering' + surr_method = "joint_isi_dithering" n_surr = 30 dither = 15 * pq.ms spiketrain = neo.SpikeTrain( - [39.65696411, 98.93868274, 120.2417674, 134.70971166, - 154.20788924, - 160.29077989, 179.19884034, 212.86773029, 247.59488061, - 273.04095041, - 297.56437605, 344.99204215, 418.55696486, 460.54298334, - 482.82299125, - 524.236052, 566.38966742, 597.87562722, 651.26965293, - 692.39802855, - 740.90285815, 849.45874695, 974.57724848, 8.79247605], - t_start=0. * pq.ms, t_stop=1000. * pq.ms, units=pq.ms) - surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method, - dt=dither) + [ + 39.65696411, + 98.93868274, + 120.2417674, + 134.70971166, + 154.20788924, + 160.29077989, + 179.19884034, + 212.86773029, + 247.59488061, + 273.04095041, + 297.56437605, + 344.99204215, + 418.55696486, + 460.54298334, + 482.82299125, + 524.236052, + 566.38966742, + 597.87562722, + 651.26965293, + 692.39802855, + 740.90285815, + 849.45874695, + 974.57724848, + 8.79247605, + ], + t_start=0.0 * pq.ms, + t_stop=1000.0 * pq.ms, + units=pq.ms, + ) + surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method, dt=dither) def test_joint_isi_spikes_at_border(self): - surr_method = 'joint_isi_dithering' + surr_method = "joint_isi_dithering" n_surr = 30 dither = 15 * pq.ms spiketrain = neo.SpikeTrain( - [4., 28., 45., 51., 83., 87., 96., 111., 126., 131., - 138., 150., - 209., 232., 253., 275., 279., 303., 320., 371., 396., - 401., 429., 447., - 479., 511., 535., 549., 581., 585., 605., 607., 626., - 630., 644., 714., - 832., 835., 853., 858., 878., 905., 909., 932., 950., - 961., 999., 1000.], - t_start=0. * pq.ms, t_stop=1000. * pq.ms, units=pq.ms) - surr.surrogates( - spiketrain, n_surrogates=n_surr, method=surr_method, dt=dither) + [ + 4.0, + 28.0, + 45.0, + 51.0, + 83.0, + 87.0, + 96.0, + 111.0, + 126.0, + 131.0, + 138.0, + 150.0, + 209.0, + 232.0, + 253.0, + 275.0, + 279.0, + 303.0, + 320.0, + 371.0, + 396.0, + 401.0, + 429.0, + 447.0, + 479.0, + 511.0, + 535.0, + 549.0, + 581.0, + 585.0, + 605.0, + 607.0, + 626.0, + 630.0, + 644.0, + 714.0, + 832.0, + 835.0, + 853.0, + 858.0, + 878.0, + 905.0, + 909.0, + 932.0, + 950.0, + 961.0, + 999.0, + 1000.0, + ], + t_start=0.0 * pq.ms, + t_stop=1000.0 * pq.ms, + units=pq.ms, + ) + surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method, dt=dither) def test_bin_shuffling_output_format(self): - self.bin_size = 3 * pq.ms self.max_displacement = 10 - spiketrain = neo.SpikeTrain([90, 93, 97, 100, 105, - 150, 180, 350] * pq.ms, t_stop=.5 * pq.s) + spiketrain = neo.SpikeTrain( + [90, 93, 97, 100, 105, 150, 180, 350] * pq.ms, t_stop=0.5 * pq.s + ) binned_spiketrain = conv.BinnedSpikeTrain(spiketrain, self.bin_size) n_surrogates = 2 for sliding in (True, False): surrogate_trains = surr.bin_shuffling( - binned_spiketrain, max_displacement=self.max_displacement, - n_surrogates=n_surrogates, sliding=sliding) + binned_spiketrain, + max_displacement=self.max_displacement, + n_surrogates=n_surrogates, + sliding=sliding, + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) self.assertIsInstance(surrogate_trains[0], conv.BinnedSpikeTrain) for surrogate_train in surrogate_trains: - self.assertEqual(surrogate_train.t_start, - binned_spiketrain.t_start) - self.assertEqual(surrogate_train.t_stop, - binned_spiketrain.t_stop) - self.assertEqual(surrogate_train.n_bins, - binned_spiketrain.n_bins) - self.assertEqual(surrogate_train.bin_size, - binned_spiketrain.bin_size) + self.assertEqual(surrogate_train.t_start, binned_spiketrain.t_start) + self.assertEqual(surrogate_train.t_stop, binned_spiketrain.t_stop) + self.assertEqual(surrogate_train.n_bins, binned_spiketrain.n_bins) + self.assertEqual(surrogate_train.bin_size, binned_spiketrain.bin_size) def test_bin_shuffling_empty_train(self): - self.bin_size = 3 * pq.ms self.max_displacement = 10 empty_spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) - binned_spiketrain = conv.BinnedSpikeTrain(empty_spiketrain, - self.bin_size) + binned_spiketrain = conv.BinnedSpikeTrain(empty_spiketrain, self.bin_size) surrogate_train = surr.bin_shuffling( - binned_spiketrain, max_displacement=self.max_displacement, - n_surrogates=1)[0] + binned_spiketrain, max_displacement=self.max_displacement, n_surrogates=1 + )[0] self.assertEqual(np.sum(surrogate_train.to_bool_array()), 0) def test_trial_shuffling_output_format(self): - spiketrain = \ - [neo.SpikeTrain([90, 93, 97, 100, 105, 150, 180, 190] * pq.ms, - t_stop=.2 * pq.s), - neo.SpikeTrain([90, 93, 97, 100, 105, 150, 180, 190] * pq.ms, - t_stop=.2 * pq.s)] + spiketrain = [ + neo.SpikeTrain( + [90, 93, 97, 100, 105, 150, 180, 190] * pq.ms, t_stop=0.2 * pq.s + ), + neo.SpikeTrain( + [90, 93, 97, 100, 105, 150, 180, 190] * pq.ms, t_stop=0.2 * pq.s + ), + ] # trial_length = 200 * pq.ms # trial_separation = 50 * pq.ms n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr.trial_shifting( - spiketrain, dither=dither, n_surrogates=n_surrogates) + spiketrain, dither=dither, n_surrogates=n_surrogates + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -651,30 +776,37 @@ def test_trial_shuffling_output_format(self): self.assertEqual(surrogate_train.t_start, spiketrain[0].t_start) self.assertEqual(surrogate_train.t_stop, spiketrain[0].t_stop) self.assertEqual(len(surrogate_train), len(spiketrain[0])) - assert_array_less(0., np.diff(surrogate_train)) # check ordering + assert_array_less(0.0, np.diff(surrogate_train)) # check ordering def test_trial_shuffling_empty_train(self): - - empty_spiketrain = [neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms), - neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms)] + empty_spiketrain = [ + neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms), + neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms), + ] dither = 10 * pq.ms surrogate_train = surr.trial_shifting( - empty_spiketrain, dither=dither, n_surrogates=1)[0] + empty_spiketrain, dither=dither, n_surrogates=1 + )[0] self.assertEqual(len(surrogate_train), 2) self.assertEqual(len(surrogate_train[0]), 0) def test_trial_shuffling_output_format_concatenated(self): - spiketrain = neo.SpikeTrain([90, 93, 97, 100, 105, - 150, 180, 350] * pq.ms, t_stop=.5 * pq.s) + spiketrain = neo.SpikeTrain( + [90, 93, 97, 100, 105, 150, 180, 350] * pq.ms, t_stop=0.5 * pq.s + ) trial_length = 200 * pq.ms trial_separation = 50 * pq.ms n_surrogates = 2 dither = 10 * pq.ms surrogate_trains = surr._trial_shifting_of_concatenated_spiketrain( - spiketrain, dither=dither, n_surrogates=n_surrogates, - trial_length=trial_length, trial_separation=trial_separation) + spiketrain, + dither=dither, + n_surrogates=n_surrogates, + trial_length=trial_length, + trial_separation=trial_separation, + ) self.assertIsInstance(surrogate_trains, list) self.assertEqual(len(surrogate_trains), n_surrogates) @@ -685,18 +817,21 @@ def test_trial_shuffling_output_format_concatenated(self): self.assertEqual(surrogate_train.t_start, spiketrain.t_start) self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop) self.assertEqual(len(surrogate_train), len(spiketrain)) - assert_array_less(0., np.diff(surrogate_train)) # check ordering + assert_array_less(0.0, np.diff(surrogate_train)) # check ordering def test_trial_shuffling_empty_train_concatenated(self): - empty_spiketrain = neo.SpikeTrain([] * pq.ms, t_stop=500 * pq.ms) trial_length = 200 * pq.ms trial_separation = 50 * pq.ms dither = 10 * pq.ms surrogate_train = surr._trial_shifting_of_concatenated_spiketrain( - empty_spiketrain, dither=dither, n_surrogates=1, - trial_length=trial_length, trial_separation=trial_separation)[0] + empty_spiketrain, + dither=dither, + n_surrogates=1, + trial_length=trial_length, + trial_separation=trial_separation, + )[0] self.assertEqual(len(surrogate_train), 0) diff --git a/elephant/test/test_spike_train_synchrony.py b/elephant/test/test_spike_train_synchrony.py index 58be525eb..d4f283792 100644 --- a/elephant/test/test_spike_train_synchrony.py +++ b/elephant/test/test_spike_train_synchrony.py @@ -10,21 +10,26 @@ from quantities import Hz, ms, second import elephant.spike_train_generation as stgen -from elephant.spike_train_synchrony import Synchrotool, spike_contrast, \ - _get_theta_and_n_per_bin, _binning_half_overlap +from elephant.spike_train_synchrony import ( + Synchrotool, + spike_contrast, + _get_theta_and_n_per_bin, + _binning_half_overlap, +) from elephant.datasets import download_datasets, unzip class TestSpikeContrast(unittest.TestCase): - def test_spike_contrast_random(self): # randomly generated spiketrains that share the same t_start and # t_stop np.random.seed(24) # to make the results reproducible poisson_process_1 = stgen.StationaryPoissonProcess( - rate=10.*Hz, t_start=0.*ms, t_stop=10000.*ms) + rate=10.0 * Hz, t_start=0.0 * ms, t_stop=10000.0 * ms + ) poisson_process_2 = stgen.StationaryPoissonProcess( - rate=1.*Hz, t_start=0.*ms, t_stop=10000.*ms) + rate=1.0 * Hz, t_start=0.0 * ms, t_stop=10000.0 * ms + ) spike_train_1 = poisson_process_1.generate_spiketrain() spike_train_2 = poisson_process_1.generate_spiketrain() spike_train_3 = poisson_process_1.generate_spiketrain() @@ -32,15 +37,22 @@ def test_spike_contrast_random(self): spike_train_5 = poisson_process_2.generate_spiketrain() spike_train_6 = poisson_process_2.generate_spiketrain() - spike_trains = [spike_train_1, spike_train_2, spike_train_3, - spike_train_4, spike_train_5, spike_train_6] + spike_trains = [ + spike_train_1, + spike_train_2, + spike_train_3, + spike_train_4, + spike_train_5, + spike_train_6, + ] synchrony = spike_contrast(spike_trains) self.assertAlmostEqual(synchrony, 0.1875795, places=6) def test_spike_contrast_same_signal(self): np.random.seed(21) spike_train = stgen.StationaryPoissonProcess( - rate=10.*Hz, t_start=0.*ms, t_stop=10000.*ms).generate_spiketrain() + rate=10.0 * Hz, t_start=0.0 * ms, t_stop=10000.0 * ms + ).generate_spiketrain() spike_trains = [spike_train, spike_train] synchrony = spike_contrast(spike_trains, min_bin=1 * ms) self.assertEqual(synchrony, 1.0) @@ -48,7 +60,8 @@ def test_spike_contrast_same_signal(self): def test_spike_contrast_double_duration(self): np.random.seed(19) poisson_process = stgen.StationaryPoissonProcess( - rate=10 * Hz, t_start=0. * ms, t_stop=10000. * ms) + rate=10 * Hz, t_start=0.0 * ms, t_stop=10000.0 * ms + ) spike_train_1 = poisson_process.generate_spiketrain() spike_train_2 = poisson_process.generate_spiketrain() spike_train_3 = poisson_process.generate_spiketrain() @@ -60,24 +73,26 @@ def test_spike_contrast_double_duration(self): def test_spike_contrast_non_overlapping_spiketrains(self): np.random.seed(15) spike_train_1 = stgen.StationaryPoissonProcess( - rate=10 * Hz, t_start=0. * ms, t_stop=10000. * ms + rate=10 * Hz, t_start=0.0 * ms, t_stop=10000.0 * ms ).generate_spiketrain() spike_train_2 = stgen.StationaryPoissonProcess( - rate=10 * Hz, t_start=5000. * ms, t_stop=10000. * ms + rate=10 * Hz, t_start=5000.0 * ms, t_stop=10000.0 * ms ).generate_spiketrain() spiketrains = [spike_train_1, spike_train_2] synchrony = spike_contrast(spiketrains, t_stop=5000 * ms) # the synchrony of non-overlapping spiketrains must be zero - self.assertEqual(synchrony, 0.) + self.assertEqual(synchrony, 0.0) def test_spike_contrast_trace(self): np.random.seed(15) poisson_process = stgen.StationaryPoissonProcess( - rate=10 * Hz, t_start=0. * ms, t_stop=1000. * ms) + rate=10 * Hz, t_start=0.0 * ms, t_stop=1000.0 * ms + ) spike_train_1 = poisson_process.generate_spiketrain() spike_train_2 = poisson_process.generate_spiketrain() - synchrony, trace = spike_contrast([spike_train_1, spike_train_2], - return_trace=True) + synchrony, trace = spike_contrast( + [spike_train_1, spike_train_2], return_trace=True + ) self.assertEqual(synchrony, max(trace.synchrony)) self.assertEqual(len(trace.contrast), len(trace.active_spiketrains)) self.assertEqual(len(trace.active_spiketrains), len(trace.synchrony)) @@ -89,9 +104,14 @@ def test_spike_contrast_trace(self): def test_invalid_data(self): # invalid spiketrains self.assertRaises(TypeError, spike_contrast, [[0, 1], [1.5, 2.3]]) - self.assertRaises(ValueError, spike_contrast, - [neo.SpikeTrain([10] * ms, t_stop=1000 * ms), - neo.SpikeTrain([20] * ms, t_stop=1000 * ms)]) + self.assertRaises( + ValueError, + spike_contrast, + [ + neo.SpikeTrain([10] * ms, t_stop=1000 * ms), + neo.SpikeTrain([20] * ms, t_stop=1000 * ms), + ], + ) # a single spiketrain spiketrain_valid = neo.SpikeTrain([0, 1000] * ms, t_stop=1000 * ms) @@ -101,22 +121,19 @@ def test_invalid_data(self): spiketrains = [spiketrain_valid, spiketrain_valid2] # invalid shrink factor - self.assertRaises(ValueError, spike_contrast, spiketrains, - bin_shrink_factor=0.) + self.assertRaises( + ValueError, spike_contrast, spiketrains, bin_shrink_factor=0.0 + ) # invalid t_start, t_stop, and min_bin - self.assertRaises(TypeError, spike_contrast, spiketrains, - t_start=0) - self.assertRaises(TypeError, spike_contrast, spiketrains, - t_stop=1000) - self.assertRaises(TypeError, spike_contrast, spiketrains, - min_bin=0.01) + self.assertRaises(TypeError, spike_contrast, spiketrains, t_start=0) + self.assertRaises(TypeError, spike_contrast, spiketrains, t_stop=1000) + self.assertRaises(TypeError, spike_contrast, spiketrains, min_bin=0.01) def test_t_start_agnostic(self): np.random.seed(15) t_stop = 10 * second - poisson_process = stgen.StationaryPoissonProcess( - rate=10 * Hz, t_stop=t_stop) + poisson_process = stgen.StationaryPoissonProcess(rate=10 * Hz, t_stop=t_stop) spike_train_1 = poisson_process.generate_spiketrain() spike_train_2 = poisson_process.generate_spiketrain() spiketrains = [spike_train_1, spike_train_2] @@ -125,24 +142,17 @@ def test_t_start_agnostic(self): assert synchrony_target > 0 t_shift = 20 * second spiketrains_shifted = [ - neo.SpikeTrain(st.times + t_shift, - t_start=t_shift, - t_stop=t_stop + t_shift) + neo.SpikeTrain(st.times + t_shift, t_start=t_shift, t_stop=t_stop + t_shift) for st in spiketrains ] synchrony = spike_contrast(spiketrains_shifted) self.assertAlmostEqual(synchrony, synchrony_target) def test_get_theta_and_n_per_bin(self): - spike_trains = [ - [1, 2, 3, 9], - [1, 2, 3, 9], - [1, 2, 2.5] - ] - theta, n = _get_theta_and_n_per_bin(spike_trains, - t_start=0, - t_stop=10, - bin_size=5) + spike_trains = [[1, 2, 3, 9], [1, 2, 3, 9], [1, 2, 2.5]] + theta, n = _get_theta_and_n_per_bin( + spike_trains, t_start=0, t_stop=10, bin_size=5 + ) assert_array_equal(theta, [9, 3, 2]) assert_array_equal(n, [3, 3, 2]) @@ -166,8 +176,7 @@ def test_spike_contrast_with_Izhikevich_network_auto(self): izhikevich_gin = r"dataset-3/Data_Izhikevich_network.zip" checksum = "70e848500c1d9c6403b66de8c741d849" - filepath_zip = download_datasets(repo_path=izhikevich_gin, - checksum=checksum) + filepath_zip = download_datasets(repo_path=izhikevich_gin, checksum=checksum) unzip(filepath_zip) filepath_json = filepath_zip.with_suffix(".json") with open(filepath_json) as read_file: @@ -178,102 +187,114 @@ def test_spike_contrast_with_Izhikevich_network_auto(self): for network_simulations in networks_subset: for simulation in network_simulations.values(): - synchrony_true = simulation['synchrony'] + synchrony_true = simulation["synchrony"] spiketrains = [ - neo.SpikeTrain(st, t_start=0 * second, t_stop=2 * second, - units=second) - for st in simulation['spiketrains']] + neo.SpikeTrain( + st, t_start=0 * second, t_stop=2 * second, units=second + ) + for st in simulation["spiketrains"] + ] synchrony = spike_contrast(spiketrains) self.assertAlmostEqual(synchrony, synchrony_true, places=2) class SynchrofactDetectionTestCase(unittest.TestCase): - - def _test_template(self, spiketrains, correct_complexities, sampling_rate, - spread, deletion_threshold=2, mode='delete', - in_place=False, binary=True): - + def _test_template( + self, + spiketrains, + correct_complexities, + sampling_rate, + spread, + deletion_threshold=2, + mode="delete", + in_place=False, + binary=True, + ): synchrofact_obj = Synchrotool( - spiketrains, - sampling_rate=sampling_rate, - binary=binary, - spread=spread) + spiketrains, sampling_rate=sampling_rate, binary=binary, spread=spread + ) # test annotation synchrofact_obj.annotate_synchrofacts() - annotations = [st.array_annotations['complexity'] - for st in spiketrains] + annotations = [st.array_annotations["complexity"] for st in spiketrains] assert_array_equal(annotations, correct_complexities) for a in annotations: self.assertEqual(a.dtype, np.dtype(np.uint16).type) - if mode == 'extract': + if mode == "extract": correct_spike_times = [ - spikes[mask] for spikes, mask - in zip(spiketrains, - correct_complexities >= deletion_threshold) + spikes[mask] + for spikes, mask in zip( + spiketrains, correct_complexities >= deletion_threshold + ) ] else: correct_spike_times = [ - spikes[mask] for spikes, mask - in zip(spiketrains, - correct_complexities < deletion_threshold) + spikes[mask] + for spikes, mask in zip( + spiketrains, correct_complexities < deletion_threshold + ) ] # test deletion - synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, - in_place=in_place, - mode=mode) + synchrofact_obj.delete_synchrofacts( + threshold=deletion_threshold, in_place=in_place, mode=mode + ) cleaned_spike_times = [st.times for st in spiketrains] - for correct_st, cleaned_st in zip(correct_spike_times, - cleaned_spike_times): + for correct_st, cleaned_st in zip(correct_spike_times, cleaned_spike_times): assert_array_almost_equal(cleaned_st, correct_st) def test_no_synchrofacts(self): - # nothing to find here # there used to be an error for spread > 0 when nothing was found sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 9, 12, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([3, 7, 15, 17] * pq.s, - t_stop=20*pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 9, 12, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([3, 7, 15, 17] * pq.s, t_stop=20 * pq.s), + ] - correct_annotations = np.array([[1, 1, 1, 1], - [1, 1, 1, 1]]) + correct_annotations = np.array([[1, 1, 1, 1], [1, 1, 1, 1]]) - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='delete', - deletion_threshold=2) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=1, + mode="delete", + deletion_threshold=2, + ) def test_spread_0(self): - # basic test with a minimum number of two spikes per synchrofact # only taking into account multiple spikes # within one bin of size 1 / sampling_rate sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20*pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, t_stop=20 * pq.s), + ] - correct_annotations = np.array([[2, 1, 1, 1, 2, 1], - [2, 1, 1, 1, 2, 1]]) + correct_annotations = np.array([[2, 1, 1, 1, 2, 1], [2, 1, 1, 1, 2, 1]]) - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, mode='delete', in_place=True, - deletion_threshold=2) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=0, + mode="delete", + in_place=True, + deletion_threshold=2, + ) def test_spiketrains_findable(self): - # same test as `test_spread_0` with the addition of # a neo structure: we must not overwrite the spiketrain # list of the segment before determining the index @@ -282,22 +303,26 @@ def test_spiketrains_findable(self): segment = neo.Segment() - segment.spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20*pq.s)] + segment.spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, t_stop=20 * pq.s), + ] segment.create_relationship() - correct_annotations = np.array([[2, 1, 1, 1, 2, 1], - [2, 1, 1, 1, 2, 1]]) + correct_annotations = np.array([[2, 1, 1, 1, 2, 1], [2, 1, 1, 1, 2, 1]]) - self._test_template(segment.spiketrains, correct_annotations, - sampling_rate, spread=0, mode='delete', - in_place=True, deletion_threshold=2) + self._test_template( + segment.spiketrains, + correct_annotations, + sampling_rate, + spread=0, + mode="delete", + in_place=True, + deletion_threshold=2, + ) def test_unidirectional_uplinks(self): - # same test as `test_spiketrains_findable` but the spiketrains # are rescaled first # the rescaled spiketrains have a unidirectional uplink to segment @@ -308,92 +333,111 @@ def test_unidirectional_uplinks(self): segment = neo.Segment() - segment.spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20*pq.s)] + segment.spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, t_stop=20 * pq.s), + ] segment.create_relationship() spiketrains = [st.rescale(pq.s) for st in segment.spiketrains] - correct_annotations = np.array([[2, 1, 1, 1, 2, 1], - [2, 1, 1, 1, 2, 1]]) + correct_annotations = np.array([[2, 1, 1, 1, 2, 1], [2, 1, 1, 1, 2, 1]]) with self.assertWarns(UserWarning): - self._test_template(spiketrains, correct_annotations, - sampling_rate, spread=0, mode='delete', - in_place=True, deletion_threshold=2) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=0, + mode="delete", + in_place=True, + deletion_threshold=2, + ) def test_spread_1(self): - # test synchrofact search taking into account adjacent bins # this requires an additional loop with shifted binning sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21*pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, t_stop=21 * pq.s), + ] - correct_annotations = np.array([[2, 2, 1, 3, 3, 1], - [2, 2, 1, 3, 1, 1]]) + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], [2, 2, 1, 3, 1, 1]]) - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='delete', in_place=True, - deletion_threshold=2) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=1, + mode="delete", + in_place=True, + deletion_threshold=2, + ) def test_n_equals_3(self): - # test synchrofact detection with a minimum number of # three spikes per synchrofact sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 1, 5, 10, 13, 16, 17, 19] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 9, 12, 14, 16, 20] * pq.s, - t_stop=21*pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 1, 5, 10, 13, 16, 17, 19] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 9, 12, 14, 16, 20] * pq.s, t_stop=21 * pq.s), + ] - correct_annotations = np.array([[3, 3, 2, 2, 3, 3, 3, 2], - [3, 2, 1, 2, 3, 3, 3, 2]]) + correct_annotations = np.array( + [[3, 3, 2, 2, 3, 3, 3, 2], [3, 2, 1, 2, 3, 3, 3, 2]] + ) - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='delete', binary=False, - in_place=True, deletion_threshold=3) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=1, + mode="delete", + binary=False, + in_place=True, + deletion_threshold=3, + ) def test_extract(self): - # test synchrofact search taking into account adjacent bins # this requires an additional loop with shifted binning sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21*pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, t_stop=21 * pq.s), + ] - correct_annotations = np.array([[2, 2, 1, 3, 3, 1], - [2, 2, 1, 3, 1, 1]]) + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], [2, 2, 1, 3, 1, 1]]) - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='extract', in_place=True, - deletion_threshold=2) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=1, + mode="extract", + in_place=True, + deletion_threshold=2, + ) def test_binning_for_input_with_rounding_errors(self): - # a test with inputs divided by 30000 which leads to rounding errors # these errors have to be accounted for by proper binning; # check if we still get the correct result sampling_rate = 30000 / pq.s - spiketrains = [neo.SpikeTrain(np.arange(1000) * pq.s / 30000, - t_stop=.1 * pq.s), - neo.SpikeTrain(np.arange(2000, step=2) * pq.s / 30000, - t_stop=.1 * pq.s)] + spiketrains = [ + neo.SpikeTrain(np.arange(1000) * pq.s / 30000, t_stop=0.1 * pq.s), + neo.SpikeTrain(np.arange(2000, step=2) * pq.s / 30000, t_stop=0.1 * pq.s), + ] first_annotations = np.ones(1000) first_annotations[::2] = 2 @@ -401,26 +445,29 @@ def test_binning_for_input_with_rounding_errors(self): second_annotations = np.ones(1000) second_annotations[:500] = 2 - correct_annotations = np.array([first_annotations, - second_annotations]) + correct_annotations = np.array([first_annotations, second_annotations]) - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, mode='delete', in_place=True, - deletion_threshold=2) + self._test_template( + spiketrains, + correct_annotations, + sampling_rate, + spread=0, + mode="delete", + in_place=True, + deletion_threshold=2, + ) def test_correct_transfer_of_spiketrain_attributes(self): - # for delete=True the spiketrains in the block are changed, # test if their attributes remain correct sampling_rate = 1 / pq.s - spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, - t_stop=10 * pq.s) + spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, t_stop=10 * pq.s) block = neo.Block() - group = neo.Group(name='Test Group') + group = neo.Group(name="Test Group") block.groups.append(group) group.spiketrains.append(spiketrain) @@ -432,29 +479,28 @@ def test_correct_transfer_of_spiketrain_attributes(self): spiketrain.annotate(cool_spike_train=True) spiketrain.array_annotate( - spike_number=np.arange(len(spiketrain.times.magnitude))) + spike_number=np.arange(len(spiketrain.times.magnitude)) + ) spiketrain.waveforms = np.sin( np.arange(len(spiketrain.times.magnitude))[:, np.newaxis] - + np.arange(len(spiketrain.times.magnitude))[np.newaxis, :]) + + np.arange(len(spiketrain.times.magnitude))[np.newaxis, :] + ) correct_mask = np.array([False, False, True, True]) # store the correct attributes correct_annotations = spiketrain.annotations.copy() correct_waveforms = spiketrain.waveforms[correct_mask].copy() - correct_array_annotations = {key: value[correct_mask] for key, value in - spiketrain.array_annotations.items()} + correct_array_annotations = { + key: value[correct_mask] + for key, value in spiketrain.array_annotations.items() + } # perform a synchrofact search with delete=True synchrofact_obj = Synchrotool( - [spiketrain], - spread=0, - sampling_rate=sampling_rate, - binary=False) - synchrofact_obj.delete_synchrofacts( - mode='delete', - in_place=True, - threshold=2) + [spiketrain], spread=0, sampling_rate=sampling_rate, binary=False + ) + synchrofact_obj.delete_synchrofacts(mode="delete", in_place=True, threshold=2) # Ensure that the spiketrain was not duplicated self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) @@ -468,26 +514,26 @@ def test_correct_transfer_of_spiketrain_attributes(self): cleaned_annotations = cleaned_spiketrain.annotations cleaned_waveforms = cleaned_spiketrain.waveforms cleaned_array_annotations = cleaned_spiketrain.array_annotations - cleaned_array_annotations.pop('complexity') + cleaned_array_annotations.pop("complexity") self.assertDictEqual(correct_annotations, cleaned_annotations) assert_array_almost_equal(cleaned_waveforms, correct_waveforms) - self.assertTrue(len(cleaned_array_annotations) - == len(correct_array_annotations)) + self.assertTrue( + len(cleaned_array_annotations) == len(correct_array_annotations) + ) for key, value in correct_array_annotations.items(): self.assertTrue(key in cleaned_array_annotations.keys()) assert_array_almost_equal(value, cleaned_array_annotations[key]) def test_wrong_input_errors(self): synchrofact_obj = Synchrotool( - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - sampling_rate=1/pq.s, + [neo.SpikeTrain([1] * pq.s, t_stop=2 * pq.s)], + sampling_rate=1 / pq.s, binary=True, - spread=1) - self.assertRaises(ValueError, - synchrofact_obj.delete_synchrofacts, - -1) + spread=1, + ) + self.assertRaises(ValueError, synchrofact_obj.delete_synchrofacts, -1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_sta.py b/elephant/test/test_sta.py index 0b9c7c52e..bcd9ea325 100644 --- a/elephant/test/test_sta.py +++ b/elephant/test/test_sta.py @@ -20,27 +20,46 @@ class sta_TestCase(unittest.TestCase): - def setUp(self): - self.asiga0 = AnalogSignal(np.array([ - np.sin(np.arange(0, 20 * math.pi, 0.1))]).T, - units='mV', sampling_rate=10 / ms) - self.asiga1 = AnalogSignal(np.array([ - np.sin(np.arange(0, 20 * math.pi, 0.1)), - np.cos(np.arange(0, 20 * math.pi, 0.1))]).T, - units='mV', sampling_rate=10 / ms) - self.asiga2 = AnalogSignal(np.array([ - np.sin(np.arange(0, 20 * math.pi, 0.1)), - np.cos(np.arange(0, 20 * math.pi, 0.1)), - np.tan(np.arange(0, 20 * math.pi, 0.1))]).T, - units='mV', sampling_rate=10 / ms) + self.asiga0 = AnalogSignal( + np.array([np.sin(np.arange(0, 20 * math.pi, 0.1))]).T, + units="mV", + sampling_rate=10 / ms, + ) + self.asiga1 = AnalogSignal( + np.array( + [ + np.sin(np.arange(0, 20 * math.pi, 0.1)), + np.cos(np.arange(0, 20 * math.pi, 0.1)), + ] + ).T, + units="mV", + sampling_rate=10 / ms, + ) + self.asiga2 = AnalogSignal( + np.array( + [ + np.sin(np.arange(0, 20 * math.pi, 0.1)), + np.cos(np.arange(0, 20 * math.pi, 0.1)), + np.tan(np.arange(0, 20 * math.pi, 0.1)), + ] + ).T, + units="mV", + sampling_rate=10 / ms, + ) self.st0 = SpikeTrain( [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi], - units='ms', t_stop=self.asiga0.t_stop) - self.lst = [SpikeTrain( - [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi], - units='ms', t_stop=self.asiga1.t_stop), - SpikeTrain([30, 35, 40], units='ms', t_stop=self.asiga1.t_stop)] + units="ms", + t_stop=self.asiga0.t_stop, + ) + self.lst = [ + SpikeTrain( + [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi], + units="ms", + t_stop=self.asiga1.t_stop, + ), + SpikeTrain([30, 35, 40], units="ms", t_stop=self.asiga1.t_stop), + ] # *********************************************************************** # ************************ Test for typical values ********************** @@ -49,57 +68,57 @@ def test_spike_triggered_average_with_n_spikes_on_constant_function(self): """Signal should average to the input""" const = 13.8 x = const * np.ones(201) - asiga = AnalogSignal( - np.array([x]).T, units='mV', sampling_rate=10 / ms) - st = SpikeTrain([3, 5.6, 7, 7.1, 16, 16.3], units='ms', t_stop=20) + asiga = AnalogSignal(np.array([x]).T, units="mV", sampling_rate=10 / ms) + st = SpikeTrain([3, 5.6, 7, 7.1, 16, 16.3], units="ms", t_stop=20) window_starttime = -2 * ms window_endtime = 2 * ms - STA = sta.spike_triggered_average( - asiga, st, (window_starttime, window_endtime)) - a = int(((window_endtime - window_starttime) * - asiga.sampling_rate).simplified) - cutout = asiga[0: a] + STA = sta.spike_triggered_average(asiga, st, (window_starttime, window_endtime)) + a = int(((window_endtime - window_starttime) * asiga.sampling_rate).simplified) + cutout = asiga[0:a] cutout.t_start = window_starttime assert_array_almost_equal(STA, cutout, 12) def test_spike_triggered_average_with_shifted_sin_wave(self): """Signal should average to zero""" - STA = sta.spike_triggered_average( - self.asiga0, self.st0, (-4 * ms, 4 * ms)) + STA = sta.spike_triggered_average(self.asiga0, self.st0, (-4 * ms, 4 * ms)) target = 5e-2 * mV - self.assertEqual(np.abs(STA).max().dimensionality.simplified, - pq.Quantity(1, "V").dimensionality.simplified) + self.assertEqual( + np.abs(STA).max().dimensionality.simplified, + pq.Quantity(1, "V").dimensionality.simplified, + ) self.assertLess(np.abs(STA).max(), target) def test_only_one_spike(self): """The output should be the same as the input""" x = np.arange(0, 20, 0.1) - y = x ** 2 + y = x**2 sr = 10 / ms - z = AnalogSignal(np.array([y]).T, units='mV', sampling_rate=sr) + z = AnalogSignal(np.array([y]).T, units="mV", sampling_rate=sr) spiketime = 8 * ms spiketime_in_ms = int((spiketime / ms).simplified) - st = SpikeTrain([spiketime_in_ms], units='ms', t_stop=20) + st = SpikeTrain([spiketime_in_ms], units="ms", t_stop=20) window_starttime = -3 * ms window_endtime = 5 * ms - STA = sta.spike_triggered_average( - z, st, (window_starttime, window_endtime)) - cutout = z[int(((spiketime + window_starttime) * sr).simplified): - int(((spiketime + window_endtime) * sr).simplified)] + STA = sta.spike_triggered_average(z, st, (window_starttime, window_endtime)) + cutout = z[ + int(((spiketime + window_starttime) * sr).simplified) : int( + ((spiketime + window_endtime) * sr).simplified + ) + ] cutout.t_start = window_starttime assert_array_equal(STA, cutout) def test_usage_of_spikes(self): - st = SpikeTrain([16.5 * math.pi, - 17.5 * math.pi, - 18.5 * math.pi, - 19.5 * math.pi], - units='ms', - t_stop=20 * math.pi) + st = SpikeTrain( + [16.5 * math.pi, 17.5 * math.pi, 18.5 * math.pi, 19.5 * math.pi], + units="ms", + t_stop=20 * math.pi, + ) STA = sta.spike_triggered_average( - self.asiga0, st, (-math.pi * ms, math.pi * ms)) - self.assertEqual(STA.annotations['used_spikes'], 3) - self.assertEqual(STA.annotations['unused_spikes'], 1) + self.asiga0, st, (-math.pi * ms, math.pi * ms) + ) + self.assertEqual(STA.annotations["used_spikes"], 3) + self.assertEqual(STA.annotations["unused_spikes"], 1) # *********************************************************************** # **** Test for an invalid value, to check that the function raises ***** @@ -108,47 +127,71 @@ def test_usage_of_spikes(self): def test_analog_signal_of_wrong_type(self): """Analog signal given as list, but must be AnalogSignal""" asiga = [0, 1, 2, 3, 4] - self.assertRaises(TypeError, sta.spike_triggered_average, - asiga, self.st0, (-2 * ms, 2 * ms)) + self.assertRaises( + TypeError, sta.spike_triggered_average, asiga, self.st0, (-2 * ms, 2 * ms) + ) def test_spiketrain_of_list_type_in_wrong_sense(self): st = [10, 11, 12] - self.assertRaises(TypeError, sta.spike_triggered_average, - self.asiga0, st, (1 * ms, 2 * ms)) + self.assertRaises( + TypeError, sta.spike_triggered_average, self.asiga0, st, (1 * ms, 2 * ms) + ) def test_spiketrain_of_nonlist_and_nonspiketrain_type(self): st = (10, 11, 12) - self.assertRaises(TypeError, sta.spike_triggered_average, - self.asiga0, st, (1 * ms, 2 * ms)) + self.assertRaises( + TypeError, sta.spike_triggered_average, self.asiga0, st, (1 * ms, 2 * ms) + ) def test_forgotten_AnalogSignal_argument(self): - self.assertRaises(TypeError, sta.spike_triggered_average, - self.st0, (-2 * ms, 2 * ms)) + self.assertRaises( + TypeError, sta.spike_triggered_average, self.st0, (-2 * ms, 2 * ms) + ) def test_one_smaller_nrspiketrains_smaller_nranalogsignals(self): """Number of spiketrains between 1 and number of analogsignals""" - self.assertRaises(ValueError, sta.spike_triggered_average, - self.asiga2, self.lst, (-2 * ms, 2 * ms)) + self.assertRaises( + ValueError, + sta.spike_triggered_average, + self.asiga2, + self.lst, + (-2 * ms, 2 * ms), + ) def test_more_spiketrains_than_analogsignals_forbidden(self): - self.assertRaises(ValueError, sta.spike_triggered_average, - self.asiga0, self.lst, (-2 * ms, 2 * ms)) + self.assertRaises( + ValueError, + sta.spike_triggered_average, + self.asiga0, + self.lst, + (-2 * ms, 2 * ms), + ) def test_spike_earlier_than_analogsignal(self): - st = SpikeTrain([-1 * math.pi, 2 * math.pi], - units='ms', t_start=-2 * math.pi, t_stop=20 * math.pi) - self.assertRaises(ValueError, sta.spike_triggered_average, - self.asiga0, st, (-2 * ms, 2 * ms)) + st = SpikeTrain( + [-1 * math.pi, 2 * math.pi], + units="ms", + t_start=-2 * math.pi, + t_stop=20 * math.pi, + ) + self.assertRaises( + ValueError, sta.spike_triggered_average, self.asiga0, st, (-2 * ms, 2 * ms) + ) def test_spike_later_than_analogsignal(self): - st = SpikeTrain( - [math.pi, 21 * math.pi], units='ms', t_stop=25 * math.pi) - self.assertRaises(ValueError, sta.spike_triggered_average, - self.asiga0, st, (-2 * ms, 2 * ms)) + st = SpikeTrain([math.pi, 21 * math.pi], units="ms", t_stop=25 * math.pi) + self.assertRaises( + ValueError, sta.spike_triggered_average, self.asiga0, st, (-2 * ms, 2 * ms) + ) def test_impossible_window(self): - self.assertRaises(ValueError, sta.spike_triggered_average, - self.asiga0, self.st0, (-2 * ms, -5 * ms)) + self.assertRaises( + ValueError, + sta.spike_triggered_average, + self.asiga0, + self.st0, + (-2 * ms, -5 * ms), + ) def test_window_larger_than_signal(self): self.assertRaises( @@ -156,16 +199,26 @@ def test_window_larger_than_signal(self): sta.spike_triggered_average, self.asiga0, self.st0, - (-15 * math.pi * ms, - 15 * math.pi * ms)) + (-15 * math.pi * ms, 15 * math.pi * ms), + ) def test_wrong_window_starttime_unit(self): - self.assertRaises(TypeError, sta.spike_triggered_average, - self.asiga0, self.st0, (-2 * mV, 2 * ms)) + self.assertRaises( + TypeError, + sta.spike_triggered_average, + self.asiga0, + self.st0, + (-2 * mV, 2 * ms), + ) def test_wrong_window_endtime_unit(self): - self.assertRaises(TypeError, sta.spike_triggered_average, - self.asiga0, self.st0, (-2 * ms, 2 * Hz)) + self.assertRaises( + TypeError, + sta.spike_triggered_average, + self.asiga0, + self.st0, + (-2 * ms, 2 * Hz), + ) def test_window_borders_as_complex_numbers(self): self.assertRaises( @@ -173,50 +226,57 @@ def test_window_borders_as_complex_numbers(self): sta.spike_triggered_average, self.asiga0, self.st0, - ((-2 * math.pi + 3j) * ms, - (2 * math.pi + 3j) * ms)) + ((-2 * math.pi + 3j) * ms, (2 * math.pi + 3j) * ms), + ) # *********************************************************************** # **** Test for an empty value (where the argument is a list, array, **** # ********* vector or other container datatype). ************************ def test_empty_analogsignal(self): - asiga = AnalogSignal([], units='mV', sampling_rate=10 / ms) - st = SpikeTrain([5], units='ms', t_stop=10) - self.assertRaises(ValueError, sta.spike_triggered_average, - asiga, st, (-1 * ms, 1 * ms)) + asiga = AnalogSignal([], units="mV", sampling_rate=10 / ms) + st = SpikeTrain([5], units="ms", t_stop=10) + self.assertRaises( + ValueError, sta.spike_triggered_average, asiga, st, (-1 * ms, 1 * ms) + ) def test_one_spiketrain_empty(self): """Test for one empty SpikeTrain, but existing spikes in other""" - st = [SpikeTrain( - [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi], - units='ms', t_stop=self.asiga1.t_stop), - SpikeTrain([], units='ms', t_stop=self.asiga1.t_stop)] + st = [ + SpikeTrain( + [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi], + units="ms", + t_stop=self.asiga1.t_stop, + ), + SpikeTrain([], units="ms", t_stop=self.asiga1.t_stop), + ] with warnings.catch_warnings(): warnings.simplefilter("ignore") """ Ignore the RuntimeWarning: invalid value encountered in true_divide new_signal = f(other, *args) for the empty SpikeTrain. """ - STA = sta.spike_triggered_average(self.asiga1, - spiketrains=st, - window=(-1 * ms, 1 * ms)) + STA = sta.spike_triggered_average( + self.asiga1, spiketrains=st, window=(-1 * ms, 1 * ms) + ) assert np.isnan(STA.magnitude[:, 1]).all() def test_all_spiketrains_empty(self): - st = SpikeTrain([], units='ms', t_stop=self.asiga1.t_stop) + st = SpikeTrain([], units="ms", t_stop=self.asiga1.t_stop) with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger warnings. - STA = sta.spike_triggered_average( - self.asiga1, st, (-1 * ms, 1 * ms)) - self.assertEqual("No spike at all was either found or used " - "for averaging", str(w[-1].message)) + STA = sta.spike_triggered_average(self.asiga1, st, (-1 * ms, 1 * ms)) + self.assertEqual( + "No spike at all was either found or used " "for averaging", + str(w[-1].message), + ) nan_array = np.empty(20) nan_array.fill(np.nan) - cmp_array = AnalogSignal(np.array([nan_array, nan_array]).T, - units='mV', sampling_rate=10 / ms) + cmp_array = AnalogSignal( + np.array([nan_array, nan_array]).T, units="mV", sampling_rate=10 / ms + ) assert_array_equal(STA.magnitude, cmp_array.magnitude) @@ -224,25 +284,32 @@ def test_all_spiketrains_empty(self): # Tests for new scipy verison (with scipy.signal.coherence) # ========================================================================= -@unittest.skipIf(not hasattr(scipy.signal, 'coherence'), - "Please update scipy " - "to a version >= 0.16") -class sfc_TestCase_new_scipy(unittest.TestCase): +@unittest.skipIf( + not hasattr(scipy.signal, "coherence"), + "Please update scipy " "to a version >= 0.16", +) +class sfc_TestCase_new_scipy(unittest.TestCase): def setUp(self): # standard testsignals tlen0 = 100 * pq.s - f0 = 20. * pq.Hz + f0 = 20.0 * pq.Hz fs0 = 1 * pq.ms - t0 = np.arange( - 0, tlen0.rescale(pq.s).magnitude, - fs0.rescale(pq.s).magnitude) * pq.s + t0 = ( + np.arange(0, tlen0.rescale(pq.s).magnitude, fs0.rescale(pq.s).magnitude) + * pq.s + ) self.anasig0 = AnalogSignal( np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), - units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) + units=pq.mV, + t_start=0 * pq.ms, + sampling_period=fs0, + ) self.st0 = SpikeTrain( np.arange(0, tlen0.rescale(pq.ms).magnitude, 50) * pq.ms, - t_start=0 * pq.ms, t_stop=tlen0) + t_start=0 * pq.ms, + t_stop=tlen0, + ) self.bst0 = BinnedSpikeTrain(self.st0, bin_size=fs0) # shortened analogsignals @@ -253,82 +320,100 @@ def setUp(self): fs1 = 0.1 * pq.ms self.anasig3 = AnalogSignal( np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), - units=pq.mV, t_start=0 * pq.ms, sampling_period=fs1) + units=pq.mV, + t_start=0 * pq.ms, + sampling_period=fs1, + ) self.bst1 = BinnedSpikeTrain( - self.st0.time_slice(self.anasig3.t_start, self.anasig3.t_stop), - bin_size=fs1) + self.st0.time_slice(self.anasig3.t_start, self.anasig3.t_stop), bin_size=fs1 + ) # analogsignal containing multiple traces self.anasig4 = AnalogSignal( - np.array([ - np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), - np.sin(4 * np.pi * (f0 * t0).simplified.magnitude)]). - transpose(), - units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) + np.array( + [ + np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), + np.sin(4 * np.pi * (f0 * t0).simplified.magnitude), + ] + ).transpose(), + units=pq.mV, + t_start=0 * pq.ms, + sampling_period=fs0, + ) # shortened spike train self.st3 = SpikeTrain( np.arange( - (tlen0.rescale(pq.ms).magnitude * .25), - (tlen0.rescale(pq.ms).magnitude * .75), 50) * pq.ms, - t_start=0 * pq.ms, t_stop=tlen0) + (tlen0.rescale(pq.ms).magnitude * 0.25), + (tlen0.rescale(pq.ms).magnitude * 0.75), + 50, + ) + * pq.ms, + t_start=0 * pq.ms, + t_stop=tlen0, + ) self.bst3 = BinnedSpikeTrain(self.st3, bin_size=fs0) - self.st4 = SpikeTrain(np.arange( - (tlen0.rescale(pq.ms).magnitude * .25), - (tlen0.rescale(pq.ms).magnitude * .75), 50) * pq.ms, - t_start=5 * fs0, t_stop=tlen0 - 5 * fs0) + self.st4 = SpikeTrain( + np.arange( + (tlen0.rescale(pq.ms).magnitude * 0.25), + (tlen0.rescale(pq.ms).magnitude * 0.75), + 50, + ) + * pq.ms, + t_start=5 * fs0, + t_stop=tlen0 - 5 * fs0, + ) self.bst4 = BinnedSpikeTrain(self.st4, bin_size=fs0) # spike train with incompatible bin_size - self.bst5 = BinnedSpikeTrain(self.st3, bin_size=fs0 * 2.) + self.bst5 = BinnedSpikeTrain(self.st3, bin_size=fs0 * 2.0) # spike train with same bin_size as the analog signal, but with # bin edges not aligned to the time axis of the analog signal self.bst6 = BinnedSpikeTrain( - self.st3, - bin_size=fs0, - t_start=4.5 * fs0, - t_stop=tlen0 - 4.5 * fs0) + self.st3, bin_size=fs0, t_start=4.5 * fs0, t_stop=tlen0 - 4.5 * fs0 + ) # ========================================================================= # Tests for correct input handling # ========================================================================= def test_wrong_input_type(self): - self.assertRaises(TypeError, - sta.spike_field_coherence, - np.array([1, 2, 3]), self.bst0) - self.assertRaises(TypeError, - sta.spike_field_coherence, - self.anasig0, [1, 2, 3]) - self.assertRaises(ValueError, - sta.spike_field_coherence, - self.anasig0.duplicate_with_new_data([]), self.bst0) + self.assertRaises( + TypeError, sta.spike_field_coherence, np.array([1, 2, 3]), self.bst0 + ) + self.assertRaises(TypeError, sta.spike_field_coherence, self.anasig0, [1, 2, 3]) + self.assertRaises( + ValueError, + sta.spike_field_coherence, + self.anasig0.duplicate_with_new_data([]), + self.bst0, + ) def test_start_stop_times_out_of_range(self): - self.assertRaises(ValueError, - sta.spike_field_coherence, - self.anasig1, self.bst0) + self.assertRaises( + ValueError, sta.spike_field_coherence, self.anasig1, self.bst0 + ) - self.assertRaises(ValueError, - sta.spike_field_coherence, - self.anasig2, self.bst0) + self.assertRaises( + ValueError, sta.spike_field_coherence, self.anasig2, self.bst0 + ) def test_non_matching_input_binning(self): - self.assertRaises(ValueError, - sta.spike_field_coherence, - self.anasig0, self.bst1) + self.assertRaises( + ValueError, sta.spike_field_coherence, self.anasig0, self.bst1 + ) def test_incompatible_spiketrain_analogsignal(self): # These spike trains have incompatible binning (bin_size or alignment # to time axis of analog signal) - self.assertRaises(ValueError, - sta.spike_field_coherence, - self.anasig0, self.bst5) - self.assertRaises(ValueError, - sta.spike_field_coherence, - self.anasig0, self.bst6) + self.assertRaises( + ValueError, sta.spike_field_coherence, self.anasig0, self.bst5 + ) + self.assertRaises( + ValueError, sta.spike_field_coherence, self.anasig0, self.bst6 + ) def test_signal_dimensions(self): # single analogsignal trace and single spike train @@ -353,11 +438,11 @@ def test_signal_dimensions(self): def test_non_binned_spiketrain_input(self): s, f = sta.spike_field_coherence(self.anasig0, self.st0) - f_ind = np.where(f >= 19.)[0][0] + f_ind = np.where(f >= 19.0)[0][0] max_ind = np.argmax(s[1:]) + 1 self.assertEqual(f_ind, max_ind) - self.assertAlmostEqual(s[f_ind], 1., delta=0.01) + self.assertAlmostEqual(s[f_ind], 1.0, delta=0.01) # ========================================================================= # Tests for correct return values @@ -372,14 +457,13 @@ def test_spike_field_coherence_perfect_coherence(self): warning RuntimeWarning: invalid value encountered in true_divide Cxy = np.abs(Pxy)**2 / Pxx / Pyy. """ - s, f = sta.spike_field_coherence( - self.anasig0, self.bst0, window='boxcar') + s, f = sta.spike_field_coherence(self.anasig0, self.bst0, window="boxcar") - f_ind = np.where(f >= 19.)[0][0] + f_ind = np.where(f >= 19.0)[0][0] max_ind = np.argmax(s[1:]) + 1 self.assertEqual(f_ind, max_ind) - self.assertAlmostEqual(s[f_ind], 1., delta=0.01) + self.assertAlmostEqual(s[f_ind], 1.0, delta=0.01) def test_output_frequencies(self): nfft = 256 @@ -388,10 +472,8 @@ def test_output_frequencies(self): # check number of frequency samples self.assertEqual(len(f), nfft / 2 + 1) - f_max = self.anasig3.sampling_rate.rescale('Hz').magnitude / 2 - f_ground_truth = np.linspace(start=0, - stop=f_max, - num=nfft // 2 + 1) * pq.Hz + f_max = self.anasig3.sampling_rate.rescale("Hz").magnitude / 2 + f_ground_truth = np.linspace(start=0, stop=f_max, num=nfft // 2 + 1) * pq.Hz # check values of frequency samples assert_array_almost_equal(f, f_ground_truth) @@ -405,13 +487,11 @@ def test_short_spiketrain(self): warning RuntimeWarning: invalid value encountered in true_divide Cxy = np.abs(Pxy)**2 / Pxx / Pyy. """ - s1, f1 = sta.spike_field_coherence( - self.anasig0, self.bst3, window='boxcar') + s1, f1 = sta.spike_field_coherence(self.anasig0, self.bst3, window="boxcar") # this spike train has the same spikes as above, # but it's shorter than anasig0 - s2, f2 = sta.spike_field_coherence( - self.anasig0, self.bst4, window='boxcar') + s2, f2 = sta.spike_field_coherence(self.anasig0, self.bst4, window="boxcar") # the results above should be the same, nevertheless assert_array_equal(s1.magnitude, s2.magnitude) @@ -422,30 +502,38 @@ def test_short_spiketrain(self): # Tests for old scipy verison (without scipy.signal.coherence) # ========================================================================= -@unittest.skipIf(hasattr(scipy.signal, 'coherence'), 'Applies only for old ' - 'scipy versions (<0.16)') -class sfc_TestCase_old_scipy(unittest.TestCase): +@unittest.skipIf( + hasattr(scipy.signal, "coherence"), "Applies only for old " "scipy versions (<0.16)" +) +class sfc_TestCase_old_scipy(unittest.TestCase): def setUp(self): # standard testsignals tlen0 = 100 * pq.s - f0 = 20. * pq.Hz + f0 = 20.0 * pq.Hz fs0 = 1 * pq.ms - t0 = np.arange( - 0, tlen0.rescale(pq.s).magnitude, - fs0.rescale(pq.s).magnitude) * pq.s + t0 = ( + np.arange(0, tlen0.rescale(pq.s).magnitude, fs0.rescale(pq.s).magnitude) + * pq.s + ) self.anasig0 = AnalogSignal( np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), - units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) + units=pq.mV, + t_start=0 * pq.ms, + sampling_period=fs0, + ) self.st0 = SpikeTrain( np.arange(0, tlen0.rescale(pq.ms).magnitude, 50) * pq.ms, - t_start=0 * pq.ms, t_stop=tlen0) + t_start=0 * pq.ms, + t_stop=tlen0, + ) self.bst0 = BinnedSpikeTrain(self.st0, bin_size=fs0) def test_old_scipy_version(self): - self.assertRaises(AttributeError, sta.spike_field_coherence, - self.anasig0, self.bst0) + self.assertRaises( + AttributeError, sta.spike_field_coherence, self.anasig0, self.bst0 + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..8a67888eb 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -5,6 +5,7 @@ :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ + from __future__ import division import itertools @@ -15,8 +16,11 @@ import numpy as np import quantities as pq import scipy.integrate as spint -from numpy.testing import assert_array_almost_equal, assert_array_equal, \ - assert_array_less +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_array_less, +) import elephant.kernels as kernels from elephant import statistics from elephant.spike_train_generation import StationaryPoissonProcess @@ -26,29 +30,33 @@ class IsiTestCase(unittest.TestCase): def setUp(self): - self.test_array_2d = np.array([[0.3, 0.56, 0.87, 1.23], - [0.02, 0.71, 1.82, 8.46], - [0.03, 0.14, 0.15, 0.92]]) - self.targ_array_2d_0 = np.array([[-0.28, 0.15, 0.95, 7.23], - [0.01, -0.57, -1.67, -7.54]]) - self.targ_array_2d_1 = np.array([[0.26, 0.31, 0.36], - [0.69, 1.11, 6.64], - [0.11, 0.01, 0.77]]) + self.test_array_2d = np.array( + [ + [0.3, 0.56, 0.87, 1.23], + [0.02, 0.71, 1.82, 8.46], + [0.03, 0.14, 0.15, 0.92], + ] + ) + self.targ_array_2d_0 = np.array( + [[-0.28, 0.15, 0.95, 7.23], [0.01, -0.57, -1.67, -7.54]] + ) + self.targ_array_2d_1 = np.array( + [[0.26, 0.31, 0.36], [0.69, 1.11, 6.64], [0.11, 0.01, 0.77]] + ) self.targ_array_2d_default = self.targ_array_2d_1 self.test_array_1d = self.test_array_2d[0, :] self.targ_array_1d = self.targ_array_2d_1[0, :] def test_isi_with_spiketrain(self): - st = neo.SpikeTrain( - self.test_array_1d, units='ms', t_stop=10.0, t_start=0.29) - target = pq.Quantity(self.targ_array_1d, 'ms') + st = neo.SpikeTrain(self.test_array_1d, units="ms", t_stop=10.0, t_start=0.29) + target = pq.Quantity(self.targ_array_1d, "ms") res = statistics.isi(st) assert_array_almost_equal(res, target, decimal=9) def test_isi_with_quantities_1d(self): - st = pq.Quantity(self.test_array_1d, units='ms') - target = pq.Quantity(self.targ_array_1d, 'ms') + st = pq.Quantity(self.test_array_1d, units="ms") + target = pq.Quantity(self.targ_array_1d, "ms") res = statistics.isi(st) assert_array_almost_equal(res, target, decimal=9) @@ -92,7 +100,7 @@ def setUp(self): self.test_array_regular = np.arange(1, 6) def test_cv_isi_regular_spiketrain_is_zero(self): - st = neo.SpikeTrain(self.test_array_regular, units='ms', t_stop=10.0) + st = neo.SpikeTrain(self.test_array_regular, units="ms", t_stop=10.0) targ = 0.0 res = statistics.cv(statistics.isi(st)) self.assertEqual(res, targ) @@ -107,9 +115,13 @@ def test_cv_isi_regular_array_is_zero(self): class MeanFiringRateTestCase(unittest.TestCase): def setUp(self): self.test_array_3d = np.ones([5, 7, 13]) - self.test_array_2d = np.array([[0.3, 0.56, 0.87, 1.23], - [0.02, 0.71, 1.82, 8.46], - [0.03, 0.14, 0.15, 0.92]]) + self.test_array_2d = np.array( + [ + [0.3, 0.56, 0.87, 1.23], + [0.02, 0.71, 1.82, 8.46], + [0.03, 0.14, 0.15, 0.92], + ] + ) self.targ_array_2d_0 = np.array([3, 3, 3, 3]) self.targ_array_2d_1 = np.array([4, 4, 4]) @@ -129,48 +141,48 @@ def test_invalid_input_spiketrain(self): # empty spiketrain self.assertRaises(ValueError, statistics.mean_firing_rate, []) for st_invalid in (None, 0.1): - self.assertRaises(TypeError, statistics.mean_firing_rate, - st_invalid) + self.assertRaises(TypeError, statistics.mean_firing_rate, st_invalid) def test_mean_firing_rate_with_spiketrain(self): - st = neo.SpikeTrain(self.test_array_1d, units='ms', t_stop=10.0) - target = pq.Quantity(self.targ_array_1d / 10., '1/ms') + st = neo.SpikeTrain(self.test_array_1d, units="ms", t_stop=10.0) + target = pq.Quantity(self.targ_array_1d / 10.0, "1/ms") res = statistics.mean_firing_rate(st) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_typical_use_case(self): np.random.seed(92) st = StationaryPoissonProcess( - rate=100 * pq.Hz, t_stop=100 * pq.s).generate_spiketrain() + rate=100 * pq.Hz, t_stop=100 * pq.s + ).generate_spiketrain() rate1 = statistics.mean_firing_rate(st) - rate2 = statistics.mean_firing_rate(st, t_start=st.t_start, - t_stop=st.t_stop) + rate2 = statistics.mean_firing_rate(st, t_start=st.t_start, t_stop=st.t_stop) self.assertEqual(rate1.units, rate2.units) self.assertAlmostEqual(rate1.item(), rate2.item()) def test_mean_firing_rate_with_spiketrain_set_ends(self): - st = neo.SpikeTrain(self.test_array_1d, units='ms', t_stop=10.0) - target = pq.Quantity(2 / 0.5, '1/ms') - res = statistics.mean_firing_rate(st, t_start=0.4 * pq.ms, - t_stop=0.9 * pq.ms) + st = neo.SpikeTrain(self.test_array_1d, units="ms", t_stop=10.0) + target = pq.Quantity(2 / 0.5, "1/ms") + res = statistics.mean_firing_rate(st, t_start=0.4 * pq.ms, t_stop=0.9 * pq.ms) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_with_quantities_1d(self): - st = pq.Quantity(self.test_array_1d, units='ms') - target = pq.Quantity(self.targ_array_1d / self.max_array_1d, '1/ms') + st = pq.Quantity(self.test_array_1d, units="ms") + target = pq.Quantity(self.targ_array_1d / self.max_array_1d, "1/ms") res = statistics.mean_firing_rate(st) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_with_quantities_1d_set_ends(self): - st = pq.Quantity(self.test_array_1d, units='ms') + st = pq.Quantity(self.test_array_1d, units="ms") # t_stop is not a Quantity - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_start=400 * pq.us, t_stop=1.) + self.assertRaises( + TypeError, statistics.mean_firing_rate, st, t_start=400 * pq.us, t_stop=1.0 + ) # t_start is not a Quantity - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_start=0.4, t_stop=1. * pq.ms) + self.assertRaises( + TypeError, statistics.mean_firing_rate, st, t_start=0.4, t_stop=1.0 * pq.ms + ) def test_mean_firing_rate_with_plain_array_1d(self): st = self.test_array_1d @@ -209,37 +221,36 @@ def test_mean_firing_rate_with_plain_array_2d_1(self): def test_mean_firing_rate_with_plain_array_3d_None(self): st = self.test_array_3d - target = np.sum(self.test_array_3d, None) / 5. - res = statistics.mean_firing_rate(st, axis=None, t_stop=5.) + target = np.sum(self.test_array_3d, None) / 5.0 + res = statistics.mean_firing_rate(st, axis=None, t_stop=5.0) assert not isinstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_with_plain_array_3d_0(self): st = self.test_array_3d - target = np.sum(self.test_array_3d, 0) / 5. - res = statistics.mean_firing_rate(st, axis=0, t_stop=5.) + target = np.sum(self.test_array_3d, 0) / 5.0 + res = statistics.mean_firing_rate(st, axis=0, t_stop=5.0) assert not isinstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_with_plain_array_3d_1(self): st = self.test_array_3d - target = np.sum(self.test_array_3d, 1) / 5. - res = statistics.mean_firing_rate(st, axis=1, t_stop=5.) + target = np.sum(self.test_array_3d, 1) / 5.0 + res = statistics.mean_firing_rate(st, axis=1, t_stop=5.0) assert not isinstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_with_plain_array_3d_2(self): st = self.test_array_3d - target = np.sum(self.test_array_3d, 2) / 5. - res = statistics.mean_firing_rate(st, axis=2, t_stop=5.) + target = np.sum(self.test_array_3d, 2) / 5.0 + res = statistics.mean_firing_rate(st, axis=2, t_stop=5.0) assert not isinstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) def test_mean_firing_rate_with_plain_array_2d_1_set_ends(self): st = self.test_array_2d target = np.array([4, 1, 3]) / (1.23 - 0.14) - res = statistics.mean_firing_rate(st, axis=1, t_start=0.14, - t_stop=1.23) + res = statistics.mean_firing_rate(st, axis=1, t_start=0.14, t_stop=1.23) assert not isinstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) @@ -250,22 +261,35 @@ def test_mean_firing_rate_with_plain_array_2d_None(self): assert not isinstance(res, pq.Quantity) assert_array_almost_equal(res, target, decimal=9) - def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror( - self): + def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror(self): st = self.test_array_2d - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_start=pq.Quantity(0, 'ms')) - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_stop=pq.Quantity(10, 'ms')) - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_start=pq.Quantity(0, 'ms'), - t_stop=pq.Quantity(10, 'ms')) - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_start=pq.Quantity(0, 'ms'), - t_stop=10.) - self.assertRaises(TypeError, statistics.mean_firing_rate, st, - t_start=0., - t_stop=pq.Quantity(10, 'ms')) + self.assertRaises( + TypeError, statistics.mean_firing_rate, st, t_start=pq.Quantity(0, "ms") + ) + self.assertRaises( + TypeError, statistics.mean_firing_rate, st, t_stop=pq.Quantity(10, "ms") + ) + self.assertRaises( + TypeError, + statistics.mean_firing_rate, + st, + t_start=pq.Quantity(0, "ms"), + t_stop=pq.Quantity(10, "ms"), + ) + self.assertRaises( + TypeError, + statistics.mean_firing_rate, + st, + t_start=pq.Quantity(0, "ms"), + t_stop=10.0, + ) + self.assertRaises( + TypeError, + statistics.mean_firing_rate, + st, + t_start=0.0, + t_stop=pq.Quantity(10, "ms"), + ) class FanoFactorTestCase(unittest.TestCase): @@ -279,9 +303,9 @@ def setUp(self): self.sp_counts = np.zeros(num_st) for i in range(num_st): r = np.random.rand(np.random.randint(20) + 1) - st = neo.core.SpikeTrain(r * pq.ms, - t_start=0.0 * pq.ms, - t_stop=20.0 * pq.ms) + st = neo.core.SpikeTrain( + r * pq.ms, t_start=0.0 * pq.ms, t_stop=20.0 * pq.ms + ) self.test_spiketrains.append(st) self.test_array.append(r) self.test_quantity.append(r * pq.ms) @@ -293,7 +317,8 @@ def test_fanofactor_spiketrains(self): # Test with list of spiketrains self.assertEqual( np.var(self.sp_counts) / np.mean(self.sp_counts), - statistics.fanofactor(self.test_spiketrains)) + statistics.fanofactor(self.test_spiketrains), + ) # One spiketrain in list st = self.test_spiketrains[0] @@ -308,8 +333,7 @@ def test_fanofactor_empty(self): self.assertTrue(np.isnan(statistics.fanofactor([] * pq.ms))) # Empty spiketrain - st = neo.core.SpikeTrain([] * pq.ms, t_start=0 * pq.ms, - t_stop=1.5 * pq.ms) + st = neo.core.SpikeTrain([] * pq.ms, t_start=0 * pq.ms, t_stop=1.5 * pq.ms) self.assertTrue(np.isnan(statistics.fanofactor(st))) def test_fanofactor_spiketrains_same(self): @@ -318,24 +342,30 @@ def test_fanofactor_spiketrains_same(self): self.assertEqual(statistics.fanofactor(sts), 0.0) def test_fanofactor_array(self): - self.assertEqual(statistics.fanofactor(self.test_array), - np.var(self.sp_counts) / np.mean(self.sp_counts)) + self.assertEqual( + statistics.fanofactor(self.test_array), + np.var(self.sp_counts) / np.mean(self.sp_counts), + ) def test_fanofactor_array_same(self): lst = [self.test_array[0]] * 3 self.assertEqual(statistics.fanofactor(lst), 0.0) def test_fanofactor_quantity(self): - self.assertEqual(statistics.fanofactor(self.test_quantity), - np.var(self.sp_counts) / np.mean(self.sp_counts)) + self.assertEqual( + statistics.fanofactor(self.test_quantity), + np.var(self.sp_counts) / np.mean(self.sp_counts), + ) def test_fanofactor_quantity_same(self): lst = [self.test_quantity[0]] * 3 self.assertEqual(statistics.fanofactor(lst), 0.0) def test_fanofactor_list(self): - self.assertEqual(statistics.fanofactor(self.test_list), - np.var(self.sp_counts) / np.mean(self.sp_counts)) + self.assertEqual( + statistics.fanofactor(self.test_list), + np.var(self.sp_counts) / np.mean(self.sp_counts), + ) def test_fanofactor_list_same(self): lst = [self.test_list[0]] * 3 @@ -349,27 +379,118 @@ def test_fanofactor_different_durations(self): def test_fanofactor_wrong_type(self): # warn_tolerance is not a quantity st1 = neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=4 * pq.s) - self.assertRaises(TypeError, statistics.fanofactor, [st1], - warn_tolerance=1e-4) + self.assertRaises(TypeError, statistics.fanofactor, [st1], warn_tolerance=1e-4) class LVTestCase(unittest.TestCase): def setUp(self): - self.test_seq = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12, - 4, 12, 59, 2, 4, 18, 33, 25, 2, 34, - 4, 1, 1, 14, 8, 1, 10, 1, 8, 20, - 5, 1, 6, 5, 12, 2, 8, 8, 2, 8, - 2, 10, 2, 1, 1, 2, 15, 3, 20, 6, - 11, 6, 18, 2, 5, 17, 4, 3, 13, 6, - 1, 18, 1, 16, 12, 2, 52, 2, 5, 7, - 6, 25, 6, 5, 3, 15, 4, 3, 16, 3, - 6, 5, 24, 21, 3, 3, 4, 8, 4, 11, - 5, 7, 5, 6, 8, 11, 33, 10, 7, 4] + self.test_seq = [ + 1, + 28, + 4, + 47, + 5, + 16, + 2, + 5, + 21, + 12, + 4, + 12, + 59, + 2, + 4, + 18, + 33, + 25, + 2, + 34, + 4, + 1, + 1, + 14, + 8, + 1, + 10, + 1, + 8, + 20, + 5, + 1, + 6, + 5, + 12, + 2, + 8, + 8, + 2, + 8, + 2, + 10, + 2, + 1, + 1, + 2, + 15, + 3, + 20, + 6, + 11, + 6, + 18, + 2, + 5, + 17, + 4, + 3, + 13, + 6, + 1, + 18, + 1, + 16, + 12, + 2, + 52, + 2, + 5, + 7, + 6, + 25, + 6, + 5, + 3, + 15, + 4, + 3, + 16, + 3, + 6, + 5, + 24, + 21, + 3, + 3, + 4, + 8, + 4, + 11, + 5, + 7, + 5, + 6, + 8, + 11, + 33, + 10, + 7, + 4, + ] self.target = 0.971826029994 def test_lv_with_quantities(self): - seq = pq.Quantity(self.test_seq, units='ms') + seq = pq.Quantity(self.test_seq, units="ms") assert_array_almost_equal(statistics.lv(seq), self.target, decimal=9) def test_lv_with_plain_array(self): @@ -396,36 +517,126 @@ def test_2short_spike_train(self): class LVRTestCase(unittest.TestCase): def setUp(self): - self.test_seq = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12, - 4, 12, 59, 2, 4, 18, 33, 25, 2, 34, - 4, 1, 1, 14, 8, 1, 10, 1, 8, 20, - 5, 1, 6, 5, 12, 2, 8, 8, 2, 8, - 2, 10, 2, 1, 1, 2, 15, 3, 20, 6, - 11, 6, 18, 2, 5, 17, 4, 3, 13, 6, - 1, 18, 1, 16, 12, 2, 52, 2, 5, 7, - 6, 25, 6, 5, 3, 15, 4, 3, 16, 3, - 6, 5, 24, 21, 3, 3, 4, 8, 4, 11, - 5, 7, 5, 6, 8, 11, 33, 10, 7, 4] + self.test_seq = [ + 1, + 28, + 4, + 47, + 5, + 16, + 2, + 5, + 21, + 12, + 4, + 12, + 59, + 2, + 4, + 18, + 33, + 25, + 2, + 34, + 4, + 1, + 1, + 14, + 8, + 1, + 10, + 1, + 8, + 20, + 5, + 1, + 6, + 5, + 12, + 2, + 8, + 8, + 2, + 8, + 2, + 10, + 2, + 1, + 1, + 2, + 15, + 3, + 20, + 6, + 11, + 6, + 18, + 2, + 5, + 17, + 4, + 3, + 13, + 6, + 1, + 18, + 1, + 16, + 12, + 2, + 52, + 2, + 5, + 7, + 6, + 25, + 6, + 5, + 3, + 15, + 4, + 3, + 16, + 3, + 6, + 5, + 24, + 21, + 3, + 3, + 4, + 8, + 4, + 11, + 5, + 7, + 5, + 6, + 8, + 11, + 33, + 10, + 7, + 4, + ] self.target = 2.1845363464753134 def test_lvr_with_quantities(self): - seq = pq.Quantity(self.test_seq, units='ms') + seq = pq.Quantity(self.test_seq, units="ms") assert_array_almost_equal(statistics.lvr(seq), self.target, decimal=9) - seq = pq.Quantity(self.test_seq, units='ms').rescale('s', dtype=float) + seq = pq.Quantity(self.test_seq, units="ms").rescale("s", dtype=float) assert_array_almost_equal(statistics.lvr(seq), self.target, decimal=9) def test_lvr_with_plain_array(self): seq = np.array(self.test_seq) with self.assertWarns(UserWarning): - assert_array_almost_equal(statistics.lvr(seq), - self.target, decimal=9) + assert_array_almost_equal(statistics.lvr(seq), self.target, decimal=9) def test_lvr_with_list(self): seq = self.test_seq with self.assertWarns(UserWarning): - assert_array_almost_equal(statistics.lvr(seq), - self.target, decimal=9) + assert_array_almost_equal(statistics.lvr(seq), self.target, decimal=9) def test_lvr_raise_error(self): seq = self.test_seq @@ -437,8 +648,7 @@ def test_lvr_raise_error(self): def test_lvr_refractoriness_kwarg(self): seq = np.array(self.test_seq) with self.assertWarns(UserWarning): - assert_array_almost_equal(statistics.lvr(seq, R=5), - self.target, decimal=9) + assert_array_almost_equal(statistics.lvr(seq, R=5), self.target, decimal=9) def test_2short_spike_train(self): seq = [1] @@ -450,21 +660,113 @@ def test_2short_spike_train(self): class CV2TestCase(unittest.TestCase): def setUp(self): - self.test_seq = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12, - 4, 12, 59, 2, 4, 18, 33, 25, 2, 34, - 4, 1, 1, 14, 8, 1, 10, 1, 8, 20, - 5, 1, 6, 5, 12, 2, 8, 8, 2, 8, - 2, 10, 2, 1, 1, 2, 15, 3, 20, 6, - 11, 6, 18, 2, 5, 17, 4, 3, 13, 6, - 1, 18, 1, 16, 12, 2, 52, 2, 5, 7, - 6, 25, 6, 5, 3, 15, 4, 3, 16, 3, - 6, 5, 24, 21, 3, 3, 4, 8, 4, 11, - 5, 7, 5, 6, 8, 11, 33, 10, 7, 4] + self.test_seq = [ + 1, + 28, + 4, + 47, + 5, + 16, + 2, + 5, + 21, + 12, + 4, + 12, + 59, + 2, + 4, + 18, + 33, + 25, + 2, + 34, + 4, + 1, + 1, + 14, + 8, + 1, + 10, + 1, + 8, + 20, + 5, + 1, + 6, + 5, + 12, + 2, + 8, + 8, + 2, + 8, + 2, + 10, + 2, + 1, + 1, + 2, + 15, + 3, + 20, + 6, + 11, + 6, + 18, + 2, + 5, + 17, + 4, + 3, + 13, + 6, + 1, + 18, + 1, + 16, + 12, + 2, + 52, + 2, + 5, + 7, + 6, + 25, + 6, + 5, + 3, + 15, + 4, + 3, + 16, + 3, + 6, + 5, + 24, + 21, + 3, + 3, + 4, + 8, + 4, + 11, + 5, + 7, + 5, + 6, + 8, + 11, + 33, + 10, + 7, + 4, + ] self.target = 1.0022235296529176 def test_cv2_with_quantities(self): - seq = pq.Quantity(self.test_seq, units='ms') + seq = pq.Quantity(self.test_seq, units="ms") assert_array_almost_equal(statistics.cv2(seq), self.target, decimal=9) def test_cv2_with_plain_array(self): @@ -483,7 +785,6 @@ def test_cv2_raise_error(self): class InstantaneousRateTest(unittest.TestCase): - @classmethod def setUpClass(cls) -> None: """ @@ -492,8 +793,7 @@ def setUpClass(cls) -> None: block = _create_trials_block(n_trials=36) cls.block = block - cls.trial_object = TrialsFromBlock(block, - description='trials are segments') + cls.trial_object = TrialsFromBlock(block, description="trials are segments") # create a poisson spike train: cls.st_tr = (0, 20.0) # seconds @@ -502,87 +802,123 @@ def setUpClass(cls) -> None: cls.st_rate = 10.0 # Hertz np.random.seed(19) duration_effective = cls.st_dur - 2 * cls.st_margin - st_num_spikes = np.random.poisson( - cls.st_rate * duration_effective) + st_num_spikes = np.random.poisson(cls.st_rate * duration_effective) spike_train = sorted( - np.random.rand(st_num_spikes) * - duration_effective + - cls.st_margin) + np.random.rand(st_num_spikes) * duration_effective + cls.st_margin + ) # convert spike train into neo objects - cls.spike_train = neo.SpikeTrain(spike_train * pq.s, - t_start=cls.st_tr[0] * pq.s, - t_stop=cls.st_tr[1] * pq.s) + cls.spike_train = neo.SpikeTrain( + spike_train * pq.s, t_start=cls.st_tr[0] * pq.s, t_stop=cls.st_tr[1] * pq.s + ) # generation of a multiply used specific kernel cls.kernel = kernels.TriangularKernel(sigma=0.03 * pq.s) # calculate instantaneous rate cls.sampling_period = 0.01 * pq.s cls.inst_rate = statistics.instantaneous_rate( - cls.spike_train, cls.sampling_period, cls.kernel, cutoff=0) + cls.spike_train, cls.sampling_period, cls.kernel, cutoff=0 + ) def test_instantaneous_rate_warnings(self): with self.assertWarns(UserWarning): # Catches warning: The width of the kernel was adjusted to a # minimally allowed width. - statistics.instantaneous_rate(self.spike_train, - self.sampling_period, - self.kernel, cutoff=0) + statistics.instantaneous_rate( + self.spike_train, self.sampling_period, self.kernel, cutoff=0 + ) def test_instantaneous_rate_errors(self): self.assertRaises( # input is not neo.SpikeTrain - TypeError, statistics.instantaneous_rate, + TypeError, + statistics.instantaneous_rate, spiketrains=[1, 2, 3] * pq.s, - sampling_period=0.01 * pq.ms, kernel=self.kernel) + sampling_period=0.01 * pq.ms, + kernel=self.kernel, + ) self.assertRaises( # sampling period is not time quantity - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, kernel=self.kernel, - sampling_period=0.01) + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + kernel=self.kernel, + sampling_period=0.01, + ) self.assertRaises( # sampling period is < 0 - ValueError, statistics.instantaneous_rate, - spiketrains=self.spike_train, kernel=self.kernel, - sampling_period=-0.01 * pq.ms) + ValueError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + kernel=self.kernel, + sampling_period=-0.01 * pq.ms, + ) self.assertRaises( # no kernel or kernel='auto' - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, sampling_period=0.01 * pq.ms, - kernel='NONE') + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + sampling_period=0.01 * pq.ms, + kernel="NONE", + ) self.assertRaises( # wrong string for kernel='string' - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, sampling_period=0.01 * pq.s, - kernel='wrong_string') + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + sampling_period=0.01 * pq.s, + kernel="wrong_string", + ) self.assertRaises( # cutoff is not float or int - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, sampling_period=0.01 * pq.ms, - kernel=self.kernel, cutoff=20 * pq.ms) + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + sampling_period=0.01 * pq.ms, + kernel=self.kernel, + cutoff=20 * pq.ms, + ) self.assertRaises( # t_start not time quantity - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, sampling_period=0.01 * pq.ms, - kernel=self.kernel, t_start=2) + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + sampling_period=0.01 * pq.ms, + kernel=self.kernel, + t_start=2, + ) self.assertRaises( # t_stop not time quantity - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, sampling_period=0.01 * pq.ms, - kernel=self.kernel, t_stop=20 * pq.mV) + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + sampling_period=0.01 * pq.ms, + kernel=self.kernel, + t_stop=20 * pq.mV, + ) self.assertRaises( # trim is not bool - TypeError, statistics.instantaneous_rate, - spiketrains=self.spike_train, sampling_period=0.01 * pq.ms, - kernel=self.kernel, trim=1) + TypeError, + statistics.instantaneous_rate, + spiketrains=self.spike_train, + sampling_period=0.01 * pq.ms, + kernel=self.kernel, + trim=1, + ) self.assertRaises( # can't estimate a kernel for a list of spiketrains - ValueError, statistics.instantaneous_rate, + ValueError, + statistics.instantaneous_rate, spiketrains=[self.spike_train, self.spike_train], - sampling_period=10 * pq.ms, kernel='auto') + sampling_period=10 * pq.ms, + kernel="auto", + ) def test_instantaneous_rate_output(self): # return type correct self.assertIsInstance(self.inst_rate, neo.core.AnalogSignal) # sampling_period input and output same - self.assertEqual(self.inst_rate.sampling_period.simplified, - self.sampling_period.simplified) + self.assertEqual( + self.inst_rate.sampling_period.simplified, self.sampling_period.simplified + ) # return correct units pq.Hz self.assertEqual(self.inst_rate.simplified.units, pq.Hz) # input and output t_stop same - self.assertEqual(self.spike_train.t_stop.simplified, - self.inst_rate.t_stop.simplified) + self.assertEqual( + self.spike_train.t_stop.simplified, self.inst_rate.t_stop.simplified + ) # input and output t_start same - self.assertEqual(self.inst_rate.t_start.simplified, - self.spike_train.t_start.simplified) + self.assertEqual( + self.inst_rate.t_start.simplified, self.spike_train.t_start.simplified + ) def test_instantaneous_rate_rate_estimation_consistency(self): """ @@ -590,15 +926,18 @@ def test_instantaneous_rate_rate_estimation_consistency(self): equal to the number of spikes of the spike train. """ kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.Kernel) and - kern_cls is not kernels.Kernel and - kern_cls is not kernels.SymmetricKernel) + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.Kernel) + and kern_cls is not kernels.Kernel + and kern_cls is not kernels.SymmetricKernel + ) # set sigma - kernels_available = [kern_cls(sigma=0.5 * pq.s, invert=False) - for kern_cls in kernel_types] - kernels_available.append('auto') + kernels_available = [ + kern_cls(sigma=0.5 * pq.s, invert=False) for kern_cls in kernel_types + ] + kernels_available.append("auto") kernel_resolution = 0.01 * pq.s for kernel in kernels_available: border_correction = False @@ -613,14 +952,16 @@ def test_instantaneous_rate_rate_estimation_consistency(self): t_stop=self.st_tr[1] * pq.s, trim=False, center_kernel=center_kernel, - border_correction=border_correction + border_correction=border_correction, ) num_spikes = len(self.spike_train) area_under_curve = spint.cumulative_trapezoid( y=rate_estimate.magnitude[:, 0], - x=rate_estimate.times.rescale('s').magnitude)[-1] - self.assertAlmostEqual(num_spikes, area_under_curve, - delta=0.01 * num_spikes) + x=rate_estimate.times.rescale("s").magnitude, + )[-1] + self.assertAlmostEqual( + num_spikes, area_under_curve, delta=0.01 * num_spikes + ) def test_instantaneous_rate_regression_107(self): # Create a spiketrain with t_start=0s, t_stop=2s and a single spike at @@ -629,12 +970,16 @@ def test_instantaneous_rate_regression_107(self): # into the future' from the perspective of the neuron. t_spike = 1 * pq.s spiketrain = neo.SpikeTrain( - [t_spike], t_start=0 * pq.s, t_stop=2 * pq.s, units=pq.s) + [t_spike], t_start=0 * pq.s, t_stop=2 * pq.s, units=pq.s + ) kernel = kernels.AlphaKernel(200 * pq.ms) sampling_period = 0.1 * pq.ms rate = statistics.instantaneous_rate( - spiketrains=spiketrain, sampling_period=sampling_period, - kernel=kernel, center_kernel=False) + spiketrains=spiketrain, + sampling_period=sampling_period, + kernel=kernel, + center_kernel=False, + ) # find positive nonezero rate estimates rate_nonzero_index = np.nonzero(rate > 1e-6)[0] # find times, where the mass is concentrated, i.e. rate is estimated>0 @@ -653,15 +998,18 @@ def test_instantaneous_rate_regression_288(self): np.random.seed(9) sampling_period = 200 * pq.ms spiketrain = StationaryPoissonProcess( - 10 * pq.Hz, t_start=0 * pq.s, - t_stop=10 * pq.s).generate_spiketrain() + 10 * pq.Hz, t_start=0 * pq.s, t_stop=10 * pq.s + ).generate_spiketrain() kernel = kernels.AlphaKernel(sigma=5 * pq.ms, invert=True) _ = statistics.instantaneous_rate( - spiketrain, sampling_period=sampling_period, kernel=kernel) + spiketrain, sampling_period=sampling_period, kernel=kernel + ) except ValueError: - self.fail('When providing a kernel on a much smaller time scale ' - 'than sampling rate requested the instantaneous rate ' - 'estimation will fail on numpy level ') + self.fail( + "When providing a kernel on a much smaller time scale " + "than sampling rate requested the instantaneous rate " + "estimation will fail on numpy level " + ) def test_instantaneous_rate_small_kernel_sigma(self): # Test that the instantaneous rate is overestimated when @@ -672,22 +1020,26 @@ def test_instantaneous_rate_small_kernel_sigma(self): sigma = 5 * pq.ms rate_expected = 10 * pq.Hz spiketrain = StationaryPoissonProcess( - rate_expected, t_start=0 * pq.s, - t_stop=10 * pq.s).generate_spiketrain() + rate_expected, t_start=0 * pq.s, t_stop=10 * pq.s + ).generate_spiketrain() kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.Kernel) and - kern_cls is not kernels.Kernel and - kern_cls is not kernels.SymmetricKernel) + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.Kernel) + and kern_cls is not kernels.Kernel + and kern_cls is not kernels.SymmetricKernel + ) for kern_cls, invert in itertools.product(kernel_types, (False, True)): kernel = kern_cls(sigma=sigma, invert=invert) with self.subTest(kernel=kernel): rate = statistics.instantaneous_rate( spiketrain, sampling_period=sampling_period, - kernel=kernel, center_kernel=True) + kernel=kernel, + center_kernel=True, + ) self.assertGreater(rate.mean(), rate_expected) def test_instantaneous_rate_spikes_on_edges(self): @@ -699,24 +1051,28 @@ def test_instantaneous_rate_spikes_on_edges(self): sampling_period = 0.01 * pq.s # with t_spikes = [-5, 5]s the isi is 10s, so 1/isi 0.1 Hz t_spikes = np.array([-cutoff, cutoff]) * pq.s - spiketrain = neo.SpikeTrain(t_spikes, t_start=t_spikes[0], - t_stop=t_spikes[-1]) + spiketrain = neo.SpikeTrain(t_spikes, t_start=t_spikes[0], t_stop=t_spikes[-1]) kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.Kernel) and - kern_cls is not kernels.Kernel and - kern_cls is not kernels.SymmetricKernel) - kernels_available = [kern_cls(sigma=1 * pq.s, invert=False) - for kern_cls in kernel_types] + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.Kernel) + and kern_cls is not kernels.Kernel + and kern_cls is not kernels.SymmetricKernel + ) + kernels_available = [ + kern_cls(sigma=1 * pq.s, invert=False) for kern_cls in kernel_types + ] for kernel in kernels_available: for center_kernel in (False, True): rate = statistics.instantaneous_rate( spiketrain, sampling_period=sampling_period, kernel=kernel, - cutoff=cutoff, trim=True, - center_kernel=center_kernel) + cutoff=cutoff, + trim=True, + center_kernel=center_kernel, + ) assert_array_almost_equal(rate.magnitude, 0, decimal=2) def test_instantaneous_rate_center_kernel(self): @@ -728,28 +1084,38 @@ def test_instantaneous_rate_center_kernel(self): cutoff = 5 sampling_period = 0.01 * pq.s t_spikes = np.linspace(-cutoff, cutoff, num=(2 * cutoff + 1)) * pq.s - spiketrain = neo.SpikeTrain(t_spikes, t_start=t_spikes[0], - t_stop=t_spikes[-1]) + spiketrain = neo.SpikeTrain(t_spikes, t_start=t_spikes[0], t_stop=t_spikes[-1]) kernel = kernels.RectangularKernel(sigma=1 * pq.s) assert cutoff > kernel.min_cutoff, "Choose larger cutoff" kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.SymmetricKernel) and - kern_cls is not kernels.SymmetricKernel) - kernels_symmetric = [kern_cls(sigma=1 * pq.s, invert=False) - for kern_cls in kernel_types] + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.SymmetricKernel) + and kern_cls is not kernels.SymmetricKernel + ) + kernels_symmetric = [ + kern_cls(sigma=1 * pq.s, invert=False) for kern_cls in kernel_types + ] for kernel in kernels_symmetric: for trim in (False, True): rate_centered = statistics.instantaneous_rate( - spiketrain, sampling_period=sampling_period, - kernel=kernel, cutoff=cutoff, trim=trim, - center_kernel=True) + spiketrain, + sampling_period=sampling_period, + kernel=kernel, + cutoff=cutoff, + trim=trim, + center_kernel=True, + ) rate_not_centered = statistics.instantaneous_rate( - spiketrain, sampling_period=sampling_period, - kernel=kernel, cutoff=cutoff, trim=trim, - center_kernel=False) + spiketrain, + sampling_period=sampling_period, + kernel=kernel, + cutoff=cutoff, + trim=trim, + center_kernel=False, + ) assert_array_almost_equal(rate_centered, rate_not_centered) def test_instantaneous_rate_list_of_spiketrains(self): @@ -757,53 +1123,58 @@ def test_instantaneous_rate_list_of_spiketrains(self): duration_effective = self.st_dur - 2 * self.st_margin st_num_spikes = np.random.poisson(self.st_rate * duration_effective) spike_train2 = sorted( - np.random.rand(st_num_spikes) * - duration_effective + - self.st_margin) - spike_train2 = neo.SpikeTrain(spike_train2 * pq.s, - t_start=self.st_tr[0] * pq.s, - t_stop=self.st_tr[1] * pq.s) + np.random.rand(st_num_spikes) * duration_effective + self.st_margin + ) + spike_train2 = neo.SpikeTrain( + spike_train2 * pq.s, + t_start=self.st_tr[0] * pq.s, + t_stop=self.st_tr[1] * pq.s, + ) st_rate_1 = statistics.instantaneous_rate( - self.spike_train, sampling_period=self.sampling_period, - kernel=self.kernel) + self.spike_train, sampling_period=self.sampling_period, kernel=self.kernel + ) st_rate_2 = statistics.instantaneous_rate( - spike_train2, sampling_period=self.sampling_period, - kernel=self.kernel) + spike_train2, sampling_period=self.sampling_period, kernel=self.kernel + ) rate_concat = np.c_[st_rate_1, st_rate_2] combined_rate = statistics.instantaneous_rate( [self.spike_train, spike_train2], sampling_period=self.sampling_period, - kernel=self.kernel) + kernel=self.kernel, + ) # 'time_vector.dtype' in instantaneous_rate() is changed from float64 # to float32 which results in 3e-6 abs difference - assert_array_almost_equal(combined_rate.magnitude, - rate_concat.magnitude, decimal=5) + assert_array_almost_equal( + combined_rate.magnitude, rate_concat.magnitude, decimal=5 + ) def test_instantaneous_rate_regression_144(self): # The following spike train contains spikes that are so close to each # other, that the optimal kernel cannot be detected. Therefore, the # function should react with a ValueError. st = neo.SpikeTrain([2.12, 2.13, 2.15] * pq.s, t_stop=10 * pq.s) - self.assertRaises( - ValueError, statistics.instantaneous_rate, st, 1 * pq.ms) + self.assertRaises(ValueError, statistics.instantaneous_rate, st, 1 * pq.ms) def test_instantaneous_rate_regression_245(self): # This test makes sure that the correct kernel width is chosen when # selecting 'auto' as kernel spiketrain = neo.SpikeTrain( - pq.ms * range(1, 30), t_start=0 * pq.ms, t_stop=30 * pq.ms) + pq.ms * range(1, 30), t_start=0 * pq.ms, t_stop=30 * pq.ms + ) # This is the correct procedure to attain the kernel: first, the result # of sskernel retrieves the kernel bandwidth of an optimal Gaussian # kernel in terms of its standard deviation sigma, then uses this value # directly in the function for creating the Gaussian kernel kernel_width_sigma = statistics.optimal_kernel_bandwidth( - spiketrain.magnitude, times=None, bootstrap=False)['optw'] + spiketrain.magnitude, times=None, bootstrap=False + )["optw"] kernel = kernels.GaussianKernel(kernel_width_sigma * spiketrain.units) result_target = statistics.instantaneous_rate( - spiketrain, 10 * pq.ms, kernel=kernel) + spiketrain, 10 * pq.ms, kernel=kernel + ) # Here, we check if the 'auto' argument leads to the same operation. In # the regression, it was incorrectly assumed that the sskernel() @@ -813,7 +1184,8 @@ def test_instantaneous_rate_regression_245(self): # factor 2.7 connects half width of Gaussian distribution with # 99% probability mass with its standard deviation. result_automatic = statistics.instantaneous_rate( - spiketrain, 10 * pq.ms, kernel='auto') + spiketrain, 10 * pq.ms, kernel="auto" + ) assert_array_almost_equal(result_target, result_automatic) @@ -821,13 +1193,15 @@ def test_instantaneous_rate_grows_with_sampling_period(self): np.random.seed(0) rate_expected = 10 * pq.Hz spiketrain = StationaryPoissonProcess( - rate=rate_expected, t_stop=10 * pq.s).generate_spiketrain() + rate=rate_expected, t_stop=10 * pq.s + ).generate_spiketrain() kernel = kernels.GaussianKernel(sigma=100 * pq.ms) rates_mean = [] for sampling_period in np.linspace(1, 1000, num=10) * pq.ms: with self.subTest(sampling_period=sampling_period): rate = statistics.instantaneous_rate( - spiketrain, sampling_period=sampling_period, kernel=kernel) + spiketrain, sampling_period=sampling_period, kernel=kernel + ) rates_mean.append(rate.mean()) # rate means are greater or equal the expected rate assert_array_less(rate_expected, rates_mean) @@ -839,50 +1213,53 @@ def test_instantaneous_rate_regression_360(self): # with spikes at [-0.0001, 0, 0.0001]. # Skip RectangularKernel because it doesn't have a strong peak. kernel_types = tuple( - kern_cls for kern_cls in kernels.__dict__.values() - if isinstance(kern_cls, type) and - issubclass(kern_cls, kernels.SymmetricKernel) and - kern_cls not in (kernels.SymmetricKernel, - kernels.RectangularKernel)) - kernels_symmetric = [kern_cls(sigma=50 * pq.ms, invert=False) - for kern_cls in kernel_types] + kern_cls + for kern_cls in kernels.__dict__.values() + if isinstance(kern_cls, type) + and issubclass(kern_cls, kernels.SymmetricKernel) + and kern_cls not in (kernels.SymmetricKernel, kernels.RectangularKernel) + ) + kernels_symmetric = [ + kern_cls(sigma=50 * pq.ms, invert=False) for kern_cls in kernel_types + ] # first part: a symmetric spiketrain with a symmetric kernel - spiketrain = neo.SpikeTrain(np.array([-0.0001, 0, 0.0001]) * pq.s, - t_start=-1, - t_stop=1) + spiketrain = neo.SpikeTrain( + np.array([-0.0001, 0, 0.0001]) * pq.s, t_start=-1, t_stop=1 + ) for kernel in kernels_symmetric: - rate = statistics.instantaneous_rate(spiketrain, - sampling_period=20 * pq.ms, - kernel=kernel) + rate = statistics.instantaneous_rate( + spiketrain, sampling_period=20 * pq.ms, kernel=kernel + ) # the peak time must be centered at origin t=0 self.assertEqual(rate.times[np.argmax(rate)], 0) # second part: a single spike at t=0 - periods = [2 ** exp for exp in range(-3, 6)] + periods = [2**exp for exp in range(-3, 6)] for period in periods: with self.subTest(period=period): - spiketrain = neo.SpikeTrain(np.array([0]) * pq.s, - t_start=-period * 10 * pq.ms, - t_stop=period * 10 * pq.ms) + spiketrain = neo.SpikeTrain( + np.array([0]) * pq.s, + t_start=-period * 10 * pq.ms, + t_stop=period * 10 * pq.ms, + ) for kernel in kernels_symmetric: rate = statistics.instantaneous_rate( - spiketrain, - sampling_period=period * pq.ms, - kernel=kernel) + spiketrain, sampling_period=period * pq.ms, kernel=kernel + ) self.assertEqual(rate.times[np.argmax(rate)], 0) def test_instantaneous_rate_annotations(self): spiketrain = neo.SpikeTrain([1, 2], t_stop=2 * pq.s, units=pq.s) kernel = kernels.AlphaKernel(sigma=100 * pq.ms) - rate = statistics.instantaneous_rate(spiketrain, - sampling_period=10 * pq.ms, - kernel=kernel) - kernel_annotation = dict(type=type(kernel).__name__, - sigma=str(kernel.sigma), - invert=kernel.invert) - self.assertIn('kernel', rate.annotations) - self.assertEqual(rate.annotations['kernel'], kernel_annotation) + rate = statistics.instantaneous_rate( + spiketrain, sampling_period=10 * pq.ms, kernel=kernel + ) + kernel_annotation = dict( + type=type(kernel).__name__, sigma=str(kernel.sigma), invert=kernel.invert + ) + self.assertIn("kernel", rate.annotations) + self.assertEqual(rate.annotations["kernel"], kernel_annotation) def test_instantaneous_rate_regression_374(self): # Check if the last interval is dropped. @@ -892,28 +1269,32 @@ def test_instantaneous_rate_regression_374(self): # dropped and not be considered in the calculation. spike_times = np.array([9.65, 9.7, 9.75]) * pq.s - spiketrain = neo.SpikeTrain(spike_times, - t_start=0, - t_stop=9.8) + spiketrain = neo.SpikeTrain(spike_times, t_start=0, t_stop=9.8) kernel = kernels.GaussianKernel(sigma=250 * pq.ms) sampling_period = 1000 * pq.ms rate = statistics.instantaneous_rate( spiketrain, sampling_period=sampling_period, - kernel=kernel, center_kernel=False, trim=False, cutoff=1) + kernel=kernel, + center_kernel=False, + trim=False, + cutoff=1, + ) assert_array_almost_equal(rate.magnitude, 0) def test_instantaneous_rate_rate_times(self): # check if the differences between the rate.times is equal to # sampling_period st = self.spike_train - periods = [1, 0.99, 0.35, 11, st.duration]*pq.s + periods = [1, 0.99, 0.35, 11, st.duration] * pq.s for period in periods: - rate = statistics.instantaneous_rate(st, - sampling_period=period, - kernel=self.kernel, - center_kernel=True, - trim=False) + rate = statistics.instantaneous_rate( + st, + sampling_period=period, + kernel=self.kernel, + center_kernel=True, + trim=False, + ) rate_times_diff = np.diff(rate.times) period_times = np.full_like(rate_times_diff, period) assert_array_almost_equal(rate_times_diff, period_times) @@ -923,29 +1304,30 @@ def test_instantaneous_rate_bin_edges(self): # are multiples of the sampling rate. In the following example, the # rate maximum is expected to be at 5.785s. # See PR#453 https://github.com/NeuralEnsemble/elephant/pull/453 - spike_times = np.array( - [4.45, 4.895, 5.34, 5.785, 6.23, 6.675, 7.12]) * pq.s + spike_times = np.array([4.45, 4.895, 5.34, 5.785, 6.23, 6.675, 7.12]) * pq.s # add 0.01 s - shifted_spike_times = spike_times + .01 * pq.s + shifted_spike_times = spike_times + 0.01 * pq.s - spiketrain = neo.SpikeTrain(shifted_spike_times, - t_start=0, - t_stop=10) + spiketrain = neo.SpikeTrain(shifted_spike_times, t_start=0, t_stop=10) kernel = kernels.GaussianKernel(sigma=500 * pq.ms) sampling_period = 445 * pq.ms rate = statistics.instantaneous_rate( spiketrain, sampling_period=sampling_period, - kernel=kernel, center_kernel=True, trim=False) - self.assertAlmostEqual(spike_times[3].magnitude.item(), - rate.times[rate.argmax()].magnitude.item()) + kernel=kernel, + center_kernel=True, + trim=False, + ) + self.assertAlmostEqual( + spike_times[3].magnitude.item(), rate.times[rate.argmax()].magnitude.item() + ) def test_instantaneous_rate_border_correction(self): np.random.seed(0) n_spiketrains = 125 - rate = 50. * pq.Hz - t_start = 0. * pq.ms - t_stop = 1000. * pq.ms + rate = 50.0 * pq.Hz + t_start = 0.0 * pq.ms + t_stop = 1000.0 * pq.ms sampling_period = 0.1 * pq.ms trial_list = StationaryPoissonProcess( rate=rate, t_start=t_start, t_stop=t_stop @@ -957,8 +1339,8 @@ def test_instantaneous_rate_border_correction(self): instantaneous_rate = statistics.instantaneous_rate( spiketrains=trial, sampling_period=sampling_period, - kernel='auto', - border_correction=correction + kernel="auto", + border_correction=correction, ) rates.append(instantaneous_rate) # The average estimated rate gives the average estimated value of @@ -968,25 +1350,25 @@ def test_instantaneous_rate_border_correction(self): average_estimated_rate = np.mean(rates, axis=0)[:, 0] rtol = 0.05 # Five percent of tolerance if correction: - self.assertLess(np.max(average_estimated_rate), - (1. + rtol) * rate.item()) - self.assertGreater(np.min(average_estimated_rate), - (1. - rtol) * rate.item()) + self.assertLess(np.max(average_estimated_rate), (1.0 + rtol) * rate.item()) + self.assertGreater( + np.min(average_estimated_rate), (1.0 - rtol) * rate.item() + ) else: - self.assertLess(np.max(average_estimated_rate), - (1. + rtol) * rate.item()) + self.assertLess(np.max(average_estimated_rate), (1.0 + rtol) * rate.item()) # The minimal rate deviates strongly in the uncorrected case. - self.assertLess(np.min(average_estimated_rate), - (1. - rtol) * rate.item()) + self.assertLess(np.min(average_estimated_rate), (1.0 - rtol) * rate.item()) def test_instantaneous_rate_trials_pool_trials(self): kernel = kernels.GaussianKernel(sigma=500 * pq.ms) - rate = statistics.instantaneous_rate(self.trial_object, - sampling_period=0.1 * pq.ms, - kernel=kernel, - pool_spike_trains=False, - pool_trials=True) + rate = statistics.instantaneous_rate( + self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=False, + pool_trials=True, + ) self.assertIsInstance(rate, neo.core.AnalogSignal) def test_instantaneous_rate_list_pool_spike_trains(self): @@ -997,7 +1379,8 @@ def test_instantaneous_rate_list_pool_spike_trains(self): sampling_period=0.1 * pq.ms, kernel=kernel, pool_spike_trains=True, - pool_trials=False) + pool_trials=False, + ) self.assertIsInstance(rate, neo.core.AnalogSignal) self.assertEqual(rate.magnitude.shape[1], 1) @@ -1008,7 +1391,8 @@ def test_instantaneous_rate_list_of_spike_trains(self): sampling_period=0.1 * pq.ms, kernel=kernel, pool_spike_trains=False, - pool_trials=False) + pool_trials=False, + ) self.assertIsInstance(rate, neo.core.AnalogSignal) self.assertEqual(rate.magnitude.shape[1], 2) @@ -1016,9 +1400,11 @@ def test_instantaneous_rate_list_of_spike_trains(self): class TimeHistogramTestCase(unittest.TestCase): def setUp(self): self.spiketrain_a = neo.SpikeTrain( - [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) + [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s + ) self.spiketrain_b = neo.SpikeTrain( - [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s + ) self.spiketrains = [self.spiketrain_a, self.spiketrain_b] def tearDown(self): @@ -1034,77 +1420,86 @@ def test_time_histogram(self): def test_time_histogram_binary(self): targ = np.array([2, 2, 1, 1, 2, 2, 1, 0, 1, 0]) - histogram = statistics.time_histogram(self.spiketrains, bin_size=pq.s, - binary=True) + histogram = statistics.time_histogram( + self.spiketrains, bin_size=pq.s, binary=True + ) assert_array_equal(targ, histogram.magnitude[:, 0]) def test_time_histogram_tstart_tstop(self): # Start, stop short range targ = np.array([2, 1]) - histogram = statistics.time_histogram(self.spiketrains, bin_size=pq.s, - t_start=5 * pq.s, - t_stop=7 * pq.s) + histogram = statistics.time_histogram( + self.spiketrains, bin_size=pq.s, t_start=5 * pq.s, t_stop=7 * pq.s + ) assert_array_equal(targ, histogram.magnitude[:, 0]) # Test without t_stop targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0]) - histogram = statistics.time_histogram(self.spiketrains, - bin_size=1 * pq.s, - t_start=0 * pq.s) + histogram = statistics.time_histogram( + self.spiketrains, bin_size=1 * pq.s, t_start=0 * pq.s + ) assert_array_equal(targ, histogram.magnitude[:, 0]) # Test without t_start - histogram = statistics.time_histogram(self.spiketrains, - bin_size=1 * pq.s, - t_stop=10 * pq.s) + histogram = statistics.time_histogram( + self.spiketrains, bin_size=1 * pq.s, t_stop=10 * pq.s + ) assert_array_equal(targ, histogram.magnitude[:, 0]) def test_time_histogram_output(self): # Normalization mean - histogram = statistics.time_histogram(self.spiketrains, bin_size=pq.s, - output='mean') + histogram = statistics.time_histogram( + self.spiketrains, bin_size=pq.s, output="mean" + ) targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0], dtype=float) / 2 assert_array_equal(targ.reshape(targ.size, 1), histogram.magnitude) # Normalization rate - histogram = statistics.time_histogram(self.spiketrains, bin_size=pq.s, - output='rate') - assert_array_equal(histogram.view(pq.Quantity), - targ.reshape(targ.size, 1) * 1 / pq.s) + histogram = statistics.time_histogram( + self.spiketrains, bin_size=pq.s, output="rate" + ) + assert_array_equal( + histogram.view(pq.Quantity), targ.reshape(targ.size, 1) * 1 / pq.s + ) # Normalization unspecified, raises error - self.assertRaises(ValueError, statistics.time_histogram, - self.spiketrains, - bin_size=pq.s, output=' ') + self.assertRaises( + ValueError, + statistics.time_histogram, + self.spiketrains, + bin_size=pq.s, + output=" ", + ) def test_annotations(self): np.random.seed(1) spiketrains = StationaryPoissonProcess( - rate=10 * pq.Hz, t_stop=10 * pq.s).generate_n_spiketrains( - n_spiketrains=10) + rate=10 * pq.Hz, t_stop=10 * pq.s + ).generate_n_spiketrains(n_spiketrains=10) for output in ("counts", "mean", "rate"): - histogram = statistics.time_histogram(spiketrains, - bin_size=3 * pq.ms, - output=output) - self.assertIn('normalization', histogram.annotations) - self.assertEqual(histogram.annotations['normalization'], output) + histogram = statistics.time_histogram( + spiketrains, bin_size=3 * pq.ms, output=output + ) + self.assertIn("normalization", histogram.annotations) + self.assertEqual(histogram.annotations["normalization"], output) class ComplexityTestCase(unittest.TestCase): def test_complexity_pdf_deprecated(self): spiketrain_a = neo.SpikeTrain( - [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s + ) spiketrain_b = neo.SpikeTrain( - [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s + ) spiketrain_c = neo.SpikeTrain( - [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) - spiketrains = [ - spiketrain_a, spiketrain_b, spiketrain_c] + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s + ) + spiketrains = [spiketrain_a, spiketrain_b, spiketrain_c] # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) - complexity = statistics.complexity_pdf( - spiketrains, bin_size=0.1*pq.s) + complexity = statistics.complexity_pdf(spiketrains, bin_size=0.1 * pq.s) assert_array_equal(targ, complexity.magnitude[:, 0]) self.assertEqual(1, complexity.magnitude[:, 0].sum()) self.assertEqual(len(spiketrains) + 1, len(complexity)) @@ -1113,17 +1508,18 @@ def test_complexity_pdf_deprecated(self): def test_complexity_pdf(self): spiketrain_a = neo.SpikeTrain( - [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s + ) spiketrain_b = neo.SpikeTrain( - [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s + ) spiketrain_c = neo.SpikeTrain( - [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) - spiketrains = [ - spiketrain_a, spiketrain_b, spiketrain_c] + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s + ) + spiketrains = [spiketrain_a, spiketrain_b, spiketrain_c] # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) - complexity_obj = statistics.Complexity(spiketrains, - bin_size=0.1 * pq.s) + complexity_obj = statistics.Complexity(spiketrains, bin_size=0.1 * pq.s) pdf = complexity_obj.pdf() assert_array_equal(targ, complexity_obj.pdf().magnitude[:, 0]) self.assertEqual(1, pdf.magnitude[:, 0].sum()) @@ -1134,191 +1530,191 @@ def test_complexity_pdf(self): def test_complexity_histogram_spread_0(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20 * pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, t_stop=20 * pq.s), + ] correct_histogram = np.array([10, 8, 2]) - correct_time_histogram = np.array([0, 2, 0, 0, 1, 1, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 2, 0, 1, 1]) + correct_time_histogram = np.array( + [0, 2, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 2, 0, 1, 1] + ) - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - spread=0) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, spread=0 + ) - assert_array_equal(complexity_obj.complexity_histogram, - correct_histogram) + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) def test_complexity_histogram_spread_0_nonbinary(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 5, 9, 11, 16, 19] * pq.s, - t_stop=20 * pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 16, 18] * pq.s, - t_stop=20 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 5, 9, 11, 16, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 16, 18] * pq.s, t_stop=20 * pq.s), + ] correct_histogram = np.array([10, 7, 2, 1]) - correct_time_histogram = np.array([0, 2, 0, 0, 1, 2, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 3, 0, 1, 1]) + correct_time_histogram = np.array( + [0, 2, 0, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 3, 0, 1, 1] + ) - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - binary=False, - spread=0) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, binary=False, spread=0 + ) - assert_array_equal(complexity_obj.complexity_histogram, - correct_histogram) + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) def test_complexity_epoch_spread_0(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20 * pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, t_stop=20 * pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, t_stop=20 * pq.s), + ] - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - spread=0) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, spread=0 + ) - self.assertIsInstance(complexity_obj.epoch, - neo.Epoch) + self.assertIsInstance(complexity_obj.epoch, neo.Epoch) def test_complexity_histogram_spread_1(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([0, 1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21 * pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21 * pq.s)] + spiketrains = [ + neo.SpikeTrain([0, 1, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, t_stop=21 * pq.s), + ] correct_histogram = np.array([9, 5, 1, 2]) - correct_time_histogram = np.array([3, 3, 0, 0, 2, 2, 0, 1, 0, 1, 0, - 3, 3, 3, 0, 0, 1, 0, 1, 0, 1]) + correct_time_histogram = np.array( + [3, 3, 0, 0, 2, 2, 0, 1, 0, 1, 0, 3, 3, 3, 0, 0, 1, 0, 1, 0, 1] + ) - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - spread=1) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, spread=1 + ) - assert_array_equal(complexity_obj.complexity_histogram, - correct_histogram) + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) def test_complexity_histogram_spread_1_nonbinary(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([0, 1, 5, 5, 9, 11, 13, 20] * pq.s, - t_stop=21 * pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 16, 18] * pq.s, - t_stop=21 * pq.s)] + spiketrains = [ + neo.SpikeTrain([0, 1, 5, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 16, 18] * pq.s, t_stop=21 * pq.s), + ] correct_histogram = np.array([9, 4, 1, 3]) - correct_time_histogram = np.array([3, 3, 0, 0, 3, 3, 0, 1, 0, 1, 0, - 3, 3, 3, 0, 0, 2, 0, 1, 0, 1]) + correct_time_histogram = np.array( + [3, 3, 0, 0, 3, 3, 0, 1, 0, 1, 0, 3, 3, 3, 0, 0, 2, 0, 1, 0, 1] + ) - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - binary=False, - spread=1) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, binary=False, spread=1 + ) - assert_array_equal(complexity_obj.complexity_histogram, - correct_histogram) + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) def test_complexity_histogram_spread_2(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21 * pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, t_stop=21 * pq.s), + ] correct_histogram = np.array([5, 0, 1, 1, 0, 0, 0, 1]) - correct_time_histogram = np.array([0, 2, 0, 0, 7, 7, 7, 7, 7, 7, 7, - 7, 7, 7, 0, 0, 3, 3, 3, 3, 3]) + correct_time_histogram = np.array( + [0, 2, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 3, 3, 3, 3, 3] + ) - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - spread=2) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, spread=2 + ) - assert_array_equal(complexity_obj.complexity_histogram, - correct_histogram) + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) def test_complexity_histogram_spread_2_nonbinary(self): sampling_rate = 1 / pq.s - spiketrains = [neo.SpikeTrain([1, 5, 5, 9, 11, 13, 20] * pq.s, - t_stop=21 * pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 16, 18] * pq.s, - t_stop=21 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 16, 18] * pq.s, t_stop=21 * pq.s), + ] correct_histogram = np.array([5, 0, 1, 0, 1, 0, 0, 0, 1]) - correct_time_histogram = np.array([0, 2, 0, 0, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 0, 0, 4, 4, 4, 4, 4]) + correct_time_histogram = np.array( + [0, 2, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 0, 4, 4, 4, 4, 4] + ) - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - binary=False, - spread=2) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, binary=False, spread=2 + ) - assert_array_equal(complexity_obj.complexity_histogram, - correct_histogram) + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) def test_wrong_input_errors(self): - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21 * pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, t_stop=21 * pq.s), + ] - self.assertRaises(ValueError, - statistics.Complexity, - spiketrains) + self.assertRaises(ValueError, statistics.Complexity, spiketrains) - self.assertRaises(ValueError, - statistics.Complexity, - spiketrains, - sampling_rate=1 * pq.s, - spread=-7) + self.assertRaises( + ValueError, + statistics.Complexity, + spiketrains, + sampling_rate=1 * pq.s, + spread=-7, + ) def test_sampling_rate_warning(self): - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21 * pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21 * pq.s)] + spiketrains = [ + neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, t_stop=21 * pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, t_stop=21 * pq.s), + ] with self.assertWarns(UserWarning): - statistics.Complexity(spiketrains, - bin_size=1 * pq.s, - spread=1) + statistics.Complexity(spiketrains, bin_size=1 * pq.s, spread=1) def test_binning_for_input_with_rounding_errors(self): # a test with inputs divided by 30000 which leads to rounding errors @@ -1327,23 +1723,28 @@ def test_binning_for_input_with_rounding_errors(self): sampling_rate = 333 / pq.s - spiketrains = [neo.SpikeTrain(np.arange(1000, step=2) * pq.s / 333, - t_stop=30.33333333333 * pq.s), - neo.SpikeTrain(np.arange(2000, step=4) * pq.s / 333, - t_stop=30.33333333333 * pq.s)] + spiketrains = [ + neo.SpikeTrain( + np.arange(1000, step=2) * pq.s / 333, t_stop=30.33333333333 * pq.s + ), + neo.SpikeTrain( + np.arange(2000, step=4) * pq.s / 333, t_stop=30.33333333333 * pq.s + ), + ] correct_time_histogram = np.zeros(10101) correct_time_histogram[:1000:2] = 1 correct_time_histogram[:2000:4] += 1 - complexity_obj = statistics.Complexity(spiketrains, - sampling_rate=sampling_rate, - spread=1) + complexity_obj = statistics.Complexity( + spiketrains, sampling_rate=sampling_rate, spread=1 + ) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), - correct_time_histogram) + correct_time_histogram, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_total_spiking_probability_edges.py b/elephant/test/test_total_spiking_probability_edges.py index 60bfa5e2f..a9e2f8299 100644 --- a/elephant/test/test_total_spiking_probability_edges.py +++ b/elephant/test/test_total_spiking_probability_edges.py @@ -8,12 +8,12 @@ from scipy.io import loadmat from elephant.conversion import BinnedSpikeTrain -from elephant.functional_connectivity_src.total_spiking_probability_edges \ - import (generate_filter_pairs, - normalized_cross_correlation, - TspeFilterPair, - total_spiking_probability_edges, - ) +from elephant.functional_connectivity_src.total_spiking_probability_edges import ( + generate_filter_pairs, + normalized_cross_correlation, + TspeFilterPair, + total_spiking_probability_edges, +) from elephant.datasets import download_datasets @@ -36,27 +36,34 @@ def test_generate_filter_pairs(self): function_output = generate_filter_pairs(a, b, c) - for filter_pair_function, filter_pair_test in zip(function_output, - test_output): + for filter_pair_function, filter_pair_test in zip(function_output, test_output): np.testing.assert_array_equal( - filter_pair_function.edge_filter, - filter_pair_test.edge_filter) + filter_pair_function.edge_filter, filter_pair_test.edge_filter + ) np.testing.assert_array_equal( filter_pair_function.running_total_filter, - filter_pair_test.running_total_filter) + filter_pair_test.running_total_filter, + ) - self.assertEqual(filter_pair_function.needed_padding, - filter_pair_test.needed_padding) + self.assertEqual( + filter_pair_function.needed_padding, filter_pair_test.needed_padding + ) - self.assertEqual(filter_pair_function.surrounding_window_size, - filter_pair_test.surrounding_window_size) + self.assertEqual( + filter_pair_function.surrounding_window_size, + filter_pair_test.surrounding_window_size, + ) - self.assertEqual(filter_pair_function.observed_window_size, - filter_pair_test.observed_window_size) + self.assertEqual( + filter_pair_function.observed_window_size, + filter_pair_test.observed_window_size, + ) - self.assertEqual(filter_pair_function.crossover_window_size, - filter_pair_test.crossover_window_size) + self.assertEqual( + filter_pair_function.crossover_window_size, + filter_pair_test.crossover_window_size, + ) def test_normalized_cross_correlation(self): # Generate Spiketrains @@ -65,13 +72,14 @@ def test_normalized_cross_correlation(self): spike_times_delayed = spike_times + delay_time * ms spiketrains = BinnedSpikeTrain( - [SpikeTrain(spike_times, t_stop=20.0 * ms), - SpikeTrain(spike_times_delayed, t_stop=20.0 * ms),], + [ + SpikeTrain(spike_times, t_stop=20.0 * ms), + SpikeTrain(spike_times_delayed, t_stop=20.0 * ms), + ], bin_size=1 * ms, ) - test_output = np.array([[[0.0, 0.0], [1.1, 0.0]], [[0.0, 1.1], - [0.0, 0.0]]]) + test_output = np.array([[[0.0, 0.0], [1.1, 0.0]], [[0.0, 1.1], [0.0, 0.0]]]) function_output = normalized_cross_correlation( spiketrains, [-delay_time, delay_time] @@ -80,25 +88,29 @@ def test_normalized_cross_correlation(self): assert np.allclose(function_output, test_output, 0.1) def test_total_spiking_probability_edges(self): - files = ["SW/new_sim0_100.mat", - "BA/new_sim0_100.mat", - "CA/new_sim0_100.mat", - "ER05/new_sim0_100.mat", - "ER10/new_sim0_100.mat", - "ER15/new_sim0_100.mat", - ] + files = [ + "SW/new_sim0_100.mat", + "BA/new_sim0_100.mat", + "CA/new_sim0_100.mat", + "ER05/new_sim0_100.mat", + "ER10/new_sim0_100.mat", + "ER15/new_sim0_100.mat", + ] for datafile in files: - repo_base_path = 'unittest/functional_connectivity/' \ - 'total_spiking_probability_edges/data/' - downloaded_dataset_path = download_datasets(repo_base_path + - datafile) + repo_base_path = ( + "unittest/functional_connectivity/" + "total_spiking_probability_edges/data/" + ) + downloaded_dataset_path = download_datasets(repo_base_path + datafile) spiketrains, original_data = load_spike_train_simulated( - downloaded_dataset_path) + downloaded_dataset_path + ) - connectivity_matrix, delay_matrix = \ - total_spiking_probability_edges(spiketrains) + connectivity_matrix, delay_matrix = total_spiking_probability_edges( + spiketrains + ) # Remove self-connections np.fill_diagonal(connectivity_matrix, 0) @@ -107,6 +119,7 @@ def test_total_spiking_probability_edges(self): self.assertGreater(auc, 0.95) + # ====== HELPER FUNCTIONS ====== @@ -116,7 +129,7 @@ def classify_connections(connectivity_matrix: np.ndarray, threshold: int): mask_excitatory = connectivity_matrix_binarized > threshold mask_inhibitory = connectivity_matrix_binarized < -threshold - mask_left = ~ (mask_excitatory + mask_inhibitory) + mask_left = ~(mask_excitatory + mask_inhibitory) connectivity_matrix_binarized[mask_excitatory] = 1 connectivity_matrix_binarized[mask_inhibitory] = -1 @@ -178,9 +191,11 @@ def roc_curve(estimate, original): return tpr_list, fpr_list, thresholds, auc -def load_spike_train_simulated(path: Union[Path, str], bin_size=None, - t_stop=None, - ) -> Tuple[BinnedSpikeTrain, np.ndarray]: +def load_spike_train_simulated( + path: Union[Path, str], + bin_size=None, + t_stop=None, +) -> Tuple[BinnedSpikeTrain, np.ndarray]: if isinstance(path, str): path = Path(path) @@ -190,8 +205,7 @@ def load_spike_train_simulated(path: Union[Path, str], bin_size=None, data = loadmat(path, simplify_cells=True)["data"] if "asdf" not in data: - raise ValueError('Incorrect Dataformat: Missing spiketrain_data in' - '"asdf"') + raise ValueError("Incorrect Dataformat: Missing spiketrain_data in" '"asdf"') spiketrain_data = data["asdf"] @@ -210,10 +224,11 @@ def load_spike_train_simulated(path: Union[Path, str], bin_size=None, ) ) - spiketrains = BinnedSpikeTrain(spiketrains, bin_size=bin_size, - t_stop=t_stop or recording_duration_ms) + spiketrains = BinnedSpikeTrain( + spiketrains, bin_size=bin_size, t_stop=t_stop or recording_duration_ms + ) # Load original_data - original_data = data['SWM'].T + original_data = data["SWM"].T return spiketrains, original_data diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index b472e7a8e..6f3a417bb 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -15,21 +15,20 @@ from elephant.trials import TrialsFromBlock, TrialsFromLists -def _create_trials_block(n_trials: int = 0, - n_spiketrains: int = 2, - n_analogsignals: int = 2) -> neo.core.Block: - """ Create block with n_trials, n_spiketrains and n_analog_signals """ - block = neo.Block(name='test_block') +def _create_trials_block( + n_trials: int = 0, n_spiketrains: int = 2, n_analogsignals: int = 2 +) -> neo.core.Block: + """Create block with n_trials, n_spiketrains and n_analog_signals""" + block = neo.Block(name="test_block") for trial in range(n_trials): - segment = neo.Segment(name=f'No. {trial}') - spiketrains = StationaryPoissonProcess(rate=50. * pq.Hz, - t_start=0 * pq.ms, - t_stop=1000 * pq.ms - ).generate_n_spiketrains( - n_spiketrains=n_spiketrains) - analogsignals = [AnalogSignal(signal=[.01, 3.3, 9.3], units='uV', - sampling_rate=1 * pq.Hz) - for _ in range(n_analogsignals)] + segment = neo.Segment(name=f"No. {trial}") + spiketrains = StationaryPoissonProcess( + rate=50.0 * pq.Hz, t_start=0 * pq.ms, t_stop=1000 * pq.ms + ).generate_n_spiketrains(n_spiketrains=n_spiketrains) + analogsignals = [ + AnalogSignal(signal=[0.01, 3.3, 9.3], units="uV", sampling_rate=1 * pq.Hz) + for _ in range(n_analogsignals) + ] for spiketrain in spiketrains: segment.spiketrains.append(spiketrain) for analogsignal in analogsignals: @@ -54,8 +53,7 @@ def setUpClass(cls) -> None: block = _create_trials_block(n_trials=36) cls.block = block - cls.trial_object = TrialsFromBlock(block, - description='trials are segments') + cls.trial_object = TrialsFromBlock(block, description="trials are segments") def setUp(self) -> None: """ @@ -66,7 +64,7 @@ def test_trials_from_block_description(self) -> None: """ Test description of the trials object. """ - self.assertEqual(self.trial_object.description, 'trials are segments') + self.assertEqual(self.trial_object.description, "trials are segments") def test_trials_from_block_get_item(self) -> None: """ @@ -79,14 +77,16 @@ def test_trials_from_block_get_trial_as_segment(self) -> None: Test get a trial from the trials. """ self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), - neo.core.Segment) + self.trial_object.get_trial_as_segment(0), neo.core.Segment + ) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).spiketrains[0], - neo.core.SpikeTrain) + neo.core.SpikeTrain, + ) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).analogsignals[0], - neo.core.AnalogSignal) + neo.core.AnalogSignal, + ) def test_trials_from_block_get_trials_as_block(self) -> None: """ @@ -94,8 +94,7 @@ def test_trials_from_block_get_trials_as_block(self) -> None: """ block = self.trial_object.get_trials_as_block([0, 3, 5]) self.assertIsInstance(block, neo.core.Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), - neo.core.Block) + self.assertIsInstance(self.trial_object.get_trials_as_block(), neo.core.Block) self.assertEqual(len(block.segments), 3) def test_trials_from_block_get_trials_as_list(self) -> None: @@ -118,64 +117,71 @@ def test_trials_from_block_n_spiketrains_trial_by_trial(self) -> None: """ Test get number of spiketrains per trial. """ - self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [len(trial.spiketrains) for trial in - self.block.segments]) + self.assertEqual( + self.trial_object.n_spiketrains_trial_by_trial, + [len(trial.spiketrains) for trial in self.block.segments], + ) def test_trials_from_block_n_analogsignals_trial_by_trial(self) -> None: """ Test get number of analogsignals per trial. """ - self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [len(trial.analogsignals) for trial in - self.block.segments]) + self.assertEqual( + self.trial_object.n_analogsignals_trial_by_trial, + [len(trial.analogsignals) for trial in self.block.segments], + ) - def test_trials_from_block_get_spiketrains_from_trial_as_list(self - ) -> None: + def test_trials_from_block_get_spiketrains_from_trial_as_list(self) -> None: """ Test get spiketrains from trial as list """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0), - neo.core.spiketrainlist.SpikeTrainList) + neo.core.spiketrainlist.SpikeTrainList, + ) self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - neo.core.SpikeTrain) + neo.core.SpikeTrain, + ) - def test_trials_from_list_get_spiketrains_from_trial_as_segment(self - ) -> None: + def test_trials_from_list_get_spiketrains_from_trial_as_segment(self) -> None: """ Test get spiketrains from trial as segment """ self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment(0), - neo.core.Segment) + self.trial_object.get_spiketrains_from_trial_as_segment(0), neo.core.Segment + ) self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], neo.core.SpikeTrain) + self.trial_object.get_spiketrains_from_trial_as_segment(0).spiketrains[0], + neo.core.SpikeTrain, + ) - def test_trials_from_block_get_analogsignals_from_trial_as_list(self - ) -> None: + def test_trials_from_block_get_analogsignals_from_trial_as_list(self) -> None: """ Test get analogsignals from trial as list """ self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0), list) + self.trial_object.get_analogsignals_from_trial_as_list(0), list + ) self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - neo.core.AnalogSignal) + neo.core.AnalogSignal, + ) - def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) \ - -> None: + def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) -> None: """ Test get spiketrains from trial as segment """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment(0), - neo.core.Segment) + neo.core.Segment, + ) self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], neo.core.AnalogSignal) + self.trial_object.get_analogsignals_from_trial_as_segment(0).analogsignals[ + 0 + ], + neo.core.AnalogSignal, + ) class TrialsFromListTestCase(unittest.TestCase): @@ -191,16 +197,16 @@ def setUpClass(cls) -> None: # Create Trialobject as list of lists # add spiketrains - trial_list = [[spiketrain for spiketrain in trial.spiketrains] - for trial in block.segments] + trial_list = [ + [spiketrain for spiketrain in trial.spiketrains] for trial in block.segments + ] # add analogsignals for idx, trial in enumerate(block.segments): for analogsignal in trial.analogsignals: trial_list[idx].append(analogsignal) cls.trial_list = trial_list - cls.trial_object = TrialsFromLists(trial_list, - description='trial is a list') + cls.trial_object = TrialsFromLists(trial_list, description="trial is a list") def setUp(self) -> None: """ @@ -211,32 +217,33 @@ def test_trials_from_list_description(self) -> None: """ Test description of the trials object. """ - self.assertEqual(self.trial_object.description, 'trial is a list') + self.assertEqual(self.trial_object.description, "trial is a list") def test_trials_from_list_get_item(self) -> None: """ Test get a trial from the trials. """ - self.assertIsInstance(self.trial_object[0], - neo.core.Segment) - self.assertIsInstance(self.trial_object[0].spiketrains[0], - neo.core.SpikeTrain) - self.assertIsInstance(self.trial_object[0].analogsignals[0], - neo.core.AnalogSignal) + self.assertIsInstance(self.trial_object[0], neo.core.Segment) + self.assertIsInstance(self.trial_object[0].spiketrains[0], neo.core.SpikeTrain) + self.assertIsInstance( + self.trial_object[0].analogsignals[0], neo.core.AnalogSignal + ) def test_trials_from_list_get_trial_as_segment(self) -> None: """ Test get a trial from the trials. """ self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), - neo.core.Segment) + self.trial_object.get_trial_as_segment(0), neo.core.Segment + ) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).spiketrains[0], - neo.core.SpikeTrain) + neo.core.SpikeTrain, + ) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).analogsignals[0], - neo.core.AnalogSignal) + neo.core.AnalogSignal, + ) def test_trials_from_list_get_trials_as_block(self) -> None: """ @@ -244,8 +251,7 @@ def test_trials_from_list_get_trials_as_block(self) -> None: """ block = self.trial_object.get_trials_as_block([0, 3, 5]) self.assertIsInstance(block, neo.core.Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), - neo.core.Block) + self.assertIsInstance(self.trial_object.get_trials_as_block(), neo.core.Block) self.assertEqual(len(block.segments), 3) def test_trials_from_list_get_trials_as_list(self) -> None: @@ -268,18 +274,25 @@ def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: """ Test get number of spiketrains per trial. """ - self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), - trial)) for trial in self.trial_list]) + self.assertEqual( + self.trial_object.n_spiketrains_trial_by_trial, + [ + sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) + for trial in self.trial_list + ], + ) def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: """ Test get number of analogsignals per trial. """ - self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [sum(map(lambda x: isinstance(x, - neo.core.AnalogSignal), - trial)) for trial in self.trial_list]) + self.assertEqual( + self.trial_object.n_analogsignals_trial_by_trial, + [ + sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) + for trial in self.trial_list + ], + ) def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ @@ -287,47 +300,52 @@ def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0), - neo.core.spiketrainlist.SpikeTrainList) + neo.core.spiketrainlist.SpikeTrainList, + ) self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - neo.core.SpikeTrain) + neo.core.SpikeTrain, + ) - def test_trials_from_list_get_spiketrains_from_trial_as_segment(self - ) -> None: + def test_trials_from_list_get_spiketrains_from_trial_as_segment(self) -> None: """ Test get spiketrains from trial as segment """ self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment(0), - neo.core.Segment) + self.trial_object.get_spiketrains_from_trial_as_segment(0), neo.core.Segment + ) self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], neo.core.SpikeTrain) + self.trial_object.get_spiketrains_from_trial_as_segment(0).spiketrains[0], + neo.core.SpikeTrain, + ) - def test_trials_from_list_get_analogsignals_from_trial_as_list(self - ) -> None: + def test_trials_from_list_get_analogsignals_from_trial_as_list(self) -> None: """ Test get analogsignals from trial as list """ self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0), list) + self.trial_object.get_analogsignals_from_trial_as_list(0), list + ) self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - neo.core.AnalogSignal) + neo.core.AnalogSignal, + ) - def test_trials_from_list_get_analogsignals_from_trial_as_segment(self - ) \ - -> None: + def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) -> None: """ Test get spiketrains from trial as segment """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment(0), - neo.core.Segment) + neo.core.Segment, + ) self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], neo.core.AnalogSignal) + self.trial_object.get_analogsignals_from_trial_as_segment(0).analogsignals[ + 0 + ], + neo.core.AnalogSignal, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_unitary_event_analysis.py b/elephant/test/test_unitary_event_analysis.py index 4d09836fc..12416a9d1 100644 --- a/elephant/test/test_unitary_event_analysis.py +++ b/elephant/test/test_unitary_event_analysis.py @@ -20,94 +20,109 @@ class UETestCase(unittest.TestCase): - def setUp(self): - sts1_with_trial = [[26., 48., 78., 144., 178.], - [4., 45., 85., 123., 156., 185.], - [22., 53., 73., 88., 120., 147., 167., 193.], - [23., 49., 74., 116., 142., 166., 189.], - [5., 34., 54., 80., 108., 128., 150., 181.], - [18., 61., 107., 170.], - [62., 98., 131., 161.], - [37., 63., 86., 131., 168.], - [39., 76., 100., 127., 153., 198.], - [3., 35., 60., 88., 108., 141., 171., 184.], - [39., 170.], - [25., 68., 170.], - [19., 57., 84., 116., 157., 192.], - [17., 80., 131., 172.], - [33., 65., 124., 162., 192.], - [58., 87., 185.], - [19., 101., 174.], - [84., 118., 156., 198., 199.], - [5., 55., 67., 96., 114., 148., 172., 199.], - [61., 105., 131., 169., 195.], - [26., 96., 129., 157.], - [41., 85., 157., 199.], - [6., 30., 53., 76., 109., 142., 167., 194.], - [159.], - [6., 51., 78., 113., 154., 183.], - [138.], - [23., 59., 154., 185.], - [12., 14., 52., 54., 109., 145., 192.], - [29., 61., 84., 122., 145., 168.], - [26., 99.], - [3., 31., 55., 85., 108., 158., 191.], - [5., 37., 70., 119., 170.], - [38., 79., 117., 157., 192.], - [174.], - [114.], - []] - sts2_with_trial = [[3., 119.], - [54., 155., 183.], - [35., 133.], - [25., 100., 176.], - [9., 98.], - [6., 97., 198.], - [7., 62., 148.], - [100., 158.], - [7., 62., 122., 179., 191.], - [125., 182.], - [30., 55., 127., 157., 196.], - [27., 70., 173.], - [82., 84., 198.], - [11., 29., 137.], - [5., 49., 61., 101., 142., 190.], - [78., 162., 178.], - [13., 14., 130., 172.], - [22.], - [16., 55., 109., 113., 175.], - [17., 33., 63., 102., 144., 189., 190.], - [58.], - [27., 30., 99., 145., 176.], - [10., 58., 116., 182.], - [14., 68., 104., 126., 162., 194.], - [56., 129., 196.], - [50., 78., 105., 152., 190., 197.], - [24., 66., 113., 117., 161.], - [9., 31., 81., 95., 136., 154.], - [10., 115., 185., 191.], - [71., 140., 157.], - [15., 27., 88., 102., 103., 151., 181., 188.], - [51., 75., 95., 134., 195.], - [18., 55., 75., 131., 186.], - [10., 16., 41., 42., 75., 127.], - [62., 76., 102., 145., 171., 183.], - [66., 71., 85., 140., 154.]] - self.sts1_neo = [neo.SpikeTrain( - i * pq.ms, t_stop=200 * pq.ms) for i in sts1_with_trial] - self.sts2_neo = [neo.SpikeTrain( - i * pq.ms, t_stop=200 * pq.ms) for i in sts2_with_trial] - self.binary_sts = np.array([[[1, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 0, 1]], - [[1, 1, 1, 1, 1], - [0, 1, 1, 1, 1], - [1, 1, 0, 1, 0]]]) + sts1_with_trial = [ + [26.0, 48.0, 78.0, 144.0, 178.0], + [4.0, 45.0, 85.0, 123.0, 156.0, 185.0], + [22.0, 53.0, 73.0, 88.0, 120.0, 147.0, 167.0, 193.0], + [23.0, 49.0, 74.0, 116.0, 142.0, 166.0, 189.0], + [5.0, 34.0, 54.0, 80.0, 108.0, 128.0, 150.0, 181.0], + [18.0, 61.0, 107.0, 170.0], + [62.0, 98.0, 131.0, 161.0], + [37.0, 63.0, 86.0, 131.0, 168.0], + [39.0, 76.0, 100.0, 127.0, 153.0, 198.0], + [3.0, 35.0, 60.0, 88.0, 108.0, 141.0, 171.0, 184.0], + [39.0, 170.0], + [25.0, 68.0, 170.0], + [19.0, 57.0, 84.0, 116.0, 157.0, 192.0], + [17.0, 80.0, 131.0, 172.0], + [33.0, 65.0, 124.0, 162.0, 192.0], + [58.0, 87.0, 185.0], + [19.0, 101.0, 174.0], + [84.0, 118.0, 156.0, 198.0, 199.0], + [5.0, 55.0, 67.0, 96.0, 114.0, 148.0, 172.0, 199.0], + [61.0, 105.0, 131.0, 169.0, 195.0], + [26.0, 96.0, 129.0, 157.0], + [41.0, 85.0, 157.0, 199.0], + [6.0, 30.0, 53.0, 76.0, 109.0, 142.0, 167.0, 194.0], + [159.0], + [6.0, 51.0, 78.0, 113.0, 154.0, 183.0], + [138.0], + [23.0, 59.0, 154.0, 185.0], + [12.0, 14.0, 52.0, 54.0, 109.0, 145.0, 192.0], + [29.0, 61.0, 84.0, 122.0, 145.0, 168.0], + [26.0, 99.0], + [3.0, 31.0, 55.0, 85.0, 108.0, 158.0, 191.0], + [5.0, 37.0, 70.0, 119.0, 170.0], + [38.0, 79.0, 117.0, 157.0, 192.0], + [174.0], + [114.0], + [], + ] + sts2_with_trial = [ + [3.0, 119.0], + [54.0, 155.0, 183.0], + [35.0, 133.0], + [25.0, 100.0, 176.0], + [9.0, 98.0], + [6.0, 97.0, 198.0], + [7.0, 62.0, 148.0], + [100.0, 158.0], + [7.0, 62.0, 122.0, 179.0, 191.0], + [125.0, 182.0], + [30.0, 55.0, 127.0, 157.0, 196.0], + [27.0, 70.0, 173.0], + [82.0, 84.0, 198.0], + [11.0, 29.0, 137.0], + [5.0, 49.0, 61.0, 101.0, 142.0, 190.0], + [78.0, 162.0, 178.0], + [13.0, 14.0, 130.0, 172.0], + [22.0], + [16.0, 55.0, 109.0, 113.0, 175.0], + [17.0, 33.0, 63.0, 102.0, 144.0, 189.0, 190.0], + [58.0], + [27.0, 30.0, 99.0, 145.0, 176.0], + [10.0, 58.0, 116.0, 182.0], + [14.0, 68.0, 104.0, 126.0, 162.0, 194.0], + [56.0, 129.0, 196.0], + [50.0, 78.0, 105.0, 152.0, 190.0, 197.0], + [24.0, 66.0, 113.0, 117.0, 161.0], + [9.0, 31.0, 81.0, 95.0, 136.0, 154.0], + [10.0, 115.0, 185.0, 191.0], + [71.0, 140.0, 157.0], + [15.0, 27.0, 88.0, 102.0, 103.0, 151.0, 181.0, 188.0], + [51.0, 75.0, 95.0, 134.0, 195.0], + [18.0, 55.0, 75.0, 131.0, 186.0], + [10.0, 16.0, 41.0, 42.0, 75.0, 127.0], + [62.0, 76.0, 102.0, 145.0, 171.0, 183.0], + [66.0, 71.0, 85.0, 140.0, 154.0], + ] + self.sts1_neo = [ + neo.SpikeTrain(i * pq.ms, t_stop=200 * pq.ms) for i in sts1_with_trial + ] + self.sts2_neo = [ + neo.SpikeTrain(i * pq.ms, t_stop=200 * pq.ms) for i in sts2_with_trial + ] + self.binary_sts = np.array( + [ + [[1, 1, 1, 1, 0], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1]], + [[1, 1, 1, 1, 1], [0, 1, 1, 1, 1], [1, 1, 0, 1, 0]], + ] + ) def test_hash_default(self): - m = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], - [1, 0, 1], [0, 1, 1], [1, 1, 1]]) + m = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) expected = np.array([77, 43, 23]) h = ue.hash_from_pattern(m) self.assertTrue(np.all(expected == h)) @@ -115,7 +130,7 @@ def test_hash_default(self): def test_hash_default_longpattern(self): m = np.zeros((100, 2)) m[0, 0] = 1 - expected = np.array([2 ** 99, 0]) + expected = np.array([2**99, 0]) h = ue.hash_from_pattern(m) self.assertTrue(np.all(expected == h)) @@ -127,13 +142,33 @@ def test_hash_inverse_longpattern(self): assert_array_equal(m, m_inv) def test_hash_ValueError_wrong_entries(self): - m = np.array([[0, 0, 0], [1, 0, 0], [0, 2, 0], [0, 0, 1], [1, 1, 0], - [1, 0, 1], [0, 1, 1], [1, 1, 1]]) + m = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 2, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) self.assertRaises(ValueError, ue.hash_from_pattern, m) def test_hash_base_not_two(self): - m = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], - [1, 0, 1], [0, 1, 1], [1, 1, 1]]) + m = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) m = m.T base = 3 expected = np.array([0, 9, 3, 1, 12, 10, 4, 13]) @@ -144,15 +179,18 @@ def test_invhash_ValueError(self): """ The hash is larger than sum(2 ** range(N)). """ - self.assertRaises( - ValueError, ue.inverse_hash_from_pattern, [128, 8], 4) + self.assertRaises(ValueError, ue.inverse_hash_from_pattern, [128, 8], 4) def test_invhash_default_base(self): N = 3 h = np.array([0, 4, 2, 1, 6, 5, 3, 7]) expected = np.array( - [[0, 1, 0, 0, 1, 1, 0, 1], [0, 0, 1, 0, 1, 0, 1, 1], - [0, 0, 0, 1, 0, 1, 1, 1]]) + [ + [0, 1, 0, 0, 1, 1, 0, 1], + [0, 0, 1, 0, 1, 0, 1, 1], + [0, 0, 0, 1, 0, 1, 1, 1], + ] + ) m = ue.inverse_hash_from_pattern(h, N) self.assertTrue(np.all(expected == m)) @@ -167,23 +205,44 @@ def test_invhash_base_not_two(self): def test_invhash_shape_mat(self): N = 8 h = np.array([178, 212, 232]) - expected = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [ - 1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]]) + expected = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) m = ue.inverse_hash_from_pattern(h, N) self.assertTrue(np.shape(m)[0] == N) def test_hash_invhash_consistency(self): - m = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], - [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]]) + m = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) inv_h = ue.hash_from_pattern(m) m1 = ue.inverse_hash_from_pattern(inv_h, N=8) self.assertTrue(np.all(m == m1)) def test_n_emp_mat_default(self): - mat = np.array([[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], - [1, 0, 1, 1, 1], [1, 0, 1, 1, 1]]) + mat = np.array( + [[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [1, 0, 1, 1, 1], [1, 0, 1, 1, 1]] + ) pattern_hash = [3, 15] - expected1 = np.array([2., 1.]) + expected1 = np.array([2.0, 1.0]) expected2 = [[0, 2], [4]] nemp, nemp_indices = ue.n_emp_mat(mat, pattern_hash) self.assertTrue(np.all(nemp == expected1)) @@ -194,18 +253,18 @@ def test_n_emp_mat_sum_trial_default(self): mat = self.binary_sts pattern_hash = np.array([4, 6]) N = 3 - expected1 = np.array([1., 3.]) + expected1 = np.array([1.0, 3.0]) expected2 = [[[0], [3]], [[], [2, 4]]] n_emp, n_emp_idx = ue.n_emp_mat_sum_trial(mat, pattern_hash) self.assertTrue(np.all(n_emp == expected1)) for item0_cnt, item0 in enumerate(n_emp_idx): for item1_cnt, item1 in enumerate(item0): - self.assertTrue( - np.allclose(expected2[item0_cnt][item1_cnt], item1)) + self.assertTrue(np.allclose(expected2[item0_cnt][item1_cnt], item1)) def test_n_exp_mat_default(self): - mat = np.array([[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], - [1, 0, 1, 1, 1], [1, 0, 1, 1, 1]]) + mat = np.array( + [[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [1, 0, 1, 1, 1], [1, 0, 1, 1, 1]] + ) pattern_hash = [3, 11] expected = np.array([1.536, 1.024]) nexp = ue.n_exp_mat(mat, pattern_hash) @@ -223,29 +282,30 @@ def test_n_exp_mat_sum_trial_TrialAverage(self): pattern_hash = np.array([5, 6]) expected = np.array([1.62, 2.52]) n_exp = ue.n_exp_mat_sum_trial( - mat, pattern_hash, method='analytic_TrialAverage') + mat, pattern_hash, method="analytic_TrialAverage" + ) self.assertTrue(np.allclose(n_exp, expected)) def test_n_exp_mat_sum_trial_surrogate(self): mat = self.binary_sts pattern_hash = np.array([5]) n_exp_anal = ue.n_exp_mat_sum_trial( - mat, pattern_hash, method='analytic_TrialAverage') + mat, pattern_hash, method="analytic_TrialAverage" + ) n_exp_surr = ue.n_exp_mat_sum_trial( - mat, pattern_hash, method='surrogate_TrialByTrial', - n_surrogates=1000) + mat, pattern_hash, method="surrogate_TrialByTrial", n_surrogates=1000 + ) self.assertLess( - a=np.abs(n_exp_anal[0] - np.mean(n_exp_surr)) / n_exp_anal[0], - b=0.1) + a=np.abs(n_exp_anal[0] - np.mean(n_exp_surr)) / n_exp_anal[0], b=0.1 + ) def test_gen_pval_anal_default(self): - mat = np.array([[[1, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 0, 1]], - - [[1, 1, 1, 1, 1], - [0, 1, 1, 1, 1], - [1, 1, 0, 1, 0]]]) + mat = np.array( + [ + [[1, 1, 1, 1, 0], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1]], + [[1, 1, 1, 1, 1], [0, 1, 1, 1, 1], [1, 1, 0, 1, 0]], + ] + ) pattern_hash = np.array([5, 6]) expected = np.array([1.56, 2.56]) pval_func, n_exp = ue.gen_pval_anal(mat, pattern_hash) @@ -273,19 +333,20 @@ def test__winpos(self): t_stop = 46 * pq.ms winsize = 15 * pq.ms winstep = 3 * pq.ms - expected = [10., 13., 16., 19., 22., 25., 28., 31.] * pq.ms + expected = [10.0, 13.0, 16.0, 19.0, 22.0, 25.0, 28.0, 31.0] * pq.ms self.assertTrue( np.allclose( - ue._winpos(t_start, t_stop, winsize, - winstep).rescale('ms').magnitude, - expected.rescale('ms').magnitude)) + ue._winpos(t_start, t_stop, winsize, winstep).rescale("ms").magnitude, + expected.rescale("ms").magnitude, + ) + ) def test__UE_default(self): mat = self.binary_sts pattern_hash = np.array([4, 6]) expected_S = np.array([-0.26226523, 0.04959301]) expected_idx = [[[0], [3]], [[], [2, 4]]] - expected_nemp = np.array([1., 3.]) + expected_nemp = np.array([1.0, 3.0]) expected_nexp = np.array([1.04, 2.56]) expected_rate = np.array([0.9, 0.7, 0.6]) S, rate_avg, n_exp, n_emp, indices = ue._UE(mat, pattern_hash) @@ -295,81 +356,78 @@ def test__UE_default(self): self.assertTrue(np.allclose(expected_rate, rate_avg)) for item0_cnt, item0 in enumerate(indices): for item1_cnt, item1 in enumerate(item0): - self.assertTrue( - np.allclose(expected_idx[item0_cnt][item1_cnt], item1)) + self.assertTrue(np.allclose(expected_idx[item0_cnt][item1_cnt], item1)) def test__UE_surrogate(self): mat = self.binary_sts pattern_hash = np.array([4]) - _, rate_avg_surr, _, n_emp_surr, indices_surr = \ - ue._UE( - mat, - pattern_hash, - method='surrogate_TrialByTrial', - n_surrogates=100) - _, rate_avg, _, n_emp, indices = \ - ue._UE(mat, pattern_hash, method='analytic_TrialByTrial') + _, rate_avg_surr, _, n_emp_surr, indices_surr = ue._UE( + mat, pattern_hash, method="surrogate_TrialByTrial", n_surrogates=100 + ) + _, rate_avg, _, n_emp, indices = ue._UE( + mat, pattern_hash, method="analytic_TrialByTrial" + ) self.assertTrue(np.allclose(n_emp, n_emp_surr)) self.assertTrue(np.allclose(rate_avg, rate_avg_surr)) for item0_cnt, item0 in enumerate(indices): for item1_cnt, item1 in enumerate(item0): - self.assertTrue( - np.allclose( - indices_surr[item0_cnt][item1_cnt], - item1)) + self.assertTrue(np.allclose(indices_surr[item0_cnt][item1_cnt], item1)) def test_jointJ_window_analysis(self): - sts1 = self.sts1_neo sts2 = self.sts2_neo # joinJ_window_analysis requires the following: # A list of spike trains(neo.SpikeTrain objects) in different trials: - data = list(zip(*[sts1,sts2])) + data = list(zip(*[sts1, sts2])) win_size = 100 * pq.ms bin_size = 5 * pq.ms win_step = 20 * pq.ms pattern_hash = [3] - UE_dic = ue.jointJ_window_analysis(spiketrains=data, - pattern_hash=pattern_hash, - bin_size=bin_size, - win_size=win_size, - win_step=win_step) + UE_dic = ue.jointJ_window_analysis( + spiketrains=data, + pattern_hash=pattern_hash, + bin_size=bin_size, + win_size=win_size, + win_step=win_step, + ) expected_Js = np.array( - [0.57953708, 0.47348757, 0.1729669, - 0.01883295, -0.21934742, -0.80608759]) - expected_n_emp = np.array( - [9., 9., 7., 7., 6., 6.]) - expected_n_exp = np.array( - [6.5, 6.85, 6.05, 6.6, 6.45, 8.7]) - expected_rate = np.array( - [[0.02166667, 0.01861111], - [0.02277778, 0.01777778], - [0.02111111, 0.01777778], - [0.02277778, 0.01888889], - [0.02305556, 0.01722222], - [0.02388889, 0.02055556]]) * pq.kHz - expected_indecis_tril26 = [4., 4.] - expected_indecis_tril4 = [1.] - assert_array_almost_equal(UE_dic['Js'].squeeze(), expected_Js) - assert_array_almost_equal(UE_dic['n_emp'].squeeze(), expected_n_emp) - assert_array_almost_equal(UE_dic['n_exp'].squeeze(), expected_n_exp) - assert_array_almost_equal(UE_dic['rate_avg'].squeeze(), expected_rate) - assert_array_almost_equal(UE_dic['indices']['trial26'], - expected_indecis_tril26) - assert_array_almost_equal(UE_dic['indices']['trial4'], - expected_indecis_tril4) + [0.57953708, 0.47348757, 0.1729669, 0.01883295, -0.21934742, -0.80608759] + ) + expected_n_emp = np.array([9.0, 9.0, 7.0, 7.0, 6.0, 6.0]) + expected_n_exp = np.array([6.5, 6.85, 6.05, 6.6, 6.45, 8.7]) + expected_rate = ( + np.array( + [ + [0.02166667, 0.01861111], + [0.02277778, 0.01777778], + [0.02111111, 0.01777778], + [0.02277778, 0.01888889], + [0.02305556, 0.01722222], + [0.02388889, 0.02055556], + ] + ) + * pq.kHz + ) + expected_indecis_tril26 = [4.0, 4.0] + expected_indecis_tril4 = [1.0] + assert_array_almost_equal(UE_dic["Js"].squeeze(), expected_Js) + assert_array_almost_equal(UE_dic["n_emp"].squeeze(), expected_n_emp) + assert_array_almost_equal(UE_dic["n_exp"].squeeze(), expected_n_exp) + assert_array_almost_equal(UE_dic["rate_avg"].squeeze(), expected_rate) + assert_array_almost_equal(UE_dic["indices"]["trial26"], expected_indecis_tril26) + assert_array_almost_equal(UE_dic["indices"]["trial4"], expected_indecis_tril4) # check the input parameters - input_params = UE_dic['input_parameters'] - self.assertEqual(input_params['pattern_hash'], pattern_hash) - self.assertEqual(input_params['bin_size'], bin_size) - self.assertEqual(input_params['win_size'], win_size) - self.assertEqual(input_params['win_step'], win_step) - self.assertEqual(input_params['method'], 'analytic_TrialByTrial') - self.assertEqual(input_params['t_start'], 0 * pq.s) - self.assertEqual(input_params['t_stop'], 200 * pq.ms) + input_params = UE_dic["input_parameters"] + self.assertEqual(input_params["pattern_hash"], pattern_hash) + self.assertEqual(input_params["bin_size"], bin_size) + self.assertEqual(input_params["win_size"], win_size) + self.assertEqual(input_params["win_step"], win_step) + self.assertEqual(input_params["method"], "analytic_TrialByTrial") + self.assertEqual(input_params["t_start"], 0 * pq.s) + self.assertEqual(input_params["t_stop"], 200 * pq.ms) @staticmethod def load_gdf2Neo(fname, trigger, t_pre, t_post): @@ -394,13 +452,13 @@ def load_gdf2Neo(fname, trigger, t_pre, t_post): """ data = np.loadtxt(fname) - if trigger == 'PS_4': + if trigger == "PS_4": trigger_code = 114 - if trigger == 'RS_4': + if trigger == "RS_4": trigger_code = 124 - if trigger == 'RS': + if trigger == "RS": trigger_code = 12 - if trigger == 'ES': + if trigger == "ES": trigger_code = 15 # specify units units_id = np.unique(data[:, 0][data[:, 0] < 7]) @@ -414,23 +472,30 @@ def load_gdf2Neo(fname, trigger, t_pre, t_post): start_tmp = data[i][1] - t_pre.magnitude stop_tmp = data[i][1] + t_post.magnitude sel_data_tmp = np.array( - data[np.where((data[:, 1] <= stop_tmp) & - (data[:, 1] >= start_tmp))]) + data[np.where((data[:, 1] <= stop_tmp) & (data[:, 1] >= start_tmp))] + ) sp_units_tmp = sel_data_tmp[:, 1][ - np.where(sel_data_tmp[:, 0] == id_tmp)[0]] + np.where(sel_data_tmp[:, 0] == id_tmp)[0] + ] if len(sp_units_tmp) > 0: aligned_time = sp_units_tmp - start_tmp - data_sel_units.append(neo.SpikeTrain( - aligned_time * pq.ms, t_start=0 * pq.ms, - t_stop=t_pre + t_post)) + data_sel_units.append( + neo.SpikeTrain( + aligned_time * pq.ms, + t_start=0 * pq.ms, + t_stop=t_pre + t_post, + ) + ) else: - data_sel_units.append(neo.SpikeTrain( - [] * pq.ms, t_start=0 * pq.ms, - t_stop=t_pre + t_post)) + data_sel_units.append( + neo.SpikeTrain( + [] * pq.ms, t_start=0 * pq.ms, t_stop=t_pre + t_post + ) + ) data_tr.append(data_sel_units) data_tr.reverse() - data_tr=np.asarray(data_tr, dtype=object) + data_tr = np.asarray(data_tr, dtype=object) spiketrain = np.vstack([i for i in data_tr]).T return spiketrain @@ -438,21 +503,25 @@ def load_gdf2Neo(fname, trigger, t_pre, t_post): # consistent with the result of Riehle et al 1997 Science # (see Rostami et al (2016) [Re] Science, 3(1):1-17). def test_Riehle_et_al_97_UE(self): - url = "http://raw.githubusercontent.com/ReScience-Archives/Rostami-" \ - "Ito-Denker-Gruen-2017/master/data" + url = ( + "http://raw.githubusercontent.com/ReScience-Archives/Rostami-" + "Ito-Denker-Gruen-2017/master/data" + ) files_to_download = ( ("extracted_data.npy", "c4903666ce8a8a31274d6b11238a5ac3"), - ("winny131_23.gdf", "cc2958f7b4fb14dbab71e17bba49bd10") + ("winny131_23.gdf", "cc2958f7b4fb14dbab71e17bba49bd10"), ) for filename, checksum in files_to_download: # The files will be downloaded to ELEPHANT_TMP_DIR download(url=f"{url}/{filename}", checksum=checksum) # load spike data of figure 2 of Riehle et al 1997 - spiketrain = self.load_gdf2Neo(ELEPHANT_TMP_DIR / "winny131_23.gdf", - trigger='RS_4', - t_pre=1799 * pq.ms, - t_post=300 * pq.ms) + spiketrain = self.load_gdf2Neo( + ELEPHANT_TMP_DIR / "winny131_23.gdf", + trigger="RS_4", + t_pre=1799 * pq.ms, + t_post=300 * pq.ms, + ) # calculating UE ... winsize = 100 * pq.ms @@ -464,36 +533,42 @@ def test_Riehle_et_al_97_UE(self): t_winpos = ue._winpos(t_start, t_stop, winsize, winstep) significance_level = 0.05 - UE = ue.jointJ_window_analysis(spiketrain, - pattern_hash=pattern_hash, - bin_size=bin_size, - win_size=winsize, - win_step=winstep, - method='analytic_TrialAverage') + UE = ue.jointJ_window_analysis( + spiketrain, + pattern_hash=pattern_hash, + bin_size=bin_size, + win_size=winsize, + win_step=winstep, + method="analytic_TrialAverage", + ) # load extracted data from figure 2 of Riehle et al 1997 - extracted_data = np.load(ELEPHANT_TMP_DIR / 'extracted_data.npy', - encoding='latin1', allow_pickle=True).item() + extracted_data = np.load( + ELEPHANT_TMP_DIR / "extracted_data.npy", + encoding="latin1", + allow_pickle=True, + ).item() Js_sig = ue.jointJ(significance_level) - sig_idx_win = np.where(UE['Js'] >= Js_sig)[0] + sig_idx_win = np.where(UE["Js"] >= Js_sig)[0] diff_UE_rep = [] y_cnt = 0 for trial_id in range(len(spiketrain)): trial_id_str = "trial{}".format(trial_id) - indices_unique = np.unique(UE['indices'][trial_id_str]) + indices_unique = np.unique(UE["indices"][trial_id_str]) if len(indices_unique) > 0: # choose only the significant coincidences indices_unique_significant = [] for j in sig_idx_win: - significant = indices_unique[np.where( - (indices_unique * bin_size >= t_winpos[j]) & - (indices_unique * bin_size < t_winpos[j] + winsize))] + significant = indices_unique[ + np.where( + (indices_unique * bin_size >= t_winpos[j]) + & (indices_unique * bin_size < t_winpos[j] + winsize) + ) + ] indices_unique_significant.extend(significant) - x_tmp = np.unique(indices_unique_significant) * \ - bin_size.magnitude + x_tmp = np.unique(indices_unique_significant) * bin_size.magnitude if len(x_tmp) > 0: - ue_trial = np.sort(extracted_data['ue'][y_cnt]) - diff_UE_rep = np.append( - diff_UE_rep, x_tmp - ue_trial) + ue_trial = np.sort(extracted_data["ue"][y_cnt]) + diff_UE_rep = np.append(diff_UE_rep, x_tmp - ue_trial) y_cnt += +1 np.testing.assert_array_less(np.abs(diff_UE_rep), 0.3) @@ -501,70 +576,77 @@ def test_multiple_neurons(self): np.random.seed(12) # Create a list of lists containing 3 Trials with 5 spiketrains - spiketrains = \ - [StationaryPoissonProcess( - rate=50 * pq.Hz, t_stop=1 * pq.s).generate_n_spiketrains(5) - for _ in range(3)] + spiketrains = [ + StationaryPoissonProcess( + rate=50 * pq.Hz, t_stop=1 * pq.s + ).generate_n_spiketrains(5) + for _ in range(3) + ] spiketrains = list(zip(*spiketrains)) - UE_dic = ue.jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms, - win_size=300 * pq.ms, - win_step=100 * pq.ms) - - js_expected = [[0.3978179], - [0.08131966], - [-1.4239882], - [-0.9377029], - [-0.3374434], - [-0.2043383], - [-1.001536], - [-np.inf]] - indices_expected = \ - {'trial3': [12, 27, 31, 34, 27, 31, 34, 136, 136, 136], - 'trial4': [4, 60, 60, 60, 117, 117, 117]} - n_emp_expected = [[5.], - [4.], - [1.], - [2.], - [2.], - [2.], - [1.], - [0.]] - n_exp_expected = [[3.5591667], - [3.4536111], - [3.3158333], - [3.8466666], - [2.370278], - [2.0811112], - [2.4011111], - [3.0533333]] - rate_expected = [[[0.042, 0.03933334, 0.048]], - [[0.04533333, 0.038, 0.05]], - [[0.046, 0.04, 0.04666667]], - [[0.05066667, 0.042, 0.046]], - [[0.04466667, 0.03666667, 0.04066667]], - [[0.04066667, 0.03533333, 0.04333333]], - [[0.03933334, 0.038, 0.038]], - [[0.04066667, 0.04866667, 0.03666667]]] * (1. / pq.ms) - input_parameters_expected = {'pattern_hash': [7], - 'bin_size': 5 * pq.ms, - 'win_size': 300 * pq.ms, - 'win_step': 100 * pq.ms, - 'method': 'analytic_TrialByTrial', - 't_start': 0 * pq.s, - 't_stop': 1 * pq.s, 'n_surrogates': 100} - - assert_array_almost_equal(UE_dic['Js'], js_expected) - assert_array_almost_equal(UE_dic['n_emp'], n_emp_expected) - assert_array_almost_equal(UE_dic['n_exp'], n_exp_expected) - assert_array_almost_equal(UE_dic['rate_avg'], rate_expected) - self.assertEqual(sorted(UE_dic['indices'].keys()), - sorted(indices_expected.keys())) + UE_dic = ue.jointJ_window_analysis( + spiketrains, bin_size=5 * pq.ms, win_size=300 * pq.ms, win_step=100 * pq.ms + ) + + js_expected = [ + [0.3978179], + [0.08131966], + [-1.4239882], + [-0.9377029], + [-0.3374434], + [-0.2043383], + [-1.001536], + [-np.inf], + ] + indices_expected = { + "trial3": [12, 27, 31, 34, 27, 31, 34, 136, 136, 136], + "trial4": [4, 60, 60, 60, 117, 117, 117], + } + n_emp_expected = [[5.0], [4.0], [1.0], [2.0], [2.0], [2.0], [1.0], [0.0]] + n_exp_expected = [ + [3.5591667], + [3.4536111], + [3.3158333], + [3.8466666], + [2.370278], + [2.0811112], + [2.4011111], + [3.0533333], + ] + rate_expected = [ + [[0.042, 0.03933334, 0.048]], + [[0.04533333, 0.038, 0.05]], + [[0.046, 0.04, 0.04666667]], + [[0.05066667, 0.042, 0.046]], + [[0.04466667, 0.03666667, 0.04066667]], + [[0.04066667, 0.03533333, 0.04333333]], + [[0.03933334, 0.038, 0.038]], + [[0.04066667, 0.04866667, 0.03666667]], + ] * (1.0 / pq.ms) + input_parameters_expected = { + "pattern_hash": [7], + "bin_size": 5 * pq.ms, + "win_size": 300 * pq.ms, + "win_step": 100 * pq.ms, + "method": "analytic_TrialByTrial", + "t_start": 0 * pq.s, + "t_stop": 1 * pq.s, + "n_surrogates": 100, + } + + assert_array_almost_equal(UE_dic["Js"], js_expected) + assert_array_almost_equal(UE_dic["n_emp"], n_emp_expected) + assert_array_almost_equal(UE_dic["n_exp"], n_exp_expected) + assert_array_almost_equal(UE_dic["rate_avg"], rate_expected) + self.assertEqual( + sorted(UE_dic["indices"].keys()), sorted(indices_expected.keys()) + ) for trial_key in indices_expected.keys(): - assert_array_equal(indices_expected[trial_key], - UE_dic['indices'][trial_key]) - self.assertEqual(UE_dic['input_parameters'], input_parameters_expected) + assert_array_equal( + indices_expected[trial_key], UE_dic["indices"][trial_key] + ) + self.assertEqual(UE_dic["input_parameters"], input_parameters_expected) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index cc927d53e..9470a59cc 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -19,52 +19,61 @@ class TestUtils(unittest.TestCase): - def test_check_neo_consistency(self): - self.assertRaises(TypeError, - utils.check_neo_consistency, - [], object_type=neo.SpikeTrain) - self.assertRaises(TypeError, - utils.check_neo_consistency, - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - np.arange(2)], object_type=neo.SpikeTrain) - self.assertRaises(ValueError, - utils.check_neo_consistency, - [neo.SpikeTrain([1]*pq.s, - t_start=1*pq.s, - t_stop=2*pq.s), - neo.SpikeTrain([1]*pq.s, - t_start=0*pq.s, - t_stop=2*pq.s)], - object_type=neo.SpikeTrain) - self.assertRaises(ValueError, - utils.check_neo_consistency, - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - neo.SpikeTrain([1]*pq.s, t_stop=3*pq.s)], - object_type=neo.SpikeTrain) - self.assertRaises(ValueError, - utils.check_neo_consistency, - [neo.SpikeTrain([1]*pq.ms, t_stop=2000*pq.ms), - neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - object_type=neo.SpikeTrain) + self.assertRaises( + TypeError, utils.check_neo_consistency, [], object_type=neo.SpikeTrain + ) + self.assertRaises( + TypeError, + utils.check_neo_consistency, + [neo.SpikeTrain([1] * pq.s, t_stop=2 * pq.s), np.arange(2)], + object_type=neo.SpikeTrain, + ) + self.assertRaises( + ValueError, + utils.check_neo_consistency, + [ + neo.SpikeTrain([1] * pq.s, t_start=1 * pq.s, t_stop=2 * pq.s), + neo.SpikeTrain([1] * pq.s, t_start=0 * pq.s, t_stop=2 * pq.s), + ], + object_type=neo.SpikeTrain, + ) + self.assertRaises( + ValueError, + utils.check_neo_consistency, + [ + neo.SpikeTrain([1] * pq.s, t_stop=2 * pq.s), + neo.SpikeTrain([1] * pq.s, t_stop=3 * pq.s), + ], + object_type=neo.SpikeTrain, + ) + self.assertRaises( + ValueError, + utils.check_neo_consistency, + [ + neo.SpikeTrain([1] * pq.ms, t_stop=2000 * pq.ms), + neo.SpikeTrain([1] * pq.s, t_stop=2 * pq.s), + ], + object_type=neo.SpikeTrain, + ) def test_round_binning_errors(self): n_bins = utils.round_binning_errors(0.999999, tolerance=1e-6) self.assertEqual(n_bins, 1) - self.assertEqual(utils.round_binning_errors(0.999999, tolerance=None), - 0) + self.assertEqual(utils.round_binning_errors(0.999999, tolerance=None), 0) array = np.array([0, 0.7, 1 - 1e-8, 1 - 1e-9]) corrected = utils.round_binning_errors(array.copy()) assert_array_equal(corrected, [0, 0, 1, 1]) assert_array_equal( - utils.round_binning_errors(array.copy(), tolerance=None), - [0, 0, 0, 0]) + utils.round_binning_errors(array.copy(), tolerance=None), [0, 0, 0, 0] + ) class DecoratorTest: """ This class is used as a mock for testing the decorator. """ + @utils.trials_to_list_of_spiketrainlist def method_to_decorate(self, trials=None, trials_obj=None): # This is just a mock implementation for testing purposes @@ -80,23 +89,22 @@ def setUpClass(cls): cls.n_channels = 10 cls.n_trials = 5 cls.list_of_list_of_spiketrains = [ - StationaryPoissonProcess(rate=5 * pq.Hz, t_stop=1000.0 * pq.ms - ).generate_n_spiketrains(cls.n_channels) - for _ in range(cls.n_trials)] + StationaryPoissonProcess( + rate=5 * pq.Hz, t_stop=1000.0 * pq.ms + ).generate_n_spiketrains(cls.n_channels) + for _ in range(cls.n_trials) + ] cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) def test_decorator_applied(self): # Test that the decorator is applied correctly - self.assertTrue(hasattr( - DecoratorTest.method_to_decorate, '__wrapped__' - )) + self.assertTrue(hasattr(DecoratorTest.method_to_decorate, "__wrapped__")) def test_decorator_return_with_trials_input_as_arg(self): # Test if decorator takes in trial-object and returns # list of spiketrainlists new_class = DecoratorTest() - list_of_spiketrainlists = new_class.method_to_decorate( - self.trial_object) + list_of_spiketrainlists = new_class.method_to_decorate(self.trial_object) self.assertEqual(len(list_of_spiketrainlists), self.n_trials) for spiketrainlist in list_of_spiketrainlists: self.assertIsInstance(spiketrainlist, SpikeTrainList) @@ -106,7 +114,8 @@ def test_decorator_return_with_list_of_lists_input_as_arg(self): # and does not change input new_class = DecoratorTest() list_of_list_of_spiketrains = new_class.method_to_decorate( - self.list_of_list_of_spiketrains) + self.list_of_list_of_spiketrains + ) self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) for list_of_spiketrains in list_of_list_of_spiketrains: self.assertIsInstance(list_of_spiketrains, list) @@ -118,7 +127,8 @@ def test_decorator_return_with_trials_input_as_kwarg(self): # list of spiketrainlists new_class = DecoratorTest() list_of_spiketrainlists = new_class.method_to_decorate( - trials_obj=self.trial_object) + trials_obj=self.trial_object + ) self.assertEqual(len(list_of_spiketrainlists), self.n_trials) for spiketrainlist in list_of_spiketrainlists: self.assertIsInstance(spiketrainlist, SpikeTrainList) @@ -128,7 +138,8 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self): # and does not change input new_class = DecoratorTest() list_of_list_of_spiketrains = new_class.method_to_decorate( - trials_obj=self.list_of_list_of_spiketrains) + trials_obj=self.list_of_list_of_spiketrains + ) self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) for list_of_spiketrains in list_of_list_of_spiketrains: self.assertIsInstance(list_of_spiketrains, list) @@ -136,5 +147,5 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self): self.assertIsInstance(spiketrain, SpikeTrain) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/test/test_waveform_features.py b/elephant/test/test_waveform_features.py index a4a50ffbf..6c7e6c140 100644 --- a/elephant/test/test_waveform_features.py +++ b/elephant/test/test_waveform_features.py @@ -18,11 +18,46 @@ class WaveformWidthTestCase(unittest.TestCase): def setUp(self): - self.waveform = [29., 42., 41., 18., 24., 28., 34., 34., 9., - -31., -100., -145., -125., -88., -48., -18., 14., 36., - 30., 33., -4., -25., -3., 30., 51., 47., 70., - 76., 78., 57., 53., 49., 22., 15., 88., 109., - 79., 68.] + self.waveform = [ + 29.0, + 42.0, + 41.0, + 18.0, + 24.0, + 28.0, + 34.0, + 34.0, + 9.0, + -31.0, + -100.0, + -145.0, + -125.0, + -88.0, + -48.0, + -18.0, + 14.0, + 36.0, + 30.0, + 33.0, + -4.0, + -25.0, + -3.0, + 30.0, + 51.0, + 47.0, + 70.0, + 76.0, + 78.0, + 57.0, + 53.0, + 49.0, + 22.0, + 15.0, + 88.0, + 109.0, + 79.0, + 68.0, + ] self.target_width = 24 def test_list(self): @@ -42,8 +77,7 @@ def test_pq_quantity(self): def test_np_array_2d(self): waveform = np.asarray(self.waveform) waveform = np.vstack([waveform, waveform]) - self.assertRaises(ValueError, waveform_features.waveform_width, - waveform) + self.assertRaises(ValueError, waveform_features.waveform_width, waveform) def test_empty_list(self): self.assertRaises(ValueError, waveform_features.waveform_width, []) @@ -53,18 +87,17 @@ def test_cutoff(self): waveform = np.arange(size, dtype=float) for cutoff in (-1, 1): # outside of [0, 1) range - self.assertRaises(ValueError, waveform_features.waveform_width, - waveform, cutoff=cutoff) - for cutoff in np.linspace(0., 1., num=size, endpoint=False): + self.assertRaises( + ValueError, waveform_features.waveform_width, waveform, cutoff=cutoff + ) + for cutoff in np.linspace(0.0, 1.0, num=size, endpoint=False): width = waveform_features.waveform_width(waveform, cutoff=cutoff) self.assertEqual(width, size - 1) class WaveformSignalToNoiseRatioTestCase(unittest.TestCase): def test_zero_waveforms(self): - zero_waveforms = [np.zeros((5, 10)), - np.zeros((5, 1, 10)), - np.zeros((5, 3, 10))] + zero_waveforms = [np.zeros((5, 10)), np.zeros((5, 1, 10)), np.zeros((5, 3, 10))] for zero_wf in zero_waveforms: with self.assertWarns(UserWarning): # expect np.nan result when waveform noise is zero. @@ -77,8 +110,9 @@ def test_waveforms_arange_single_spiketrain(self): snr_float = waveform_features.waveform_snr(waveforms) self.assertIsInstance(snr_float, float) self.assertEqual(snr_float, target_snr) - self.assertEqual(waveform_features.waveform_snr(np.squeeze(waveforms)), - target_snr) + self.assertEqual( + waveform_features.waveform_snr(np.squeeze(waveforms)), target_snr + ) def test_waveforms_arange_multiple_spiketrains(self): target_snr = [0.3, 0.3, 0.3] @@ -88,5 +122,5 @@ def test_waveforms_arange_multiple_spiketrains(self): assert_array_almost_equal(snr_arr, target_snr) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/elephant/trials.py b/elephant/trials.py index cd006addd..f62133ea9 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -122,8 +122,7 @@ def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: pass @abstractmethod - def get_trials_as_block(self, trial_ids: List[int] = None - ) -> neo.core.Block: + def get_trials_as_block(self, trial_ids: List[int] = None) -> neo.core.Block: """Get trials as block. Parameters @@ -142,8 +141,9 @@ def get_trials_as_block(self, trial_ids: List[int] = None pass @abstractmethod - def get_trials_as_list(self, trial_ids: List[int] = None - ) -> neo.core.spiketrainlist.SpikeTrainList: + def get_trials_as_list( + self, trial_ids: List[int] = None + ) -> neo.core.spiketrainlist.SpikeTrainList: """Get trials as list of segments. Parameters @@ -162,8 +162,9 @@ def get_trials_as_list(self, trial_ids: List[int] = None pass @abstractmethod - def get_spiketrains_from_trial_as_list(self, trial_id: int) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_from_trial_as_list( + self, trial_id: int + ) -> neo.core.spiketrainlist.SpikeTrainList: """ Get all spike trains from a specific trial and return a list. @@ -180,8 +181,7 @@ def get_spiketrains_from_trial_as_list(self, trial_id: int) -> ( pass @abstractmethod - def get_spiketrains_from_trial_as_segment(self, trial_id: int - ) -> neo.core.Segment: + def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> neo.core.Segment: """ Get all spike trains from a specific trial and return a Segment. @@ -197,8 +197,9 @@ def get_spiketrains_from_trial_as_segment(self, trial_id: int pass @abstractmethod - def get_analogsignals_from_trial_as_list(self, trial_id: int - ) -> List[neo.core.AnalogSignal]: + def get_analogsignals_from_trial_as_list( + self, trial_id: int + ) -> List[neo.core.AnalogSignal]: """ Get all analogsignals from a specific trial and return a list. @@ -215,8 +216,9 @@ def get_analogsignals_from_trial_as_list(self, trial_id: int pass @abstractmethod - def get_analogsignals_from_trial_as_segment(self, trial_id: int - ) -> neo.core.Segment: + def get_analogsignals_from_trial_as_segment( + self, trial_id: int + ) -> neo.core.Segment: """ Get all analogsignal objects from a specific trial and return a :class:`neo.Segment`. @@ -261,8 +263,7 @@ def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: # Get a specific trial by number as a segment return self.__getitem__(trial_id) - def get_trials_as_block(self, trial_ids: List[int] = None - ) -> neo.core.Block: + def get_trials_as_block(self, trial_ids: List[int] = None) -> neo.core.Block: # Get a block of trials by trial numbers block = Block() if not trial_ids: @@ -271,13 +272,11 @@ def get_trials_as_block(self, trial_ids: List[int] = None block.segments.append(self.get_trial_as_segment(trial_number)) return block - def get_trials_as_list(self, trial_ids: List[int] = None - ) -> List[neo.core.Segment]: + def get_trials_as_list(self, trial_ids: List[int] = None) -> List[neo.core.Segment]: if not trial_ids: trial_ids = list(range(self.n_trials)) # Get a list of segments by trial numbers - return [self.get_trial_as_segment(trial_number) - for trial_number in trial_ids] + return [self.get_trial_as_segment(trial_number) for trial_number in trial_ids] @property def n_trials(self) -> int: @@ -294,33 +293,37 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: # Get the number of AnalogSignals instances in each trial. return [len(trial.analogsignals) for trial in self.block.segments] - def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_from_trial_as_list( + self, trial_id: int = 0 + ) -> neo.core.spiketrainlist.SpikeTrainList: # Return a list of all spike trains from a trial - return SpikeTrainList(items=[spiketrain for spiketrain in - self.block.segments[trial_id].spiketrains]) + return SpikeTrainList( + items=[ + spiketrain for spiketrain in self.block.segments[trial_id].spiketrains + ] + ) - def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): + def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> neo.core.Segment: # Return a segment with all spiketrains from a trial segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id - ): + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id): segment.spiketrains.append(spiketrain) return segment - def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( - List[neo.core.AnalogSignal]): + def get_analogsignals_from_trial_as_list( + self, trial_id: int + ) -> List[neo.core.AnalogSignal]: # Return a list of all analogsignals from a trial - return [analogsignal for analogsignal in - self.block.segments[trial_id].analogsignals] + return [ + analogsignal for analogsignal in self.block.segments[trial_id].analogsignals + ] - def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): + def get_analogsignals_from_trial_as_segment( + self, trial_id: int + ) -> neo.core.Segment: # Return a segment with all analogsignals from a trial segment = neo.core.Segment() - for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_id): + for analogsignal in self.get_analogsignals_from_trial_as_list(trial_id): segment.analogsignals.append(analogsignal) return segment @@ -360,8 +363,7 @@ def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: # Get a specific trial by number as a segment return self.__getitem__(trial_id) - def get_trials_as_block(self, trial_ids: List[int] = None - ) -> neo.core.Block: + def get_trials_as_block(self, trial_ids: List[int] = None) -> neo.core.Block: if not trial_ids: trial_ids = list(range(self.n_trials)) # Get a block of trials by trial numbers @@ -370,13 +372,11 @@ def get_trials_as_block(self, trial_ids: List[int] = None block.segments.append(self.get_trial_as_segment(trial_number)) return block - def get_trials_as_list(self, trial_ids: List[int] = None - ) -> List[neo.core.Segment]: + def get_trials_as_list(self, trial_ids: List[int] = None) -> List[neo.core.Segment]: if not trial_ids: trial_ids = list(range(self.n_trials)) # Get a list of segments by trial numbers - return [self.get_trial_as_segment(trial_number) - for trial_number in trial_ids] + return [self.get_trial_as_segment(trial_number) for trial_number in trial_ids] @property def n_trials(self) -> int: @@ -386,42 +386,53 @@ def n_trials(self) -> int: @property def n_spiketrains_trial_by_trial(self) -> List[int]: # Get the number of spiketrains in each trial. - return [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) - for trial in self.list_of_trials] + return [ + sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) + for trial in self.list_of_trials + ] @property def n_analogsignals_trial_by_trial(self) -> List[int]: # Get the number of analogsignals in each trial. - return [sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) - for trial in self.list_of_trials] - - def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( - neo.core.spiketrainlist.SpikeTrainList): + return [ + sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) + for trial in self.list_of_trials + ] + + def get_spiketrains_from_trial_as_list( + self, trial_id: int = 0 + ) -> neo.core.spiketrainlist.SpikeTrainList: # Return a list of all spiketrains from a trial - return SpikeTrainList(items=[ - spiketrain for spiketrain in self.list_of_trials[trial_id] - if isinstance(spiketrain, neo.core.SpikeTrain)]) - - def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): + return SpikeTrainList( + items=[ + spiketrain + for spiketrain in self.list_of_trials[trial_id] + if isinstance(spiketrain, neo.core.SpikeTrain) + ] + ) + + def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> neo.core.Segment: # Return a segment with all spiketrains from a trial segment = neo.core.Segment() for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id): segment.spiketrains.append(spiketrain) return segment - def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( - List[neo.core.AnalogSignal]): + def get_analogsignals_from_trial_as_list( + self, trial_id: int + ) -> List[neo.core.AnalogSignal]: # Return a list of all analogsignals from a trial - return [analogsignal for analogsignal in - self.list_of_trials[trial_id] - if isinstance(analogsignal, neo.core.AnalogSignal)] - - def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): + return [ + analogsignal + for analogsignal in self.list_of_trials[trial_id] + if isinstance(analogsignal, neo.core.AnalogSignal) + ] + + def get_analogsignals_from_trial_as_segment( + self, trial_id: int + ) -> neo.core.Segment: # Return a segment with all analogsignals from a trial segment = neo.core.Segment() - for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_id): + for analogsignal in self.get_analogsignals_from_trial_as_list(trial_id): segment.analogsignals.append(analogsignal) return segment diff --git a/elephant/unitary_event_analysis.py b/elephant/unitary_event_analysis.py index 6396905bc..ea425d480 100644 --- a/elephant/unitary_event_analysis.py +++ b/elephant/unitary_event_analysis.py @@ -71,7 +71,7 @@ "n_exp_mat_sum_trial", "gen_pval_anal", "jointJ", - "jointJ_window_analysis" + "jointJ_window_analysis", ] @@ -126,11 +126,11 @@ def hash_from_pattern(m, base=2): # check the entries of the matrix if not is_binary(m): - raise ValueError('Patterns should be binary: 0 or 1') + raise ValueError("Patterns should be binary: 0 or 1") # generate the representation # don't use numpy - it's upperbounded by int64 - powers = [base ** x for x in range(n_neurons)][::-1] + powers = [base**x for x in range(n_neurons)][::-1] # calculate the binary number by use of scalar product return np.dot(powers, m) @@ -184,10 +184,11 @@ def inverse_hash_from_pattern(h, N, base=2): # check if the hash values are not greater than the greatest possible # value for N neurons with the given base - powers = np.array([base ** x for x in range(N)])[::-1] + powers = np.array([base**x for x in range(N)])[::-1] if any(h > sum(powers)): - raise ValueError(f"hash value {h} is not compatible with the number " - f"of neurons {N}") + raise ValueError( + f"hash value {h} is not compatible with the number " f"of neurons {N}" + ) m = h // np.expand_dims(powers, axis=1) m %= base # m is a binary matrix now m = m.astype(int) # convert object to int if the hash was > int64 @@ -335,8 +336,9 @@ def _n_exp_mat_analytic(mat, pattern_hash): nrep = m.shape[1] # multipyling the marginal probability of neurons with regard to the # pattern - pmat = np.multiply(m, np.tile(marg_prob, (1, nrep))) \ - + np.multiply(1 - m, np.tile(1 - marg_prob, (1, nrep))) + pmat = np.multiply(m, np.tile(marg_prob, (1, nrep))) + np.multiply( + 1 - m, np.tile(1 - marg_prob, (1, nrep)) + ) return np.prod(pmat, axis=0) * float(mat.shape[1]) @@ -346,7 +348,7 @@ def _n_exp_mat_surrogate(mat, pattern_hash, n_surrogates=1): time randomization surrogate """ if len(pattern_hash) > 1: - raise ValueError('surrogate method works only for one pattern!') + raise ValueError("surrogate method works only for one pattern!") N_exp_array = np.zeros(n_surrogates) for rz_idx, rz in enumerate(np.arange(n_surrogates)): # row-wise shuffling all elements of zero-one matrix @@ -357,7 +359,7 @@ def _n_exp_mat_surrogate(mat, pattern_hash, n_surrogates=1): return N_exp_array -def n_exp_mat(mat, pattern_hash, method='analytic', n_surrogates=1): +def n_exp_mat(mat, pattern_hash, method="analytic", n_surrogates=1): """ Calculates the expected joint probability for each spike pattern. @@ -423,15 +425,15 @@ def n_exp_mat(mat, pattern_hash, method='analytic', n_surrogates=1): if not np.all((mat >= 0) & (mat <= 1)): raise ValueError("entries of mat should be in range [0, 1]") - if method == 'analytic': + if method == "analytic": return _n_exp_mat_analytic(mat, pattern_hash) - elif method == 'surr': - return _n_exp_mat_surrogate(mat, pattern_hash, - n_surrogates=n_surrogates) + elif method == "surr": + return _n_exp_mat_surrogate(mat, pattern_hash, n_surrogates=n_surrogates) -def n_exp_mat_sum_trial(mat, pattern_hash, method='analytic_TrialByTrial', - n_surrogates=1): +def n_exp_mat_sum_trial( + mat, pattern_hash, method="analytic_TrialByTrial", n_surrogates=1 +): """ Calculates the expected joint probability for each spike pattern sum over trials. @@ -489,28 +491,27 @@ def n_exp_mat_sum_trial(mat, pattern_hash, method='analytic_TrialByTrial', >>> print(n_exp_anal) [1.56 2.56] """ - if method == 'analytic_TrialByTrial': + if method == "analytic_TrialByTrial": n_exp = np.zeros(len(pattern_hash)) for mat_tr in mat: - n_exp += n_exp_mat(mat_tr, pattern_hash, - method='analytic') - elif method == 'analytic_TrialAverage': - n_exp = n_exp_mat( - np.mean(mat, axis=0), pattern_hash, - method='analytic') * mat.shape[0] - elif method == 'surrogate_TrialByTrial': + n_exp += n_exp_mat(mat_tr, pattern_hash, method="analytic") + elif method == "analytic_TrialAverage": + n_exp = ( + n_exp_mat(np.mean(mat, axis=0), pattern_hash, method="analytic") + * mat.shape[0] + ) + elif method == "surrogate_TrialByTrial": n_exp = np.zeros(n_surrogates) for mat_tr in mat: - n_exp += n_exp_mat(mat_tr, pattern_hash, - method='surr', n_surrogates=n_surrogates) + n_exp += n_exp_mat( + mat_tr, pattern_hash, method="surr", n_surrogates=n_surrogates + ) else: - raise ValueError( - "The method only works on the zero_one matrix at the moment") + raise ValueError("The method only works on the zero_one matrix at the moment") return n_exp -def gen_pval_anal(mat, pattern_hash, method='analytic_TrialByTrial', - n_surrogates=1): +def gen_pval_anal(mat, pattern_hash, method="analytic_TrialByTrial", n_surrogates=1): """ Compute the expected coincidences and a function to calculate the p-value for the given empirical coincidences. @@ -572,26 +573,28 @@ def gen_pval_anal(mat, pattern_hash, method='analytic_TrialByTrial', [1.56 2.56] """ - if method == 'analytic_TrialByTrial' or method == 'analytic_TrialAverage': + if method == "analytic_TrialByTrial" or method == "analytic_TrialAverage": n_exp = n_exp_mat_sum_trial(mat, pattern_hash, method=method) def pval(n_emp): - p = 1. - scipy.special.gammaincc(n_emp, n_exp) + p = 1.0 - scipy.special.gammaincc(n_emp, n_exp) return p - elif method == 'surrogate_TrialByTrial': + elif method == "surrogate_TrialByTrial": n_exp = n_exp_mat_sum_trial( - mat, pattern_hash, method=method, n_surrogates=n_surrogates) + mat, pattern_hash, method=method, n_surrogates=n_surrogates + ) def pval(n_emp): - hist = np.bincount(n_exp.astype(int, casting='unsafe')) + hist = np.bincount(n_exp.astype(int, casting="unsafe")) exp_dist = hist / float(np.sum(hist)) if len(n_emp) > 1: - raise ValueError('In surrogate method the p_value can be' - 'calculated only for one pattern!') - return np.sum(exp_dist[int(n_emp[0]):]) + raise ValueError( + "In surrogate method the p_value can be" + "calculated only for one pattern!" + ) + return np.sum(exp_dist[int(n_emp[0]) :]) else: - raise ValueError("Method is not allowed: {method}".format( - method=method)) + raise ValueError("Method is not allowed: {method}".format(method=method)) return pval, n_exp @@ -622,7 +625,7 @@ def jointJ(p_val): """ p_arr = np.asarray(p_val) - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): # Ignore 'Division by zero' warning which happens when p_arr is 1.0 or # 0.0 (no spikes). Js = np.log10(1 - p_arr) - np.log10(p_arr) @@ -644,32 +647,32 @@ def _bintime(t, bin_size): """ Change the real time to `bin_size` units. """ - t_dl = t.rescale('ms').magnitude - bin_size_dl = bin_size.rescale('ms').item() + t_dl = t.rescale("ms").magnitude + bin_size_dl = bin_size.rescale("ms").item() return np.floor(np.array(t_dl) / bin_size_dl).astype(int) -def _winpos(t_start, t_stop, win_size, win_step, position='left-edge'): +def _winpos(t_start, t_stop, win_size, win_step, position="left-edge"): """ Calculate the position of the analysis window. """ - t_start_dl = t_start.rescale('ms').item() - t_stop_dl = t_stop.rescale('ms').item() - winsize_dl = win_size.rescale('ms').item() - winstep_dl = win_step.rescale('ms').item() + t_start_dl = t_start.rescale("ms").item() + t_stop_dl = t_stop.rescale("ms").item() + winsize_dl = win_size.rescale("ms").item() + winstep_dl = win_step.rescale("ms").item() # left side of the window time - if position == 'left-edge': - ts_winpos = np.arange( - t_start_dl, t_stop_dl - winsize_dl + winstep_dl, - winstep_dl) * pq.ms + if position == "left-edge": + ts_winpos = ( + np.arange(t_start_dl, t_stop_dl - winsize_dl + winstep_dl, winstep_dl) + * pq.ms + ) else: - raise ValueError( - 'the current version only returns left-edge of the window') + raise ValueError("the current version only returns left-edge of the window") return ts_winpos -def _UE(mat, pattern_hash, method='analytic_TrialByTrial', n_surrogates=1): +def _UE(mat, pattern_hash, method="analytic_TrialByTrial", n_surrogates=1): """ Return the default results of unitary events analysis (Surprise, empirical coincidences and index of where it happened @@ -677,23 +680,30 @@ def _UE(mat, pattern_hash, method='analytic_TrialByTrial', n_surrogates=1): """ rate_avg = _rate_mat_avg_trial(mat) n_emp, indices = n_emp_mat_sum_trial(mat, pattern_hash) - if method == 'surrogate_TrialByTrial': + if method == "surrogate_TrialByTrial": dist_exp, n_exp = gen_pval_anal( - mat, pattern_hash, method, n_surrogates=n_surrogates) + mat, pattern_hash, method, n_surrogates=n_surrogates + ) n_exp = np.mean(n_exp) - elif method == 'analytic_TrialByTrial' or \ - method == 'analytic_TrialAverage': + elif method == "analytic_TrialByTrial" or method == "analytic_TrialAverage": dist_exp, n_exp = gen_pval_anal(mat, pattern_hash, method) pval = dist_exp(n_emp) Js = jointJ(pval) return Js, rate_avg, n_exp, n_emp, indices -def jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms, - win_size=100 * pq.ms, win_step=5 * pq.ms, - pattern_hash=None, method='analytic_TrialByTrial', - t_start=None, t_stop=None, binary=True, - n_surrogates=100): +def jointJ_window_analysis( + spiketrains, + bin_size=5 * pq.ms, + win_size=100 * pq.ms, + win_step=5 * pq.ms, + pattern_hash=None, + method="analytic_TrialByTrial", + t_start=None, + t_stop=None, + binary=True, + n_surrogates=100, +): """ Calculates the joint surprise in a sliding window fashion. @@ -785,7 +795,8 @@ def jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms, if not isinstance(spiketrains[0][0], neo.SpikeTrain): raise ValueError( "structure of the data is not correct: 0-axis should be trials, " - "1-axis units and 2-axis neo spike trains") + "1-axis units and 2-axis neo spike trains" + ) if t_start is None: t_start = spiketrains[0][0].t_start @@ -801,55 +812,69 @@ def jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms, pattern_hash = [int(pattern_hash)] # position of all windows (left edges) - t_winpos = _winpos(t_start, t_stop, win_size, win_step, - position='left-edge') + t_winpos = _winpos(t_start, t_stop, win_size, win_step, position="left-edge") t_winpos_bintime = _bintime(t_winpos, bin_size) winsize_bintime = _bintime(win_size, bin_size) winstep_bintime = _bintime(win_step, bin_size) if winsize_bintime * bin_size != win_size: - warnings.warn(f"The ratio between the win_size ({win_size}) and the " - f"bin_size ({bin_size}) is not an integer") + warnings.warn( + f"The ratio between the win_size ({win_size}) and the " + f"bin_size ({bin_size}) is not an integer" + ) if winstep_bintime * bin_size != win_step: - warnings.warn(f"The ratio between the win_step ({win_step}) and the " - f"bin_size ({bin_size}) is not an integer") - - input_parameters = dict(pattern_hash=pattern_hash, bin_size=bin_size, - win_size=win_size, win_step=win_step, - method=method, t_start=t_start, t_stop=t_stop, - n_surrogates=n_surrogates) + warnings.warn( + f"The ratio between the win_step ({win_step}) and the " + f"bin_size ({bin_size}) is not an integer" + ) + + input_parameters = dict( + pattern_hash=pattern_hash, + bin_size=bin_size, + win_size=win_size, + win_step=win_step, + method=method, + t_start=t_start, + t_stop=t_stop, + n_surrogates=n_surrogates, + ) n_bins = int(((t_stop - t_start) / bin_size).simplified.item()) - mat_tr_unit_spt = np.zeros((len(spiketrains), n_neurons, n_bins), - dtype=np.int32) + mat_tr_unit_spt = np.zeros((len(spiketrains), n_neurons, n_bins), dtype=np.int32) for trial, sts in enumerate(spiketrains): - bs = conv.BinnedSpikeTrain(list(sts), t_start=t_start, t_stop=t_stop, - bin_size=bin_size) + bs = conv.BinnedSpikeTrain( + list(sts), t_start=t_start, t_stop=t_stop, bin_size=bin_size + ) if not binary: raise NotImplementedError( - "The method works only with binary matrices at the moment") + "The method works only with binary matrices at the moment" + ) mat_tr_unit_spt[trial] = bs.to_bool_array() n_windows = len(t_winpos) n_hashes = len(pattern_hash) - Js_win, n_exp_win, n_emp_win = np.zeros((3, n_windows, n_hashes), - dtype=np.float32) + Js_win, n_exp_win, n_emp_win = np.zeros((3, n_windows, n_hashes), dtype=np.float32) rate_avg = np.zeros((n_windows, n_hashes, n_neurons), dtype=np.float32) indices_win = defaultdict(list) for i, win_pos in enumerate(t_winpos_bintime): - mat_win = mat_tr_unit_spt[:, :, win_pos:win_pos + winsize_bintime] - Js_win[i], rate_avg[i], n_exp_win[i], n_emp_win[ - i], indices_lst = _UE(mat_win, pattern_hash=pattern_hash, - method=method, n_surrogates=n_surrogates) + mat_win = mat_tr_unit_spt[:, :, win_pos : win_pos + winsize_bintime] + Js_win[i], rate_avg[i], n_exp_win[i], n_emp_win[i], indices_lst = _UE( + mat_win, pattern_hash=pattern_hash, method=method, n_surrogates=n_surrogates + ) for j in range(n_trials): if len(indices_lst[j][0]) > 0: indices_win[f"trial{j}"].append(indices_lst[j][0] + win_pos) for key in indices_win.keys(): indices_win[key] = np.hstack(indices_win[key]) - return {'Js': Js_win, 'indices': indices_win, 'n_emp': n_emp_win, - 'n_exp': n_exp_win, 'rate_avg': rate_avg / bin_size, - 'input_parameters': input_parameters} + return { + "Js": Js_win, + "indices": indices_win, + "n_emp": n_emp_win, + "n_exp": n_exp_win, + "rate_avg": rate_avg / bin_size, + "input_parameters": input_parameters, + } diff --git a/elephant/utils.py b/elephant/utils.py index b4ddfee22..236130a57 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -104,12 +104,8 @@ def _rename_kwargs(func_name, kwargs, aliases): for old, new in aliases.items(): if old in kwargs: if new in kwargs: - raise TypeError( - f"{func_name} received both '{old}' and " f"'{new}'" - ) - warnings.warn( - f"'{old}' is deprecated; use '{new}'", DeprecationWarning - ) + raise TypeError(f"{func_name} received both '{old}' and " f"'{new}'") + warnings.warn(f"'{old}' is deprecated; use '{new}'", DeprecationWarning) kwargs[new] = kwargs.pop(old) @@ -177,9 +173,7 @@ def get_common_start_stop_times(neo_objects): "Input neo objects must have 't_start' and " "'t_stop' attributes" ) if t_stop < t_start: - raise ValueError( - f"t_stop ({t_stop}) is smaller than t_start " f"({t_start})" - ) + raise ValueError(f"t_stop ({t_stop}) is smaller than t_start " f"({t_start})") return t_start, t_stop @@ -423,10 +417,7 @@ def trials_to_list_of_spiketrainlist(method): @wraps(method) def wrapper(*args, **kwargs): new_args = tuple( - [ - arg.get_spiketrains_from_trial_as_list(idx) - for idx in range(arg.n_trials) - ] + [arg.get_spiketrains_from_trial_as_list(idx) for idx in range(arg.n_trials)] if isinstance(arg, Trials) else arg for arg in args diff --git a/elephant/waveform_features.py b/elephant/waveform_features.py index e36a52f71..507338ba5 100644 --- a/elephant/waveform_features.py +++ b/elephant/waveform_features.py @@ -17,10 +17,7 @@ import neo import numpy as np -__all__ = [ - "waveform_width", - "waveform_snr" -] +__all__ = ["waveform_width", "waveform_snr"] def waveform_width(waveform, cutoff=0.75): @@ -63,11 +60,11 @@ def waveform_width(waveform, cutoff=0.75): """ waveform = np.squeeze(waveform) if np.ndim(waveform) != 1: - raise ValueError('Expected 1-dimensional waveform.') + raise ValueError("Expected 1-dimensional waveform.") if len(waveform) < 2: - raise ValueError('Too short waveform.') + raise ValueError("Too short waveform.") if not (0 <= cutoff < 1): - raise ValueError('Cuttoff must be in range [0, 1).') + raise ValueError("Cuttoff must be in range [0, 1).") min_border = max(1, int(len(waveform) * cutoff)) idx_min = np.argmin(waveform[:min_border]) @@ -120,8 +117,10 @@ def waveform_snr(waveforms): """ if isinstance(waveforms, neo.SpikeTrain): - warnings.warn("spiketrain input is deprecated; pass " - "'spiketrain.waveforms' directly.", DeprecationWarning) + warnings.warn( + "spiketrain input is deprecated; pass " "'spiketrain.waveforms' directly.", + DeprecationWarning, + ) waveforms = waveforms.waveforms # asarray removes quantities, if present waveforms = np.squeeze(np.asarray(waveforms)) @@ -137,6 +136,6 @@ def waveform_snr(waveforms): snr = peak_range / noise if np.isnan(snr).any(): - warnings.warn('The waveforms noise was evaluated to 0. Returning NaN') + warnings.warn("The waveforms noise was evaluated to 0. Returning NaN") return snr