Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix copy-view issue in epochs #12121

Merged
merged 16 commits into from
Nov 9, 2023
Merged
15 changes: 13 additions & 2 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,7 @@ def _get_data(
units=None,
tmin=None,
tmax=None,
copy=False,
on_empty="warn",
verbose=None,
):
Expand Down Expand Up @@ -1648,6 +1649,8 @@ def _get_data(
if self.preload:
# we will store our result in our existing array
data = self._data
if copy:
data = data.copy()
else:
# we start out with an empty array, allocate only if necessary
data = np.empty((0, len(self.info["ch_names"]), len(self.times)))
Expand Down Expand Up @@ -1796,7 +1799,9 @@ def _detrend_picks(self):
return []

@fill_doc
def get_data(self, picks=None, item=None, units=None, tmin=None, tmax=None):
def get_data(
self, picks=None, item=None, units=None, tmin=None, tmax=None, copy=True
):
"""Get all epochs as a 3D array.

Parameters
Expand All @@ -1821,13 +1826,19 @@ def get_data(self, picks=None, item=None, units=None, tmin=None, tmax=None):
End time of data to get in seconds.

.. versionadded:: 0.24.0
copy : bool | None
Whether to return a copy of the object's data, or (if possible) a view.
See :std:label:`basics.copies-and-views <numpy:basics.copies-and-views>`
larsoner marked this conversation as resolved.
Show resolved Hide resolved
for an explanation. Default is ``True``.
larsoner marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
data : array of shape (n_epochs, n_channels, n_times)
A view on epochs data.
"""
return self._get_data(picks=picks, item=item, units=units, tmin=tmin, tmax=tmax)
return self._get_data(
picks=picks, item=item, units=units, tmin=tmin, tmax=tmax, copy=copy
)

@verbose
def apply_function(
Expand Down
9 changes: 8 additions & 1 deletion mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ def test_get_data():
with pytest.raises(TypeError, match="tmax .* float, None"):
epochs.get_data(tmin=1, tmax=np.ones(5))

# Test copy
data = epochs.get_data(copy=True)
assert not np.shares_memory(data, epochs._data)

data = epochs.get_data(copy=False)
assert np.shares_memory(data, epochs._data)


def test_hierarchical():
"""Test hierarchical access."""
Expand Down Expand Up @@ -1033,7 +1040,7 @@ def test_epochs_baseline_basic(preload, tmp_path):
epochs = mne.Epochs(raw, events, None, 0, 1e-3, baseline=None, preload=preload)
epochs.drop_bad()
epochs_nobl = epochs.copy()
epochs_data = epochs.get_data()
epochs_data = epochs.get_data(copy=False)
assert epochs_data.shape == (1, 2, 2)
expected = data.copy()
assert_array_equal(epochs_data[0], expected)
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def test_plot_psd_epochs(epochs):
fig = spectrum.plot_topomap(bands=[(20, "20 Hz"), (15, 25, "15-25 Hz")])
# test with a flat channel
err_str = "for channel %s" % epochs.ch_names[2]
epochs.get_data()[0, 2, :] = 0
epochs.get_data(copy=False)[0, 2, :] = 0
for dB in [True, False]:
with pytest.warns(UserWarning, match=err_str):
epochs.compute_psd().plot(dB=dB)
Expand Down
Loading