Skip to content

Commit

Permalink
Add getter and setter methods for compile_backend across accelerators. (
Browse files Browse the repository at this point in the history
#5299)

Add getter and setter methods for `compile_backend` across accelerators,
which provide a mechanism to retrieve the compile backend. These APIs
handle user-defined backend selection and raise a `ValueError` with
informative error messages for unsupported backends.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Apr 24, 2024
1 parent fbdf0ea commit fa8458b
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 9 deletions.
9 changes: 9 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC):
def __init__(self):
self._name = None
self._communication_backend_name = None
self._compile_backend = None

@abc.abstractmethod
def is_synchronized_device(self):
Expand Down Expand Up @@ -295,3 +296,11 @@ def visible_devices_envs(self):
@abc.abstractmethod
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
...

@abc.abstractmethod
def get_compile_backend(self):
...

@abc.abstractmethod
def set_compile_backend(self, backend):
...
12 changes: 12 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'cpu'
self._compile_backend = "inductor"
if oneccl_imported_p:
self._communication_backend_name = 'ccl'
else:
Expand Down Expand Up @@ -330,3 +331,14 @@ def visible_devices_envs(self):
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
12 changes: 12 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
self._compile_backend = "inductor"
if pynvml is None:
self._init_pynvml()

Expand Down Expand Up @@ -367,3 +368,14 @@ def visible_devices_envs(self):
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
12 changes: 12 additions & 0 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'hpu'
self._communication_backend_name = 'hccl'
self._compile_backend = "hpu_backend"
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
Expand Down Expand Up @@ -301,3 +302,14 @@ def visible_devices_envs(self):
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
12 changes: 12 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = "mps"
self._communication_backend_name = None
self._compile_backend = "inductor"

def is_synchronized_device(self):
return False
Expand Down Expand Up @@ -267,3 +268,14 @@ def visible_devices_envs(self):
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
12 changes: 12 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self):
super().__init__()
self._name = 'npu'
self._communication_backend_name = 'hccl'
self._compile_backend = "inductor"
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
Expand Down Expand Up @@ -285,3 +286,14 @@ def visible_devices_envs(self):
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")
12 changes: 12 additions & 0 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class XPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'xpu'
self._communication_backend_name = 'ccl'
self._compile_backend = "inductor"
self.aligned_tensors = []

def is_synchronized_device(self):
Expand Down Expand Up @@ -296,3 +297,14 @@ def visible_devices_envs(self):
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
4 changes: 1 addition & 3 deletions tests/unit/runtime/compile/test_compile_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ def base_config():
},
"compile": {
"enabled": True,
"backend": "inductor"
"backend": get_accelerator().get_compile_backend()
}
}
if get_accelerator().device_name() == 'hpu':
config_dict['compile']['backend'] = 'hpu_backend'
return config_dict


Expand Down
4 changes: 1 addition & 3 deletions tests/unit/runtime/compile/test_compile_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
},
"compile": {
"enabled": True,
"backend": "inductor"
"backend": get_accelerator().get_compile_backend()
}
}

if get_accelerator().device_name() == 'hpu':
config_dict['compile']['backend'] = 'hpu_backend'
if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/runtime/compile/test_load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ def base_config():
},
"compile": {
"enabled": True,
"backend": "inductor"
"backend": get_accelerator().get_compile_backend()
}
}

if get_accelerator().device_name() == 'hpu':
config_dict['compile']['backend'] = 'hpu_backend'
return config_dict


Expand Down

0 comments on commit fa8458b

Please sign in to comment.