Skip to content
This repository has been archived by the owner on Dec 1, 2024. It is now read-only.

[Feature] Intel dGPU/SYCL support #125

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions flexgen/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flexgen.pytorch_backend import (TorchTensor, TorchDevice,
DeviceType, general_copy, fix_recursive_import)
from flexgen.utils import np_dtype_to_torch_dtype

from flexgen.xpu_utils import is_xpu_available

@dataclasses.dataclass
class CompressionConfig:
Expand Down Expand Up @@ -49,9 +49,14 @@ def allocate(self, shape, dtype, comp_config, pin_memory=None, name=None):
(data, scale, comp_config), self, name=name)

def init_cache_one_gpu_batch(self, config, task, policy):
num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = (
config.n_head, config.input_dim, task.prompt_len, task.gen_len,
policy.gpu_batch_size)
if is_xpu_available():
num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = (
config.n_head, config.input_dim, task.prompt_len, task.gen_len,
policy.xpu_batch_size)
else:
num_head, hidden_size, prompt_len, gen_len, gpu_batch_size = (
config.n_head, config.input_dim, task.prompt_len, task.gen_len,
policy.gpu_batch_size)
shape = (prompt_len + gen_len - 1, gpu_batch_size * num_head, hidden_size // num_head)
# NOTE: disable pin_memory due to high memory overhead
pin_memory = False
Expand All @@ -65,7 +70,10 @@ def init_attention_compute_workspace(self, config, task, policy):
if self.base_device.device_type != DeviceType.CPU:
return # Only CPU requires this fp32 workspace

b = policy.gpu_batch_size
if is_xpu_available():
b = policy.xpu_batch_size
else:
b = policy.gpu_batch_size
n_head = config.n_head
head_dim = config.input_dim // n_head
max_seq_len = task.prompt_len + task.gen_len - 1
Expand Down Expand Up @@ -334,7 +342,10 @@ def compress_and_decompress(tensor, config):

def test_simulated_compression():
torch.manual_seed(0)
a = torch.normal(0, 1, (64, 64, 64), dtype=torch.float16).cuda()
if is_xpu_available():
a = torch.normal(0, 1, (64, 64, 64), dtype=torch.float16).xpu()
else:
a = torch.normal(0, 1, (64, 64, 64), dtype=torch.float16).cuda()

config = CompressionConfig(
num_bits=4, group_size=32, group_dim=0, symmetric=False)
Expand All @@ -346,11 +357,17 @@ def test_simulated_compression():

def test_real_compression():
torch.manual_seed(0)
a = torch.normal(0, 1, (32, 1, 1), dtype=torch.float16).cuda()
if is_xpu_available():
a = torch.normal(0, 1, (32, 1, 1), dtype=torch.float16).xpu()
else:
a = torch.normal(0, 1, (32, 1, 1), dtype=torch.float16).cuda()

config = CompressionConfig(
num_bits=4, group_size=32, group_dim=0, symmetric=False)
dev = TorchDevice("cuda:0", 0, 0).compressed_device
if is_xpu_available():
dev =TorchDevice("xpu:0", 0, 0).compressed_device
else:
dev = TorchDevice("cuda:0", 0, 0).compressed_device
packed = dev.compress(a, config)
b = dev.decompress(packed)

Expand Down
117 changes: 90 additions & 27 deletions flexgen/dist_flex_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from flexgen.timer import timers
from flexgen.utils import (Task, ExecutionEnv, GB, T, ValueHolder,
array_1d, array_2d, array_3d, array_4d, str2bool, project_decode_latency)
from flexgen.xpu_utils import is_xpu_available, is_ccl_available


#os.environ["NCCL_DEBUG"] = "TRACE"
#if is_ccl_available():
#os.environ["CCL_LOG_LEVEL"] = "TRACE"


class DistOptLM(OptLM):
Expand All @@ -35,7 +38,10 @@ def __init__(self, config, env, path, policy, pipeline_rank,
self.env = env
self.path = path
self.policy = policy
self.num_gpu_batches = self.policy.num_gpu_batches
if is_xpu_available():
self.num_gpu_batches = self.policy.num_xpu_batches
else:
self.num_gpu_batches = self.policy.num_gpu_batches
self.pipeline_rank = pipeline_rank
self.num_pipeline_stages = num_pipeline_stages
self.num_inner_iterations = num_inner_iterations if num_inner_iterations is not None else num_pipeline_stages
Expand All @@ -44,9 +50,11 @@ def __init__(self, config, env, path, policy, pipeline_rank,
self.comm_device = self.env.cpu
elif comm_device == "gpu":
self.comm_device = self.env.gpu
elif comm_device == "xpu" and is_xpu_available():
self.comm_device = self.env.xpu
else:
raise ValueError(f"Invalid comm_device: {comm_device}")

layers = []
if pipeline_rank == 0:
layers.append(InputEmbed(self.config, self.env, self.policy))
Expand All @@ -69,17 +77,24 @@ def __init__(self, config, env, path, policy, pipeline_rank,

if self.policy.act_gpu_percent == 100:
self.act_home = self.env.gpu
elif self.policy.act_xpu_percent == 100 and is_xpu_available():
self.act_home =self.env.xpu
elif self.policy.act_cpu_percent == 100:
self.act_home = self.env.cpu
elif self.policy.act_disk_percent == 100:
self.act_home = self.env.disk
else:
raise NotImplementedError()

# CUDA streams
self.load_weight_stream = torch.cuda.Stream()
self.load_cache_stream = torch.cuda.Stream()
self.store_cache_stream = torch.cuda.Stream()
if is_xpu_available():
self.load_weight_stream = torch.xpu.Stream()
self.load_cache_stream = torch.xpu.Stream()
self.store_cache_stream = torch.xpu.Stream()
else:
# CUDA streams
self.load_weight_stream = torch.cuda.Stream()
self.load_cache_stream = torch.cuda.Stream()
self.store_cache_stream = torch.cuda.Stream()

self.task = None
self.init_all_weights()
Expand All @@ -99,8 +114,12 @@ def load_weight(self, b, t, i, j, k):
return

# Load from weight_home to weight_read_buf
with torch.cuda.stream(self.load_weight_stream):
self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k)
if is_xpu_available():
with torch.xpu.stream(self.load_weight_stream):
self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k)
else:
with torch.cuda.stream(self.load_weight_stream):
self.layers[j].load_weight(self.weight_home[j], self.weight_read_buf[j], k)

def init_cache(self, t, j, k):
self.layers[j].init_cache_one_gpu_batch(self.cache_home[t][j][k])
Expand All @@ -120,8 +139,12 @@ def load_cache(self, t, i, j, k):
return

# Load from cache_home to cache_read_buf
with torch.cuda.stream(self.load_cache_stream):
self.layers[j].load_cache(self.cache_home[t][j][k], self.cache_read_buf[t][j][k], i)
if is_xpu_available():
with torch.xpu.stream(self.load_cache_stream):
self.layers[j].load_cache(self.cache_home[t][j][k]), self.cache_read_buf[t][j][k], i)
else:
with torch.cuda.stream(self.load_cache_stream):
self.layers[j].load_cache(self.cache_home[t][j][k], self.cache_read_buf[t][j][k], i)

def store_cache(self, t, i, j, k):
# Handle corner cases
Expand All @@ -139,8 +162,12 @@ def store_cache(self, t, i, j, k):

# Store cache_write_buf to cache_home
# Delete cache_write_buf
with torch.cuda.stream(self.store_cache_stream):
self.layers[j].store_cache(self.cache_home[t][j][k], self.cache_write_buf[t][j][k], i)
if is_xpu_available():
with torch.xpu.stream(self.store_cache_stream):
self.layers[j].store_cache(self.cache_home[t][j][k], self.cache_write_buf[t][j][k], i)
else:
with torch.cuda.stream(self.store_cache_stream):
self.layers[j].store_cache(self.cache_home[t][j][k], self.cache_write_buf[t][j][k], i)

def delete_cache(self, t, j, k):
v = self.cache_home[t][j][k].pop()
Expand Down Expand Up @@ -175,7 +202,10 @@ def load_hidden(self, b, t, i, j, k):
# Already received the input from previous hidden states
self.hidden[t][i][j][k].val = self.hidden[t][i][j][k].val.move(dst)
return
gpu_batch_size = self.policy.gpu_batch_size
if is_xpu_available():
gpu_batch_size = self.policy.xpu_batch_size
else:
gpu_batch_size = self.policy.gpu_batch_size
num_gpu_batches = self.num_gpu_batches
num_inner_iterations = self.num_inner_iterations
left = ((b * num_inner_iterations + t) * num_gpu_batches + k) * gpu_batch_size
Expand Down Expand Up @@ -221,7 +251,10 @@ def store_hidden(self, b, t, i, j, k):
hidden_val = self.hidden[t][i][j][k].val

ids = hidden_val.data.detach().cpu().numpy()
gpu_batch_size = self.policy.gpu_batch_size
if is_xpu_available():
gpu_batch_size = self.policy.xpu_batch_size
else:
gpu_batch_size = self.policy.gpu_batch_size
num_gpu_batches = self.num_gpu_batches
num_inner_iterations = self.num_inner_iterations
left = ((b * num_inner_iterations + t) * num_gpu_batches + k) * gpu_batch_size
Expand All @@ -244,8 +277,12 @@ def recv_hidden(self, t, i, j, k, tag=0, async_=False):
sender_rank = (self.pipeline_rank - 1) % self.num_pipeline_stages
val_holder = self.hidden[t][i][j][k]
seq_len = self.task.prompt_len if i == 0 else 1
shape, dtype = self.layers[j].input_act_shape_and_dtype(
self.policy.gpu_batch_size, seq_len)
if is_xpu_available():
shape, dtype = self.layers[j].input_act_shape_and_dtype(
self.policy.xpu_batch_size, seq_len)
else:
shape, dtype = self.layers[j].input_act_shape_and_dtype(
self.policy.gpu_batch_size, seq_len)
if val_holder.val is None:
val_holder.val = self.comm_device.allocate(shape, dtype)
else:
Expand Down Expand Up @@ -275,15 +312,22 @@ def update_attention_mask(self, b, t, i, k):
mask.val = mask.val.device.extend_attention_mask(mask.val, [True])
return

gpu_batch_size = self.policy.gpu_batch_size
if is_xpu_available():
gpu_batch_size = self.policy.xpu_batch_size
else:
gpu_batch_size = self.policy.gpu_batch_size
num_gpu_batches = self.num_gpu_batches
num_inner_iterations = self.num_inner_iterations
left = ((b * num_inner_iterations + t) * num_gpu_batches + k) * gpu_batch_size
right = left + gpu_batch_size
input_ids = self.output_ids[left:right, :self.task.prompt_len]

attention_compute = (self.env.cpu if self.policy.cpu_cache_compute
else self.env.gpu)
if is_xpu_available():
attention_compute = (self.env.cpu if self.policy.cpu_cache_compute
else self.env.xpu)
else:
attention_compute = (self.env.cpu if self.policy.cpu_cache_compute
else self.env.gpu)
val = attention_compute.allocate(
(self.policy.gpu_batch_size, self.task.prompt_len), bool)
val.load_from_np((input_ids != self.config.pad_token_id))
Expand Down Expand Up @@ -311,7 +355,10 @@ def generate(self,
num_pipeline_stages = self.num_pipeline_stages
num_layers = self.num_layers
num_gpu_batches = self.num_gpu_batches
gpu_batch_size = self.policy.gpu_batch_size
if is_xpu_available():
gpu_batch_size = self.policy.xpu_batch_size
else:
gpu_batch_size = self.policy.gpu_batch_size
overlap = self.policy.overlap
num_prompts = len(task.inputs)
num_inner_iterations = self.num_inner_iterations
Expand Down Expand Up @@ -549,16 +596,25 @@ def run_flexgen_dist(args):

gpu = TorchDevice(f"cuda:{args.local_rank}")
cpu = TorchDevice("cpu")
if args.ipex and is_xpu_available():
xpu = TorchDevice("xpu:0")
disk = TorchDisk(args.offload_dir, None, args.local_rank)
env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
env = ExecutionEnv(gpu=gpu, xpu=xpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, xpu, cpu, disk]))
TorchTensor.name_count = count(start=args.rank, step=args.world_size)
if args.ipex and is_xpu_available():
comm_test(xpu.dev if args.comm_device == "xpu" else cpu.dev)
else:
comm_test(gpu.dev if args.comm_device == "gpu" else cpu.dev)

comm_test(gpu.dev if args.comm_device == "gpu" else cpu.dev)

if args.ipex and is_xpu_available():
args.xpu_batch_size, args.num_xpu_batches = args.gpu_batch_size, args.num_xpu_batches
policy = Policy(args.gpu_batch_size, args.num_gpu_batches,
args.xpu_batch_size, args.num_xpu_batches,
args.percent[0], args.percent[1],
args.percent[2], args.percent[3],
args.percent[4], args.percent[5],
args.percent[6], args.percent[7],
args.percent[8],
args.overlap, args.sep_layer, args.pin_weight,
args.cpu_cache_compute, args.attn_sparsity,
args.compress_weight,
Expand Down Expand Up @@ -611,6 +667,8 @@ def run_flexgen_dist(args):
total_latency = prefill_latency + decode_latency
total_throughput = num_generated_tokens / total_latency
_, gpu_peak_mem = gpu.mem_stats()
if args.ipex and is_xpu_available():
_, xpu_peak_mem = xpu.mem_stats()
_, cpu_peak_mem = cpu.mem_stats()

if DUMMY_WEIGHT not in args.path:
Expand All @@ -622,13 +680,15 @@ def run_flexgen_dist(args):
print(show_str)

gpu.print_stats()
if args.ipex and is_xpu_available():
xpu.print_stats()
cpu.print_stats()
projected = args.debug_mode or cut_gen_len

log_str = (f"model size: {opt_config.model_bytes()/GB:.3f} GB\t"
f"cache size: {cache_size/GB:.3f} GB\t"
f"hidden size (prefill): {hidden_size/GB:.3f} GB\n"
f"peak gpu mem: {gpu_peak_mem / GB:.3f} GB\n"
f"peak gpu mem: {xpu_peak_mem if (args.ipex and is_xpu_available()) else gpu_peak_mem / GB:.3f} GB\n"
f"prefill latency: {prefill_latency:.2f} s\t"
f"prefill throughput: {prefill_throughput:.2f} token/s\n"
f"decode latency: {decode_latency:.2f} s\t"
Expand Down Expand Up @@ -656,12 +716,15 @@ def add_distributed_parser_arguments(parser):
parser.add_argument('--use-mpi', action='store_true', default=False,
help="Get distributed info from MPI")
parser.add_argument('--comm-device', type=str, default='gpu',
choices=['gpu', 'cpu'],
help='communication through gpu nvlink or cpu memory '
choices=['gpu', 'cpu', 'xpu'],
help='communication through gpu nvlink ,xpu pcie or cpu memory '
'and socket')
parser.add_argument('--num-inner-iterations', metavar='I', type=int, default=None)
parser.add_argument('--async-comm', action='store_true', default=False,
help="Use asynchronous communication")
parser.add_argument("--ipex", action="store_true",
help="Whether to use xpu runtime on Intel GPU.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -681,7 +744,7 @@ def add_distributed_parser_arguments(parser):
args.rank = 0
args.local_rank = 0

assert len(args.percent) == 6
assert len(args.percent) == 9

try:
run_flexgen_dist(args)
Expand Down
8 changes: 7 additions & 1 deletion flexgen/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
from flexgen.xpu_utils import is_xpu_available, is_ccl_available

_COMM_DEVICE = None
_PIPELINE_PARALLEL_PRED_GROUP = None
Expand All @@ -11,14 +12,19 @@ def initialize_distributed(head_ip, port, world_size, rank, local_rank,
f'world_size={world_size}, rank={rank}, local_rank={local_rank}.')

# Initialize distributed environment
torch.cuda.set_device(local_rank)
if is_xpu_available() and is_ccl_available():
torch.xpu.set_device(local_rank)
else:
torch.cuda.set_device(local_rank)
distributed_init_method = f'tcp://{head_ip}:{port}'
global _COMM_DEVICE
_COMM_DEVICE = comm_device
if comm_device == 'cpu':
backend = 'gloo'
elif comm_device == 'gpu':
backend = 'nccl'
elif comm_device == 'xpu' and is_ccl_available() and is_xpu_available():
backend = 'ccl'
else:
raise ValueError(f'Unknown comm_device: {comm_device}')
dist.init_process_group(backend=backend,
Expand Down
Loading