Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chatglm #1478

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ The following model architectures, tasks and device distributions have been vali
| MiniCPM3 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Baichuan2 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V2 | | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
</div>

- Diffusers:
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| MiniCPM3 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Baichuan2 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V2 | | ✅ | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |

- Diffusers

Expand Down
27 changes: 27 additions & 0 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,33 @@ python ../gaudi_spawn.py \
This example has been validated with the following DeepSpeed ZeRO-2 config: https://github.com/huggingface/optimum-habana/blob/main/tests/configs/deepspeed_zero_2.json


### Multi-card Training with Deepspeed (chatglm3-6b)
```bash
python ../gaudi_spawn.py \
--world_size 8 --use_deepspeed run_clm.py \
--config_name THUDM/chatglm3-6b \
--tokenizer_name THUDM/chatglm3-6b \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 6 \
--per_device_eval_batch_size 4 \
--do_train \
--do_eval \
--deepspeed llama2_ds_zero3_config.json \
--output_dir /tmp/test-clm \
--gaudi_config_name Habana/gpt2 \
--use_habana \
--use_lazy_mode \
--throughput_warmup_steps 3 \
--bf16 \
--block_size 1024 \
--use_cache False \
--overwrite_output_dir \
--logging_first_step True \
--logging_steps 20
```


## Multi-Node Training with Deepspeed (GPT-NeoX)

The following command triggers the fine-tuning of [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b) on WikiText-2 with Deepspeed ZeRO-2.
Expand Down
8 changes: 6 additions & 2 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ def main():
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")

# Note that chatglm2/3 has float16 dtype from config.json, and on Gaudi we need to use bfloat16.
if config.model_type == "chatglm":
config.dtype = "torch.bfloat16"

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
Expand Down Expand Up @@ -472,8 +476,8 @@ def main():

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
# We need to skip this test for baichuan pretrain
if config.model_type not in ("baichuan"):
# We need to skip this test for baichuan and chatglm pretrain
if config.model_type not in ("baichuan", "chatglm"):
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
Expand Down
4 changes: 3 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
"minicpm3",
"baichuan",
"deepseek_v2",
"chatglm",
]

# Initial generated token index is set to 1 to accomodate SOS (start of string) token.
Expand Down Expand Up @@ -1087,8 +1088,9 @@ def generate(
"gemma",
"gemma2",
"baichuan",
"chatglm",
]
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2 and baichuan at the moment"
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan and chatglm at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
Expand Down
12 changes: 12 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
BaichuanConfig,
BaichuanForCausalLM,
BaichuanTokenizer,
ChatGLMConfig,
ChatGLMForConditionalGeneration,
ChatGLMForSequenceClassification,
ChatGLMTokenizer,
DeciLMConfig,
DeciLMForCausalLM,
DeepseekTokenizerFast,
Expand Down Expand Up @@ -719,3 +723,11 @@ def adapt_transformers_to_gaudi():
transformers.AutoConfig.register("baichuan", BaichuanConfig)
transformers.AutoTokenizer.register(BaichuanConfig, slow_tokenizer_class=BaichuanTokenizer)
transformers.AutoModelForCausalLM.register(BaichuanConfig, BaichuanForCausalLM)

# Register chatglm with optimization on Gaudi
transformers.AutoConfig.register("chatglm", ChatGLMConfig)
transformers.AutoTokenizer.register(ChatGLMConfig, ChatGLMTokenizer)
transformers.AutoModel.register(ChatGLMConfig, ChatGLMForConditionalGeneration)
transformers.AutoModelForCausalLM.register(ChatGLMConfig, ChatGLMForConditionalGeneration)
transformers.AutoModelForSeq2SeqLM.register(ChatGLMConfig, ChatGLMForConditionalGeneration)
transformers.AutoModelForSequenceClassification.register(ChatGLMConfig, ChatGLMForSequenceClassification)
6 changes: 6 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
gaudi_bloom_convert_to_standard_cache,
gaudi_bloom_model_forward,
)
from .chatglm import (
ChatGLMConfig,
ChatGLMForConditionalGeneration,
ChatGLMForSequenceClassification,
ChatGLMTokenizer,
)
from .clip import (
GaudiCLIPAttention,
GaudiCLIPEncoder,
Expand Down
6 changes: 6 additions & 0 deletions optimum/habana/transformers/models/chatglm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .configuration_chatglm import ChatGLMConfig
from .modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMForSequenceClassification,
)
from .tokenization_chatglm import ChatGLMTokenizer
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company
###############################################################################

"""
Adapted from the following sources:
https://huggingface.co/THUDM/chatglm2-6b/blob/main/configuration_chatglm.py
https://huggingface.co/THUDM/chatglm3-6b/blob/main/configuration_chatglm.py
"""

from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"

def __init__(
self,
_name_or_path=None,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
classifier_dropout=None,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
rope_ratio=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
pre_seq_len=None,
prefix_projection=False,
**kwargs,
):
self.name_or_path = _name_or_path
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.rope_ratio = rope_ratio
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)
Loading