diff --git a/segment_anything_fast/automatic_mask_generator.py b/segment_anything_fast/automatic_mask_generator.py index 6838a1b..06874d1 100644 --- a/segment_anything_fast/automatic_mask_generator.py +++ b/segment_anything_fast/automatic_mask_generator.py @@ -50,6 +50,7 @@ def __init__( point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask", + process_batch_size: Optional[int] = None, ) -> None: """ Using a SAM model, generates masks for the entire image. @@ -94,6 +95,10 @@ def __init__( 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. + process_batch_size (int or None): Set a batch size for the decoding step. + If None, all points will be batched up at once. Set a small number here + to decrease memory footprint. A smaller number will likely decrease + latency, but also decrease memory usage. """ assert (points_per_side is None) != ( @@ -133,6 +138,7 @@ def __init__( self.crop_n_points_downscale_factor = crop_n_points_downscale_factor self.min_mask_region_area = min_mask_region_area self.output_mode = output_mode + self.process_batch_size = process_batch_size @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: @@ -241,8 +247,13 @@ def _process_crop( points_for_image = self.point_grids[crop_layer_idx] * points_scale # Generate masks for this crop in batches + data = MaskData() all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)] - data = self._process_batch(all_points, cropped_im_size, crop_box, orig_size) + process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size + for i in range(0, len(all_points), process_batch_size): + some_points = all_points[i:i+process_batch_size] + batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) data["rles"] = mask_to_rle_pytorch_2(data["masks"]) self.predictor.reset_image()