Skip to content

Commit

Permalink
Black
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Jul 9, 2024
1 parent 7a55d6b commit e1632f6
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 38 deletions.
175 changes: 138 additions & 37 deletions sortingview/SpikeSortingView/prepare_spikesortingview_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,48 @@ def prepare_spikesortingview_data(
num_frames = recording.get_num_frames()
num_frames_per_segment = math.ceil(segment_duration_sec * sampling_frequency)
num_segments = math.ceil(num_frames / num_frames_per_segment)
if hasattr(recording, 'has_scaleable_traces') and callable(getattr(recording, 'has_scaleable_traces')):
if hasattr(recording, "has_scaleable_traces") and callable(
getattr(recording, "has_scaleable_traces")
):
scalable = recording.has_scaleable_traces()
elif hasattr(recording, 'has_scaled') and callable(getattr(recording, 'has_scaled')):
elif hasattr(recording, "has_scaled") and callable(
getattr(recording, "has_scaled")
):
scalable = recording.has_scaled()
else:
scalable = False

with kcl.TemporaryDirectory() as tmpdir:
output_file_name = tmpdir + "/spikesortingview.h5"
with h5py.File(output_file_name, "w") as f:
f.create_dataset("unit_ids", data=unit_ids)
f.create_dataset("sampling_frequency", data=np.array([sampling_frequency]).astype(np.float32))
f.create_dataset(
"sampling_frequency",
data=np.array([sampling_frequency]).astype(np.float32),
)
f.create_dataset("channel_ids", data=channel_ids)
f.create_dataset("num_frames", data=np.array([num_frames]).astype(int_type))
channel_locations = recording.get_channel_locations()
f.create_dataset("channel_locations", data=np.array(channel_locations))
f.create_dataset("num_segments", data=np.array([num_segments]).astype(np.int32))
f.create_dataset("num_frames_per_segment", data=np.array([num_frames_per_segment]).astype(np.int32))
f.create_dataset("snippet_len", data=np.array([snippet_len[0], snippet_len[1]]).astype(np.int32))
f.create_dataset("max_num_snippets_per_segment", data=np.array([max_num_snippets_per_segment]).astype(np.int32))
f.create_dataset("channel_neighborhood_size", data=np.array([channel_neighborhood_size]).astype(np.int32))
f.create_dataset(
"num_segments", data=np.array([num_segments]).astype(np.int32)
)
f.create_dataset(
"num_frames_per_segment",
data=np.array([num_frames_per_segment]).astype(np.int32),
)
f.create_dataset(
"snippet_len",
data=np.array([snippet_len[0], snippet_len[1]]).astype(np.int32),
)
f.create_dataset(
"max_num_snippets_per_segment",
data=np.array([max_num_snippets_per_segment]).astype(np.int32),
)
f.create_dataset(
"channel_neighborhood_size",
data=np.array([channel_neighborhood_size]).astype(np.int32),
)

# first get peak channels and channel neighborhoods
unit_peak_channel_ids = {}
Expand All @@ -74,83 +95,161 @@ def prepare_spikesortingview_data(
end_frame = min(start_frame + num_frames_per_segment, num_frames)
start_frame_with_padding = max(start_frame - snippet_len[0], 0)
end_frame_with_padding = min(end_frame + snippet_len[1], num_frames)
traces_with_padding = recording.get_traces(start_frame=start_frame_with_padding, end_frame=end_frame_with_padding, return_scaled=scalable)
traces_with_padding = recording.get_traces(
start_frame=start_frame_with_padding,
end_frame=end_frame_with_padding,
return_scaled=scalable,
)
assert isinstance(traces_with_padding, np.ndarray)
for unit_id in unit_ids:
if str(unit_id) not in unit_peak_channel_ids:
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame)
spike_train = sorting.get_unit_spike_train(
unit_id=unit_id,
start_frame=start_frame,
end_frame=end_frame,
)
assert isinstance(spike_train, np.ndarray)
if len(spike_train) > 0:
values = traces_with_padding[spike_train - start_frame_with_padding, :].astype(np.int32)
values = traces_with_padding[
spike_train - start_frame_with_padding, :
].astype(np.int32)
avg_value = np.mean(values, axis=0)
peak_channel_ind = np.argmax(np.abs(avg_value))
peak_channel_id = channel_ids[peak_channel_ind]
channel_neighborhood = get_channel_neighborhood(
channel_ids=channel_ids, channel_locations=channel_locations, peak_channel_id=peak_channel_id, channel_neighborhood_size=channel_neighborhood_size
channel_ids=channel_ids,
channel_locations=channel_locations,
peak_channel_id=peak_channel_id,
channel_neighborhood_size=channel_neighborhood_size,
)
if len(spike_train) >= 10:
unit_peak_channel_ids[str(unit_id)] = peak_channel_id
else:
fallback_unit_peak_channel_ids[str(unit_id)] = peak_channel_id
unit_channel_neighborhoods[str(unit_id)] = channel_neighborhood
fallback_unit_peak_channel_ids[
str(unit_id)
] = peak_channel_id
unit_channel_neighborhoods[
str(unit_id)
] = channel_neighborhood
for unit_id in unit_ids:
peak_channel_id = unit_peak_channel_ids.get(str(unit_id), None)
if peak_channel_id is None:
peak_channel_id = fallback_unit_peak_channel_ids.get(str(unit_id), None)
peak_channel_id = fallback_unit_peak_channel_ids.get(
str(unit_id), None
)
if peak_channel_id is None:
raise Exception(f"Peak channel not found for unit {unit_id}. This is probably because no spikes were found in any segment for this unit.")
raise Exception(
f"Peak channel not found for unit {unit_id}. This is probably because no spikes were found in any segment for this unit."
)
channel_neighborhood = unit_channel_neighborhoods[str(unit_id)]
f.create_dataset(f"unit/{unit_id}/peak_channel_id", data=np.array([peak_channel_id]).astype(np.int32))
f.create_dataset(f"unit/{unit_id}/channel_neighborhood", data=np.array(channel_neighborhood).astype(np.int32))
f.create_dataset(
f"unit/{unit_id}/peak_channel_id",
data=np.array([peak_channel_id]).astype(np.int32),
)
f.create_dataset(
f"unit/{unit_id}/channel_neighborhood",
data=np.array(channel_neighborhood).astype(np.int32),
)

for iseg in range(num_segments):
print(f"Segment {iseg} of {num_segments}")
start_frame = iseg * num_frames_per_segment
end_frame = min(start_frame + num_frames_per_segment, num_frames)
start_frame_with_padding = max(start_frame - snippet_len[0], 0)
end_frame_with_padding = min(end_frame + snippet_len[1], num_frames)
traces_with_padding = recording.get_traces(start_frame=start_frame_with_padding, end_frame=end_frame_with_padding, return_scaled=scalable)
traces_sample = traces_with_padding[start_frame - start_frame_with_padding : start_frame - start_frame_with_padding + int(sampling_frequency * 1), :]
traces_with_padding = recording.get_traces(
start_frame=start_frame_with_padding,
end_frame=end_frame_with_padding,
return_scaled=scalable,
)
traces_sample = traces_with_padding[
start_frame
- start_frame_with_padding : start_frame
- start_frame_with_padding
+ int(sampling_frequency * 1),
:,
]
f.create_dataset(f"segment/{iseg}/traces_sample", data=traces_sample)
all_subsampled_spike_trains = []
for unit_id in unit_ids:
peak_channel_id = unit_peak_channel_ids.get(str(unit_id), None)
if peak_channel_id is None:
peak_channel_id = fallback_unit_peak_channel_ids.get(str(unit_id), None)
peak_channel_id = fallback_unit_peak_channel_ids.get(
str(unit_id), None
)
if peak_channel_id is None:
raise Exception(f"Peak channel not found for unit {unit_id}. This is probably because no spikes were found in any segment for this unit.")
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame).astype(int_type)
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/spike_train", data=spike_train)
raise Exception(
f"Peak channel not found for unit {unit_id}. This is probably because no spikes were found in any segment for this unit."
)
spike_train = sorting.get_unit_spike_train(
unit_id=unit_id, start_frame=start_frame, end_frame=end_frame
).astype(int_type)
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/spike_train", data=spike_train
)
channel_neighborhood = unit_channel_neighborhoods[str(unit_id)]
peak_channel_ind = channel_ids.tolist().index(peak_channel_id)
if len(spike_train) > 0:
spike_amplitudes = traces_with_padding[spike_train - start_frame_with_padding, peak_channel_ind]
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/spike_amplitudes", data=spike_amplitudes)
spike_amplitudes = traces_with_padding[
spike_train - start_frame_with_padding, peak_channel_ind
]
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/spike_amplitudes",
data=spike_amplitudes,
)
else:
spike_amplitudes = np.array([], dtype=np.int32)
if max_num_snippets_per_segment is not None and len(spike_train) > max_num_snippets_per_segment:
subsampled_spike_train = subsample(spike_train, max_num_snippets_per_segment)
if (
max_num_snippets_per_segment is not None
and len(spike_train) > max_num_snippets_per_segment
):
subsampled_spike_train = subsample(
spike_train, max_num_snippets_per_segment
)
else:
subsampled_spike_train = spike_train
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/subsampled_spike_train", data=subsampled_spike_train)
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/subsampled_spike_train",
data=subsampled_spike_train,
)
all_subsampled_spike_trains.append(subsampled_spike_train)
subsampled_spike_trains_concat = np.concatenate(all_subsampled_spike_trains)
subsampled_spike_trains_concat = np.concatenate(
all_subsampled_spike_trains
)
# print('Extracting spike snippets')
spike_snippets_concat = extract_spike_snippets(traces=traces_with_padding, times=subsampled_spike_trains_concat - start_frame_with_padding, snippet_len=snippet_len)
spike_snippets_concat = extract_spike_snippets(
traces=traces_with_padding,
times=subsampled_spike_trains_concat - start_frame_with_padding,
snippet_len=snippet_len,
)
# print('Collecting spike snippets')
index = 0
for ii, unit_id in enumerate(unit_ids):
channel_neighborhood = unit_channel_neighborhoods[str(unit_id)]
channel_neighborhood_indices = [channel_ids.tolist().index(ch_id) for ch_id in channel_neighborhood]
channel_neighborhood_indices = [
channel_ids.tolist().index(ch_id)
for ch_id in channel_neighborhood
]
num = len(all_subsampled_spike_trains[ii])
spike_snippets = spike_snippets_concat[index : index + num, :, channel_neighborhood_indices]
spike_snippets = spike_snippets_concat[
index : index + num, :, channel_neighborhood_indices
]
index = index + num
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/subsampled_spike_snippets", data=spike_snippets)
f.create_dataset(
f"segment/{iseg}/unit/{unit_id}/subsampled_spike_snippets",
data=spike_snippets,
)
uri = kcl.store_file_local(output_file_name)
return uri


def get_channel_neighborhood(*, channel_ids: np.ndarray, channel_locations: np.ndarray, peak_channel_id: int, channel_neighborhood_size: int):
def get_channel_neighborhood(
*,
channel_ids: np.ndarray,
channel_locations: np.ndarray,
peak_channel_id: int,
channel_neighborhood_size: int,
):
channel_locations_by_id = {}
for ii, channel_id in enumerate(channel_ids):
channel_locations_by_id[channel_id] = channel_locations[ii]
Expand All @@ -174,7 +273,9 @@ def subsample(x: np.ndarray, num: int):
return x[0 : stride * num : stride]


def extract_spike_snippets(*, traces: np.ndarray, times: np.ndarray, snippet_len: Tuple[int, int]):
def extract_spike_snippets(
*, traces: np.ndarray, times: np.ndarray, snippet_len: Tuple[int, int]
):
a = snippet_len[0]
b = snippet_len[1]
T = a + b
Expand Down
2 changes: 1 addition & 1 deletion sortingview/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# This file was automatically generated by jinjaroot. Do not edit directly.
__version__ = '0.13.4'
__version__ = "0.13.4"

0 comments on commit e1632f6

Please sign in to comment.