Skip to content

Commit

Permalink
Merge pull request #11 from AdaptiveMotorControlLab/mwm/tests-packaging
Browse files Browse the repository at this point in the history
some re-arrangement & additions for python packaging
  • Loading branch information
hsirm authored Oct 28, 2024
2 parents 7960ad5 + f360df6 commit 5e0634a
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 35 deletions.
65 changes: 65 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Python package

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build:
strategy:
fail-fast: true
matrix:
os: [ubuntu-latest]
python-version: ["3.12"]
# We aim to support the versions on pytorch.org
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.4.0"]
include:
- os: windows-latest
torch-version: 2.4.0
python-version: "3.12"

runs-on: ${{ matrix.os }}

steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}

- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install package
run: |
python -m pip install git+https://github.com/RobustBench/robustbench.git
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev]'
- name: Run pytest tests
timeout-minutes: 10
run: |
pip install pytest
python -m pytest
- name: Build package
run: |
make build
- name: Check reinstall script
timeout-minutes: 3
run: |
./reinstall.sh
21 changes: 21 additions & 0 deletions .github/workflows/codespell.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
name: Codespell

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
codespell:
name: Check for spelling errors
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v3
- name: Codespell
uses: codespell-project/actions-codespell@v1
with:
ignore_words_list: aros, fpr, tpr, idx, fpr95
52 changes: 52 additions & 0 deletions .github/workflows/release-pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: release

on:
push:
tags:
- 'v*.*.*'
pull_request:
branches:
- main
types:
- labeled
- opened
- edited
- synchronize
- reopened

jobs:
release:
runs-on: ubuntu-latest

steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip

- name: Checkout code
uses: actions/checkout@v3

- name: Build and publish to Test PyPI
if: ${{ (github.ref != 'refs/heads/main') && (github.event.label.name == 'release') }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }}
run: |
make dist
ls dist/
tar tvf dist/aros-node-*.tar.gz
python3 -m twine upload --repository testpypi dist/*
- name: Build and publish to PyPI
if: ${{ github.event_name == 'push' }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
make dist
ls dist/
tar tvf dist/aros-node-*.tar.gz
python3 -m twine upload dist/*
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DS_Store
.tar.gz

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
3 changes: 0 additions & 3 deletions AROS/__init__.py

This file was deleted.

9 changes: 9 additions & 0 deletions aros_node/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# © M.W. Mathis Lab | Hossein Mirzaei & M.W. Mathis
# https://github.com/AdaptiveMotorControlLab/AROS
# Licensed under Apache 2.0

from aros_node.version import __version__
from aros_node.data_loader import LabelChangedDataset, get_subsampled_subset, get_loaders
from aros_node.evaluate import compute_fpr95, compute_auroc, compute_aupr, get_clean_AUC, wrapper_method
from aros_node.stability_loss_function import *
from aros_node.utils import *
File renamed without changes.
2 changes: 1 addition & 1 deletion AROS/evaluate.py → aros_node/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
from utils import *
from aros_node.utils import *
import argparse
import torch
import torch.nn as nn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from robustbench.utils import load_model
import torch.nn as nn
from torch.nn.parameter import Parameter
import utils
from utils import *
import aros_node.utils
from aros_node.utils import *
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset, SubsetRandomSampler, ConcatDataset
import numpy as np
from tqdm.notebook import tqdm
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions aros_node/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
10 changes: 5 additions & 5 deletions AROS/main.py → main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

!pip install -r requirements.txt
import aros_node
import argparse
import torch
import torch.nn as nn
from evaluate import *
from utils import *
from aros_node.evaluate import *
from aros_node.utils import *
from tqdm.notebook import tqdm
from data_loader import *
from stability_loss_function import *
from aros_node.data_loader import *
from aros_node.stability_loss_function import *

def main():
parser = argparse.ArgumentParser(description="Hyperparameters for the script")
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
25 changes: 25 additions & 0 deletions reinstall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# Re-install the package. By running './reinstall.sh'
#
# Note that AROS uses the build
# system specified in
# PEP517 https://peps.python.org/pep-0517/ and
# PEP518 https://peps.python.org/pep-0518/
# and hence there is no setup.py file.

set -e # abort on error

pip uninstall -y aros-node

# Get version
VERSION=0.0.1
echo "Upgrading to AROS v${VERSION}"

# Upgrade the build system (PEP517/518 compatible)
python3 -m pip install virtualenv
python3 -m pip install --upgrade build
python3 -m build --sdist --wheel .

# Reinstall the package with most recent version
pip install --upgrade --no-cache-dir "dist/aros_node-${VERSION}-py3-none-any.whl"
9 changes: 7 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
geotorch
torch
torchdiffeq
git+https://github.com/RobustBench/robustbench.git
timm==1.0.9
timm==1.0.9
robustbench
numpy
scikit-learn
scipy
tqdm
26 changes: 4 additions & 22 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@
[metadata]
name = aros
version = 0.0.1
name = aros-node
version = attr: aros_node.version.__version__
author = Hossein Mirzaei, Mackenzie Mathis
author_email = [email protected]
description = AROS: Adversarially Robust Out-of-Distribution Detection through Stability
long_description = file: README.md
long_description_content_type = text/markdown
license_files = LICENSE.md
license_file_type = text/markdown
url = https://github.com/AdaptiveMotorControlLab/AROS
project_urls =
Bug Tracker = https://github.com/AdaptiveMotorControlLab/AROS/issues
classifiers =
Development Status :: 4 - Beta
Environment :: GPU :: NVIDIA CUDA
Intended Audience :: Science/Research
Operating System :: OS Independent
Programming Language :: Python :: 3
Topic :: Scientific/Engineering :: Artificial Intelligence
License :: OSI Approved :: Apache Software License

[options]
packages = find:
include_package_data = True
python_requires = >=3.10
install_requires =
geotorch
torchdiffeq
git+https://github.com/RobustBench/robustbench.git
install_requires = file: requirements.txt

[options.extras_require]
dev =
pylint
toml
yapf
black
pytest

[bdist_wheel]
universal=0
pytest
69 changes: 69 additions & 0 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor
from aros_node import (
LabelChangedDataset,
get_subsampled_subset,
get_loaders,
)

# Set up transformations and datasets for tests
transform_tensor = ToTensor()

@pytest.fixture
def cifar10_datasets():
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_tensor)
testset = CIFAR10(root='./data', train=False, download=True, transform=transform_tensor)
return trainset, testset

@pytest.fixture
def cifar100_datasets():
trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_tensor)
testset = CIFAR100(root='./data', train=False, download=True, transform=transform_tensor)
return trainset, testset

def test_label_changed_dataset(cifar10_datasets):
_, testset = cifar10_datasets
new_label = 99
relabeled_dataset = LabelChangedDataset(testset, new_label)

assert len(relabeled_dataset) == len(testset), "Relabeled dataset should match the original dataset length"

for img, label in relabeled_dataset:
assert label == new_label, "All labels should be changed to the new label"

def test_get_subsampled_subset(cifar10_datasets):
trainset, _ = cifar10_datasets
subset_ratio = 0.1
subset = get_subsampled_subset(trainset, subset_ratio=subset_ratio)

expected_size = int(len(trainset) * subset_ratio)
assert len(subset) == expected_size, f"Subset size should be {expected_size}"

def test_get_loaders_cifar10(cifar10_datasets):
train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar10')

assert isinstance(train_loader, DataLoader)
assert isinstance(test_loader, DataLoader)
assert isinstance(test_loader_vs_other, DataLoader)

for images, labels in test_loader:
assert images.shape[0] == 16, "Test loader batch size should be 16"
break

def test_get_loaders_cifar100(cifar100_datasets):
train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar100')

assert isinstance(train_loader, DataLoader)
assert isinstance(test_loader, DataLoader)
assert isinstance(test_loader_vs_other, DataLoader)

for images, labels in test_loader:
assert images.shape[0] == 16, "Test loader batch size should be 16"
break

def test_get_loaders_invalid_dataset():
with pytest.raises(ValueError, match="Dataset 'invalid_dataset' is not supported."):
get_loaders('invalid_dataset')

0 comments on commit 5e0634a

Please sign in to comment.