Skip to content

Commit

Permalink
[references] Update Logging (#1847)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jan 23, 2025
1 parent ebfc9f3 commit d41ba35
Show file tree
Hide file tree
Showing 14 changed files with 802 additions and 367 deletions.
124 changes: 86 additions & 38 deletions references/classification/train_pytorch_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,14 @@ def record_lr(
return lr_recorder[: len(loss_recorder)], loss_recorder


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, clearml_log=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None):
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
if clearml_log:
from clearml import Logger

logger = Logger.current_logger()

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
epoch_train_loss, batch_cnt = 0.0, 0.0
pbar = tqdm(train_loader, dynamic_ncols=True)
for images, targets in pbar:
if torch.cuda.is_available():
images = images.cuda()
Expand All @@ -143,24 +139,28 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
train_loss = cross_entropy(out, targets)
train_loss.backward()
optimizer.step()

scheduler.step()
last_lr = scheduler.get_last_lr()[0]

pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}")
log(train_loss=train_loss.item(), lr=last_lr)

epoch_train_loss += train_loss.item()
batch_cnt += 1

pbar.set_description(f"Training loss: {train_loss.item():.6}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.item(), iteration=iteration
)
iteration += 1
epoch_train_loss /= batch_cnt
return epoch_train_loss, last_lr


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms, amp=False):
def evaluate(model, val_loader, batch_transforms, amp=False, log=None):
# Model in eval mode
model.eval()
# Validation loop
val_loss, correct, samples, batch_cnt = 0, 0, 0, 0
for images, targets in tqdm(val_loader):
pbar = tqdm(val_loader, dynamic_ncols=True)
for images, targets in pbar:
images = batch_transforms(images)

if torch.cuda.is_available():
Expand All @@ -177,6 +177,9 @@ def evaluate(model, val_loader, batch_transforms, amp=False):
# Compute metric
correct += (out.argmax(dim=1) == targets).sum().item()

pbar.set_description(f"Validation loss: {loss.item():.6}")
log(val_loss=loss.item())

val_loss += loss.item()
batch_cnt += 1
samples += images.shape[0]
Expand All @@ -187,7 +190,8 @@ def evaluate(model, val_loader, batch_transforms, amp=False):


def main(args):
print(args)
pbar = tqdm(disable=True)
pbar.write(str(args))

if args.push_to_hub:
login_to_hub()
Expand Down Expand Up @@ -222,7 +226,7 @@ def main(args):
sampler=SequentialSampler(val_set),
pin_memory=torch.cuda.is_available(),
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)")
pbar.write(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)")

batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

Expand All @@ -231,7 +235,7 @@ def main(args):

# Resume weights
if isinstance(args.resume, str):
print(f"Resuming {args.resume}")
pbar.write(f"Resuming {args.resume}")
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint)

Expand All @@ -251,9 +255,9 @@ def main(args):
model = model.cuda()

if args.test_only:
print("Running evaluation")
pbar.write("Running evaluation")
val_loss, acc = evaluate(model, val_loader, batch_transforms)
print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
pbar.write(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
return

st = time.time()
Expand Down Expand Up @@ -286,7 +290,7 @@ def main(args):
sampler=RandomSampler(train_set),
pin_memory=torch.cuda.is_available(),
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)")
pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)")

if args.show_samples:
x, target = next(iter(train_loader))
Expand Down Expand Up @@ -343,46 +347,87 @@ def main(args):
"pretrained": args.pretrained,
}

global global_step
global_step = 0 # Shared global step counter

# W&B
if args.wb:
import wandb

run = wandb.init(
name=exp_name,
project="character-classification",
config=config,
)
run = wandb.init(name=exp_name, project="character-classification", config=config)

def wandb_log_at_step(train_loss=None, val_loss=None, lr=None):
wandb.log({
**({"train_loss_step": train_loss} if train_loss is not None else {}),
**({"val_loss_step": val_loss} if val_loss is not None else {}),
**({"step_lr": lr} if lr is not None else {}),
})

# ClearML
if args.clearml:
from clearml import Task
from clearml import Logger, Task

task = Task.init(project_name="docTR/character-classification", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)
global iteration
iteration = 0

def clearml_log_at_step(train_loss=None, val_loss=None, lr=None):
logger = Logger.current_logger()
if train_loss is not None:
logger.report_scalar(
title="Training Step Loss",
series="train_loss_step",
iteration=global_step,
value=train_loss,
)
if val_loss is not None:
logger.report_scalar(
title="Validation Step Loss",
series="val_loss_step",
iteration=global_step,
value=val_loss,
)
if lr is not None:
logger.report_scalar(
title="Step Learning Rate",
series="step_lr",
iteration=global_step,
value=lr,
)

# Unified logger
def log_at_step(train_loss=None, val_loss=None, lr=None):
global global_step
if args.wb:
wandb_log_at_step(train_loss, val_loss, lr)
if args.clearml:
clearml_log_at_step(train_loss, val_loss, lr)
global_step += 1 # Increment the shared global step counter

# Create loss queue
min_loss = np.inf
# Training loop
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
for epoch in range(args.epochs):
fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, clearml_log=args.clearml
train_loss, actual_lr = fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step
)
pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}")

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
val_loss, acc = evaluate(model, val_loader, batch_transforms, log=log_at_step)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")

# W&B
if args.wb:
wandb.log({
"train_loss": train_loss,
"val_loss": val_loss,
"learning_rate": actual_lr,
"acc": acc,
})

Expand All @@ -391,24 +436,27 @@ def main(args):
from clearml import Logger

logger = Logger.current_logger()
logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch)
logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch)
logger.report_scalar(title="Accuracy", series="acc", value=acc, iteration=epoch)

if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
pbar.write("Training halted early due to reaching patience limit.")
break

if args.wb:
run.finish()

if args.push_to_hub:
push_to_hf_hub(model, exp_name, task="classification", run_config=args)

if args.export_onnx:
print("Exporting model to ONNX...")
pbar.write("Exporting model to ONNX...")
dummy_batch = next(iter(val_loader))
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
model_path = export_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")
pbar.write(f"Exported model saved in {model_path}")


def parse_args():
Expand Down
Loading

0 comments on commit d41ba35

Please sign in to comment.