Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[references] Update Logging #1847

Merged
merged 6 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 86 additions & 38 deletions references/classification/train_pytorch_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,14 @@
return lr_recorder[: len(loss_recorder)], loss_recorder


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, clearml_log=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None):
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
if clearml_log:
from clearml import Logger

logger = Logger.current_logger()

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
epoch_train_loss, batch_cnt = 0.0, 0.0
pbar = tqdm(train_loader, dynamic_ncols=True)
for images, targets in pbar:
if torch.cuda.is_available():
images = images.cuda()
Expand All @@ -143,24 +139,28 @@
train_loss = cross_entropy(out, targets)
train_loss.backward()
optimizer.step()

scheduler.step()
last_lr = scheduler.get_last_lr()[0]

pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}")
log(train_loss=train_loss.item(), lr=last_lr)

epoch_train_loss += train_loss.item()
batch_cnt += 1

pbar.set_description(f"Training loss: {train_loss.item():.6}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.item(), iteration=iteration
)
iteration += 1
epoch_train_loss /= batch_cnt
return epoch_train_loss, last_lr


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms, amp=False):
def evaluate(model, val_loader, batch_transforms, amp=False, log=None):
# Model in eval mode
model.eval()
# Validation loop
val_loss, correct, samples, batch_cnt = 0, 0, 0, 0
for images, targets in tqdm(val_loader):
pbar = tqdm(val_loader, dynamic_ncols=True)
for images, targets in pbar:
images = batch_transforms(images)

if torch.cuda.is_available():
Expand All @@ -177,6 +177,9 @@
# Compute metric
correct += (out.argmax(dim=1) == targets).sum().item()

pbar.set_description(f"Validation loss: {loss.item():.6}")
log(val_loss=loss.item())

val_loss += loss.item()
batch_cnt += 1
samples += images.shape[0]
Expand All @@ -187,7 +190,8 @@


def main(args):
print(args)
pbar = tqdm(disable=True)
pbar.write(str(args))

if args.push_to_hub:
login_to_hub()
Expand Down Expand Up @@ -222,7 +226,7 @@
sampler=SequentialSampler(val_set),
pin_memory=torch.cuda.is_available(),
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)")
pbar.write(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)")

batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

Expand All @@ -231,7 +235,7 @@

# Resume weights
if isinstance(args.resume, str):
print(f"Resuming {args.resume}")
pbar.write(f"Resuming {args.resume}")
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint)

Expand All @@ -251,9 +255,9 @@
model = model.cuda()

if args.test_only:
print("Running evaluation")
pbar.write("Running evaluation")
val_loss, acc = evaluate(model, val_loader, batch_transforms)
print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
pbar.write(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
return

st = time.time()
Expand Down Expand Up @@ -286,7 +290,7 @@
sampler=RandomSampler(train_set),
pin_memory=torch.cuda.is_available(),
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)")
pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)")

if args.show_samples:
x, target = next(iter(train_loader))
Expand Down Expand Up @@ -343,46 +347,87 @@
"pretrained": args.pretrained,
}

global global_step

Check warning on line 350 in references/classification/train_pytorch_character.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

references/classification/train_pytorch_character.py#L350

Global variable 'global_step' undefined at the module level
global_step = 0 # Shared global step counter

# W&B
if args.wb:
import wandb

run = wandb.init(
name=exp_name,
project="character-classification",
config=config,
)
run = wandb.init(name=exp_name, project="character-classification", config=config)

def wandb_log_at_step(train_loss=None, val_loss=None, lr=None):
wandb.log({
**({"train_loss_step": train_loss} if train_loss is not None else {}),
**({"val_loss_step": val_loss} if val_loss is not None else {}),
**({"step_lr": lr} if lr is not None else {}),
})

# ClearML
if args.clearml:
from clearml import Task
from clearml import Logger, Task

task = Task.init(project_name="docTR/character-classification", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)
global iteration
iteration = 0

def clearml_log_at_step(train_loss=None, val_loss=None, lr=None):
logger = Logger.current_logger()
if train_loss is not None:
logger.report_scalar(
title="Training Step Loss",
series="train_loss_step",
iteration=global_step,
value=train_loss,
)
if val_loss is not None:
logger.report_scalar(
title="Validation Step Loss",
series="val_loss_step",
iteration=global_step,
value=val_loss,
)
if lr is not None:
logger.report_scalar(
title="Step Learning Rate",
series="step_lr",
iteration=global_step,
value=lr,
)

# Unified logger
def log_at_step(train_loss=None, val_loss=None, lr=None):
global global_step
if args.wb:
wandb_log_at_step(train_loss, val_loss, lr)
if args.clearml:
clearml_log_at_step(train_loss, val_loss, lr)
global_step += 1 # Increment the shared global step counter

# Create loss queue
min_loss = np.inf
# Training loop
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
for epoch in range(args.epochs):
fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, clearml_log=args.clearml
train_loss, actual_lr = fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step
)
pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6} | LR: {actual_lr:.6}")

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
val_loss, acc = evaluate(model, val_loader, batch_transforms, log=log_at_step)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
pbar.write(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")

# W&B
if args.wb:
wandb.log({
"train_loss": train_loss,
"val_loss": val_loss,
"learning_rate": actual_lr,
"acc": acc,
})

Expand All @@ -391,24 +436,27 @@
from clearml import Logger

logger = Logger.current_logger()
logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch)
logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch)
logger.report_scalar(title="Accuracy", series="acc", value=acc, iteration=epoch)

if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
pbar.write("Training halted early due to reaching patience limit.")
break

if args.wb:
run.finish()

if args.push_to_hub:
push_to_hf_hub(model, exp_name, task="classification", run_config=args)

if args.export_onnx:
print("Exporting model to ONNX...")
pbar.write("Exporting model to ONNX...")
dummy_batch = next(iter(val_loader))
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
model_path = export_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")
pbar.write(f"Exported model saved in {model_path}")


def parse_args():
Expand Down
Loading
Loading