diff --git a/flagscale/train/extra_valid.py b/flagscale/train/extra_valid.py index 02a5eca4..fa1747ca 100644 --- a/flagscale/train/extra_valid.py +++ b/flagscale/train/extra_valid.py @@ -1,4 +1,3 @@ -import os import math import torch @@ -14,6 +13,7 @@ from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.utils import get_blend_from_list from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +from megatron.core.rerun_state_machine import RerunDataIterator from megatron.legacy.data.data_samplers import build_pretraining_data_loader from flagscale.train import get_extra_valid_datasets, set_extra_valid_datasets @@ -198,12 +198,15 @@ def build_extra_valid_data_iterators(build_extra_valid_dataset_provider): def _get_iterator(dataloader_type, dataloader): """Return dataset iterator.""" if dataloader_type == "single": - return iter(dataloader) + return RerunDataIterator(iter(dataloader)) elif dataloader_type == "cyclic": - return iter(cyclic_iter(dataloader)) + return RerunDataIterator(iter(cyclic_iter(dataloader))) elif dataloader_type == "external": # External dataloader is passed through. User is expected to define how to iterate. - return dataloader + if isinstance(dataloader, list): + return [RerunDataIterator(d) for d in dataloader] + else: + return RerunDataIterator(dataloader) else: raise RuntimeError("unexpected dataloader type") diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 46292364..22972d72 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -69,6 +69,7 @@ from megatron.training.async_utils import maybe_finalize_async_save from megatron.training.utils import ( + append_to_progress_log, calc_params_l2_norm, check_adlr_autoresume_termination, logical_and_across_model_parallel_group, @@ -78,7 +79,6 @@ print_rank_last, report_memory, unwrap_model, - append_to_progress_log, update_use_dist_ckpt, ) from megatron.training.global_vars import ( @@ -88,7 +88,8 @@ get_timers, get_tensorboard_writer, get_wandb_writer, - get_one_logger) + get_one_logger, +) from megatron.training import one_logger_utils from megatron.training import ft_integration @@ -220,7 +221,7 @@ def _get_field(string, type): def preprocess_common_state_dict(common_state_dict): import copy - # Convert args key of type namespace to dictionary + # Convert args key of type namespace to dictionary preprocessed_common_state_dict = copy.deepcopy(common_state_dict) preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args']) # Remove rank and local rank from state dict if it exists, since they are expected to be different @@ -352,7 +353,6 @@ def pretrain( train_data_iterator = [] valid_data_iterator = [] test_data_iterator = [] - extra_valid_data_iterator = [] for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) iterators = build_train_valid_test_data_iterators( @@ -360,15 +360,10 @@ def pretrain( train_data_iterator.append(iterators[0]) valid_data_iterator.append(iterators[1]) test_data_iterator.append(iterators[2]) - extra_iterators = build_extra_valid_data_iterators( - extra_valid_dataset_provider) - extra_valid_data_iterator.append(extra_iterators) else: train_data_iterator, valid_data_iterator, test_data_iterator \ = build_train_valid_test_data_iterators( train_valid_test_dataset_provider) - extra_valid_data_iterator = build_extra_valid_data_iterators( - extra_valid_dataset_provider) timers('train/valid/test-data-iterators-setup').stop() print_datetime('after dataloaders are built') app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms() @@ -406,7 +401,7 @@ def pretrain( model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, process_non_loss_data_func, config, checkpointing_context, - non_loss_data_func, extra_valid_data_iterator) + non_loss_data_func, extra_valid_dataset_provider) print_datetime('after training is done') @@ -824,6 +819,7 @@ def train_step(forward_step_func, data_iterator, unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) # Update parameters. + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() @@ -976,13 +972,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r wandb_writer.log({'consumed-tokens': args.consumed_train_samples * args.seq_length / 1000. / 1000 / 1000}, iteration) if writer: writer.add_scalar('learning-rate', learning_rate, iteration) + writer.add_scalar('learning-rate vs samples', learning_rate, + args.consumed_train_samples) if wandb_writer: wandb_writer.log({'learning-rate': learning_rate}, iteration) - if args.decoupled_lr is not None: + if args.decoupled_lr is not None: + if writer: writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) - writer.add_scalar('learning-rate vs samples', learning_rate, - args.consumed_train_samples) - if args.skipped_train_samples > 0: if writer: writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration) @@ -993,9 +989,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples) if wandb_writer: - wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration) wandb_writer.log({'batch-size': batch_size}, iteration) - wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration) for key in loss_dict: if writer: writer.add_scalar(key , loss_dict[key], iteration) @@ -1107,7 +1101,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r writer.add_scalar('throughput', throughput, iteration) if wandb_writer: wandb_writer.log({'throughput': throughput}, iteration) - assert learning_rate is not None # Decoupled_learning_rate should be not None only on first and last pipeline stage. log_string += f' learning rate: {learning_rate:.6E} |' if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or @@ -1141,7 +1134,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) if not args.auto_tune: - if report_memory_flag and learning_rate > 0.: + if report_memory_flag: # Report memory after optimizer state has been initialized. if torch.distributed.get_rank() == 0: num_microbatches = get_num_microbatches() @@ -1386,7 +1379,7 @@ def checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration, def train(forward_step_func, model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, process_non_loss_data_func, config, checkpointing_context, non_loss_data_func, - extra_valid_data_iterator=None): + extra_valid_dataset_provider=None): """Training function: run train_step desired number of times, run validation, checkpoint.""" args = get_args() timers = get_timers() @@ -1663,39 +1656,56 @@ def get_e2e_base_metrics(): ft_integration.get_rank_monitor_client( ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat() - # Extra Evaluation - if args.extra_eval_interval and iteration % args.extra_eval_interval == 0 and \ - getattr(args, "do_extra_valid", False): + # Extra Evaluation ===================================================================== + if args.extra_eval_interval and iteration % args.extra_eval_interval == 0: + # NOTE(zhaoyinglia): Must rebuild the dataloaders for extra validation here, + # to guarantee extra validation start from extra_iter=0 every time, + # but we don't need to rebuild the datasets. + if args.virtual_pipeline_model_parallel_size is not None: + extra_valid_data_iterator = [] + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + extra_iterators = build_extra_valid_data_iterators( + extra_valid_dataset_provider) + extra_valid_data_iterator.append(extra_iterators) + else: + extra_valid_data_iterator = build_extra_valid_data_iterators( + extra_valid_dataset_provider) timers('interval-time').stop() - if args.use_distributed_optimizer and args.overlap_param_gather: - disable_forward_pre_hook(model) - if args.manual_gc and args.manual_gc_eval: - # Collect all objects. - gc.collect() - prefix = 'iteration {}'.format(iteration) - for extra_valid_index, extra_valid_data_itr in enumerate(extra_valid_data_iterator): - timers('extra-eval-time', log_level=0).start(barrier=True) - extra_eval_iters = args.extra_eval_iters_list[extra_valid_index] - extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func, - extra_valid_data_itr, model, - iteration, process_non_loss_data_func, - config, verbose=False, write_to_tensorboard=True, - non_loss_data_func=non_loss_data_func) - extra_eval_duration += timers('extra-eval-time').elapsed() - extra_eval_iterations += extra_eval_iters - timers('extra-eval-time').stop() - one_logger_utils.track_e2e_metrics() - - if args.manual_gc and args.manual_gc_eval: - # Collect only the objects created and used in evaluation. - gc.collect(generation=0) - if args.use_distributed_optimizer and args.overlap_param_gather: - enable_forward_pre_hook(model) - timers('interval-time', log_level=0).start(barrier=True) + # do_extra_valid flag is used to indicate that we are doing extra validation + # and is set in the build_extra_valid_data_iterators function + if getattr(args, "do_extra_valid", False): + if args.use_distributed_optimizer and args.overlap_param_gather: + disable_forward_pre_hook(model) + if args.manual_gc and args.manual_gc_eval: + # Collect all objects. + gc.collect() + prefix = 'iteration {}'.format(iteration) + for extra_valid_index, extra_valid_data_itr in enumerate(extra_valid_data_iterator): + timers('extra-eval-time', log_level=0).start(barrier=True) + extra_eval_iters = args.extra_eval_iters_list[extra_valid_index] + extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func, + extra_valid_data_itr, model, + iteration, process_non_loss_data_func, + config, verbose=False, write_to_tensorboard=True, + non_loss_data_func=non_loss_data_func) + extra_eval_duration += timers('extra-eval-time').elapsed() + extra_eval_iterations += extra_eval_iters + timers('extra-eval-time').stop() + one_logger_utils.track_e2e_metrics() + + if args.manual_gc and args.manual_gc_eval: + # Collect only the objects created and used in evaluation. + gc.collect(generation=0) + if args.use_distributed_optimizer and args.overlap_param_gather: + enable_forward_pre_hook(model) + pre_hook_enabled = True + timers('interval-time', log_level=0).start(barrier=True) - if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None: - ft_integration.get_rank_monitor_client( - ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat() + if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None: + ft_integration.get_rank_monitor_client( + ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat() + # ======================================================================================= # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). # Some of these only happen at specific iterations. diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index 1bc98d0d..5692f558 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -2356,32 +2356,3 @@ def _add_auto_tuner_args(parser): help='use auto tuner') return parser - - -def _add_hetero_args(parser): - group = parser.add_argument_group(title="heterogeneous training") - - group.add_argument('--enable-hetero', action="store_true", - help='the mode of heterogeneous training') - group.add_argument('--hetero-device-types', nargs='*', type=str, default=None, - help='the list of device types: device_type_0 device_type_1 ...') - group.add_argument('--hetero-current-device-type', type=str, default=None, - help='the current device type') - group.add_argument('--hetero-pipeline-layer-split', nargs='*', type=int, default=None, - help='Incompatible with --num-layers-per-virtual-pipeline-stage for now.' - 'hetero-pipeline-layer-split must be in the form: layers_0 layers_1 ... layers_n. The number of the list should be equal to pipeline-model-parallel-size.') - group.add_argument('--hetero-process-meshes', nargs='*', type=int, default=None, - help='Use this arg to set TP-CP-DP-PP of each process mesh.' - 'This argument must be in the form: TP0, CP0, DP0, PP0, TP1, CP0, DP1, PP1...TPN, CPN, DPN, PPN. CP and TP size can be different, sum of PP should match pipeline-model-parallel-size, DP size should be the same.') - group.add_argument('--hetero-use-cpu-communication', action='store_true', help='Use CPU for communication for heterogeneous communication.') - - return parser - - -def _add_auto_tuner_args(parser): - group = parser.add_argument_group(title="auto tuner") - - group.add_argument('--auto-tune', action='store_true', - help='use auto tuner') - - return parser