Skip to content

Commit

Permalink
manually specify coco categories for cvat2coco
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Oct 17, 2023
1 parent 7353868 commit 60abe9b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 391 deletions.
11 changes: 8 additions & 3 deletions airo-dataset-tools/airo_dataset_tools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ def view_coco_dataset_cli(

@cli.command(name="convert-cvat-to-coco-keypoints")
@click.argument("cvat_xml_file", type=str, required=True)
@click.argument("coco_categories_json_file", type=str, required=True)
@click.option("--add_bbox", is_flag=True, default=False, help="include bounding box in coco annotations")
@click.option("--add_segmentation", is_flag=True, default=False, help="include segmentation in coco annotations")
def convert_cvat_to_coco_cli(cvat_xml_file: str, add_bbox: bool, add_segmentation: bool) -> None:
"""Convert CVAT XML to COCO keypoints json"""
coco = cvat_image_to_coco(cvat_xml_file, add_bbox=add_bbox, add_segmentation=add_segmentation)
def convert_cvat_to_coco_cli(
cvat_xml_file: str, coco_categories_json_file: str, add_bbox: bool, add_segmentation: bool
) -> None:
"""Convert CVAT XML to COCO keypoints json according to specified coco categories"""
coco = cvat_image_to_coco(
cvat_xml_file, coco_categories_json_file, add_bbox=add_bbox, add_segmentation=add_segmentation
)
path = os.path.dirname(cvat_xml_file)
filename = os.path.basename(cvat_xml_file)
path = os.path.join(path, filename.split(".")[0] + ".json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tqdm
from airo_dataset_tools.cvat_labeling.load_xml_to_dict import get_dict_from_xml
from airo_dataset_tools.data_parsers.coco import (
CocoCategory,
CocoImage,
CocoKeypointAnnotation,
CocoKeypointCategory,
Expand All @@ -19,7 +20,7 @@


def cvat_image_to_coco( # noqa: C901, too complex
cvat_xml_path: str, add_bbox: bool = True, add_segmentation: bool = True
cvat_xml_path: str, coco_configuration_json_path: str, add_bbox: bool = True, add_segmentation: bool = True
) -> dict:
"""Function that converts an annotation XML in the CVAT 1.1 Image format to the COCO keypoints format.
If you don't need keypoints, you can simply use CVAT to create a COCOinstances format and should not use this function!
Expand All @@ -31,6 +32,7 @@ def cvat_image_to_coco( # noqa: C901, too complex
Args:
cvat_xml_path (str): _description_
coco_configuration_json_path (str): path to the COCO categories to use for annotating this dataset.
add_bbox (bool): add bounding box annotations to the COCO dataset, requires all keypoint annotations to have a bbox annotation
add_segmentation (bool): add segmentation annotations to the COCO dataset, requires all keypoint annotations to have a mask annotation. Bboxes will be created from the segmentation masks.
Expand All @@ -44,29 +46,18 @@ def cvat_image_to_coco( # noqa: C901, too complex
coco_annotations: List[CocoKeypointAnnotation] = []
coco_categories: List[CocoKeypointCategory] = []

annotation_id_counter = 1 # counter for the annotation ID

# create the COCOKeypointCatgegories
categories_dict = defaultdict(list)

for annotation_category in cvat_parsed.annotations.meta.get_job_or_task().labels.label:
assert isinstance(annotation_category, LabelItem)
category_str, annotation_name = annotation_category.name.split(".")
categories_dict[category_str].append(annotation_name)
# load the COCO categories from the configuration file
with open(coco_configuration_json_path, "r") as file:
coco_categories_config = json.load(file)
for category_dict in coco_categories_config["categories"]:
category = CocoCategory(**category_dict)
coco_categories.append(category)

for category_str, semantic_types in categories_dict.items():
if add_bbox:
assert "bbox" in semantic_types, "bbox annotations are required"
if add_segmentation:
assert "mask" in semantic_types, "segmentation masks are required"
_validate_coco_categories_are_in_cvat(
cvat_parsed, coco_categories, add_bbox=add_bbox, add_segmentation=add_segmentation
)

semantic_types = [
semantic_type for semantic_type in semantic_types if semantic_type != "bbox" and semantic_type != "mask"
]
coco_category = CocoKeypointCategory(
name=category_str, id=len(coco_categories) + 1, keypoints=semantic_types, supercategory=""
)
coco_categories.append(coco_category)
annotation_id_counter = 1 # counter for the annotation ID

# iterate over all cvat annotations (grouped per image)
# and create the COCO Keypoint annotations
Expand Down Expand Up @@ -124,6 +115,31 @@ def cvat_image_to_coco( # noqa: C901, too complex
####################


def _validate_coco_categories_are_in_cvat(
cvat_parsed: CVATImagesParser, coco_categories: List[CocoKeypointCategory], add_bbox: bool, add_segmentation: bool
) -> None:
# gather the annotation from CVAT
cvat_categories_dict = defaultdict(list)

for annotation_category in cvat_parsed.annotations.meta.get_job_or_task().labels.label:
assert isinstance(annotation_category, LabelItem)
category_str, annotation_name = annotation_category.name.split(".")
cvat_categories_dict[category_str].append(annotation_name)

for category_str, semantic_types in cvat_categories_dict.items():
if add_bbox:
assert "bbox" in semantic_types, "bbox annotations are required"
if add_segmentation:
assert "mask" in semantic_types, "segmentation masks are required"
# find the matching COCO category
coco_category = None
for coco_category in coco_categories:
if coco_category.name == category_str:
break
for category_keypoint in coco_category.keypoints:
assert category_keypoint in semantic_types, f"semantic type {category_keypoint.name} not found"


def _get_n_category_instances_in_image(cvat_image: ImageItem, category_name: str) -> int:
"""returns the number of instances for the specified category in the CVAT ImageItem.
Expand Down Expand Up @@ -197,7 +213,7 @@ def _get_segmentation_for_instance_from_cvat_image(cvat_image: ImageItem, instan
"""returns the segmentation polygon for the instance in the cvat image."""
instance_id_str = str(instance_id)
if cvat_image.polygon is None:
raise ValueError("segmentation annotations are required for image {cvat_image.name}")
raise ValueError(f"segmentation annotations are required for image {cvat_image.name}")
if not isinstance(cvat_image.polygon, list):
if instance_id_str == cvat_image.polygon.group_id:
polygon_str = cvat_image.polygon.points
Expand Down Expand Up @@ -274,6 +290,8 @@ def _extract_coco_keypoint_from_cvat_point(cvat_point: Point) -> List:
path = pathlib.Path(__file__).parent.absolute()
cvat_xml_file = str(path / "example" / "annotations.xml")

coco = cvat_image_to_coco(cvat_xml_file, add_bbox=True, add_segmentation=False)
coco_categories_file = str(path / "example" / "coco_categories.json")

coco = cvat_image_to_coco(cvat_xml_file, coco_categories_file, add_bbox=True, add_segmentation=False)
with open("coco.json", "w") as file:
json.dump(coco, file)
Loading

0 comments on commit 60abe9b

Please sign in to comment.