Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xview2: oversampling images with building damage #109

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/dataset/xview2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ auto_download: False
img_size: 1024
multi_temporal: False
multi_modal: False
oversample_building_damage: True

# classes
ignore_index: -1
Expand Down
64 changes: 62 additions & 2 deletions pangaea/datasets/xview2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Sources:
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/datasets/base_dataset.py
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/datasets/supervised_dataset.py
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/legacy/datasets.py

from typing import Sequence, Dict, Any, Union, Literal, Tuple
import time
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
data_max: dict[str, list[str]],
download_url: str,
auto_download: bool,
oversample_building_damage: bool
):
"""Initialize the xView2 dataset.
Link: https://xview2.org/dataset
Expand Down Expand Up @@ -69,6 +71,7 @@ def __init__(
e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]}
download_url (str): url to download the dataset.
auto_download (bool): whether to download the dataset automatically.
oversample_building_damage (bool): whether to oversample images with building damage
"""
super(xView2, self).__init__(
split=split,
Expand Down Expand Up @@ -104,9 +107,9 @@ def __init__(
self.ignore_index = ignore_index
self.download_url = download_url
self.auto_download = auto_download
self.oversample_building_damage = oversample_building_damage

self.all_files = self.get_all_files()


def get_all_files(self) -> Sequence[str]:
all_files = []
Expand All @@ -123,7 +126,13 @@ def get_all_files(self) -> Sequence[str]:

if self.split != "test":
train_val_idcs = self.get_stratified_train_val_split(all_files)

if self.split == "train" and self.oversample_building_damage:
train_val_idcs[self.split] = self.oversample_building_files(all_files, train_val_idcs[self.split])


all_files = [all_files[i] for i in train_val_idcs[self.split]]


return all_files

Expand All @@ -140,6 +149,34 @@ def get_stratified_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[i
stratify=disaster_names)
return {"train": train_idxs, "val": val_idxs}

def oversample_building_files(self, all_files, train_idxs):
# Oversamples buildings on the image-level, by including each image with any building pixels twice in the training set.
file_classes = []
for i, fn in enumerate(all_files):
fl = np.zeros((4,), dtype=bool)
# Only read images that are included in train_idxs
if i in train_idxs:
msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'),
cv2.IMREAD_UNCHANGED)
for c in range(1, 5):
fl[c - 1] = c in msk1
file_classes.append(fl)
file_classes = np.asarray(file_classes)

new_train_idxs = []
for i in train_idxs:
new_train_idxs.append(i)
# If any building damage was present in the image, add the image to the training set a second time.
if file_classes[i, 1:].max():
new_train_idxs.append(i)
# If minor or medium damage were present, add it a third time, since these two classes are very hard to detect.
# Source: https://github.com/DIUx-xView/xView2_first_place/blob/master/train34_cls.py
if file_classes[i, 1:3].max():
new_train_idxs.append(i)
train_idxs = np.asarray(new_train_idxs)
return train_idxs


def __len__(self) -> int:
return len(self.all_files)

Expand Down Expand Up @@ -206,4 +243,27 @@ def download(self, silent=False):
tar.extractall(output_path)
print("done.")

os.remove(output_path / temp_file_name)
os.remove(output_path / temp_file_name)

if __name__=="__main__":
dataset = xView2(
split="train",
dataset_name="xView2",
root_path="./data/xView2",
download_url="https://the-dataset-is-not-publicly-available.com",
auto_download=False,
img_size=1024,
multi_temporal=False,
multi_modal=False,
classes=["No building", "No damage","Minor damage","Major damage","Destroyed"],
num_classes=5,
ignore_index=-1,
bands=["B4", "B3", "B2"],
distribution = [0.9415, 0.0448, 0.0049, 0.0057, 0.0031],
data_mean=[66.7703, 88.4452, 85.1047],
data_std=[48.3066, 51.9129, 62.7612],
data_min=[0.0, 0.0, 0.0],
data_max=[255, 255, 255],
)
x,y = dataset[0]
print(x["optical"].shape, y.shape)