Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] add skip-train extra validation #307

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flagscale/train/extra_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def extra_evaluate_and_print_results(index, prefix, forward_step_func,
if extra_num_samples_list:
comsumed_samples = extra_num_samples_list[index]

string = f' extra validation {prefix} loss at {label} | '
string = f' extra validation loss at {prefix} {label} | '
string += f'consumed samples: {comsumed_samples} | '
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
Expand Down
33 changes: 28 additions & 5 deletions flagscale/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,29 @@ def pretrain(
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func)

if extra_valid_dataset_provider is not None:
# NOTE(zhaoyinglia): Must rebuild the dataloaders for extra validation here,
# to guarantee extra validation start from extra_iter=0 every time,
# but we don't need to rebuild the datasets.
if args.virtual_pipeline_model_parallel_size is not None:
extra_valid_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
extra_iterators = build_extra_valid_data_iterators(
extra_valid_dataset_provider)
extra_valid_data_iterator.append(extra_iterators)
else:
extra_valid_data_iterator = build_extra_valid_data_iterators(
extra_valid_dataset_provider)
if getattr(args, "do_extra_valid", False):
prefix = f'iteration {iteration} on extra validation set'
for extra_valid_index, extra_valid_data_itr in enumerate(extra_valid_data_iterator):
extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func,
extra_valid_data_itr, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func)

wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
Expand Down Expand Up @@ -1673,7 +1696,7 @@ def get_e2e_base_metrics():
extra_valid_dataset_provider)
timers('interval-time').stop()
# do_extra_valid flag is used to indicate that we are doing extra validation
# and is set in the build_extra_valid_data_iterators function
# and is set in the build_extra_valid_data_iterators function
if getattr(args, "do_extra_valid", False):
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model)
Expand All @@ -1685,10 +1708,10 @@ def get_e2e_base_metrics():
timers('extra-eval-time', log_level=0).start(barrier=True)
extra_eval_iters = args.extra_eval_iters_list[extra_valid_index]
extra_evaluate_and_print_results(extra_valid_index, prefix, forward_step_func,
extra_valid_data_itr, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
extra_valid_data_itr, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
extra_eval_duration += timers('extra-eval-time').elapsed()
extra_eval_iterations += extra_eval_iters
timers('extra-eval-time').stop()
Expand Down
Loading