Skip to content

Commit

Permalink
Update dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghao9610 committed Sep 15, 2024
1 parent a43ebae commit a805a75
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
21 changes: 11 additions & 10 deletions ovdino/detrex/data/datasets/coco_ovd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


logger = logging.getLogger(__name__)
print_sample = True

__all__ = ["load_coco_json", "register_coco_ovd_instances"]

Expand Down Expand Up @@ -73,6 +74,7 @@ def load_coco_json(
The results do not have the "image" field.
"""
timer = Timer()
global print_sample
json_file = PathManager.get_local_path(json_file)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file)
Expand Down Expand Up @@ -238,23 +240,20 @@ def load_coco_json(
record["annotations"] = objs

if test_mode and num_sampled_classes > 0:
assert template == "identity"
obj_cat_ids = [obj["category_id"] for obj in objs]
sampled_cat_names = [
clean_words_or_phrase(cat_name) for _, cat_name in id2name.items()
]
sampled_cat_names = [
[template.format(cat_name) for template in template_meta[template]]
for cat_name in sampled_cat_names
]

# sample category from category_list
if not test_mode and num_sampled_classes > 0:
obj_cat_ids = [obj["category_id"] for obj in objs]
continous_cat_ids = sorted(id_map.values())
continuous_cat_ids = sorted(id_map.values())
pos_cat_ids = set(obj_cat_ids)
assert len(pos_cat_ids) <= num_sampled_classes
neg_cat_ids = random.sample(
set(continous_cat_ids) - pos_cat_ids,
set(continuous_cat_ids) - pos_cat_ids,
num_sampled_classes - len(pos_cat_ids),
)
sampled_cat_ids = list(pos_cat_ids) + list(neg_cat_ids)
Expand All @@ -275,10 +274,12 @@ def load_coco_json(

dataset_dicts.append(record)

rank0_print(
f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n"
+ f"Sample: {sampled_cat_names}"
)
if print_sample:
rank0_print(
f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n"
+ f"Sample: {sampled_cat_names}"
)
print_sample = False

if num_instances_without_valid_segmentation > 0:
logger.warning(
Expand Down
21 changes: 11 additions & 10 deletions ovdino/detrex/data/datasets/custom_ovd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


logger = logging.getLogger(__name__)
print_sample = True

__all__ = ["load_custom_json", "register_custom_ovd_instances"]

Expand Down Expand Up @@ -73,6 +74,7 @@ def load_custom_json(
The results do not have the "image" field.
"""
timer = Timer()
global print_sample
json_file = PathManager.get_local_path(json_file)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file)
Expand Down Expand Up @@ -238,23 +240,20 @@ def load_custom_json(
record["annotations"] = objs

if test_mode and num_sampled_classes > 0:
assert template == "identity"
obj_cat_ids = [obj["category_id"] for obj in objs]
sampled_cat_names = [
clean_words_or_phrase(cat_name) for _, cat_name in id2name.items()
]
sampled_cat_names = [
[template.format(cat_name) for template in template_meta[template]]
for cat_name in sampled_cat_names
]

# sample category from category_list
if not test_mode and num_sampled_classes > 0:
obj_cat_ids = [obj["category_id"] for obj in objs]
continous_cat_ids = sorted(id_map.values())
continuous_cat_ids = sorted(id_map.values())
pos_cat_ids = set(obj_cat_ids)
assert len(pos_cat_ids) <= num_sampled_classes
neg_cat_ids = random.sample(
set(continous_cat_ids) - pos_cat_ids,
set(continuous_cat_ids) - pos_cat_ids,
num_sampled_classes - len(pos_cat_ids),
)
sampled_cat_ids = list(pos_cat_ids) + list(neg_cat_ids)
Expand All @@ -275,10 +274,12 @@ def load_custom_json(

dataset_dicts.append(record)

rank0_print(
f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n"
+ f"Sample: {sampled_cat_names}"
)
if print_sample:
rank0_print(
f"Loaded {len(dataset_dicts)} data points from {dataset_name}, template: {template}\n"
+ f"Sample: {sampled_cat_names}"
)
print_sample = False

if num_instances_without_valid_segmentation > 0:
logger.warning(
Expand Down

0 comments on commit a805a75

Please sign in to comment.