Skip to content

Commit

Permalink
fix: check autocast is supported on device
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 4, 2024
1 parent ffa4231 commit cbc9f43
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
21 changes: 15 additions & 6 deletions edsnlp/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,15 @@ def process_items(self, stage):
autocast = self.stream.autocast
autocast_ctx = nullcontext()
device = self.devices[self.uid]
if autocast:
autocast_ctx = torch.autocast(
device_type=getattr(device, "type", device).split(":")[0],
dtype=autocast if autocast is not True else None,
)
device_type = getattr(device, "type", device).split(":")[0]
try:
if autocast:
autocast_ctx = torch.autocast(
device_type=device_type,
dtype=autocast if autocast is not True else None,
)
except RuntimeError: # pragma: no cover
pass

with torch.no_grad(), autocast_ctx, torch.inference_mode():
for item in self.iter_tasks(stage):
Expand Down Expand Up @@ -1249,7 +1253,12 @@ def adjust_num_workers(stream: Stream):
num_gpu_workers = 0

max_cpu_workers = max(num_cpus - num_gpu_workers - 1, 0)
default_cpu_workers = max(min(max_cpu_workers, num_gpu_workers * 4), 1)
default_cpu_workers = max(
min(max_cpu_workers, num_gpu_workers * 4)
if num_gpu_workers > 0
else max_cpu_workers,
1,
)
num_cpu_workers = (
default_cpu_workers
if stream.num_cpu_workers is None
Expand Down
23 changes: 13 additions & 10 deletions edsnlp/processing/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@ def execute_simple_backend(stream: Stream):
try:
torch = sys.modules["torch"]
no_grad_ctx = torch.no_grad()
autocast_device_type = next(
device = next(
(p.device for pipe in stream.torch_components() for p in pipe.parameters()),
torch.device("cpu"),
).type.split(":")[0]
autocast_dtype = stream.autocast if stream.autocast is not True else None
autocast_ctx = (
torch.autocast(
device_type=autocast_device_type,
dtype=autocast_dtype,
)
if stream.autocast
else nullcontext()
)
device_type = getattr(device, "type", device).split(":")[0]
autocast = stream.autocast
autocast_ctx = nullcontext()
try:
if autocast:
autocast_ctx = torch.autocast(
device_type=device_type,
dtype=autocast if autocast is not True else None,
)
except RuntimeError: # pragma: no cover
pass

inference_mode_ctx = (
torch.inference_mode()
if hasattr(torch, "inference_mode")
Expand Down

0 comments on commit cbc9f43

Please sign in to comment.