diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..59e9bb9 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,65 @@ +name: Python package + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + strategy: + fail-fast: true + matrix: + os: [ubuntu-latest] + python-version: ["3.12"] + # We aim to support the versions on pytorch.org + # as well as selected previous versions on + # https://pytorch.org/get-started/previous-versions/ + torch-version: ["2.4.0"] + include: + - os: windows-latest + torch-version: 2.4.0 + python-version: "3.12" + + runs-on: ${{ matrix.os }} + + steps: + - name: Cache dependencies + id: pip-cache + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }} + + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package + run: | + python -m pip install git+https://github.com/RobustBench/robustbench.git + python -m pip install --upgrade pip setuptools wheel + python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu + pip install '.[dev]' + + - name: Run pytest tests + timeout-minutes: 10 + run: | + pip install pytest + python -m pytest + + - name: Build package + run: | + make build + + - name: Check reinstall script + timeout-minutes: 3 + run: | + ./reinstall.sh \ No newline at end of file diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml new file mode 100644 index 0000000..a480cff --- /dev/null +++ b/.github/workflows/codespell.yml @@ -0,0 +1,21 @@ +--- +name: Codespell + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + codespell: + name: Check for spelling errors + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Codespell + uses: codespell-project/actions-codespell@v1 + with: + ignore_words_list: aros, fpr, tpr, idx, fpr95 \ No newline at end of file diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml new file mode 100644 index 0000000..4378652 --- /dev/null +++ b/.github/workflows/release-pypi.yml @@ -0,0 +1,52 @@ +name: release + +on: + push: + tags: + - 'v*.*.*' + pull_request: + branches: + - main + types: + - labeled + - opened + - edited + - synchronize + - reopened + +jobs: + release: + runs-on: ubuntu-latest + + steps: + - name: Cache dependencies + id: pip-cache + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip + + - name: Checkout code + uses: actions/checkout@v3 + + - name: Build and publish to Test PyPI + if: ${{ (github.ref != 'refs/heads/main') && (github.event.label.name == 'release') }} + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }} + run: | + make dist + ls dist/ + tar tvf dist/aros-node-*.tar.gz + python3 -m twine upload --repository testpypi dist/* + + - name: Build and publish to PyPI + if: ${{ github.event_name == 'push' }} + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: | + make dist + ls dist/ + tar tvf dist/aros-node-*.tar.gz + python3 -m twine upload dist/* \ No newline at end of file diff --git a/.gitignore b/.gitignore index f27f895..bcb9b59 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +.tar.gz # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/AROS/__init__.py b/AROS/__init__.py deleted file mode 100644 index fc97800..0000000 --- a/AROS/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -#© M.W. Mathis Lab | Hossein Mirzaei & M.W. Mathis -# https://github.com/https://github.com/AdaptiveMotorControlLab/AROS -# Licensed under Apache 2.0 \ No newline at end of file diff --git a/aros_node/__init__.py b/aros_node/__init__.py new file mode 100644 index 0000000..f0a11bb --- /dev/null +++ b/aros_node/__init__.py @@ -0,0 +1,9 @@ +# © M.W. Mathis Lab | Hossein Mirzaei & M.W. Mathis +# https://github.com/AdaptiveMotorControlLab/AROS +# Licensed under Apache 2.0 + +from aros_node.version import __version__ +from aros_node.data_loader import LabelChangedDataset, get_subsampled_subset, get_loaders +from aros_node.evaluate import compute_fpr95, compute_auroc, compute_aupr, get_clean_AUC, wrapper_method +from aros_node.stability_loss_function import * +from aros_node.utils import * \ No newline at end of file diff --git a/AROS/data_loader.py b/aros_node/data_loader.py similarity index 100% rename from AROS/data_loader.py rename to aros_node/data_loader.py diff --git a/AROS/evaluate.py b/aros_node/evaluate.py similarity index 99% rename from AROS/evaluate.py rename to aros_node/evaluate.py index c6a2919..1732550 100644 --- a/AROS/evaluate.py +++ b/aros_node/evaluate.py @@ -1,7 +1,7 @@ import numpy as np from tqdm.notebook import tqdm from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc -from utils import * +from aros_node.utils import * import argparse import torch import torch.nn as nn diff --git a/AROS/stability_loss_function.py b/aros_node/stability_loss_function.py similarity index 99% rename from AROS/stability_loss_function.py rename to aros_node/stability_loss_function.py index 75dddb4..413c6f3 100644 --- a/AROS/stability_loss_function.py +++ b/aros_node/stability_loss_function.py @@ -2,8 +2,8 @@ from robustbench.utils import load_model import torch.nn as nn from torch.nn.parameter import Parameter -import utils -from utils import * +import aros_node.utils +from aros_node.utils import * from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset, SubsetRandomSampler, ConcatDataset import numpy as np from tqdm.notebook import tqdm diff --git a/AROS/utils.py b/aros_node/utils.py similarity index 100% rename from AROS/utils.py rename to aros_node/utils.py diff --git a/aros_node/version.py b/aros_node/version.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/aros_node/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/AROS/main.py b/main.py similarity index 96% rename from AROS/main.py rename to main.py index 6845400..4fa7a17 100644 --- a/AROS/main.py +++ b/main.py @@ -1,13 +1,13 @@ -!pip install -r requirements.txt +import aros_node import argparse import torch import torch.nn as nn -from evaluate import * -from utils import * +from aros_node.evaluate import * +from aros_node.utils import * from tqdm.notebook import tqdm -from data_loader import * -from stability_loss_function import * +from aros_node.data_loader import * +from aros_node.stability_loss_function import * def main(): parser = argparse.ArgumentParser(description="Hyperparameters for the script") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8fe2f47 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/reinstall.sh b/reinstall.sh new file mode 100755 index 0000000..082d492 --- /dev/null +++ b/reinstall.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Re-install the package. By running './reinstall.sh' +# +# Note that AROS uses the build +# system specified in +# PEP517 https://peps.python.org/pep-0517/ and +# PEP518 https://peps.python.org/pep-0518/ +# and hence there is no setup.py file. + +set -e # abort on error + +pip uninstall -y aros-node + +# Get version +VERSION=0.0.1 +echo "Upgrading to AROS v${VERSION}" + +# Upgrade the build system (PEP517/518 compatible) +python3 -m pip install virtualenv +python3 -m pip install --upgrade build +python3 -m build --sdist --wheel . + +# Reinstall the package with most recent version +pip install --upgrade --no-cache-dir "dist/aros_node-${VERSION}-py3-none-any.whl" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c5b0a81..2cb89b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,9 @@ geotorch +torch torchdiffeq -git+https://github.com/RobustBench/robustbench.git -timm==1.0.9 \ No newline at end of file +timm==1.0.9 +robustbench +numpy +scikit-learn +scipy +tqdm \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 64cb16f..231ab7b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,33 +1,18 @@ [metadata] -name = aros -version = 0.0.1 +name = aros-node +version = attr: aros_node.version.__version__ author = Hossein Mirzaei, Mackenzie Mathis author_email = mackenzie@post.harvard.edu description = AROS: Adversarially Robust Out-of-Distribution Detection through Stability long_description = file: README.md long_description_content_type = text/markdown -license_files = LICENSE.md -license_file_type = text/markdown url = https://github.com/AdaptiveMotorControlLab/AROS -project_urls = - Bug Tracker = https://github.com/AdaptiveMotorControlLab/AROS/issues -classifiers = - Development Status :: 4 - Beta - Environment :: GPU :: NVIDIA CUDA - Intended Audience :: Science/Research - Operating System :: OS Independent - Programming Language :: Python :: 3 - Topic :: Scientific/Engineering :: Artificial Intelligence - License :: OSI Approved :: Apache Software License [options] packages = find: include_package_data = True python_requires = >=3.10 -install_requires = - geotorch - torchdiffeq - git+https://github.com/RobustBench/robustbench.git +install_requires = file: requirements.txt [options.extras_require] dev = @@ -35,7 +20,4 @@ dev = toml yapf black - pytest - -[bdist_wheel] -universal=0 \ No newline at end of file + pytest \ No newline at end of file diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py new file mode 100644 index 0000000..058fd90 --- /dev/null +++ b/tests/test_dataloaders.py @@ -0,0 +1,69 @@ +import pytest +import torch +from torch.utils.data import DataLoader, Subset +from torchvision.datasets import CIFAR10, CIFAR100 +from torchvision.transforms import ToTensor +from aros_node import ( + LabelChangedDataset, + get_subsampled_subset, + get_loaders, +) + +# Set up transformations and datasets for tests +transform_tensor = ToTensor() + +@pytest.fixture +def cifar10_datasets(): + trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_tensor) + testset = CIFAR10(root='./data', train=False, download=True, transform=transform_tensor) + return trainset, testset + +@pytest.fixture +def cifar100_datasets(): + trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_tensor) + testset = CIFAR100(root='./data', train=False, download=True, transform=transform_tensor) + return trainset, testset + +def test_label_changed_dataset(cifar10_datasets): + _, testset = cifar10_datasets + new_label = 99 + relabeled_dataset = LabelChangedDataset(testset, new_label) + + assert len(relabeled_dataset) == len(testset), "Relabeled dataset should match the original dataset length" + + for img, label in relabeled_dataset: + assert label == new_label, "All labels should be changed to the new label" + +def test_get_subsampled_subset(cifar10_datasets): + trainset, _ = cifar10_datasets + subset_ratio = 0.1 + subset = get_subsampled_subset(trainset, subset_ratio=subset_ratio) + + expected_size = int(len(trainset) * subset_ratio) + assert len(subset) == expected_size, f"Subset size should be {expected_size}" + +def test_get_loaders_cifar10(cifar10_datasets): + train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar10') + + assert isinstance(train_loader, DataLoader) + assert isinstance(test_loader, DataLoader) + assert isinstance(test_loader_vs_other, DataLoader) + + for images, labels in test_loader: + assert images.shape[0] == 16, "Test loader batch size should be 16" + break + +def test_get_loaders_cifar100(cifar100_datasets): + train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar100') + + assert isinstance(train_loader, DataLoader) + assert isinstance(test_loader, DataLoader) + assert isinstance(test_loader_vs_other, DataLoader) + + for images, labels in test_loader: + assert images.shape[0] == 16, "Test loader batch size should be 16" + break + +def test_get_loaders_invalid_dataset(): + with pytest.raises(ValueError, match="Dataset 'invalid_dataset' is not supported."): + get_loaders('invalid_dataset')