Skip to content

Commit

Permalink
missing lines
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Dec 17, 2024
1 parent 6c1235c commit 4cc0fb6
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions references/classification/train_pytorch_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training")
Expand Down
2 changes: 1 addition & 1 deletion references/detection/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def main(rank: int, world_size: int, args):
min_loss = val_loss
if args.save_interval_epoch:
print(f"Saving state at epoch: {epoch + 1}")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt")
torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt")
log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
Expand Down
2 changes: 1 addition & 1 deletion references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def main(args):
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
model.save_weights(f"./{exp_name}.weights.h5")
model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5")
min_loss = val_loss
if args.save_interval_epoch:
print(f"Saving state at epoch: {epoch + 1}")
Expand Down

0 comments on commit 4cc0fb6

Please sign in to comment.