From 2714a5e3f535aca3dbc81c34d67d457bdad20c71 Mon Sep 17 00:00:00 2001 From: Ritu Yadav <40523539+RituYadav92@users.noreply.github.com> Date: Fri, 20 Sep 2024 23:23:34 +0200 Subject: [PATCH] Update biomassters.py --- datasets/biomassters.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/datasets/biomassters.py b/datasets/biomassters.py index 472101ad..8624726c 100644 --- a/datasets/biomassters.py +++ b/datasets/biomassters.py @@ -8,16 +8,7 @@ from .utils import read_tif from utils.registry import DATASET_REGISTRY -s1_min = np.array([-25 , -62 , -25, -60], dtype="float32") -s1_max = np.array([ 29 , 28, 30, 22 ], dtype="float32") -s1_mm = s1_max - s1_min -s2_max = np.array( - [19616., 18400., 17536., 17097., 16928., 16768., 16593., 16492., 15401., 15226., 255.], - dtype="float32", -) -IMG_SIZE = (256, 256) - -def read_imgs(multi_temporal, temp , fname, data_dir): +def read_imgs(multi_temporal, temp , fname, data_dir, img_size): imgs_s1, imgs_s2, mask = [], [], [] if multi_temporal: month_list = list(range(12)) @@ -34,18 +25,16 @@ def read_imgs(multi_temporal, temp , fname, data_dir): img_s1 = imread(s1_filepath) m = img_s1 == -9999 img_s1 = img_s1.astype("float32") - img_s1 = (img_s1 - s1_min) / s1_mm img_s1 = np.where(m, 0, img_s1) else: - img_s1 = np.zeros(IMG_SIZE + (4,), dtype="float32") + img_s1 = np.zeros((img_size, img_size) + (4,), dtype="float32") s2_filepath = data_dir.joinpath(s2_fname) if s2_filepath.exists(): img_s2 = imread(s2_filepath) img_s2 = img_s2.astype("float32") - img_s2 = img_s2 / s2_max else: - img_s2 = np.zeros(IMG_SIZE + (11,), dtype="float32") + img_s2 = np.zeros((img_size, img_size) + (11,), dtype="float32") img_s1 = np.transpose(img_s1, (2, 0, 1)) img_s2 = np.transpose(img_s2, (2, 0, 1)) @@ -70,6 +59,7 @@ def __init__(self, cfg, split): self.multi_temporal = cfg['multi_temporal'] self.temp = cfg['temporal'] self.split = split + self.img_size = cfg['img_size'] self.data_path = pathlib.Path(self.root_path).joinpath(f"{split}_Data_list.csv") self.id_list = pd.read_csv(self.data_path)['chip_id'] @@ -84,7 +74,7 @@ def __getitem__(self, index): chip_id = self.id_list.iloc[index] fname = str(chip_id)+'_agbm.tif' - imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features) + imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features, self.img_size) with rasterio.open(self.dir_labels.joinpath(fname)) as lbl: target = lbl.read(1) target = np.nan_to_num(target)