Skip to content

Commit

Permalink
Add process_batch_size argument to control memory
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch committed Nov 30, 2023
1 parent ec9a2bf commit 308049b
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion segment_anything_fast/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
process_batch_size: Optional[int] = None,
) -> None:
"""
Using a SAM model, generates masks for the entire image.
Expand Down Expand Up @@ -94,6 +95,10 @@ def __init__(
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
process_batch_size (int or None): Set a batch size for the decoding step.
If None, all points will be batched up at once. Set a small number here
to decrease memory footprint. A smaller number will likely decrease
latency, but also decrease memory usage.
"""

assert (points_per_side is None) != (
Expand Down Expand Up @@ -133,6 +138,7 @@ def __init__(
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
self.process_batch_size = process_batch_size

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -241,8 +247,13 @@ def _process_crop(
points_for_image = self.point_grids[crop_layer_idx] * points_scale

# Generate masks for this crop in batches
data = MaskData()
all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)]
data = self._process_batch(all_points, cropped_im_size, crop_box, orig_size)
process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size
for i in range(0, len(all_points), process_batch_size):
some_points = all_points[i:i+process_batch_size]
batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
data["rles"] = mask_to_rle_pytorch_2(data["masks"])
self.predictor.reset_image()

Expand Down

0 comments on commit 308049b

Please sign in to comment.