diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 7c2014c5c1..622e6a11db 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +The Gaudi Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" import contextlib import copy import functools -import importlib.metadata import inspect import json import math @@ -35,9 +37,15 @@ import torch from accelerate import skip_first_batches from accelerate.data_loader import SeedableRandomSampler -from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin, save_fsdp_model +from accelerate.utils import ( + DistributedDataParallelKwargs, + GradientAccumulationPlugin, + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, +) from huggingface_hub import upload_folder -from packaging import version from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler from transformers import Trainer from transformers.data.data_collator import DataCollator @@ -50,7 +58,7 @@ ) from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers.trainer import _get_fsdp_ckpt_kwargs +from transformers.trainer import _get_fsdp_ckpt_kwargs, _is_peft_model from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerState from transformers.trainer_pt_utils import ( DistributedTensorGatherer, @@ -91,7 +99,6 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, PushInProgress, - is_accelerate_available, is_datasets_available, is_peft_available, is_safetensors_available, @@ -117,52 +124,76 @@ if is_datasets_available(): import datasets - if is_safetensors_available(): import safetensors.torch - if is_peft_available(): from peft import PeftModel from peft.utils import PeftType - if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper -if is_accelerate_available(): - from accelerate.utils import ( - load_fsdp_optimizer, - save_fsdp_optimizer, - ) +from accelerate.utils import DataLoaderConfiguration -if TYPE_CHECKING: - import optuna +def _get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple[bool, Dict]: + """ + Determines whether the input settings need to be updated. -DATA_SAMPLERS = [RandomSampler, SeedableRandomSampler] + Currently (attn_softmax_bf16, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask) are enabled only for llama, qwen2, starcoder2, gemma, baichuan + and chatglm + lazy_mode for llama, qwen2, starcoder2 and mistral -if is_accelerate_available("0.28.0"): - from accelerate.utils import DataLoaderConfiguration + Args: + model: The model instance for which the input update settings are being evaluated + lazy_mode[Optional[bool]]: Whether to use lazy mode for the model (defaults to `None`) + Returns: + Tuple[bool, Dict]: A flag indicating whether the input settings should be updated. + A dictionary containing the specific input settings that need to be updated, if any + """ + inputs_update: Dict = {} -def _is_peft_model(model): - if is_peft_available(): - classes_to_check = (PeftModel,) if is_peft_available() else () - # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 - if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): - from peft import PeftMixedModel + should_update_inputs = (getattr(model, "generation_config", None) is not None) and ( + model.config.model_type in ("llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm") + ) + if should_update_inputs: + if model.generation_config.attn_softmax_bf16: + inputs_update["attn_softmax_bf16"] = True + if model.generation_config.use_flash_attention: + inputs_update["use_flash_attention"] = True + if model.generation_config.flash_attention_recompute: + inputs_update["flash_attention_recompute"] = True + if model.generation_config.flash_attention_causal_mask: + inputs_update["flash_attention_causal_mask"] = True + + should_update_inputs = ( + (getattr(model, "generation_config", None) is not None) + and (model.config.model_type in ("llama", "qwen2", "starcoder2", "mistral")) + and (lazy_mode is not None) + ) + if should_update_inputs: + if _is_peft_model(model): + forward_method = getattr(model.get_base_model(), "forward") + else: + forward_method = getattr(model, "forward") + signature = inspect.signature(forward_method) + if "lazy_mode" in signature.parameters: + inputs_update["lazy_mode"] = lazy_mode + + should_update_inputs: bool = len(inputs_update) > 0 - classes_to_check = (*classes_to_check, PeftMixedModel) - return isinstance(model, classes_to_check) - return False + return should_update_inputs, inputs_update if TYPE_CHECKING: - if is_datasets_available(): - import datasets + import optuna + +DATA_SAMPLERS = [RandomSampler, SeedableRandomSampler] logger = logging.get_logger(__name__) @@ -328,13 +359,12 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: num_samples = len(self.train_dataset) if ( not self.args.dataloader_drop_last - and len(self.train_dataset) % self.args.per_device_train_batch_size != 0 + and num_samples % self.args.per_device_train_batch_size != 0 and self.args.parallel_mode != ParallelMode.DISTRIBUTED ): # Make the total number of samples divisible by the batch size in lazy mode if needed num_samples += ( - self.args.per_device_train_batch_size - - len(self.train_dataset) % self.args.per_device_train_batch_size + self.args.per_device_train_batch_size - num_samples % self.args.per_device_train_batch_size ) return RandomSampler(self.train_dataset, num_samples=num_samples) @@ -602,7 +632,7 @@ def _inner_training_loop( num_train_tokens = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) - num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: @@ -873,6 +903,15 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio self._globalstep_last_logged = self.state.global_step self._zero_model_grad(model) _grad_norm: Optional[float] = None + _should_compute_grad_norm: bool = not self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED and ( + # Gradient clipping + args.max_grad_norm is not None and args.max_grad_norm > 0 + ) + + # attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma and baichuan + # lazy_mode for llama, qwen2, starcoder2 and mistral + _should_update_inputs, _inputs_update = _get_input_update_settings(self.model, lazy_mode=args.use_lazy_mode) + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.eval_on_start: @@ -974,32 +1013,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio self.control = self.callback_handler.on_step_begin(args, self.state, self.control) # attn_softmax_bf16 and use_flash_attention is enabled only for llama, qwen2, starcoder2, gemma, baichuan and chatglm - if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type in [ - "llama", - "qwen2", - "starcoder2", - "gemma", - "baichuan", - "chatglm", - ]: - if self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True - if self.model.generation_config.use_flash_attention: - inputs["use_flash_attention"] = True - if self.model.generation_config.flash_attention_recompute: - inputs["flash_attention_recompute"] = True - if self.model.generation_config.flash_attention_causal_mask: - inputs["flash_attention_causal_mask"] = True - if self.model.config is not None: - if self.model.config.model_type in ["llama", "qwen2", "mistral", "starcoder2"]: - if _is_peft_model(model): - forward_method = getattr(model.get_base_model(), "forward") - else: - forward_method = getattr(model, "forward") - signature = inspect.signature(forward_method) - if "lazy_mode" in signature.parameters: - inputs["lazy_mode"] = args.use_lazy_mode + # lazy_mode for llama, qwen2, starcoder2 and mistral + if _should_update_inputs: + inputs.update(_inputs_update) + # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) @@ -1044,10 +1061,9 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio if is_last_step_and_steps_less_than_grad_acc: self.accelerator.gradient_state._set_sync_gradients(True) - # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: + # If the condition is true, we need to compute _grad_norm + if _should_compute_grad_norm: # deepspeed does its own clipping - if self.gaudi_config.use_fused_clip_norm and args.use_habana: # TODO: to merge self.accelerator.clip_grad_norm_ when HMP is removed _grad_norm = self.FusedNorm.clip_norm(model.parameters()) @@ -1060,7 +1076,6 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - optimizer_was_run = True self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) @@ -1070,8 +1085,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() - self._zero_model_grad(model) + self._zero_model_grad(model) self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch if args.use_lazy_mode: @@ -1170,22 +1185,22 @@ def _load_best_model(self): best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) model = self.model - # TODO: check if the code below works - # if self.is_deepspeed_enabled: - # deepspeed_load_checkpoint( - # self.model_wrapped, - # self.state.best_model_checkpoint, - # load_module_strict=not _is_peft_model(self.model), - # ) - # elif self.is_fsdp_enabled: - # load_result = load_fsdp_model( - # self.accelerator.state.fsdp_plugin, - # self.accelerator, - # model, - # self.state.best_model_checkpoint, - # **_get_fsdp_ckpt_kwargs(), - # ) - if ( + + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, + self.state.best_model_checkpoint, + load_module_strict=not _is_peft_model(self.model), + ) + elif self.is_fsdp_enabled: + load_result = load_fsdp_model( + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + self.state.best_model_checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + elif ( os.path.exists(best_model_path) or os.path.exists(best_safe_model_path) or os.path.exists(best_adapter_model_path) @@ -1265,7 +1280,7 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign # This grad_norm block was outside of _maybe_log_save_evaluate method causing perf degradataion. # Moving it here so the grad tensor is only copied when it's needed. - if is_accelerate_available() and self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED: + if self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED: grad_norm = model.get_global_grad_norm() # In some cases the grad norm may not return a float if hasattr(grad_norm, "item"): @@ -1276,7 +1291,7 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign and self.accelerator.distributed_type != GaudiDistributedType.FSDP and _grad_norm.size() == torch.Size([1]) ): - grad_norm = _grad_norm.item() + grad_norm = _grad_norm.detach().item() else: grad_norm = None @@ -1340,72 +1355,6 @@ def _load_rng_state(self, checkpoint): "\nThis won't yield the same results as if the training had not been interrupted." ) - def _save_checkpoint(self, model, trial, metrics=None): - # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we - # want to save except FullyShardedDDP. - # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" - - # Save model checkpoint - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - - if self.hp_search_backend is None and trial is None: - self.store_flos() - - run_dir = self._get_output_dir(trial=trial) - output_dir = os.path.join(run_dir, checkpoint_folder) - self.save_model(output_dir, _internal_call=True) - - if not self.args.save_only_model: - # Save optimizer and scheduler - self._save_optimizer_and_scheduler(output_dir) - # Save RNG state - self._save_rng_state(output_dir) - - # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - try: - metric_value = metrics[metric_to_check] - except KeyError as exc: - raise KeyError( - f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " - f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." - ) from exc - - operator = np.greater if self.args.greater_is_better else np.less - if ( - self.state.best_metric is None - or self.state.best_model_checkpoint is None - or operator(metric_value, self.state.best_metric) - ): - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir - - # Save the Trainer state - if self.args.should_save: - # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently - for cb in [ - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) - ]: - cb_name = cb.__class__.__name__ - cb_state = cb.state() - if isinstance(self.state.stateful_callbacks[cb_name], list): - self.state.stateful_callbacks[cb_name].append(cb_state) - else: - self.state.stateful_callbacks[cb_name] = cb_state - self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) - - if self.args.push_to_hub: - self._push_from_checkpoint(output_dir) - - # Maybe delete some older checkpoints. - if self.args.should_save: - # Solely rely on numerical checkpoint id for rotation. - # mtime is not reliable especially on some fuse fs in cloud environments. - self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) - def _save_rng_state(self, output_dir): # Save RNG state in non-distributed training rng_states = { @@ -1583,7 +1532,7 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): arguments, depending on the situation. Modified by Habana to enable using `autocast` on Gaudi devices. """ if self.use_cpu_amp: - ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=torch.bfloat16) + ctx_manager = torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled) elif self.use_hpu_amp: ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True) else: @@ -1954,6 +1903,12 @@ def evaluation_loop( # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 + # attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma and baichuan + _should_update_inputs, _inputs_update = _get_input_update_settings(self.model) + + # set a default dtype of logits + logits_dtype: str = "float32" + # Main evaluation loop start_time_eval = time.time() for step, inputs in enumerate(dataloader): @@ -1974,44 +1929,30 @@ def evaluation_loop( batch_size = observed_batch_size # attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma, baichuan and chatglm - if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type in ["llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm"]: - if self.model.generation_config.attn_softmax_bf16: - inputs["attn_softmax_bf16"] = True - if self.model.generation_config.use_flash_attention: - inputs["use_flash_attention"] = True - if self.model.generation_config.flash_attention_recompute: - inputs["flash_attention_recompute"] = True - if self.model.generation_config.flash_attention_causal_mask: - inputs["flash_attention_causal_mask"] = True + if _should_update_inputs: + inputs.update(_inputs_update) # Prediction step losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) main_input_name = getattr(self.model, "main_input_name", "input_ids") inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None - # Save the logits dtype since we need to convert them into floats during the process - # They will be converted back into their original dtype right before computing metrics - if logits is not None: - logits_dtype = get_dtype(logits) - # Update containers if losses is not None: losses = self.gather_function((losses.repeat(batch_size))) all_losses.add(losses) - if labels is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) - if self.args.context_parallel_size != 1: - labels = labels.clone() - labels = self.gather_function((labels)) - if not self.args.batch_eval_metrics or description == "Prediction": - all_labels.add(labels) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) + if labels is not None: + # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block. + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + # Save the logits dtype since we need to convert them into floats during the process + # They will be converted back into their original dtype right before computing metrics if logits is not None: + logits_dtype = get_dtype(logits) if args.use_habana and logits_dtype != "float32": logits = to_device_dtype(logits, target_dtype=torch.float32) logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) @@ -2020,6 +1961,12 @@ def evaluation_loop( logits = self.gather_function((logits)) if not self.args.batch_eval_metrics or description == "Prediction": all_preds.add(logits) + if labels is not None: + if self.args.context_parallel_size != 1: + labels = labels.clone() + labels = self.gather_function((labels)) + if not self.args.batch_eval_metrics or description == "Prediction": + all_labels.add(labels) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) @@ -2321,6 +2268,7 @@ def prediction_loop( # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped + model.eval() # if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): # self.optimizer.eval() @@ -2474,7 +2422,7 @@ def prediction_loop( def create_accelerator_and_postprocess(self): grad_acc_kwargs = {} - if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + if self.args.accelerator_config.gradient_accumulation_kwargs is not None: grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs # check if num_steps is attempted to be passed in gradient_accumulation_kwargs @@ -2494,25 +2442,18 @@ def create_accelerator_and_postprocess(self): accelerator_config = self.args.accelerator_config.to_dict() - if is_accelerate_available("0.28.0"): - dataloader_config = DataLoaderConfiguration( - split_batches=accelerator_config.pop("split_batches"), - dispatch_batches=accelerator_config.pop("dispatch_batches"), - even_batches=accelerator_config.pop("even_batches"), - use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), - ) + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) non_blocking = accelerator_config.pop("non_blocking") - if not is_accelerate_available("0.30.0"): - if non_blocking: - raise ImportError( - "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." - ) - else: - if non_blocking and not self.args.dataloader_pin_memory: - logger.warning( - "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." - ) - dataloader_config.non_blocking = non_blocking + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning( + "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." + ) + dataloader_config.non_blocking = non_blocking # this would have been updated above, no need for it anymore accelerator_config.pop("gradient_accumulation_kwargs") @@ -2521,11 +2462,8 @@ def create_accelerator_and_postprocess(self): "gradient_accumulation_plugin": gradient_accumulation_plugin, "distribution_strategy": self.args.distribution_strategy, "dynamic": self.args.compile_dynamic, + "dataloader_config": dataloader_config, } - if is_accelerate_available("0.28.0"): - args["dataloader_config"] = dataloader_config - else: - args.update(accelerator_config) # create accelerator object self.accelerator = GaudiAccelerator(**args)