-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add some preliminary tests and a github action for them (#30)
* Add some preliminary tests * Add test github workflow * Enable workflow on test branch * Add badge, remove defaults channel * Add env cache * Learn to spell * Cache the correct miniconda dir * See below * Update actions to latest version * Update actions to latest version again * Enable linting * Fix linting? * Fix linting? * Fix linting?
- Loading branch information
1 parent
00f1816
commit eea4285
Showing
11 changed files
with
197 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | ||
|
||
name: Tests | ||
|
||
on: | ||
push: | ||
branches: [ "main", "feature/test" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
env: | ||
CACHE_NUMBER: 0 # increase to reset cache manually | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build: | ||
defaults: | ||
run: | ||
shell: bash -l {0} | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Lint with flake8 | ||
run: | | ||
pip install flake8 | ||
# stop the build if there are Python syntax errors or undefined names | ||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | ||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide | ||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | ||
- name: Create environment with mamba | ||
uses: conda-incubator/setup-miniconda@v3 | ||
with: | ||
mamba-version: "*" | ||
channels: pytorch, nvidia, conda-forge | ||
auto-activate-base: false | ||
activate-environment: gfm-bench8 | ||
|
||
- name: Set cache date | ||
run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV | ||
- uses: actions/cache@v4 | ||
with: | ||
path: /usr/share/miniconda/envs/gfm-bench8 | ||
key: conda-${{ hashFiles('environment.yaml') }}-${{ env.DATE }}-${{ env.CACHE_NUMBER }} | ||
id: cache | ||
|
||
- name: Update environment | ||
run: mamba env update -n gfm-bench8 -f environment.yaml | ||
if: steps.cache.outputs.cache-hit != 'true' | ||
|
||
- name: Check solution | ||
run: | | ||
mamba env export | ||
- name: Test with pytest | ||
run: | | ||
python -m unittest tests.test_imports |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import unittest | ||
|
||
|
||
class testDatasetSetup(unittest.TestCase): | ||
def setUp(self): | ||
# TODO should we just glob these for convinience? | ||
self.datasets = { | ||
"ai4smallfarms": "configs/datasets/ai4smallfarms.yaml", | ||
"biomassters": "configs/datasets/biomassters.yaml", | ||
"croptypemapping": "configs/datasets/croptypemapping.yaml", | ||
"fivebillionpixels": "configs/datasets/fivebillionpixels.yaml", | ||
"hlsburnscars": "configs/datasets/hlsburnscars.yaml", | ||
"mados": "configs/datasets/mados.yaml", | ||
"sen1floods11": "configs/datasets/sen1floods11.yaml", | ||
"spacenet7": "configs/datasets/spacenet7.yaml", | ||
"spacenet7cd": "configs/datasets/spacenet7cd.yaml", | ||
"xview2": "configs/datasets/xview2.yaml", | ||
} | ||
|
||
def test_download(self): | ||
from utils.configs import load_configs | ||
import foundation_models.utils | ||
from run import parser | ||
from utils.registry import DATASET_REGISTRY | ||
|
||
for dataset in self.datasets.keys(): | ||
for dataset, config_path in self.datasets.items(): | ||
mock_argv = [ | ||
'run.py', | ||
'--config', 'configs/run/mados_prithvi.yaml', | ||
'--dataset_config', config_path | ||
] | ||
with unittest.mock.patch('sys.argv', mock_argv): | ||
with self.subTest(dataset=dataset): | ||
print(f"Downloading dataset {dataset}") | ||
cfg = load_configs(parser) | ||
|
||
dataset = DATASET_REGISTRY.get(cfg.dataset.dataset_name) | ||
dataset.download(cfg.dataset, silent=False) | ||
dataset_splits = dataset.get_splits(cfg.dataset) | ||
|
||
for ds in dataset_splits: | ||
input = next(iter(ds)) | ||
self.assertTrue(input) # TODO some sanity checks here based on the config file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import unittest | ||
|
||
|
||
class testPackageImports(unittest.TestCase): | ||
def test_datasets(self): | ||
import datasets | ||
|
||
def test_foundation_models(self): | ||
import foundation_models | ||
|
||
def test_segmentors(self): | ||
import segmentors | ||
|
||
def test_engine(self): | ||
import engine | ||
|
||
def test_run(self): | ||
import run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import unittest | ||
|
||
import os | ||
|
||
import torch.nn as nn | ||
|
||
from omegaconf import OmegaConf | ||
|
||
class testModelBuild(unittest.TestCase): | ||
def setUp(self): | ||
self.models = { | ||
'croma': 'configs/foundation_models/croma.yaml', | ||
'dofa': 'configs/foundation_models/dofa.yaml', | ||
'gfmswin': 'configs/foundation_models/gfmswin.yaml', | ||
'prithvi': 'configs/foundation_models/prithvi.yaml', | ||
'remoteclip': 'configs/foundation_models/remoteclip.yaml', | ||
'satlasnet': 'configs/foundation_models/satlasnet.yaml', | ||
'scalemae': 'configs/foundation_models/scalemae.yaml', | ||
'spectralgpt': 'configs/foundation_models/spectralgpt.yaml', | ||
'ssl4eo_data2vec': 'configs/foundation_models/ssl4eo_data2vec.yaml', | ||
'ssl4eo_dino': 'configs/foundation_models/ssl4eo_dino.yaml', | ||
'ssl4eo_mae': 'configs/foundation_models/ssl4eo_mae.yaml', | ||
'ssl4eo_moco': 'configs/foundation_models/ssl4eo_moco.yaml', | ||
'unet_encoder': 'configs/foundation_models/unet_encoder.yaml', | ||
'ssl4eo_moco': 'configs/models_config/ssl4eo_mae.yaml', | ||
} | ||
|
||
def test_download(self): | ||
from utils.configs import load_configs | ||
import foundation_models.utils | ||
from run import parser | ||
|
||
for model, config_path in self.models.items(): | ||
mock_argv = [ | ||
'run.py', | ||
'--config', 'configs/run/mados_prithvi.yaml', | ||
'--encoder_config', config_path | ||
] | ||
with unittest.mock.patch('sys.argv', mock_argv): | ||
with self.subTest(model=model): | ||
cfg = load_configs(parser) | ||
|
||
if 'download_url' in cfg.encoder: | ||
if os.path.isfile(cfg.encoder.encoder_weights): | ||
os.remove(cfg.encoder.encoder_weights) | ||
res = foundation_models.utils.download_model(cfg.encoder) | ||
self.assertTrue(res) | ||
|
||
# def test_build(self): | ||
# for model in self.models.keys(): | ||
# with self.subTest(model=model): | ||
# print(f"\nTesting {model}:") | ||
# cfg = {'encoder_config': self.models[model]} | ||
# model_cfg = load_specific_config(cfg, 'encoder_config') | ||
|
||
# model = make_encoder(model_cfg) | ||
# self.assertIsInstance(model, nn.Module) | ||
# del model |