From f747802d057c559c2816a780d2a5cd9efb715afc Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 25 Sep 2023 03:21:59 -0700 Subject: [PATCH 1/4] initiliaze sycl support --- flexgen/profile_matmul.py | 5 +- flexgen/pytorch_backend.py | 138 ++++++++++++++++++++++++++----------- flexgen/utils.py | 14 ++-- flexgen/xpu_utils.py | 67 ++++++++++++++++++ 4 files changed, 178 insertions(+), 46 deletions(-) create mode 100644 flexgen/xpu_utils.py diff --git a/flexgen/profile_matmul.py b/flexgen/profile_matmul.py index 59772e71..419f9906 100644 --- a/flexgen/profile_matmul.py +++ b/flexgen/profile_matmul.py @@ -7,13 +7,16 @@ import torch from flexgen.profile_bandwidth import benchmark_func +from flexgen.xpu_utils import is_xpu_available def bench_matmul(): - for device in ["cuda", "cpu"]: + for device in ["cuda", "cpu", "xpu"]: for n in [1024, 2048]: if device == "cuda": dtype = torch.float16 + elif device == "xpu" and is_xpu_available(): + dtype = torch.bfloat16 else: dtype = torch.float32 diff --git a/flexgen/pytorch_backend.py b/flexgen/pytorch_backend.py index 7f341849..a1531bb7 100644 --- a/flexgen/pytorch_backend.py +++ b/flexgen/pytorch_backend.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F import numpy as np +from flexgen.xpu_utils import is_xpu_available from flexgen.utils import (GB, T, cpu_mem_stats, vector_gather, np_dtype_to_torch_dtype, torch_dtype_to_np_dtype, @@ -32,6 +33,7 @@ def fix_recursive_import(): class DeviceType(Enum): CPU = auto() CUDA = auto() + XPU = auto() DISK = auto() MIXED = auto() COMPRESSED = auto() @@ -48,6 +50,8 @@ def convert(name): return DeviceType.MIXED elif name == "compressed": return DeviceType.COMPRESSED + elif name == "xpu" and is_xpu_available(): + return DeviceType.XPU else: raise ValueError(f"Invalid name: {name}") @@ -422,8 +426,11 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, else: q = q.float().cpu() k, v = k.float(), v.float() - value = self._attention_value(q, k, v, attention_mask.data, - b, src_s, tgt_s, n_head, head_dim).cuda().half() + if is_xpu_available(): + value = self._attention_value(q, k, v, attention_mask.data, b, src_s, tgt_s, n_head, head_dim).xpu().half() + else: + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim).cuda().half() else: # Sparse attention # shape: (s, b * n_head, head_dim) k = k_cache.data[:src_s] @@ -437,9 +444,14 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, attn_sparsity) else: q = q.float().cpu() - value = self._sparse_attention_value(q, k, v_new, v_cache, - attention_mask.data, b, src_s, tgt_s, n_head, head_dim, - attn_sparsity).cuda().half() + if is_xpu_available(): + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity).xpu().half() + else: + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity).cuda().half() else: # Mixed device attention assert attn_sparsity >= 1.0 value = self._mixed_device_attention(q, k_cache, v_cache, @@ -541,8 +553,11 @@ def _mixed_device_attention(self, q, k_cache, v_cache, k_new, v_new, k_gpu = k_gpu.permute(1, 2, 0) # shape: (b * n_head, s, head_dim) v_gpu = v_gpu.permute(1, 0, 2) - - mask_gpu = mask[:b_gpu].cuda() + + if is_xpu_available(): + mask_gpu = mask[:b_gpu].xpu() + else: + mask_gpu = mask[:b_gpu].cuda() value_gpu = self._attention_value(q_gpu, k_gpu, v_gpu, mask_gpu, b_gpu, src_s, tgt_s, n_head, head_dim) @@ -562,8 +577,11 @@ def _mixed_device_attention(self, q, k_cache, v_cache, k_new, v_new, mask_cpu = mask[b_gpu:] value_cpu = self._attention_value(q_cpu, k_cpu, v_cpu, mask_cpu, b_cpu, src_s, tgt_s, n_head, head_dim) - - value = torch.cat([value_gpu, value_cpu.cuda().half()], dim=0) + + if is_xpu_available(): + value = torch.cat([value_gpu, value_cpu.xpu().half()], dim=0) + else: + value = torch.cat([value_gpu, value_cpu.cuda().half()], dim=0) return value def mlp(self, inputs, wi, bi, wo, bo, w_ln, b_ln, donate): @@ -584,22 +602,28 @@ def mlp(self, inputs, wi, bi, wo, bo, w_ln, b_ln, donate): return TorchTensor.create_from_torch(out, self) def synchronize(self): - torch.cuda.synchronize() + if is_xpu_available(): + torch.xpu.synchronize() + else: + torch.cuda.synchronize() def mem_stats(self): if self.device_type == DeviceType.CUDA: cur_mem = torch.cuda.memory_allocated(self.dev) peak_mem = torch.cuda.max_memory_allocated(self.dev) + elif self.device_type == DeviceType.XPU: + cur_mem = torch.xpu.memory_allocated(self.dev) + peak_mem = torch.xpu.max_memory_allocated(self.dev) elif self.device_type == DeviceType.CPU: cur_mem = cpu_mem_stats() - peak_mem = 0 + peak_mem = 0 else: raise NotImplementedError() return cur_mem, peak_mem def print_stats(self, output_file=None): - torch.cuda.synchronize() + self.synchronize() cur_mem, peak_mem = self.mem_stats() if output_file is not None: @@ -621,7 +645,7 @@ def __str__(self): class TorchDisk: """Manage tensors stored on a disk.""" - def __init__(self, path, mem_capacity=None, cuda_id=0, num_copy_threads=4): + def __init__(self, path, mem_capacity=None, cuda_id=0, xpu_id =0, num_copy_threads=4): self.name = path self.path = os.path.abspath(os.path.expanduser(path)) self.mem_capacity = mem_capacity @@ -640,7 +664,7 @@ def __init__(self, path, mem_capacity=None, cuda_id=0, num_copy_threads=4): self.copy_queue = queue.Queue() self.copy_threads = [ threading.Thread( - target=copy_worker_func, args=(self.copy_queue, cuda_id) + target=copy_worker_func, args=(self.copy_queue, cuda_id, xpu_id) ) for _ in range(num_copy_threads) ] for t in self.copy_threads: @@ -834,14 +858,15 @@ def general_copy(dst: TorchTensor, dst_indices: Tuple[slice], elif dst.device.device_type == DeviceType.DISK: # The tensor is on the disk, dispatch to copy threads for asynchronous copy dst.device.submit_copy(dst, dst_indices, src, src_indices) - elif (src.device.device_type == DeviceType.CUDA and - dst.device.device_type == DeviceType.CPU and + elif ((src.device.device_type == DeviceType.CUDA or src.device.device_type == DeviceType.XPU) + and dst.device.device_type == DeviceType.CPU and not dst.data.is_pinned() and src.shape[0] > 1): # The cpu tensor is not pinned, dispatch to copy threads and use pin_memory # as a relay global_disk_device.submit_copy(dst, dst_indices, src, src_indices) elif (src.device.device_type == DeviceType.CPU and - dst.device.device_type == DeviceType.CUDA and + (dst.device.device_type == DeviceType.CUDA or + dst.device.device_type == DeviceType.XPU) and not src.data.is_pinned()): # The cpu tensor is not pinned, use pin_memory as a relay src = src.data[src_indices] if src_indices else src.data @@ -875,32 +900,63 @@ def map_to_torch_tensor(tensor, indices): return data[indices] if indices else data -def copy_worker_func(queue, cuda_id): +def copy_worker_func(queue, cuda_id, xpu_id): """The copy worker thread.""" - torch.cuda.set_device(cuda_id) + if is_xpu_available(): + torch.xpu.set_device(xpu_id) + else: + torch.cuda.set_device(cuda_id) cpu_buf = torch.empty((1 * GB,), dtype=torch.float16, pin_memory=True) - copy_stream = torch.cuda.Stream() - - with torch.cuda.stream(copy_stream): - while True: - item = queue.get() - if item is None: + if is_xpu_available(): + copy_stream = torch.xpu.Stream() + + with torch.xpu.stream(copy_stream): + while True: + item = queue.get() + if item is None: + queue.task_done() + return + + dst, dst_indices, src, src_indices = item + src_data = map_to_torch_tensor(src, src_indices) + dst_data = map_to_torch_tensor(dst, dst_indices) + + if (src.device.device_type == DeviceType.XPU or + dst.device.device_type == DeviceType.XPU): + # Use a pinned cpu buffer as a relay + size = np.prod(src_data.shape) + tmp_cpu_buf = cpu_buf[:size].view(src_data.shape) + tmp_cpu_buf.copy_(src_data) + dst_data.copy_(tmp_cpu_buf) + else: + dst_data.copy_(src_data) + queue.task_done() - return - - dst, dst_indices, src, src_indices = item - src_data = map_to_torch_tensor(src, src_indices) - dst_data = map_to_torch_tensor(dst, dst_indices) - - if (src.device.device_type == DeviceType.CUDA or - dst.device.device_type == DeviceType.CUDA): - # Use a pinned cpu buffer as a relay - size = np.prod(src_data.shape) - tmp_cpu_buf = cpu_buf[:size].view(src_data.shape) - tmp_cpu_buf.copy_(src_data) - dst_data.copy_(tmp_cpu_buf) - else: - dst_data.copy_(src_data) - queue.task_done() + + else: + copy_stream = torch.cuda.Stream() + + with torch.cuda.stream(copy_stream): + while True: + item = queue.get() + if item is None: + queue.task_done() + return + + dst, dst_indices, src, src_indices = item + src_data = map_to_torch_tensor(src, src_indices) + dst_data = map_to_torch_tensor(dst, dst_indices) + + if (src.device.device_type == DeviceType.CUDA or + dst.device.device_type == DeviceType.CUDA): + # Use a pinned cpu buffer as a relay + size = np.prod(src_data.shape) + tmp_cpu_buf = cpu_buf[:size].view(src_data.shape) + tmp_cpu_buf.copy_(src_data) + dst_data.copy_(tmp_cpu_buf) + else: + dst_data.copy_(src_data) + + queue.task_done() diff --git a/flexgen/utils.py b/flexgen/utils.py index 7e62c9bb..58af0adf 100644 --- a/flexgen/utils.py +++ b/flexgen/utils.py @@ -10,6 +10,7 @@ import numpy as np import torch +from flexgen.xpu_utils import is_xpu_available KB = 1 << 10 @@ -35,6 +36,7 @@ class Task: class ExecutionEnv: """Hardware environment.""" gpu: Any = None + xpu: Any = None cpu: Any = None disk: Any = None mixed: Any = None @@ -45,8 +47,9 @@ def create(cls, offload_dir): from flexgen.pytorch_backend import TorchDevice, TorchDisk, TorchMixedDevice gpu = TorchDevice("cuda:0") cpu = TorchDevice("cpu") + xpu = TorchDevice("xpu:0") disk = TorchDisk(offload_dir) - return cls(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) + return cls(gpu=gpu, xpu=xpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, xpu, disk])) def close_copy_threads(self): self.disk.close_copy_threads() @@ -70,13 +73,13 @@ class BenchmarkResult: } torch_dtype_to_np_dtype = { - torch.float16: np.float16, torch.float32: np.float32, + torch.float16: np.float16, torch.float32: np.float32, torch.uint8: np.uint8, torch.int8: np.int8, torch.int32: np.int32, torch.int64: np.int64, torch.bool: bool, } torch_dtype_to_num_bytes = { - torch.float16: 2, torch.float32: 4, + torch.float16: 2, torch.float32: 4, torch.bfloat16: 2, torch.int8: 1, torch.uint8: 1, torch.int32: 4, torch.int64: 8, torch.bool: 1, } @@ -145,7 +148,10 @@ def cpu_mem_stats(): def torch_mem_stats(): objects = gc.get_objects() - tensors = [obj for obj in objects if torch.is_tensor(obj) and obj.is_cuda] + if is_xpu_available(): + tensors = [obj for obj in objects if torch.is_tensor(obj) and obj.is_xpu()]x + else: + tensors = [obj for obj in objects if torch.is_tensor(obj) and obj.is_cuda] total_numel = 0 total_mem = 0 diff --git a/flexgen/xpu_utils.py b/flexgen/xpu_utils.py new file mode 100644 index 00000000..082a2707 --- /dev/null +++ b/flexgen/xpu_utils.py @@ -0,0 +1,67 @@ +import torch +import importlib +import importlib.metadata +import os +import warnings +from functools import lru_cache + +from packaging import version +from packaging.version import parse + + + +def is_ipex_available(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + _torch_version = importlib.metadata.version("torch") + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + _ipex_version = "N/A" + try: + _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") + except importlib.metadata.PackageNotFoundError: + return False + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + warnings.warn( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True + + +@lru_cache +def is_xpu_available(check_device=False): + "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" + if not is_ipex_available(): + return False + + import intel_extension_for_pytorch # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def is_ccl_available(): + ccl_version = "N/A" + try: + _is_ccl_available = ( + importlib.util.find_spec("torch_ccl") is not None + or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None + ) + + ccl_version = importlib.metadata.version("oneccl_bind_pt") + print(f"Detected oneccl_bind_pt version {ccl_version}") + except importlib.metadata.PackageNotFoundError: + _is_ccl_available = False + return False + return _is_ccl_available \ No newline at end of file From 5de76d985ad2f118b2e619d5a664a9dd3215a003 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 26 Sep 2023 00:54:24 -0700 Subject: [PATCH 2/4] extend support for flex_opt --- flexgen/flex_opt.py | 173 ++++++++++++++++++++++++++++--------- flexgen/pytorch_backend.py | 39 +++++++-- flexgen/utils.py | 5 +- 3 files changed, 166 insertions(+), 51 deletions(-) diff --git a/flexgen/flex_opt.py b/flexgen/flex_opt.py index 24f87a90..9660934f 100644 --- a/flexgen/flex_opt.py +++ b/flexgen/flex_opt.py @@ -24,6 +24,7 @@ array_1d, array_2d, array_3d, str2bool, project_decode_latency, torch_mem_stats, torch_dtype_to_np_dtype, write_benchmark_log, read_benchmark_log) +from flexgen.xpu_utils import is_xpu_available fix_recursive_import() @@ -34,6 +35,8 @@ class Policy: gpu_batch_size: int num_gpu_batches: int + xpu_batch_size: int = None + num_xpu_batches: int = None # percent = a means a% w_gpu_percent: float @@ -42,7 +45,10 @@ class Policy: cache_cpu_percent: float act_gpu_percent: float act_cpu_percent: float - + w_xpu_percent: float = None + cache_xpu_percent: float = None + act_xpu_percent: float = None + # Whether to overlap the I/O and compute overlap: bool @@ -68,14 +74,20 @@ class Policy: @property def w_disk_percent(self): + if is_xpu_available(): + return 100 - self.w_xpu_percent -self.w_cpu_percent return 100 - self.w_gpu_percent - self.w_cpu_percent @property def cache_disk_percent(self): + if is_xpu_available(): + return 100 - self.cache_xpu_percent - self.cache_cpu_percent return 100 - self.cache_gpu_percent - self.cache_cpu_percent @property def act_disk_percent(self): + if is_xpu_available(): + return 100 - self.act_xpu_percent - self.act_cpu_percent return 100 - self.act_gpu_percent - self.act_cpu_percent @@ -92,6 +104,9 @@ def get_choice(cur_percent, percents, choices): def init_weight_list(weight_specs, policy, env): dev_percents = [policy.w_disk_percent, policy.w_cpu_percent, policy.w_gpu_percent] dev_choices = [env.disk, env.cpu, env.gpu] + if is_xpu_available(): + dev_percents.append(policy.w_xpu_percent) + dev_choices.append(env.xpu) sizes = [np.prod(spec[0]) for spec in weight_specs] sizes_cumsum = np.cumsum(sizes) @@ -136,7 +151,10 @@ def __init__(self, config, env, policy): self.config = config self.env = env self.policy = policy - self.compute = self.env.gpu + if is_xpu_available(): + self.compute = self.env.xpu + else: + self.compute = self.env.gpu self.weight_load_dst = (self.compute.compressed_device if policy.compress_weight else self.compute) @@ -183,8 +201,7 @@ def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, donate = [False] * 4 h, donate[0] = hidden.val, True mask, donate[1] = attention_mask.val.smart_copy(self.compute) - - if k == self.policy.num_gpu_batches - 1: + if k == self.policy.num_gpu_batches - 1 or (k == self.policy.num_xpu_batches - 1 and is_xpu_available()): # Clear the weight_read_buf if it is the last gpu batch (w_token, donate[2]), (w_pos, donate[3]) = weight_read_buf.pop() else: @@ -200,7 +217,10 @@ def __init__(self, config, env, policy): self.config = config self.env = env self.policy = policy - self.compute = self.env.gpu + if is_xpu_available(): + self.compute = self.env.xpu + else: + self.compute = self.env.gpu self.weight_load_dst = (self.compute.compressed_device if policy.compress_weight else self.compute) @@ -249,8 +269,7 @@ def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, cache_write_buf, i, k): donate = [False] * 4 h, donate[0] = hidden.val, True - - if k == self.policy.num_gpu_batches - 1: + if k == self.policy.num_gpu_batches - 1 or (k == self.policy.num_xpu_batches - 1 and is_xpu_available()): # Clear the weight_read_buf if it is the last gpu batch (w_ln, donate[1]), (b_ln, donate[2]), (w_token, donate[3]) = weight_read_buf.pop() else: @@ -267,11 +286,18 @@ def __init__(self, config, env, policy, layer_id): self.env = env self.layer_id = layer_id self.policy = policy - self.compute = self.env.gpu + if is_xpu_available(): + self.compute = self.env.xpu + else: + self.compute = self.env.gpu self.weight_load_dst = (self.compute.compressed_device if policy.compress_weight else self.compute) - self.attention_compute = (self.env.cpu if self.policy.cpu_cache_compute - else self.env.gpu) + if is_xpu_available(): + self.attention_compute = (self.env.cpu if self.policy.cpu_cache_compute + else self.env.xpu) + else: + self.attention_compute = (self.env.cpu if self.policy.cpu_cache_compute + else self.env.gpu) self.task = None @@ -325,6 +351,8 @@ def init_cache_one_gpu_batch(self, cache_home): device = self.env.cpu elif self.policy.cache_disk_percent == 100: device = self.env.disk + elif self.policy.cache_xpu_percent == 100 and is_xpu_available(): + device = self.env.xpu else: device = self.env.mixed @@ -431,7 +459,7 @@ def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, donate = [False] * 14 h, donate[0] = hidden.val, True - if k == self.policy.num_gpu_batches - 1: + if k == self.policy.num_gpu_batches - 1 or (k == self.policy.num_xpu_batches - 1 and is_xpu_available()): # Clear the weight_read_buf if it is the last gpu batch ((w_q, donate[2]), (b_q, donate[3]), (w_k, donate[4]), (b_k, donate[5]), (w_v, donate[6]), (b_v, donate[7]), (w_out, donate[8]), (b_out, donate[9]), @@ -465,7 +493,10 @@ def __init__(self, config, env, policy, layer_id): self.env = env self.layer_id = layer_id self.policy = policy - self.compute = self.env.gpu + if is_xpu_available(): + self.compute = self.env.xpu + else: + self.compute = self.env.gpu self.weight_load_dst = (self.compute.compressed_device if policy.compress_weight else self.compute) @@ -521,7 +552,7 @@ def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, donate = [False] * 7 h, donate[0] = hidden.val, True - if k == self.policy.num_gpu_batches - 1: + if k == self.policy.num_gpu_batches - 1 or (k == self.policy.num_xpu_batches and is_xpu_available()): # Clear the weight_read_buf if it is the last gpu batch ((wi, donate[1]), (bi, donate[2]), (wo, donate[3]), (bo, donate[4]), (w_ln, donate[5]), (b_ln, donate[6])) = weight_read_buf.pop() @@ -569,7 +600,7 @@ def store_cache(self, cache_home, cache_write_buf, i): def forward(self, hidden, cache_read_buf, weight_read_buf, attention_mask, cache_write_buf, i, k): - if k == self.policy.num_gpu_batches - 1: + if k == self.policy.num_gpu_batches - 1 or (k == self.policy.num_xpu_batches and is_xpu_available()): read_buf1, read_buf2 = weight_read_buf.pop() else: read_buf1, read_buf2 = weight_read_buf.val @@ -591,7 +622,10 @@ def __init__(self, self.env = env self.path = path self.policy = policy - self.num_gpu_batches = policy.num_gpu_batches + if is_xpu_available(): + self.num_xpu_batches = policy.num_xpu_batches + else: + self.num_gpu_batches = policy.num_gpu_batches layers = [] layers.append(InputEmbed(self.config, self.env, self.policy)) @@ -611,18 +645,29 @@ def __init__(self, self.act_home = self.env.cpu elif self.policy.act_disk_percent == 100: self.act_home = self.env.disk + elif self.policy.act_xpu_percent == 100: + self.act_home = self.env.xpu else: raise NotImplementedError() - # CUDA streams - self.load_weight_stream = torch.cuda.Stream() - self.load_cache_stream = torch.cuda.Stream() - self.store_cache_stream = torch.cuda.Stream() + if is_xpu_available(): + self.load_weight_stream = torch.xpu.Stream() + self.load_cache_stream = torch.xpu.Stream() + self.store_cache_stream = torch.xpu.Stream() + else: + # CUDA streams + self.load_weight_stream = torch.cuda.Stream() + self.load_cache_stream = torch.cuda.Stream() + self.store_cache_stream = torch.cuda.Stream() + # Intermediate tensors # The following buffers store values used # for the i-th token, j-th layer, k-th gpu batch. - num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches + if is_xpu_available(): + num_layers, num_gpu_batches = self.num_layers, self.policy.num_xpu_batches + else: + num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches # cache[j][k] self.cache_home = array_2d(num_layers, num_gpu_batches, ValueHolder) @@ -660,8 +705,12 @@ def load_weight(self, i, j, k, overlap=True): # Load from weight_home to weight_read_buf if overlap: - with torch.cuda.stream(self.load_weight_stream): - self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) + if is_xpu_available(): + with torch.xpu.stream(self.load_weight_stream): + self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) + else: + with torch.cuda.stream(self.load_weight_stream): + self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) else: self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) @@ -692,8 +741,12 @@ def load_cache(self, i, j, k, overlap=True): # Load from cache_home to cache_read_buf if overlap: - with torch.cuda.stream(self.load_cache_stream): - self.layers[j].load_cache(self.cache_home[j][k], self.cache_read_buf[j][k], i) + if is_xpu_available(): + with torch.xpu.stream(self.load_cache_stream): + self.layers[j].load_cache(self.cache_home[j][k], self.cache_read_buf[j][k], i) + else: + with torch.cuda.stream(self.load_cache_stream): + self.layers[j].load_cache(self.cache_home[j][k], self.cache_read_buf[j][k], i) else: self.layers[j].load_cache(self.cache_home[j][k], self.cache_read_buf[j][k], i) @@ -714,8 +767,12 @@ def store_cache(self, i, j, k, overlap=True): # Store cache_write_buf to cache_home # Delete cache_write_buf if overlap: - with torch.cuda.stream(self.store_cache_stream): - self.layers[j].store_cache(self.cache_home[j][k], self.cache_write_buf[j][k], i) + if is_xpu_available(): + with torch.xpu.stream(self.store_cache_stream): + self.layers[j].store_cache(self.cache_home[j][k], self.cache_write_buf[j][k], i) + else: + with torch.cuda.stream(self.store_cache_stream): + self.layers[j].store_cache(self.cache_home[j][k], self.cache_write_buf[j][k], i) else: self.layers[j].store_cache(self.cache_home[j][k], self.cache_write_buf[j][k], i) @@ -739,7 +796,10 @@ def load_hidden(self, i, j, k): # Load to hidden states buffers dst = self.layers[j].compute if j == 0: - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size left, right = k * gpu_batch_size, (k + 1) * gpu_batch_size if i == 0: # load from the input ids val = dst.allocate((gpu_batch_size, self.task.prompt_len), np.int32) @@ -765,7 +825,10 @@ def store_hidden(self, i, j, k): # Store to hidden states buffers if j == self.num_layers - 1: # store to output - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size left, right = k * gpu_batch_size, (k + 1) * gpu_batch_size ids = self.hidden[i][j][k].pop().data.detach().cpu().numpy() pos = self.task.prompt_len + i @@ -792,7 +855,10 @@ def compute_layer(self, i, j, k): def sync(self): self.env.disk.synchronize() - torch.cuda.synchronize() + if is_xpu_available(): + torch.xpu.synchronize() + else: + torch.cuda.synchronize() def init_all_weights(self): self.weight_home = array_1d(self.num_layers, ValueHolder) @@ -809,8 +875,11 @@ def update_attention_mask(self, i, k): assert mask.val is not None mask.val = mask.val.device.extend_attention_mask(mask.val, [True]) return - - gpu_batch_size = self.policy.gpu_batch_size + + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size left = k * gpu_batch_size right = left + gpu_batch_size input_ids = self.output_ids[left:right, :self.task.prompt_len] @@ -841,8 +910,12 @@ def generate(self, stop=stop, ) num_layers = self.num_layers - num_gpu_batches = self.num_gpu_batches - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + num_gpu_batches = self.num_xpu_batches + gpu_batch_size = self.policy.xpu_batch_size + else: + num_gpu_batches = self.num_gpu_batches + gpu_batch_size = self.policy.gpu_batch_size overlap = self.policy.overlap prompt_len, gen_len = task.prompt_len, task.gen_len self.execute_gen_len = task.cut_gen_len if task.cut_gen_len else task.gen_len @@ -857,7 +930,10 @@ def generate(self, # Intermediate tensors # The following buffers store values used # for the i-th token, j-th layer, k-th gpu batch. - num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches + if is_xpu_available(): + num_layers, num_gpu_batches = self.num_layers, self.policy.num_xpu_batches + else: + num_layers, num_gpu_batches = self.num_layers, self.policy.num_gpu_batches for j in range(num_layers): for k in range(num_gpu_batches): self.cache_home[j][k].clear() @@ -1191,13 +1267,19 @@ def run_flexgen(args): gpu = TorchDevice("cuda:0") cpu = TorchDevice("cpu") + if args.ipex and is_xpu_available(): + xpu = TorchDevice("xpu:0") disk = TorchDisk(args.offload_dir) - env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) - + env = ExecutionEnv(gpu=gpu, xpu=xpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, xpu, cpu, disk])) + if args.ipex and is_xpu_available(): + args.xpu_batch_size, args.num_xpu_batches = args.gpu_batch_size, args.num_xpu_batches policy = Policy(args.gpu_batch_size, args.num_gpu_batches, + args.xpu_batch_size, args.num_xpu_batches, args.percent[0], args.percent[1], args.percent[2], args.percent[3], args.percent[4], args.percent[5], + args.percent[6], args.percent[7], + args.percent[8], args.overlap, args.sep_layer, args.pin_weight, args.cpu_cache_compute, args.attn_sparsity, args.compress_weight, @@ -1244,6 +1326,8 @@ def run_flexgen(args): total_latency = prefill_latency + decode_latency total_throughput = num_generated_tokens / total_latency _, gpu_peak_mem = gpu.mem_stats() + if args.ipex and is_xpu_available(): + _, xpu_peak_mem = xpu.mem_stats() _, cpu_peak_mem = cpu.mem_stats() if DUMMY_WEIGHT not in args.path: @@ -1256,6 +1340,8 @@ def run_flexgen(args): print(show_str) gpu.print_stats() + if args.ipex and is_xpu_available(): + xpu.print_stats() cpu.print_stats() projected = bool(args.debug_mode or cut_gen_len) @@ -1263,10 +1349,10 @@ def run_flexgen(args): filename = get_filename(args) + ".log" else: filename = args.log_file - + log_str = write_benchmark_log(filename, opt_config.model_bytes(), cache_size, hidden_size, - gpu_peak_mem, projected, prefill_latency, prefill_throughput, + xpu_peak_mem if (args.ipex and is_xpu_available()) else gpu_peak_mem, projected, prefill_latency, prefill_throughput, decode_latency, decode_throughput, total_latency, total_throughput) if args.verbose >= 1: print(log_str) @@ -1289,14 +1375,17 @@ def add_parser_arguments(parser): parser.add_argument("--gpu-batch-size", type=int, default=4) parser.add_argument("--num-gpu-batches", type=int, default=1) parser.add_argument("--percent", nargs="+", type=int, - default=[100, 0, 100, 0, 100, 0], - help="Six numbers. They are " + default=[100, 0, 100, 0, 100, 0, 100, 100, 100], + help="Nine numbers. They are " "the percentage of weight on GPU, " "the percentage of weight on CPU, " "the percentage of attention cache on GPU, " "the percentage of attention cache on CPU, " "the percentage of activations on GPU, " - "the percentage of activations on CPU") + "the percentage of activations on CPU, " + "the percentage of weight on XPU, " + "the percentage of attention cache on XPU, " + "the percentage of activations on XPU, ") parser.add_argument("--sep-layer", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--pin-weight", type=str2bool, nargs="?", @@ -1307,6 +1396,8 @@ def add_parser_arguments(parser): help="Whether to compress weight.") parser.add_argument("--compress-cache", action="store_true", help="Whether to compress cache.") + parser.add_argument("--ipex", action="store_true", + help="Whether to use xpu runtime on Intel GPU.") parser.add_argument("--log-file", type=str, default="auto") diff --git a/flexgen/pytorch_backend.py b/flexgen/pytorch_backend.py index a1531bb7..fe08411d 100644 --- a/flexgen/pytorch_backend.py +++ b/flexgen/pytorch_backend.py @@ -423,6 +423,9 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, if k.is_cuda: value = self._attention_value(q, k, v, attention_mask.data, b, src_s, tgt_s, n_head, head_dim) + elif k.is_xpu(): + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim) else: q = q.float().cpu() k, v = k.float(), v.float() @@ -442,6 +445,10 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, value = self._sparse_attention_value(q, k, v_new, v_cache, attention_mask.data, b, src_s, tgt_s, n_head, head_dim, attn_sparsity) + elif k.is_xpu(): + value = self._sparse_attention_value(q, k, v_new, v_cache, + attention_mask.data, b, src_s, tgt_s, n_head, head_dim, + attn_sparsity) else: q = q.float().cpu() if is_xpu_available(): @@ -689,9 +696,14 @@ def delete(self, tensor): os.remove(tensor.data) def init_cache_one_gpu_batch(self, config, task, policy): - num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( - config.n_head, config.input_dim, task.prompt_len, task.gen_len, - policy.gpu_batch_size) + if is_xpu_available(): + num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( + config.n_head, config.input_dim, task.prompt_len, task.gen_len, + policy.xpu_batch_size) + else: + num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( + config.n_head, config.input_dim, task.prompt_len, task.gen_len, + policy.gpu_batch_size) shape = (prompt_len + gen_len - 1, gpu_batch_size * num_head, hidden_size // num_head) k_cache = self.allocate(shape, np.float16) v_cache = self.allocate(shape, np.float16) @@ -760,18 +772,29 @@ def delete(self, tensor): x.delete() def init_cache_one_gpu_batch(self, config, task, policy): - num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( - config.n_head, config.input_dim, task.prompt_len, task.gen_len, - policy.gpu_batch_size) + if is_xpu_available(): + num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( + config.n_head, config.input_dim, task.prompt_len, task.gen_len, + policy.xpu_batch_size) + else: + num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( + config.n_head, config.input_dim, task.prompt_len, task.gen_len, + policy.gpu_batch_size) shape = (prompt_len + gen_len - 1, gpu_batch_size * num_head, hidden_size // num_head) # We have to round to a multiple of `num_head` if policy.cache_disk_percent == 0: - len_gpu = int(shape[SEG_DIM] * policy.cache_gpu_percent / 100) // num_head * num_head + if is_xpu_available(): + len_gpu = int(shape[SEG_DIM] * policy.cache_xpu_percent / 100) // num_head * num_head + else: + len_gpu = int(shape[SEG_DIM] * policy.cache_gpu_percent / 100) // num_head * num_head len_cpu = shape[SEG_DIM] - len_gpu len_disk = 0 else: - len_gpu = int(shape[SEG_DIM] * policy.cache_gpu_percent / 100) // num_head * num_head + if is_xpu_available(): + len_gpu = int(shape[SEG_DIM] * policy.cache_xpu_percent / 100) // num_head * num_head + else: + len_gpu = int(shape[SEG_DIM] * policy.cache_gpu_percent / 100) // num_head * num_head len_cpu = int(shape[SEG_DIM] * policy.cache_cpu_percent / 100) // num_head * num_head len_disk = shape[SEG_DIM] - len_gpu - len_cpu lens = [len_gpu, len_cpu, len_disk] diff --git a/flexgen/utils.py b/flexgen/utils.py index 58af0adf..74051ffe 100644 --- a/flexgen/utils.py +++ b/flexgen/utils.py @@ -47,9 +47,10 @@ def create(cls, offload_dir): from flexgen.pytorch_backend import TorchDevice, TorchDisk, TorchMixedDevice gpu = TorchDevice("cuda:0") cpu = TorchDevice("cpu") - xpu = TorchDevice("xpu:0") + if is_xpu_available(): + xpu = TorchDevice("xpu:0") disk = TorchDisk(offload_dir) - return cls(gpu=gpu, xpu=xpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, xpu, disk])) + return cls(gpu=gpu, xpu=xpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, xpu, cpu, disk])) def close_copy_threads(self): self.disk.close_copy_threads() From a8460fb0568ea1c6a4d87dd38a45c4ea604a9c33 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 26 Sep 2023 04:38:41 -0700 Subject: [PATCH 3/4] support dist --- flexgen/compression.py | 33 ++++++++--- flexgen/dist_flex_opt.py | 117 ++++++++++++++++++++++++++++++--------- flexgen/dist_utils.py | 8 ++- flexgen/flex_opt.py | 2 +- 4 files changed, 123 insertions(+), 37 deletions(-) diff --git a/flexgen/compression.py b/flexgen/compression.py index 37350b9e..4a944e2b 100644 --- a/flexgen/compression.py +++ b/flexgen/compression.py @@ -6,7 +6,7 @@ from flexgen.pytorch_backend import (TorchTensor, TorchDevice, DeviceType, general_copy, fix_recursive_import) from flexgen.utils import np_dtype_to_torch_dtype - +from flexgen.xpu_utils import is_xpu_available @dataclasses.dataclass class CompressionConfig: @@ -49,9 +49,14 @@ def allocate(self, shape, dtype, comp_config, pin_memory=None, name=None): (data, scale, comp_config), self, name=name) def init_cache_one_gpu_batch(self, config, task, policy): - num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( - config.n_head, config.input_dim, task.prompt_len, task.gen_len, - policy.gpu_batch_size) + if is_xpu_available(): + num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( + config.n_head, config.input_dim, task.prompt_len, task.gen_len, + policy.xpu_batch_size) + else: + num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = ( + config.n_head, config.input_dim, task.prompt_len, task.gen_len, + policy.gpu_batch_size) shape = (prompt_len + gen_len - 1, gpu_batch_size * num_head, hidden_size // num_head) # NOTE: disable pin_memory due to high memory overhead pin_memory = False @@ -65,7 +70,10 @@ def init_attention_compute_workspace(self, config, task, policy): if self.base_device.device_type != DeviceType.CPU: return # Only CPU requires this fp32 workspace - b = policy.gpu_batch_size + if is_xpu_available(): + b = policy.xpu_batch_size + else: + b = policy.gpu_batch_size n_head = config.n_head head_dim = config.input_dim // n_head max_seq_len = task.prompt_len + task.gen_len - 1 @@ -334,7 +342,10 @@ def compress_and_decompress(tensor, config): def test_simulated_compression(): torch.manual_seed(0) - a = torch.normal(0, 1, (64, 64, 64), dtype=torch.float16).cuda() + if is_xpu_available(): + a = torch.normal(0, 1, (64, 64, 64), dtype=torch.float16).xpu() + else: + a = torch.normal(0, 1, (64, 64, 64), dtype=torch.float16).cuda() config = CompressionConfig( num_bits=4, group_size=32, group_dim=0, symmetric=False) @@ -346,11 +357,17 @@ def test_simulated_compression(): def test_real_compression(): torch.manual_seed(0) - a = torch.normal(0, 1, (32, 1, 1), dtype=torch.float16).cuda() + if is_xpu_available(): + a = torch.normal(0, 1, (32, 1, 1), dtype=torch.float16).xpu() + else: + a = torch.normal(0, 1, (32, 1, 1), dtype=torch.float16).cuda() config = CompressionConfig( num_bits=4, group_size=32, group_dim=0, symmetric=False) - dev = TorchDevice("cuda:0", 0, 0).compressed_device + if is_xpu_available(): + dev =TorchDevice("xpu:0", 0, 0).compressed_device + else: + dev = TorchDevice("cuda:0", 0, 0).compressed_device packed = dev.compress(a, config) b = dev.decompress(packed) diff --git a/flexgen/dist_flex_opt.py b/flexgen/dist_flex_opt.py index 593bb78f..4ee2ccd3 100644 --- a/flexgen/dist_flex_opt.py +++ b/flexgen/dist_flex_opt.py @@ -22,9 +22,12 @@ from flexgen.timer import timers from flexgen.utils import (Task, ExecutionEnv, GB, T, ValueHolder, array_1d, array_2d, array_3d, array_4d, str2bool, project_decode_latency) +from flexgen.xpu_utils import is_xpu_available, is_ccl_available #os.environ["NCCL_DEBUG"] = "TRACE" +#if is_ccl_available(): +#os.environ["CCL_LOG_LEVEL"] = "TRACE" class DistOptLM(OptLM): @@ -35,7 +38,10 @@ def __init__(self, config, env, path, policy, pipeline_rank, self.env = env self.path = path self.policy = policy - self.num_gpu_batches = self.policy.num_gpu_batches + if is_xpu_available(): + self.num_gpu_batches = self.policy.num_xpu_batches + else: + self.num_gpu_batches = self.policy.num_gpu_batches self.pipeline_rank = pipeline_rank self.num_pipeline_stages = num_pipeline_stages self.num_inner_iterations = num_inner_iterations if num_inner_iterations is not None else num_pipeline_stages @@ -44,9 +50,11 @@ def __init__(self, config, env, path, policy, pipeline_rank, self.comm_device = self.env.cpu elif comm_device == "gpu": self.comm_device = self.env.gpu + elif comm_device == "xpu" and is_xpu_available(): + self.comm_device = self.env.xpu else: raise ValueError(f"Invalid comm_device: {comm_device}") - + layers = [] if pipeline_rank == 0: layers.append(InputEmbed(self.config, self.env, self.policy)) @@ -69,6 +77,8 @@ def __init__(self, config, env, path, policy, pipeline_rank, if self.policy.act_gpu_percent == 100: self.act_home = self.env.gpu + elif self.policy.act_xpu_percent == 100 and is_xpu_available(): + self.act_home =self.env.xpu elif self.policy.act_cpu_percent == 100: self.act_home = self.env.cpu elif self.policy.act_disk_percent == 100: @@ -76,10 +86,15 @@ def __init__(self, config, env, path, policy, pipeline_rank, else: raise NotImplementedError() - # CUDA streams - self.load_weight_stream = torch.cuda.Stream() - self.load_cache_stream = torch.cuda.Stream() - self.store_cache_stream = torch.cuda.Stream() + if is_xpu_available(): + self.load_weight_stream = torch.xpu.Stream() + self.load_cache_stream = torch.xpu.Stream() + self.store_cache_stream = torch.xpu.Stream() + else: + # CUDA streams + self.load_weight_stream = torch.cuda.Stream() + self.load_cache_stream = torch.cuda.Stream() + self.store_cache_stream = torch.cuda.Stream() self.task = None self.init_all_weights() @@ -99,8 +114,12 @@ def load_weight(self, b, t, i, j, k): return # Load from weight_home to weight_read_buf - with torch.cuda.stream(self.load_weight_stream): - self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) + if is_xpu_available(): + with torch.xpu.stream(self.load_weight_stream): + self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) + else: + with torch.cuda.stream(self.load_weight_stream): + self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k) def init_cache(self, t, j, k): self.layers[j].init_cache_one_gpu_batch(self.cache_home[t][j][k]) @@ -120,8 +139,12 @@ def load_cache(self, t, i, j, k): return # Load from cache_home to cache_read_buf - with torch.cuda.stream(self.load_cache_stream): - self.layers[j].load_cache(self.cache_home[t][j][k], self.cache_read_buf[t][j][k], i) + if is_xpu_available(): + with torch.xpu.stream(self.load_cache_stream): + self.layers[j].load_cache(self.cache_home[t][j][k]), self.cache_read_buf[t][j][k], i) + else: + with torch.cuda.stream(self.load_cache_stream): + self.layers[j].load_cache(self.cache_home[t][j][k], self.cache_read_buf[t][j][k], i) def store_cache(self, t, i, j, k): # Handle corner cases @@ -139,8 +162,12 @@ def store_cache(self, t, i, j, k): # Store cache_write_buf to cache_home # Delete cache_write_buf - with torch.cuda.stream(self.store_cache_stream): - self.layers[j].store_cache(self.cache_home[t][j][k], self.cache_write_buf[t][j][k], i) + if is_xpu_available(): + with torch.xpu.stream(self.store_cache_stream): + self.layers[j].store_cache(self.cache_home[t][j][k], self.cache_write_buf[t][j][k], i) + else: + with torch.cuda.stream(self.store_cache_stream): + self.layers[j].store_cache(self.cache_home[t][j][k], self.cache_write_buf[t][j][k], i) def delete_cache(self, t, j, k): v = self.cache_home[t][j][k].pop() @@ -175,7 +202,10 @@ def load_hidden(self, b, t, i, j, k): # Already received the input from previous hidden states self.hidden[t][i][j][k].val = self.hidden[t][i][j][k].val.move(dst) return - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size num_gpu_batches = self.num_gpu_batches num_inner_iterations = self.num_inner_iterations left = ((b * num_inner_iterations + t) * num_gpu_batches + k) * gpu_batch_size @@ -221,7 +251,10 @@ def store_hidden(self, b, t, i, j, k): hidden_val = self.hidden[t][i][j][k].val ids = hidden_val.data.detach().cpu().numpy() - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size num_gpu_batches = self.num_gpu_batches num_inner_iterations = self.num_inner_iterations left = ((b * num_inner_iterations + t) * num_gpu_batches + k) * gpu_batch_size @@ -244,8 +277,12 @@ def recv_hidden(self, t, i, j, k, tag=0, async_=False): sender_rank = (self.pipeline_rank - 1) % self.num_pipeline_stages val_holder = self.hidden[t][i][j][k] seq_len = self.task.prompt_len if i == 0 else 1 - shape, dtype = self.layers[j].input_act_shape_and_dtype( - self.policy.gpu_batch_size, seq_len) + if is_xpu_available(): + shape, dtype = self.layers[j].input_act_shape_and_dtype( + self.policy.xpu_batch_size, seq_len) + else: + shape, dtype = self.layers[j].input_act_shape_and_dtype( + self.policy.gpu_batch_size, seq_len) if val_holder.val is None: val_holder.val = self.comm_device.allocate(shape, dtype) else: @@ -275,15 +312,22 @@ def update_attention_mask(self, b, t, i, k): mask.val = mask.val.device.extend_attention_mask(mask.val, [True]) return - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size num_gpu_batches = self.num_gpu_batches num_inner_iterations = self.num_inner_iterations left = ((b * num_inner_iterations + t) * num_gpu_batches + k) * gpu_batch_size right = left + gpu_batch_size input_ids = self.output_ids[left:right, :self.task.prompt_len] - attention_compute = (self.env.cpu if self.policy.cpu_cache_compute - else self.env.gpu) + if is_xpu_available(): + attention_compute = (self.env.cpu if self.policy.cpu_cache_compute + else self.env.xpu) + else: + attention_compute = (self.env.cpu if self.policy.cpu_cache_compute + else self.env.gpu) val = attention_compute.allocate( (self.policy.gpu_batch_size, self.task.prompt_len), bool) val.load_from_np((input_ids != self.config.pad_token_id)) @@ -311,7 +355,10 @@ def generate(self, num_pipeline_stages = self.num_pipeline_stages num_layers = self.num_layers num_gpu_batches = self.num_gpu_batches - gpu_batch_size = self.policy.gpu_batch_size + if is_xpu_available(): + gpu_batch_size = self.policy.xpu_batch_size + else: + gpu_batch_size = self.policy.gpu_batch_size overlap = self.policy.overlap num_prompts = len(task.inputs) num_inner_iterations = self.num_inner_iterations @@ -549,16 +596,25 @@ def run_flexgen_dist(args): gpu = TorchDevice(f"cuda:{args.local_rank}") cpu = TorchDevice("cpu") + if args.ipex and is_xpu_available(): + xpu = TorchDevice("xpu:0") disk = TorchDisk(args.offload_dir, None, args.local_rank) - env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) + env = ExecutionEnv(gpu=gpu, xpu=xpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, xpu, cpu, disk])) TorchTensor.name_count = count(start=args.rank, step=args.world_size) + if args.ipex and is_xpu_available(): + comm_test(xpu.dev if args.comm_device == "xpu" else cpu.dev) + else: + comm_test(gpu.dev if args.comm_device == "gpu" else cpu.dev) - comm_test(gpu.dev if args.comm_device == "gpu" else cpu.dev) - + if args.ipex and is_xpu_available(): + args.xpu_batch_size, args.num_xpu_batches = args.gpu_batch_size, args.num_xpu_batches policy = Policy(args.gpu_batch_size, args.num_gpu_batches, + args.xpu_batch_size, args.num_xpu_batches, args.percent[0], args.percent[1], args.percent[2], args.percent[3], args.percent[4], args.percent[5], + args.percent[6], args.percent[7], + args.percent[8], args.overlap, args.sep_layer, args.pin_weight, args.cpu_cache_compute, args.attn_sparsity, args.compress_weight, @@ -611,6 +667,8 @@ def run_flexgen_dist(args): total_latency = prefill_latency + decode_latency total_throughput = num_generated_tokens / total_latency _, gpu_peak_mem = gpu.mem_stats() + if args.ipex and is_xpu_available(): + _, xpu_peak_mem = xpu.mem_stats() _, cpu_peak_mem = cpu.mem_stats() if DUMMY_WEIGHT not in args.path: @@ -622,13 +680,15 @@ def run_flexgen_dist(args): print(show_str) gpu.print_stats() + if args.ipex and is_xpu_available(): + xpu.print_stats() cpu.print_stats() projected = args.debug_mode or cut_gen_len log_str = (f"model size: {opt_config.model_bytes()/GB:.3f} GB\t" f"cache size: {cache_size/GB:.3f} GB\t" f"hidden size (prefill): {hidden_size/GB:.3f} GB\n" - f"peak gpu mem: {gpu_peak_mem / GB:.3f} GB\n" + f"peak gpu mem: {xpu_peak_mem if (args.ipex and is_xpu_available()) else gpu_peak_mem / GB:.3f} GB\n" f"prefill latency: {prefill_latency:.2f} s\t" f"prefill throughput: {prefill_throughput:.2f} token/s\n" f"decode latency: {decode_latency:.2f} s\t" @@ -656,12 +716,15 @@ def add_distributed_parser_arguments(parser): parser.add_argument('--use-mpi', action='store_true', default=False, help="Get distributed info from MPI") parser.add_argument('--comm-device', type=str, default='gpu', - choices=['gpu', 'cpu'], - help='communication through gpu nvlink or cpu memory ' + choices=['gpu', 'cpu', 'xpu'], + help='communication through gpu nvlink ,xpu pcie or cpu memory ' 'and socket') parser.add_argument('--num-inner-iterations', metavar='I', type=int, default=None) parser.add_argument('--async-comm', action='store_true', default=False, help="Use asynchronous communication") + parser.add_argument("--ipex", action="store_true", + help="Whether to use xpu runtime on Intel GPU.") + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -681,7 +744,7 @@ def add_distributed_parser_arguments(parser): args.rank = 0 args.local_rank = 0 - assert len(args.percent) == 6 + assert len(args.percent) == 9 try: run_flexgen_dist(args) diff --git a/flexgen/dist_utils.py b/flexgen/dist_utils.py index df83d0ac..90d2dd36 100644 --- a/flexgen/dist_utils.py +++ b/flexgen/dist_utils.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +from flexgen.xpu_utils import is_xpu_available, is_ccl_available _COMM_DEVICE = None _PIPELINE_PARALLEL_PRED_GROUP = None @@ -11,7 +12,10 @@ def initialize_distributed(head_ip, port, world_size, rank, local_rank, f'world_size={world_size}, rank={rank}, local_rank={local_rank}.') # Initialize distributed environment - torch.cuda.set_device(local_rank) + if is_xpu_available() and is_ccl_available(): + torch.xpu.set_device(local_rank) + else: + torch.cuda.set_device(local_rank) distributed_init_method = f'tcp://{head_ip}:{port}' global _COMM_DEVICE _COMM_DEVICE = comm_device @@ -19,6 +23,8 @@ def initialize_distributed(head_ip, port, world_size, rank, local_rank, backend = 'gloo' elif comm_device == 'gpu': backend = 'nccl' + elif comm_device == 'xpu' and is_ccl_available() and is_xpu_available(): + backend = 'ccl' else: raise ValueError(f'Unknown comm_device: {comm_device}') dist.init_process_group(backend=backend, diff --git a/flexgen/flex_opt.py b/flexgen/flex_opt.py index 9660934f..e7f1b6ee 100644 --- a/flexgen/flex_opt.py +++ b/flexgen/flex_opt.py @@ -1413,6 +1413,6 @@ def add_parser_arguments(parser): add_parser_arguments(parser) args = parser.parse_args() - assert len(args.percent) == 6 + assert len(args.percent) == 9 run_flexgen(args) From c04fbe30337b13a02c6453c03882aa786a9f4ba0 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:57:27 +0530 Subject: [PATCH 4/4] bug fix typo --- flexgen/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexgen/utils.py b/flexgen/utils.py index 74051ffe..d09361f8 100644 --- a/flexgen/utils.py +++ b/flexgen/utils.py @@ -150,7 +150,7 @@ def cpu_mem_stats(): def torch_mem_stats(): objects = gc.get_objects() if is_xpu_available(): - tensors = [obj for obj in objects if torch.is_tensor(obj) and obj.is_xpu()]x + tensors = [obj for obj in objects if torch.is_tensor(obj) and obj.is_xpu()] else: tensors = [obj for obj in objects if torch.is_tensor(obj) and obj.is_cuda]