Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Mar 7, 2024
1 parent a29cd87 commit 78ec019
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 105 deletions.
1 change: 1 addition & 0 deletions src/delphi/train/gigaconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GigaConfig:
# learning rate decay settings
decay_lr: bool = True # whether to decay the learning rate
warmup_iters: int = 1000 # how many steps to warm up for
min_lr: float = 0.0 # should be ~learning_rate/10 per Chinchill


# Jai Overrides TODO: remove these
Expand Down
213 changes: 114 additions & 99 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@
initialize_model,
resume_model,
save_checkpoint_if_needed,
set_lr,
)

# system
device = get_device()

# -----------------------------------------------------------------------------

# load data
train_docs_ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, "train").select(
range(256) # TODO: remove when done debugging
Expand Down Expand Up @@ -115,11 +114,8 @@
# optimizer
optimizer = get_optimizer(
model=model,
weight_decay=config.weight_decay,
learning_rate=config.learning_rate,
beta_1=config.beta1,
beta_2=config.beta2,
device_type=device_type,
config=config,
device=device,
checkpoint=checkpoint
if checkpoint is not None and "optimizer" in checkpoint
else None,
Expand All @@ -128,8 +124,6 @@


eval_callbacks = [save_checkpoint_if_needed]

# logging
if config.wandb_log:
wandb_utils.init_wandb(config)
eval_callbacks.append(wandb_utils.log_to_wandb)
Expand All @@ -144,21 +138,103 @@
epoch = 0


def set_lr(lr_decay_iters, min_lr, iter_num, optimizer):
lr = (
get_lr(
iter_num,
config.warmup_iters,
config.learning_rate,
lr_decay_iters,
min_lr,
def train_step(
train_ds,
validation_ds,
lr_decay_iters,
tokens_per_iter,
iter_num,
best_val_loss,
model_args,
model,
optimizer,
eval_callbacks,
running_mfu,
t0,
local_iter_num,
):
# here's how each train step works:
# 1. Set learning rate
# 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint
# 3. forward backward update
# 4. log timing

# 1. determine and set the learning rate for this iteration
lr = set_lr(lr_decay_iters, config, optimizer, iter_num)

# 2. evaluate the loss on train/val sets and write checkpoints
if iter_num % config.eval_interval == 0:
losses = estimate_loss(
model=model,
eval_iters=config.eval_iters,
batch_size=config.batch_size,
split_to_ds={"train": train_ds, "val": validation_ds},
)
new_best_val_loss = False
if losses["val"] < best_val_loss or config.always_save_checkpoint:
best_val_loss = float(losses["val"])
new_best_val_loss = True
eval_data = EvalData(
iter_num=iter_num,
tokens_per_iter=tokens_per_iter,
running_mfu=running_mfu,
lr=lr,
losses=losses,
best_val_loss=best_val_loss,
new_best_val_loss=new_best_val_loss,
model=model,
model_args=model_args,
optimizer=optimizer,
config=config,
)
if config.decay_lr
else config.learning_rate
print(
f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
)
for callback in eval_callbacks:
callback(eval_data)

if iter_num == 0 and config.eval_only:
return True, None, None, None

# 3. forward backward update, with optional gradient accumulation to simulate larger batch size
X, Y = next(train_batch_iter)
print(
f"gradient accumulation steps: {config.gradient_accumulation_steps}, num_steps: {num_steps}, iter_num: {iter_num}"
)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr
for micro_step in range(
min(config.gradient_accumulation_steps, num_steps - iter_num + 1)
):
logits = model(X, Y)
loss = model.last_loss / config.gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = next(train_batch_iter)
loss.backward()
# clip the gradient
if config.grad_clip != 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore
optimizer.step()

# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)

# timing and logging
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % config.log_interval == 0:
# get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
lossf = loss.item() * config.gradient_accumulation_steps
if local_iter_num >= 5: # let the training loop settle a bit
mfu = model.estimate_mfu(
config.batch_size * config.gradient_accumulation_steps, dt
)
running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
print(
f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
)
iter_num += 1
local_iter_num += 1
return False, t0, iter_num, local_iter_num


for epoch in range(config.max_epochs):
Expand All @@ -168,81 +244,20 @@ def set_lr(lr_decay_iters, min_lr, iter_num, optimizer):
X, Y = next(train_batch_iter)

for _ in tqdm(range(num_steps)):
# here's how each train step works:
# 1. Set learning rate
# 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint
# 3. forward backward update
# 4. log timing

# 1. determine and set the learning rate for this iteration
lr = set_lr(lr_decay_iters, min_lr, iter_num, optimizer)

# 2. evaluate the loss on train/val sets and write checkpoints
if iter_num % config.eval_interval == 0:
losses = estimate_loss(
model=model,
eval_iters=config.eval_iters,
batch_size=config.batch_size,
split_to_ds={"train": train_ds, "val": validation_ds},
)
new_best_val_loss = False
if losses["val"] < best_val_loss or config.always_save_checkpoint:
best_val_loss = float(losses["val"])
new_best_val_loss = True
eval_data = EvalData(
iter_num=iter_num,
tokens_per_iter=tokens_per_iter,
running_mfu=running_mfu,
lr=lr,
losses=losses,
best_val_loss=best_val_loss,
new_best_val_loss=new_best_val_loss,
model=model,
model_args=model_args,
optimizer=optimizer,
config=config,
)
print(
f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
)
for callback in eval_callbacks:
callback(eval_data)
if iter_num == 0 and config.eval_only:
breaknow, t0, iter_num, local_iter_num = train_step(
train_ds,
validation_ds,
lr_decay_iters,
tokens_per_iter,
iter_num,
best_val_loss,
model_args,
model,
optimizer,
eval_callbacks,
running_mfu,
t0,
local_iter_num,
)
if breaknow:
break

# forward backward update, with optional gradient accumulation to simulate larger batch size
for micro_step in range(
min(config.gradient_accumulation_steps, num_steps - iter_num)
):
logits = model(X, Y)
loss = model.last_loss / config.gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = next(train_batch_iter)
loss.backward()
# clip the gradient
if config.grad_clip != 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore
optimizer.step()

# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)

# timing and logging
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % config.log_interval == 0:
# get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
lossf = loss.item() * config.gradient_accumulation_steps
if local_iter_num >= 5: # let the training loop settle a bit
mfu = model.estimate_mfu(
config.batch_size * config.gradient_accumulation_steps, dt
)
running_mfu = (
mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
)
print(
f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
)
iter_num += 1
local_iter_num += 1
30 changes: 24 additions & 6 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,16 @@ def resume_model(resume_from_path: Path, device: str, **model_args) -> ModelMidT

def get_optimizer(
model: Llama2Model,
weight_decay,
learning_rate,
beta_1,
beta_2,
device_type,
config: GigaConfig,
device: str,
checkpoint=None,
) -> AdamW:
device_type = "cuda" if "cuda" in device else "cpu"
optimizer = model.configure_optimizers(
weight_decay, learning_rate, (beta_1, beta_2), device_type
config.weight_decay,
config.learning_rate,
(config.beta1, config.beta2),
device_type,
)
if checkpoint is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
Expand Down Expand Up @@ -136,6 +137,23 @@ def get_lr(it, warmup_iters, learning_rate, lr_decay_iters, min_lr):
return min_lr + coeff * (learning_rate - min_lr)


def set_lr(lr_decay_iters: int, config: GigaConfig, optimizer: AdamW, iter_num: int):
lr = (
get_lr(
iter_num,
config.warmup_iters,
config.learning_rate,
lr_decay_iters,
config.min_lr,
)
if config.decay_lr
else config.learning_rate
)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr


@dataclass
class EvalData:
# values we expose to eval callback functions
Expand Down

0 comments on commit 78ec019

Please sign in to comment.