Skip to content

Commit

Permalink
[Platforms] Add device_type in Platform (#10508)
Browse files Browse the repository at this point in the history
Signed-off-by: MengqingCao <[email protected]>
  • Loading branch information
MengqingCao authored Nov 21, 2024
1 parent 6c1208d commit 9d82717
Show file tree
Hide file tree
Showing 10 changed files with 11 additions and 15 deletions.
17 changes: 2 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,21 +1193,8 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif current_platform.is_neuron():
self.device_type = "neuron"
elif current_platform.is_hpu():
self.device_type = "hpu"
elif current_platform.is_openvino():
self.device_type = "openvino"
elif current_platform.is_tpu():
self.device_type = "tpu"
elif current_platform.is_cpu():
self.device_type = "cpu"
elif current_platform.is_xpu():
self.device_type = "xpu"
else:
self.device_type = current_platform.device_type
if self.device_type is None:
raise RuntimeError("Failed to infer device type")
else:
# Device type is assigned explicitly
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_type: str = "cpu"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:

class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"

@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_type: str = "hpu"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def to_int(self) -> int:

class Platform:
_enum: PlatformEnum
device_type: str

def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_type: str = "neuron"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
device_type: str = "openvino"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_type: str = "cuda"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_type: str = "tpu"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
device_type: str = "xpu"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down

0 comments on commit 9d82717

Please sign in to comment.