diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py index 3c8dcd0b0e..224bdfca78 100644 --- a/doctr/models/detection/predictor/pytorch.py +++ b/doctr/models/detection/predictor/pytorch.py @@ -13,6 +13,8 @@ __all__ = ["DetectionPredictor"] +from doctr.utils.gpu import select_gpu_device + class DetectionPredictor(nn.Module): """Implements an object able to localize text elements in a document @@ -27,29 +29,26 @@ def __init__( pre_processor: PreProcessor, model: nn.Module, ) -> None: - super().__init__() self.model = model.eval() self.pre_processor = pre_processor self.postprocessor = self.model.postprocessor - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if os.environ.get("CUDA_VISIBLE_DEVICES", []) == "": - self.device = torch.device("cpu") - elif len(os.environ.get("CUDA_VISIBLE_DEVICES", [])) > 0: - self.device = torch.device("cuda") - if "onnx" not in str((type(self.model))) and (self.device == torch.device("cuda")): + + detected_device, selected_device = select_gpu_device() + if "onnx" in str((type(self.model))): + selected_device = 'cpu' # self.model = nn.DataParallel(self.model) # self.model = self.model.half() - self.model = self.model.to(self.device) + self.device = torch.device(selected_device) + self.model = self.model.to(self.device) @torch.no_grad() def forward( self, pages: List[Union[np.ndarray, torch.Tensor]], - return_model_output = False, + return_model_output=False, **kwargs: Any, ) -> List[np.ndarray]: - # Dimension check if any(page.ndim != 3 for page in pages): raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") diff --git a/doctr/models/recognition/predictor/pytorch.py b/doctr/models/recognition/predictor/pytorch.py index 141d092c6a..894699c1bd 100644 --- a/doctr/models/recognition/predictor/pytorch.py +++ b/doctr/models/recognition/predictor/pytorch.py @@ -10,6 +10,7 @@ from torch import nn import os from doctr.models.preprocessor import PreProcessor +from doctr.utils.gpu import select_gpu_device from ._utils import remap_preds, split_crops @@ -31,20 +32,19 @@ def __init__( model: nn.Module, split_wide_crops: bool = True, ) -> None: - super().__init__() self.pre_processor = pre_processor self.model = model.eval() self.postprocessor = self.model.postprocessor - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if os.environ.get("CUDA_VISIBLE_DEVICES", []) == "": - self.device = torch.device("cpu") - elif len(os.environ.get("CUDA_VISIBLE_DEVICES", [])) > 0: - self.device = torch.device("cuda") - if "onnx" not in str((type(self.model))) and (self.device == torch.device("cuda")): + + detected_device, selected_device = select_gpu_device() + if "onnx" in str((type(self.model))): + selected_device = 'cpu' # self.model = nn.DataParallel(self.model) - self.model = self.model.to(self.device) # self.model = self.model.half() + self.device = torch.device(selected_device) + self.model = self.model.to(self.device) + self.split_wide_crops = split_wide_crops self.critical_ar = 8 # Critical aspect ratio self.dil_factor = 1.4 # Dilation factor to overlap the crops @@ -56,7 +56,6 @@ def forward( crops: Sequence[Union[np.ndarray, torch.Tensor]], **kwargs: Any, ) -> List[Tuple[str, float]]: - if len(crops) == 0: return [] # Dimension check diff --git a/doctr/utils/gpu.py b/doctr/utils/gpu.py new file mode 100644 index 0000000000..7a0da2fe0f --- /dev/null +++ b/doctr/utils/gpu.py @@ -0,0 +1,40 @@ +import logging +import os +from typing import Tuple +import torch + +log = logging.getLogger(__name__) + + +def select_gpu_device() -> Tuple[str, str]: + """tries to find either cuda or arm mps gpu accelerator and choses the most appropriate one, + honoring the environment variables (CUDA_VISIBLE_DEVICES), if any have been set. + + returns tuple(best_detected_device, selected_device) + best_detected_device reflects capabilities of the system + selected_device is the device that should be used (might be cpu even if best_detected_device is eg cuda) + """ + if torch.cuda.is_available(): + detected_gpu_device = 'cuda' + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + detected_gpu_device = 'mps' + else: + detected_gpu_device = 'cpu' + + selected_gpu_device = detected_gpu_device + match detected_gpu_device: # various exceptions to the above + case 'cuda': + if os.environ.get("CUDA_VISIBLE_DEVICES") == "": + selected_gpu_device = 'cpu' + case 'mps': + # FIXME detected mps selects cpu here because of the many bugs present in the mps implementation of + # torch'es 1.13 LSTM. As of 5/29/2023, they appear to be actively fixing them. I did try with torch + # 2.0.1 and while the bugs look different it's still broken. Revisit when later versions of torch + # are available. + # pass + selected_gpu_device = 'cpu' + case 'cpu': + pass + + log.info(f"{detected_gpu_device=} {selected_gpu_device=}") + return detected_gpu_device, selected_gpu_device