diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index 5a7f0cc7..92a7fe29 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -1,25 +1,7 @@ """ -This training script can be run both on a single gpu in debug mode, -and also in a larger training run with distributed data parallel (ddp). -To run on a single GPU small debug run, example: -$ python -m train.py --compile=False --eval_iters=10 --batch_size=8 - -To run with DDP on 4 gpus on 1 node, example: -$ torchrun --standalone --nproc_per_node=4 train.py - -To run with DDP on 4 gpus across 2 nodes, example: -- Run on the first (master) node with example IP 123.456.123.456: -$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py -- Run on the worker node: -$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py -(If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) """ -import torch._dynamo - -torch._dynamo.config.suppress_errors = True - import math import os import time @@ -28,11 +10,9 @@ from functools import partial import torch -from torch.distributed import destroy_process_group, init_process_group -from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm -from delphi.train.llama2 import LLaMA2, LLaMA2Args +from llama2c.model import ModelArgs as Llama2ModelArgs, Transformer as Llama2Model from llama2c import Task, model_export # ----------------------------------------------------------------------------- @@ -78,7 +58,7 @@ device = ( "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks ) -dtype = "bfloat16" # float32|bfloat16|float16 +dtype = "float32" # float32|bfloat16|float16 compile = False # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- config_keys = [ @@ -123,37 +103,16 @@ # various inits, derived attributes, I/O setup seed = 1337 -ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? -if ddp: - init_process_group(backend="nccl") - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) - device = f"cuda:{ddp_local_rank}" - torch.cuda.set_device(device) - master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. - seed_offset = ddp_rank # each process gets a different seed - # world_size number of processes will be training simultaneously, so we can scale - # down the desired gradient accumulation iterations per process proportionally - assert gradient_accumulation_steps % ddp_world_size == 0 - gradient_accumulation_steps //= ddp_world_size -else: - # if not ddp, we are running on a single gpu, and one process - master_process = True - seed_offset = 0 - ddp_world_size = 1 tokens_per_iter = ( - gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len + gradient_accumulation_steps * batch_size * max_seq_len +) +print(f"tokens per iteration will be: {tokens_per_iter:,}") +print( + f"breaks down as: {gradient_accumulation_steps} grad accum steps * {batch_size} batch size * {max_seq_len} max seq len" ) -if master_process: - print(f"tokens per iteration will be: {tokens_per_iter:,}") - print( - f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len" - ) -if master_process: - os.makedirs(out_dir, exist_ok=True) -torch.manual_seed(seed + seed_offset) +os.makedirs(out_dir, exist_ok=True) +torch.manual_seed(seed) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast @@ -163,11 +122,7 @@ "bfloat16": torch.bfloat16, "float16": torch.float16, }[dtype] -ctx = ( - nullcontext() - if device_type == "cpu" - else torch.amp.autocast(device_type=device_type, dtype=ptdtype) -) + # task-specific setup iter_batches = partial( @@ -199,8 +154,8 @@ if init_from == "scratch": # init a new model from scratch print("Initializing a new model from scratch") - gptconf = LLaMA2Args(**model_args) - model = LLaMA2(gptconf) + gptconf = Llama2ModelArgs(**model_args) + model = Llama2Model(gptconf) elif init_from == "resume": print(f"Resuming training from {out_dir}") # resume training from a checkpoint. @@ -220,8 +175,8 @@ ]: model_args[k] = checkpoint_model_args[k] # create the model - gptconf = LLaMA2Args(**model_args) - model = LLaMA2(gptconf) + gptconf = Llama2ModelArgs(**model_args) + model = Llama2Model(gptconf) state_dict = checkpoint["model"] # fix the keys of the state dictionary :( # honestly no idea how checkpoints sometimes get this prefix, have to debug more @@ -234,8 +189,6 @@ best_val_loss = checkpoint["best_val_loss"] model.to(device) -# initialize a GradScaler. If enabled=False scaler is a no-op -scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) # optimizer optimizer = model.configure_optimizers( @@ -245,20 +198,7 @@ optimizer.load_state_dict(checkpoint["optimizer"]) checkpoint = None # free up memory -# compile the model -if compile: - print("compiling the model... (takes a ~minute)") - unoptimized_model = model - model = torch.compile(model) # requires PyTorch 2.0 - # wrap model into DDP container -if ddp: - # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at - # construction time since NCCL does not support `ComplexFloat` - prefix = "_orig_mod." if compile else "" - model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"} - model = DDP(model, device_ids=[ddp_local_rank]) - # helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() @@ -270,9 +210,8 @@ def estimate_loss(): losses = torch.zeros(eval_iters) # keep on CPU for k in range(eval_iters): X, Y = next(batch_iter) - with ctx: - logits = model(X, Y) - loss = raw_model.last_loss + logits = model(X, Y) + loss = model.last_loss losses[k] = loss.item() out[split] = losses.mean() model.train() @@ -295,7 +234,7 @@ def get_lr(it): # logging -if wandb_log and master_process: +if wandb_log: import wandb wandb.init( @@ -305,7 +244,6 @@ def get_lr(it): # training loop t0 = time.time() local_iter_num = 0 # number of iterations in the lifetime of this process -raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 epoch = 0 for epoch in range(max_epochs): @@ -318,7 +256,7 @@ def get_lr(it): param_group["lr"] = lr # evaluate the loss on train/val sets and write checkpoints - if iter_num % eval_interval == 0 and master_process: + if iter_num % eval_interval == 0: losses = estimate_loss() print( f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" @@ -342,7 +280,7 @@ def get_lr(it): best_val_loss = losses["val"] if iter_num > 0: checkpoint = { - "model": raw_model.state_dict(), + "model": model.state_dict(), "optimizer": optimizer.state_dict(), "model_args": model_args, "iter_num": iter_num, @@ -352,7 +290,7 @@ def get_lr(it): print(f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) model_export( - raw_model, os.path.join(out_dir, "model.bin"), version=0 + model, os.path.join(out_dir, "model.bin"), version=0 ) if iter_num == 0 and eval_only: break @@ -360,29 +298,19 @@ def get_lr(it): # forward backward update, with optional gradient accumulation to simulate larger batch size # and using the GradScaler if data type is float16 for micro_step in range(gradient_accumulation_steps): - if ddp: - # in DDP training we only need to sync gradients at the last micro step. - # the official way to do this is with model.no_sync() context manager, but - # I really dislike that this bloats the code and forces us to repeat code - # looking at the source of that context manager, it just toggles this variable - model.require_backward_grad_sync = ( - micro_step == gradient_accumulation_steps - 1 - ) - with ctx: - logits = model(X, Y) - loss = raw_model.last_loss - loss = loss / gradient_accumulation_steps + + logits = model(X, Y) + loss = model.last_loss + loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = next(train_batch_iter) # backward pass, with gradient scaling if training in fp16 - scaler.scale(loss).backward() + loss.backward() # clip the gradient if grad_clip != 0.0: - scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # step the optimizer and scaler if training in fp16 - scaler.step(optimizer) - scaler.update() + optimizer.set() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) @@ -390,11 +318,11 @@ def get_lr(it): t1 = time.time() dt = t1 - t0 t0 = t1 - if iter_num % log_interval == 0 and master_process: + if iter_num % log_interval == 0: # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point lossf = loss.item() * gradient_accumulation_steps if local_iter_num >= 5: # let the training loop settle a bit - mfu = raw_model.estimate_mfu( + mfu = model.estimate_mfu( batch_size * gradient_accumulation_steps, dt ) running_mfu = ( @@ -406,5 +334,4 @@ def get_lr(it): iter_num += 1 local_iter_num += 1 -if ddp: - destroy_process_group() +i