diff --git a/datasets/utae_dynamicen.py b/datasets/utae_dynamicen.py index 6f2bc620..3747a643 100644 --- a/datasets/utae_dynamicen.py +++ b/datasets/utae_dynamicen.py @@ -140,20 +140,11 @@ def __len__(self): return len(self.files) def __getitem__(self, index): - # padding = (np.array(self.mean)).tolist() (images, dates), label = self.load_data(index) - #base_size = label.shape[1] images = torch.from_numpy(images).permute(3, 0, 1, 2)#.transpose(0, 1) - images = TF.resize(images, size=512) - #images = images[0] - #print(images.shape) - label = torch.from_numpy(np.array(label, dtype=np.int32)).long().unsqueeze(0).unsqueeze(0) - #print(label.shape) - label = TF.resize(label, size=512, interpolation=transforms.InterpolationMode.NEAREST) - label = label.squeeze(0).squeeze(0) - #print(label.shape) - #images = images.transpose(0, 1) + label = torch.from_numpy(np.array(label, dtype=np.int32)).long() + output = { 'image': { 'optical': images,