diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 797b43cb58..dccb32158d 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -50,9 +50,9 @@ def _orientation_predictor( else: allowed_archs = [classification.MobileNetV3] if is_torch_available(): - from doctr.models.utils import _get_torch_compile_type + from doctr.models.utils import _CompiledModule - allowed_archs.append(_get_torch_compile_type()) + allowed_archs.append(_CompiledModule) if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 7e14205889..419ebac7e8 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -59,9 +59,9 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST] if is_torch_available(): # The following is required for torch compiled models - from doctr.models.utils import _get_torch_compile_type + from doctr.models.utils import _CompiledModule - allowed_archs.append(_get_torch_compile_type()) + allowed_archs.append(_CompiledModule) if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index f4bea2c57a..64522760c7 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -38,9 +38,9 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq] if is_torch_available(): # The following is required for torch compiled models - from doctr.models.utils import _get_torch_compile_type + from doctr.models.utils import _CompiledModule - allowed_archs.append(_get_torch_compile_type()) + allowed_archs.append(_CompiledModule) if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index afe14be61f..59430c9831 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -18,12 +18,11 @@ "export_model_to_onnx", "_copy_tensor", "_bf16_to_float32", - "_get_torch_compile_type", + "_CompiledModule", ] - -def _get_torch_compile_type() -> Any: - return torch._dynamo.eval_frame.OptimizedModule +# torch compiled model type +_CompiledModule = torch._dynamo.eval_frame.OptimizedModule def _copy_tensor(x: torch.Tensor) -> torch.Tensor: