From bfdae2cd990c3eed8956f3430c6a8c81f9f20304 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Tue, 23 Jan 2024 13:46:43 +0100 Subject: [PATCH 1/2] Drop last in distributed mode (#55) --- gluefactory/datasets/base_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gluefactory/datasets/base_dataset.py b/gluefactory/datasets/base_dataset.py index ef622cbc..b3114c99 100644 --- a/gluefactory/datasets/base_dataset.py +++ b/gluefactory/datasets/base_dataset.py @@ -161,9 +161,12 @@ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): except omegaconf.MissingMandatoryValue: batch_size = self.conf.batch_size num_workers = self.conf.get("num_workers", batch_size) + drop_last = True if split == "train" else False if distributed: shuffle = False - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, drop_last=drop_last + ) else: sampler = None if shuffle is None: @@ -178,7 +181,7 @@ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): num_workers=num_workers, worker_init_fn=worker_init_fn, prefetch_factor=self.conf.prefetch_factor, - drop_last=True if split == "train" else False, + drop_last=drop_last, ) def get_overfit_loader(self, split): From 365ad7b0d515da750845b78374f8b6b79de94e72 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Wed, 24 Jan 2024 15:47:12 +0100 Subject: [PATCH 2/2] Fix typo (#46) --- gluefactory/models/utils/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluefactory/models/utils/losses.py b/gluefactory/models/utils/losses.py index cca17636..06c7958b 100644 --- a/gluefactory/models/utils/losses.py +++ b/gluefactory/models/utils/losses.py @@ -69,5 +69,5 @@ def nll_loss(self, log_assignment, data): weights[:, :m, :n] = positive weights[:, :m, -1] = neg0 - weights[:, -1, :m] = neg1 + weights[:, -1, :n] = neg1 return weights