diff --git a/caiman/utils/visualization.py b/caiman/utils/visualization.py index 39e0db78f..a82833507 100644 --- a/caiman/utils/visualization.py +++ b/caiman/utils/visualization.py @@ -1344,7 +1344,6 @@ def view_quilt(template_image: np.ndarray, Note: Currently assumes square patches so takes in a single number for stride/overlap. - TODO: implement bokeh version of this function """ im_dims = template_image.shape patch_rows, patch_cols = get_rectangle_coords(im_dims, stride, overlap) @@ -1362,3 +1361,119 @@ def view_quilt(template_image: np.ndarray, ax=ax) return ax + + +def create_quilt_patches(patch_rows, patch_cols): + """ + Helper function for nb_view_quilt. + Create patches given the row and column coordinates. + + Args: + patch_rows (ndarray): Array of row start and end positions for each patch. + patch_cols (ndarray): Array of column start and end positions for each patch. + + Returns: + list: A list of dictionaries, each containing the center coordinates, width, + and height of a patch. + """ + patches = [] + for row in patch_rows: + for col in patch_cols: + center_x = (col[0] + col[1]) / 2 + center_y = (row[0] + row[1]) / 2 + width = col[1] - col[0] + height = row[1] - row[0] + patches.append({'center_x': center_x, 'center_y': center_y, 'width': width, 'height': height}) + return patches + + +def nb_view_quilt(template_image: np.ndarray, + rf: int, + stride_input: int, + color: Optional[Any]='white', + alpha: Optional[float]=0.2): + """ + Bokeh implementation of view_quilt. + Plot patches on the template image given stride and overlap parameters. + + Args: + template_image (ndarray): Row x column summary image upon which to draw patches (e.g., correlation image). + rf (int): Half-size of the patches in pixels (patch width is rf*2 + 1). + stride_input (int): Amount of overlap between the patches in pixels (overlap is stride_input + 1). + color (Optional[Any]): Color of the patches, default 'white'. + alpha (Optional[float]): Patch transparency, default 0.2. + """ + + width = (rf*2)+1 + overlap = stride_input+1 + stride = width-overlap + + im_dims = template_image.shape + patch_rows, patch_cols = get_rectangle_coords(im_dims, stride, overlap) + patches = create_quilt_patches(patch_rows, patch_cols) + + plot = bpl.figure(x_range=(0, im_dims[1]), y_range=(im_dims[0], 0), width=600, height=600) + #plot.y_range.flipped = True + plot.image(image=[template_image], x=0, y=0, dw=im_dims[1], dh=im_dims[0], palette="Greys256") + source = ColumnDataSource(data=dict( + center_x=[patch['center_x'] for patch in patches], + center_y=[patch['center_y'] for patch in patches], + width=[patch['width'] for patch in patches], + height=[patch['height'] for patch in patches] + )) + plot.rect(x='center_x', y='center_y', width='width', height='height', source=source, color=color, alpha=alpha) + + # Create sliders + stride_slider = bokeh.models.Slider(start=1, end=100, value=rf, step=1, title="Patch half-size (rf)") + overlap_slider = bokeh.models.Slider(start=0, end=100, value=stride_input, step=1, title="Overlap (stride)") + + callback = CustomJS(args=dict(source=source, im_dims=im_dims, stride_slider=stride_slider, overlap_slider=overlap_slider), code=""" + function get_rectangle_coords(im_dims, stride, overlap) { + let patch_width = overlap + stride; + + let patch_onset_rows = Array.from({length: Math.ceil((im_dims[0] - patch_width) / stride)}, (_, i) => i * stride).concat([im_dims[0] - patch_width]); + let patch_offset_rows = patch_onset_rows.map(x => Math.min(x + patch_width, im_dims[0] - 1)); + let patch_rows = patch_onset_rows.map((x, i) => [x, patch_offset_rows[i]]); + + let patch_onset_cols = Array.from({length: Math.ceil((im_dims[1] - patch_width) / stride)}, (_, i) => i * stride).concat([im_dims[1] - patch_width]); + let patch_offset_cols = patch_onset_cols.map(x => Math.min(x + patch_width, im_dims[1] - 1)); + let patch_cols = patch_onset_cols.map((x, i) => [x, patch_offset_cols[i]]); + + return [patch_rows, patch_cols]; + } + + function create_quilt_patches(patch_rows, patch_cols) { + let patches = []; + for (let row of patch_rows) { + for (let col of patch_cols) { + let center_x = (col[0] + col[1]) / 2; + let center_y = (row[0] + row[1]) / 2; + let width = col[1] - col[0]; + let height = row[1] - row[0]; + patches.push({'center_x': center_x, 'center_y': center_y, 'width': width, 'height': height}); + } + } + return patches; + } + + let width = (stride_slider.value * 2) + 1; + let overlap = overlap_slider.value + 1; + let stride = width - overlap + + let [patch_rows, patch_cols] = get_rectangle_coords(im_dims, stride, overlap); + let patches = create_quilt_patches(patch_rows, patch_cols); + + source.data = { + center_x: patches.map(patch => patch.center_x), + center_y: patches.map(patch => patch.center_y), + width: patches.map(patch => patch.width), + height: patches.map(patch => patch.height) + }; + source.change.emit(); + """) + + stride_slider.js_on_change('value', callback) + overlap_slider.js_on_change('value', callback) + + + bpl.show(bokeh.layouts.row(plot, bokeh.layouts.column(stride_slider, overlap_slider))) \ No newline at end of file diff --git a/demos/notebooks/demo_pipeline.ipynb b/demos/notebooks/demo_pipeline.ipynb index 0923ddd82..7fb668cc3 100644 --- a/demos/notebooks/demo_pipeline.ipynb +++ b/demos/notebooks/demo_pipeline.ipynb @@ -73,7 +73,7 @@ "from caiman.source_extraction.cnmf import cnmf, params\n", "from caiman.utils.utils import download_demo\n", "from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour\n", - "from caiman.utils.visualization import view_quilt\n", + "from caiman.utils.visualization import nb_view_quilt\n", "\n", "bpl.output_notebook()\n", "hv.notebook_extension('bokeh')" @@ -741,7 +741,7 @@ "metadata": {}, "source": [ "### Selecting spatial parameters\n", - "To select the spatial parameters (`gSig`, `rf`, `stride`, `K`), you need to look at your movie, or a summary image for your movie, and pick values close to those suggested by the guidelines above. It is helpful to use `view_quilt()` function to see if our key spatial parameters are in the right ballpark (note we recommend running this viewer in interactive qt mode so you can interact with it and get a better feel for the parameters):" + "To select the spatial parameters (`gSig`, `rf`, `stride`, `K`), you need to look at your movie, or a summary image for your movie, and pick values close to those suggested by the guidelines above. It is helpful to use the interactive `nb_view_quilt()` function or `view_quilt()` function to see if our key spatial parameters are in the right ballpark (you can use the sliders in the interactive version to change the `rf` and `stride` parameters and get a better feel for them):" ] }, { @@ -757,13 +757,9 @@ "print(f'Patch width: {cnmf_patch_width} , Stride: {cnmf_patch_stride}, Overlap: {cnmf_patch_overlap}');\n", "\n", "# plot the patches\n", - "patch_ax = view_quilt(correlation_image, \n", - " cnmf_patch_stride, \n", - " cnmf_patch_overlap, \n", - " vmin=np.percentile(np.ravel(correlation_image),50), \n", - " vmax=np.percentile(np.ravel(correlation_image),99.5),\n", - " figsize=(4,4));\n", - "patch_ax.set_title(f'CNMF Patches Width {cnmf_patch_width}, Overlap {cnmf_patch_overlap}');" + "patch_ax = nb_view_quilt(correlation_image, \n", + " cnmf_model.params.patch['rf'], \n", + " cnmf_model.params.patch['stride']);" ] }, {