Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix eval scripts + possible overflow in Resize #1715

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/app/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
13 changes: 7 additions & 6 deletions doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,33 +74,34 @@
if self.symmetric_pad:
half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
_pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
# Pad image
img = pad(img, _pad)

# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
if target is not None:
if self.symmetric_pad:
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]

Check warning on line 83 in doctr/transforms/modules/pytorch.py

View check run for this annotation

Codecov / codecov/patch

doctr/transforms/modules/pytorch.py#L83

Added line #L83 was not covered by tests

if self.preserve_aspect_ratio:
# Get absolute coords
if target.shape[1:] == (4,):
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
if np.max(target) <= 1:
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
else:
target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
elif target.shape[1:] == (4, 2):
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
if np.max(target) <= 1:
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
else:
target[..., 0] *= raw_shape[-1] / img.shape[-1]
target[..., 1] *= raw_shape[-2] / img.shape[-2]
else:
raise AssertionError
return img, target
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")

Check warning on line 102 in doctr/transforms/modules/pytorch.py

View check run for this annotation

Codecov / codecov/patch

doctr/transforms/modules/pytorch.py#L102

Added line #L102 was not covered by tests

return img, np.clip(target, 0, 1)
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved

return img

Expand Down
26 changes: 15 additions & 11 deletions doctr/transforms/modules/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,46 +107,50 @@
target: Optional[np.ndarray] = None,
) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]:
input_dtype = img.dtype
self.output_size = (
(self.output_size, self.output_size) if isinstance(self.output_size, int) else self.output_size
)

img = tf.image.resize(img, self.wanted_size, self.method, self.preserve_aspect_ratio, self.antialias)
# It will produce an un-padded resized image, with a side shorter than wanted if we preserve aspect ratio
raw_shape = img.shape[:2]
if self.symmetric_pad:
half_pad = (int((self.output_size[0] - img.shape[0]) / 2), 0)
if self.preserve_aspect_ratio:
if isinstance(self.output_size, (tuple, list)):
# In that case we need to pad because we want to enforce both width and height
if not self.symmetric_pad:
offset = (0, 0)
half_pad = (0, 0)
elif self.output_size[0] == img.shape[0]:
offset = (0, int((self.output_size[1] - img.shape[1]) / 2))
else:
offset = (int((self.output_size[0] - img.shape[0]) / 2), 0)
img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size)
half_pad = (0, int((self.output_size[1] - img.shape[1]) / 2))
# Pad image
img = tf.image.pad_to_bounding_box(img, *half_pad, *self.output_size)

# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
if target is not None:
if self.symmetric_pad:
offset = half_pad[0] / img.shape[0], half_pad[1] / img.shape[1]

Check warning on line 132 in doctr/transforms/modules/tensorflow.py

View check run for this annotation

Codecov / codecov/patch

doctr/transforms/modules/tensorflow.py#L132

Added line #L132 was not covered by tests

if self.preserve_aspect_ratio:
# Get absolute coords
if target.shape[1:] == (4,):
if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
if np.max(target) <= 1:
offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1]
target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0]
else:
target[:, [0, 2]] *= raw_shape[1] / img.shape[1]
target[:, [1, 3]] *= raw_shape[0] / img.shape[0]
elif target.shape[1:] == (4, 2):
if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
if np.max(target) <= 1:
offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1]
target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0]
else:
target[..., 0] *= raw_shape[1] / img.shape[1]
target[..., 1] *= raw_shape[0] / img.shape[0]
else:
raise AssertionError
return tf.cast(img, dtype=input_dtype), target
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")

Check warning on line 151 in doctr/transforms/modules/tensorflow.py

View check run for this annotation

Codecov / codecov/patch

doctr/transforms/modules/tensorflow.py#L151

Added line #L151 was not covered by tests

return tf.cast(img, dtype=input_dtype), np.clip(target, 0, 1)

return tf.cast(img, dtype=input_dtype)

Expand Down
2 changes: 1 addition & 1 deletion references/classification/latency_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def main(args):
if args.gpu:
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
Expand Down
2 changes: 1 addition & 1 deletion references/classification/train_tensorflow_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from doctr.models import login_to_hub, push_to_hf_hub

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion references/classification/train_tensorflow_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from doctr.models import login_to_hub, push_to_hf_hub

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion references/detection/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tensorflow.keras import mixed_precision
from tqdm import tqdm

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion references/detection/latency_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def main(args):
if args.gpu:
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
Expand Down
2 changes: 1 addition & 1 deletion references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from doctr.models import login_to_hub, push_to_hf_hub

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion references/recognition/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensorflow.keras import mixed_precision
from tqdm import tqdm

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion references/recognition/latency_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def main(args):
if args.gpu:
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
Expand Down
2 changes: 1 addition & 1 deletion references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from doctr.models import login_to_hub, push_to_hf_hub

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if is_tf_available():
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
2 changes: 1 addition & 1 deletion scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if is_tf_available():
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

Expand Down
35 changes: 31 additions & 4 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

from doctr import datasets
from doctr import transforms as T
from doctr.file_utils import is_tf_available
from doctr.models import ocr_predictor
from doctr.utils.geometry import extract_crops, extract_rcrops
Expand All @@ -20,7 +21,7 @@
if is_tf_available():
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
Expand All @@ -35,24 +36,47 @@ def main(args):
if not args.rotation:
args.eval_straight = True

input_shape = (args.size, args.size)

# We define a transformation function which does transform the annotation
# to the required format for the Resize transformation
def _transform(img, target):
boxes = target["boxes"]
transformed_img, transformed_boxes = T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
)(img, boxes)
return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]}

predictor = ocr_predictor(
args.detection,
args.recognition,
pretrained=True,
reco_bs=args.batch_size,
preserve_aspect_ratio=False,
preserve_aspect_ratio=False, # we handle the transformation directly in the dataset so this is set to False
symmetric_pad=False, # we handle the transformation directly in the dataset so this is set to False
assume_straight_pages=not args.rotation,
)

if args.img_folder and args.label_file:
testset = datasets.OCRDataset(
img_folder=args.img_folder,
label_file=args.label_file,
sample_transforms=_transform,
)
sets = [testset]
else:
train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight)
val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight)
train_set = datasets.__dict__[args.dataset](
train=True,
download=True,
use_polygons=not args.eval_straight,
sample_transforms=_transform,
)
val_set = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=not args.eval_straight,
sample_transforms=_transform,
)
sets = [train_set, val_set]

reco_metric = TextMatch()
Expand Down Expand Up @@ -190,6 +214,9 @@ def parse_args():
parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels")
parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing")
parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition")
parser.add_argument("--size", type=int, default=1024, help="model input size, H = W")
parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image")
parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically")
parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples")
parser.add_argument(
"--eval-straight",
Expand Down
35 changes: 31 additions & 4 deletions scripts/evaluate_kie.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tqdm import tqdm

from doctr import datasets
from doctr import transforms as T
from doctr.file_utils import is_tf_available
from doctr.models import kie_predictor
from doctr.utils.geometry import extract_crops, extract_rcrops
Expand All @@ -22,7 +23,7 @@
if is_tf_available():
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
Expand All @@ -37,24 +38,47 @@ def main(args):
if not args.rotation:
args.eval_straight = True

input_shape = (args.size, args.size)

# We define a transformation function which does transform the annotation
# to the required format for the Resize transformation
def _transform(img, target):
boxes = target["boxes"]
transformed_img, transformed_boxes = T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
)(img, boxes)
return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]}

predictor = kie_predictor(
args.detection,
args.recognition,
pretrained=True,
reco_bs=args.batch_size,
preserve_aspect_ratio=False,
preserve_aspect_ratio=False, # we handle the transformation directly in the dataset so this is set to False
symmetric_pad=False, # we handle the transformation directly in the dataset so this is set to False
assume_straight_pages=not args.rotation,
)

if args.img_folder and args.label_file:
testset = datasets.OCRDataset(
img_folder=args.img_folder,
label_file=args.label_file,
sample_transforms=_transform,
)
sets = [testset]
else:
train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight)
val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight)
train_set = datasets.__dict__[args.dataset](
train=True,
download=True,
use_polygons=not args.eval_straight,
sample_transforms=_transform,
)
val_set = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=not args.eval_straight,
sample_transforms=_transform,
)
sets = [train_set, val_set]

reco_metric = TextMatch()
Expand Down Expand Up @@ -187,6 +211,9 @@ def parse_args():
parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels")
parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing")
parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition")
parser.add_argument("--size", type=int, default=1024, help="model input size, H = W")
parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image")
parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically")
parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples")
parser.add_argument(
"--eval-straight",
Expand Down
16 changes: 16 additions & 0 deletions tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ def test_resize():
out = transfo(input_t)
assert out.dtype == torch.float16

# --- Test with target (bounding boxes) ---

target_boxes = np.array([[0.1, 0.1, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]])
output_size = (64, 64)

transfo = Resize(output_size, preserve_aspect_ratio=True)
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
out, new_target = transfo(input_t, target_boxes)

assert out.shape[-2:] == output_size
assert new_target.shape == target_boxes.shape
assert np.all(new_target >= 0) and np.all(new_target <= 1)

out = transfo(input_t)
assert out.shape[-2:] == output_size


@pytest.mark.parametrize(
"rgb_min",
Expand Down
Loading
Loading