Skip to content

Commit

Permalink
support for importing data with team wide classes (#179)
Browse files Browse the repository at this point in the history
* support for importing data with team wide classes
  • Loading branch information
simedw authored Aug 4, 2021
1 parent e096889 commit 6e83578
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 29 deletions.
5 changes: 5 additions & 0 deletions darwin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ def create_dataset(self, name: str, team: Optional[str] = None) -> RemoteDataset
client=self,
)

def fetch_remote_classes(self, team: Optional[str] = None):
"""Fetches all remote classes on the remote dataset"""
team_slug = self.config.get_team(team or self.default_team)["slug"]
return self.get(f"/teams/{team_slug}/annotation_classes?include_tags=true")["annotation_classes"]

def load_feature_flags(self, team: Optional[str] = None):
"""Gets current features enabled for a team"""
team_slug = self.config.get_team(team or self.default_team)["slug"]
Expand Down
35 changes: 31 additions & 4 deletions darwin/dataset/remote_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,42 @@ def create_annotation_class(self, name: str, type: str):
"name": name,
"metadata": {"_color": "auto"},
"annotation_type_ids": [type_id],
"datasets": [{"id": self.dataset_id}],
},
error_handlers=[name_taken, validation_error],
)

def fetch_remote_classes(self):
"""Fetches all remote classes on the remote dataset"""
return self.client.get(f"/datasets/{self.dataset_id}/annotation_classes?include_tags=true")[
"annotation_classes"
def add_annotation_class(self, annotation_class):
# Waiting for a better api for setting classes
# in the meantime this will do
all_classes = self.fetch_remote_classes(True)
match = [
cls
for cls in all_classes
if cls["name"] == annotation_class.name
and annotation_class.annotation_internal_type in cls["annotation_types"]
]
if not match:
raise ValueError(f"Unknown annotation class {annotation_class.name}, id: {annotation_class.id}")

datasets = match[0]["datasets"]
# check that we are not already part of the dataset
for dataset in datasets:
if dataset["id"] == self.dataset_id:
return
datasets.append({"id": self.dataset_id})
return self.client.put(f"/annotation_classes/{match[0]['id']}", {"datasets": datasets, "id": match[0]["id"]})

def fetch_remote_classes(self, team_wide=False):
"""Fetches all remote classes on the remote dataset"""
all_classes = self.client.fetch_remote_classes()
classes_to_return = []
for cls in all_classes:
belongs_to_current_dataset = any([dataset["id"] == self.dataset_id for dataset in cls["datasets"]])
cls["available"] = belongs_to_current_dataset
if team_wide or belongs_to_current_dataset:
classes_to_return.append(cls)
return classes_to_return

def fetch_remote_attributes(self):
"""Fetches all remote attributes on the remote dataset"""
Expand Down
86 changes: 62 additions & 24 deletions darwin/importer/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@


def build_main_annotations_lookup_table(annotation_classes):
MAIN_ANNOTATION_TYPES = [
"bounding_box",
"cuboid",
"ellipse",
"keypoint",
"line",
"link",
"polygon",
"skeleton",
"tag",
]
lookup = {}
for cls in annotation_classes:
for annotation_type in cls["annotation_types"]:
if annotation_type["granularity"] == "main":
if annotation_type["name"] not in lookup:
lookup[annotation_type["name"]] = {}

lookup[annotation_type["name"]][cls["name"]] = cls["id"]
if annotation_type in MAIN_ANNOTATION_TYPES:
if annotation_type not in lookup:
lookup[annotation_type] = {}
lookup[annotation_type][cls["name"]] = cls["id"]
return lookup


Expand Down Expand Up @@ -65,14 +75,39 @@ def get_remote_files(dataset, filenames):
return remote_files


def _resolve_annotation_classes(annotation_classes: List[dt.AnnotationClass], classes_in_dataset, classes_in_team):
local_classes_not_in_dataset = set()
local_classes_not_in_team = set()

for cls in annotation_classes:
annotation_type = cls.annotation_internal_type or cls.annotation_type
# Only add the new class if it doesn't exist remotely already
if annotation_type in classes_in_dataset and cls.name in classes_in_dataset[annotation_type]:
continue

# Only add the new class if it's not included in the list of the missing classes already
if cls.name in [missing_class.name for missing_class in local_classes_not_in_dataset]:
continue
if cls.name in [missing_class.name for missing_class in local_classes_not_in_team]:
continue

if annotation_type in classes_in_team and cls.name in classes_in_team[annotation_type]:
local_classes_not_in_dataset.add(cls)
else:
local_classes_not_in_team.add(cls)
return local_classes_not_in_dataset, local_classes_not_in_team


def import_annotations(
dataset: "RemoteDataset",
importer: Callable[[Path], Union[List[dt.AnnotationFile], dt.AnnotationFile, None]],
file_paths: List[Union[str, Path]],
append: bool,
):
print("Fetching remote class list...")
remote_classes = build_main_annotations_lookup_table(dataset.fetch_remote_classes())
team_classes = dataset.fetch_remote_classes(True)
classes_in_dataset = build_main_annotations_lookup_table([cls for cls in team_classes if cls["available"]])
classes_in_team = build_main_annotations_lookup_table([cls for cls in team_classes if not cls["available"]])
attributes = build_attribute_lookup(dataset)

print("Retrieving local annotations ...")
Expand Down Expand Up @@ -100,34 +135,37 @@ def import_annotations(
if not secure_continue_request():
return

local_classes_missing_remotely = set()
for local_file in local_files:
for cls in local_file.annotation_classes:
annotation_type = cls.annotation_internal_type or cls.annotation_type
# Only add the new class if it doesn't exist remotely already
if annotation_type in remote_classes and cls.name in remote_classes[annotation_type]:
continue
# Only add the new class if it's not included in the list of the missing classes already
if cls.name in [missing_class.name for missing_class in local_classes_missing_remotely]:
continue
local_classes_missing_remotely.add(cls)
local_classes_not_in_dataset, local_classes_not_in_team = _resolve_annotation_classes(
[annotation_class for file in local_files for annotation_class in file.annotation_classes],
classes_in_dataset,
classes_in_team,
)

print(f"{len(local_classes_missing_remotely)} classes are missing remotely.")
if local_classes_missing_remotely:
print(f"{len(local_classes_not_in_team)} classes needs to be created.")
print(f"{len(local_classes_not_in_dataset)} classes needs to be added to {dataset.identifier}")

if local_classes_not_in_team:
print("About to create the following classes")
for missing_class in local_classes_missing_remotely:
for missing_class in local_classes_not_in_team:
print(
f"\t{missing_class.name}, type: {missing_class.annotation_internal_type or missing_class.annotation_type}"
)
if not secure_continue_request():
return
for missing_class in local_classes_missing_remotely:
for missing_class in local_classes_not_in_team:
dataset.create_annotation_class(
missing_class.name, missing_class.annotation_internal_type or missing_class.annotation_type
)

# Refetch classes to update mappings
remote_classes = build_main_annotations_lookup_table(dataset.fetch_remote_classes())
if local_classes_not_in_dataset:
print(f"About to add the following classes to {dataset.identifier}")
for cls in local_classes_not_in_dataset:
dataset.add_annotation_class(cls)

# Refetch classes to update mappings
if local_classes_not_in_team or local_classes_not_in_dataset:
remote_classes = build_main_annotations_lookup_table(dataset.fetch_remote_classes())
else:
remote_classes = build_main_annotations_lookup_table(team_classes)

# Need to re parse the files since we didn't save the annotations in memory
for local_path in set(local_file.path for local_file in local_files):
Expand Down
1 change: 0 additions & 1 deletion darwin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def parse_darwin_json(path: Union[str, Path], count: int):
def parse_darwin_image(path, data, count):
annotations = list(filter(None, map(parse_darwin_annotation, data["annotations"])))
annotation_classes = set([annotation.annotation_class for annotation in annotations])

return dt.AnnotationFile(
path,
get_local_filename(data["image"]),
Expand Down

0 comments on commit 6e83578

Please sign in to comment.