Skip to content

Commit

Permalink
Merge pull request #55 from fractal-napari-plugins-collection/51_pred…
Browse files Browse the repository at this point in the history
…iction_layer_fixes

Fix prediction/annotation layer issues
  • Loading branch information
jluethi authored Sep 16, 2024
2 parents 1734442 + 6981f2b commit 37583cf
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 41 deletions.
26 changes: 16 additions & 10 deletions src/napari_feature_classifier/annotator_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,7 @@ def __init__(
for layer in self._viewer.layers:
if type(layer) == napari.layers.Labels and layer.name == "Annotations":
self._viewer.layers.remove(layer)
self._annotations_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Annotations",
translate=self._last_selected_label_layer.translate,
)
self._annotations_layer.editable = False

# Set the label selection to a valid label layer => Running into proxy bug
self._viewer.layers.selection.active = self._last_selected_label_layer
self.add_annotations_layer()

# Class selection
self.ClassSelection = ClassSelection # pylint: disable=C0103
Expand Down Expand Up @@ -212,11 +203,26 @@ def selection_changed(self, event):
self._save_destination.enabled = False
self._class_selector.enabled = False

def add_annotations_layer(self):
self._annotations_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Annotations",
translate=self._last_selected_label_layer.translate,
)
self._annotations_layer.editable = False
# Set the label selection to a valid label layer
self._viewer.layers.selection.active = self._last_selected_label_layer

def toggle_label(self, labels_layer, event):
"""
Callback for when a label is clicked. It then updates the color of that
label in the annotation layer.
"""
# If the annotations layer is missing, add it back
if "Annotations" not in [x.name for x in self._viewer.layers]:
self.add_annotations_layer()

# Need to translate & scale position that event.position returns by the
# label_layer scale.
# If scale is (1, 1, 1), nothing changes
Expand Down
116 changes: 85 additions & 31 deletions src/napari_feature_classifier/classifier_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
self._last_selected_label_layer = get_selected_or_valid_label_layer(
viewer=self._viewer
)

# Initialize the classifier
if classifier:
self._classifier = classifier
Expand All @@ -248,17 +249,10 @@ def __init__(
self._viewer, get_class_selection(class_names=self.class_names)
)

# Handle existing predictions layer
for layer in self._viewer.layers:
if type(layer) == napari.layers.Labels and layer.name == "Predictions":
self._viewer.layers.remove(layer)
self._prediction_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Predictions",
translate=self._last_selected_label_layer.translate,
)
self._prediction_layer.contour = 2
self.add_prediction_layer()

# Set the label selection to a valid label layer => Running into proxy bug
self._viewer.layers.selection.active = self._last_selected_label_layer
Expand Down Expand Up @@ -295,11 +289,6 @@ def __init__(
self._export_button.clicked.connect(self.export_results)
self._viewer.layers.selection.events.changed.connect(self.selection_changed)
self._init_prediction_layer(self._last_selected_label_layer)
# Whenever the label layer is clicked, hide the prediction layer
# (e.g. new annotations are made)
# self._last_selected_label_layer.mouse_drag_callbacks.append(
# self.hide_prediction_layer
# )

def run(self):
"""
Expand Down Expand Up @@ -349,6 +338,15 @@ def add_features_to_classifier(self):
dict_of_features[layer.name] = layer.features
self._classifier.add_dict_of_features(dict_of_features)

def add_prediction_layer(self):
self._prediction_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Predictions",
translate=self._last_selected_label_layer.translate,
)
self._prediction_layer.contour = 2

def make_predictions(self):
"""
Make predictions for all relevant label layers and add them to the
Expand Down Expand Up @@ -398,17 +396,72 @@ def selection_changed(self):
viewer=self._viewer
):
self._last_selected_label_layer = self._viewer.layers.selection.active
self._init_prediction_layer(self._viewer.layers.selection.active)
# self._last_selected_label_layer.mouse_drag_callbacks.append(
# self.hide_prediction_layer
# )
self._init_prediction_layer(
self._viewer.layers.selection.active, ensure_layer_presence=False
)
self._update_export_destination(self._last_selected_label_layer)

def _init_prediction_layer(self, label_layer: napari.layers.Labels):
def reorder_layers(self):
"""Reorders layers if needed to ensure Annotation & Prediction layers
are above the currently selected label layer.
"""
# Get the current order of layers
all_layers = list(self._viewer.layers)

# Determine the indices of the layers if they exist
indices_to_move = []

# Find the index of "Prediction" layer if it exists
if "Predictions" in self._viewer.layers:
indices_to_move.append(self._viewer.layers.index("Predictions"))

# Find the index of "Annotation" layer if it exists
if "Annotations" in self._viewer.layers:
indices_to_move.append(self._viewer.layers.index("Annotations"))

# Find the index of the reference_label_layer
if self._last_selected_label_layer.name in self._viewer.layers:
indices_to_move.append(
self._viewer.layers.index(self._last_selected_label_layer.name)
)

# Calculate the new order of layer indices
remaining_indices = [
i for i in range(len(all_layers)) if i not in indices_to_move
]
remaining_indices.reverse()
new_order = indices_to_move + remaining_indices
new_order.reverse()

# Reorder the layers using the move_multiple function
self._viewer.layers.move_multiple(new_order)

def _init_prediction_layer(
self, label_layer: napari.layers.Labels, ensure_layer_presence: bool = True
):
"""
Initialize the prediction layer and reset its data (to fit the input
label_layer) and its colormap
label_layer) and its colormap.
ensure_layer_presence creates the Predictions layer if it doesn't exist
yet and triggers layer reordering.
"""
# Ensure that prediction layer exists
if (
"Predictions" not in [x.name for x in self._viewer.layers]
and ensure_layer_presence
):
self.add_prediction_layer()
if ensure_layer_presence:
# Ensure correct layer order: This sometimes fails with weird
# EmitLoopError & IndexError that should be ignored
try:
self.reorder_layers()
except: # noqa
pass

# Ensure that prediction layer is above the current label layer
self._last_selected_label_layer

# Check if the predict column already exists in the layer.features
if "prediction" not in label_layer.features:
unique_labels = np.unique(label_layer.data)[1:]
Expand Down Expand Up @@ -448,12 +501,6 @@ def _init_prediction_layer(self, label_layer: napari.layers.Labels):
cmap=get_colormap(),
)

# def hide_prediction_layer(self, labels_layer, event):
# """
# Hide the prediction layer
# """
# self._prediction_layer.visible = False

def get_relevant_label_layers(self):
relevant_label_layers = []
required_columns = [self._label_column, self._roi_id_colum]
Expand Down Expand Up @@ -613,12 +660,19 @@ def load(self):
with open(clf_path, "rb") as f: # pylint: disable=C0103
clf = pickle.load(f)

self._run_container = ClassifierRunContainer(
self._viewer,
clf,
classifier_save_path=clf_path,
auto_save=True,
)
try:
self._run_container = ClassifierRunContainer(
self._viewer,
clf,
classifier_save_path=clf_path,
auto_save=True,
)
except NotImplementedError:
napari_info(
"Create a label layer with a feature dataframe before loading "
"the classifier"
)
return
self.clear()
self.append(self._run_container)

Expand Down

0 comments on commit 37583cf

Please sign in to comment.