diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md index 5d3bb908081..22a12a8313b 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md @@ -7,7 +7,8 @@ In this directory, you will find a C++ example on how to run LLM models on Intel |------------|----------------------------------------------------------------| | Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | | Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | - +| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | +| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | ## 0. Requirements To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 2ab7d72b929..18ee5b1d4ad 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -436,7 +436,11 @@ def convert_llm_for_deploy(model: torch.nn.Module, "max_prompt_len": max_prompt_len, "layernorm_const": layernorm_const, "group_size": group_size, - "fused_layers": fused_layers} + "fused_layers": fused_layers, + "qkv_bias": True, + "use_prefill_sdp": False, + "weight_num": 7, + "weight_idx": 8} model.config.update(update_dict) model.config.save_pretrained(save_directory) @@ -453,3 +457,39 @@ def convert_llm_for_deploy(model: torch.nn.Module, # save blob of lmhead and bin of embedding convert_lm_head_and_embedding(model, n_splits_linear, save_directory, weight_dir, True) + elif model.config.model_type == "llama": + layernorm_const = True + if model.config.vocab_size == 32000: + # for Llama2-7B + fused_layers = 4 + else: + # for Llama3-8B + fused_layers = 2 + update_dict = {"kv_len": kv_len, + "num_head": model.model.layers[0].self_attn.num_heads, + "head_dim": model.model.layers[0].self_attn.head_dim, + "transpose_value_cache": transpose_value_cache, + "max_prompt_len": max_prompt_len, + "layernorm_const": layernorm_const, + "group_size": group_size, + "fused_layers": fused_layers, + "qkv_bias": False, + "use_prefill_sdp": True, + "weight_num": 7, + "weight_idx": 5} + model.config.update(update_dict) + model.config.save_pretrained(save_directory) + + from .llama import convert_llama_layer, convert_fused_llama_layer + from .llama import convert_lm_head_and_embedding + # save fused_layers blobs of fused decoder layers + convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, + save_directory, weight_dir, transpose_value_cache, kv_len, + group_size, layernorm_const, "decode") + # save blob of single prefill layer + convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj, + save_directory, weight_dir, transpose_value_cache, max_prompt_len, + group_size, layernorm_const, "prefill") + # save blob of lmhead and bin of embedding + convert_lm_head_and_embedding(model, n_splits_linear, + save_directory, weight_dir, True) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 2e62418fa6f..3fb57381ef2 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -83,7 +83,8 @@ def __init__( self.compile() -def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): +def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, + convert_model=False): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -119,7 +120,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): vocab_size=vocab_size, n_splits=n_splits_linear ) - last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) + last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir, + True, False) # save weights bins files if n_splits_linear == 1: @@ -154,14 +156,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): attention_scaling=model.model.rotary_emb.attention_scaling, dtype=np.float16, ) - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir) return first_blob_path, last_blob_path def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, - layernorm_const): + layernorm_const, mode="decode"): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -201,8 +208,16 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, else: # FP16 Linear np_dtype = np.float16 + if mode == "decode": + input_len = 1 + decoder_name = f"decoder_layer_{layer_idx}" + else: + input_len = kv_len + decoder_name = "decoder_layer_prefill" + layernorm_const = False + single_decoder = LowBitLlamaMultiDecoderlayer( - [1, 1, num_heads * head_dim], + [1, input_len, num_heads * head_dim], input_layernorm_weights=[layer_norm_0] if layernorm_const else None, post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, cached_cos=cached_cos, @@ -213,40 +228,136 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, max_seq_len=kv_len, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, - mode="decode", + mode=mode, transpose_value=transpose_value_cache, dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, group_size=group_size ) + rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, - f"decoder_layer_{layer_idx}", - temp_dir) + decoder_name, + temp_dir, + True, False) - if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): - # llama-2-7B & llama-3-8B - if layernorm_const: - st_idx = 5 + if mode == "decode": + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + if layernorm_const: + st_idx = 5 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 7 else: + # llama-3.2-3B & llama-3.2-1B + if layernorm_const: + st_idx = 6 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 8 + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + del single_decoder + + +def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, + save_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const, mode="decode"): + num_heads = model.model.layers[0].self_attn.num_heads + num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads + head_dim = model.model.layers[0].self_attn.head_dim + intermediate_size = model.config.intermediate_size + rms_norm_eps = model.config.rms_norm_eps + layer_num = len(model.model.layers) + fused_layer_num = layer_num // fused_layers + + from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer + for i in range(fused_layers): + layer_start = i * fused_layer_num + layer_end = min((i + 1) * fused_layer_num, layer_num) + layer_weights = [] + input_layer_norm_weights = [] + post_attn_layernorm_weights = [] + layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + for layer_idx in layer_indexs: + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [] + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + layer_weights.extend(weights) + input_layer_norm_weights.append(layer_norm_0) + post_attn_layernorm_weights.append(layer_norm_1) + + # save weight input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") layer_norm_0.data.numpy().tofile(input_lm_bin_file) layer_norm_1.data.numpy().tofile(post_lm_bin_file) - st_idx = 7 - else: - # llama-3.2-3B & llama-3.2-1B - if layernorm_const: - st_idx = 6 - else: - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - st_idx = 8 - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") - scale.numpy().tofile(bin_file) - del single_decoder + st_idx = 5 + # 6, 7 are past k/v + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + fused_decoder = LowBitLlamaMultiDecoderlayer( + [1, 1, num_heads * head_dim], + input_layernorm_weights=input_layer_norm_weights, + post_attn_layernorm_weights=post_attn_layernorm_weights, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + num_layers=fused_layer_num, + max_seq_len=kv_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode=mode, + transpose_value=transpose_value_cache, + dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size + ) + update_names_of_IR_and_export_blob(fused_decoder, + f"decoder_layer_{i}", + save_dir, + compile_blob=True, + keep_ir=False) + return 0 diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 61233731404..385208277f5 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -135,13 +135,9 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, if mode == "decode": input_len = 1 decoder_name = f"decoder_layer_{layer_idx}" - compile = True - keep_ir = True else: input_len = kv_len decoder_name = "decoder_layer_prefill" - compile = True - keep_ir = False single_decoder = LowBitQwenMultiDecoderlayer( [1, input_len, num_heads * head_dim], input_layernorm_weights=None, @@ -166,7 +162,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, decoder_name, - temp_dir, compile, keep_ir) + temp_dir, True, False) # 0, 1, 2 are input_embed/attention_mask/position_id if mode == "decode":