From bfdae2cd990c3eed8956f3430c6a8c81f9f20304 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Tue, 23 Jan 2024 13:46:43 +0100 Subject: [PATCH] Drop last in distributed mode (#55) --- gluefactory/datasets/base_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gluefactory/datasets/base_dataset.py b/gluefactory/datasets/base_dataset.py index ef622cbc..b3114c99 100644 --- a/gluefactory/datasets/base_dataset.py +++ b/gluefactory/datasets/base_dataset.py @@ -161,9 +161,12 @@ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): except omegaconf.MissingMandatoryValue: batch_size = self.conf.batch_size num_workers = self.conf.get("num_workers", batch_size) + drop_last = True if split == "train" else False if distributed: shuffle = False - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, drop_last=drop_last + ) else: sampler = None if shuffle is None: @@ -178,7 +181,7 @@ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): num_workers=num_workers, worker_init_fn=worker_init_fn, prefetch_factor=self.conf.prefetch_factor, - drop_last=True if split == "train" else False, + drop_last=drop_last, ) def get_overfit_loader(self, split):