Skip to content

Commit

Permalink
Merge pull request #40 from yurujaja/regression_updates
Browse files Browse the repository at this point in the history
Regression updates
  • Loading branch information
RituYadav92 authored Sep 17, 2024
2 parents 05dd89e + 75691e1 commit 4a2712d
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 89 deletions.
14 changes: 8 additions & 6 deletions configs/augmentations/regression_default.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
train:
RegPreprocessor: ~
NormalizeMeanStd: ~
NormalizeMinMax: ~
# NormalizeMeanStd: ~
RandomCropToEncoder: ~
RandomFlip:
ud_probability: 0.3
lr_probability: 0.3
# RandomFlip:
# ud_probability: 0.3
# lr_probability: 0.3
test:
RegPreprocessor: ~
NormalizeMeanStd: ~
Tile: ~
NormalizeMinMax: ~
# NormalizeMeanStd: ~
Tile: ~
50 changes: 30 additions & 20 deletions configs/datasets/biomassters.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
dataset_name: BioMassters
root_path: /geomatics/gpuserver-0/vmarsocci/biomassters
download_url: #https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars/resolve/main/hls_burn_scars.tar.gz?download=true
root_path: /geoinfo_vol1/home/r/i/ru/
# download_url: #https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars/resolve/main/hls_burn_scars.tar.gz?download=true
auto_download: False
img_size: 256
multi_temporal: 12
temporal: 12 #6 (summer month if multi_temp false), 12
multi_temporal: True
multi_modal: True


# classes
# ignore_index: -1
ignore_index: -1
num_classes: 1
classes:
- regression
Expand All @@ -25,29 +25,39 @@ bands:
- B2
- B3
- B4
# - B5
# - B6
# - B7
# - B8
# - B8a
# - B11
# - B12
- B5
- B6
- B7
- B8
- B8A
- B11
- B12
- CLP
sar:
- ASC_VV
- ASC_VH
- DSC_VV
- DSC_VH

# TODO: fix the normalization
data_mean:
optical:
- 66.7703
- 88.4452
- 85.1047
# sar:
# - 66.7703
# - 88.4452
# - 85.1047
sar:

data_std:
optical:
- 48.3066
- 51.9129
- 62.7612
# sar:
# - 48.3066
# - 51.9129
# - 62.7612
sar:

data_min:
optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
sar: [-25, -62, -25, -60]

data_max:
optical: [19616., 18400., 17536., 17097., 16928., 16768., 16593., 16492., 15401., 15226., 255.]
sar: [29, 28, 30, 22]
24 changes: 24 additions & 0 deletions configs/segmentors/reg_upernet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
segmentor_name: UPerNet_regress
task_name: regression
# time_frames: 2
#task_model_args:
#num_frames: 1
#mt_strategy: "ltae" #activated only when if num_frames > 1
#num_classes - task parameter passed from the dataset config
#wave_list - task parameter passed from the dataset config

channels: 512

loss:
loss_name: MSELoss # WeightedCrossEntropy
ignore_index: -1

optimizer:
optimizer_name: AdamW
lr: 0.0001
weight_decay: 0.05

scheduler:
scheduler_name: MultiStepLR
lr_milestones: [0.6, 0.9]

5 changes: 2 additions & 3 deletions configs/segmentors/reg_upernet_mt.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
segmentor_name: MTUPerNet
segmentor_name: MTUPerNet_regress
task_name: regression
binary: False
multi_temporal_strategy: linear
# time_frames: 2
time_frames: 12
#task_model_args:
#num_frames: 1
#mt_strategy: "ltae" #activated only when if num_frames > 1
Expand Down
92 changes: 55 additions & 37 deletions datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,51 @@
import torch
import pandas as pd
import pathlib
import rasterio
from skimage import io
from os.path import join as opj
from .utils import read_tif
from utils.registry import DATASET_REGISTRY

s1_min = np.array([-25 , -62 , -25, -60], dtype="float32")
s1_max = np.array([ 29 , 28, 30, 22 ], dtype="float32")
s1_mm = s1_max - s1_min

s2_max = np.array(
[19616., 18400., 17536., 17097., 16928., 16768., 16593., 16492., 15401., 15226., 255.],
dtype="float32",
)

IMG_SIZE = (256, 256)


def read_imgs(chip_id: str, data_dir: pathlib.Path):
imgs, imgs_s1, imgs_s2, mask = [], [], [], []
for month in range(12):
img_s1 = read_tif(data_dir.joinpath(f"{chip_id}_S1_{month:0>2}.tif"))
m = img_s1 == -9999
img_s1 = img_s1.astype("float32")
img_s1 = (img_s1 - s1_min) / s1_mm
img_s1 = np.where(m, 0, img_s1)
filepath = data_dir.joinpath(f"{chip_id}_S2_{month:0>2}.tif")
if filepath.exists():
img_s2 = read_tif(filepath)
def read_imgs(multi_temporal, temp , fname, data_dir):
imgs_s1, imgs_s2, mask = [], [], []
if multi_temporal:
month_list = list(range(12))
else:
month_list = [temp]

for month in month_list:

s1_fname = '%s_%s_%02d.tif' % (str.split(fname, '_')[0], 'S1', month)
s2_fname = '%s_%s_%02d.tif' % (str.split(fname, '_')[0], 'S2', month)

s1_filepath = data_dir.joinpath(s1_fname)
if s1_filepath.exists():
img_s1 = io.imread(s1_filepath)
m = img_s1 == -9999
img_s1 = img_s1.astype("float32")
img_s1 = (img_s1 - s1_min) / s1_mm
img_s1 = np.where(m, 0, img_s1)
else:
img_s1 = np.zeros(IMG_SIZE + (4,), dtype="float32")

s2_filepath = data_dir.joinpath(s2_fname)
if s2_filepath.exists():
img_s2 = io.imread(s2_filepath)
img_s2 = img_s2.astype("float32")
img_s2 = img_s2 / s2_max
else:
else:
img_s2 = np.zeros(IMG_SIZE + (11,), dtype="float32")

# img = np.concatenate([img_s1, img_s2], axis=2)

img_s1 = np.transpose(img_s1, (2, 0, 1))
img_s2 = np.transpose(img_s2, (2, 0, 1))
imgs_s1.append(img_s1)
Expand All @@ -42,39 +55,43 @@ def read_imgs(chip_id: str, data_dir: pathlib.Path):

mask = np.array(mask)

imgs_s1 = np.stack(imgs_s1, axis=1) # [c, t, h, w]
imgs_s1 = np.stack(imgs_s1, axis=1) # [c, t, h, w] prithvi
imgs_s2 = np.stack(imgs_s2, axis=1) # [c, t, h, w]

return imgs_s1, imgs_s2, mask

@DATASET_REGISTRY.register()
class BioMassters(torch.utils.data.Dataset):
def __init__(self, cfg, split): #, augs=False):
df_path = pathlib.Path(cfg["root_path"]).joinpath("The_BioMassters_-_features_metadata.csv.csv")
df: pd.DataFrame = pd.read_csv(str(df_path))
self.df = df[df.split == split].copy()
self.dir_features = pathlib.Path(cfg["root_path"]).joinpath(f"{split}_features")
self.dir_labels = pathlib.Path(cfg["root_path"]).joinpath( f"{split}_agbm")

self.root_path = cfg['root_path']
self.data_min = cfg['data_min']
self.data_max = cfg['data_max']
self.multi_temporal = cfg['multi_temporal']
self.temp = cfg['temporal']
self.split = split
# self.augs = augs

self.data_path = pathlib.Path(self.root_path).joinpath(f"{split}_Data_list.csv")
self.id_list = pd.read_csv(self.data_path)['chip_id']
self.dir_features = pathlib.Path(self.root_path).joinpath("TRAIN/train_features")
self.dir_labels = pathlib.Path(self.root_path).joinpath( "TRAIN/train_agbm")

def __len__(self):
return len(self.df)
return len(self.id_list)

def __getitem__(self, index):
item = self.df.iloc[index]

# print(item.chip_id)
# print(self.dir_features)

imgs_s1, imgs_s2, mask = read_imgs(item.chip_id, self.dir_features)
if self.dir_labels is not None:
target = read_tif(self.dir_labels.joinpath(f'{item.chip_id}_agbm.tif'))
else:
target = item.chip_id
chip_id = self.id_list.iloc[index]
fname = str(chip_id)+'_agbm.tif'

imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features)
with rasterio.open(self.dir_labels.joinpath(fname)) as lbl:
target = lbl.read(1)
target = np.nan_to_num(target)
# print(imgs_s1.shape, imgs_s2.shape, len(mask), target.shape)#(4, 1, 256, 256) (11, 1, 256, 256) 1 (256, 256)

# Reshaping tensors from (T, H, W, C) to (C, T, H, W)
# format (B/C, T, H, W)
imgs_s1 = torch.from_numpy(imgs_s1).float()
imgs_s2 = torch.from_numpy(imgs_s2).float()
target = torch.from_numpy(target).float()
Expand All @@ -90,11 +107,12 @@ def __getitem__(self, index):

@staticmethod
def get_splits(dataset_config):
dataset_train = BioMassters(cfg=dataset_config, split="test")
dataset_val = BioMassters(cfg=dataset_config, split="test")
dataset_train = BioMassters(cfg=dataset_config, split="train")
dataset_val = BioMassters(cfg=dataset_config, split="val")
dataset_test = BioMassters(cfg=dataset_config, split="test")
# print('loaded sample points',len(dataset_train), len(dataset_val), len(dataset_test))
return dataset_train, dataset_val, dataset_test

@staticmethod
def download(dataset_config:dict, silent=False):
pass
pass
24 changes: 3 additions & 21 deletions engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,11 @@ def __init__(self, args, val_loader, exp_dir, device):

@torch.no_grad()
def evaluate(self, model, model_name='model'):
# TODO: Rework this to allow evaluation only runs
# Move common parts to parent class, and get loss function from the registry.
t = time.time()

model.eval()

tag = f'Evaluating {model_name} on {self.split} set'
# confusion_matrix = torch.zeros((self.num_classes, self.num_classes), device=self.device)

for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)):
image, target = data['image'], data['target']
Expand All @@ -186,14 +183,8 @@ def evaluate(self, model, model_name='model'):

logits = model(image, output_shape=target.shape[-2:]).squeeze(dim=1)
mse = F.mse_loss(logits, target)
# pred = torch.argmax(logits, dim=1)
# valid_mask = target != -1
# pred, target = pred[valid_mask], target[valid_mask]
# count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes ** 2)
# confusion_matrix += count.view(self.num_classes, self.num_classes)

# torch.distributed.all_reduce(confusion_matrix, op=torch.distributed.ReduceOp.SUM)
metrics = {"MSE" : mse.item, "RMSE" : torch.sqrt(mse).item}

metrics = {"MSE" : mse.item(), "RMSE" : torch.sqrt(mse).item()}
self.log_metrics(metrics)

used_time = time.time() - t
Expand All @@ -204,20 +195,11 @@ def evaluate(self, model, model_name='model'):
def __call__(self, model, model_name='model'):
return self.evaluate(model, model_name)


# def compute_metrics(self, confusion_matrix):
# iou = torch.diag(confusion_matrix) / (confusion_matrix.sum(dim=1) + confusion_matrix.sum(dim=0) - torch.diag(confusion_matrix)) * 100
# iou = iou.cpu()
# metrics = {'IoU': [iou[i].item() for i in range(self.num_classes)], 'mIoU': iou.mean().item()}

# return metrics

def log_metrics(self, metrics):
header = "------- MSE and RMSE --------\n"
# iou = '\n'.join(c.ljust(self.max_name_len, ' ') + '\t{:>7}'.format('%.3f' % num) for c, num in zip(self.classes, metrics['MSE'])) + '\n'
mse = "-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE'])+'\n'
rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE'])
self.logger.info(header+mse+rmse)

if self.args.use_wandb and self.args.rank == 0:
self.wandb.log({"val_MSE": metrics["MSE"], "val_RMSE": metrics["RMSE"]})
self.wandb.log({"val_MSE": metrics["MSE"], "val_RMSE": metrics["RMSE"]})
3 changes: 2 additions & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- pillow
- pytorch>=2.1
- rasterio
- scikit-image
- scikit-learn
- tensorboard
- torchaudio
Expand All @@ -28,4 +29,4 @@ dependencies:
- google-cloud-storage
- omegaconf
- pydataverse
- pytest
- pytest
Loading

0 comments on commit 4a2712d

Please sign in to comment.