Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add suite test for docstring and get_test_params coverage #482

Merged
merged 9 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ addopts =
--ignore build_tools
--ignore examples
--ignore docs
--doctest-modules
--durations 10
--timeout 600
--cov skpro
Expand Down
16 changes: 15 additions & 1 deletion skpro/distributions/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,5 +622,19 @@ def get_test_params(cls, parameter_set="default"):
"index": pd.Index(np.arange(3)),
"columns": pd.Index(np.arange(2)),
}
params2 = {
"bins": [
[[0, 1.5, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]],
[(2, 12, 5), [0, 1, 2, 3, 4]],
[[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]],
],
"bin_mass": [
[[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]],
[[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]],
[[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]],
],
"index": pd.Index(np.arange(3)),
"columns": pd.Index(np.arange(2)),
}

return [params1]
return [params1, params2]
10 changes: 10 additions & 0 deletions skpro/metrics/survival/_c_harrell.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,13 @@ def _evaluate_by_index(self, y_true, y_pred, **kwargs):
return pd.DataFrame(res_df.mean(axis=1), columns=["C_Harrell"])
else:
return res_df

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator."""
# array case examples
params1 = {}
params2 = {"score": "quantile", "score_args": {"alpha": 0.5}}
params3 = {"normalization": "index"}

return [params1, params2, params3]
7 changes: 7 additions & 0 deletions skpro/metrics/survival/_spll.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,10 @@ def _evaluate_by_index(self, y_true, y_pred, **kwargs):
return pd.DataFrame(res.mean(axis=1), columns=["SPLL"])
else:
return res

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Retrieve test parameters."""
params1 = {}
params2 = {"multivariate": True}
return [params1, params2]
5 changes: 4 additions & 1 deletion skpro/regression/online/_dont_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_test_params(cls, parameter_set="default"):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, Ridge

from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
Expand All @@ -102,5 +102,8 @@ def get_test_params(cls, parameter_set="default"):
if _check_estimator_deps(CoxPH, severity="none"):
coxph = CoxPH()
params.append({"estimator": coxph})
else:
ridge = Ridge()
params.append({"estimator": ResidualDouble(ridge)})

return params
5 changes: 4 additions & 1 deletion skpro/regression/online/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_test_params(cls, parameter_set="default"):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, Ridge

from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
Expand All @@ -177,5 +177,8 @@ def get_test_params(cls, parameter_set="default"):
if _check_estimator_deps(CoxPH, severity="none"):
coxph = CoxPH()
params.append({"estimator": coxph})
else:
ridge = Ridge()
params.append({"estimator": ResidualDouble(ridge)})

return params
5 changes: 4 additions & 1 deletion skpro/regression/online/_refit_every.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_test_params(cls, parameter_set="default"):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, Ridge

from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
Expand All @@ -156,5 +156,8 @@ def get_test_params(cls, parameter_set="default"):
if _check_estimator_deps(CoxPH, severity="none"):
coxph = CoxPH()
params.append({"estimator": coxph})
else:
ridge = Ridge()
params.append({"estimator": ResidualDouble(ridge)})

return params
48 changes: 45 additions & 3 deletions skpro/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
whether test with name test_name should be excluded for estimator est
should be used only for encoding general rules, not individual skips
individual skips should go on the EXCLUDED_TESTS list in _config
requires _generate_estimator_class and _generate_estimator_instance as is
requires _generate_object_class and _generate_object_instance as is
_excluded_scenario: static method (test_name: str, scenario) -> bool
whether scenario should be skipped in test with test_name test_name
requires _generate_estimator_scenario as is
Expand All @@ -74,10 +74,10 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
object_instance: instance of estimator inheriting from BaseObject
ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
instances are generated by create_test_instance class method of estimator_class
instances are generated by create_test_instance class method of object_class
scenario: instance of TestScenario
ranges over all scenarios returned by retrieve_scenarios
applicable for estimator_class or estimator_instance
applicable for object_class or object_instance
"""

# overrides object retrieval in scikit-base
Expand Down Expand Up @@ -166,6 +166,12 @@ def _excluded_scenario(test_name, scenario):
class TestAllObjects(PackageConfig, BaseFixtureGenerator, _TestAllObjects):
"""Generic tests for all objects in the mini package."""

def test_doctest_examples(self, object_class):
"""Runs doctests for estimator class."""
import doctest

doctest.run_docstring_examples(object_class, globals())

# override this due to reserved_params index, columns, in the BaseDistribution class
# index and columns params behave like pandas, i.e., are changed after __init__
def test_constructor(self, object_class):
Expand Down Expand Up @@ -297,6 +303,42 @@ def unreserved(params):
)
assert is_equal, msg

def test_get_test_params_coverage(self, object_class):
"""Check that get_test_params has good test coverage.

Checks that:

* get_test_params returns at least two test parameter sets
"""
param_list = object_class.get_test_params()

if isinstance(param_list, dict):
param_list = [param_list]

def _coerce_to_list_of_str(obj):
if isinstance(obj, str):
return obj
elif isinstance(obj, list):
return obj
else:
return []

reserved_param_names = object_class.get_class_tag(
"reserved_params", tag_value_default=None
)
reserved_param_names = _coerce_to_list_of_str(reserved_param_names)
reserved_set = set(reserved_param_names)

param_names = object_class.get_param_names()
unreserved_param_names = set(param_names).difference(reserved_set)

if len(unreserved_param_names) > 0:
msg = (
f"{object_class.__name__}.get_test_params should return "
f"at least two test parameter sets, but only {len(param_list)} found."
)
assert len(param_list) > 1, msg


class TestAllEstimators(PackageConfig, BaseFixtureGenerator, _QuickTester):
"""Package level tests for all sktime estimators, i.e., objects with fit."""
Expand Down
117 changes: 117 additions & 0 deletions skpro/tests/test_doctest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
"""Doctest checks directed through pytest with conditional skipping."""

import importlib
import inspect
import pkgutil
from functools import lru_cache

from skpro.tests.test_all_estimators import ONLY_CHANGED_MODULES
from skpro.tests.test_switch import run_test_module_changed

EXCLUDE_MODULES_STARTING_WITH = ("all", "test")


def _all_functions(module_name):
"""Get all functions from a module, including submodules.

Excludes:

* modules starting with 'all' or 'test'.
* if the flag ``ONLY_CHANGED_MODULES`` is set, modules that have not changed,
compared to the ``main`` branch.

Parameters
----------
module_name : str
Name of the module.

Returns
-------
functions_list : list
List of tuples (function_name, function_object).
"""
res = _all_functions_cached(module_name, only_changed_modules=ONLY_CHANGED_MODULES)
# copy the result to avoid modifying the cached result
return res.copy()


@lru_cache
def _all_functions_cached(module_name, only_changed_modules=False):
"""Get all functions from a module, including submodules.

Excludes:

* modules starting with 'all' or 'test'.
* if ``only_changed_modules`` is ``True``, modules that have not changed,
compared to the ``main`` branch.

Parameters
----------
module_name : str
Name of the module.
only_changed_modules : bool, optional (default=False)
If True, only functions from modules that have changed are returned.

Returns
-------
functions_list : list
List of tuples (function_name, function_object).
"""
# Import the package
package = importlib.import_module(module_name)

# Initialize an empty list to hold all functions
functions_list = []

# Walk through the package's modules
package_path = package.__path__[0]
for _, modname, _ in pkgutil.walk_packages(
path=[package_path], prefix=package.__name__ + "."
):
# Skip modules starting with 'all' or 'test'
if modname.split(".")[-1].startswith(EXCLUDE_MODULES_STARTING_WITH):
continue

# Skip modules that have not changed
if only_changed_modules and not run_test_module_changed(modname):
continue

# Import the module
module = importlib.import_module(modname)

# Get all functions from the module
for name, obj in inspect.getmembers(module, inspect.isfunction):
# if function is imported from another module, skip it
if obj.__module__ != module.__name__:
continue
# add the function to the list
functions_list.append((name, obj))

return functions_list


def pytest_generate_tests(metafunc):
"""Test parameterization routine for pytest.

Fixtures parameterized
----------------------
func : all functions from sktime, as returned by _all_functions
if ONLY_CHANGED_MODULES is set, only functions from modules that have changed
"""
# we assume all four arguments are present in the test below
funcs_and_names = _all_functions("skpro")

if len(funcs_and_names) > 0:
funcs, names = zip(*funcs_and_names)

metafunc.parametrize("func", funcs, ids=names)
else:
metafunc.parametrize("func", [])


def test_all_functions_doctest(func):
"""Run doctest for all functions in skpro."""
import doctest

doctest.run_docstring_examples(func, globals())
Loading