diff --git a/caiman/base/rois.py b/caiman/base/rois.py index b5fff51e1..417405828 100644 --- a/caiman/base/rois.py +++ b/caiman/base/rois.py @@ -28,46 +28,36 @@ pass -def com(A: np.ndarray, d1: int, d2: int, d3: Optional[int] = None) -> np.array: +def com(A, d1: int, d2: int, d3: Optional[int] = None, order: str = 'F') -> np.ndarray: """Calculation of the center of mass for spatial components Args: - A: np.ndarray - matrix of spatial components (d x K) + A: np.ndarray or scipy.sparse array or matrix + matrix of spatial components (d x K). - d1: int - number of pixels in x-direction - - d2: int - number of pixels in y-direction - - d3: int - number of pixels in z-direction + d1, d2, d3: ints + d1, d2, and (optionally) d3 are the original dimensions of the data. + + order: 'C' or 'F' + how each column of A should be reshaped to match the given dimensions. Returns: cm: np.ndarray - center of mass for spatial components (K x 2 or 3) + center of mass for spatial components (K x D) """ - if 'csc_matrix' not in str(type(A)): A = scipy.sparse.csc_matrix(A) - if d3 is None: - Coor = np.matrix([np.outer(np.ones(d2), np.arange(d1)).ravel(), - np.outer(np.arange(d2), np.ones(d1)).ravel()], - dtype=A.dtype) - else: - Coor = np.matrix([ - np.outer(np.ones(d3), - np.outer(np.ones(d2), np.arange(d1)).ravel()).ravel(), - np.outer(np.ones(d3), - np.outer(np.arange(d2), np.ones(d1)).ravel()).ravel(), - np.outer(np.arange(d3), - np.outer(np.ones(d2), np.ones(d1)).ravel()).ravel() - ], - dtype=A.dtype) - - cm = (Coor * A / A.sum(axis=0)).T + dims = [d1, d2] + if d3 is not None: + dims.append(d3) + + # make coordinate arrays where coor[d] increases from 0 to npixels[d]-1 along the dth axis + coors = np.meshgrid(*[range(d) for d in dims], indexing='ij') + coor = np.stack([c.ravel(order=order) for c in coors]) + + # take weighted sum of pixel positions along each coordinate + cm = (coor @ A / A.sum(axis=0)).T return np.array(cm) diff --git a/caiman/tests/test_toydata.py b/caiman/tests/test_toydata.py index 2a0f2a8e0..beb51a786 100644 --- a/caiman/tests/test_toydata.py +++ b/caiman/tests/test_toydata.py @@ -6,6 +6,7 @@ import caiman.source_extraction.cnmf.params from caiman.source_extraction import cnmf as cnmf +from caiman.utils.visualization import get_contours #%% @@ -64,6 +65,34 @@ def pipeline(D): ] npt.assert_allclose(corr, 1, .05) + # Check that get_contours works regardless of swap_dim + coor_normal = get_contours(cnm.estimates.A, dims, swap_dim=False) + coor_swapped = get_contours(cnm.estimates.A, dims[::-1], swap_dim=True) + for c_normal, c_swapped in zip(coor_normal, coor_swapped): + if D == 3: + for plane_coor_normal, plane_coor_swapped in zip(c_normal['coordinates'], c_swapped['coordinates']): + compare_contour_coords(plane_coor_normal, plane_coor_swapped[:, ::-1]) + else: + compare_contour_coords(c_normal['coordinates'], c_swapped['coordinates'][:, ::-1]) + + npt.assert_allclose(c_normal['CoM'], c_swapped['CoM'][::-1]) + +def compare_contour_coords(coords1: np.ndarray, coords2: np.ndarray): + """ + Compare 2 matrices of contour coordinates that should be the same, but may be calculated in a different order/ + from different starting points. + + The first point of each contour component is repeated, and this may be a different point depending on orientation. + To get around this, compare differences instead (have to take absolute value b/c direction may be opposite). + Also sort coordinates b/c starting point is unimportant & depends on orientation + """ + diffs_sorted = [] + for coords in [coords1, coords2]: + abs_diffs = np.abs(np.diff(coords, axis=0)) + sort_order = np.lexsort(abs_diffs.T) + diffs_sorted.append(abs_diffs[sort_order, :]) + npt.assert_allclose(diffs_sorted[0], diffs_sorted[1]) + def test_2D(): pipeline(2) diff --git a/caiman/utils/visualization.py b/caiman/utils/visualization.py index a82833507..9f9a6f2e1 100644 --- a/caiman/utils/visualization.py +++ b/caiman/utils/visualization.py @@ -366,7 +366,7 @@ def plot_unit(uid, scl): .redim.range(unit_id=(0, nr-1), scale=(0.0, 1.0))) -def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False): +def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False, slice_dim: Optional[int] = None): """Gets contour of spatial components and returns their coordinates Args: @@ -374,15 +374,24 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False): Matrix of Spatial components (d x K) dims: tuple of ints - Spatial dimensions of movie (x, y[, z]) + Spatial dimensions of movie thr: scalar between 0 and 1 Energy threshold for computing contours (default 0.9) - thr_method: [optional] string + thr_method: string Method of thresholding: 'max' sets to zero pixels that have value less than a fraction of the max value 'nrg' keeps the pixels that contribute up to a specified fraction of the energy + + swap_dim: bool + If False (default), each column of A should be reshaped in F-order to recover the mask; + this is correct if the dimensions have not been reordered from (y, x[, z]). + If True, each column should be reshaped in C-order; this is correct for dims = ([z, ]x, y). + + slice_dim: int or None + Which dimension to slice along if we have 3D data. (i.e., get contours on each plane along this axis). + The default (None) is 0 if swap_dim is True, else -1. Returns: Coor: list of coordinates with center of mass and @@ -392,18 +401,11 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False): if 'csc_matrix' not in str(type(A)): A = csc_matrix(A) d, nr = np.shape(A) - # if we are on a 3D video - if len(dims) == 3: - d1, d2, d3 = dims - x, y = np.mgrid[0:d2:1, 0:d3:1] - else: - d1, d2 = dims - x, y = np.mgrid[0:d1:1, 0:d2:1] coordinates = [] # get the center of mass of neurons( patches ) - cm = caiman.base.rois.com(A, *dims) + cm = caiman.base.rois.com(A, *dims, order='C' if swap_dim else 'F') # for each patches for i in range(nr): @@ -437,9 +439,10 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False): Bmat = np.reshape(Bvec, dims, order='C') else: Bmat = np.reshape(Bvec, dims, order='F') - pars['coordinates'] = [] - # for each dimensions we draw the contour - for B in (Bmat if len(dims) == 3 else [Bmat]): + + def get_slice_coords(B: np.ndarray) -> np.ndarray: + """Get contour coordinates for a 2D slice""" + d1, d2 = B.shape vertices = find_contours(B.T, thr) # this fix is necessary for having disjoint figures and borders plotted correctly v = np.atleast_2d([np.nan, np.nan]) @@ -448,16 +451,26 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False): if num_close_coords < 2: if num_close_coords == 0: # case angle - newpt = np.round(vtx[-1, :] / [d2, d1]) * [d2, d1] - vtx = np.concatenate((vtx, newpt[np.newaxis, :]), axis=0) + newpt = np.round(np.mean(vtx[[0, -1], :], axis=0) / [d2, d1]) * [d2, d1] + vtx = np.concatenate((newpt[np.newaxis, :], vtx, newpt[np.newaxis, :]), axis=0) else: # case one is border vtx = np.concatenate((vtx, vtx[0, np.newaxis]), axis=0) v = np.concatenate( (v, vtx, np.atleast_2d([np.nan, np.nan])), axis=0) + return v + + if len(dims) == 2: + pars['coordinates'] = get_slice_coords(Bmat) + else: + # make a list of the contour coordinates for each 2D slice + pars['coordinates'] = [] + if slice_dim is None: + slice_dim = 0 if swap_dim else -1 + for s in range(dims[slice_dim]): + B = Bmat.take(s, axis=slice_dim) + pars['coordinates'].append(get_slice_coords(B)) - pars['coordinates'] = v if len( - dims) == 2 else (pars['coordinates'] + [v]) pars['CoM'] = np.squeeze(cm[i, :]) pars['neuron_id'] = i + 1 coordinates.append(pars) @@ -1098,16 +1111,11 @@ def plot_contours(A, Cn, thr=None, thr_method='max', maxthr=0.2, nrgthr=0.9, dis plt.plot(*v.T, c=colors, **contour_args) if display_numbers: - d1, d2 = np.shape(Cn) - d, nr = np.shape(A) - cm = caiman.base.rois.com(A, d1, d2) + nr = A.shape[1] if max_number is None: - max_number = A.shape[1] - for i in range(np.minimum(nr, max_number)): - if swap_dim: - ax.text(cm[i, 0], cm[i, 1], str(i + 1), color=colors, **number_args) - else: - ax.text(cm[i, 1], cm[i, 0], str(i + 1), color=colors, **number_args) + max_number = nr + for i, c in zip(range(np.minimum(nr, max_number)), coordinates): + ax.text(c['CoM'][1], c['CoM'][0], str(i + 1), color=colors, **number_args) return coordinates def plot_shapes(Ab, dims, num_comps=15, size=(15, 15), comps_per_row=None,