Skip to content

Commit

Permalink
chatglm: Add pretrain example and test
Browse files Browse the repository at this point in the history
  • Loading branch information
mengker33 committed Dec 2, 2024
1 parent 1904b2b commit 0b9f898
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 4 deletions.
28 changes: 28 additions & 0 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,34 @@ python ../gaudi_spawn.py \
This example has been validated with the following DeepSpeed ZeRO-2 config: https://github.com/huggingface/optimum-habana/blob/main/tests/configs/deepspeed_zero_2.json


### Multi-card Training with Deepspeed (glm-4-9b-chat)
```bash
python ../gaudi_spawn.py \
--world_size 8 --use_deepspeed run_clm.py \
--config_name THUDM/glm-4-9b-chat \
--tokenizer_name THUDM/glm-4-9b-chat \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--do_train \
--do_eval \
--deepspeed llama2_ds_zero3_config.json \
--output_dir /tmp/test-clm \
--gaudi_config_name Habana/gpt2 \
--use_habana \
--use_lazy_mode \
--throughput_warmup_steps 3 \
--bf16 \
--block_size 1024 \
--use_cache False \
--overwrite_output_dir \
--logging_first_step True \
--logging_steps 20
```
Note that if pretrain chatglm2-6b and chatglm3-6b, we need to set "GLM" env variable to ensure the corresponding tokenizer is registered and loaded, i.e., GLM=3 for chatglm3 tokenizer, and GLM=2 for chatglm2; If GLM is not set, glm4 tokenizer will be registered by default.


## Multi-Node Training with Deepspeed (GPT-NeoX)

The following command triggers the fine-tuning of [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b) on WikiText-2 with Deepspeed ZeRO-2.
Expand Down
8 changes: 6 additions & 2 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ def main():
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")

# Note that chatglm2/3 has float16 dtype from config.json, and on Gaudi we need to use bfloat16.
if config.model_type == "chatglm":
config.dtype = "torch.bfloat16"

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
Expand Down Expand Up @@ -472,8 +476,8 @@ def main():

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
# We need to skip this test for baichuan pretrain
if config.model_type not in ("baichuan"):
# We need to skip this test for baichuan and chatglm pretrain
if config.model_type not in ("baichuan", "chatglm"):
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
Expand Down
6 changes: 5 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _get_supported_models_for_script(

def is_valid_model_type(model_type: str) -> bool:
true_model_type = "llama" if model_type == "llama_guard" else model_type
if model_type == "protst":
if model_type in ("protst", "chatglm"):
in_task_mapping = True
else:
# llama_guard is not a model type in Transformers so CONFIG_MAPPING wouldn't find it
Expand Down Expand Up @@ -234,6 +234,7 @@ def to_test(
"codellama/CodeLlama-13b-Instruct-hf",
"MIT/ast-finetuned-speech-commands-v2",
"meta-llama/LlamaGuard-7b",
"THUDM/glm-4-9b-chat",
]

case_only_in_gaudi2 = [
Expand Down Expand Up @@ -310,6 +311,8 @@ def to_test(
return True
elif "ast-finetuned-speech-commands-v2" in model_name and IS_GAUDI2:
return True
elif "glm-4-9b-chat" in model_name and IS_GAUDI2 and deepspeed:
return True

return False

Expand Down Expand Up @@ -348,6 +351,7 @@ def __new__(
attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test(
model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile, fp8
)

attrs["EXAMPLE_NAME"] = example_name
return super().__new__(cls, name, bases, attrs)

Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"qwen2": [("Qwen/Qwen2-7B", "Habana/qwen"), ("Qwen/Qwen2-72B", "Habana/qwen")],
"idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")],
"mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")],
"chatglm": [("THUDM/chatglm3-6b", "Habana/gpt2")],
}

MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [
Expand All @@ -81,7 +82,7 @@
# "distilbert",
]

MODELS_TO_TEST_FOR_CAUSAL_LANGUAGE_MODELING = ["gpt2", "gpt_neox", "bloom", "code_llama"]
MODELS_TO_TEST_FOR_CAUSAL_LANGUAGE_MODELING = ["gpt2", "gpt_neox", "bloom", "code_llama", "chatglm"]

MODELS_TO_TEST_FOR_SEQ2SEQ = ["t5"]

Expand Down

0 comments on commit 0b9f898

Please sign in to comment.