Skip to content

Commit

Permalink
Apply suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Dec 4, 2024
1 parent 768ec80 commit a92fc95
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.1
hooks:
- id: ruff
args: [ --fix ]
Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from doctr.models import classification
from doctr.models.classification.predictor import OrientationPredictor
from doctr.models.utils import export_model_to_onnx
from doctr.models.utils import _CompiledModule, export_model_to_onnx


def _test_classification(model, input_shape, output_size, batch_size=2):
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_crop_orientation_model(mock_text_box):
compiled_model = torch.compile(classification.mobilenet_v3_small_crop_orientation(pretrained=True))
compiled_classifier = classification.crop_orientation_predictor(compiled_model)

assert isinstance(compiled_model, torch._dynamo.eval_frame.OptimizedModule)
assert isinstance(compiled_model, _CompiledModule)
assert isinstance(compiled_classifier, OrientationPredictor)
assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3]
assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_page_orientation_model(mock_payslip):
compiled_model = torch.compile(classification.mobilenet_v3_small_page_orientation(pretrained=True))
compiled_classifier = classification.page_orientation_predictor(compiled_model)

assert isinstance(compiled_model, torch._dynamo.eval_frame.OptimizedModule)
assert isinstance(compiled_model, _CompiledModule)
assert isinstance(compiled_classifier, OrientationPredictor)
assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3]
assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_models_detection_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.detection._utils import dilate, erode
from doctr.models.detection.fast.pytorch import reparameterize
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.utils import export_model_to_onnx
from doctr.models.utils import _CompiledModule, export_model_to_onnx


@pytest.mark.parametrize("train_mode", [True, False])
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_torch_compiled_models(arch_name, mock_payslip):

# Compile the model
compiled_model = torch.compile(detection.__dict__[arch_name](pretrained=True).eval())
assert isinstance(compiled_model, torch._dynamo.eval_frame.OptimizedModule)
assert isinstance(compiled_model, _CompiledModule)
compiled_predictor = detection.zoo.detection_predictor(compiled_model)
compiled_out, seg_maps = compiled_predictor(doc, return_maps=True)

Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_models_recognition_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.models.recognition.sar.pytorch import SARPostProcessor
from doctr.models.recognition.vitstr.pytorch import ViTSTRPostProcessor
from doctr.models.utils import export_model_to_onnx
from doctr.models.utils import _CompiledModule, export_model_to_onnx

system_available_memory = int(psutil.virtual_memory().available / 1024**3)

Expand Down Expand Up @@ -178,7 +178,7 @@ def test_torch_compiled_models(arch_name, mock_text_box):

# Compile the model
compiled_model = torch.compile(recognition.__dict__[arch_name](pretrained=True).eval())
assert isinstance(compiled_model, torch._dynamo.eval_frame.OptimizedModule)
assert isinstance(compiled_model, _CompiledModule)
compiled_predictor = recognition.zoo.recognition_predictor(compiled_model)
compiled_out = compiled_predictor(doc)

Expand Down

0 comments on commit a92fc95

Please sign in to comment.