diff --git a/xtuner/model/modules/dispatch/cohere.py b/xtuner/model/modules/dispatch/cohere.py index d3529f570..8acf06747 100644 --- a/xtuner/model/modules/dispatch/cohere.py +++ b/xtuner/model/modules/dispatch/cohere.py @@ -3,6 +3,8 @@ import torch import torch.distributed as dist +import transformers +from mmengine.utils import digit_version from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb from xtuner.parallel.sequence import get_sequence_parallel_world_size @@ -18,6 +20,14 @@ class Cache: pass +TRANSFORMERS_VERSION = digit_version(transformers.__version__) +IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') + +if not IS_LOW_VERSION_TRANSFORMERS: + from transformers.modeling_flash_attention_utils import \ + _flash_attention_forward + + def cohere_attn_forward( self, hidden_states: torch.Tensor, @@ -110,13 +120,25 @@ def cohere_attn_forward( ori_num_head = self.num_heads self.num_heads = query_states.shape[-2] - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_states.shape[1], - dropout=dropout_rate) + if IS_LOW_VERSION_TRANSFORMERS: + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=dropout_rate) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) diff --git a/xtuner/model/modules/dispatch/mistral.py b/xtuner/model/modules/dispatch/mistral.py index d08b0f00e..dc6c7fed8 100644 --- a/xtuner/model/modules/dispatch/mistral.py +++ b/xtuner/model/modules/dispatch/mistral.py @@ -6,7 +6,9 @@ import torch import torch.distributed as dist import torch.nn as nn +import transformers from mmengine import MessageHub +from mmengine.utils import digit_version from transformers.cache_utils import Cache from transformers.models.mistral.modeling_mistral import (apply_rotary_pos_emb, repeat_kv) @@ -28,6 +30,13 @@ except ImportError: pass +TRANSFORMERS_VERSION = digit_version(transformers.__version__) +IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') + +if not IS_LOW_VERSION_TRANSFORMERS: + from transformers.modeling_flash_attention_utils import \ + _flash_attention_forward + class MistralRotaryEmbedding(nn.Module): @@ -220,15 +229,28 @@ def mistral_attn_forward( ori_num_head = self.num_heads self.num_heads = query_states.shape[-2] - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length=query_states.shape[1], - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) + if IS_LOW_VERSION_TRANSFORMERS: + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length=query_states.shape[1], + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=dropout_rate, + sliding_window=getattr(self.config, 'sliding_window', None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) diff --git a/xtuner/model/modules/dispatch/phi3.py b/xtuner/model/modules/dispatch/phi3.py index 97ebc8d33..10f60f939 100644 --- a/xtuner/model/modules/dispatch/phi3.py +++ b/xtuner/model/modules/dispatch/phi3.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +import inspect import warnings from typing import Optional, Tuple import torch import torch.distributed as dist +import transformers from mmengine import MessageHub +from mmengine.utils import digit_version from xtuner.parallel.sequence import (get_sequence_parallel_world_size, post_process_for_sequence_parallel_attn, @@ -19,7 +22,12 @@ class Cache: pass -import inspect +TRANSFORMERS_VERSION = digit_version(transformers.__version__) +IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') + +if not IS_LOW_VERSION_TRANSFORMERS: + from transformers.modeling_flash_attention_utils import \ + _flash_attention_forward _flash_supports_window_size = False try: @@ -239,15 +247,28 @@ def phi3_attn_forward( ori_num_head = self.num_heads self.num_heads = query_states.shape[-2] - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_states.shape[1], - dropout=attn_dropout, - use_sliding_windows=use_sliding_windows, - ) + if IS_LOW_VERSION_TRANSFORMERS: + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=attn_dropout, + use_sliding_windows=use_sliding_windows, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=attn_dropout, + sliding_window=getattr(self.config, 'sliding_window', None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) if enable_sequence_parallel: # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim) diff --git a/xtuner/model/modules/dispatch/qwen2.py b/xtuner/model/modules/dispatch/qwen2.py index 1c8c5a8d0..20f2f40f3 100644 --- a/xtuner/model/modules/dispatch/qwen2.py +++ b/xtuner/model/modules/dispatch/qwen2.py @@ -5,7 +5,9 @@ import torch import torch.distributed as dist +import transformers from mmengine import MessageHub +from mmengine.utils import digit_version from transformers.cache_utils import Cache from transformers.models.qwen2.modeling_qwen2 import (apply_rotary_pos_emb, repeat_kv) @@ -26,6 +28,13 @@ except ImportError: pass +TRANSFORMERS_VERSION = digit_version(transformers.__version__) +IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') + +if not IS_LOW_VERSION_TRANSFORMERS: + from transformers.modeling_flash_attention_utils import \ + _flash_attention_forward + def qwen2_attn_forward( self, @@ -157,15 +166,35 @@ def qwen2_attn_forward( ori_num_head = self.num_heads self.num_heads = query_states.shape[-2] - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length=query_states.shape[1], - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) + if IS_LOW_VERSION_TRANSFORMERS: + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length=query_states.shape[1], + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + else: + if (self.config.use_sliding_window + and getattr(self.config, 'sliding_window', None) is not None + and self.layer_idx >= self.config.max_window_layers): + # There may be bugs here, but we are aligned with Transformers + sliding_window = self.config.sliding_window + else: + sliding_window = None + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output)