diff --git a/darwin/client.py b/darwin/client.py index dfd987c63..168fa5287 100644 --- a/darwin/client.py +++ b/darwin/client.py @@ -345,6 +345,11 @@ def create_dataset(self, name: str, team: Optional[str] = None) -> RemoteDataset client=self, ) + def fetch_remote_classes(self, team: Optional[str] = None): + """Fetches all remote classes on the remote dataset""" + team_slug = self.config.get_team(team or self.default_team)["slug"] + return self.get(f"/teams/{team_slug}/annotation_classes?include_tags=true")["annotation_classes"] + def load_feature_flags(self, team: Optional[str] = None): """Gets current features enabled for a team""" team_slug = self.config.get_team(team or self.default_team)["slug"] diff --git a/darwin/dataset/remote_dataset.py b/darwin/dataset/remote_dataset.py index ae97f360f..376304f1c 100644 --- a/darwin/dataset/remote_dataset.py +++ b/darwin/dataset/remote_dataset.py @@ -374,15 +374,42 @@ def create_annotation_class(self, name: str, type: str): "name": name, "metadata": {"_color": "auto"}, "annotation_type_ids": [type_id], + "datasets": [{"id": self.dataset_id}], }, error_handlers=[name_taken, validation_error], ) - def fetch_remote_classes(self): - """Fetches all remote classes on the remote dataset""" - return self.client.get(f"/datasets/{self.dataset_id}/annotation_classes?include_tags=true")[ - "annotation_classes" + def add_annotation_class(self, annotation_class): + # Waiting for a better api for setting classes + # in the meantime this will do + all_classes = self.fetch_remote_classes(True) + match = [ + cls + for cls in all_classes + if cls["name"] == annotation_class.name + and annotation_class.annotation_internal_type in cls["annotation_types"] ] + if not match: + raise ValueError(f"Unknown annotation class {annotation_class.name}, id: {annotation_class.id}") + + datasets = match[0]["datasets"] + # check that we are not already part of the dataset + for dataset in datasets: + if dataset["id"] == self.dataset_id: + return + datasets.append({"id": self.dataset_id}) + return self.client.put(f"/annotation_classes/{match[0]['id']}", {"datasets": datasets, "id": match[0]["id"]}) + + def fetch_remote_classes(self, team_wide=False): + """Fetches all remote classes on the remote dataset""" + all_classes = self.client.fetch_remote_classes() + classes_to_return = [] + for cls in all_classes: + belongs_to_current_dataset = any([dataset["id"] == self.dataset_id for dataset in cls["datasets"]]) + cls["available"] = belongs_to_current_dataset + if team_wide or belongs_to_current_dataset: + classes_to_return.append(cls) + return classes_to_return def fetch_remote_attributes(self): """Fetches all remote attributes on the remote dataset""" diff --git a/darwin/importer/importer.py b/darwin/importer/importer.py index 63ba5c5f0..9a05867e9 100644 --- a/darwin/importer/importer.py +++ b/darwin/importer/importer.py @@ -11,14 +11,24 @@ def build_main_annotations_lookup_table(annotation_classes): + MAIN_ANNOTATION_TYPES = [ + "bounding_box", + "cuboid", + "ellipse", + "keypoint", + "line", + "link", + "polygon", + "skeleton", + "tag", + ] lookup = {} for cls in annotation_classes: for annotation_type in cls["annotation_types"]: - if annotation_type["granularity"] == "main": - if annotation_type["name"] not in lookup: - lookup[annotation_type["name"]] = {} - - lookup[annotation_type["name"]][cls["name"]] = cls["id"] + if annotation_type in MAIN_ANNOTATION_TYPES: + if annotation_type not in lookup: + lookup[annotation_type] = {} + lookup[annotation_type][cls["name"]] = cls["id"] return lookup @@ -65,6 +75,29 @@ def get_remote_files(dataset, filenames): return remote_files +def _resolve_annotation_classes(annotation_classes: List[dt.AnnotationClass], classes_in_dataset, classes_in_team): + local_classes_not_in_dataset = set() + local_classes_not_in_team = set() + + for cls in annotation_classes: + annotation_type = cls.annotation_internal_type or cls.annotation_type + # Only add the new class if it doesn't exist remotely already + if annotation_type in classes_in_dataset and cls.name in classes_in_dataset[annotation_type]: + continue + + # Only add the new class if it's not included in the list of the missing classes already + if cls.name in [missing_class.name for missing_class in local_classes_not_in_dataset]: + continue + if cls.name in [missing_class.name for missing_class in local_classes_not_in_team]: + continue + + if annotation_type in classes_in_team and cls.name in classes_in_team[annotation_type]: + local_classes_not_in_dataset.add(cls) + else: + local_classes_not_in_team.add(cls) + return local_classes_not_in_dataset, local_classes_not_in_team + + def import_annotations( dataset: "RemoteDataset", importer: Callable[[Path], Union[List[dt.AnnotationFile], dt.AnnotationFile, None]], @@ -72,7 +105,9 @@ def import_annotations( append: bool, ): print("Fetching remote class list...") - remote_classes = build_main_annotations_lookup_table(dataset.fetch_remote_classes()) + team_classes = dataset.fetch_remote_classes(True) + classes_in_dataset = build_main_annotations_lookup_table([cls for cls in team_classes if cls["available"]]) + classes_in_team = build_main_annotations_lookup_table([cls for cls in team_classes if not cls["available"]]) attributes = build_attribute_lookup(dataset) print("Retrieving local annotations ...") @@ -100,34 +135,37 @@ def import_annotations( if not secure_continue_request(): return - local_classes_missing_remotely = set() - for local_file in local_files: - for cls in local_file.annotation_classes: - annotation_type = cls.annotation_internal_type or cls.annotation_type - # Only add the new class if it doesn't exist remotely already - if annotation_type in remote_classes and cls.name in remote_classes[annotation_type]: - continue - # Only add the new class if it's not included in the list of the missing classes already - if cls.name in [missing_class.name for missing_class in local_classes_missing_remotely]: - continue - local_classes_missing_remotely.add(cls) + local_classes_not_in_dataset, local_classes_not_in_team = _resolve_annotation_classes( + [annotation_class for file in local_files for annotation_class in file.annotation_classes], + classes_in_dataset, + classes_in_team, + ) - print(f"{len(local_classes_missing_remotely)} classes are missing remotely.") - if local_classes_missing_remotely: + print(f"{len(local_classes_not_in_team)} classes needs to be created.") + print(f"{len(local_classes_not_in_dataset)} classes needs to be added to {dataset.identifier}") + + if local_classes_not_in_team: print("About to create the following classes") - for missing_class in local_classes_missing_remotely: + for missing_class in local_classes_not_in_team: print( f"\t{missing_class.name}, type: {missing_class.annotation_internal_type or missing_class.annotation_type}" ) if not secure_continue_request(): return - for missing_class in local_classes_missing_remotely: + for missing_class in local_classes_not_in_team: dataset.create_annotation_class( missing_class.name, missing_class.annotation_internal_type or missing_class.annotation_type ) - - # Refetch classes to update mappings - remote_classes = build_main_annotations_lookup_table(dataset.fetch_remote_classes()) + if local_classes_not_in_dataset: + print(f"About to add the following classes to {dataset.identifier}") + for cls in local_classes_not_in_dataset: + dataset.add_annotation_class(cls) + + # Refetch classes to update mappings + if local_classes_not_in_team or local_classes_not_in_dataset: + remote_classes = build_main_annotations_lookup_table(dataset.fetch_remote_classes()) + else: + remote_classes = build_main_annotations_lookup_table(team_classes) # Need to re parse the files since we didn't save the annotations in memory for local_path in set(local_file.path for local_file in local_files): diff --git a/darwin/utils.py b/darwin/utils.py index 7a0ef016d..c90c64a0e 100644 --- a/darwin/utils.py +++ b/darwin/utils.py @@ -173,7 +173,6 @@ def parse_darwin_json(path: Union[str, Path], count: int): def parse_darwin_image(path, data, count): annotations = list(filter(None, map(parse_darwin_annotation, data["annotations"]))) annotation_classes = set([annotation.annotation_class for annotation in annotations]) - return dt.AnnotationFile( path, get_local_filename(data["image"]),