From 5471878d1d1cfa4f00d865af70e453eed646f6c0 Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Thu, 5 Sep 2024 01:24:00 -0400 Subject: [PATCH 1/2] Fix z shift saving for piecewise mcorr (#314) * Fix failure to compute correlation image for 3D data (pending caiman update) * Fix by passing file as string and re-run without baseline if that fails * Make local_correlations in cnmf compatible with 3D data * Add z shifts for pw_rigid case * Pin caiman version instead of catching error from old version * Fix saving non-rigid z-shifts --- mesmerize_core/algorithms/mcorr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesmerize_core/algorithms/mcorr.py b/mesmerize_core/algorithms/mcorr.py index fbfd1fd..3bac29e 100644 --- a/mesmerize_core/algorithms/mcorr.py +++ b/mesmerize_core/algorithms/mcorr.py @@ -112,7 +112,7 @@ def run_algo(batch_path, uuid, data_path: str = None): y_shifts = mc.y_shifts_els shifts = [x_shifts, y_shifts] if hasattr(mc, 'z_shifts_els'): - shifts += mc.z_shifts_els + shifts.append(mc.z_shifts_els) shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") np.save(str(shift_path), shifts) else: From 5d7c9b3ebe121eb641addccb6acdc8119d14a1d8 Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Sun, 8 Sep 2024 20:55:30 -0400 Subject: [PATCH 2/2] Make get_shifts return value more useful; add test for shifts (#317) * Make get_shifts return value more useful; add test for shifts * Fix type annotation for get_shifts --- mesmerize_core/caiman_extensions/mcorr.py | 29 +++++++---------------- tests/test_core.py | 20 +++++++++++++++- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/mesmerize_core/caiman_extensions/mcorr.py b/mesmerize_core/caiman_extensions/mcorr.py index 9cd5fb7..0699d98 100644 --- a/mesmerize_core/caiman_extensions/mcorr.py +++ b/mesmerize_core/caiman_extensions/mcorr.py @@ -5,7 +5,6 @@ from caiman import load_memmap from ._utils import validate -from typing import * @pd.api.extensions.register_series_accessor("mcorr") @@ -92,9 +91,7 @@ def get_output(self, mode: str = "r") -> np.ndarray: return mc_movie @validate("mcorr") - def get_shifts( - self, pw_rigid: bool = False - ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + def get_shifts(self, pw_rigid) -> list[np.ndarray]: """ Gets file path to shifts array (.npy file) for item, processes shifts array into a list of x and y shifts based on whether rigid or nonrigid @@ -107,26 +104,16 @@ def get_shifts( False = Rigid Returns: -------- - List of Processed X and Y shifts arrays + List of Processed X and Y [and Z] shifts arrays + - For rigid correction, each element is a vector of length n_frames + - For pw_rigid correction, each element is an n_frames x n_patches matrix """ path = self._series.paths.resolve(self._series["outputs"]["shifts"]) shifts = np.load(str(path)) if pw_rigid: - n_pts = shifts.shape[1] - n_lines = shifts.shape[2] - xs = [np.linspace(0, n_pts, n_pts)] - ys = [] - - for i in range(shifts.shape[0]): - for j in range(n_lines): - ys.append(shifts[i, :, j]) + shifts_by_dim = list(shifts) # dims-length list of n_frames x n_patches matrices else: - n_pts = shifts.shape[0] - n_lines = shifts.shape[1] - xs = [np.linspace(0, n_pts, n_pts)] - ys = [] - - for i in range(n_lines): - ys.append(shifts[:, i]) - return xs, ys + shifts_by_dim = list(shifts.T) # dims-length list of n_frames-length vectors + + return shifts_by_dim diff --git a/tests/test_core.py b/tests/test_core.py index 163189e..878a43a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -44,7 +44,7 @@ def _download_ground_truths(): print(f"Downloading ground truths") - url = f"https://zenodo.org/record/6828096/files/ground_truths.zip" + url = f"https://zenodo.org/record/13732996/files/ground_truths.zip" # basically from https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests/37573701 response = requests.get(url, stream=True) @@ -252,6 +252,15 @@ def test_mcorr(): ) ) + # test to check shifts output path + assert ( + batch_dir.joinpath(df.iloc[-1]["outputs"]["shifts"]) + == df.paths.resolve(df.iloc[-1]["outputs"]["shifts"]) + == batch_dir.joinpath( + str(df.iloc[-1]["uuid"]), f'{df.iloc[-1]["uuid"]}_shifts.npy' + ) + ) + # test to check mean-projection output path assert ( batch_dir.joinpath(df.iloc[-1]["outputs"]["mean-projection-path"]) @@ -303,6 +312,15 @@ def test_mcorr(): ) numpy.testing.assert_array_equal(mcorr_output, mcorr_output_actual) + + # test to check mcorr get_shifts() + mcorr_shifts = df.iloc[-1].mcorr.get_shifts(pw_rigid=test_params[algo]["main"]["pw_rigid"]) + mcorr_shifts_actual = numpy.load( + ground_truths_dir.joinpath("mcorr", "mcorr_shifts.npy") + ) + numpy.testing.assert_array_equal(mcorr_shifts, mcorr_shifts_actual) + + # test to check caiman get_input_movie_path() assert df.iloc[-1].caiman.get_input_movie_path() == get_full_raw_data_path( df.iloc[0]["input_movie_path"]