Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Update extra validation feature #303

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 103 additions & 53 deletions flagscale/train/extra_valid.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
import sys
import math
import torch

from megatron.training.global_vars import get_args
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_wandb_writer
from megatron.training.global_vars import get_tokenizer
from megatron.training.utils import print_rank_0
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_tokenizer
from megatron.training.utils import print_rank_last
from megatron.training.utils import is_last_rank
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_wandb_writer
from megatron.core import mpu
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
Expand All @@ -21,30 +20,37 @@


def is_dataset_built_on_rank():
return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0
return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0


def core_gpt_dataset_config_from_args(args, data_path):
tokenizer = get_tokenizer()

# Only build the validation dataset
assert data_path is not None, \
"Please provide a valid data_path for extra validation dataset."
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend=get_blend_from_list(data_path),
blend_per_split=None,
renormalize_blend_weights=args.renormalize_blend_weights,
split="0,1,0",
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
mmap_bin_files=args.mmap_bin_files,
tokenizer=tokenizer,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path,
)


def extra_valid_dataset_provider(data_path, num_samples, tag):
def extra_valid_datasets_provider(data_path, num_samples):
"""Build the train test and extra_validation datasets.

Args:
Expand All @@ -55,10 +61,8 @@ def extra_valid_dataset_provider(data_path, num_samples, tag):
config = core_gpt_dataset_config_from_args(args, data_path)

if args.mock_data:
from megatron.core.datasets.gpt_dataset import MockGPTDataset
dataset_type = MockGPTDataset
else:
from megatron.core.datasets.gpt_dataset import GPTDataset
dataset_type = GPTDataset

print_rank_0(f"> building extra validation dataset ({data_path}, {num_samples}) for GPT ...")
Expand All @@ -82,31 +86,42 @@ def build_extra_valid_datasets(build_extra_valid_dataset_provider):
"""Build extra_valid datasets."""

args = get_args()
num_tokens_list = args.extra_valid_data_path[0::3]
paths = args.extra_valid_data_path[1::3]
tags = args.extra_valid_data_path[2::3]

num_samples_list = []
valid_iters_list = []
for num_tokens in num_tokens_list:
assert int(num_tokens) > 0, f"Number of tokens {num_tokens} should be greater than 0"

assert len(args.extra_valid_data_path) % 2 == 0, \
"extra_valid_data_path format should be a list of weight, prefix and tag."

blend = args.extra_valid_data_path
raw_num_tokens_per_dataset, raw_prefix_paths_per_dataset = zip(
*[(blend[i], blend[i+1]) for i in range(0, len(blend), 2)]
)

num_samples_per_dataset = []
valid_iters_per_dataset = []
for rntpd in raw_num_tokens_per_dataset:
try:
num_tokens = int(rntpd)
except ValueError:
raise ValueError(f"Number of tokens {rntpd} is error.")

assert num_tokens > 0, f"Number of tokens {num_tokens} should be greater than 0"
# Make sure that the number of samples is a multiple of the sequence length
num_samples = (int(num_tokens) + args.seq_length - 1) // args.seq_length
num_samples = (num_tokens + args.seq_length - 1) // args.seq_length
# Make sure that the number of samples is a multiple of the global batch size.
eval_iters = (num_samples + args.global_batch_size - 1) // args.global_batch_size
num_samples = eval_iters * args.global_batch_size
num_samples_list.append(num_samples)
valid_iters_list.append(eval_iters)
args.extra_valid_iters_list = valid_iters_list
num_samples_per_dataset.append(num_samples)
valid_iters_per_dataset.append(eval_iters)

args.extra_eval_iters_list = valid_iters_per_dataset
args.extra_prefix_paths_list = raw_prefix_paths_per_dataset
args.extra_num_samples_list = num_samples_per_dataset

assert len(paths) == len(num_samples_list), \
f"Number of extra_valid data paths {len(paths)} does not match number of extra_valid data samples {len(num_samples_list)}"
assert len(raw_prefix_paths_per_dataset) == len(num_samples_per_dataset), \
f"Number of extra_valid data paths {len(raw_prefix_paths_per_dataset)} does not match number of extra_valid data samples {len(num_samples_per_dataset)}"

extra_valid_datasets = []
for path, num_samples, tag in zip(paths, num_samples_list, tags):
assert os.path.exists(path + ".bin"), f"Path {path} does not exist"
assert os.path.exists(path + ".idx"), f"Path {path} does not exist"
extra_valid_datasets.append(build_extra_valid_dataset_provider(path, num_samples, tag))
for path, num_samples in zip(raw_prefix_paths_per_dataset, num_samples_per_dataset):
extra_valid_datasets.append(build_extra_valid_dataset_provider([path], num_samples))

return extra_valid_datasets

Expand All @@ -116,11 +131,10 @@ def build_extra_valid_data_loaders(build_extra_valid_dataset_provider):

args = get_args()

paths = args.extra_valid_data_path[1::3]

extra_valid_dataloaders = [None for _ in paths]
extra_valid_dataloaders = [None]

print_rank_0('> building extra validation datasets ...')
print_rank_0('> extra validation consumed_samples is always 0.')

# Rely on distributed-aware core datasets, temporary
is_distributed = getattr(build_extra_valid_dataset_provider, "is_distributed", False)
Expand Down Expand Up @@ -155,18 +169,45 @@ def build_extra_valid_data_loaders(build_extra_valid_dataset_provider):
return extra_valid_dataloaders


def cyclic_iter(iter):
while True:
for x in iter:
yield x


def build_extra_valid_data_iterators(build_extra_valid_dataset_provider):
"""Build pretraining data iterators."""
if build_extra_valid_dataset_provider is None:
return None

args = get_args()

# Build loaders.
extra_valid_dataloaders = \
build_extra_valid_data_loaders(
build_extra_valid_dataset_provider)

# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic', 'external']

def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return iter(dataloader)
elif dataloader_type == "cyclic":
return iter(cyclic_iter(dataloader))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return dataloader
else:
raise RuntimeError("unexpected dataloader type")

if extra_valid_dataloaders[0] is not None:
extra_valid_data_iterators = []
for extra_valid_dataloader in extra_valid_dataloaders:
extra_valid_data_iterators.append(iter(extra_valid_dataloader))
extra_valid_data_iterators = [
_get_iterator(dl_type, extra_valid_dataloader)
for extra_valid_dataloader in extra_valid_dataloaders
]
else:
extra_valid_data_iterators = [None for _ in extra_valid_dataloaders]

Expand All @@ -176,7 +217,7 @@ def build_extra_valid_data_iterators(build_extra_valid_dataset_provider):
def extra_evaluate_and_print_results(index, prefix, forward_step_func,
data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=False, write_to_tensorboard=True):
verbose=False, write_to_tensorboard=True, non_loss_data_func=None):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
if write_to_tensorboard:
Expand All @@ -186,23 +227,27 @@ def extra_evaluate_and_print_results(index, prefix, forward_step_func,

wandb_writer = get_wandb_writer()

# To avoid the circular import.
from megatron.training.training import evaluate
from flagscale.train.train import evaluate # To avoid the circular import
total_loss_dict, collected_non_loss_data, timelimit = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, config, verbose, index)
process_non_loss_data_func, config, verbose, non_loss_data_func, index)

# Timelimit hit during evaluation
if timelimit:
return
extra_valid_data_path = args.extra_valid_data_path
path = extra_valid_data_path[1::3][index]
filename = os.path.basename(path)
tag = extra_valid_data_path[2::3][index]
label = f'{filename}-{tag}'
string = ' extra_validation {} loss at {} | '.format(label, prefix)
loss_section = 'validation loss'
ppl_section = 'validation ppl'

label = ''
extra_prefix_paths_list = getattr(args, "extra_prefix_paths_list", None)
if extra_prefix_paths_list:
label = f'{extra_prefix_paths_list[index]}'

comsumed_samples = ''
extra_num_samples_list = getattr(args, "extra_num_samples_list", None)
if extra_num_samples_list:
comsumed_samples = extra_num_samples_list[index]

string = f' extra validation {prefix} loss at {label} | '
string += f'consumed samples: {comsumed_samples} | '
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
Expand All @@ -214,17 +259,22 @@ def extra_evaluate_and_print_results(index, prefix, forward_step_func,
writer.add_scalar('{} validation {} vs samples'.format(key, label),
total_loss_dict[key].item(),
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'{}/{} validation {}'.format(loss_section, key, label): total_loss_dict[key].item()},
iteration)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar('{} validation {} ppl'.format(key, label), ppl,
iteration)
writer.add_scalar('{} validation {} ppl vs samples'.format(key, label),
ppl, args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'{}/{} validation {} ppl'.format(ppl_section, key, label): ppl},
iteration)
if wandb_writer and is_last_rank():
wandb_writer.log({
'{} validation {}'.format(key, label): total_loss_dict[key].item()},
iteration)
wandb_writer.log({
'{} validation {} vs samples'.format(key, label): args.consumed_train_samples},
iteration)
wandb_writer.log({'validation ppl/{} validation {} ppl'.format(key, label): ppl},
iteration)
wandb_writer.log({'validation loss/{} validation {}'.format(key, label): total_loss_dict[key].item()},
iteration)

if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
Expand Down
Loading
Loading