Skip to content

Commit

Permalink
Point clouds (#100)
Browse files Browse the repository at this point in the history
* fix holes and performance

* expose confidence maps

* start for point cloud tutorial

* new PointCloud type

* finish pointcloud tutorial

* switch to open3d tensor api

* fix mypy issues

* disable testing for point cloud notebooks

* set timeout back to 300

* tests for crop_point_cloud()

* update rerun point cloud logging

* rename pointcloud to point_cloud

* point cloud spelling

* fix notebooks

* update changelog
  • Loading branch information
Victorlouisdg authored Jan 18, 2024
1 parent 1a3425e commit f4675a1
Show file tree
Hide file tree
Showing 21 changed files with 1,525 additions and 73 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,26 @@ This project uses a [CalVer](https://calver.org/) versioning scheme with monthly
- internal dependencies are now listed as regular dependencies in the `setup.py` file to overcome issues and make the installation process less complicated. This implies you need to install packages according to their dependencies and can no longer use the `external` tag as in `pip install airo-typing[external]`.
see [issue #91](https://github.com/airo-ugent/airo-mono/issues/91) and
[PR](https://github.com/airo-ugent/airo-mono/pull/108) for more details.
- `PointCloud` dataclass replaces the `ColoredPointCloudType` to support point cloud attritubes


### Added
- `PointCloud` dataclass as the main data structure for point clouds in airo-mono
- Notebooks to get started with point clouds, checking performance and logging to rerun
- Functions to crop point clouds and filter points with a mask (e.g. low-confidence points))
- Functions to convert from our numpy-based dataclass to and from open3d point clouds
- `BoundingBox3DType`



### Changed
- dropped support for python 3.8 and added 3.11 to the testing matrix [#103](https://github.com/airo-ugent/airo-mono/issues/103)

### Fixed
- Fixed bug in `get_colored_point_cloud()` that removed some points see issue #25.

### Removed
- `ColoredPointCloudType`

## 2024.1.0

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Below is a short overview of the packages:

| Package | Description| owner |
|-------|-------|--------|
| `airo-camera-toolkit`|code for working with RGB(D) cameras, images and pointclouds |@tlpss|
| `airo-camera-toolkit`|code for working with RGB(D) cameras, images and point clouds |@tlpss|
|`airo-dataset-tools`| code for creating, loading and working with datasets| @Victorlouisdg|
| `airo-robots`| minimal interfaces for interacting with the controllers of robot arms and grippers| @tlpss|
| `airo-spatial-algebra`|code for working with SE3 poses |@tlpss|
Expand Down
1 change: 1 addition & 0 deletions airo-camera-toolkit/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
airo_camera_toolkit/image_transforms/tutorial_image.jpg
airo_camera_toolkit/calibration/saved_calibrations
**/calibration_20**/
notebooks/data
8 changes: 6 additions & 2 deletions airo-camera-toolkit/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# airo-camera-toolkit
This package contains code for working with RGB(D) cameras, images and pointclouds.
This package contains code for working with RGB(D) cameras, images and point clouds.


Overview of the functionality and the structure:
Expand All @@ -9,6 +9,7 @@ airo-camera-toolkit/
├── cameras/ # actual camera drivers
├── image_transformations/ # reversible geometric 2D transforms
├── pinhole_operations/ # 2D-3D operations
├── point_clouds/ # conversions and operations
├── utils/ # a.o. annotation tool and converter
├── interfaces.py
└── cli.py
Expand Down Expand Up @@ -87,9 +88,12 @@ See [annotation_tool.md](./airo_camera_toolkit/annotation_tool.md) for usage ins
See the [README](./airo_camera_toolkit/image_transforms/README.md) in the `image_transforms` folder for more details.

## Real-time visualisation
For realtime visualisation of robotics data we strongly encourage using [rerun.io](https://www.rerun.io/) instead of manually hacking something together with opencv/pyqt/... No wrappers are needed here, just pip install the SDK. An example notebook to get to know this tool and its potential can be found [here](docs/rerun-zed-example.ipynb).
For realtime visualisation of robotics data we strongly encourage using [rerun.io](https://www.rerun.io/) instead of manually hacking something together with opencv/pyqt/... No wrappers are needed here, just pip install the SDK. An example notebook to get to know this tool and its potential can be found [here](notebooks/rerun-zed-tutorial.ipynb).
See this [README](./docs/rerun.md) for more details.

## Point clouds
See the tutorial notebook [here](notebooks/point_clouds_tutorial.ipynb) for an introduction.

## Multiprocessing
Camera processing can be computationally expensive.
If this is a problem for your application, see [multiprocess/README.md](./airo_camera_toolkit/cameras/multiprocess/README.md).
Expand Down
53 changes: 29 additions & 24 deletions airo-camera-toolkit/airo_camera_toolkit/cameras/zed/zed2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@

import time

import cv2
import numpy as np
from airo_camera_toolkit.interfaces import StereoRGBDCamera
from airo_camera_toolkit.utils.image_converter import ImageConverter
from airo_typing import (
CameraIntrinsicsMatrixType,
CameraResolutionType,
ColoredPointCloudType,
HomogeneousMatrixType,
NumpyDepthMapType,
NumpyFloatImageType,
NumpyIntImageType,
OpenCVIntImageType,
PointCloud,
)


Expand Down Expand Up @@ -149,7 +150,10 @@ def __init__( # type: ignore[no-any-unimported]
self.image_matrix = sl.Mat()
self.depth_image_matrix = sl.Mat()
self.depth_matrix = sl.Mat()
self.pointcloud_matrix = sl.Mat()
self.point_cloud_matrix = sl.Mat()

self.confidence_matrix = sl.Mat()
self.confidence_map = None

@property
def resolution(self) -> CameraResolutionType:
Expand Down Expand Up @@ -188,7 +192,7 @@ def pose_of_right_view_in_left_view(self) -> HomogeneousMatrixType:

@property
def depth_enabled(self) -> bool:
"""Runtime parameter to enable/disable the depth & pointcloud computation. This speeds up the RGB image capture."""
"""Runtime parameter to enable/disable the depth & point_cloud computation. This speeds up the RGB image capture."""
return self.runtime_params.enable_depth

@depth_enabled.setter
Expand Down Expand Up @@ -218,9 +222,10 @@ def _retrieve_rgb_image_as_int(self, view: str = StereoRGBDCamera.LEFT_RGB) -> N
else:
view = sl.VIEW.LEFT
self.camera.retrieve_image(self.image_matrix, view)
image: OpenCVIntImageType = self.image_matrix.get_data()
image = image[..., :3] # remove alpha channel
image = image[..., ::-1] # convert from BGR to RGB
image_bgra: OpenCVIntImageType = self.image_matrix.get_data()
# image = image[..., :3] # remove alpha channel
# image = image[..., ::-1] # convert from BGR to RGB
image = cv2.cvtColor(image_bgra, cv2.COLOR_BGRA2RGB)
return image

def _retrieve_depth_map(self) -> NumpyDepthMapType:
Expand All @@ -238,31 +243,31 @@ def _retrieve_depth_image(self) -> NumpyIntImageType:
image = image[..., :3]
return image

def get_colored_point_cloud(self) -> ColoredPointCloudType:
def _retrieve_colored_point_cloud(self) -> PointCloud:
assert self.depth_mode != self.NONE_DEPTH_MODE, "Cannot retrieve depth data if depth mode is NONE"
assert self.depth_enabled, "Cannot retrieve depth data if depth is disabled"

self._grab_images()
self.camera.retrieve_measure(self.pointcloud_matrix, sl.MEASURE.XYZRGBA)
self.camera.retrieve_measure(self.point_cloud_matrix, sl.MEASURE.XYZ)
# shape (width, height, 4) with the 4th dim being x,y,z,(rgba packed into float)
# can be nan,nan,nan, nan (no point in the pointcloud on this pixel)
# can be nan,nan,nan, nan (no point in the point_cloud on this pixel)
# or x,y,z, nan (no color information on this pixel??)
# or x,y,z, value (color information on this pixel)

# filter out all that have nan in any of the positions of the 3th dim
# and reshape to (width*height, 4)
point_cloud = self.pointcloud_matrix.get_data()
point_cloud = point_cloud[~np.isnan(point_cloud).any(axis=2), :]
point_cloud = self.point_cloud_matrix.get_data()
points = point_cloud[:, :, :3].reshape(-1, 3)
colors = self._retrieve_rgb_image_as_int().reshape(-1, 3)

return PointCloud(points, colors)

# unpack the colors, drop alpha channel and convert to 0-1 range
points = point_cloud[:, :3]
colors = point_cloud[:, 3]
rgba = np.ravel(colors).view(np.uint8).reshape(-1, 4)
rgb = rgba[:, :3]
rgb_float = rgb.astype(np.float32) / 255.0 # convert to 0-1 range
def _retrieve_confidence_map(self) -> NumpyFloatImageType:
self.camera.retrieve_measure(self.confidence_matrix, sl.MEASURE.CONFIDENCE)
return self.confidence_matrix.get_data() # single channel float32 image

colored_pointcloud = np.concatenate((points, rgb_float), axis=1)
return colored_pointcloud
def get_colored_point_cloud(self) -> PointCloud:
assert self.depth_mode != self.NONE_DEPTH_MODE, "Cannot retrieve depth data if depth mode is NONE"
assert self.depth_enabled, "Cannot retrieve depth data if depth is disabled"

self._grab_images()
return self._retrieve_colored_point_cloud()

@staticmethod
def list_camera_serial_numbers() -> List[str]:
Expand Down Expand Up @@ -297,7 +302,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
# test rgbd stereo camera

with Zed2i(Zed2i.RESOLUTION_2K, fps=15, depth_mode=Zed2i.PERFORMANCE_DEPTH_MODE) as zed:
print(zed.get_colored_point_cloud()[0]) # TODO: test the pointcloud more explicity?
print(zed.get_colored_point_cloud().points) # TODO: test the point_cloud more explicity?
manual_test_stereo_rgbd_camera(zed)

# profile rgb throughput, should be at 60FPS, i.e. 0.017s
Expand Down
10 changes: 5 additions & 5 deletions airo-camera-toolkit/airo_camera_toolkit/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from airo_typing import (
CameraIntrinsicsMatrixType,
CameraResolutionType,
ColoredPointCloudType,
HomogeneousMatrixType,
NumpyDepthMapType,
NumpyFloatImageType,
NumpyIntImageType,
PointCloud,
)


Expand Down Expand Up @@ -103,13 +103,13 @@ def get_depth_image(self) -> NumpyIntImageType:
self._grab_images()
return self._retrieve_depth_image()

def get_colored_point_cloud(self) -> ColoredPointCloudType:
def get_colored_point_cloud(self) -> PointCloud:
"""Get the latest point cloud of the camera.
The point cloud contains 6D arrays of floats, that provide the estimated position in the camera frame
of points on the image plane (pixels). The last 3 floats are the corresponding RGB color (in the range [0, 1]).
The point cloud contains the estimated position in the camera frame of points on the image plane (pixels).
Each point also has a color associated with it, which is the color of the corresponding pixel in the RGB image.
Returns:
np.ndarray: Nx6 array containing PointCloud with color information. Each entry is (x,y,z,r,g,b)
PointCloud: the points (= positions) and colors
"""
# TODO: offer a base implementation that uses the depth map and the rgb image to construct this pointcloud?
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def extract_depth_from_depthmap_heuristic(
This function takes the percentile of a region around the specified point and assumes we are interested in the nearest object present.
This is not always true (think about the backside of a box looking under a 45 degree angle) but it serves as a good proxy. The more confident
you are of your keypoints and the better the heatmaps are, the lower you could set the mask size and percentile. If you are very, very confident
you could directly take the pointcloud as well instead of manually querying the heatmap, but I find that they are more noisy.
you could directly take the point cloud as well instead of manually querying the heatmap, but I find that they are more noisy.
Also note that this function assumes there are no negative infinity values (no objects closer than 30cm!)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any

import open3d as o3d
import open3d.core as o3c
from airo_typing import PointCloud


def point_cloud_to_open3d(point_cloud: PointCloud) -> Any: # TODO: change Any back to o3d.t.geometry.PointCloud
"""Converts a PointCloud dataclass object to an open3d tensor point cloud.
Note that the memory buffers of the underlying numpy arrays are shared between the two.
Args:
point_cloud: the point cloud to convert
Returns:
pcd: the open3d tensor point cloud
"""
positions = o3c.Tensor.from_numpy(point_cloud.points)

map_to_tensors = {
"positions": positions,
}

if point_cloud.colors is not None:
colors = o3c.Tensor.from_numpy(point_cloud.colors)
map_to_tensors["colors"] = colors

if point_cloud.attributes is not None:
for attribute_name, array in point_cloud.attributes.items():
map_to_tensors[attribute_name] = o3c.Tensor.from_numpy(array)

pcd = o3d.t.geometry.PointCloud(map_to_tensors)
return pcd


def open3d_to_point_cloud(pcd: Any) -> PointCloud: # TODO: change Any back to o3d.t.geometry.PointCloud
"""Converts an open3d point cloud to a PointCloud dataclass object.
Note that the memory buffers of the underlying numpy arrays are shared between the two.
Args:
pcd: the open3d tensor point cloud
"""
points = pcd.point.positions.numpy()
colors = pcd.point.colors.numpy() if "colors" in pcd.point else None

attributes = {}
for attribute_name, array in pcd.point.items():
if attribute_name in ["positions", "colors"]:
continue
attributes[attribute_name] = array.numpy()

return PointCloud(points, colors)
63 changes: 63 additions & 0 deletions airo-camera-toolkit/airo_camera_toolkit/point_clouds/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Any

import numpy as np
from airo_typing import BoundingBox3DType, PointCloud


def filter_point_cloud(point_cloud: PointCloud, mask: Any) -> PointCloud:
"""Creates a new point cloud that is filtered by the given mask.
Will also filter the colors and attributes if they are present.
Args:
point_cloud: the point cloud to filter
mask: the mask to filter the point cloud by, used to index the attribute arrays, can be boolean or indices
Returns:
the new filtered point cloud
"""
points = point_cloud.points[mask]
colors = None if point_cloud.colors is None else point_cloud.colors[mask]

attributes = None
if point_cloud.attributes is not None:
attributes = {}
for key, value in point_cloud.attributes.items():
attributes[key] = value[mask]

point_cloud_filtered = PointCloud(points, colors, attributes)
return point_cloud_filtered


def generate_point_cloud_crop_mask(point_cloud: PointCloud, bounding_box: BoundingBox3DType) -> np.ndarray:
"""Creates a mask that can be used to filter a point cloud to the given bounding box.
Args:
bounding_box: the bounding box that surrounds the points to keep
point_cloud: the point cloud to crop
Returns:
the mask that can be used to filter the point cloud
"""
points = point_cloud.points
x, y, z = points[:, 0], points[:, 1], points[:, 2]
(x_min, y_min, z_min), (x_max, y_max, z_max) = bounding_box
crop_mask = (x >= x_min) & (x <= x_max) & (y >= y_min) & (y <= y_max) & (z >= z_min) & (z <= z_max)
return crop_mask


def crop_point_cloud(
point_cloud: PointCloud,
bounding_box: BoundingBox3DType,
) -> PointCloud:
"""Creates a new point cloud that is cropped to the given bounding box.
Will also crop the colors and attributes if they are present.
Args:
bounding_box: the bounding box that surrounds the points to keep
point_cloud: the point cloud to crop
Returns:
the new cropped point cloud
"""
crop_mask = generate_point_cloud_crop_mask(point_cloud, bounding_box)
return filter_point_cloud(point_cloud, crop_mask.nonzero())
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Tuple

import open3d as o3d
from airo_typing import Vector3DType


def open3d_point(
position: Vector3DType, color: Tuple[float, float, float], radius: float = 0.01
) -> Any: # Change Any back to o3d.geometry.TriangleMesh
"""Creates a small sphere mesh for visualization in open3d.
Args:
position: 3D position of the point
color: RGB color of the point as 0-1 floats
radius: radius of the sphere
Returns:
sphere: an open3d mesh
"""
sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius)
sphere.translate(position)
sphere.paint_uniform_color(color)
sphere.compute_vertex_normals()
return sphere
2 changes: 1 addition & 1 deletion airo-camera-toolkit/docs/rerun.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ rerun.log_image("zed_top", image)
rerun.log_scalar("force_z", force[2])
...
```
See the [example notebook](./rerun-zed-example.ipynb) for more.
See the [example notebook](../notebooks/rerun-zed-example.ipynb) for more.

> :information_source: A note on starting the Rerun viewer: you can start it by calling `rerun.spawn()` from Python. However when starting Rerun like that, [there is no way to specify a memory limit](https://www.rerun.io/docs/howto/limit-ram). This quickly becomes a problem when logging images, so we recommend starting Rerun from a terminal:
>```
Expand Down
Loading

0 comments on commit f4675a1

Please sign in to comment.