Skip to content

Commit

Permalink
consolidate iteration_param args
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Mar 7, 2024
1 parent 85fe484 commit 1ca7a00
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
18 changes: 9 additions & 9 deletions src/delphi/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
def train_step(
train_ds,
validation_ds,
lr_decay_iters,
tokens_per_iter,
iteration_params,
iter_num,
best_val_loss,
model_args,
Expand All @@ -21,8 +20,6 @@ def train_step(
local_iter_num,
config,
train_batch_iter,
num_steps,
eval_iters,
):
# here's how each train step works:
# 1. Set learning rate
Expand All @@ -31,13 +28,13 @@ def train_step(
# 4. log timing

# 1. determine and set the learning rate for this iteration
lr = set_lr(lr_decay_iters, config, optimizer, iter_num)
lr = set_lr(iteration_params.lr_decay_iters, config, optimizer, iter_num)

# 2. evaluate the loss on train/val sets and write checkpoints
if iter_num % config.eval_interval == 0:
losses = estimate_loss(
model=model,
eval_iters=eval_iters,
eval_iters=iteration_params.eval_iters,
batch_size=config.batch_size,
split_to_ds={"train": train_ds, "val": validation_ds},
)
Expand All @@ -47,7 +44,7 @@ def train_step(
new_best_val_loss = True
eval_data = EvalData(
iter_num=iter_num,
tokens_per_iter=tokens_per_iter,
tokens_per_iter=iteration_params.tokens_per_iter,
running_mfu=running_mfu,
lr=lr,
losses=losses,
Expand All @@ -70,10 +67,13 @@ def train_step(
# 3. forward backward update, with optional gradient accumulation to simulate larger batch size
X, Y = next(train_batch_iter)
print(
f"gradient accumulation steps: {config.gradient_accumulation_steps}, num_steps: {num_steps}, iter_num: {iter_num}"
f"gradient accumulation steps: {config.gradient_accumulation_steps}, num_steps: {iteration_params.num_steps}, iter_num: {iter_num}"
)
for micro_step in range(
min(config.gradient_accumulation_steps, num_steps - iter_num + 1)
min(
config.gradient_accumulation_steps,
iteration_params.num_steps - iter_num + 1,
)
):
logits = model(X, Y)
loss = model.last_loss / config.gradient_accumulation_steps
Expand Down
5 changes: 1 addition & 4 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@
breaknow, t0, iter_num, local_iter_num, best_val_loss = train_step(
train_ds,
validation_ds,
iteration_params.lr_decay_iters,
iteration_params.tokens_per_iter,
iteration_params,
iter_num,
best_val_loss,
model_args,
Expand All @@ -78,8 +77,6 @@
local_iter_num,
config,
train_batch_iter,
iteration_params.num_steps,
iteration_params.eval_iters,
)
if breaknow:
break

0 comments on commit 1ca7a00

Please sign in to comment.