From 78ec019866fdb94da8f084304b3f9c5951090cfd Mon Sep 17 00:00:00 2001 From: JaiDhyani Date: Wed, 6 Mar 2024 17:22:16 -0800 Subject: [PATCH] more refactoring --- src/delphi/train/gigaconfig.py | 1 + src/delphi/train/training.py | 213 ++++++++++++++++++--------------- src/delphi/train/utils.py | 30 ++++- 3 files changed, 139 insertions(+), 105 deletions(-) diff --git a/src/delphi/train/gigaconfig.py b/src/delphi/train/gigaconfig.py index a99ad4b3..f9a35dfe 100644 --- a/src/delphi/train/gigaconfig.py +++ b/src/delphi/train/gigaconfig.py @@ -46,6 +46,7 @@ class GigaConfig: # learning rate decay settings decay_lr: bool = True # whether to decay the learning rate warmup_iters: int = 1000 # how many steps to warm up for + min_lr: float = 0.0 # should be ~learning_rate/10 per Chinchill # Jai Overrides TODO: remove these diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index b77f4bc4..23e363f1 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -26,13 +26,12 @@ initialize_model, resume_model, save_checkpoint_if_needed, + set_lr, ) # system device = get_device() -# ----------------------------------------------------------------------------- - # load data train_docs_ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, "train").select( range(256) # TODO: remove when done debugging @@ -115,11 +114,8 @@ # optimizer optimizer = get_optimizer( model=model, - weight_decay=config.weight_decay, - learning_rate=config.learning_rate, - beta_1=config.beta1, - beta_2=config.beta2, - device_type=device_type, + config=config, + device=device, checkpoint=checkpoint if checkpoint is not None and "optimizer" in checkpoint else None, @@ -128,8 +124,6 @@ eval_callbacks = [save_checkpoint_if_needed] - -# logging if config.wandb_log: wandb_utils.init_wandb(config) eval_callbacks.append(wandb_utils.log_to_wandb) @@ -144,21 +138,103 @@ epoch = 0 -def set_lr(lr_decay_iters, min_lr, iter_num, optimizer): - lr = ( - get_lr( - iter_num, - config.warmup_iters, - config.learning_rate, - lr_decay_iters, - min_lr, +def train_step( + train_ds, + validation_ds, + lr_decay_iters, + tokens_per_iter, + iter_num, + best_val_loss, + model_args, + model, + optimizer, + eval_callbacks, + running_mfu, + t0, + local_iter_num, +): + # here's how each train step works: + # 1. Set learning rate + # 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint + # 3. forward backward update + # 4. log timing + + # 1. determine and set the learning rate for this iteration + lr = set_lr(lr_decay_iters, config, optimizer, iter_num) + + # 2. evaluate the loss on train/val sets and write checkpoints + if iter_num % config.eval_interval == 0: + losses = estimate_loss( + model=model, + eval_iters=config.eval_iters, + batch_size=config.batch_size, + split_to_ds={"train": train_ds, "val": validation_ds}, + ) + new_best_val_loss = False + if losses["val"] < best_val_loss or config.always_save_checkpoint: + best_val_loss = float(losses["val"]) + new_best_val_loss = True + eval_data = EvalData( + iter_num=iter_num, + tokens_per_iter=tokens_per_iter, + running_mfu=running_mfu, + lr=lr, + losses=losses, + best_val_loss=best_val_loss, + new_best_val_loss=new_best_val_loss, + model=model, + model_args=model_args, + optimizer=optimizer, + config=config, ) - if config.decay_lr - else config.learning_rate + print( + f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" + ) + for callback in eval_callbacks: + callback(eval_data) + + if iter_num == 0 and config.eval_only: + return True, None, None, None + + # 3. forward backward update, with optional gradient accumulation to simulate larger batch size + X, Y = next(train_batch_iter) + print( + f"gradient accumulation steps: {config.gradient_accumulation_steps}, num_steps: {num_steps}, iter_num: {iter_num}" ) - for param_group in optimizer.param_groups: - param_group["lr"] = lr - return lr + for micro_step in range( + min(config.gradient_accumulation_steps, num_steps - iter_num + 1) + ): + logits = model(X, Y) + loss = model.last_loss / config.gradient_accumulation_steps + # immediately async prefetch next batch while model is doing the forward pass on the GPU + X, Y = next(train_batch_iter) + loss.backward() + # clip the gradient + if config.grad_clip != 0.0: + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore + optimizer.step() + + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) + + # timing and logging + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if iter_num % config.log_interval == 0: + # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point + lossf = loss.item() * config.gradient_accumulation_steps + if local_iter_num >= 5: # let the training loop settle a bit + mfu = model.estimate_mfu( + config.batch_size * config.gradient_accumulation_steps, dt + ) + running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu + print( + f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%" + ) + iter_num += 1 + local_iter_num += 1 + return False, t0, iter_num, local_iter_num for epoch in range(config.max_epochs): @@ -168,81 +244,20 @@ def set_lr(lr_decay_iters, min_lr, iter_num, optimizer): X, Y = next(train_batch_iter) for _ in tqdm(range(num_steps)): - # here's how each train step works: - # 1. Set learning rate - # 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint - # 3. forward backward update - # 4. log timing - - # 1. determine and set the learning rate for this iteration - lr = set_lr(lr_decay_iters, min_lr, iter_num, optimizer) - - # 2. evaluate the loss on train/val sets and write checkpoints - if iter_num % config.eval_interval == 0: - losses = estimate_loss( - model=model, - eval_iters=config.eval_iters, - batch_size=config.batch_size, - split_to_ds={"train": train_ds, "val": validation_ds}, - ) - new_best_val_loss = False - if losses["val"] < best_val_loss or config.always_save_checkpoint: - best_val_loss = float(losses["val"]) - new_best_val_loss = True - eval_data = EvalData( - iter_num=iter_num, - tokens_per_iter=tokens_per_iter, - running_mfu=running_mfu, - lr=lr, - losses=losses, - best_val_loss=best_val_loss, - new_best_val_loss=new_best_val_loss, - model=model, - model_args=model_args, - optimizer=optimizer, - config=config, - ) - print( - f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" - ) - for callback in eval_callbacks: - callback(eval_data) - if iter_num == 0 and config.eval_only: + breaknow, t0, iter_num, local_iter_num = train_step( + train_ds, + validation_ds, + lr_decay_iters, + tokens_per_iter, + iter_num, + best_val_loss, + model_args, + model, + optimizer, + eval_callbacks, + running_mfu, + t0, + local_iter_num, + ) + if breaknow: break - - # forward backward update, with optional gradient accumulation to simulate larger batch size - for micro_step in range( - min(config.gradient_accumulation_steps, num_steps - iter_num) - ): - logits = model(X, Y) - loss = model.last_loss / config.gradient_accumulation_steps - # immediately async prefetch next batch while model is doing the forward pass on the GPU - X, Y = next(train_batch_iter) - loss.backward() - # clip the gradient - if config.grad_clip != 0.0: - torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore - optimizer.step() - - # flush the gradients as soon as we can, no need for this memory anymore - optimizer.zero_grad(set_to_none=True) - - # timing and logging - t1 = time.time() - dt = t1 - t0 - t0 = t1 - if iter_num % config.log_interval == 0: - # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point - lossf = loss.item() * config.gradient_accumulation_steps - if local_iter_num >= 5: # let the training loop settle a bit - mfu = model.estimate_mfu( - config.batch_size * config.gradient_accumulation_steps, dt - ) - running_mfu = ( - mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu - ) - print( - f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%" - ) - iter_num += 1 - local_iter_num += 1 diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index ba79b2ee..72ac3f3f 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -83,15 +83,16 @@ def resume_model(resume_from_path: Path, device: str, **model_args) -> ModelMidT def get_optimizer( model: Llama2Model, - weight_decay, - learning_rate, - beta_1, - beta_2, - device_type, + config: GigaConfig, + device: str, checkpoint=None, ) -> AdamW: + device_type = "cuda" if "cuda" in device else "cpu" optimizer = model.configure_optimizers( - weight_decay, learning_rate, (beta_1, beta_2), device_type + config.weight_decay, + config.learning_rate, + (config.beta1, config.beta2), + device_type, ) if checkpoint is not None: optimizer.load_state_dict(checkpoint["optimizer"]) @@ -136,6 +137,23 @@ def get_lr(it, warmup_iters, learning_rate, lr_decay_iters, min_lr): return min_lr + coeff * (learning_rate - min_lr) +def set_lr(lr_decay_iters: int, config: GigaConfig, optimizer: AdamW, iter_num: int): + lr = ( + get_lr( + iter_num, + config.warmup_iters, + config.learning_rate, + lr_decay_iters, + config.min_lr, + ) + if config.decay_lr + else config.learning_rate + ) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + return lr + + @dataclass class EvalData: # values we expose to eval callback functions