From 4002dd84b9ee8b6735a93c3b364966be5dbd7ec6 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Fri, 20 Sep 2024 17:03:26 +0200 Subject: [PATCH] Squash merge: Merged from LeungTsang-main pr#51 --- configs/datasets/dynamicen.yaml | 2 +- datasets/mados.py | 3 +- datasets/utae_dynamicen.py | 2 +- engine/data_preprocessor.py | 74 ++++++++++------------------ engine/trainer.py | 52 ++++++------------- foundation_models/prithvi_encoder.py | 4 -- run.py | 44 ++++++++--------- segmentors/upernet.py | 2 +- 8 files changed, 66 insertions(+), 117 deletions(-) diff --git a/configs/datasets/dynamicen.yaml b/configs/datasets/dynamicen.yaml index 190c10c5..46b75340 100644 --- a/configs/datasets/dynamicen.yaml +++ b/configs/datasets/dynamicen.yaml @@ -4,7 +4,7 @@ download_url: None auto_download: False img_size: 1024 -multi_temporal: False +multi_temporal: 6 multi_modal: False # classes diff --git a/datasets/mados.py b/datasets/mados.py index 9e917217..62a4826d 100644 --- a/datasets/mados.py +++ b/datasets/mados.py @@ -23,7 +23,6 @@ from .utils import DownloadProgressBar from utils.registry import DATASET_REGISTRY -import matplotlib.pyplot as plt ############################################################### @@ -156,4 +155,4 @@ def get_splits(dataset_config): dataset_train = MADOS(cfg=dataset_config, split="train", is_train=True) dataset_val = MADOS(cfg=dataset_config, split="val", is_train=False) dataset_test = MADOS(cfg=dataset_config, split="test", is_train=False) - return dataset_train, dataset_val, dataset_test + return dataset_train, dataset_val, dataset_test \ No newline at end of file diff --git a/datasets/utae_dynamicen.py b/datasets/utae_dynamicen.py index 3747a643..ca478da4 100644 --- a/datasets/utae_dynamicen.py +++ b/datasets/utae_dynamicen.py @@ -36,7 +36,7 @@ def __init__(self, cfg, split, is_train=True): self.split = split self.is_train = is_train - self.mode = 'single' + self.mode = 'weekly' self.files = [] diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index e8867003..c9d4633b 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -1,13 +1,15 @@ -import logging -import math import random -from typing import Callable +import math -import numpy as np -import omegaconf import torch import torch.nn.functional as F import torchvision.transforms as T +import torchvision.transforms.functional as TF +from typing import Callable + +import numpy as np +import logging +import omegaconf from utils.registry import AUGMENTER_REGISTRY @@ -16,7 +18,7 @@ def get_collate_fn(cfg: omegaconf.DictConfig) -> Callable: modalities = cfg.encoder.input_bands.keys() def collate_fn( - batch: dict[dict[str, torch.Tensor]], + batch: dict[dict[str, torch.Tensor]] ) -> dict[dict[str, torch.Tensor]]: """Collate function for torch DataLoader args: @@ -149,19 +151,9 @@ def __init__(self, cfg, modality): self.input_bands = getattr(cfg.encoder.input_bands, modality, []) self.encoder_name = cfg.encoder.encoder_name - self.used_bands_mask = torch.tensor( - [b in self.input_bands for b in self.dataset_bands], dtype=torch.bool - ) - self.avail_bands_mask = torch.tensor( - [b in self.dataset_bands for b in self.input_bands], dtype=torch.bool - ) - self.avail_bands_indices = torch.tensor( - [ - self.dataset_bands.index(b) if b in self.dataset_bands else -1 - for b in self.input_bands - ], - dtype=torch.long, - ) + self.used_bands_mask = torch.tensor([b in self.input_bands for b in self.dataset_bands], dtype=torch.bool) + self.avail_bands_mask = torch.tensor([b in self.dataset_bands for b in self.input_bands], dtype=torch.bool) + self.avail_bands_indices = torch.tensor([self.dataset_bands.index(b) if b in self.dataset_bands else -1 for b in self.input_bands], dtype=torch.long) self.need_padded = self.avail_bands_mask.sum() < len(self.input_bands) @@ -191,18 +183,10 @@ def __init__(self, cfg, modality): ) def preprocess_band_statistics(self, data_mean, data_std, data_min, data_max): - data_mean = [ - data_mean[i] if i != -1 else 0.0 for i in self.avail_bands_indices.tolist() - ] - data_std = [ - data_std[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() - ] - data_min = [ - data_min[i] if i != -1 else -1.0 for i in self.avail_bands_indices.tolist() - ] - data_max = [ - data_max[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() - ] + data_mean = [data_mean[i] if i != -1 else 0.0 for i in self.avail_bands_indices.tolist()] + data_std = [data_std[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist()] + data_min = [data_min[i] if i != -1 else -1.0 for i in self.avail_bands_indices.tolist()] + data_max = [data_max[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist()] return data_mean, data_std, data_min, data_max def preprocess_single_timeframe(self, image): @@ -326,19 +310,19 @@ def __getitem__(self, index): # Ignore overlapping borders if h_index != 0: - tiled_data["target"][..., 0:h_label_offset, :] = ( - self.dataset_cfg.ignore_index - ) + tiled_data["target"][ + ..., 0:h_label_offset, : + ] = self.dataset_cfg.ignore_index if w_index != 0: tiled_data["target"][..., 0:w_label_offset] = self.dataset_cfg.ignore_index if h_index != self.tiles_per_dim - 1: - tiled_data["target"][..., self.output_size - h_label_offset :, :] = ( - self.dataset_cfg.ignore_index - ) + tiled_data["target"][ + ..., self.output_size - h_label_offset :, : + ] = self.dataset_cfg.ignore_index if w_index != self.tiles_per_dim - 1: - tiled_data["target"][..., self.output_size - w_label_offset :] = ( - self.dataset_cfg.ignore_index - ) + tiled_data["target"][ + ..., self.output_size - w_label_offset : + ] = self.dataset_cfg.ignore_index return tiled_data @@ -507,18 +491,14 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: - data["image"][k] = T.Resize(self.size)(v) + data["image"][k] = T.resize(v, self.size, interpolation=T.InterpolationMode.BILINEAR, antialias=True) if data["target"].ndim == 2: data["target"] = data["target"].unsqueeze(0) - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) + data["target"] = T.resize(data["target"], self.size, interpolation=T.InterpolationMode.NEAREST) data["target"] = data["target"].squeeze(0) else: - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) + data["target"] = T.resize(data["target"], self.size, interpolation=T.InterpolationMode.NEAREST) return data diff --git a/engine/trainer.py b/engine/trainer.py index 38f5325f..4c2df54e 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -92,22 +92,22 @@ def train_one_epoch(self, epoch): with torch.cuda.amp.autocast(enabled=self.enable_mixed_precision, dtype=self.precision): logits = self.model(image, output_shape=target.shape[-2:]) loss = self.compute_loss(logits, target) - self.compute_logging_metrics(logits.detach().clone(), target.detach().clone()) self.optimizer.zero_grad() - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() + if not torch.isnan(loss): + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + self.training_stats['loss'].update(loss.item()) + with torch.no_grad(): + self.compute_logging_metrics(logits, target) + if (batch_idx + 1) % self.args.log_interval == 0: + self.log(batch_idx + 1, epoch) + else: + self.logger.warning("Skip batch {} because of nan loss".format(batch_idx + 1)) self.lr_scheduler.step() - self.training_stats['loss'].update(loss.item()) - if (batch_idx + 1) % self.args.log_interval == 0: - self.log(batch_idx + 1, epoch) - self.training_stats['batch_time'].update(time.time() - end_time) - #print(self.training_stats['batch_time'].val, self.training_stats['batch_time'].avg) - end_time = time.time() - if self.use_wandb and self.rank == 0: self.wandb.log( { @@ -122,6 +122,9 @@ def train_one_epoch(self, epoch): step=epoch * len(self.train_loader) + batch_idx, ) + self.training_stats['batch_time'].update(time.time() - end_time) + end_time = time.time() + def get_checkpoint(self, epoch): checkpoint = { "model": self.model.module.state_dict(), @@ -277,35 +280,8 @@ def compute_loss(self, logits, target): @torch.no_grad() def compute_logging_metrics(self, logits, target): - # logits = F.interpolate(logits, size=target.shape[1:], mode='bilinear') - # print(logits.shape) - # print(target.shape) - mse = F.mse_loss(logits.squeeze(dim=1), target) - - # pred = torch.argmax(logits, dim=1, keepdim=True) - # target = target.unsqueeze(1) - # ignore_mask = target == -1 - # target[ignore_mask] = 0 - # ignore_mask = ignore_mask.expand(-1, logits.shape[1], -1, -1) - - # binary_pred = torch.zeros(logits.shape, dtype=bool, device=self.device) - # binary_target = torch.zeros(logits.shape, dtype=bool, device=self.device) - # binary_pred.scatter_(dim=1, index=pred, src=torch.ones_like(binary_pred)) - # binary_target.scatter_(dim=1, index=target, src=torch.ones_like(binary_target)) - # binary_pred[ignore_mask] = 0 - # binary_target[ignore_mask] = 0 - - # intersection = torch.logical_and(binary_pred, binary_target) - # union = torch.logical_or(binary_pred, binary_target) - - # acc = intersection.sum() / binary_target.sum() * 100 - # macc = torch.nanmean(intersection.sum(dim=(0, 2, 3)) / binary_target.sum(dim=(0, 2, 3))) * 100 - # miou = torch.nanmean(intersection.sum(dim=(0, 2, 3)) / union.sum(dim=(0, 2, 3))) * 100 - self.training_metrics['MSE'].update(mse.item()) - # self.training_metrics['mAcc'].update(macc.item()) - # self.training_metrics['mIoU'].update(miou.item()) diff --git a/foundation_models/prithvi_encoder.py b/foundation_models/prithvi_encoder.py index 6f94ca38..79508983 100644 --- a/foundation_models/prithvi_encoder.py +++ b/foundation_models/prithvi_encoder.py @@ -108,11 +108,7 @@ def forward(self, image): for i, blk in enumerate(self.blocks): x = blk(x) if i in self.output_layers: - #out = self.norm(x) if i == 11 else x - # print(x.shape) out = x[:, 1:, :].permute(0, 2, 1).view(x.shape[0], -1, self.num_frames, self.img_size // self.patch_size, self.img_size // self.patch_size).squeeze(2).contiguous() - # out = x[:, 1:, :].permute(0, 2, 1).reshape(x.shape[0], -1, self.img_size // self.patch_size, self.img_size // self.patch_size).contiguous() - output.append(out) return output diff --git a/run.py b/run.py index f861287d..4f20c1da 100644 --- a/run.py +++ b/run.py @@ -248,8 +248,6 @@ def main(): collate_fn = get_collate_fn(cfg) # training if not cfg.eval_dir: - - if 0 < cfg.limited_label < 1: indices = random.sample(range(len(train_dataset)), int(len(train_dataset)*cfg.limited_label)) train_dataset = Subset(train_dataset, indices) @@ -357,30 +355,30 @@ def main(): trainer.train() # Evaluation - else: - test_loader = DataLoader( - test_dataset, - sampler=DistributedSampler(test_dataset), - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - pin_memory=True, - persistent_workers=False, - drop_last=False, - collate_fn=collate_fn, - ) + test_loader = DataLoader( + test_dataset, + sampler=DistributedSampler(test_dataset), + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + persistent_workers=False, + drop_last=False, + collate_fn=collate_fn, + ) - logger.info("Built {} dataset for evaluation.".format(dataset_name)) + logger.info("Built {} dataset for evaluation.".format(dataset_name)) - if task_name == "regression": - # TODO: This doesn't work atm - test_evaluator = RegEvaluator(cfg, test_loader, exp_dir, device) - else: - test_evaluator = SegEvaluator(cfg, test_loader, exp_dir, device) + if task_name == "regression": + # TODO: This doesn't work atm + test_evaluator = RegEvaluator(cfg, test_loader, exp_dir, device) + else: + test_evaluator = SegEvaluator(cfg, test_loader, exp_dir, device) + + model_ckpt_path = os.path.join( + exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_best.pth")) + ) + test_evaluator.evaluate(model, "best model", model_ckpt_path) - model_ckpt_path = os.path.join( - exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_best.pth")) - ) - test_evaluator.evaluate(model, "best model", model_ckpt_path) if cfg.use_wandb and cfg.rank == 0: wandb.finish() diff --git a/segmentors/upernet.py b/segmentors/upernet.py index fdfc26e6..f4fe5951 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -231,7 +231,7 @@ def forward(self, img, output_shape=None): if self.multi_temporal_strategy == "ltae": feats[i] = self.tmap(feats[i]) elif self.multi_temporal_strategy == "linear": - feats[i] = self.tmap(feats[i].permute(0,1,3,4,2)).squeeze(-1) + feats[i] = self.tmap(feats[i].permute(0, 1, 3, 4, 2)).squeeze(-1) feat = self.neck(feats) feat = self._forward_feature(feat)