Skip to content

Commit

Permalink
improve indendation to make things more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
Tamas Tompa committed Oct 23, 2024
1 parent d1b678c commit e1da8d1
Showing 1 changed file with 37 additions and 32 deletions.
69 changes: 37 additions & 32 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def validate_and_checkpoint(
)
if "ScheduleFree" in type(optimizer).__name__:
optimizer.eval()

with param_context:
valid_loss = 0.0
wandb_log_dict = {}
Expand Down Expand Up @@ -487,50 +488,54 @@ def validate_and_checkpoint(
valid_loss = (
valid_loss_head # consider only the last head for the checkpoint
)

if log_wandb:
import wandb
wandb.log(wandb_log_dict)
if rank == 0:
if valid_loss >= lowest_loss:
patience_counter += 1
if patience_counter >= patience:
if swa is not None and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two"
)
epoch = swa.start
else:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
return True, epoch, lowest_loss, patience_counter
if save_all_checkpoints:
param_context = (
ema.average_parameters()
if ema is not None
else nullcontext()

if rank != 0:
return False, epoch, lowest_loss, patience_counter

if valid_loss >= lowest_loss:
patience_counter += 1
if patience_counter >= patience:
if swa is not None and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two"
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=True,
)
else:
lowest_loss = valid_loss
patience_counter = 0
epoch = swa.start
else:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
return True, epoch, lowest_loss, patience_counter
if save_all_checkpoints:
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
ema.average_parameters()
if ema is not None
else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
keep_last=True,
)
keep_last = False or save_all_checkpoints
else:
lowest_loss = valid_loss
patience_counter = 0
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)
keep_last = False or save_all_checkpoints

return False, epoch, lowest_loss, patience_counter
return False, epoch, lowest_loss, patience_counter


def evaluate(
Expand Down

0 comments on commit e1da8d1

Please sign in to comment.