Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor-based annotation storage to reduce DDP RAM usage with large COCO-format datasets #1885

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import gc
from collections import defaultdict

import copy
import dataclasses
import json
import os

import numpy as np
import os
import pickle
import torch
from tqdm import tqdm
from typing import List, Optional, Tuple

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.common.deprecate import deprecated_parameter
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.common.registry import register_dataset
from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy_inplace
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.training.utils.detection_utils import change_bbox_bounds_for_image_size

logger = get_logger(__name__)


def _serialize(data):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
buffer = pickle.dumps(data, protocol=-1)
return torch.frombuffer(buffer, dtype=torch.uint8)


@register_dataset("COCOFormatDetectionDataset")
class COCOFormatDetectionDataset(DetectionDataset):
"""Base dataset to load ANY dataset that is with a similar structure to the COCO dataset.
Expand Down Expand Up @@ -104,8 +114,21 @@ def _setup_data_source(self) -> int:

self.original_classes = list(all_class_names)
self.classes = copy.deepcopy(self.original_classes)
self._annotations = annotations
return len(annotations)

self._annotations = [_serialize(x) for x in annotations]

del annotations, all_class_names
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
gc.collect()

self._addr = torch.tensor([len(x) for x in self._annotations], dtype=torch.int64)
self._addr = torch.cumsum(self._addr, dim=0)
self._annotations = torch.concatenate(self._annotations)
print("Serialized dataset takes {:.2f} MiB".format(len(self._annotations) / 1024**2))
shaydeci marked this conversation as resolved.
Show resolved Hide resolved

return len(self._addr)

def __len__(self) -> int:
return len(self._addr)

@property
def _all_classes(self) -> List[str]:
Expand All @@ -124,7 +147,9 @@ def _load_annotation(self, sample_id: int) -> dict:
:return img_path: Path to the associated image
"""

annotation = self._annotations[sample_id]
start_addr = 0 if sample_id == 0 else self._addr[sample_id - 1].item()
end_addr = self._addr[sample_id].item()
annotation = pickle.loads(self._annotations[start_addr:end_addr].numpy().data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We saw slowdown due to pickle load, do we want to make the serialize-parse fix optional?
I mean, eventually a memory leak is a memory leak, but for small datasets you get overhead whereas without the fix "you'd be fine". @BloodAxe , thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NatanBagrov yes IMO


width = annotation.image_width
height = annotation.image_height
Expand Down Expand Up @@ -200,26 +225,27 @@ def parse_coco_into_detection_annotations(
(respecting include_classes/exclude_classes/class_ids_to_ignore) and
annotations is a list of DetectionAnnotation objects.
"""
print(f"Creating a new COCO dataset with {ann}")
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
with open(ann, "r") as f:
coco = json.load(f)

# Extract class names and class ids
print("Extract class names and class ids")
category_ids = np.array([category["id"] for category in coco["categories"]], dtype=int)
category_names = np.array([category["name"] for category in coco["categories"]], dtype=str)

# Extract box annotations
print("Extract box annotations")
ann_box_xyxy = xywh_to_xyxy_inplace(np.array([annotation["bbox"] for annotation in coco["annotations"]], dtype=np.float32), image_shape=None)

ann_category_id = np.array([annotation["category_id"] for annotation in coco["annotations"]], dtype=int)
ann_iscrowd = np.array([annotation["iscrowd"] for annotation in coco["annotations"]], dtype=bool)
ann_image_ids = np.array([annotation["image_id"] for annotation in coco["annotations"]], dtype=int)

# Extract image stuff
img_ids = np.array([img["id"] for img in coco["images"]], dtype=int)
img_paths = np.array([img["file_name"] if "file_name" in img else "{:012}".format(img["id"]) + ".jpg" for img in coco["images"]], dtype=str)
img_width_height = np.array([(img["width"], img["height"]) for img in coco["images"]], dtype=int)
print("Extract image stuff")
img_ids = [img["id"] for img in coco["images"]]
img_paths = [img["file_name"] if "file_name" in img else "{:012}".format(img["id"]) + ".jpg" for img in coco["images"]]
img_width_height = [(img["width"], img["height"]) for img in coco["images"]]

# Now, we can drop the annotations that belongs to the excluded classes
print("Now, we can drop the annotations that belongs to the excluded classes")
if int(class_ids_to_ignore is not None) + int(exclude_classes is not None) + int(include_classes is not None) > 1:
raise ValueError("Only one of exclude_classes, class_ids_to_ignore or include_classes can be specified")
elif exclude_classes is not None:
Expand Down Expand Up @@ -256,12 +282,12 @@ def parse_coco_into_detection_annotations(
# category_ids can be non-sequential and not ordered
num_categories = len(category_ids)

# Make sequential
print("Make sequential")
order = np.argsort(category_ids, kind="stable")
category_ids = category_ids[order] #
category_names = category_names[order]

# Remap category ids to be in range [0, num_categories)
print("Remap category ids to be in range [0, num_categories)")
class_label_table = np.zeros(np.max(category_ids) + 1, dtype=int) - 1
new_class_ids = np.arange(num_categories, dtype=int)
class_label_table[category_ids] = new_class_ids
Expand All @@ -273,9 +299,17 @@ def parse_coco_into_detection_annotations(

annotations = []

for img_id, image_path, (image_width, image_height) in zip(img_ids, img_paths, img_width_height):
mask = ann_image_ids == img_id

print("Indexing annotations...")
img_id2ann_box_xyxy = defaultdict(list)
img_id2ann_iscrowd = defaultdict(list)
img_id2ann_category_id = defaultdict(list)
for ann_image_id, _ann_box_xyxy, _ann_iscrowd, _ann_category_id in zip(ann_image_ids, ann_box_xyxy, ann_iscrowd, ann_category_id):
img_id2ann_box_xyxy[ann_image_id].append(_ann_box_xyxy)
img_id2ann_iscrowd[ann_image_id].append(_ann_iscrowd)
img_id2ann_category_id[ann_image_id].append(_ann_category_id)

print("Create annotations")
for img_id, image_path, (image_width, image_height) in tqdm(zip(img_ids, img_paths, img_width_height), total=len(img_ids)):
if image_path_prefix is not None:
image_path = os.path.join(image_path_prefix, image_path)

Expand All @@ -284,9 +318,9 @@ def parse_coco_into_detection_annotations(
image_path=image_path,
image_width=image_width,
image_height=image_height,
ann_boxes_xyxy=ann_box_xyxy[mask],
ann_is_crowd=ann_iscrowd[mask],
ann_labels=ann_category_id[mask],
ann_boxes_xyxy=np.asarray(img_id2ann_box_xyxy[img_id], dtype=np.float32).reshape(-1, 4),
ann_is_crowd=np.asarray(img_id2ann_iscrowd[img_id], dtype=bool).reshape(-1),
ann_labels=np.asarray(img_id2ann_category_id[img_id], dtype=int).reshape(-1),
)
annotations.append(ann)

Expand Down
Loading