From 7f6757c968432f25e0358124f0926ef6a33bcf8d Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 1 Oct 2024 10:42:02 +0200 Subject: [PATCH] [datasets] Allow detection task for built-in datasets (#1717) --- Makefile | 2 +- docs/source/using_doctr/using_datasets.rst | 19 +- doctr/datasets/cord.py | 11 +- doctr/datasets/funsd.py | 12 +- doctr/datasets/ic03.py | 12 +- doctr/datasets/ic13.py | 11 +- doctr/datasets/iiit5k.py | 42 ++-- doctr/datasets/imgur5k.py | 11 +- doctr/datasets/sroie.py | 12 +- doctr/datasets/svhn.py | 12 +- doctr/datasets/svt.py | 12 +- doctr/datasets/synthtext.py | 12 +- doctr/datasets/utils.py | 9 +- doctr/datasets/wildreceipt.py | 13 +- references/detection/evaluate_pytorch.py | 14 +- references/detection/evaluate_tensorflow.py | 14 +- tests/pytorch/test_datasets_pt.py | 251 ++++++++++++++++---- tests/tensorflow/test_datasets_tf.py | 242 +++++++++++++++---- 18 files changed, 586 insertions(+), 125 deletions(-) diff --git a/Makefile b/Makefile index 428bc4fc4a..04662b9613 100644 --- a/Makefile +++ b/Makefile @@ -6,8 +6,8 @@ quality: # this target runs checks on all files and potentially modifies some of them style: - ruff check --fix . ruff format . + ruff check --fix . # Run tests for the library test: diff --git a/docs/source/using_doctr/using_datasets.rst b/docs/source/using_doctr/using_datasets.rst index 52c5f7e24d..5fd5dc2776 100644 --- a/docs/source/using_doctr/using_datasets.rst +++ b/docs/source/using_doctr/using_datasets.rst @@ -48,9 +48,9 @@ This datasets contains the information to train or validate a text detection mod from doctr.datasets import CORD # Load straight boxes - train_set = CORD(train=True, download=True) + train_set = CORD(train=True, download=True, detection_task=True) # Load rotated boxes - train_set = CORD(train=True, download=True, use_polygons=True) + train_set = CORD(train=True, download=True, use_polygons=True, detection_task=True) img, target = train_set[0] @@ -99,6 +99,21 @@ This datasets contains the information to train or validate a text recognition m img, target = train_set[0] +OCR +^^^ + +The same dataset table as for detection, but with information about the bounding boxes and labels. + +.. code:: python3 + + from doctr.datasets import CORD + # Load straight boxes + train_set = CORD(train=True, download=True) + # Load rotated boxes + train_set = CORD(train=True, download=True, use_polygons=True) + img, target = train_set[0] + + Object Detection ^^^^^^^^^^^^^^^^ diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py index b88fbb28e8..9e2188727d 100644 --- a/doctr/datasets/cord.py +++ b/doctr/datasets/cord.py @@ -33,6 +33,7 @@ class CORD(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -53,6 +54,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, name = self.TRAIN if train else self.TEST @@ -64,10 +66,15 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) # List images tmp_root = os.path.join(self.root, "image") - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] self.train = train np_dtype = np.float32 for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))): @@ -109,6 +116,8 @@ def __init__( ) for crop, label in zip(crops, list(text_targets)): self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0))) else: self.data.append(( img_path, diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py index 0580b473a7..3bd8b088f9 100644 --- a/doctr/datasets/funsd.py +++ b/doctr/datasets/funsd.py @@ -33,6 +33,7 @@ class FUNSD(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -45,6 +46,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -55,6 +57,12 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train np_dtype = np.float32 @@ -63,7 +71,7 @@ def __init__( # # List images tmp_root = os.path.join(self.root, subfolder, "images") - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))): # File existence check if not os.path.exists(os.path.join(tmp_root, img_path)): @@ -100,6 +108,8 @@ def __init__( # filter labels with unknown characters if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]): self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype))) else: self.data.append(( img_path, diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py index 6f080e4d45..b3af8d958c 100644 --- a/doctr/datasets/ic03.py +++ b/doctr/datasets/ic03.py @@ -32,6 +32,7 @@ class IC03(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -51,6 +52,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, file_name = self.TRAIN if train else self.TEST @@ -62,8 +64,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 # Load xml data @@ -117,6 +125,8 @@ def __init__( for crop, label in zip(crops, labels): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((name.text, boxes)) else: self.data.append((name.text, dict(boxes=boxes, labels=labels))) diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py index 81ba62f001..0082d92316 100644 --- a/doctr/datasets/ic13.py +++ b/doctr/datasets/ic13.py @@ -38,6 +38,7 @@ class IC13(AbstractDataset): label_folder: folder with all annotation files for the images use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `AbstractDataset`. """ @@ -47,11 +48,17 @@ def __init__( label_folder: str, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) # File existence check if not os.path.exists(label_folder) or not os.path.exists(img_folder): @@ -59,7 +66,7 @@ def __init__( f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}" ) - self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 img_names = os.listdir(img_folder) @@ -95,5 +102,7 @@ def __init__( crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets) for crop, label in zip(crops, labels): self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, box_targets)) else: self.data.append((img_path, dict(boxes=box_targets, labels=labels))) diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py index 2b33ebb50b..89619dd8aa 100644 --- a/doctr/datasets/iiit5k.py +++ b/doctr/datasets/iiit5k.py @@ -34,6 +34,7 @@ class IIIT5K(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -45,6 +46,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -55,6 +57,12 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train # Load mat data @@ -62,7 +70,7 @@ def __init__( mat_file = "trainCharBound" if self.train else "testCharBound" mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0] - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)): @@ -73,24 +81,26 @@ def __init__( if not os.path.exists(os.path.join(tmp_root, _raw_path)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}") + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [ + [ + [box[0], box[1]], + [box[0] + box[2], box[1]], + [box[0] + box[2], box[1] + box[3]], + [box[0], box[1] + box[3]], + ] + for box in box_targets + ] + else: + # xmin, ymin, xmax, ymax + box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] + if recognition_task: self.data.append((_raw_path, _raw_label)) + elif detection_task: + self.data.append((_raw_path, np.asarray(box_targets, dtype=np_dtype))) else: - if use_polygons: - # (x, y) coordinates of top left, top right, bottom right, bottom left corners - box_targets = [ - [ - [box[0], box[1]], - [box[0] + box[2], box[1]], - [box[0] + box[2], box[1] + box[3]], - [box[0], box[1] + box[3]], - ] - for box in box_targets - ] - else: - # xmin, ymin, xmax, ymax - box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] - # label are casted to list where each char corresponds to the character's bounding box self.data.append(( _raw_path, diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py index 3e7cf0e07b..4dcfec02b8 100644 --- a/doctr/datasets/imgur5k.py +++ b/doctr/datasets/imgur5k.py @@ -46,6 +46,7 @@ class IMGUR5K(AbstractDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `AbstractDataset`. """ @@ -56,17 +57,23 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) # File existence check if not os.path.exists(label_path) or not os.path.exists(img_folder): raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") - self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] self.train = train np_dtype = np.float32 @@ -132,6 +139,8 @@ def __init__( tmp_img = Image.fromarray(crop) tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png")) reco_images_counter += 1 + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype))) else: self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels))) diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py index e72fde68a1..d6e7dac83b 100644 --- a/doctr/datasets/sroie.py +++ b/doctr/datasets/sroie.py @@ -33,6 +33,7 @@ class SROIE(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -52,6 +53,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, name = self.TRAIN if train else self.TEST @@ -63,10 +65,16 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train tmp_root = os.path.join(self.root, "images") - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))): @@ -94,6 +102,8 @@ def __init__( for crop, label in zip(crops, labels): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, coords)) else: self.data.append((img_path, dict(boxes=coords, labels=labels))) diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py index 57085c5213..595113a42d 100644 --- a/doctr/datasets/svhn.py +++ b/doctr/datasets/svhn.py @@ -32,6 +32,7 @@ class SVHN(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -52,6 +53,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, name = self.TRAIN if train else self.TEST @@ -63,8 +65,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 tmp_root = os.path.join(self.root, "train" if train else "test") @@ -122,6 +130,8 @@ def __init__( for crop, label in zip(crops, label_targets): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((img_name, box_targets)) else: self.data.append((img_name, dict(boxes=box_targets, labels=label_targets))) diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py index 3eb7b6d599..b9e88b4cc1 100644 --- a/doctr/datasets/svt.py +++ b/doctr/datasets/svt.py @@ -32,6 +32,7 @@ class SVT(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -43,6 +44,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -53,8 +55,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 # Load xml data @@ -108,6 +116,8 @@ def __init__( for crop, label in zip(crops, labels): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((name.text, boxes)) else: self.data.append((name.text, dict(boxes=boxes, labels=labels))) diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py index a60e22e832..8be11e2303 100644 --- a/doctr/datasets/synthtext.py +++ b/doctr/datasets/synthtext.py @@ -35,6 +35,7 @@ class SynthText(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -46,6 +47,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -56,8 +58,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 # Load mat data @@ -111,6 +119,8 @@ def __init__( tmp_img = Image.fromarray(crop) tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png")) reco_images_counter += 1 + elif detection_task: + self.data.append((img_path[0], np.asarray(word_boxes, dtype=np_dtype))) else: self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels))) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index 860e19a229..75182a227a 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -169,8 +169,13 @@ def encode_sequences( return encoded_data -def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]: - target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) +def convert_target_to_relative( + img: ImageTensor, target: Union[np.ndarray, Dict[str, Any]] +) -> Tuple[ImageTensor, Union[Dict[str, Any], np.ndarray]]: + if isinstance(target, np.ndarray): + target = convert_to_relative_coords(target, get_img_shape(img)) + else: + target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) return img, target diff --git a/doctr/datasets/wildreceipt.py b/doctr/datasets/wildreceipt.py index 19108d7761..685266931a 100644 --- a/doctr/datasets/wildreceipt.py +++ b/doctr/datasets/wildreceipt.py @@ -40,6 +40,7 @@ class WILDRECEIPT(AbstractDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `AbstractDataset`. """ @@ -50,11 +51,19 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs ) + # Task check + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + # File existence check if not os.path.exists(label_path) or not os.path.exists(img_folder): raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") @@ -62,7 +71,7 @@ def __init__( tmp_root = img_folder self.train = train np_dtype = np.float32 - self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] with open(label_path, "r") as file: data = file.read() @@ -100,6 +109,8 @@ def __init__( for crop, label in zip(crops, list(text_targets)): if label and " " not in label: self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0))) else: self.data.append(( img_path, diff --git a/references/detection/evaluate_pytorch.py b/references/detection/evaluate_pytorch.py index 15f60df664..10b20e40cc 100644 --- a/references/detection/evaluate_pytorch.py +++ b/references/detection/evaluate_pytorch.py @@ -37,7 +37,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) - targets = [{CLASS_NAME: t["boxes"]} for t in targets] + targets = [{CLASS_NAME: t} for t in targets] if amp: with torch.cuda.amp.autocast(): out = model(images, targets, return_preds=True) @@ -82,7 +82,10 @@ def main(args): train=True, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape), + detection_task=True, + sample_transforms=T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) # Monkeypatch subfolder = ds.root.split("/")[-2:] @@ -92,7 +95,10 @@ def main(args): train=False, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape), + detection_task=True, + sample_transforms=T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) subfolder = _ds.root.split("/")[-2:] ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) @@ -155,6 +161,8 @@ def parse_args(): parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") parser.add_argument("--device", default=None, type=int, help="device") parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox") parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index abf012ed83..4eef9a40b7 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -35,7 +35,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_loss, batch_cnt = 0, 0 for images, targets in tqdm(val_loader): images = batch_transforms(images) - targets = [{CLASS_NAME: t["boxes"]} for t in targets] + targets = [{CLASS_NAME: t} for t in targets] out = model(images, targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] @@ -81,7 +81,10 @@ def main(args): train=True, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape[:2]), + detection_task=True, + sample_transforms=T.Resize( + input_shape[:2], preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) # Monkeypatch subfolder = ds.root.split("/")[-2:] @@ -91,7 +94,10 @@ def main(args): train=False, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape[:2]), + detection_task=True, + sample_transforms=T.Resize( + input_shape[:2], preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) subfolder = _ds.root.split("/")[-2:] ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) @@ -129,6 +135,8 @@ def parse_args(): parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox") parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index 749a86bf06..30f9e6f288 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -72,6 +72,36 @@ def _validate_dataset_recognition_part(ds, input_size, batch_size=2): assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) +def _validate_dataset_detection_part(ds, input_size, batch_size=2, is_polygons=False): + # Fetch one sample + img, target = ds[0] + + assert isinstance(img, torch.Tensor) + assert img.shape == (3, *input_size) + assert img.dtype == torch.float32 + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + if is_polygons: + assert target.ndim == 3 and target.shape[1:] == (4, 2) + else: + assert target.ndim == 2 and target.shape[1:] == (4,) + assert np.all(np.logical_and(target <= 1, target >= 0)) + + # Check batching + loader = DataLoader( + ds, + batch_size=batch_size, + drop_last=True, + sampler=RandomSampler(ds), + num_workers=0, + pin_memory=True, + collate_fn=ds.collate_fn, + ) + + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (batch_size, 3, *input_size) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + def test_visiondataset(): url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" with pytest.raises(ValueError): @@ -282,13 +312,14 @@ def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts) @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 626 training samples and 360 test samples - [[32, 128], 15, True], # recognition + [[512, 512], 3, False, False], # Actual set has 626 training samples and 360 test samples + [[32, 128], 15, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset): +def test_sroie(input_size, num_samples, rotate, recognition, detection, mock_sroie_dataset): # monkeypatch the path to temporary dataset datasets.SROIE.TRAIN = (mock_sroie_dataset, None, "sroie2019_train_task1.zip") @@ -298,6 +329,7 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), cache_subdir=mock_sroie_dataset.split("/")[-2], ) @@ -306,67 +338,94 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) assert repr(ds) == f"SROIE(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SROIE( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), + cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 5, False], # Actual set has 229 train and 233 test samples - [[32, 128], 25, True], # recognition + [[512, 512], 5, False, False], # Actual set has 229 train and 233 test samples + [[32, 128], 25, True, False], # recognition + [[512, 512], 5, False, True], # detection ], ) -def test_ic13_dataset(input_size, num_samples, rotate, recognition, mock_ic13): +def test_ic13_dataset(input_size, num_samples, rotate, recognition, detection, mock_ic13): ds = datasets.IC13( *mock_ic13, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC13(*mock_ic13, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 7149 train and 796 test samples - [[32, 128], 5, True], # recognition + [[512, 512], 3, False, False], # Actual set has 7149 train and 796 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, mock_imgur5k): +def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, detection, mock_imgur5k): ds = datasets.IMGUR5K( *mock_imgur5k, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split assert repr(ds) == f"IMGUR5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IMGUR5K(*mock_imgur5k, train=True, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 3, False], # Actual set has 33402 training samples and 13068 test samples - [[32, 128], 12, True], # recognition + [[32, 128], 3, False, False], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 12, True, False], # recognition + [[32, 128], 3, False, True], # detection ], ) -def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): +def test_svhn(input_size, num_samples, rotate, recognition, detection, mock_svhn_dataset): # monkeypatch the path to temporary dataset datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") @@ -376,6 +435,7 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), cache_subdir=mock_svhn_dataset.split("/")[-2], ) @@ -384,19 +444,32 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): assert repr(ds) == f"SVHN(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVHN( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), + cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 149 training samples and 50 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 149 training samples and 50 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset): +def test_funsd(input_size, num_samples, rotate, recognition, detection, mock_funsd_dataset): # monkeypatch the path to temporary dataset datasets.FUNSD.URL = mock_funsd_dataset datasets.FUNSD.SHA256 = None @@ -408,6 +481,7 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), cache_subdir=mock_funsd_dataset.split("/")[-2], ) @@ -416,19 +490,32 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) assert repr(ds) == f"FUNSD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.FUNSD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), + cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 800 training samples and 100 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 800 training samples and 100 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): +def test_cord(input_size, num_samples, rotate, recognition, detection, mock_cord_dataset): # monkeypatch the path to temporary dataset datasets.CORD.TRAIN = (mock_cord_dataset, None, "cord_train.zip") @@ -438,6 +525,7 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), cache_subdir=mock_cord_dataset.split("/")[-2], ) @@ -446,19 +534,32 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): assert repr(ds) == f"CORD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.CORD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), + cache_subdir=mock_cord_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], # Actual set has 772875 training samples and 85875 test samples - [[32, 128], 10, True], # recognition + [[512, 512], 2, False, False], # Actual set has 772875 training samples and 85875 test samples + [[32, 128], 10, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_dataset): +def test_synthtext(input_size, num_samples, rotate, recognition, detection, mock_synthtext_dataset): # monkeypatch the path to temporary dataset datasets.SynthText.URL = mock_synthtext_dataset datasets.SynthText.SHA256 = None @@ -469,6 +570,7 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), cache_subdir=mock_synthtext_dataset.split("/")[-2], ) @@ -477,19 +579,32 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ assert repr(ds) == f"SynthText(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SynthText( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), + cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 1, False], # Actual set has 2000 training samples and 3000 test samples - [[32, 128], 1, True], # recognition + [[32, 128], 1, False, False], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, True, False], # recognition + [[32, 128], 1, False, True], # detection ], ) -def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_dataset): +def test_iiit5k(input_size, num_samples, rotate, recognition, detection, mock_iiit5k_dataset): # monkeypatch the path to temporary dataset datasets.IIIT5K.URL = mock_iiit5k_dataset datasets.IIIT5K.SHA256 = None @@ -500,6 +615,7 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), cache_subdir=mock_iiit5k_dataset.split("/")[-2], ) @@ -508,19 +624,32 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase assert repr(ds) == f"IIIT5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size, batch_size=1) + elif detection: + _validate_dataset_detection_part(ds, input_size, batch_size=1, is_polygons=rotate) else: _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IIIT5K( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), + cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 100 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 100 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): +def test_svt(input_size, num_samples, rotate, recognition, detection, mock_svt_dataset): # monkeypatch the path to temporary dataset datasets.SVT.URL = mock_svt_dataset datasets.SVT.SHA256 = None @@ -531,6 +660,7 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), cache_subdir=mock_svt_dataset.split("/")[-2], ) @@ -539,19 +669,32 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): assert repr(ds) == f"SVT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVT( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), + cache_subdir=mock_svt_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 246 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 246 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): +def test_ic03(input_size, num_samples, rotate, recognition, detection, mock_ic03_dataset): # monkeypatch the path to temporary dataset datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") @@ -561,6 +704,7 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), cache_subdir=mock_ic03_dataset.split("/")[-2], ) @@ -569,33 +713,52 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): assert repr(ds) == f"IC03(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC03( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), + cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], - [[32, 128], 5, True], + [[512, 512], 2, False, False], # Actual set has 1268 training samples and 472 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset): +def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, detection, mock_wildreceipt_dataset): ds = datasets.WILDRECEIPT( *mock_wildreceipt_dataset, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples assert repr(ds) == f"WILDRECEIPT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.WILDRECEIPT(*mock_wildreceipt_dataset, train=True, recognition_task=True, detection_task=True) + # NOTE: following datasets are only for recognition task diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 5d6c61b116..1129b4264e 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -54,6 +54,27 @@ def _validate_dataset_recognition_part(ds, input_size, batch_size=2): assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) +def _validate_dataset_detection_part(ds, input_size, is_polygons=False, batch_size=2): + # Fetch one sample + img, target = ds[0] + assert isinstance(img, tf.Tensor) + assert img.shape == (*input_size, 3) + assert img.dtype == tf.float32 + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + if is_polygons: + assert target.ndim == 3 and target.shape[1:] == (4, 2) + else: + assert target.ndim == 2 and target.shape[1:] == (4,) + assert np.all(np.logical_and(target <= 1, target >= 0)) + + # Check batching + loader = DataLoader(ds, batch_size=batch_size) + + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (batch_size, *input_size, 3) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + def test_visiondataset(): url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" with pytest.raises(ValueError): @@ -264,13 +285,14 @@ def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts) @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 626 training samples and 360 test samples - [[32, 128], 15, True], # recognition + [[512, 512], 3, False, False], # Actual set has 626 training samples and 360 test samples + [[32, 128], 15, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset): +def test_sroie(input_size, num_samples, rotate, recognition, detection, mock_sroie_dataset): # monkeypatch the path to temporary dataset datasets.SROIE.TRAIN = (mock_sroie_dataset, None, "sroie2019_train_task1.zip") @@ -280,6 +302,7 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), cache_subdir=mock_sroie_dataset.split("/")[-2], ) @@ -288,67 +311,94 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) assert repr(ds) == f"SROIE(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SROIE( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), + cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 5, False], # Actual set has 229 train and 233 test samples - [[32, 128], 25, True], # recognition + [[512, 512], 5, False, False], # Actual set has 229 train and 233 test samples + [[32, 128], 25, True, False], # recognition + [[512, 512], 5, False, True], # detection ], ) -def test_ic13_dataset(input_size, num_samples, rotate, recognition, mock_ic13): +def test_ic13_dataset(input_size, num_samples, rotate, recognition, detection, mock_ic13): ds = datasets.IC13( *mock_ic13, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC13(*mock_ic13, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 7149 train and 796 test samples - [[32, 128], 5, True], # recognition + [[512, 512], 3, False, False], # Actual set has 7149 train and 796 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, mock_imgur5k): +def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, detection, mock_imgur5k): ds = datasets.IMGUR5K( *mock_imgur5k, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split assert repr(ds) == f"IMGUR5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IMGUR5K(*mock_imgur5k, train=True, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 3, False], # Actual set has 33402 training samples and 13068 test samples - [[32, 128], 12, True], # recognition + [[32, 128], 3, False, False], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 12, True, False], # recognition + [[32, 128], 3, False, True], # detection ], ) -def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): +def test_svhn(input_size, num_samples, rotate, recognition, detection, mock_svhn_dataset): # monkeypatch the path to temporary dataset datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") @@ -358,6 +408,7 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), cache_subdir=mock_svhn_dataset.split("/")[-2], ) @@ -366,19 +417,32 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): assert repr(ds) == f"SVHN(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVHN( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), + cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 149 training samples and 50 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 149 training samples and 50 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset): +def test_funsd(input_size, num_samples, rotate, recognition, detection, mock_funsd_dataset): # monkeypatch the path to temporary dataset datasets.FUNSD.URL = mock_funsd_dataset datasets.FUNSD.SHA256 = None @@ -390,6 +454,7 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), cache_subdir=mock_funsd_dataset.split("/")[-2], ) @@ -398,19 +463,32 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) assert repr(ds) == f"FUNSD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.FUNSD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), + cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 800 training samples and 100 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 800 training samples and 100 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): +def test_cord(input_size, num_samples, rotate, recognition, detection, mock_cord_dataset): # monkeypatch the path to temporary dataset datasets.CORD.TRAIN = (mock_cord_dataset, None, "cord_train.zip") @@ -420,6 +498,7 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), cache_subdir=mock_cord_dataset.split("/")[-2], ) @@ -428,19 +507,32 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): assert repr(ds) == f"CORD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.CORD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), + cache_subdir=mock_cord_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], # Actual set has 772875 training samples and 85875 test samples - [[32, 128], 10, True], # recognition + [[512, 512], 2, False, False], # Actual set has 772875 training samples and 85875 test samples + [[32, 128], 10, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_dataset): +def test_synthtext(input_size, num_samples, rotate, recognition, detection, mock_synthtext_dataset): # monkeypatch the path to temporary dataset datasets.SynthText.URL = mock_synthtext_dataset datasets.SynthText.SHA256 = None @@ -451,6 +543,7 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), cache_subdir=mock_synthtext_dataset.split("/")[-2], ) @@ -459,19 +552,32 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ assert repr(ds) == f"SynthText(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SynthText( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), + cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 1, False], # Actual set has 2000 training samples and 3000 test samples - [[32, 128], 1, True], # recognition + [[32, 128], 1, False, False], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, True, False], # recognition + [[32, 128], 1, False, True], # detection ], ) -def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_dataset): +def test_iiit5k(input_size, num_samples, rotate, recognition, detection, mock_iiit5k_dataset): # monkeypatch the path to temporary dataset datasets.IIIT5K.URL = mock_iiit5k_dataset datasets.IIIT5K.SHA256 = None @@ -482,6 +588,7 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), cache_subdir=mock_iiit5k_dataset.split("/")[-2], ) @@ -490,19 +597,32 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase assert repr(ds) == f"IIIT5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size, batch_size=1) + elif detection: + _validate_dataset_detection_part(ds, input_size, batch_size=1, is_polygons=rotate) else: _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IIIT5K( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), + cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 100 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 100 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): +def test_svt(input_size, num_samples, rotate, recognition, detection, mock_svt_dataset): # monkeypatch the path to temporary dataset datasets.SVT.URL = mock_svt_dataset datasets.SVT.SHA256 = None @@ -513,6 +633,7 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), cache_subdir=mock_svt_dataset.split("/")[-2], ) @@ -521,19 +642,32 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): assert repr(ds) == f"SVT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVT( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), + cache_subdir=mock_svt_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 246 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 246 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): +def test_ic03(input_size, num_samples, rotate, recognition, detection, mock_ic03_dataset): # monkeypatch the path to temporary dataset datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") @@ -543,6 +677,7 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), cache_subdir=mock_ic03_dataset.split("/")[-2], ) @@ -551,33 +686,52 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): assert repr(ds) == f"IC03(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC03( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), + cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], - [[32, 128], 5, True], + [[512, 512], 2, False, False], # Actual set has 1268 training samples and 472 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset): +def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, detection, mock_wildreceipt_dataset): ds = datasets.WILDRECEIPT( *mock_wildreceipt_dataset, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples assert repr(ds) == f"WILDRECEIPT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.WILDRECEIPT(*mock_wildreceipt_dataset, train=True, recognition_task=True, detection_task=True) + # NOTE: following datasets are only for recognition task