Skip to content

Commit

Permalink
[MNT] Differential testing of estimators (#96)
Browse files Browse the repository at this point in the history
This PR introduces the differential testing from `sktime` to the PR CI,
to test only estimators from modules that have changed.

This does not affect the release CI, which runs the tests for all
estimators in `sktime`.
  • Loading branch information
fkiraly authored Sep 13, 2023
1 parent c8ef880 commit f47629a
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 3 deletions.
18 changes: 16 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ jobs:
os: [ubuntu-20.04, windows-latest, macOS-11]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- run: git remote set-branches origin 'main'

- run: git fetch --depth 1

- name: Set up Python
uses: actions/setup-python@v4
Expand All @@ -89,6 +93,9 @@ jobs:
- name: Show dependencies
run: python -m pip list

- name: Show available branches
run: git branch -a

- name: Run tests
run: make test

Expand All @@ -104,7 +111,11 @@ jobs:
os: [ubuntu-20.04, windows-latest, macOS-11]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- run: git remote set-branches origin 'main'

- run: git fetch --depth 1

- name: Set up Python
uses: actions/setup-python@v4
Expand All @@ -121,6 +132,9 @@ jobs:
- name: Show dependencies
run: python -m pip list

- name: Show available branches
run: git branch -a

- name: Run tests
run: make test

Expand Down
30 changes: 30 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Main configuration file for pytest.
Contents:
adds an --only_changed_modules option to pytest
this allows to turn on/off differential testing (for shorter runtime)
"on" condition ensures that only estimators are tested that have changed,
more precisely, only estimators whose class is in a module
that has changed compared to the main branch
by default, this is off, including for default local runs of pytest
"""
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)

__author__ = ["fkiraly"]


def pytest_addoption(parser):
"""Pytest command line parser options adder."""
parser.addoption(
"--only_changed_modules",
default=False,
help="test only estimators from modules that have changed compared to main",
)


def pytest_configure(config):
"""Pytest configuration preamble."""
from skpro.tests import test_all_estimators

if config.getoption("--only_changed_modules") in [True, "True"]:
test_all_estimators.ONLY_CHANGED_MODULES = True
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ addopts =
--cov-report xml
--cov-report html
--showlocals
--only_changed_modules True
-n auto
filterwarnings =
ignore::UserWarning
Expand Down
22 changes: 21 additions & 1 deletion skpro/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from skbase.testing.utils.inspect import _get_args

from skpro.registry import OBJECT_TAG_LIST
from skpro.utils.git_diff import is_class_changed

# whether to test only estimators from modules that are changed w.r.t. main
# default is False, can be set to True by pytest --only_changed_modules True flag
ONLY_CHANGED_MODULES = False


class PackageConfig:
Expand All @@ -29,7 +34,22 @@ class PackageConfig:
valid_tags = OBJECT_TAG_LIST


class TestAllObjects(PackageConfig, _TestAllObjects):
class BaseFixtureGenerator:
"""Base class for fixture generation, overrides skbase object retrieval."""

def _all_objects(self):
"""Retrieve list of all object classes of type self.object_type_filter."""
obj_list = super()._all_objects()

# this setting ensures that only estimators are tested that have changed
# in the sense that any line in the module is different from main
if ONLY_CHANGED_MODULES:
obj_list = [obj for obj in obj_list if is_class_changed(obj)]

return obj_list


class TestAllObjects(PackageConfig, BaseFixtureGenerator, _TestAllObjects):
"""Generic tests for all objects in the mini package."""

# override this due to reserved_params index, columns, in the BaseDistribution class
Expand Down
80 changes: 80 additions & 0 deletions skpro/utils/git_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Git related utilities to identify changed modules."""

__author__ = ["fkiraly"]
__all__ = []

import importlib.util
import inspect
import subprocess


def get_module_from_class(cls):
"""Get full parent module string from class.
Parameters
----------
cls : class
class to get module string from, e.g., NaiveForecaster
Returns
-------
str : module string, e.g., sktime.forecasting.naive
"""
module = inspect.getmodule(cls)
return module.__name__ if module else None


def get_path_from_module(module_str):
r"""Get local path string from module string.
Parameters
----------
module_str : str
module string, e.g., sktime.forecasting.naive
Returns
-------
str : local path string, e.g., sktime\forecasting\naive.py
"""
try:
module_spec = importlib.util.find_spec(module_str)
if module_spec is None:
raise ImportError(
f"Error in get_path_from_module, module '{module_str}' not found."
)
return module_spec.origin
except Exception as e:
raise ImportError(f"Error finding module '{module_str}'") from e


def is_module_changed(module_str):
"""Check if a module has changed compared to the main branch.
Parameters
----------
module_str : str
module string, e.g., sktime.forecasting.naive
"""
module_file_path = get_path_from_module(module_str)
cmd = f"git diff remotes/origin/main -- {module_file_path}"
try:
output = subprocess.check_output(cmd, shell=True, text=True)
return bool(output)
except subprocess.CalledProcessError:
return True


def is_class_changed(cls):
"""Check if a class' parent module has changed compared to the main branch.
Parameters
----------
cls : class
class to get module string from, e.g., NaiveForecaster
Returns
-------
bool : True if changed, False otherwise
"""
module_str = get_module_from_class(cls)
return is_module_changed(module_str)

0 comments on commit f47629a

Please sign in to comment.