From eea42858a1caf5c3d8ec2ef4e15a2c6df667be55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kerekes=20D=C3=A1vid?= Date: Wed, 11 Sep 2024 10:46:14 +0200 Subject: [PATCH] Add some preliminary tests and a github action for them (#30) * Add some preliminary tests * Add test github workflow * Enable workflow on test branch * Add badge, remove defaults channel * Add env cache * Learn to spell * Cache the correct miniconda dir * See below * Update actions to latest version * Update actions to latest version again * Enable linting * Fix linting? * Fix linting? * Fix linting? --- .github/workflows/python-test.yml | 67 +++++++++++++++++++++++++++++++ README.md | 2 + datasets/biomassters.py | 13 +----- datasets/utils.py | 2 + environment.yaml | 6 ++- foundation_models/__init__.py | 31 +------------- test.py | 61 ---------------------------- tests/__init__.py | 0 tests/test_datasets.py | 44 ++++++++++++++++++++ tests/test_imports.py | 18 +++++++++ tests/test_models.py | 58 ++++++++++++++++++++++++++ 11 files changed, 197 insertions(+), 105 deletions(-) create mode 100644 .github/workflows/python-test.yml delete mode 100644 test.py create mode 100644 tests/__init__.py create mode 100644 tests/test_datasets.py create mode 100644 tests/test_imports.py create mode 100644 tests/test_models.py diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml new file mode 100644 index 00000000..4ef4d443 --- /dev/null +++ b/.github/workflows/python-test.yml @@ -0,0 +1,67 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Tests + +on: + push: + branches: [ "main", "feature/test" ] + pull_request: + branches: [ "main" ] + +env: + CACHE_NUMBER: 0 # increase to reset cache manually + +permissions: + contents: read + +jobs: + build: + defaults: + run: + shell: bash -l {0} + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Lint with flake8 + run: | + pip install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + - name: Create environment with mamba + uses: conda-incubator/setup-miniconda@v3 + with: + mamba-version: "*" + channels: pytorch, nvidia, conda-forge + auto-activate-base: false + activate-environment: gfm-bench8 + + - name: Set cache date + run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV + - uses: actions/cache@v4 + with: + path: /usr/share/miniconda/envs/gfm-bench8 + key: conda-${{ hashFiles('environment.yaml') }}-${{ env.DATE }}-${{ env.CACHE_NUMBER }} + id: cache + + - name: Update environment + run: mamba env update -n gfm-bench8 -f environment.yaml + if: steps.cache.outputs.cache-hit != 'true' + + - name: Check solution + run: | + mamba env export + + - name: Test with pytest + run: | + python -m unittest tests.test_imports diff --git a/README.md b/README.md index ffed85ea..b11d1850 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Tests](https://github.com/yurujaja/geofm-bench/actions/workflows/python-test.yml/badge.svg)](https://github.com/yurujaja/geofm-bench/actions/workflows/python-test.yml) + ## What is New In general, the architecture of the whole codebase is refactored and a few bugs and errors are fixed by the way. diff --git a/datasets/biomassters.py b/datasets/biomassters.py index 09484e31..94ec389b 100644 --- a/datasets/biomassters.py +++ b/datasets/biomassters.py @@ -97,15 +97,4 @@ def get_splits(dataset_config): @staticmethod def download(dataset_config:dict, silent=False): - pass - - -if __name__ == '__main__': - - dataset = BioMassters(cfg, split = "test") - - train_dict = dataset.__getitem__(0) - - print(train_dict["image"]["optical"].shape) - print(train_dict["image"]["sar"].shape) - print(train_dict["target"].shape) \ No newline at end of file + pass \ No newline at end of file diff --git a/datasets/utils.py b/datasets/utils.py index b1fcb775..9b28cd88 100644 --- a/datasets/utils.py +++ b/datasets/utils.py @@ -81,11 +81,13 @@ def download_blob_file_pair(blob_file_pair): else: print("Downloaded {} to {}.".format(name, destination_directory + name)) + def read_tif(file: pathlib.Path): with rasterio.open(file) as dataset: arr = dataset.read() # (bands X height X width) return arr.transpose((1, 2, 0)) + def read_tif_with_metadata(file: pathlib.Path): with rasterio.open(file) as dataset: arr = dataset.read() # (bands X height X width) diff --git a/environment.yaml b/environment.yaml index e3ecbd25..29d610e9 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,8 +1,9 @@ name: gfm-bench8 channels: - pytorch - - conda-forge - nvidia + - conda-forge + - nodefaults dependencies: - python=3.10 - geopandas @@ -25,4 +26,5 @@ dependencies: - fastai::opencv-python-headless - google-cloud-storage - omegaconf - - pydataverse \ No newline at end of file + - pydataverse + - pytest \ No newline at end of file diff --git a/foundation_models/__init__.py b/foundation_models/__init__.py index 11e0cd75..55e90d10 100644 --- a/foundation_models/__init__.py +++ b/foundation_models/__init__.py @@ -1,18 +1,3 @@ - -# from .spectralgpt import vit_spectral_gpt -# from .prithvi import MaskedAutoencoderViT -# from .scalemae import ScaleMAE_baseline -# from .croma import croma_vit -# from .remoteclip import RemoteCLIP -# from .ssl4eo_mae import mae_vit -# from .ssl4eo_dino import vit -# from .ssl4eo_moco import moco_vit -# from .ssl4eo_data2vec import beit -# from .dofa import dofa_vit -# from .gfm_swin import SwinTransformer as GFM_SwinTransformer -# from .gfm_swin import adapt_gfm_pretrained -# from .satlasnet import Model as SATLASNet -# from .satlasnet import Weights as SATLASNetWeights from .prithvi_encoder import Prithvi_Encoder from .remoteclip_encoder import RemoteCLIP_Encoder from .scalemae_encoder import ScaleMAE_Encoder @@ -25,18 +10,4 @@ from .ssl4eo_moco_encoder import SSL4EO_MOCO_Encoder from .ssl4eo_data2vec_encoder import SSL4EO_Data2Vec_Encoder from .ssl4eo_mae_encoder import SSL4EO_MAE_OPTICAL_Encoder, SSL4EO_MAE_SAR_Encoder -from .unet_encoder import UNet_Encoder -# -# spectral_gpt_vit_base = vit_spectral_gpt -# prithvi_vit_base = MaskedAutoencoderViT -# scale_mae_large = ScaleMAE_baseline -# croma = croma_vit -# remote_clip = RemoteCLIP -# ssl4eo_dino_small = vit -# ssl4eo_moco = moco_vit -# ssl4eo_data2vec_small = beit -# gfm_swin_base = GFM_SwinTransformer -# satlasnet = SATLASNet -# dofa_vit = dofa_vit -# ssl4eo_mae = mae_vit -# adapt_gfm_pretrained = adapt_gfm_pretrained \ No newline at end of file +from .unet_encoder import UNet_Encoder \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index 8b0d1a9c..00000000 --- a/test.py +++ /dev/null @@ -1,61 +0,0 @@ -from datasets import CropTypeMappingSouthSudan -from omegaconf import OmegaConf -import pdb -from collections import Counter -import numpy as np -import os -import tarfile -import shutil - -ds = CropTypeMappingSouthSudan -cfg = OmegaConf.load('configs/datasets/croptypemapping.yaml') - -ds_train, ds_val, ds_test = ds.get_splits(dataset_config=cfg) - -s2_sum = [] -s1_sum = [] -for i in range(len(ds_train)): - data = ds_train[i] - s2_sum.append(data['image']['optical']) - s1_sum.append(data['image']['sar']) - # print(data['image']['sar'].shape) - # print(data['target'].shape) - # print(data['metadata']['s2'].shape) - # print(data['metadata']['s1'].shape) - # print(i) - # print(data['image']['optical'].shape) - # print(data['image']['sar'].shape) - # print(data['target'].shape) - # print(data['target'].unique()) - # print(data['metadata']['s2'].shape) - # print(data['metadata']['s1'].shape) -s2_sum = np.concatenate(s2_sum, axis=1) -s1_sum = np.concatenate(s1_sum, axis=1) -print(s2_sum.shape) -print(s1_sum.shape) - -s2_mean = np.mean(s2_sum, axis=(1,2,3)) -s2_std = np.std(s2_sum, axis=(1,2,3)) -s1_mean = np.mean(s1_sum, axis=(1,2,3)) -s1_std = np.std(s1_sum, axis=(1,2,3)) -print(s2_mean) -print(s2_std) -print(s1_mean) -print(s1_std) -# target_sum = np.concatenate(target_sum) - -# # Get unique values and their counts -# unique_values, counts = np.unique(target_sum, return_counts=True) - -# # Print the counts -# sum = 0 -# for value, count in zip(unique_values, counts): -# if value == 0: -# continue -# sum += count -# print(f"Value {value}: {count} times") - -# for value, count in zip(unique_values, counts): -# if value == 0: -# continue -# print(f"Value {value}: {count/sum*100:.2f}%") \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 00000000..52b4dc6a --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,44 @@ +import unittest + + +class testDatasetSetup(unittest.TestCase): + def setUp(self): + # TODO should we just glob these for convinience? + self.datasets = { + "ai4smallfarms": "configs/datasets/ai4smallfarms.yaml", + "biomassters": "configs/datasets/biomassters.yaml", + "croptypemapping": "configs/datasets/croptypemapping.yaml", + "fivebillionpixels": "configs/datasets/fivebillionpixels.yaml", + "hlsburnscars": "configs/datasets/hlsburnscars.yaml", + "mados": "configs/datasets/mados.yaml", + "sen1floods11": "configs/datasets/sen1floods11.yaml", + "spacenet7": "configs/datasets/spacenet7.yaml", + "spacenet7cd": "configs/datasets/spacenet7cd.yaml", + "xview2": "configs/datasets/xview2.yaml", + } + + def test_download(self): + from utils.configs import load_configs + import foundation_models.utils + from run import parser + from utils.registry import DATASET_REGISTRY + + for dataset in self.datasets.keys(): + for dataset, config_path in self.datasets.items(): + mock_argv = [ + 'run.py', + '--config', 'configs/run/mados_prithvi.yaml', + '--dataset_config', config_path + ] + with unittest.mock.patch('sys.argv', mock_argv): + with self.subTest(dataset=dataset): + print(f"Downloading dataset {dataset}") + cfg = load_configs(parser) + + dataset = DATASET_REGISTRY.get(cfg.dataset.dataset_name) + dataset.download(cfg.dataset, silent=False) + dataset_splits = dataset.get_splits(cfg.dataset) + + for ds in dataset_splits: + input = next(iter(ds)) + self.assertTrue(input) # TODO some sanity checks here based on the config file \ No newline at end of file diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 00000000..6fd147a3 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,18 @@ +import unittest + + +class testPackageImports(unittest.TestCase): + def test_datasets(self): + import datasets + + def test_foundation_models(self): + import foundation_models + + def test_segmentors(self): + import segmentors + + def test_engine(self): + import engine + + def test_run(self): + import run diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..0bb76327 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,58 @@ +import unittest + +import os + +import torch.nn as nn + +from omegaconf import OmegaConf + +class testModelBuild(unittest.TestCase): + def setUp(self): + self.models = { + 'croma': 'configs/foundation_models/croma.yaml', + 'dofa': 'configs/foundation_models/dofa.yaml', + 'gfmswin': 'configs/foundation_models/gfmswin.yaml', + 'prithvi': 'configs/foundation_models/prithvi.yaml', + 'remoteclip': 'configs/foundation_models/remoteclip.yaml', + 'satlasnet': 'configs/foundation_models/satlasnet.yaml', + 'scalemae': 'configs/foundation_models/scalemae.yaml', + 'spectralgpt': 'configs/foundation_models/spectralgpt.yaml', + 'ssl4eo_data2vec': 'configs/foundation_models/ssl4eo_data2vec.yaml', + 'ssl4eo_dino': 'configs/foundation_models/ssl4eo_dino.yaml', + 'ssl4eo_mae': 'configs/foundation_models/ssl4eo_mae.yaml', + 'ssl4eo_moco': 'configs/foundation_models/ssl4eo_moco.yaml', + 'unet_encoder': 'configs/foundation_models/unet_encoder.yaml', + 'ssl4eo_moco': 'configs/models_config/ssl4eo_mae.yaml', + } + + def test_download(self): + from utils.configs import load_configs + import foundation_models.utils + from run import parser + + for model, config_path in self.models.items(): + mock_argv = [ + 'run.py', + '--config', 'configs/run/mados_prithvi.yaml', + '--encoder_config', config_path + ] + with unittest.mock.patch('sys.argv', mock_argv): + with self.subTest(model=model): + cfg = load_configs(parser) + + if 'download_url' in cfg.encoder: + if os.path.isfile(cfg.encoder.encoder_weights): + os.remove(cfg.encoder.encoder_weights) + res = foundation_models.utils.download_model(cfg.encoder) + self.assertTrue(res) + + # def test_build(self): + # for model in self.models.keys(): + # with self.subTest(model=model): + # print(f"\nTesting {model}:") + # cfg = {'encoder_config': self.models[model]} + # model_cfg = load_specific_config(cfg, 'encoder_config') + + # model = make_encoder(model_cfg) + # self.assertIsInstance(model, nn.Module) + # del model \ No newline at end of file