Skip to content

Commit

Permalink
Merge pull request #1613 from h-mayorquin/fix_overflow_2.0
Browse files Browse the repository at this point in the history
Fix overflow of Plexon in numpy 2.0
  • Loading branch information
zm711 authored Dec 19, 2024
2 parents 49534ce + 6f56bd0 commit b6a721c
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions neo/rawio/plexonrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def _parse_header(self):
for index, pos in enumerate(positions):
bl_header = data[pos : pos + 16].view(DataBlockHeader)[0]

# To avoid overflow errors when doing arithmetic operations on numpy scalars
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
bl_header = {key: np_scalar_to_python_scalar(bl_header[key]) for key in bl_header.dtype.names}

current_upper_byte_of_5_byte_timestamp = int(bl_header["UpperByteOf5ByteTimestamp"])
current_bl_timestamp = int(bl_header["TimeStamp"])
timestamp = current_upper_byte_of_5_byte_timestamp * 2**32 + current_bl_timestamp
Expand Down Expand Up @@ -255,24 +259,29 @@ def _parse_header(self):
else:
chan_loop = range(nb_sig_chan)
for chan_index in chan_loop:
h = slowChannelHeaders[chan_index]
name = h["Name"].decode("utf8")
chan_id = h["Channel"]
slow_channel_headers = slowChannelHeaders[chan_index]

# To avoid overflow errors when doing arithmetic operations on numpy scalars
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
slow_channel_headers = {key: np_scalar_to_python_scalar(slow_channel_headers[key]) for key in slow_channel_headers.dtype.names}

name = slow_channel_headers["Name"].decode("utf8")
chan_id = slow_channel_headers["Channel"]
length = self._data_blocks[5][chan_id]["size"].sum() // 2
if length == 0:
continue # channel not added
source_id.append(h["SrcId"])
source_id.append(slow_channel_headers["SrcId"])
channel_num_samples.append(length)
sampling_rate = float(h["ADFreq"])
sampling_rate = float(slow_channel_headers["ADFreq"])
sig_dtype = "int16"
units = "" # I don't know units
if global_header["Version"] in [100, 101]:
gain = 5000.0 / (2048 * h["Gain"] * 1000.0)
gain = 5000.0 / (2048 * slow_channel_headers["Gain"] * 1000.0)
elif global_header["Version"] in [102]:
gain = 5000.0 / (2048 * h["Gain"] * h["PreampGain"])
gain = 5000.0 / (2048 * slow_channel_headers["Gain"] * slow_channel_headers["PreampGain"])
elif global_header["Version"] >= 103:
gain = global_header["SlowMaxMagnitudeMV"] / (
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * h["Gain"] * h["PreampGain"]
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * slow_channel_headers["Gain"] * slow_channel_headers["PreampGain"]
)
offset = 0.0

Expand Down Expand Up @@ -358,21 +367,21 @@ def _parse_header(self):
unit_loop = enumerate(self.internal_unit_ids)

for unit_index, (chan_id, unit_id) in unit_loop:
c = np.nonzero(dspChannelHeaders["Channel"] == chan_id)[0][0]
h = dspChannelHeaders[c]
channel_index = np.nonzero(dspChannelHeaders["Channel"] == chan_id)[0][0]
dsp_channel_headers = dspChannelHeaders[channel_index]

name = h["Name"].decode("utf8")
name = dsp_channel_headers["Name"].decode("utf8")
_id = f"ch{chan_id}#{unit_id}"
wf_units = ""
if global_header["Version"] < 103:
wf_gain = 3000.0 / (2048 * h["Gain"] * 1000.0)
wf_gain = 3000.0 / (2048 * dsp_channel_headers["Gain"] * 1000.0)
elif 103 <= global_header["Version"] < 105:
wf_gain = global_header["SpikeMaxMagnitudeMV"] / (
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * h["Gain"] * 1000.0
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * dsp_channel_headers["Gain"] * 1000.0
)
elif global_header["Version"] >= 105:
wf_gain = global_header["SpikeMaxMagnitudeMV"] / (
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * h["Gain"] * global_header["SpikePreAmpGain"]
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * dsp_channel_headers["Gain"] * global_header["SpikePreAmpGain"]
)
wf_offset = 0.0
wf_left_sweep = -1 # DONT KNOWN
Expand Down Expand Up @@ -576,7 +585,7 @@ def read_as_dict(fid, dtype, offset=None):
v = v.replace("\x03", "")
v = v.replace("\x00", "")

info[k] = v
info[k] = v.item() if isinstance(v, np.generic) else v
return info


Expand Down

0 comments on commit b6a721c

Please sign in to comment.