diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py index 04074cbc..6acc40ae 100644 --- a/src/delphi/train/train_step.py +++ b/src/delphi/train/train_step.py @@ -8,8 +8,7 @@ def train_step( train_ds, validation_ds, - lr_decay_iters, - tokens_per_iter, + iteration_params, iter_num, best_val_loss, model_args, @@ -21,8 +20,6 @@ def train_step( local_iter_num, config, train_batch_iter, - num_steps, - eval_iters, ): # here's how each train step works: # 1. Set learning rate @@ -31,13 +28,13 @@ def train_step( # 4. log timing # 1. determine and set the learning rate for this iteration - lr = set_lr(lr_decay_iters, config, optimizer, iter_num) + lr = set_lr(iteration_params.lr_decay_iters, config, optimizer, iter_num) # 2. evaluate the loss on train/val sets and write checkpoints if iter_num % config.eval_interval == 0: losses = estimate_loss( model=model, - eval_iters=eval_iters, + eval_iters=iteration_params.eval_iters, batch_size=config.batch_size, split_to_ds={"train": train_ds, "val": validation_ds}, ) @@ -47,7 +44,7 @@ def train_step( new_best_val_loss = True eval_data = EvalData( iter_num=iter_num, - tokens_per_iter=tokens_per_iter, + tokens_per_iter=iteration_params.tokens_per_iter, running_mfu=running_mfu, lr=lr, losses=losses, @@ -70,10 +67,13 @@ def train_step( # 3. forward backward update, with optional gradient accumulation to simulate larger batch size X, Y = next(train_batch_iter) print( - f"gradient accumulation steps: {config.gradient_accumulation_steps}, num_steps: {num_steps}, iter_num: {iter_num}" + f"gradient accumulation steps: {config.gradient_accumulation_steps}, num_steps: {iteration_params.num_steps}, iter_num: {iter_num}" ) for micro_step in range( - min(config.gradient_accumulation_steps, num_steps - iter_num + 1) + min( + config.gradient_accumulation_steps, + iteration_params.num_steps - iter_num + 1, + ) ): logits = model(X, Y) loss = model.last_loss / config.gradient_accumulation_steps diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index b62d43c9..96616bff 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -65,8 +65,7 @@ breaknow, t0, iter_num, local_iter_num, best_val_loss = train_step( train_ds, validation_ds, - iteration_params.lr_decay_iters, - iteration_params.tokens_per_iter, + iteration_params, iter_num, best_val_loss, model_args, @@ -78,8 +77,6 @@ local_iter_num, config, train_batch_iter, - iteration_params.num_steps, - iteration_params.eval_iters, ) if breaknow: break