diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py index 8b63a981..8308b78a 100644 --- a/element_calcium_imaging/plotting/draw_rois.py +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -14,9 +14,6 @@ Serverside, ServersideOutputTransform, ) -from scipy import ndimage -from skimage import draw, measure -from tifffile import TiffFile from .utilities import * @@ -25,6 +22,7 @@ def draw_rois(db_prefix: str): + scan = dj.create_virtual_module("scan", f"{db_prefix}scan") imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging") all_keys = (imaging.MotionCorrection).fetch("KEY") @@ -34,29 +32,42 @@ def draw_rois(db_prefix: str): app.layout = html.Div( [ html.H2("Draw ROIs", style={"color": colors["text"]}), - html.Label("Select data key from dropdown", style={"color": colors["text"]}), + html.Label( + "Select data key from dropdown", style={"color": colors["text"]} + ), dcc.Dropdown( id="toplevel-dropdown", options=[str(key) for key in all_keys] ), html.Br(), html.Div( [ - html.Button("Load Image", id="load-image-button", style={"margin-right": "20px"}), + html.Button( + "Load Image", + id="load-image-button", + style={"margin-right": "20px"}, + ), dcc.RadioItems( - id='image-type-radio', + id="image-type-radio", options=[ - {'label': 'Average Image', 'value': 'average_image'}, - {'label': 'Max Projection Image', 'value': 'max_projection_image'} + {"label": "Average Image", "value": "average_image"}, + { + "label": "Max Projection Image", + "value": "max_projection_image", + }, ], - value='average_image', # Default value - labelStyle={'display': 'inline-block', 'margin-right': '10px'}, # Inline display with some margin - style={'display': 'inline-block', 'color': colors['text']} # Inline display to keep it on the same line + value="average_image", + labelStyle={"display": "inline-block", "margin-right": "10px"}, + style={"display": "inline-block", "color": colors["text"]}, ), html.Div( [ html.Button("Submit Curated Masks", id="submit-button"), ], - style={"textAlign": "right", "flex": "1", "display": "inline-block"}, + style={ + "textAlign": "right", + "flex": "1", + "display": "inline-block", + }, ), ], style={ @@ -76,6 +87,7 @@ def draw_rois(db_prefix: str): "drawclosedpath", "drawrect", "drawcircle", + "drawline", "eraseshape", ], }, @@ -84,10 +96,10 @@ def draw_rois(db_prefix: str): ], style={ "display": "flex", - "justify-content": "center", # Centers the child horizontally - "align-items": "center", # Centers the child vertically (if you have vertical space to work with) + "justify-content": "center", + "align-items": "center", "padding": "0.0", - "margin": "auto" # Automatically adjust the margins to center the div + "margin": "auto", }, ), html.Pre(id="annotations"), @@ -120,13 +132,22 @@ def store_key(value): def create_figure(value, render_n_clicks, image_type): if render_n_clicks is not None: if image_type == "average_image": - summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("average_image") + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("average_image") else: - summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("max_proj_image") + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("max_proj_image") average_images = [image.astype("float") for image in summary_images] roi_contours = get_contours(yaml.safe_load(value), db_prefix) logger.info("Generating figure.") - fig = px.imshow(np.asarray(average_images), animation_frame=0, binary_string=True, labels=dict(animation_frame="plane")) + fig = px.imshow( + np.asarray(average_images), + animation_frame=0, + binary_string=True, + labels=dict(animation_frame="plane"), + ) for contour in roi_contours: # Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates fig.add_trace( @@ -171,30 +192,36 @@ def on_relayout(relayout_data): elif any(["shapes" in key for key in relayout_data]): return Serverside(relayout_data) - @app.callback( Output("submit-output", "children"), Input("submit-button", "n_clicks"), State("store-mask", "annotation_list"), - State("store-key", "value") + State("store-key", "value"), ) def submit_annotations(n_clicks, annotation_list, value): - print("submitting annotations") x_mask_li = [] y_mask_li = [] if n_clicks is not None: - if "shapes" in annotation_list: - shapes = [d["type"] for d in annotation_list["shapes"]] - for shape, annotation in zip(shapes, annotation_list["shapes"]): - mask = create_mask(annotation, shape) - y_mask_li.append(mask[0]) - x_mask_li.append(mask[1]) - - suite2p_masks = convert_masks_to_suite2p_format( - [np.array([x_mask_li, y_mask_li])], (512, 512) - ) - fluo_traces = extract_signals_suite2p(yaml.safe_load(value), suite2p_masks) - + if annotation_list: + if "shapes" in annotation_list: + logger.info("Creating Masks.") + shapes = [d["type"] for d in annotation_list["shapes"]] + for shape, annotation in zip(shapes, annotation_list["shapes"]): + mask = create_mask(annotation, shape) + y_mask_li.append(mask[0]) + x_mask_li.append(mask[1]) + print("Masks created") + insert_into_database( + scan, imaging, yaml.safe_load(value), x_mask_li, y_mask_li + ) + else: + logger.warn( + "Incorrect annotation list format. This is a known bug. Please draw a line anywhere on the image and click `Submit Curated Masks`. It will be ignored in the final submission but will format the list correctly." + ) + return no_update + else: + logger.warn("No annotations to submit.") + return no_update else: return no_update diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py index 0308a4f8..83afbe1b 100644 --- a/element_calcium_imaging/plotting/utilities.py +++ b/element_calcium_imaging/plotting/utilities.py @@ -1,7 +1,25 @@ +import pathlib import datajoint as dj import numpy as np from scipy import ndimage from skimage import draw, measure +from element_interface.utils import find_full_path + + +logger = dj.logger + + +def get_imaging_root_data_dir(): + """Retrieve imaging root data directory.""" + imaging_root_dirs = dj.config.get("custom", {}).get("imaging_root_data_dir", None) + if not imaging_root_dirs: + return None + elif isinstance(imaging_root_dirs, (str, pathlib.Path)): + return [imaging_root_dirs] + elif isinstance(imaging_root_dirs, list): + return imaging_root_dirs + else: + raise TypeError("`imaging_root_data_dir` must be a string, pathlib, or list") def path_to_indices(path): @@ -88,9 +106,7 @@ def create_mask(coordinates, shape_type): xy_coordinates = np.asarray( [item for item in coordinates.values()], dtype="int" ) - mask = np.asarray( - create_ellipse_mask(xy_coordinates, (512, 512)) - ).nonzero() + mask = np.asarray(create_ellipse_mask(xy_coordinates, (512, 512))).nonzero() elif shape_type == "rect": try: mask = np.asarray( @@ -140,36 +156,7 @@ def get_contours(image_key, prefix): return contours -def convert_masks_to_suite2p_format(masks, frame_dims): - """ - Convert masks to the format expected by Suite2P. - - Parameters: - masks (list of np.ndarray): A list where each item is an array representing a mask, - with non-zero values for the ROI and zeros elsewhere. - frame_dims (tuple): The dimensions of the imaging frame, (height, width). - - Returns: - np.ndarray: A 2D array where each column represents a flattened binary mask for an ROI. - """ - # Calculate the total number of pixels in a frame - num_pixels = frame_dims[0] * frame_dims[1] - - # Initialize an empty array to store the flattened binary masks - suite2p_masks = np.zeros((num_pixels, len(masks)), dtype=np.float32) - - # Convert each mask - for idx, mask in enumerate(masks): - # Ensure the mask is binary (1 for ROI, 0 for background) - binary_mask = np.where(mask > 0, 1, 0).astype(np.float32) - - # Flatten the binary mask and add it as a column in the suite2p_masks array - suite2p_masks[:, idx] = binary_mask.flatten() - - return suite2p_masks - - -def load_imaging_data_for_session(key): +def load_imaging_data_for_session(scan, key): image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path") image_files = [ find_full_path(get_imaging_root_data_dir(), image_file) @@ -177,20 +164,57 @@ def load_imaging_data_for_session(key): ] acq_software = (scan.Scan & key).fetch1("acq_software") if acq_software == "ScanImage": - imaging_data = tifffile.imread(image_files[0]) + import tifffile + + imaging_data = tifffile.imread(image_files[0]) elif acq_software == "NIS": + import nd2 + imaging_data = nd2.imread(image_files[0]) else: - raise ValueError(f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget.") + raise ValueError( + f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget." + ) return imaging_data -def extract_signals_suite2p(key, masks): - from suite2p.extraction.extract import extrace_traces - - F, _ = extrace_traces(load_imaging_data_for_session(key), masks, neuropil_masks=np.zeros_like(masks)) - - -def insert_signals_into_datajoint(signals, session_key): - # Implement logic to insert the extracted signals into DataJoint - pass \ No newline at end of file +def insert_into_database(scan_module, imaging_module, session_key, x_masks, y_masks): + images = load_imaging_data_for_session(scan_module, session_key) + print(f"Images shape: {images.shape}") + mask_id = (imaging_module.Segmentation.Mask & session_key).fetch( + "mask_id", order_by="DESC mask_id", limit=1 + ) + print(f"Mask ID: {mask_id}") + logger.info(f"Inserting {len(x_masks)} masks into the database.") + # imaging_module.Segmentation.Mask.insert( + # [ + # dict( + # **session_key, + # mask=mask_id + mask_num, + # segmentation_channel=1, + # mask_npix=y_mask.shape[0], + # mask_center_x=int(sum(x_mask) / x_mask.shape[0]), + # mask_center_y=int(sum(y_mask) / y_mask.shape[0]), + # mask_center_z=0, + # mask_xpix=x_mask, + # mask_ypix=y_mask, + # mask_zpix=0, + # mask_weights=np.ones_like(y_mask), + # ) + # for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + # ], + # allow_direct_insert=True, + # ) + logger.info(f"Inserting {len(x_masks)} traces into the database.") + # imaging_module.Fluorescence.Trace.insert( + # [ + # dict( + # **session_key, + # mask=mask_id + mask_num, + # fluo_channel=1, + # fluorescence=images[:, y_mask, x_mask].mean(axis=1), + # ) + # for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + # ], + # allow_direct_insert=True, + # ) diff --git a/notebooks/test_widget.ipynb b/notebooks/test_widget.ipynb new file mode 100644 index 00000000..55a166d1 --- /dev/null +++ b/notebooks/test_widget.ipynb @@ -0,0 +1,50 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# change to the upper level folder to detect dj_local_conf.json\n", + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", + " os.chdir(\"..\")\n", + "\n", + "import datajoint as dj\n", + "from element_calcium_imaging.plotting.draw_rois import draw_rois" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "draw_rois(\"neuro_\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "elements", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}