Skip to content

Commit

Permalink
[ENH] check_estimator utility (#90)
Browse files Browse the repository at this point in the history
This adds a `check_estimator` utility for checking implemented objects
against their type specific contracts.
  • Loading branch information
fkiraly authored Sep 12, 2023
1 parent 7f2ce21 commit f1baf8e
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 10 deletions.
4 changes: 4 additions & 0 deletions skpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Utility functionality."""

from skpro.utils.estimator_checks import check_estimator

__all__ = ["check_estimator"]
178 changes: 178 additions & 0 deletions skpro/utils/estimator_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Estimator checker for extension."""

__author__ = ["fkiraly"]
__all__ = ["check_estimator"]

from inspect import isclass

from skpro.utils.validation._dependencies import _check_soft_dependencies


def check_estimator(
estimator,
raise_exceptions=False,
tests_to_run=None,
fixtures_to_run=None,
verbose=True,
tests_to_exclude=None,
fixtures_to_exclude=None,
):
"""Run all tests on one single estimator.
Tests that are run on estimator:
all tests in test_all_estimators
all interface compatibility tests from the module of estimator's scitype
for example, test_all_regressors if estimator is a regressor
Parameters
----------
estimator : estimator class or estimator instance
raise_exceptions : bool, optional, default=False
whether to return exceptions/failures in the results dict, or raise them
* if False: returns exceptions in returned `results` dict
* if True: raises exceptions as they occur
tests_to_run : str or list of str, optional. Default = run all tests.
Names (test/function name string) of tests to run.
sub-sets tests that are run to the tests given here.
fixtures_to_run : str or list of str, optional. Default = run all tests.
pytest test-fixture combination codes, which test-fixture combinations to run.
sub-sets tests and fixtures to run to the list given here.
If both tests_to_run and fixtures_to_run are provided, runs the *union*,
i.e., all test-fixture combinations for tests in tests_to_run,
plus all test-fixture combinations in fixtures_to_run.
verbose : str, optional, default=True.
whether to print out informative summary of tests run.
tests_to_exclude : str or list of str, names of tests to exclude. default = None
removes tests that should not be run, after subsetting via tests_to_run.
fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
removes test-fixture combinations that should not be run.
This is done after subsetting via fixtures_to_run.
Returns
-------
results : dict of results of the tests in self
keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
entries are the string "PASSED" if the test passed,
or the exception raised if the test did not pass
returned only if all tests pass, or raise_exceptions=False
Raises
------
if raise_exceptions=True,
raises any exception produced by the tests directly
Examples
--------
>>> from skpro.regression.residual import ResidualDouble
>>> from skpro.utils import check_estimator
Running all tests for ResidualDouble class,
this uses all instances from get_test_params and compatible scenarios
>>> results = check_estimator(ResidualDouble)
All tests PASSED!
Running all tests for a specific ResidualDouble
this uses the instance that is passed and compatible scenarios
>>> from sklearn.linear_model import LinearRegression
>>> results = check_estimator(ResidualDouble(LinearRegression()))
All tests PASSED!
Running specific test (all fixtures) for ResidualDouble
>>> results = check_estimator(ResidualDouble, tests_to_run="test_clone")
All tests PASSED!
{'test_clone[ResidualDouble-0]': 'PASSED',
'test_clone[ResidualDouble-1]': 'PASSED'}
Running one specific test-fixture-combination for ResidualDouble
>>> check_estimator(
... ResidualDouble, fixtures_to_run="test_clone[ResidualDouble-1]"
... )
All tests PASSED!
{'test_clone[ResidualDouble-1]': 'PASSED'}
"""
msg = (
"check_estimator is a testing utility for developers, and "
"requires pytest to be present "
"in the python environment, but pytest was not found. "
"pytest is a developer dependency and not included in the base "
"sktime installation. Please run: `pip install pytest` to "
"install the pytest package. "
"To install sktime with all developer dependencies, run:"
" `pip install sktime[dev]`"
)
_check_soft_dependencies("pytest", msg=msg)

from skpro.base import BaseEstimator
from skpro.distributions.tests.test_all_distrs import TestAllDistributions
from skpro.regression.tests.test_all_regressors import TestAllRegressors
from skpro.tests.test_all_estimators import TestAllObjects

testclass_dict = dict()

testclass_dict["regressor"] = TestAllRegressors
testclass_dict["distribution"] = TestAllDistributions

results = TestAllObjects().run_tests(
obj=estimator,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
fixtures_to_exclude=fixtures_to_exclude,
)

def is_estimator(obj):
"""Return whether obj is an estimator class or estimator object."""
if isclass(obj):
return issubclass(obj, BaseEstimator)
else:
return isinstance(obj, BaseEstimator)

# commented out for now - add when TestAllEstimators is added
# if is_estimator(estimator):
# results_estimator = TestAllEstimators().run_tests(
# obj=estimator,
# raise_exceptions=raise_exceptions,
# tests_to_run=tests_to_run,
# fixtures_to_run=fixtures_to_run,
# tests_to_exclude=tests_to_exclude,
# fixtures_to_exclude=fixtures_to_exclude,
# )
# results.update(results_estimator)

try:
scitype_of_estimator = estimator.get_tag("object_type", "object")
except Exception:
scitype_of_estimator = ""

if scitype_of_estimator in testclass_dict.keys():
results_scitype = testclass_dict[scitype_of_estimator]().run_tests(
estimator=estimator,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
fixtures_to_exclude=fixtures_to_exclude,
)
results.update(results_scitype)

failed_tests = [key for key in results.keys() if results[key] != "PASSED"]
if len(failed_tests) > 0:
msg = failed_tests
msg = ["FAILED: " + x for x in msg]
msg = "\n".join(msg)
else:
msg = "All tests PASSED!"

if verbose:
# printing is an intended feature, for console usage and interactive debugging
print(msg) # noqa T001

return results
47 changes: 37 additions & 10 deletions skpro/utils/validation/_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def _check_soft_dependencies(
package_import_alias=None,
severity="error",
obj=None,
msg=None,
suppress_import_stdout=False,
):
"""Check if required soft dependencies are installed and raise error or warning.
Expand Down Expand Up @@ -49,6 +50,8 @@ def _check_soft_dependencies(
or a class is passed when it is called at the start of a single-class module,
the error message is more informative and will refer to the class/object;
if str is passed, will be used as name of the class/object or module
msg : str, or None, default=None
if str, will override the error message or warning shown with msg
suppress_import_stdout : bool, optional. Default=False
whether to suppress stdout printout upon import.
Expand All @@ -64,17 +67,24 @@ def _check_soft_dependencies(
if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
packages = packages[0]
if not all(isinstance(x, str) for x in packages):
raise TypeError("packages must be str or tuple of str")
raise TypeError(
"packages argument of _check_soft_dependencies must be str or tuple of "
f"str, but found packages argument of type {type(packages)}"
)

if package_import_alias is None:
package_import_alias = {}
msg = "package_import_alias must be a dict with str keys and values"
msg_pkg_import_alias = (
"package_import_alias argument of _check_soft_dependencies must "
"be a dict with str keys and values, but found "
f"package_import_alias of type {type(package_import_alias)}"
)
if not isinstance(package_import_alias, dict):
raise TypeError(msg)
raise TypeError(msg_pkg_import_alias)
if not all(isinstance(x, str) for x in package_import_alias.keys()):
raise TypeError(msg)
raise TypeError(msg_pkg_import_alias)
if not all(isinstance(x, str) for x in package_import_alias.values()):
raise TypeError(msg)
raise TypeError(msg_pkg_import_alias)

if obj is None:
class_name = "This functionality"
Expand All @@ -85,14 +95,25 @@ def _check_soft_dependencies(
elif isinstance(obj, str):
class_name = obj
else:
raise TypeError("obj must be a class, an object, a str, or None")
raise TypeError(
"obj argument of _check_soft_dependencies must be a class, an object,"
" a str, or None, but found obj of type"
f" {type(obj)}"
)

if msg is not None and not isinstance(msg, str):
raise TypeError(
"msg argument of _check_soft_dependencies must be a str, "
f"or None, but found msg of type {type(msg)}"
)

for package in packages:
try:
req = Requirement(package)
except InvalidRequirement:
msg_version = (
f"wrong format for package requirement string, "
f"wrong format for package requirement string "
f"passed via packages argument of _check_soft_dependencies, "
f'must be PEP 440 compatible requirement string, e.g., "pandas"'
f' or "pandas>1.1", but found "{package}"'
)
Expand All @@ -117,15 +138,15 @@ def _check_soft_dependencies(
pkg_ref = import_module(package_import_name)
# if package cannot be imported, make the user aware of installation requirement
except ModuleNotFoundError as e:
if obj is None:
if obj is None and msg is None:
msg = (
f"{e}. '{package}' is a soft dependency and not included in the "
f"base skpro installation. Please run: `pip install {package}` to "
f"install the {package} package. "
f"To install all soft dependencies, run: `pip install "
f"skpro[all_extras]`"
)
else:
elif msg is None: # obj is not None, msg is None
msg = (
f"{class_name} requires package '{package}' to be present "
f"in the python environment, but '{package}' was not found. "
Expand All @@ -135,6 +156,9 @@ def _check_soft_dependencies(
f"To install all soft dependencies, run: `pip install "
f"skpro[all_extras]`"
)
# if msg is not None, none of the above is executed,
# so if msg is passed it overrides the default messages

if severity == "error":
raise ModuleNotFoundError(msg) from e
elif severity == "warning":
Expand Down Expand Up @@ -307,10 +331,13 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
compatible = compatible and _check_python_version(obj, severity=severity)

pkg_deps = obj.get_class_tag("python_dependencies", None)
pck_alias = obj.get_class_tag("python_dependencies_alias", None)
if pkg_deps is not None and not isinstance(pkg_deps, list):
pkg_deps = [pkg_deps]
if pkg_deps is not None:
pkg_deps_ok = _check_soft_dependencies(*pkg_deps, severity=severity, obj=obj)
pkg_deps_ok = _check_soft_dependencies(
*pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
)
compatible = compatible and pkg_deps_ok

return compatible

0 comments on commit f1baf8e

Please sign in to comment.