From b91b9b3a263ddedb916862730b52ca9697648997 Mon Sep 17 00:00:00 2001 From: Jannik Brinkmann Date: Sat, 24 Feb 2024 13:17:14 +0100 Subject: [PATCH] updated training script --- src/delphi/train/__init__.py | 1 + src/delphi/{ => train}/llama2.py | 0 src/delphi/train/training_old.py | 226 ++++++++++++++++--------------- src/delphi/train/training_old.sh | 6 + 4 files changed, 127 insertions(+), 106 deletions(-) rename src/delphi/{ => train}/llama2.py (100%) create mode 100644 src/delphi/train/training_old.sh diff --git a/src/delphi/train/__init__.py b/src/delphi/train/__init__.py index e69de29b..4f4a3966 100644 --- a/src/delphi/train/__init__.py +++ b/src/delphi/train/__init__.py @@ -0,0 +1 @@ +from .llama2c.tinystories import Task \ No newline at end of file diff --git a/src/delphi/llama2.py b/src/delphi/train/llama2.py similarity index 100% rename from src/delphi/llama2.py rename to src/delphi/train/llama2.py diff --git a/src/delphi/train/training_old.py b/src/delphi/train/training_old.py index d2d5e0e1..bd1a6a9d 100644 --- a/src/delphi/train/training_old.py +++ b/src/delphi/train/training_old.py @@ -16,6 +16,9 @@ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) """ +import torch._dynamo +torch._dynamo.config.suppress_errors = True + import math import os import time @@ -29,8 +32,7 @@ from llama2 import LLaMA2, LLaMA2Args from llama2c import model_export, Task - -from shuffle import shuffle_epoch +from tqdm import tqdm # ----------------------------------------------------------------------------- # I/O @@ -42,11 +44,12 @@ always_save_checkpoint = False # if True, always save a checkpoint after each eval init_from = "scratch" # 'scratch' or 'resume' # wandb logging -wandb_log = False # disabled by default -wandb_project = "llamac" -wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") +wandb_log = True # disabled by default +wandb_entity = "jannik-brinkmann" +wandb_project = "delphi" +wandb_run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") # data -batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size +batch_size = 64 # if gradient_accumulation_steps > 1, this is the micro-batch size max_seq_len = 256 vocab_source = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens @@ -60,7 +63,7 @@ # adamw optimizer gradient_accumulation_steps = 4 # used to simulate larger batch sizes learning_rate = 5e-4 # max learning rate -max_iters = 100000 # total number of training iterations +max_epochs = 10 # total number of training epochs weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 @@ -71,21 +74,38 @@ # system device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks dtype = "bfloat16" # float32|bfloat16|float16 -compile = True # use PyTorch 2.0 to compile the model to be faster +compile = False # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- config_keys = [ k for k, v in globals().items() if not k.startswith("_") and isinstance(v, (int, float, bool, str)) ] -exec(open("configurator.py").read()) # overrides from command line or config file +exec(open("./llama2c/configurator.py").read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- # fixing some hyperparams to sensible defaults -lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla +num_batches = Task.get_num_batches( + batch_size=batch_size, + max_seq_len=max_seq_len, + vocab_size=vocab_size, + vocab_source=vocab_source, + device=device, +) +num_steps = num_batches // gradient_accumulation_steps +eval_iters = Task.get_num_batches( + split="validation", + batch_size=batch_size, + max_seq_len=max_seq_len, + vocab_size=vocab_size, + vocab_source=vocab_source, + device=device, +) +eval_iters = min(12, eval_iters) +lr_decay_iters = max_epochs * num_batches # should be ~= max_iters per Chinchilla min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla # validating checks @@ -218,7 +238,7 @@ def estimate_loss(): out = {} model.eval() for split in ["train", "val"]: - batch_iter = iter_batches(split=split) + batch_iter = iter_batches(split=split, epoch=epoch) losses = torch.zeros(eval_iters) # keep on CPU for k in range(eval_iters): X, Y = next(batch_iter) @@ -244,110 +264,104 @@ def get_lr(it): coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) - - # logging if wandb_log and master_process: import wandb - wandb.init(project=wandb_project, name=wandb_run_name, config=config) - - - + wandb.init(entity=wandb_entity, project=wandb_project, name=wandb_run_name, config=config) # training loop -train_batch_iter = iter_batches(split="train") -X, Y = next(train_batch_iter) # fetch the very first batch t0 = time.time() local_iter_num = 0 # number of iterations in the lifetime of this process raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 -while True: - # determine and set the learning rate for this iteration - lr = get_lr(iter_num) if decay_lr else learning_rate - for param_group in optimizer.param_groups: - param_group["lr"] = lr - - # evaluate the loss on train/val sets and write checkpoints - if iter_num % eval_interval == 0 and master_process: - losses = estimate_loss() - print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") - if wandb_log: - try: - wandb.log( - { - "iter": iter_num, - "tokens": iter_num * tokens_per_iter, - "loss/train": losses["train"], - "loss/val": losses["val"], - "lr": lr, - "mfu": running_mfu * 100, # convert to percentage - }, step = iter_num - ) - except Exception as e: - print(f"logging to wandb failed: {e}") - if losses["val"] < best_val_loss or always_save_checkpoint: - best_val_loss = losses["val"] - if iter_num > 0: - checkpoint = { - "model": raw_model.state_dict(), - "optimizer": optimizer.state_dict(), - "model_args": model_args, - "iter_num": iter_num, - "best_val_loss": best_val_loss, - "config": config, - } - print(f"saving checkpoint to {out_dir}") - torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) - model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0) - if iter_num == 0 and eval_only: - break - - # forward backward update, with optional gradient accumulation to simulate larger batch size - # and using the GradScaler if data type is float16 - for micro_step in range(gradient_accumulation_steps): - if ddp: - # in DDP training we only need to sync gradients at the last micro step. - # the official way to do this is with model.no_sync() context manager, but - # I really dislike that this bloats the code and forces us to repeat code - # looking at the source of that context manager, it just toggles this variable - model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 - with ctx: - logits = model(X, Y) - loss = raw_model.last_loss - loss = loss / gradient_accumulation_steps - # immediately async prefetch next batch while model is doing the forward pass on the GPU - X, Y = next(train_batch_iter) - # backward pass, with gradient scaling if training in fp16 - scaler.scale(loss).backward() - # clip the gradient - if grad_clip != 0.0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - # step the optimizer and scaler if training in fp16 - scaler.step(optimizer) - scaler.update() - # flush the gradients as soon as we can, no need for this memory anymore - optimizer.zero_grad(set_to_none=True) - - # timing and logging - t1 = time.time() - dt = t1 - t0 - t0 = t1 - if iter_num % log_interval == 0 and master_process: - # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point - lossf = loss.item() * gradient_accumulation_steps - if local_iter_num >= 5: # let the training loop settle a bit - mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) - running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu - print( - f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%" - ) - iter_num += 1 - local_iter_num += 1 - - # termination conditions - if iter_num > max_iters: - break +epoch = 0 +for epoch in range(max_epochs): + train_batch_iter = iter_batches(split="train", epoch=epoch) + X, Y = next(train_batch_iter) # fetch the very first batch + for _ in tqdm(range(num_steps)): + + # determine and set the learning rate for this iteration + lr = get_lr(iter_num) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + # evaluate the loss on train/val sets and write checkpoints + if iter_num % eval_interval == 0 and master_process: + losses = estimate_loss() + print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") + if wandb_log: + try: + wandb.log( + { + "iter": iter_num, + "tokens": iter_num * tokens_per_iter, + "loss/train": losses["train"], + "loss/val": losses["val"], + "lr": lr, + "mfu": running_mfu * 100, # convert to percentage + }, step = iter_num + ) + except Exception as e: + print(f"logging to wandb failed: {e}") + if losses["val"] < best_val_loss or always_save_checkpoint: + best_val_loss = losses["val"] + if iter_num > 0: + checkpoint = { + "model": raw_model.state_dict(), + "optimizer": optimizer.state_dict(), + "model_args": model_args, + "iter_num": iter_num, + "best_val_loss": best_val_loss, + "config": config, + } + print(f"saving checkpoint to {out_dir}") + torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) + model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0) + if iter_num == 0 and eval_only: + break + + # forward backward update, with optional gradient accumulation to simulate larger batch size + # and using the GradScaler if data type is float16 + for micro_step in range(gradient_accumulation_steps): + if ddp: + # in DDP training we only need to sync gradients at the last micro step. + # the official way to do this is with model.no_sync() context manager, but + # I really dislike that this bloats the code and forces us to repeat code + # looking at the source of that context manager, it just toggles this variable + model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 + with ctx: + logits = model(X, Y) + loss = raw_model.last_loss + loss = loss / gradient_accumulation_steps + # immediately async prefetch next batch while model is doing the forward pass on the GPU + X, Y = next(train_batch_iter) + # backward pass, with gradient scaling if training in fp16 + scaler.scale(loss).backward() + # clip the gradient + if grad_clip != 0.0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + # step the optimizer and scaler if training in fp16 + scaler.step(optimizer) + scaler.update() + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) + + # timing and logging + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if iter_num % log_interval == 0 and master_process: + # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point + lossf = loss.item() * gradient_accumulation_steps + if local_iter_num >= 5: # let the training loop settle a bit + mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) + running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu + print( + f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%" + ) + iter_num += 1 + local_iter_num += 1 if ddp: destroy_process_group() diff --git a/src/delphi/train/training_old.sh b/src/delphi/train/training_old.sh new file mode 100644 index 00000000..451225cc --- /dev/null +++ b/src/delphi/train/training_old.sh @@ -0,0 +1,6 @@ +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export CUDA_VISIBLE_DEVICES=3 +export TRANSFORMERS_CACHE=/ceph/jbrinkma/cache/transformers +export HF_DATASETS_CACHE=/ceph/jbrinkma/cache/datasets + +python3 training_old.py --vocab_source=custom --vocab_size=4096 --max_seq_len=512 --dim=48 --n_layers=8 --n_heads=8 --n_kv_heads=4