diff --git a/.github/workflows/website.yml b/.github/workflows/website.yml index ede9eef9f..830788388 100644 --- a/.github/workflows/website.yml +++ b/.github/workflows/website.yml @@ -8,7 +8,7 @@ on: # 'main' triggers updates to 'sleap.ai', all others to 'sleap.ai/develop' - main - develop - - liezl/update-intallation-docs-1.4.1 # again! + - liezl/add-channels-for-pip-conda-env paths: - "docs/**" - "README.rst" diff --git a/docs/_static/bonsai-connection.jpg b/docs/_static/bonsai-connection.jpg new file mode 100644 index 000000000..32b725416 Binary files /dev/null and b/docs/_static/bonsai-connection.jpg differ diff --git a/docs/_static/bonsai-filecapture.jpg b/docs/_static/bonsai-filecapture.jpg new file mode 100644 index 000000000..7a809d67a Binary files /dev/null and b/docs/_static/bonsai-filecapture.jpg differ diff --git a/docs/_static/bonsai-predictcentroids.jpg b/docs/_static/bonsai-predictcentroids.jpg new file mode 100644 index 000000000..e284f2338 Binary files /dev/null and b/docs/_static/bonsai-predictcentroids.jpg differ diff --git a/docs/_static/bonsai-predictposeidentities.jpg b/docs/_static/bonsai-predictposeidentities.jpg new file mode 100644 index 000000000..8582fd707 Binary files /dev/null and b/docs/_static/bonsai-predictposeidentities.jpg differ diff --git a/docs/_static/bonsai-predictposes.jpg b/docs/_static/bonsai-predictposes.jpg new file mode 100644 index 000000000..2e4f04a22 Binary files /dev/null and b/docs/_static/bonsai-predictposes.jpg differ diff --git a/docs/_static/bonsai-workflow.jpg b/docs/_static/bonsai-workflow.jpg new file mode 100644 index 000000000..0481c3dcf Binary files /dev/null and b/docs/_static/bonsai-workflow.jpg differ diff --git a/docs/guides/bonsai.md b/docs/guides/bonsai.md new file mode 100644 index 000000000..d262873b6 --- /dev/null +++ b/docs/guides/bonsai.md @@ -0,0 +1,75 @@ +(bonsai)= + +# Using Bonsai with SLEAP + +Bonsai is a visual language for reactive programming and currently supports SLEAP models. + +:::{note} +Currently Bonsai supports only single instance, top-down and top-down-id SLEAP models. +::: + +### Exporting a SLEAP trained model + +Before we can import a trained model into Bonsai, we need to use the {code}`sleap-export` command to convert the model to a format supported by Bonsai. For example, to export a top-down-id model, the command is as follows: + +```bash +sleap-export -m centroid/model/folder/path -m top_down_id/model/folder/path -e exported/model/path +``` + +Please refer to the {ref}`sleap-export` docs for more details on using the command. + +This will generate the necessary `.pb` file and other information files required by Bonsai. In this example, these files were saved to the specified `exported/model/path` folder. + +The `exported/model/path` folder will have a structure like the following: + +```plaintext +exported/model/path +├── centroid_config.json +├── confmap_config.json +├── frozen_graph.pb +└── info.json +``` + +### Installing Bonsai and necessary packages + +1. Install Bonsai. See the [Bonsai installation instructions](https://bonsai-rx.org/docs/articles/installation.html). + +2. Download and add the necessary packages for Bonsai to run with SLEAP. See the official [Bonsai SLEAP documentation](https://github.com/bonsai-rx/sleap?tab=readme-ov-file#bonsai---sleap) for more information. + +### Using Bonsai SLEAP modules + +Once you have Bonsai installed with the required packages, you should be able to open the Bonsai application. The workflow must have a source module `FileCapture` which can be found in the toolbox search in the workflow editor. Provide the path to the video that was used to train the SLEAP model in the `FileName` field of the module. + +![Bonsai FileCapture module](../_static/bonsai-filecapture.jpg) + +#### Top-down model +The top-down model requires both the `PredictCentroids` and the `PredictPoses` modules. + +The `PredictCentroids` module will predict the centroids of detections. There are two fields inside the `PredictCentroids` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centroid model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictCentroids module](../_static/bonsai-predictcentroids.jpg) + +The `PredictPoses` module will predict the instances of detections. Similar to the `PredictCentroid` module, there are two fields inside the `PredictPoses` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centered instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictPoses module](../_static/bonsai-predictposes.jpg) + +#### Top-Down-ID model +The `PredictPoseIdentities` module will predict the instances with identities. This module has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the top-down-id model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictPoseIdentities module](../_static/bonsai-predictposeidentities.jpg) + +#### Single instance model +The `PredictSinglePose` module will predict the poses for single instance models. This module also has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the single instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +### Connecting the modules +Right-click on the `FileCapture` module and select **Create Connection**. Now click on the required SLEAP module to complete the connection. + +![Bonsai module connection ](../_static/bonsai-connection.jpg) + +Once it is done, the workflow in Bonsai will look something like the following: + +![Bonsai.SLEAP workflow](../_static/bonsai-workflow.jpg) + +Now you can click the green start button to run the workflow and you can add more modules to analyze and visualize the results in Bonsai. + +For more documentation on various modules and workflows, please refer to the [official Bonsai docs](https://bonsai-rx.org/docs/articles/editor.html). diff --git a/docs/guides/index.md b/docs/guides/index.md index 7eb55b2b2..6d773d9de 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -30,6 +30,10 @@ {ref}`remote-inference` when you trained models and you want to run inference on a different machine using a **command-line interface**. +## SLEAP with Bonsai + +{ref}`bonsai` when you want to analyze the trained SLEAP model to visualize the poses, centroids and identities for further visual analysis. + ```{toctree} :hidden: true :maxdepth: 2 @@ -44,4 +48,5 @@ proofreading colab custom-training remote +bonsai ``` diff --git a/docs/installation.md b/docs/installation.md index 4799a0893..d926c724a 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -27,7 +27,7 @@ local: Installation requires entering commands in a terminal. To open one: ````{tabs} ```{tab} Windows - Open the *Start menu* and search for the *Anaconda Prompt* (if using Miniconda) or the *Command Prompt* if not. + Open the *Start menu* and search for the *Anaconda Prompt* (if using Miniconda) or the *Command Prompt* if not. ```{note} On Windows, our personal preference is to use alternative terminal apps like [Cmder](https://cmder.net) or [Windows Terminal](https://aka.ms/terminal). ``` @@ -66,7 +66,6 @@ If you don't have a `conda` package manager installation, here are some quick in Miniforge is a minimal installer for conda that includes the `conda` package manager and is maintained by the [conda-forge](https://conda-forge.org) community. The only difference between Miniforge and Miniconda is that Miniforge uses the `conda-forge` channel by default, which provides a much wider selection of community-maintained packages. - ````{tabs} ```{group-tab} Windows Open a new PowerShell terminal (does not need to be admin) and enter: @@ -135,20 +134,20 @@ This is a minimal installer for conda that includes the `conda` package manager See the [Miniconda website](https://docs.anaconda.com/free/miniconda/) for up-to-date installation instructions if the above instructions don't work for your system. - (installation-methods)= + ## Installation methods SLEAP can be installed three different ways: via {ref}`conda package`, {ref}`conda from source`, or {ref}`pip package`. Select one of the methods below to install SLEAP. We recommend {ref}`conda package`. -````{tabs} +`````{tabs} ```{tab} conda package **This is the recommended installation method**. ````{tabs} ```{group-tab} Windows and Linux ```bash conda create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap=1.4.1a2 - ``` + ``` ```{note} - This comes with CUDA to enable GPU support. All you need is to have an NVIDIA GPU and [updated drivers](https://nvidia.com/drivers). - If you already have CUDA installed on your system, this will not conflict with it. @@ -222,7 +221,7 @@ SLEAP can be installed three different ways: via {ref}`conda package bool: return super().ask(context, params) +class DeleteFrameLimitPredictions(InstanceDeleteCommand): + @staticmethod + def get_frame_instance_list(context: CommandContext, params: Dict): + """Called from the parent `InstanceDeleteCommand.ask` method. + + Returns: + List of instances to be deleted. + """ + instances = [] + # Select the instances to be deleted + for lf in context.labels.labeled_frames: + if lf.frame_idx < (params["min_frame_idx"] - 1) or lf.frame_idx > ( + params["max_frame_idx"] - 1 + ): + instances.extend([(lf, inst) for inst in lf.instances]) + return instances + + @classmethod + def ask(cls, context: CommandContext, params: Dict) -> bool: + current_video = context.state["video"] + dialog = FrameRangeDialog( + title="Delete Instances in Frame Range...", max_frame_idx=len(current_video) + ) + results = dialog.get_results() + if results: + params["min_frame_idx"] = results["min_frame_idx"] + params["max_frame_idx"] = results["max_frame_idx"] + return super().ask(context, params) + + class TransposeInstances(EditCommand): topics = [UpdateTopic.project_instances, UpdateTopic.tracks] diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index f68dc0180..721bdc321 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -15,20 +15,17 @@ """ -from qtpy import QtCore, QtWidgets, QtGui - -import numpy as np import os - from operator import itemgetter +from pathlib import Path +from typing import Any, Callable, List, Optional -from typing import Any, Callable, Dict, List, Optional, Type +import numpy as np +from qtpy import QtCore, QtGui, QtWidgets -from sleap.gui.state import GuiState from sleap.gui.commands import CommandContext -from sleap.gui.color import ColorManager -from sleap.io.dataset import Labels -from sleap.instance import LabeledFrame, Instance +from sleap.gui.state import GuiState +from sleap.instance import LabeledFrame from sleap.skeleton import Skeleton @@ -386,10 +383,25 @@ def getSelectedRowItem(self) -> Any: class VideosTableModel(GenericTableModel): - properties = ("filename", "frames", "height", "width", "channels") - - def item_to_data(self, obj, item): - return {key: getattr(item, key) for key in self.properties} + properties = ( + "name", + "filepath", + "frames", + "height", + "width", + "channels", + ) + + def item_to_data(self, obj, item: "Video"): + data = {} + for property in self.properties: + if property == "name": + data[property] = Path(item.filename).name + elif property == "filepath": + data[property] = str(Path(item.filename).parent) + else: + data[property] = getattr(item, property) + return data class SkeletonNodesTableModel(GenericTableModel): diff --git a/sleap/gui/dialogs/frame_range.py b/sleap/gui/dialogs/frame_range.py new file mode 100644 index 000000000..7165dd939 --- /dev/null +++ b/sleap/gui/dialogs/frame_range.py @@ -0,0 +1,42 @@ +"""Frame range dialog.""" +from qtpy import QtWidgets +from sleap.gui.dialogs.formbuilder import FormBuilderModalDialog +from typing import Optional + + +class FrameRangeDialog(FormBuilderModalDialog): + def __init__(self, max_frame_idx: Optional[int] = None, title: str = "Frame Range"): + + super().__init__(form_name="frame_range_form") + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + if max_frame_idx is not None: + min_frame_idx_field.setRange(1, max_frame_idx) + min_frame_idx_field.setValue(1) + + max_frame_idx_field.setRange(1, max_frame_idx) + max_frame_idx_field.setValue(max_frame_idx) + + min_frame_idx_field.valueChanged.connect(self._update_max_frame_range) + max_frame_idx_field.valueChanged.connect(self._update_min_frame_range) + + self.setWindowTitle(title) + + def _update_max_frame_range(self, value): + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + max_frame_idx_field.setRange(value, max_frame_idx_field.maximum()) + + def _update_min_frame_range(self, value): + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + min_frame_idx_field.setRange(min_frame_idx_field.minimum(), value) + + +if __name__ == "__main__": + app = QtWidgets.QApplication([]) + dialog = FrameRangeDialog(max_frame_idx=100) + print(dialog.get_results()) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 2c2617036..bc26d826c 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -637,6 +637,20 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) return items_for_inference + def _validate_id_model(self) -> bool: + """Make sure we have instances with tracks set for ID models.""" + if not self.labels.tracks: + message = "Cannot run ID model training without tracks." + return False + + found_tracks = False + for inst in self.labels.instances(): + if type(inst) == sleap.Instance and inst.track is not None: + found_tracks = True + break + + return found_tracks + def _validate_pipeline(self): can_run = True message = "" @@ -655,6 +669,15 @@ def _validate_pipeline(self): f"({', '.join(untrained)})." ) + # Make sure we have instances with tracks set for ID models. + if self.mode == "training" and self.current_pipeline in ( + "top-down-id", + "bottom-up-id", + ): + can_run = self.validate_id_model() + if not can_run: + message = "Cannot run ID model training without tracks." + # Make sure skeleton will be valid for bottom-up inference. if self.mode == "training" and self.current_pipeline == "bottom-up": skeleton = self.labels.skeletons[0] diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 949703020..08ee5bf36 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -62,6 +62,7 @@ QShortcut, QVBoxLayout, QWidget, + QPinchGesture, ) import sleap @@ -823,6 +824,8 @@ def __init__(self, state=None, player=None, *args, **kwargs): # Set icon as default background. self.setImage(QImage(sleap.util.get_package_file("gui/background.png"))) + self.grabGesture(Qt.GestureType.PinchGesture) + def dragEnterEvent(self, event): if self.parentWidget(): self.parentWidget().dragEnterEvent(event) @@ -1189,6 +1192,23 @@ def keyReleaseEvent(self, event): """Custom event hander, disables default QGraphicsView behavior.""" event.ignore() # Kicks the event up to parent + def event(self, event): + if event.type() == QtCore.QEvent.Gesture: + return self.handleGestureEvent(event) + return super().event(event) + + def handleGestureEvent(self, event): + gesture = event.gesture(Qt.GestureType.PinchGesture) + if gesture: + self.handlePinchGesture(gesture) + return True + + def handlePinchGesture(self, gesture: QPinchGesture): + if gesture.state() == Qt.GestureState.GestureUpdated: + factor = gesture.scaleFactor() + self.zoomFactor = max(factor * self.zoomFactor, 1) + self.updateViewer() + class QtNodeLabel(QGraphicsTextItem): """ @@ -1570,7 +1590,6 @@ def mousePressEvent(self, event): def mouseMoveEvent(self, event): """Custom event handler for mouse move.""" - # print(event) if self.dragParent: self.parentObject().mouseMoveEvent(event) else: @@ -1581,7 +1600,6 @@ def mouseMoveEvent(self, event): def mouseReleaseEvent(self, event): """Custom event handler for mouse release.""" - # print(event) self.unsetCursor() if self.dragParent: self.parentObject().mouseReleaseEvent(event) @@ -1610,6 +1628,10 @@ def mouseDoubleClickEvent(self, event: QMouseEvent): view = scene.views()[0] view.instanceDoubleClicked.emit(self.parentObject().instance, event) + def hoverEnterEvent(self, event): + """Custom event handler for mouse hover enter.""" + return super().hoverEnterEvent(event) + class QtEdge(QGraphicsPolygonItem): """ @@ -1809,6 +1831,7 @@ def __init__( self.labels = {} self.labels_shown = True self._selected = False + self._is_hovering = False self._bounding_rect = QRectF() # Show predicted instances behind non-predicted ones @@ -1830,6 +1853,7 @@ def __init__( box_pen.setStyle(Qt.DashLine) box_pen.setCosmetic(True) self.box.setPen(box_pen) + self.setAcceptHoverEvents(True) # Add label for highlighted instance self.highlight_label = QtTextWithBackground(parent=self) @@ -1991,7 +2015,12 @@ def updateBox(self, *args, **kwargs): select this instance. """ # Only show box if instance is selected - op = 0.7 if self._selected else 0 + op = 0 + if self._selected: + op = 0.8 + elif self._is_hovering: + op = 0.4 + self.box.setOpacity(op) # Update the position for the box rect = self.getPointsBoundingRect() @@ -2085,6 +2114,16 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + def hoverEnterEvent(self, event): + self._is_hovering = True + self.updateBox() + return super().hoverEnterEvent(event) + + def hoverLeaveEvent(self, event): + self._is_hovering = False + self.updateBox() + return super().hoverLeaveEvent(event) + class VisibleBoundingBox(QtWidgets.QGraphicsRectItem): """QGraphicsRectItem for user instance bounding boxes. @@ -2275,7 +2314,7 @@ def mouseReleaseEvent(self, event): self.parent.nodes[node_key].setPos(new_x, new_y) # Update the instance - self.parent.updatePoints(complete=True, user_change=True) + self.parent.updatePoints(complete=False, user_change=True) self.resizing = None diff --git a/sleap/info/summary.py b/sleap/info/summary.py index c6a6af60e..0cad1617e 100644 --- a/sleap/info/summary.py +++ b/sleap/info/summary.py @@ -21,7 +21,7 @@ class StatisticSeries: are frame index and value are some numerical value for the frame. Args: - labels: The :class:`Labels` for which to calculate series. + labels: The `Labels` for which to calculate series. """ labels: Labels @@ -41,7 +41,7 @@ def get_point_score_series( """Get series with statistic of point scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]: """Get series with statistic of instance scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo same track) from the closest earlier labeled frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -121,7 +121,7 @@ def get_primary_point_displacement_series( Get sum of displacement for single node of each instance per frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -226,7 +226,7 @@ def _calculate_frame_velocity( Calculate total point displacement between two given frames. Args: - lf: The :class:`LabeledFrame` for which we want velocity + lf: The `LabeledFrame` for which we want velocity last_lf: The frame from which to calculate displacement. reduce_function: Numpy function (e.g., np.sum, np.nanmean) is applied to *point* displacement, and then those @@ -246,3 +246,35 @@ def _calculate_frame_velocity( inst_dist = reduce_function(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 return val + + def get_tracking_score_series( + self, video: Video, reduction: str = "min" + ) -> Dict[int, float]: + """Get series with statistic of tracking scores in each frame. + + Args: + video: The `Video` for which to calculate statistic. + reduction: name of function applied to scores: + * mean + * min + + Returns: + The series dictionary (see class docs for details) + """ + reduce_fn = { + "min": np.nanmin, + "mean": np.nanmean, + }[reduction] + + series = dict() + + for lf in self.labels.find(video): + vals = [ + inst.tracking_score for inst in lf if hasattr(inst, "tracking_score") + ] + if vals: + val = reduce_fn(vals) + if not np.isnan(val): + series[lf.frame_idx] = val + + return series diff --git a/sleap/instance.py b/sleap/instance.py index 08a5c6ae6..382ececf2 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1049,7 +1049,9 @@ def scores(self) -> np.ndarray: return self.points_and_scores_array[:, 2] @classmethod - def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": + def from_instance( + cls, instance: Instance, score: float, tracking_score: float = 0.0 + ) -> "PredictedInstance": """Create a `PredictedInstance` from an `Instance`. The fields are copied in a shallow manner with the exception of points. For each @@ -1059,6 +1061,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": Args: instance: The `Instance` object to shallow copy data from. score: The score for this instance. + tracking_score: The tracking score for this instance. Returns: A `PredictedInstance` for the given `Instance`. @@ -1070,6 +1073,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": ) kw_args["points"] = PredictedPointArray.from_array(instance._points) kw_args["score"] = score + kw_args["tracking_score"] = tracking_score return cls(**kw_args) @classmethod @@ -1080,6 +1084,7 @@ def from_arrays( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1094,6 +1099,7 @@ def from_arrays( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. @@ -1114,6 +1120,7 @@ def from_arrays( skeleton=skeleton, score=instance_score, track=track, + tracking_score=tracking_score, ) @classmethod @@ -1124,6 +1131,7 @@ def from_pointsarray( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1138,12 +1146,18 @@ def from_pointsarray( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) @classmethod @@ -1154,6 +1168,7 @@ def from_numpy( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1168,12 +1183,18 @@ def from_numpy( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 67d4110b5..c27382e52 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -3780,9 +3780,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=np.nanmean(confs), skeleton=skeleton, track=track, + tracking_score=np.nanmean(score), ) ) @@ -4454,18 +4455,27 @@ def _object_builder(): break # Loop over frames. - for image, video_ind, frame_ind, points, confidences, scores in zip( + for ( + image, + video_ind, + frame_ind, + centroid_vals, + points, + confidences, + scores, + ) in zip( ex["image"], ex["video_ind"], ex["frame_ind"], + ex["centroid_vals"], ex["instance_peaks"], ex["instance_peak_vals"], ex["instance_scores"], ): # Loop over instances. predicted_instances = [] - for i, (pts, confs, score) in enumerate( - zip(points, confidences, scores) + for i, (pts, centroid_val, confs, score) in enumerate( + zip(points, centroid_vals, confidences, scores) ): if np.isnan(pts).all(): continue @@ -4476,9 +4486,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=centroid_val, skeleton=skeleton, track=track, + tracking_score=score, ) ) diff --git a/sleap/prefs.py b/sleap/prefs.py index 8790f1d3f..e043afc44 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -28,6 +28,8 @@ class Preferences(object): "node label size": 12, "show non-visible nodes": True, "share usage data": True, + "node marker sizes": (1, 2, 3, 4, 6, 8, 12), + "node label sizes": (6, 9, 12, 18, 24, 36), } _filename = "preferences.yaml" @@ -43,14 +45,14 @@ def load_(self): """Load preferences from file (regardless of whether loaded already).""" try: self._prefs = util.get_config_yaml(self._filename) - if not hasattr(self._prefs, "get"): - self._prefs = self._defaults - else: - self._prefs["trail length"] = self._prefs.get( - "trail length", self._defaults["trail length"] - ) except FileNotFoundError: - self._prefs = self._defaults + pass + + self._prefs = self._prefs or {} + + for k, v in self._defaults.items(): + if k not in self._prefs: + self._prefs[k] = v def save(self): """Save preferences to file.""" diff --git a/tests/data/tracks/clip.predictions.slp b/tests/data/tracks/clip.predictions.slp new file mode 100644 index 000000000..652e21302 Binary files /dev/null and b/tests/data/tracks/clip.predictions.slp differ diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index ec5dfbc29..c6507caec 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -97,6 +97,20 @@ def min_tracks_2node_labels(): ) +@pytest.fixture +def min_tracks_2node_predictions(): + """ + Generated with: + ``` + sleap-track -m "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" "tests/data/tracks/clip.mp4" + ``` + """ + return Labels.load_file( + "tests/data/tracks/clip.predictions.slp", + video_search=["tests/data/tracks/clip.mp4"], + ) + + @pytest.fixture def min_tracks_13node_labels(): return Labels.load_file( diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 3d77c891f..389bb48a3 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -7,6 +7,7 @@ import pytest from qtpy import QtWidgets +import sleap from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget from sleap.gui.learning.configs import ( TrainingConfigFilesWidget, @@ -429,3 +430,22 @@ def test_immutablilty_of_trained_config_info( # saving multiple configs from one config info. ld.save(output_dir=tmpdir) ld.save(output_dir=tmpdir) + + +def test_validate_id_model(qtbot, min_labels_slp, min_labels_slp_path): + app = MainWindow(no_usage_data=True) + ld = LearningDialog( + mode="training", + labels_filename=Path(min_labels_slp_path), + labels=min_labels_slp, + ) + assert not ld._validate_id_model() + + # Add track but don't assign it to instances + new_track = sleap.Track(name="new_track") + min_labels_slp.tracks.append(new_track) + assert not ld._validate_id_model() + + # Assign track to instances + min_labels_slp[0][0].track = new_track + assert ld._validate_id_model() diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index ffd382ab1..e19e00236 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -20,6 +20,7 @@ ReplaceVideo, OpenSkeleton, SaveProjectAs, + DeleteFrameLimitPredictions, get_new_version_filename, ) from sleap.instance import Instance, LabeledFrame @@ -851,6 +852,26 @@ def load_and_assert_changes(new_video_path: Path): shutil.move(new_video_path, expected_video_path) +def test_DeleteFrameLimitPredictions( + centered_pair_predictions: Labels, centered_pair_vid: Video +): + """Test deleting instances beyond a certain frame limit.""" + labels = centered_pair_predictions + + # Set-up command context + context = CommandContext.from_labels(labels) + context.state["video"] = centered_pair_vid + + # Set-up params for the command + params = {"min_frame_idx": 900, "max_frame_idx": 1000} + + instances_to_delete = DeleteFrameLimitPredictions.get_frame_instance_list( + context, params + ) + + assert len(instances_to_delete) == 2070 + + @pytest.mark.parametrize("export_extension", [".json.zip", ".slp"]) def test_exportLabelsPackage(export_extension, centered_pair_labels: Labels, tmpdir): def assert_loaded_package_similar(path_to_pkg: Path, sugg=False, pred=False): diff --git a/tests/info/test_summary.py b/tests/info/test_summary.py index 2cf76c166..672d97e63 100644 --- a/tests/info/test_summary.py +++ b/tests/info/test_summary.py @@ -37,6 +37,19 @@ def test_frame_statistics(simple_predictions): x = stats.get_point_displacement_series(video, "max") assert len(x) == 2 - assert len(x) == 2 assert x[0] == 0 assert x[1] == 18.0 + + +def test_get_tracking_score_series(min_tracks_2node_predictions): + + stats = StatisticSeries(min_tracks_2node_predictions) + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "min") + assert len(x) == 1500 + assert x[0] == 0.9999966621398926 + assert x[1000] == 0.9998022317886353 + + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "mean") + assert len(x) == 1500 + assert x[0] == 0.9999983310699463 + assert x[1000] == 0.9999011158943176