diff --git a/main_dino.py b/main_dino.py index cade9873d..19e7e652e 100644 --- a/main_dino.py +++ b/main_dino.py @@ -126,6 +126,13 @@ def get_args_parser(): parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + + # logging with aim + parser.add_argument("--use_aim", default=True, type=bool, help="whether to use aim for logging.") + parser.add_argument("--aim_repo", default=None, type=str, help="path to Aim repository.") + parser.add_argument("--aim_run_hash", default=None, type=str, + help="Aim run hash. Create a new run if not specified.") + return parser @@ -301,7 +308,7 @@ def train_dino(args): def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch, fp16_scaler, args): - metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger = utils.MetricLogger(args, delimiter=" ") header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): # update weight decay and learning rate according to their schedule diff --git a/utils.py b/utils.py index 958625012..244bac8bf 100644 --- a/utils.py +++ b/utils.py @@ -309,11 +309,30 @@ def reduce_dict(input_dict, average=True): reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict +try: + import functools + + from aim import Run + + @functools.lru_cache() + def get_aim_run(repo, run_hash): + from aim import Run + return Run(run_hash=run_hash, repo=repo) + +except ImportError: + print("Warning: Aim is not installed. Install aim to use metric logging.") + get_aim_run = None + class MetricLogger(object): - def __init__(self, delimiter="\t"): + def __init__(self, args, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter + self.aim_run = None + if args.use_aim and get_aim_run: + self.aim_run = get_aim_run(args.aim_repo, args.aim_run_hash) + for key, value in vars(args).items(): + self.aim_run.set(('cli_args', key), value, strict=False) def update(self, **kwargs): for k, v in kwargs.items(): @@ -321,6 +340,8 @@ def update(self, **kwargs): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) + if self.aim_run: + self.aim_run.track(v, name=k) def __getattr__(self, attr): if attr in self.meters: