diff --git a/utils/dataset.py b/utils/dataset.py index 3afeca8b91..4878e030f9 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -9,10 +9,11 @@ class BasicDataset(Dataset): - def __init__(self, imgs_dir, masks_dir, scale=1): + def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''): self.imgs_dir = imgs_dir self.masks_dir = masks_dir self.scale = scale + self.mask_suffix = mask_suffix assert 0 < scale <= 1, 'Scale must be between 0 and 1' self.ids = [splitext(file)[0] for file in listdir(imgs_dir) @@ -43,7 +44,7 @@ def preprocess(cls, pil_img, scale): def __getitem__(self, i): idx = self.ids[i] - mask_file = glob(self.masks_dir + idx + '.*') + mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*') img_file = glob(self.imgs_dir + idx + '.*') assert len(mask_file) == 1, \ @@ -63,3 +64,8 @@ def __getitem__(self, i): 'image': torch.from_numpy(img).type(torch.FloatTensor), 'mask': torch.from_numpy(mask).type(torch.FloatTensor) } + + +class CarvanaDataset(BasicDataset): + def __init__(self, imgs_dir, masks_dir, scale=1): + super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')