Skip to content

Commit

Permalink
Update biomassters.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RituYadav92 authored Sep 20, 2024
1 parent 638bbbf commit 2714a5e
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@
from .utils import read_tif
from utils.registry import DATASET_REGISTRY

s1_min = np.array([-25 , -62 , -25, -60], dtype="float32")
s1_max = np.array([ 29 , 28, 30, 22 ], dtype="float32")
s1_mm = s1_max - s1_min
s2_max = np.array(
[19616., 18400., 17536., 17097., 16928., 16768., 16593., 16492., 15401., 15226., 255.],
dtype="float32",
)
IMG_SIZE = (256, 256)

def read_imgs(multi_temporal, temp , fname, data_dir):
def read_imgs(multi_temporal, temp , fname, data_dir, img_size):
imgs_s1, imgs_s2, mask = [], [], []
if multi_temporal:
month_list = list(range(12))
Expand All @@ -34,18 +25,16 @@ def read_imgs(multi_temporal, temp , fname, data_dir):
img_s1 = imread(s1_filepath)
m = img_s1 == -9999
img_s1 = img_s1.astype("float32")
img_s1 = (img_s1 - s1_min) / s1_mm
img_s1 = np.where(m, 0, img_s1)
else:
img_s1 = np.zeros(IMG_SIZE + (4,), dtype="float32")
img_s1 = np.zeros((img_size, img_size) + (4,), dtype="float32")

s2_filepath = data_dir.joinpath(s2_fname)
if s2_filepath.exists():
img_s2 = imread(s2_filepath)
img_s2 = img_s2.astype("float32")
img_s2 = img_s2 / s2_max
else:
img_s2 = np.zeros(IMG_SIZE + (11,), dtype="float32")
img_s2 = np.zeros((img_size, img_size) + (11,), dtype="float32")

img_s1 = np.transpose(img_s1, (2, 0, 1))
img_s2 = np.transpose(img_s2, (2, 0, 1))
Expand All @@ -70,6 +59,7 @@ def __init__(self, cfg, split):
self.multi_temporal = cfg['multi_temporal']
self.temp = cfg['temporal']
self.split = split
self.img_size = cfg['img_size']

self.data_path = pathlib.Path(self.root_path).joinpath(f"{split}_Data_list.csv")
self.id_list = pd.read_csv(self.data_path)['chip_id']
Expand All @@ -84,7 +74,7 @@ def __getitem__(self, index):
chip_id = self.id_list.iloc[index]
fname = str(chip_id)+'_agbm.tif'

imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features)
imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features, self.img_size)
with rasterio.open(self.dir_labels.joinpath(fname)) as lbl:
target = lbl.read(1)
target = np.nan_to_num(target)
Expand Down

0 comments on commit 2714a5e

Please sign in to comment.