diff --git a/ovdino/detrex/data/datasets/coco_ovd.py b/ovdino/detrex/data/datasets/coco_ovd.py index a65def9..5fa91eb 100644 --- a/ovdino/detrex/data/datasets/coco_ovd.py +++ b/ovdino/detrex/data/datasets/coco_ovd.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) +print_sample = True __all__ = ["load_coco_json", "register_coco_ovd_instances"] @@ -73,6 +74,7 @@ def load_coco_json( The results do not have the "image" field. """ timer = Timer() + global print_sample json_file = PathManager.get_local_path(json_file) with contextlib.redirect_stdout(io.StringIO()): coco_api = COCO(json_file) @@ -238,23 +240,20 @@ def load_coco_json( record["annotations"] = objs if test_mode and num_sampled_classes > 0: + assert template == "identity" obj_cat_ids = [obj["category_id"] for obj in objs] sampled_cat_names = [ clean_words_or_phrase(cat_name) for _, cat_name in id2name.items() ] - sampled_cat_names = [ - [template.format(cat_name) for template in template_meta[template]] - for cat_name in sampled_cat_names - ] # sample category from category_list if not test_mode and num_sampled_classes > 0: obj_cat_ids = [obj["category_id"] for obj in objs] - continous_cat_ids = sorted(id_map.values()) + continuous_cat_ids = sorted(id_map.values()) pos_cat_ids = set(obj_cat_ids) assert len(pos_cat_ids) <= num_sampled_classes neg_cat_ids = random.sample( - set(continous_cat_ids) - pos_cat_ids, + set(continuous_cat_ids) - pos_cat_ids, num_sampled_classes - len(pos_cat_ids), ) sampled_cat_ids = list(pos_cat_ids) + list(neg_cat_ids) @@ -275,10 +274,12 @@ def load_coco_json( dataset_dicts.append(record) - rank0_print( - f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n" - + f"Sample: {sampled_cat_names}" - ) + if print_sample: + rank0_print( + f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n" + + f"Sample: {sampled_cat_names}" + ) + print_sample = False if num_instances_without_valid_segmentation > 0: logger.warning( diff --git a/ovdino/detrex/data/datasets/custom_ovd.py b/ovdino/detrex/data/datasets/custom_ovd.py index 5c2e38b..96785ee 100644 --- a/ovdino/detrex/data/datasets/custom_ovd.py +++ b/ovdino/detrex/data/datasets/custom_ovd.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) +print_sample = True __all__ = ["load_custom_json", "register_custom_ovd_instances"] @@ -73,6 +74,7 @@ def load_custom_json( The results do not have the "image" field. """ timer = Timer() + global print_sample json_file = PathManager.get_local_path(json_file) with contextlib.redirect_stdout(io.StringIO()): coco_api = COCO(json_file) @@ -238,23 +240,20 @@ def load_custom_json( record["annotations"] = objs if test_mode and num_sampled_classes > 0: + assert template == "identity" obj_cat_ids = [obj["category_id"] for obj in objs] sampled_cat_names = [ clean_words_or_phrase(cat_name) for _, cat_name in id2name.items() ] - sampled_cat_names = [ - [template.format(cat_name) for template in template_meta[template]] - for cat_name in sampled_cat_names - ] # sample category from category_list if not test_mode and num_sampled_classes > 0: obj_cat_ids = [obj["category_id"] for obj in objs] - continous_cat_ids = sorted(id_map.values()) + continuous_cat_ids = sorted(id_map.values()) pos_cat_ids = set(obj_cat_ids) assert len(pos_cat_ids) <= num_sampled_classes neg_cat_ids = random.sample( - set(continous_cat_ids) - pos_cat_ids, + set(continuous_cat_ids) - pos_cat_ids, num_sampled_classes - len(pos_cat_ids), ) sampled_cat_ids = list(pos_cat_ids) + list(neg_cat_ids) @@ -275,10 +274,12 @@ def load_custom_json( dataset_dicts.append(record) - rank0_print( - f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n" - + f"Sample: {sampled_cat_names}" - ) + if print_sample: + rank0_print( + f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n" + + f"Sample: {sampled_cat_names}" + ) + print_sample = False if num_instances_without_valid_segmentation > 0: logger.warning(