diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7e2a849de3..a897f7a6b3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,7 +41,7 @@ Install all additional dependencies with the following command: ```shell python -m pip install --upgrade pip -pip install -e .[dev] +pip install -e '.[dev]' pre-commit install ``` diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 16ffea0051..7df839484b 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -63,7 +63,7 @@ def _orientation_predictor( def crop_orientation_predictor( - arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any ) -> OrientationPredictor: """Crop orientation classification architecture. @@ -77,17 +77,18 @@ def crop_orientation_predictor( ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset + batch_size: number of samples the model processes in parallel **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- OrientationPredictor """ - return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs) + return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs) def page_orientation_predictor( - arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any ) -> OrientationPredictor: """Page orientation classification architecture. @@ -101,10 +102,11 @@ def page_orientation_predictor( ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset + batch_size: number of samples the model processes in parallel **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- OrientationPredictor """ - return _orientation_predictor(arch, pretrained, model_type="page", **kwargs) + return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs) diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index e89009fd8d..b8dfa7636d 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -79,6 +79,9 @@ def detection_predictor( arch: Any = "fast_base", pretrained: bool = False, assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + batch_size: int = 2, **kwargs: Any, ) -> DetectionPredictor: """Text detection architecture. @@ -94,10 +97,22 @@ def detection_predictor( arch: name of the architecture or model itself to use (e.g. 'db_resnet50') pretrained: If True, returns a model pre-trained on our text detection dataset assume_straight_pages: If True, fit straight boxes to the page + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right + batch_size: number of samples the model processes in parallel **kwargs: optional keyword arguments passed to the architecture Returns: ------- Detection predictor """ - return _predictor(arch, pretrained, assume_straight_pages, **kwargs) + return _predictor( + arch=arch, + pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + batch_size=batch_size, + **kwargs, + ) diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index 58a236bd08..e074fde829 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -27,6 +27,7 @@ class PreProcessor(nn.Module): batch_size: the size of page batches mean: mean value of the training distribution by channel std: standard deviation of the training distribution by channel + **kwargs: additional arguments for the resizing operation """ def __init__( diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py index 85e06fca3e..15f8be5ac3 100644 --- a/doctr/models/preprocessor/tensorflow.py +++ b/doctr/models/preprocessor/tensorflow.py @@ -25,6 +25,7 @@ class PreProcessor(NestedObject): batch_size: the size of page batches mean: mean value of the training distribution by channel std: standard deviation of the training distribution by channel + **kwargs: additional arguments for the resizing operation """ _children_names: List[str] = ["resize", "normalize"] diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index 0393240431..be6ca4ae44 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -52,7 +52,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict return predictor -def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor: +def recognition_predictor( + arch: Any = "crnn_vgg16_bn", + pretrained: bool = False, + symmetric_pad: bool = False, + batch_size: int = 128, + **kwargs: Any, +) -> RecognitionPredictor: """Text recognition architecture. Example:: @@ -66,10 +72,12 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, ---- arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') pretrained: If True, returns a model pre-trained on our text recognition dataset + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right + batch_size: number of samples the model processes in parallel **kwargs: optional parameters to be passed to the architecture Returns: ------- Recognition predictor """ - return _predictor(arch, pretrained, **kwargs) + return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)