Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added bokeh implementation of view_quilt #1365

Merged
merged 3 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion caiman/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,6 @@ def view_quilt(template_image: np.ndarray,

Note:
Currently assumes square patches so takes in a single number for stride/overlap.
TODO: implement bokeh version of this function
"""
im_dims = template_image.shape
patch_rows, patch_cols = get_rectangle_coords(im_dims, stride, overlap)
Expand All @@ -1362,3 +1361,119 @@ def view_quilt(template_image: np.ndarray,
ax=ax)

return ax


def create_quilt_patches(patch_rows, patch_cols):
"""
Helper function for nb_view_quilt.
Create patches given the row and column coordinates.

Args:
patch_rows (ndarray): Array of row start and end positions for each patch.
patch_cols (ndarray): Array of column start and end positions for each patch.

Returns:
list: A list of dictionaries, each containing the center coordinates, width,
and height of a patch.
"""
patches = []
for row in patch_rows:
for col in patch_cols:
center_x = (col[0] + col[1]) / 2
center_y = (row[0] + row[1]) / 2
width = col[1] - col[0]
height = row[1] - row[0]
patches.append({'center_x': center_x, 'center_y': center_y, 'width': width, 'height': height})
return patches


def nb_view_quilt(template_image: np.ndarray,
rf: int,
stride_input: int,
color: Optional[Any]='white',
alpha: Optional[float]=0.2):
"""
Bokeh implementation of view_quilt.
Plot patches on the template image given stride and overlap parameters.

Args:
template_image (ndarray): Row x column summary image upon which to draw patches (e.g., correlation image).
rf (int): Half-size of the patches in pixels (patch width is rf*2 + 1).
stride_input (int): Amount of overlap between the patches in pixels (overlap is stride_input + 1).
color (Optional[Any]): Color of the patches, default 'white'.
alpha (Optional[float]): Patch transparency, default 0.2.
"""

width = (rf*2)+1
overlap = stride_input+1
stride = width-overlap

im_dims = template_image.shape
patch_rows, patch_cols = get_rectangle_coords(im_dims, stride, overlap)
patches = create_quilt_patches(patch_rows, patch_cols)

plot = bpl.figure(x_range=(0, im_dims[1]), y_range=(im_dims[0], 0), width=600, height=600)
#plot.y_range.flipped = True
plot.image(image=[template_image], x=0, y=0, dw=im_dims[1], dh=im_dims[0], palette="Greys256")
source = ColumnDataSource(data=dict(
center_x=[patch['center_x'] for patch in patches],
center_y=[patch['center_y'] for patch in patches],
width=[patch['width'] for patch in patches],
height=[patch['height'] for patch in patches]
))
plot.rect(x='center_x', y='center_y', width='width', height='height', source=source, color=color, alpha=alpha)

# Create sliders
stride_slider = bokeh.models.Slider(start=1, end=100, value=rf, step=1, title="Patch half-size (rf)")
overlap_slider = bokeh.models.Slider(start=0, end=100, value=stride_input, step=1, title="Overlap (stride)")

callback = CustomJS(args=dict(source=source, im_dims=im_dims, stride_slider=stride_slider, overlap_slider=overlap_slider), code="""
function get_rectangle_coords(im_dims, stride, overlap) {
let patch_width = overlap + stride;

let patch_onset_rows = Array.from({length: Math.ceil((im_dims[0] - patch_width) / stride)}, (_, i) => i * stride).concat([im_dims[0] - patch_width]);
let patch_offset_rows = patch_onset_rows.map(x => Math.min(x + patch_width, im_dims[0] - 1));
let patch_rows = patch_onset_rows.map((x, i) => [x, patch_offset_rows[i]]);

let patch_onset_cols = Array.from({length: Math.ceil((im_dims[1] - patch_width) / stride)}, (_, i) => i * stride).concat([im_dims[1] - patch_width]);
let patch_offset_cols = patch_onset_cols.map(x => Math.min(x + patch_width, im_dims[1] - 1));
let patch_cols = patch_onset_cols.map((x, i) => [x, patch_offset_cols[i]]);

return [patch_rows, patch_cols];
}

function create_quilt_patches(patch_rows, patch_cols) {
let patches = [];
for (let row of patch_rows) {
for (let col of patch_cols) {
let center_x = (col[0] + col[1]) / 2;
let center_y = (row[0] + row[1]) / 2;
let width = col[1] - col[0];
let height = row[1] - row[0];
patches.push({'center_x': center_x, 'center_y': center_y, 'width': width, 'height': height});
}
}
return patches;
}

let width = (stride_slider.value * 2) + 1;
let overlap = overlap_slider.value + 1;
let stride = width - overlap

let [patch_rows, patch_cols] = get_rectangle_coords(im_dims, stride, overlap);
let patches = create_quilt_patches(patch_rows, patch_cols);

source.data = {
center_x: patches.map(patch => patch.center_x),
center_y: patches.map(patch => patch.center_y),
width: patches.map(patch => patch.width),
height: patches.map(patch => patch.height)
};
source.change.emit();
""")

stride_slider.js_on_change('value', callback)
overlap_slider.js_on_change('value', callback)


bpl.show(bokeh.layouts.row(plot, bokeh.layouts.column(stride_slider, overlap_slider)))
14 changes: 5 additions & 9 deletions demos/notebooks/demo_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"from caiman.source_extraction.cnmf import cnmf, params\n",
"from caiman.utils.utils import download_demo\n",
"from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour\n",
"from caiman.utils.visualization import view_quilt\n",
"from caiman.utils.visualization import nb_view_quilt\n",
"\n",
"bpl.output_notebook()\n",
"hv.notebook_extension('bokeh')"
Expand Down Expand Up @@ -741,7 +741,7 @@
"metadata": {},
"source": [
"### Selecting spatial parameters\n",
"To select the spatial parameters (`gSig`, `rf`, `stride`, `K`), you need to look at your movie, or a summary image for your movie, and pick values close to those suggested by the guidelines above. It is helpful to use `view_quilt()` function to see if our key spatial parameters are in the right ballpark (note we recommend running this viewer in interactive qt mode so you can interact with it and get a better feel for the parameters):"
"To select the spatial parameters (`gSig`, `rf`, `stride`, `K`), you need to look at your movie, or a summary image for your movie, and pick values close to those suggested by the guidelines above. It is helpful to use the interactive `nb_view_quilt()` function or `view_quilt()` function to see if our key spatial parameters are in the right ballpark (you can use the sliders in the interactive version to change the `rf` and `stride` parameters and get a better feel for them):"
]
},
{
Expand All @@ -757,13 +757,9 @@
"print(f'Patch width: {cnmf_patch_width} , Stride: {cnmf_patch_stride}, Overlap: {cnmf_patch_overlap}');\n",
"\n",
"# plot the patches\n",
"patch_ax = view_quilt(correlation_image, \n",
" cnmf_patch_stride, \n",
" cnmf_patch_overlap, \n",
" vmin=np.percentile(np.ravel(correlation_image),50), \n",
" vmax=np.percentile(np.ravel(correlation_image),99.5),\n",
" figsize=(4,4));\n",
"patch_ax.set_title(f'CNMF Patches Width {cnmf_patch_width}, Overlap {cnmf_patch_overlap}');"
"patch_ax = nb_view_quilt(correlation_image, \n",
" cnmf_model.params.patch['rf'], \n",
" cnmf_model.params.patch['stride']);"
]
},
{
Expand Down
Loading