Skip to content

Commit

Permalink
Adding test for legacy checkpoints (Lightning-AI#17562)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored May 4, 2023
1 parent 137837d commit b87df54
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .actions/pull_legacy_checkpoints.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/bin/bash

# Run this script from the project root.
URL="https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip"
mkdir -p tests/legacy
# wget is simpler but does not work on Windows
python -c "from urllib.request import urlretrieve; urlretrieve('$URL', 'tests/legacy/checkpoints.zip')"
ls -l tests/legacy/

unzip -o tests/legacy/checkpoints.zip -d tests/legacy/
ls -l tests/legacy/checkpoints/
10 changes: 7 additions & 3 deletions .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ jobs:
python requirements/pytorch/check-avail-extras.py
displayName: 'Env details'
- bash: bash .actions/pull_legacy_checkpoints.sh
displayName: 'Get legacy checkpoints'

- bash: python -m pytest pytorch_lightning
workingDirectory: src
condition: eq(variables['PACKAGE_NAME'], 'pytorch')
Expand All @@ -146,6 +143,13 @@ jobs:
condition: eq(variables['PACKAGE_NAME'], 'pytorch')
displayName: 'Adjust tests & examples'
- bash: |
bash .actions/pull_legacy_checkpoints.sh
cd tests/legacy
bash generate_checkpoints.sh
ls -l checkpoints/
displayName: 'Get legacy checkpoints'
- bash: python -m coverage run --source ${COVERAGE_SOURCE} -m pytest --ignore benchmarks -v --durations=50
workingDirectory: tests/tests_pytorch
env:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/_legacy-checkpoints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,6 @@ jobs:
title: Adding test for legacy checkpoint created with ${{ needs.create-legacy-ckpts.outputs.pl-version }}
delete-branch: true
labels: |
checkpointing
tests
pl
10 changes: 7 additions & 3 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ jobs:
if: ${{ matrix.requires == 'oldest' }}
run: python .actions/assistant.py replace_oldest_ver

- name: Pull legacy checkpoints
run: bash .actions/pull_legacy_checkpoints.sh

- name: Adjust PyTorch versions in requirements files
if: ${{ matrix.requires != 'oldest' && matrix.release != 'pre' }}
run: |
Expand Down Expand Up @@ -159,6 +156,13 @@ jobs:
- name: Prevent using raw source
run: rm -rf src/

- name: Get legacy checkpoints
run: |
bash .actions/pull_legacy_checkpoints.sh
cd tests/legacy
bash generate_checkpoints.sh
ls -l checkpoints/
- name: Testing Warnings
working-directory: tests/tests_pytorch
# needs to run outside of `pytest`
Expand Down
9 changes: 9 additions & 0 deletions tests/legacy/back-compatible-versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@
1.8.4
1.8.5
1.8.6
1.9.0
1.9.1
1.9.2
1.9.3
1.9.4
1.9.5
2.0.0
2.0.1
2.0.2
20 changes: 9 additions & 11 deletions tests/legacy/generate_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ set -e
LEGACY_PATH=$(cd $(dirname $0); pwd -P)
ENV_PATH=$LEGACY_PATH/vEnv
export PYTHONPATH=$(dirname $LEGACY_PATH) # for `import tests_pytorch`
echo LEGACY_PATH: $LEGACY_PATH
echo ENV_PATH: $ENV_PATH
echo PYTHONPATH: $PYTHONPATH
printf "LEGACY_PATH: $LEGACY_PATH"
printf "ENV_PATH: $ENV_PATH"
printf "PYTHONPATH: $PYTHONPATH"
rm -rf $ENV_PATH

function create_and_save_checkpoint {
python --version
python -m pip --version
python -m pip list

python $LEGACY_PATH/simple_classif_training.py
python $LEGACY_PATH/simple_classif_training.py $pl_ver

cp $LEGACY_PATH/simple_classif_training.py $LEGACY_PATH/checkpoints/$pl_ver
mv $LEGACY_PATH/checkpoints/$pl_ver/lightning_logs/version_0/checkpoints/*.ckpt $LEGACY_PATH/checkpoints/$pl_ver/
Expand All @@ -28,11 +29,9 @@ function create_and_save_checkpoint {
# iterate over all arguments assuming that each argument is version
for pl_ver in "$@"
do
echo processing version: $pl_ver
printf "processing version: $pl_ver"

# Don't install/update anything before activating venv
# to avoid breaking any existing environment.
rm -rf $ENV_PATH
# Don't install/update anything before activating venv to avoid breaking any existing environment.
python -m venv $ENV_PATH
source $ENV_PATH/bin/activate

Expand All @@ -47,10 +46,9 @@ done

# use the PL installed in the environment if no PL version is specified
if [[ -z "$@" ]]; then
pl_ver=$(python -c "import pytorch_lightning as pl; print(pl.__version__)")
echo processing version: $pl_ver
printf "processing local version"

python -m pip install -r $LEGACY_PATH/requirements.txt

pl_ver="local"
create_and_save_checkpoint
fi
2 changes: 1 addition & 1 deletion tests/legacy/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torchmetrics # necessary because old PL verions don't have it as dependency
torchmetrics # necessary because old PL versions don't have it as dependency
scikit-learn
10 changes: 6 additions & 4 deletions tests/legacy/simple_classif_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

import torch

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import EarlyStopping
import lightning.pytorch as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.simple_models import ClassificationModel

Expand Down Expand Up @@ -50,5 +51,6 @@ def main_train(dir_path, max_epochs: int = 20):


if __name__ == "__main__":
path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__))
name = sys.argv[1] if len(sys.argv) > 1 else str(pl.__version__)
path_dir = os.path.join(PATH_LEGACY, "checkpoints", name)
main_train(path_dir)
2 changes: 2 additions & 0 deletions tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
# load list of all back compatible versions
with open(os.path.join(_PATH_LEGACY, "back-compatible-versions.txt")) as fp:
LEGACY_BACK_COMPATIBLE_PL_VERSIONS = [ln.strip() for ln in fp.readlines()]
# This shall be created for each CI run
LEGACY_BACK_COMPATIBLE_PL_VERSIONS += ["local"]


@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_imports_unified(pl_version: str):
path_ckpt = path_ckpts[-1]

# only below version 1.5.0 we pickled stuff in checkpoints
if Version(pl_version) < Version("1.5.0"):
if pl_version != "local" and Version(pl_version) < Version("1.5.0"):
context = pytest.warns(UserWarning, match="Redirecting import of")
else:
context = no_warning_call(match="Redirecting import of*")
Expand Down

0 comments on commit b87df54

Please sign in to comment.