Skip to content

Commit

Permalink
train only readouts with lbfgs
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa committed Dec 19, 2024
1 parent 5eae9d7 commit 9cb5d65
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def run(args: argparse.Namespace) -> None:
batch_mode = args.lbfgs_config.get("batch_mode", False)

logging.info("Switching optimizer to LBFGS")
optimizer = LBFGSNew(model.parameters(),
optimizer = LBFGSNew(model.readouts.parameters(),
tolerance_grad=1e-6,
history_size=history_size,
max_iter=max_iter,
Expand Down
13 changes: 6 additions & 7 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def train(

# Validate
if epoch % eval_interval == 0:
logging.info("GPU Memory Report:")
logging.info(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
logging.info(f"Total: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB")
logging.info(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Free memory: {torch.cuda.mem_get_info()[0] / (1024**2):.2f} MB")
model_to_evaluate = (
model if distributed_model is None else distributed_model
)
Expand Down Expand Up @@ -365,13 +371,6 @@ def take_step(
start_time = time.time()
batch = batch.to(device)
batch_dict = batch.to_dict()

logging.info("GPU Memory Report:")
logging.info(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
logging.info(f"Total: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB")
logging.info(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Free memory: {torch.cuda.mem_get_info()[0] / (1024**2):.2f} MB")

def closure():
optimizer.zero_grad(set_to_none=True)
Expand Down

0 comments on commit 9cb5d65

Please sign in to comment.