From cbc9f435066197a551ea6239ea0c4b755215ae9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Mon, 4 Nov 2024 14:33:57 +0100 Subject: [PATCH] fix: check autocast is supported on device --- edsnlp/processing/multiprocessing.py | 21 +++++++++++++++------ edsnlp/processing/simple.py | 23 +++++++++++++---------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/edsnlp/processing/multiprocessing.py b/edsnlp/processing/multiprocessing.py index 9a8535cbe..10be94671 100644 --- a/edsnlp/processing/multiprocessing.py +++ b/edsnlp/processing/multiprocessing.py @@ -762,11 +762,15 @@ def process_items(self, stage): autocast = self.stream.autocast autocast_ctx = nullcontext() device = self.devices[self.uid] - if autocast: - autocast_ctx = torch.autocast( - device_type=getattr(device, "type", device).split(":")[0], - dtype=autocast if autocast is not True else None, - ) + device_type = getattr(device, "type", device).split(":")[0] + try: + if autocast: + autocast_ctx = torch.autocast( + device_type=device_type, + dtype=autocast if autocast is not True else None, + ) + except RuntimeError: # pragma: no cover + pass with torch.no_grad(), autocast_ctx, torch.inference_mode(): for item in self.iter_tasks(stage): @@ -1249,7 +1253,12 @@ def adjust_num_workers(stream: Stream): num_gpu_workers = 0 max_cpu_workers = max(num_cpus - num_gpu_workers - 1, 0) - default_cpu_workers = max(min(max_cpu_workers, num_gpu_workers * 4), 1) + default_cpu_workers = max( + min(max_cpu_workers, num_gpu_workers * 4) + if num_gpu_workers > 0 + else max_cpu_workers, + 1, + ) num_cpu_workers = ( default_cpu_workers if stream.num_cpu_workers is None diff --git a/edsnlp/processing/simple.py b/edsnlp/processing/simple.py index fd596f43e..b3258b78b 100644 --- a/edsnlp/processing/simple.py +++ b/edsnlp/processing/simple.py @@ -20,19 +20,22 @@ def execute_simple_backend(stream: Stream): try: torch = sys.modules["torch"] no_grad_ctx = torch.no_grad() - autocast_device_type = next( + device = next( (p.device for pipe in stream.torch_components() for p in pipe.parameters()), torch.device("cpu"), - ).type.split(":")[0] - autocast_dtype = stream.autocast if stream.autocast is not True else None - autocast_ctx = ( - torch.autocast( - device_type=autocast_device_type, - dtype=autocast_dtype, - ) - if stream.autocast - else nullcontext() ) + device_type = getattr(device, "type", device).split(":")[0] + autocast = stream.autocast + autocast_ctx = nullcontext() + try: + if autocast: + autocast_ctx = torch.autocast( + device_type=device_type, + dtype=autocast if autocast is not True else None, + ) + except RuntimeError: # pragma: no cover + pass + inference_mode_ctx = ( torch.inference_mode() if hasattr(torch, "inference_mode")