Skip to content

Commit

Permalink
chatglm: Add pretrain example and test
Browse files Browse the repository at this point in the history
  • Loading branch information
mengker33 committed Dec 4, 2024
1 parent c94df1b commit cb23dfe
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
27 changes: 27 additions & 0 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,33 @@ python ../gaudi_spawn.py \
This example has been validated with the following DeepSpeed ZeRO-2 config: https://github.com/huggingface/optimum-habana/blob/main/tests/configs/deepspeed_zero_2.json


### Multi-card Training with Deepspeed (chatglm3-6b)
```bash
python ../gaudi_spawn.py \
--world_size 8 --use_deepspeed run_clm.py \
--config_name THUDM/chatglm3-6b \
--tokenizer_name THUDM/chatglm3-6b \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 6 \
--per_device_eval_batch_size 4 \
--do_train \
--do_eval \
--deepspeed llama2_ds_zero3_config.json \
--output_dir /tmp/test-clm \
--gaudi_config_name Habana/gpt2 \
--use_habana \
--use_lazy_mode \
--throughput_warmup_steps 3 \
--bf16 \
--block_size 1024 \
--use_cache False \
--overwrite_output_dir \
--logging_first_step True \
--logging_steps 20
```


## Multi-Node Training with Deepspeed (GPT-NeoX)

The following command triggers the fine-tuning of [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b) on WikiText-2 with Deepspeed ZeRO-2.
Expand Down
8 changes: 6 additions & 2 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ def main():
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")

# Note that chatglm2/3 has float16 dtype from config.json, and on Gaudi we need to use bfloat16.
if config.model_type == "chatglm":
config.dtype = "torch.bfloat16"

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
Expand Down Expand Up @@ -472,8 +476,8 @@ def main():

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
# We need to skip this test for baichuan pretrain
if config.model_type not in ("baichuan"):
# We need to skip this test for baichuan and chatglm pretrain
if config.model_type not in ("baichuan", "chatglm"):
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
Expand Down
6 changes: 5 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _get_supported_models_for_script(

def is_valid_model_type(model_type: str) -> bool:
true_model_type = "llama" if model_type == "llama_guard" else model_type
if model_type == "protst":
if model_type in ("protst", "chatglm"):
in_task_mapping = True
else:
# llama_guard is not a model type in Transformers so CONFIG_MAPPING wouldn't find it
Expand Down Expand Up @@ -241,6 +241,7 @@ def to_test(
"codellama/CodeLlama-13b-Instruct-hf",
"MIT/ast-finetuned-speech-commands-v2",
"meta-llama/LlamaGuard-7b",
"THUDM/chatglm3-6b",
]

case_only_in_gaudi2 = [
Expand Down Expand Up @@ -326,6 +327,8 @@ def to_test(
return True
elif "gemma" in model_name and IS_GAUDI2:
return True
elif "chatglm3" in model_name and IS_GAUDI2 and deepspeed:
return True

return False

Expand Down Expand Up @@ -365,6 +368,7 @@ def __new__(
attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test(
model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile, fp8
)

attrs["EXAMPLE_NAME"] = example_name
return super().__new__(cls, name, bases, attrs)

Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")],
"mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")],
"gemma": [("google/gemma-2b-it", "Habana/gpt2")],
"chatglm": [("THUDM/chatglm3-6b", "Habana/gpt2")],
}

MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [
Expand All @@ -82,7 +83,7 @@
# "distilbert",
]

MODELS_TO_TEST_FOR_CAUSAL_LANGUAGE_MODELING = ["gpt2", "gpt_neox", "bloom", "code_llama", "gemma"]
MODELS_TO_TEST_FOR_CAUSAL_LANGUAGE_MODELING = ["gpt2", "gpt_neox", "bloom", "code_llama", "gemma", "chatglm"]

MODELS_TO_TEST_FOR_SEQ2SEQ = ["t5"]

Expand Down

0 comments on commit cb23dfe

Please sign in to comment.