diff --git a/.gitignore b/.gitignore index 44deaadca3..fcac20e876 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # SynapseAI logs .local.synapse_log* + +# ruff +.ruff_cache diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py index 2fcade692b..680608fccf 100644 --- a/examples/audio-classification/run_audio_classification.py +++ b/examples/audio-classification/run_audio_classification.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py index 8906dac0f9..ca81bbb8be 100644 --- a/examples/contrastive-image-text/run_clip.py +++ b/examples/contrastive-image-text/run_clip.py @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index 25605bc573..6433ecbb38 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 5d13b370ab..ff7dda6312 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -52,7 +52,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 41e4ae183f..9b1229673a 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 945221aa3c..81a0b3f772 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -49,7 +49,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py index 153e929aee..edb955309c 100644 --- a/examples/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/speech-recognition/run_speech_recognition_ctc.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index 33ad7a41a4..2e7336ede4 100644 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -52,7 +52,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index b9983c220a..6c1913846d 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -47,7 +47,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py index 47ab22307e..40e38a6681 100644 --- a/examples/translation/run_translation.py +++ b/examples/translation/run_translation.py @@ -51,7 +51,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.26.0") +check_min_version("4.27.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/optimum/habana/transformers/deepspeed.py b/optimum/habana/transformers/deepspeed.py index f076aa24a6..9459590f6c 100644 --- a/optimum/habana/transformers/deepspeed.py +++ b/optimum/habana/transformers/deepspeed.py @@ -72,6 +72,11 @@ def trainer_config_process(self, args): self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") # total_num_steps - will get set in trainer_config_finalize + if args.save_on_each_node: + # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True + self.config["checkpoint"] = self.config.get("checkpoint", {}) + self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node + # deepspeed's default mode is fp16 unless there is a config that says differently if self.is_true("bf16.enabled"): self._dtype = torch.bfloat16 @@ -171,6 +176,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf if load_path is None: raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}") else: - logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing") + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") return deepspeed_engine, optimizer, lr_scheduler diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 8d239442d0..c0b08ea59d 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -51,10 +51,10 @@ class GaudiGenerationMixin(GenerationMixin): """ - This class enables to perform fast generation in lazy mode. + This class enables to perform fast generation in lazy mode and with HPU graphs. The only difference with GenerationMixin is that the various generation methods will generate sequences whose size is max_length. Having constant - sizes allows to make the most of lazy mode. + sizes allows to make the most of lazy mode and HPU graphs. """ @torch.no_grad() @@ -79,8 +79,8 @@ def generate( model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - For a complete overview of generate, check the [following - guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). + For an overview of generation strategies and code examples, check out the [following + guide](../generation_strategies). @@ -161,6 +161,7 @@ def generate( generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) + generation_config.validate() model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) @@ -233,34 +234,28 @@ def generate( device=inputs_tensor.device, ) else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - ( - ( - "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" - f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length`" - " via the config is deprecated and `max_length` will be removed from the config in v5 of" - " Transformers -- we recommend using `max_new_tokens` to control the maximum length of the" - " generation." - ), - ), + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif has_default_max_length and generation_config.max_new_tokens is not None: + elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - elif not has_default_max_length and generation_config.max_new_tokens is not None: - raise ValueError( - "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" - " limit to the generated output length. Remove one of those arguments. Please refer to the" - " documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( @@ -450,6 +445,7 @@ def generate( length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -487,6 +483,7 @@ def generate( device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, + max_length=generation_config.max_length, ) # 13. interleave input_ids with `num_beams` additional sequences per batch @@ -532,12 +529,12 @@ def generate( beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, - max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -627,6 +624,7 @@ def typeerror(): length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -679,7 +677,7 @@ def contrastive_search( In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -702,8 +700,8 @@ def contrastive_search( used to tell if the generation loop should stop. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -782,7 +780,7 @@ def greedy_search( In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -801,8 +799,8 @@ def greedy_search( tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -882,6 +880,7 @@ def greedy_search( eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -976,8 +975,10 @@ def greedy_search( ) # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) if lazy_mode and not hpu_graphs: self.htcore_generation.mark_step() @@ -1037,7 +1038,7 @@ def sample( In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -1059,8 +1060,8 @@ def sample( tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -1138,7 +1139,7 @@ def sample( ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] ```""" logger.warning("Sampling is slow in lazy mode, eager mode should be preferred at the moment.") @@ -1160,6 +1161,7 @@ def sample( eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -1257,8 +1259,10 @@ def sample( ) # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) if lazy_mode and not hpu_graphs: self.htcore_generation.mark_step() @@ -1317,7 +1321,7 @@ def beam_search( In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -1338,8 +1342,8 @@ def beam_search( tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -1650,7 +1654,7 @@ def beam_sample( In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -1675,8 +1679,8 @@ def beam_sample( tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -1793,7 +1797,7 @@ def group_beam_search( In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -1814,8 +1818,8 @@ def group_beam_search( tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. @@ -1928,7 +1932,7 @@ def constrained_beam_search( In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following - guide](./generation_strategies). + guide](../generation_strategies). @@ -1954,8 +1958,8 @@ def constrained_beam_search( tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 55e53bf21b..fc2a0b9942 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import transformers.models.gpt2.modeling_gpt2 from transformers.generation import GenerationMixin from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.albert.modeling_albert import AlbertModel @@ -21,6 +22,7 @@ from .generation import GaudiGenerationMixin from .models import ( + GaudiGPT2Attention, gaudi_albert_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask, @@ -29,7 +31,7 @@ ) -def adapt_transformers_to_gaudi(use_habana_mixed_precision: bool): +def adapt_transformers_to_gaudi(): """ Replaces some Transformers' methods for equivalent methods optimized for Gaudi. @@ -56,11 +58,13 @@ def adapt_transformers_to_gaudi(use_habana_mixed_precision: bool): GenerationMixin.group_beam_search = GaudiGenerationMixin.group_beam_search GenerationMixin.constrained_beam_search = GaudiGenerationMixin.constrained_beam_search - if use_habana_mixed_precision: - # When HMP is enabled, replace invert_attention_mask and get_extended_attention_mask - # so that HMP is disabled for specific parts of the code - ModuleUtilsMixin.invert_attention_mask = gaudi_invert_attention_mask - ModuleUtilsMixin.get_extended_attention_mask = gaudi_get_extended_attention_mask - # AlbertModel.forward does not rely on get_extended_attention_mask so it also needs - # to be replaced when using HMP - AlbertModel.forward = gaudi_albert_forward + # Replace invert_attention_mask and get_extended_attention_mask + # so that HMP is disabled for specific parts of the code + ModuleUtilsMixin.invert_attention_mask = gaudi_invert_attention_mask + ModuleUtilsMixin.get_extended_attention_mask = gaudi_get_extended_attention_mask + # AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced + AlbertModel.forward = gaudi_albert_forward + + # From Transformers 4.27, the bias in the GPT2Attention layer is a Boolean + # Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26) + transformers.models.gpt2.modeling_gpt2.GPT2Attention = GaudiGPT2Attention diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index c17b077ba2..f93b3c2c8c 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -1,4 +1,5 @@ from .albert import gaudi_albert_forward +from .gpt2 import GaudiGPT2Attention from .modeling_all_models import gaudi_get_extended_attention_mask, gaudi_invert_attention_mask from .vit import gaudi_vit_self_attention_forward from .wav2vec2 import ( diff --git a/optimum/habana/transformers/models/gpt2/__init__.py b/optimum/habana/transformers/models/gpt2/__init__.py new file mode 100644 index 0000000000..b95f1b522c --- /dev/null +++ b/optimum/habana/transformers/models/gpt2/__init__.py @@ -0,0 +1 @@ +from .modeling_gpt2 import GaudiGPT2Attention diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py new file mode 100644 index 0000000000..bd0758cc50 --- /dev/null +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -0,0 +1,229 @@ +from typing import Optional, Tuple, Union + +import torch +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer + + +class GaudiGPT2Attention(torch.nn.Module): + """ + Copied from GPT2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py + The only differences are: + - `self.bias` is a torch.uint8 and not a torch.bool + - it is casted to bool before being used in torch.where + """ + + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = torch.nn.Dropout(config.attn_pdrop) + self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index a2b824b0e4..c390992e83 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -81,7 +81,6 @@ ) from .deepspeed import deepspeed_init from .gaudi_configuration import GAUDI_CONFIG_NAME, GaudiConfig -from .modeling_utils import adapt_transformers_to_gaudi from .trainer_utils import convert_into_dtypes, get_dtype from .training_args import GaudiTrainingArguments @@ -209,9 +208,6 @@ def __init__( logging.enable_default_handler() logging.enable_explicit_format() - # Some methods needs to be tweaked to optimally run on Gaudi - adapt_transformers_to_gaudi(self.gaudi_config.use_habana_mixed_precision) - # Suppress PyTorch autocast warnings with Wav2Vec2 # This is a bug in PyTorch warnings.filterwarnings( @@ -467,7 +463,7 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None: + if resume_from_checkpoint is not None and args.deepspeed is None: self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped diff --git a/optimum/habana/transformers/trainer_seq2seq.py b/optimum/habana/transformers/trainer_seq2seq.py index 10ed14e46a..752717e94f 100644 --- a/optimum/habana/transformers/trainer_seq2seq.py +++ b/optimum/habana/transformers/trainer_seq2seq.py @@ -183,30 +183,18 @@ def prediction_step( gen_kwargs["hpu_graphs"] if gen_kwargs.get("hpu_graphs") is not None else self.args.use_hpu_graphs ) - if "attention_mask" in inputs: - gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) - if "global_attention_mask" in inputs: - gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) - - # prepare generation inputs - # some encoder-decoder models can have varying encoder's and thus - # varying model input names - if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: - generation_inputs = inputs[self.model.encoder.main_input_name] - else: - generation_inputs = inputs[self.model.main_input_name] - + # TODO (Joao): the following line is needed to keep a consistent result on SQUAD. Ideally, we should not block + # users from preparing a dataset with `decoder_input_ids`. + inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} try: - generated_tokens = self.model.generate( - generation_inputs, - **gen_kwargs, - ) + generated_tokens = self.model.generate(**inputs, **gen_kwargs) except RuntimeError as error: if "cpu fallback is not supported during hpu graph capturing" in str(error): error.args = ( f"{error}. You should run inference in lazy mode only with `use_lazy_mode=True` and `use_hpu_graphs=False`.", ) raise error + # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # TODO: remove this hack when the legacy code that initializes generation_config from a model config is # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183 diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 432242daca..e65ee26610 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import Optional, Union +from packaging import version from transformers.debug_utils import DebugOption from transformers.file_utils import cached_property, is_torch_available, requires_backends from transformers.trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType @@ -253,6 +254,9 @@ def __post_init__(self): FutureWarning, ) self.optim = OptimizerNames.ADAFACTOR + if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available(): + if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"): + raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher") if self.report_to is None: logger.info( @@ -435,6 +439,13 @@ def _setup_devices(self) -> "torch.device": backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta ) elif self.use_habana: + # Some methods needs to be tweaked to optimally run on Gaudi + # Calling this method here to be sure it is done before model instantiation + # Otherwise this will fail when some __init__ methods are overridden (cf. GPT2Attention) + from .modeling_utils import adapt_transformers_to_gaudi + + adapt_transformers_to_gaudi() + if self.use_lazy_mode: logger.info("Enabled lazy mode.") else: diff --git a/tests/create_diff_file_for_example.py b/tests/create_diff_file_for_example.py index 9eec99e5fa..ff45d325cb 100644 --- a/tests/create_diff_file_for_example.py +++ b/tests/create_diff_file_for_example.py @@ -42,7 +42,9 @@ def _ask_yes_or_no_question(message: str) -> str: def diff(filename1: Path, filename2: Path) -> str: if not filename1.exists() or not filename2.exists(): - raise FileNotFoundError("Cannot compute the diff because at least of one the file does not exist.") + raise FileNotFoundError( + f"Cannot compute the diff because at least one of the files does not exist: {filename1} and/or {filename2}." + ) cmd_line = ["diff", str(filename1), str(filename2)] p = subprocess.Popen(cmd_line, stdout=subprocess.PIPE) outs, _ = p.communicate() diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt index 09736dd290..d60de1ec54 100644 --- a/tests/example_diff/run_audio_classification.txt +++ b/tests/example_diff/run_audio_classification.txt @@ -21,9 +21,9 @@ > from optimum.habana.utils import set_seed > 48c41 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 164,166d156 < freeze_feature_extractor: Optional[bool] = field( < default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."} diff --git a/tests/example_diff/run_clip.txt b/tests/example_diff/run_clip.txt index b7271adca1..140c7b5bb7 100644 --- a/tests/example_diff/run_clip.txt +++ b/tests/example_diff/run_clip.txt @@ -14,9 +14,9 @@ > from optimum.habana.utils import set_seed > 57c55 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 230c228 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- diff --git a/tests/example_diff/run_clm.txt b/tests/example_diff/run_clm.txt index 35b02d5303..d98fa7920e 100644 --- a/tests/example_diff/run_clm.txt +++ b/tests/example_diff/run_clm.txt @@ -25,9 +25,9 @@ > from optimum.habana.utils import set_seed > 58c55 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 79c76,77 < "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." --- diff --git a/tests/example_diff/run_glue.txt b/tests/example_diff/run_glue.txt index cda38af1d5..5521da0d43 100644 --- a/tests/example_diff/run_glue.txt +++ b/tests/example_diff/run_glue.txt @@ -13,9 +13,9 @@ > from optimum.habana.utils import set_seed > 51c50 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 211c210 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt index cc6dd731d7..ba08c334a3 100644 --- a/tests/example_diff/run_image_classification.txt +++ b/tests/example_diff/run_image_classification.txt @@ -12,9 +12,9 @@ > from optimum.habana.utils import set_seed > 58c57 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 171c170 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt index 2b407de35c..9edb95265e 100644 --- a/tests/example_diff/run_mlm.txt +++ b/tests/example_diff/run_mlm.txt @@ -20,9 +20,9 @@ > from optimum.habana.utils import set_seed > 56c53 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 200c197 < streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) --- diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt index d93ab05e0b..b4b620924f 100644 --- a/tests/example_diff/run_qa.txt +++ b/tests/example_diff/run_qa.txt @@ -18,9 +18,9 @@ > from optimum.habana import GaudiConfig, GaudiTrainingArguments > from optimum.habana.utils import set_seed 52c52 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 135c135 < " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)." --- diff --git a/tests/example_diff/run_speech_recognition_ctc.txt b/tests/example_diff/run_speech_recognition_ctc.txt index 81a522135a..6d26e208f5 100644 --- a/tests/example_diff/run_speech_recognition_ctc.txt +++ b/tests/example_diff/run_speech_recognition_ctc.txt @@ -13,9 +13,9 @@ > from optimum.habana.utils import set_seed > 54c53 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 141d139 < 374c372 diff --git a/tests/example_diff/run_summarization.txt b/tests/example_diff/run_summarization.txt index c0b4502500..d7ad7a64d5 100644 --- a/tests/example_diff/run_summarization.txt +++ b/tests/example_diff/run_summarization.txt @@ -18,9 +18,9 @@ > from optimum.habana.utils import set_seed > 55c55 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 119a120,128 > use_cache: bool = field( > default=True, diff --git a/tests/example_diff/run_translation.txt b/tests/example_diff/run_translation.txt index b6194bf060..6047bcc77b 100644 --- a/tests/example_diff/run_translation.txt +++ b/tests/example_diff/run_translation.txt @@ -13,9 +13,9 @@ > from optimum.habana.utils import set_seed > 55c54 -< check_min_version("4.27.0.dev0") +< check_min_version("4.28.0.dev0") --- -> check_min_version("4.26.0") +> check_min_version("4.27.0") 100a100,108 > use_cache: bool = field( > default=True, diff --git a/tests/test_trainer.py b/tests/test_trainer.py index ef539d2b35..5d32aba293 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -26,7 +26,7 @@ from typing import Optional, Union import numpy as np -from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token +from huggingface_hub import HfFolder, Repository, delete_repo from parameterized import parameterized from requests.exceptions import HTTPError from transformers import IntervalStrategy, PretrainedConfig, is_torch_available @@ -891,11 +891,15 @@ def test_log_level(self): logger = logging.get_logger() log_info_string = "Running training" - # test with the default log_level - should be info and thus log on the main process + # test with the default log_level - should be the same as before and thus we test depending on is_info + is_info = logging.get_verbosity() <= 20 with CaptureLogger(logger) as cl: trainer = get_regression_trainer() trainer.train() - self.assertIn(log_info_string, cl.out) + if is_info: + self.assertIn(log_info_string, cl.out) + else: + self.assertNotIn(log_info_string, cl.out) # test with low log_level - lower than info with CaptureLogger(logger) as cl: @@ -923,7 +927,13 @@ def test_save_checkpoints(self): def test_can_resume_training(self): with tempfile.TemporaryDirectory() as tmpdir: - kwargs = {"output_dir": tmpdir, "train_len": 128, "save_steps": 5, "learning_rate": 0.1} + kwargs = { + "output_dir": tmpdir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "logging_steps": 5, + } trainer = get_regression_trainer(**kwargs) # Disable FusedClipNorm because it makes the test fail trainer.gaudi_config.use_fused_clip_norm = False @@ -1473,7 +1483,6 @@ class GaudiTrainerIntegrationWithHubTester(unittest.TestCase): @classmethod def setUpClass(cls): cls._token = TOKEN - set_access_token(TOKEN) HfFolder.save_token(TOKEN) @classmethod