Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about adding / training Mixtral #43

Open
chrismrutherford opened this issue Apr 3, 2024 · 1 comment
Open

Question about adding / training Mixtral #43

chrismrutherford opened this issue Apr 3, 2024 · 1 comment

Comments

@chrismrutherford
Copy link

I followed your 'adding a new model' guide to add Mixtral. It appears transformers mixtral does not have a MixtralMLP as suggested by the guide. The other items can be imported OK. As a workaround I added MistralMLP to mlp_policy_fn insead of MixtralMLP.

The model now begins to train. Previously, without these changes there was an OOM error just prior to training, so something has worked. What is the effect of using MixtralMLP instead of MistralMLP? Am I just training garbage, or is it likely to produce something useful?

Background info:

Cannot import MixtralMLP

>>> 
>>> from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MIXTRAL_ATTENTION_CLASSES, MixtralMLP
Traceback (most recent call last):
ImportError: cannot import name 'MixtralMLP' from 'transformers.models.mixtral.modeling_mixtral' )
>>> 
>>> from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MIXTRAL_ATTENTION_CLASSES
>>> 

With mixtral mod

python train.py --model_name "/home/chris/repos/Mixtral-8x7B-Instruct-v0.1/" --batch_size 2 --context_length 512 --precision bf16 --train_type qlora --use_gradient_checkpointing true --use_cpu_offload false --dataset alpaca --reentrant_checkpointing true
World size: 4
Creating model 0
Loading model 0
Loading & Quantizing Model Shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [01:52<00:00,  5.95s/it]
Rank 0: Model created: 0.752 GiB
trainable params: 37,748,736 || all params: 46,740,541,440 || trainable%: 0.08076229935944876
Wrapping model w/ FSDP 0
Rank 0: Wrapped model: 9.803 GiB
Applying activation checkpointing 0
Total Training Steps: 6470
Epoch 0, Loss 1.045, LR 1.00e-05:   0%|▏ 

without mixtral mod

python train.py --model_name "/home/chris/repos/Mixtral-8x7B-Instruct-v0.1/" --batch_size 2 --context_length 512 --precision bf16 --train_type qlora --use_gradient_checkpointing true --use_cpu_offload false --dataset alpaca --reentrant_checkpointing true
World size: 4
Creating model 0
Loading model 0
Loading & Quantizing Model Shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [06:14<00:00, 19.69s/it]
Rank 0: Model created: 0.752 GiB
trainable params: 37,748,736 || all params: 46,740,541,440 || trainable%: 0.08076229935944876
Wrapping model w/ FSDP 0
Traceback (most recent call last):
<etc>
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 2 has a total capacity of 23.69 GiB of which 26.81 MiB is free. Including non-PyTorch memory, this process has 23.66 GiB memory in use. Of the allocated memory 23.22 GiB is allocated by PyTorch, and 47.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

The mod

diff --git a/train.py b/train.py
index 9181dc8..ca4809d 100644
--- a/train.py
+++ b/train.py
@@ -68,6 +68,7 @@ except ImportError:
 # for the wrapping policy and `check_fn` in activation checkpointing
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LLAMA_ATTENTION_CLASSES, LlamaMLP
 from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MISTRAL_ATTENTION_CLASSES, MistralMLP
+from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MIXTRAL_ATTENTION_CLASSES
 
 # To get rid of tokenizers warnings for now
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -429,18 +430,18 @@ def get_wrapping_policy(custom_policy:bool=False):
             )
     def self_attn_policy_fn(module):
         # Check module name is self_attn.
-        return isinstance(module, tuple(*LLAMA_ATTENTION_CLASSES.values(), *MISTRAL_ATTENTION_CLASSES.values()))
+        return isinstance(module, tuple(*LLAMA_ATTENTION_CLASSES.values(), *MISTRAL_ATTENTION_CLASSES.values(), *MIXTRAL_ATTENTION_CLASSES.values()))
 
     def mlp_policy_fn(module):
         # Check module name is self_attn.
-        return isinstance(module, (LlamaMLP, MistralMLP))
+        return isinstance(module, (LlamaMLP, MistralMLP, MistralMLP))
 
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
     self_attn_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn)
     mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn)
     transformer_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls=(LlamaDecoderLayer, MistralDecoderLayer),
+        transformer_layer_cls=(LlamaDecoderLayer, MistralDecoderLayer, MixtralDecoderLayer,),
     )
     policies=[lambda_policy, transformer_wrap_policy]
     if custom_policy:
@@ -735,7 +736,7 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict):
 
         )
 
-        check_fn = lambda submodule: isinstance(submodule, (LlamaDecoderLayer, MistralDecoderLayer))
+        check_fn = lambda submodule: isinstance(submodule, (LlamaDecoderLayer, MistralDecoderLayer, MixtralDecoderLayer))
         if rank == 0 or args['verbose']:
             print("Applying activation checkpointing", rank)
         apply_activation_checkpointing(
@@ -1042,4 +1043,4 @@ def main(
     mp.spawn(fsdp_main,
         args=(world_size, args),
         nprocs=torch.cuda.device_count(),
-        join=True)
\ No newline at end of file
+        join=True)
(END)
@hsb1995
Copy link

hsb1995 commented Apr 8, 2024

image
Is this a code error? Why is the downloaded code: modeling_mistral and MistralMLP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants