Skip to content

Commit

Permalink
Update pastis dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
gle-bellier committed Oct 15, 2024
1 parent 7ccc8db commit b7d05b9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 38 deletions.
8 changes: 4 additions & 4 deletions configs/dataset/pastis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ multi_modal: True
#limited_label: False

# classes
ignore_index: 0
num_classes: 19
ignore_index: 19
num_classes: 20
classes:
- Background
- Meadow
Expand All @@ -32,7 +32,7 @@ classes:
- Orchard
- Mixed Cereal
- Sorghum
#- Void Label
- Void Label
distribution:
- 0.00000
- 0.25675
Expand All @@ -53,7 +53,7 @@ distribution:
- 0.02460
- 0.00696
- 0.00580
#- 0.29476
- 0.29476

bands:
optical:
Expand Down
49 changes: 15 additions & 34 deletions pangaea/datasets/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,31 +141,15 @@ def __init__(
folds = [4]
else:
folds = [5]

self.dataset_name = dataset_name
self.bands = bands
self.split = split
self.path = root_path
self.data_mean = data_mean
self.data_std = data_std
self.data_min = data_min
self.data_max = data_max
self.classes = classes
self.img_size = img_size
self.distribution = distribution

self.num_classes = num_classes
self.ignore_index = ignore_index
self.grid_size = multi_temporal
self.download_url = download_url
self.auto_download = auto_download
self.modalities = ["s2", "aerial", "s1-asc"]
self.nb_split = 1

reference_date = "2018-09-01"
self.reference_date = datetime(*map(int, reference_date.split("-")))

self.meta_patch = gpd.read_file(os.path.join(self.path, "metadata.geojson"))
self.meta_patch = gpd.read_file(
os.path.join(self.root_path, "metadata.geojson")
)

self.num_classes = 20

Expand Down Expand Up @@ -193,19 +177,16 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
part = i % (self.nb_split * self.nb_split)
label = torch.from_numpy(
np.load(
os.path.join(self.path, "ANNOTATIONS/TARGET_" + str(name) + ".npy")
os.path.join(self.root_path, "ANNOTATIONS/TARGET_" + str(name) + ".npy")
)[0].astype(np.int32)
)
# remove void class
label[label == 19] = self.ignore_index
# label = label[1:-1] # remove Background and Void classes
output = {"label": label, "name": name}

for modality in self.modalities:
if modality == "aerial":
with rasterio.open(
os.path.join(
self.path,
self.root_path,
"DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_"
+ str(name)
+ ".tif",
Expand All @@ -220,7 +201,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
torch.from_numpy(
np.load(
os.path.join(
self.path,
self.root_path,
"DATA_{}".format(modality_name.upper()),
"{}_{}.npy".format(modality_name.upper(), name),
)
Expand All @@ -237,7 +218,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
torch.from_numpy(
np.load(
os.path.join(
self.path,
self.root_path,
"DATA_{}".format(modality_name.upper()),
"{}_{}.npy".format(modality_name.upper(), name),
)
Expand All @@ -254,7 +235,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
torch.from_numpy(
np.load(
os.path.join(
self.path,
self.root_path,
"DATA_{}".format(modality_name.upper()),
"{}_{}.npy".format(modality_name.upper(), name),
)
Expand Down Expand Up @@ -286,7 +267,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
torch.from_numpy(
np.load(
os.path.join(
self.path,
self.root_path,
"DATA_{}".format(modality_name.upper()),
"{}_{}.npy".format(modality_name.upper(), name),
)
Expand Down Expand Up @@ -319,7 +300,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
torch.from_numpy(
np.load(
os.path.join(
self.path,
self.root_path,
"DATA_{}".format(modality_name.upper()),
"{}_{}.npy".format(modality_name.upper(), name),
)
Expand All @@ -337,7 +318,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
torch.from_numpy(
np.load(
os.path.join(
self.path,
self.root_path,
"DATA_{}".format(modality.upper()),
"{}_{}.npy".format(modality.upper(), name),
)
Expand All @@ -360,17 +341,17 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
optical_ts = rearrange(output["s2"], "t c h w -> c t h w")
sar_ts = rearrange(output["s1-asc"], "t c h w -> c t h w")

if self.grid_size == 1:
if self.multi_temporal == 1:
# we only take the last frame
optical_ts = optical_ts[:, -1]
sar_ts = sar_ts[:, -1]
else:
# select evenly spaced samples
optical_indexes = torch.linspace(
0, optical_ts.shape[1] - 1, self.grid_size, dtype=torch.long
0, optical_ts.shape[1] - 1, self.multi_temporal, dtype=torch.long
)
sar_indexes = torch.linspace(
0, sar_ts.shape[1] - 1, self.grid_size, dtype=torch.long
0, sar_ts.shape[1] - 1, self.multi_temporal, dtype=torch.long
)

optical_ts = optical_ts[:, optical_indexes]
Expand All @@ -381,7 +362,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
"optical": optical_ts.to(torch.float32),
"sar": sar_ts.to(torch.float32),
},
"target": output["label"],
"target": output["label"].to(torch.int64),
"metadata": {},
}

Expand Down

0 comments on commit b7d05b9

Please sign in to comment.