Skip to content

Commit

Permalink
Add some preliminary tests and a github action for them (#30)
Browse files Browse the repository at this point in the history
* Add some preliminary tests

* Add test github workflow

* Enable workflow on test branch

* Add badge, remove defaults channel

* Add env cache

* Learn to spell

* Cache the correct miniconda dir

* See below

* Update actions to latest version

* Update actions to latest version again

* Enable linting

* Fix linting?

* Fix linting?

* Fix linting?
  • Loading branch information
KerekesDavid authored Sep 11, 2024
1 parent 00f1816 commit eea4285
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 105 deletions.
67 changes: 67 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests

on:
push:
branches: [ "main", "feature/test" ]
pull_request:
branches: [ "main" ]

env:
CACHE_NUMBER: 0 # increase to reset cache manually

permissions:
contents: read

jobs:
build:
defaults:
run:
shell: bash -l {0}

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Lint with flake8
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Create environment with mamba
uses: conda-incubator/setup-miniconda@v3
with:
mamba-version: "*"
channels: pytorch, nvidia, conda-forge
auto-activate-base: false
activate-environment: gfm-bench8

- name: Set cache date
run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
path: /usr/share/miniconda/envs/gfm-bench8
key: conda-${{ hashFiles('environment.yaml') }}-${{ env.DATE }}-${{ env.CACHE_NUMBER }}
id: cache

- name: Update environment
run: mamba env update -n gfm-bench8 -f environment.yaml
if: steps.cache.outputs.cache-hit != 'true'

- name: Check solution
run: |
mamba env export
- name: Test with pytest
run: |
python -m unittest tests.test_imports
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[![Tests](https://github.com/yurujaja/geofm-bench/actions/workflows/python-test.yml/badge.svg)](https://github.com/yurujaja/geofm-bench/actions/workflows/python-test.yml)

## What is New
In general, the architecture of the whole codebase is refactored and a few bugs and errors are fixed by the way.

Expand Down
13 changes: 1 addition & 12 deletions datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,4 @@ def get_splits(dataset_config):

@staticmethod
def download(dataset_config:dict, silent=False):
pass


if __name__ == '__main__':

dataset = BioMassters(cfg, split = "test")

train_dict = dataset.__getitem__(0)

print(train_dict["image"]["optical"].shape)
print(train_dict["image"]["sar"].shape)
print(train_dict["target"].shape)
pass
2 changes: 2 additions & 0 deletions datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ def download_blob_file_pair(blob_file_pair):
else:
print("Downloaded {} to {}.".format(name, destination_directory + name))


def read_tif(file: pathlib.Path):
with rasterio.open(file) as dataset:
arr = dataset.read() # (bands X height X width)
return arr.transpose((1, 2, 0))


def read_tif_with_metadata(file: pathlib.Path):
with rasterio.open(file) as dataset:
arr = dataset.read() # (bands X height X width)
Expand Down
6 changes: 4 additions & 2 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name: gfm-bench8
channels:
- pytorch
- conda-forge
- nvidia
- conda-forge
- nodefaults
dependencies:
- python=3.10
- geopandas
Expand All @@ -25,4 +26,5 @@ dependencies:
- fastai::opencv-python-headless
- google-cloud-storage
- omegaconf
- pydataverse
- pydataverse
- pytest
31 changes: 1 addition & 30 deletions foundation_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@

# from .spectralgpt import vit_spectral_gpt
# from .prithvi import MaskedAutoencoderViT
# from .scalemae import ScaleMAE_baseline
# from .croma import croma_vit
# from .remoteclip import RemoteCLIP
# from .ssl4eo_mae import mae_vit
# from .ssl4eo_dino import vit
# from .ssl4eo_moco import moco_vit
# from .ssl4eo_data2vec import beit
# from .dofa import dofa_vit
# from .gfm_swin import SwinTransformer as GFM_SwinTransformer
# from .gfm_swin import adapt_gfm_pretrained
# from .satlasnet import Model as SATLASNet
# from .satlasnet import Weights as SATLASNetWeights
from .prithvi_encoder import Prithvi_Encoder
from .remoteclip_encoder import RemoteCLIP_Encoder
from .scalemae_encoder import ScaleMAE_Encoder
Expand All @@ -25,18 +10,4 @@
from .ssl4eo_moco_encoder import SSL4EO_MOCO_Encoder
from .ssl4eo_data2vec_encoder import SSL4EO_Data2Vec_Encoder
from .ssl4eo_mae_encoder import SSL4EO_MAE_OPTICAL_Encoder, SSL4EO_MAE_SAR_Encoder
from .unet_encoder import UNet_Encoder
#
# spectral_gpt_vit_base = vit_spectral_gpt
# prithvi_vit_base = MaskedAutoencoderViT
# scale_mae_large = ScaleMAE_baseline
# croma = croma_vit
# remote_clip = RemoteCLIP
# ssl4eo_dino_small = vit
# ssl4eo_moco = moco_vit
# ssl4eo_data2vec_small = beit
# gfm_swin_base = GFM_SwinTransformer
# satlasnet = SATLASNet
# dofa_vit = dofa_vit
# ssl4eo_mae = mae_vit
# adapt_gfm_pretrained = adapt_gfm_pretrained
from .unet_encoder import UNet_Encoder
61 changes: 0 additions & 61 deletions test.py

This file was deleted.

Empty file added tests/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest


class testDatasetSetup(unittest.TestCase):
def setUp(self):
# TODO should we just glob these for convinience?
self.datasets = {
"ai4smallfarms": "configs/datasets/ai4smallfarms.yaml",
"biomassters": "configs/datasets/biomassters.yaml",
"croptypemapping": "configs/datasets/croptypemapping.yaml",
"fivebillionpixels": "configs/datasets/fivebillionpixels.yaml",
"hlsburnscars": "configs/datasets/hlsburnscars.yaml",
"mados": "configs/datasets/mados.yaml",
"sen1floods11": "configs/datasets/sen1floods11.yaml",
"spacenet7": "configs/datasets/spacenet7.yaml",
"spacenet7cd": "configs/datasets/spacenet7cd.yaml",
"xview2": "configs/datasets/xview2.yaml",
}

def test_download(self):
from utils.configs import load_configs
import foundation_models.utils
from run import parser
from utils.registry import DATASET_REGISTRY

for dataset in self.datasets.keys():
for dataset, config_path in self.datasets.items():
mock_argv = [
'run.py',
'--config', 'configs/run/mados_prithvi.yaml',
'--dataset_config', config_path
]
with unittest.mock.patch('sys.argv', mock_argv):
with self.subTest(dataset=dataset):
print(f"Downloading dataset {dataset}")
cfg = load_configs(parser)

dataset = DATASET_REGISTRY.get(cfg.dataset.dataset_name)
dataset.download(cfg.dataset, silent=False)
dataset_splits = dataset.get_splits(cfg.dataset)

for ds in dataset_splits:
input = next(iter(ds))
self.assertTrue(input) # TODO some sanity checks here based on the config file
18 changes: 18 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest


class testPackageImports(unittest.TestCase):
def test_datasets(self):
import datasets

def test_foundation_models(self):
import foundation_models

def test_segmentors(self):
import segmentors

def test_engine(self):
import engine

def test_run(self):
import run
58 changes: 58 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest

import os

import torch.nn as nn

from omegaconf import OmegaConf

class testModelBuild(unittest.TestCase):
def setUp(self):
self.models = {
'croma': 'configs/foundation_models/croma.yaml',
'dofa': 'configs/foundation_models/dofa.yaml',
'gfmswin': 'configs/foundation_models/gfmswin.yaml',
'prithvi': 'configs/foundation_models/prithvi.yaml',
'remoteclip': 'configs/foundation_models/remoteclip.yaml',
'satlasnet': 'configs/foundation_models/satlasnet.yaml',
'scalemae': 'configs/foundation_models/scalemae.yaml',
'spectralgpt': 'configs/foundation_models/spectralgpt.yaml',
'ssl4eo_data2vec': 'configs/foundation_models/ssl4eo_data2vec.yaml',
'ssl4eo_dino': 'configs/foundation_models/ssl4eo_dino.yaml',
'ssl4eo_mae': 'configs/foundation_models/ssl4eo_mae.yaml',
'ssl4eo_moco': 'configs/foundation_models/ssl4eo_moco.yaml',
'unet_encoder': 'configs/foundation_models/unet_encoder.yaml',
'ssl4eo_moco': 'configs/models_config/ssl4eo_mae.yaml',
}

def test_download(self):
from utils.configs import load_configs
import foundation_models.utils
from run import parser

for model, config_path in self.models.items():
mock_argv = [
'run.py',
'--config', 'configs/run/mados_prithvi.yaml',
'--encoder_config', config_path
]
with unittest.mock.patch('sys.argv', mock_argv):
with self.subTest(model=model):
cfg = load_configs(parser)

if 'download_url' in cfg.encoder:
if os.path.isfile(cfg.encoder.encoder_weights):
os.remove(cfg.encoder.encoder_weights)
res = foundation_models.utils.download_model(cfg.encoder)
self.assertTrue(res)

# def test_build(self):
# for model in self.models.keys():
# with self.subTest(model=model):
# print(f"\nTesting {model}:")
# cfg = {'encoder_config': self.models[model]}
# model_cfg = load_specific_config(cfg, 'encoder_config')

# model = make_encoder(model_cfg)
# self.assertIsInstance(model, nn.Module)
# del model

0 comments on commit eea4285

Please sign in to comment.