Skip to content

Commit

Permalink
perf: prune tensors in preprocessing workers during training
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 13, 2024
1 parent 33883b5 commit f0c226e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 72 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
### Changed
- `eds.span_context_getter`'s parameter `context_sents` is no longer optional and must be explicitly set to 0 to disable sentence context
- In multi-GPU setups, streams that contain torch components are now stripped of their parameter tensors when sent to CPU Workers since these workers only perform preprocessing and postprocessing and should therefore not need the model parameters.
### Fixed
Expand Down
102 changes: 30 additions & 72 deletions edsnlp/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from multiprocessing.connection import wait
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -232,70 +231,7 @@ def cpu_count(): # pragma: no cover
try:
import torch

# Torch may still be imported as a namespace package, so we can access the
# torch.save and torch.load functions
torch_save = torch.save
torch_load = torch.load

MAP_LOCATION = None

try:
from accelerate.hooks import AlignDevicesHook

# We need to replace the "execution_device" attribute of the AlignDevicesHook
# using map_location when unpickling the stream

def save_align_devices_hook(pickler: Any, obj: Any):
pickler.save_reduce(load_align_devices_hook, (obj.__dict__,), obj=obj)

def load_align_devices_hook(state):
state["execution_device"] = MAP_LOCATION
new_obj = AlignDevicesHook.__new__(AlignDevicesHook)
new_obj.__dict__.update(state)
return new_obj

except ImportError:
AlignDevicesHook = None

def dump(*args, **kwargs):
# We need to replace the "execution_device" attribute of the AlignDevicesHook
# using map_location when pickling the stream
old = None
old_settings = dict(dill.settings)
try:
if AlignDevicesHook is not None:
old = dill.Pickler.dispatch.get(AlignDevicesHook)
dill.Pickler.dispatch[AlignDevicesHook] = save_align_devices_hook
dill.settings["recurse"] = True
dill.settings["byref"] = True
return torch_save(*args, pickle_module=dill, **kwargs)
finally:
dill.settings.update(old_settings)
if AlignDevicesHook is not None:
del dill.Pickler.dispatch[AlignDevicesHook]
if old is not None: # pragma: no cover
dill.Pickler.dispatch[AlignDevicesHook] = old

def load(*args, map_location=None, **kwargs):
global MAP_LOCATION
MAP_LOCATION = map_location
if torch.__version__ >= "2.1" and isinstance(args[0], str):
kwargs["mmap"] = True
try:
if torch.__version__ < "2.0.0":
torch_load.__globals__["pickle"] = dill
result = torch_load(
*args,
pickle_module=dill,
map_location=map_location,
**kwargs,
)
finally:
import pickle

torch_load.__globals__["pickle"] = pickle
MAP_LOCATION = None
return result
from edsnlp.utils.torch import dump, load

except (ImportError, AttributeError): # pragma: no cover
# noinspection PyUnusedLocal
Expand Down Expand Up @@ -880,6 +816,7 @@ def __init__(self, stream):
) = self.adjust_num_workers(stream)
self.stream = stream
self.stages = stream._make_stages(split_torch_pipes=num_gpu_workers > 0)
self.has_torch_pipes = has_torch_pipes
mp = self.get_multiprocessing_context(
has_torch_pipes=has_torch_pipes,
process_start_method=stream.process_start_method,
Expand Down Expand Up @@ -945,7 +882,11 @@ def __init__(self, stream):
name = f"from-{cpu}_to-main"
self.data_queues[name] = mp.Queue(2)

self.temp_file = tempfile.NamedTemporaryFile(delete=False)
self.cpu_temp_file = self.gpu_temp_file = None
if len(self.cpu_worker_names):
self.cpu_temp_file = tempfile.NamedTemporaryFile(delete=False)
if len(self.gpu_worker_names):
self.gpu_temp_file = tempfile.NamedTemporaryFile(delete=False)

self.cpu_workers = []
self.gpu_workers = []
Expand All @@ -962,7 +903,7 @@ def __init__(self, stream):
worker_control_queue=self.worker_control_queues[cpu],
final_barrier=self.final_barrier,
schedule=self.cpu_to_gpu_schedules[cpu],
stream_path=self.temp_file.name,
stream_path=self.cpu_temp_file.name,
devices=devices,
gpu_semaphores={
gpu: sem
Expand All @@ -985,7 +926,7 @@ def __init__(self, stream):
main_control_queue=self.main_control_queue,
worker_control_queue=self.worker_control_queues[gpu],
final_barrier=self.final_barrier,
stream_path=self.temp_file.name,
stream_path=self.gpu_temp_file.name,
devices=devices,
gpu_semaphores={
cpu: sem
Expand Down Expand Up @@ -1025,9 +966,23 @@ def run(self):
)

stream_to_dump = self.stream.worker_copy()
with self.temp_file as fp:
dump((stream_to_dump, self.stages), fp)
fp.close()
if self.cpu_temp_file:
# If we have GPU workers, these will be responsible for the forward pass
# and CPU workers will only perform preprocessing and postprocessing
# so they don't need deep learning parameters
# TODO: should we make this a stream set_processing option ?
keep_tensors = self.has_torch_pipes and len(self.gpu_worker_names) == 0
with self.cpu_temp_file as fp:
dump(
(stream_to_dump, self.stages),
fp,
skip_tensors=not keep_tensors,
)
fp.close()
if self.gpu_temp_file:
with self.gpu_temp_file as fp:
dump((stream_to_dump, self.stages), fp)
fp.close()

del stream_to_dump
for worker in (*self.cpu_workers, *self.gpu_workers):
Expand All @@ -1040,7 +995,10 @@ def run(self):
if isinstance(outputs, BaseException): # pragma: no cover
raise outputs

os.unlink(self.temp_file.name)
if self.cpu_temp_file is not None:
os.unlink(self.cpu_temp_file.name)
if self.gpu_temp_file is not None:
os.unlink(self.gpu_temp_file.name)

# Start listening for notifications from workers
self.dequeue_notifications_thread = threading.Thread(
Expand Down
88 changes: 88 additions & 0 deletions edsnlp/utils/torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copyreg
import math
import warnings
from enum import Enum
from typing import TypeVar

import dill
import torch

# filter "is in beta" torch warnings
Expand Down Expand Up @@ -91,3 +93,89 @@ def make_windows(lengths, size, stride):
)
indexer %= len(scores)
return windows.masked_fill(~windows_mask, -1), indexer


def reduce_empty(*args, **kwargs):
return type(None), ()


def load_pruned_obj(obj, _):
return obj


# Torch may still be imported as a namespace package, so we can access the
# torch.save and torch.load functions

MAP_LOCATION = None


try:
from accelerate.hooks import AlignDevicesHook

# We need to replace the "execution_device" attribute of the AlignDevicesHook
# using map_location when unpickling the stream

def save_align_devices_hook(pickler, obj):
pickler.save_reduce(load_align_devices_hook, (obj.__dict__,), obj=obj)

def load_align_devices_hook(state):
state["execution_device"] = MAP_LOCATION
new_obj = AlignDevicesHook.__new__(AlignDevicesHook)
new_obj.__dict__.update(state)
return new_obj

except ImportError:
AlignDevicesHook = None


def dump(
*args,
skip_tensors: bool = False,
**kwargs,
):
# We need to replace the "execution_device" attribute of the AlignDevicesHook
# using map_location when pickling the stream
old = None
old_settings = dict(dill.settings)
old_dispatch = {}
try:
if skip_tensors:
if torch.Tensor in copyreg.dispatch_table:
old_dispatch[torch.Tensor] = copyreg.dispatch_table[torch.Tensor]
copyreg.pickle(torch.Tensor, reduce_empty)
if AlignDevicesHook is not None:
old = dill.Pickler.dispatch.get(AlignDevicesHook)
dill.Pickler.dispatch[AlignDevicesHook] = save_align_devices_hook
dill.settings["recurse"] = True
dill.settings["byref"] = True
return torch.save(*args, pickle_module=dill, **kwargs)
finally:
dill.settings.update(old_settings)
if AlignDevicesHook is not None:
del dill.Pickler.dispatch[AlignDevicesHook]
if old is not None: # pragma: no cover
dill.Pickler.dispatch[AlignDevicesHook] = old
copyreg.dispatch_table.pop(torch.Tensor, None)
copyreg.dispatch_table.update(old_dispatch)


def load(*args, map_location=None, **kwargs):
global MAP_LOCATION
MAP_LOCATION = map_location
if torch.__version__ >= "2.1" and isinstance(args[0], str):
kwargs["mmap"] = True
try:
if torch.__version__ < "2.0.0":
torch.load.__globals__["pickle"] = dill
result = torch.load(
*args,
pickle_module=dill,
map_location=map_location,
**kwargs,
)
finally:
import pickle

torch.load.__globals__["pickle"] = pickle
MAP_LOCATION = None
return result

0 comments on commit f0c226e

Please sign in to comment.