Skip to content

Commit

Permalink
update type
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 21, 2024
1 parent 4b839bf commit e75745c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def _orientation_predictor(
else:
allowed_archs = [classification.MobileNetV3]
if is_torch_available():
from doctr.models.utils import _get_torch_compile_type
from doctr.models.utils import _CompiledModule

allowed_archs.append(_get_torch_compile_type())
allowed_archs.append(_CompiledModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
if is_torch_available():
# The following is required for torch compiled models
from doctr.models.utils import _get_torch_compile_type
from doctr.models.utils import _CompiledModule

allowed_archs.append(_get_torch_compile_type())
allowed_archs.append(_CompiledModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
if is_torch_available():
# The following is required for torch compiled models
from doctr.models.utils import _get_torch_compile_type
from doctr.models.utils import _CompiledModule

allowed_archs.append(_get_torch_compile_type())
allowed_archs.append(_CompiledModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down
7 changes: 3 additions & 4 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_float32",
"_get_torch_compile_type",
"_CompiledModule",
]


def _get_torch_compile_type() -> Any:
return torch._dynamo.eval_frame.OptimizedModule
# torch compiled model type
_CompiledModule = torch._dynamo.eval_frame.OptimizedModule


def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit e75745c

Please sign in to comment.