Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix flagscale entrypoint train.py and update extra_valid for newest megatron #306

Merged
merged 4 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions flagscale/train/extra_valid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import math
import torch

Expand All @@ -14,6 +13,7 @@
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.rerun_state_machine import RerunDataIterator
from megatron.legacy.data.data_samplers import build_pretraining_data_loader

from flagscale.train import get_extra_valid_datasets, set_extra_valid_datasets
Expand Down Expand Up @@ -198,12 +198,15 @@ def build_extra_valid_data_iterators(build_extra_valid_dataset_provider):
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return iter(dataloader)
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return iter(cyclic_iter(dataloader))
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return dataloader
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")

Expand Down
110 changes: 60 additions & 50 deletions flagscale/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@

from megatron.training.async_utils import maybe_finalize_async_save
from megatron.training.utils import (
append_to_progress_log,
calc_params_l2_norm,
check_adlr_autoresume_termination,
logical_and_across_model_parallel_group,
Expand All @@ -78,7 +79,6 @@
print_rank_last,
report_memory,
unwrap_model,
append_to_progress_log,
update_use_dist_ckpt,
)
from megatron.training.global_vars import (
Expand All @@ -88,7 +88,8 @@
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger)
get_one_logger,
)
from megatron.training import one_logger_utils

from megatron.training import ft_integration
Expand Down Expand Up @@ -220,7 +221,7 @@ def _get_field(string, type):

def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
Expand Down Expand Up @@ -352,23 +353,17 @@ def pretrain(
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
extra_valid_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
iterators = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
extra_iterators = build_extra_valid_data_iterators(
extra_valid_dataset_provider)
extra_valid_data_iterator.append(extra_iterators)
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
extra_valid_data_iterator = build_extra_valid_data_iterators(
extra_valid_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()
Expand Down Expand Up @@ -406,7 +401,7 @@ def pretrain(
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context,
non_loss_data_func, extra_valid_data_iterator)
non_loss_data_func, extra_valid_dataset_provider)

print_datetime('after training is done')

Expand Down Expand Up @@ -824,6 +819,7 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

# Update parameters.

timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
Expand Down Expand Up @@ -976,13 +972,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
wandb_writer.log({'consumed-tokens': args.consumed_train_samples * args.seq_length / 1000. / 1000 / 1000}, iteration)
if writer:
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.decoupled_lr is not None:
if args.decoupled_lr is not None:
if writer:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)

if args.skipped_train_samples > 0:
if writer:
writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
Expand All @@ -993,9 +989,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration)
wandb_writer.log({'batch-size': batch_size}, iteration)
wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration)
for key in loss_dict:
if writer:
writer.add_scalar(key , loss_dict[key], iteration)
Expand Down Expand Up @@ -1107,7 +1101,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
# Decoupled_learning_rate should be not None only on first and last pipeline stage.
log_string += f' learning rate: {learning_rate:.6E} |'
if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or
Expand Down Expand Up @@ -1141,7 +1134,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if not args.auto_tune:
if report_memory_flag and learning_rate > 0.:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this pr remove the learning_rate > 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cause Megatron-LM removed it, but we didn't in the previous merge-megatron pr.

if report_memory_flag:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
Expand Down Expand Up @@ -1386,7 +1379,7 @@ def checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func,
extra_valid_data_iterator=None):
extra_valid_dataset_provider=None):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
Expand Down Expand Up @@ -1663,39 +1656,56 @@ def get_e2e_base_metrics():
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat()

# Extra Evaluation
if args.extra_eval_interval and iteration % args.extra_eval_interval == 0 and \
getattr(args, "do_extra_valid", False):
# Extra Evaluation =====================================================================
if args.extra_eval_interval and iteration % args.extra_eval_interval == 0:
# NOTE(zhaoyinglia): Must rebuild the dataloaders for extra validation here,
# to guarantee extra validation start from extra_iter=0 every time,
# but we don't need to rebuild the datasets.
if args.virtual_pipeline_model_parallel_size is not None:
extra_valid_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
extra_iterators = build_extra_valid_data_iterators(
extra_valid_dataset_provider)
extra_valid_data_iterator.append(extra_iterators)
else:
extra_valid_data_iterator = build_extra_valid_data_iterators(
extra_valid_dataset_provider)
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model)
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = 'iteration {}'.format(iteration)
for extra_valid_index, extra_valid_data_itr in enumerate(extra_valid_data_iterator):
timers('extra-eval-time', log_level=0).start(barrier=True)
extra_eval_iters = args.extra_eval_iters_list[extra_valid_index]
extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func,
extra_valid_data_itr, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
extra_eval_duration += timers('extra-eval-time').elapsed()
extra_eval_iterations += extra_eval_iters
timers('extra-eval-time').stop()
one_logger_utils.track_e2e_metrics()

if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
timers('interval-time', log_level=0).start(barrier=True)
# do_extra_valid flag is used to indicate that we are doing extra validation
# and is set in the build_extra_valid_data_iterators function
if getattr(args, "do_extra_valid", False):
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model)
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = 'iteration {}'.format(iteration)
for extra_valid_index, extra_valid_data_itr in enumerate(extra_valid_data_iterator):
timers('extra-eval-time', log_level=0).start(barrier=True)
extra_eval_iters = args.extra_eval_iters_list[extra_valid_index]
extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func,
extra_valid_data_itr, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
extra_eval_duration += timers('extra-eval-time').elapsed()
extra_eval_iterations += extra_eval_iters
timers('extra-eval-time').stop()
one_logger_utils.track_e2e_metrics()

if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)

if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat()
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat()
# =======================================================================================

# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
Expand Down
29 changes: 0 additions & 29 deletions megatron/megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,32 +2356,3 @@ def _add_auto_tuner_args(parser):
help='use auto tuner')

return parser


def _add_hetero_args(parser):
group = parser.add_argument_group(title="heterogeneous training")

group.add_argument('--enable-hetero', action="store_true",
help='the mode of heterogeneous training')
group.add_argument('--hetero-device-types', nargs='*', type=str, default=None,
help='the list of device types: device_type_0 device_type_1 ...')
group.add_argument('--hetero-current-device-type', type=str, default=None,
help='the current device type')
group.add_argument('--hetero-pipeline-layer-split', nargs='*', type=int, default=None,
help='Incompatible with --num-layers-per-virtual-pipeline-stage for now.'
'hetero-pipeline-layer-split must be in the form: layers_0 layers_1 ... layers_n. The number of the list should be equal to pipeline-model-parallel-size.')
group.add_argument('--hetero-process-meshes', nargs='*', type=int, default=None,
help='Use this arg to set TP-CP-DP-PP of each process mesh.'
'This argument must be in the form: TP0, CP0, DP0, PP0, TP1, CP0, DP1, PP1...TPN, CPN, DPN, PPN. CP and TP size can be different, sum of PP should match pipeline-model-parallel-size, DP size should be the same.')
group.add_argument('--hetero-use-cpu-communication', action='store_true', help='Use CPU for communication for heterogeneous communication.')

return parser


def _add_auto_tuner_args(parser):
group = parser.add_argument_group(title="auto tuner")

group.add_argument('--auto-tune', action='store_true',
help='use auto tuner')

return parser
Loading