From 4ea598c14618bbdd7063ca7c7bdbc309c2104dc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 22:01:36 +0100 Subject: [PATCH 01/10] start work --- skpro/registry/__init__.py | 6 + skpro/registry/_lookup.py | 355 +++++++++++++++++++++++++++++++++++++ 2 files changed, 361 insertions(+) create mode 100644 skpro/registry/__init__.py create mode 100644 skpro/registry/_lookup.py diff --git a/skpro/registry/__init__.py b/skpro/registry/__init__.py new file mode 100644 index 000000000..acaf3f287 --- /dev/null +++ b/skpro/registry/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Registry and lookup functionality.""" + +from skpro.registry._lookup import all_objects + +__all__ = ["all_objects"] diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py new file mode 100644 index 000000000..3276ba80b --- /dev/null +++ b/skpro/registry/_lookup.py @@ -0,0 +1,355 @@ +# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +"""Registry lookup methods. + +This module exports the following methods for registry lookup: + +all_objects(object_types, filter_tags) + lookup and filtering of objects +""" + +__author__ = ["fkiraly"] +# all_objects is based on the sklearn utility all_estimators + + +from copy import deepcopy +from operator import itemgetter +from pathlib import Path + +import pandas as pd +from skbase.lookup import all_objects as _all_objects + +from sktime.base import BaseEstimator +from sktime.registry._base_classes import ( + BASE_CLASS_LIST, + BASE_CLASS_LOOKUP, + TRANSFORMER_MIXIN_LIST, +) +from sktime.registry._tags import OBJECT_TAG_REGISTER + +VALID_TRANSFORMER_TYPES = tuple(TRANSFORMER_MIXIN_LIST) +VALID_ESTIMATOR_BASE_TYPES = tuple(BASE_CLASS_LIST) + +VALID_ESTIMATOR_TYPES = ( + BaseEstimator, + *VALID_ESTIMATOR_BASE_TYPES, + *VALID_TRANSFORMER_TYPES, +) + + +def all_objects( + object_types=None, + filter_tags=None, + exclude_objects=None, + return_names=True, + as_dataframe=False, + return_tags=None, + suppress_import_stdout=True, +): + """Get a list of all objects from skpro. + + This function crawls the module and gets all classes that inherit + from skpro's and sklearn's base classes. + + Not included are: the base classes themselves, classes defined in test + modules. + + Parameters + ---------- + object_types: str, list of str, optional (default=None) + Which kind of objects should be returned. + if None, no filter is applied and all objects are returned. + if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned + possible str values are entries of registry.BASE_CLASS_REGISTER (first col) + for instance 'classifier', 'regressor', 'transformer', 'forecaster' + return_names: bool, optional (default=True) + if True, object class name is included in the all_objects() + return in the order: name, object class, optional tags, either as + a tuple or as pandas.DataFrame columns + if False, object class name is removed from the all_objects() + return. + filter_tags: dict of (str or list of str), optional (default=None) + For a list of valid tag strings, use the registry.all_tags utility. + subsets the returned objects as follows: + each key/value pair is statement in "and"/conjunction + key is tag name to sub-set on + value str or list of string are tag values + condition is "key must be equal to value, or in set(value)" + exclude_objects: str, list of str, optional (default=None) + Names of objects to exclude. + as_dataframe: bool, optional (default=False) + if True, all_objects will return a pandas.DataFrame with named + columns for all of the attributes being returned. + if False, all_objects will return a list (either a list of + objects or a list of tuples, see Returns) + return_tags: str or list of str, optional (default=None) + Names of tags to fetch and return each object's value of. + For a list of valid tag strings, use the registry.all_tags utility. + if str or list of str, + the tag values named in return_tags will be fetched for each + object and will be appended as either columns or tuple entries. + suppress_import_stdout : bool, optional. Default=True + whether to suppress stdout printout upon import. + + Returns + ------- + all_objects will return one of the following: + 1. list of objects, if return_names=False, and return_tags is None + 2. list of tuples (optional object name, class, ~optional object + tags), if return_names=True or return_tags is not None. + 3. pandas.DataFrame if as_dataframe = True + if list of objects: + entries are objects matching the query, + in alphabetical order of object name + if list of tuples: + list of (optional object name, object, optional object + tags) matching the query, in alphabetical order of object name, + where + ``name`` is the object name as string, and is an + optional return + ``object`` is the actual object + ``tags`` are the object's values for each tag in return_tags + and is an optional return. + if dataframe: + all_objects will return a pandas.DataFrame. + column names represent the attributes contained in each column. + "objects" will be the name of the column of objects, "names" + will be the name of the column of object class names and the string(s) + passed in return_tags will serve as column names for all columns of + tags that were optionally requested. + + Examples + -------- + >>> from sktime.registry import all_objects + >>> # return a complete list of objects as pd.Dataframe + >>> all_objects(as_dataframe=True) + >>> # return all forecasters by filtering for object type + >>> all_objects("forecaster") + >>> # return all forecasters which handle missing data in the input by tag filtering + >>> all_objects("forecaster", filter_tags={"handles-missing-data": True}) + + References + ---------- + Modified version from scikit-learn's `all_objects()`. + """ + MODULES_TO_IGNORE = ( + "tests", + "setup", + "contrib", + "benchmarking", + "utils", + "all", + "plotting", + ) + + result = [] + ROOT = str(Path(__file__).parent.parent) # sktime package root directory + + if object_types: + clsses = _check_object_types(object_types) + if not isinstance(object_types, list): + object_types = [object_types] + CLASS_LOOKUP = {x: y for x, y in zip(object_types, clsses)} + else: + CLASS_LOOKUP = None + + result = _all_objects( + object_types=object_types, + filter_tags=filter_tags, + exclude_objects=exclude_objects, + return_names=return_names, + as_dataframe=as_dataframe, + return_tags=return_tags, + suppress_import_stdout=suppress_import_stdout, + package_name="sktime", + path=ROOT, + modules_to_ignore=MODULES_TO_IGNORE, + class_lookup=CLASS_LOOKUP, + ) + + return result + + +def _check_list_of_str_or_error(arg_to_check, arg_name): + """Check that certain arguments are str or list of str. + + Parameters + ---------- + arg_to_check: argument we are testing the type of + arg_name: str, + name of the argument we are testing, will be added to the error if + ``arg_to_check`` is not a str or a list of str + + Returns + ------- + arg_to_check: list of str, + if arg_to_check was originally a str it converts it into a list of str + so that it can be iterated over. + + Raises + ------ + TypeError if arg_to_check is not a str or list of str + """ + # check that return_tags has the right type: + if isinstance(arg_to_check, str): + arg_to_check = [arg_to_check] + if not isinstance(arg_to_check, list) or not all( + isinstance(value, str) for value in arg_to_check + ): + raise TypeError( + f"Error in all_objects! Argument {arg_name} must be either\ + a str or list of str" + ) + return arg_to_check + + +def _get_return_tags(object, return_tags): + """Fetch a list of all tags for every_entry of all_objects. + + Parameters + ---------- + object: Baseobject, an sktime object + return_tags: list of str, + names of tags to get values for the object + + Returns + ------- + tags: a tuple with all the objects values for all tags in return tags. + a value is None if it is not a valid tag for the object provided. + """ + tags = tuple(object.get_class_tag(tag) for tag in return_tags) + return tags + + +def _check_tag_cond(object, filter_tags=None, as_dataframe=True): + """Check whether object satisfies filter_tags condition. + + Parameters + ---------- + object: Baseobject, an sktime object + filter_tags: dict of (str or list of str), default=None + subsets the returned objects as follows: + each key/value pair is statement in "and"/conjunction + key is tag name to sub-set on + value str or list of string are tag values + condition is "key must be equal to value, or in set(value)" + as_dataframe: bool, default=False + if False, return is as described below; + if True, return is converted into a pandas.DataFrame for pretty + display + + Returns + ------- + cond_sat: bool, whether object satisfies condition in filter_tags + """ + if not isinstance(filter_tags, dict): + raise TypeError("filter_tags must be a dict") + + cond_sat = True + + for key, value in filter_tags.items(): + if not isinstance(value, list): + value = [value] + cond_sat = cond_sat and object.get_class_tag(key) in set(value) + + return cond_sat + + +def all_tags( + object_types=None, + as_dataframe=False, +): + """Get a list of all tags from sktime. + + Retrieves tags directly from `_tags`, offers filtering functionality. + + Parameters + ---------- + object_types: string, list of string, optional (default=None) + Which kind of objects should be returned. + - If None, no filter is applied and all objects are returned. + - Possible values are 'classifier', 'regressor', 'transformer' and + 'forecaster' to get objects only of these specific types, or a list of + these to get the objects that fit at least one of the types. + as_dataframe: bool, optional (default=False) + if False, return is as described below; + if True, return is converted into a pandas.DataFrame for pretty + display + + Returns + ------- + tags: list of tuples (a, b, c, d), + in alphabetical order by a + a : string - name of the tag as used in the _tags dictionary + b : string - name of the scitype this tag applies to + must be in _base_classes.BASE_CLASS_SCITYPE_LIST + c : string - expected type of the tag value + should be one of: + "bool" - valid values are True/False + "int" - valid values are all integers + "str" - valid values are all strings + ("str", list_of_string) - any string in list_of_string is valid + ("list", list_of_string) - any individual string and sub-list is valid + d : string - plain English description of the tag + """ + + def is_tag_for_type(tag, object_types): + tag_types = tag[1] + tag_types = _check_list_of_str_or_error(tag_types, "tag_types") + + if isinstance(object_types, str): + object_types = [object_types] + + tag_types = set(tag_types) + object_types = set(object_types) + is_valid_tag_for_type = len(tag_types.intersection(object_types)) > 0 + + return is_valid_tag_for_type + + all_tags = OBJECT_TAG_REGISTER + + if object_types: + # checking, but not using the return since that is classes, not strings + _check_object_types(object_types) + all_tags = [tag for tag in all_tags if is_tag_for_type(tag, object_types)] + + all_tags = sorted(all_tags, key=itemgetter(0)) + + # convert to pd.DataFrame if as_dataframe=True + if as_dataframe: + columns = ["name", "scitype", "type", "description"] + all_tags = pd.DataFrame(all_tags, columns=columns) + + return all_tags + + +def _check_object_types(object_types): + """Return list of classes corresponding to type strings.""" + object_types = deepcopy(object_types) + + if not isinstance(object_types, list): + object_types = [object_types] # make iterable + + def _get_err_msg(object_type): + return ( + f"Parameter `object_type` must be None, a string or a list of " + f"strings. Valid string values are: " + f"{tuple(BASE_CLASS_LOOKUP.keys())}, but found: " + f"{repr(object_type)}" + ) + + for i, object_type in enumerate(object_types): + if not isinstance(object_type, (type, str)): + raise ValueError( + "Please specify `object_types` as a list of str or " "types." + ) + if isinstance(object_type, str): + if object_type not in BASE_CLASS_LOOKUP.keys(): + raise ValueError(_get_err_msg(object_type)) + object_type = BASE_CLASS_LOOKUP[object_type] + object_types[i] = object_type + elif isinstance(object_type, type): + pass + else: + raise ValueError(_get_err_msg(object_type)) + return object_types From 153b766a45e44bbf42f21ea906554046e85f24e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 22:40:31 +0100 Subject: [PATCH 02/10] tags --- skpro/distributions/base.py | 1 + skpro/metrics/base.py | 1 + skpro/registry/__init__.py | 12 +- skpro/registry/_tags.py | 225 +++++++++++++++++++++++++++++ skpro/regression/base/_base.py | 3 +- skpro/tests/test_all_estimators.py | 26 +--- 6 files changed, 243 insertions(+), 25 deletions(-) create mode 100644 skpro/registry/_tags.py diff --git a/skpro/distributions/base.py b/skpro/distributions/base.py index 1523b17c0..cc5188274 100644 --- a/skpro/distributions/base.py +++ b/skpro/distributions/base.py @@ -20,6 +20,7 @@ class BaseDistribution(BaseObject): # default tag values - these typically make the "safest" assumption _tags = { + "object_type": "distribution", # type of object, e.g., 'distribution' "python_version": None, # PEP 440 python version specifier to limit versions "python_dependencies": None, # string or str list of pkg soft dependencies "reserved_params": ["index", "columns"], diff --git a/skpro/metrics/base.py b/skpro/metrics/base.py index 6dea97f2a..5dfdc58d6 100644 --- a/skpro/metrics/base.py +++ b/skpro/metrics/base.py @@ -35,6 +35,7 @@ class BaseProbaMetric(BaseObject): """ _tags = { + "object_type": "metric", # type of object "reserved_params": ["multioutput", "score_average"], "scitype:y_pred": "pred_proba", "lower_is_better": True, diff --git a/skpro/registry/__init__.py b/skpro/registry/__init__.py index acaf3f287..bf54ebc12 100644 --- a/skpro/registry/__init__.py +++ b/skpro/registry/__init__.py @@ -2,5 +2,15 @@ """Registry and lookup functionality.""" from skpro.registry._lookup import all_objects +from skpro.registry._tags import ( + OBJECT_TAG_LIST, + OBJECT_TAG_REGISTER, + check_tag_is_valid, +) -__all__ = ["all_objects"] +__all__ = [ + "OBJECT_TAG_LIST", + "OBJECT_TAG_REGISTER", + "all_objects", + "check_tag_is_valid", +] diff --git a/skpro/registry/_tags.py b/skpro/registry/_tags.py new file mode 100644 index 000000000..95bd99d88 --- /dev/null +++ b/skpro/registry/_tags.py @@ -0,0 +1,225 @@ +"""Register of estimator and object tags. + +Note for extenders: new tags should be entered in OBJECT_TAG_REGISTER. +No other place is necessary to add new tags. + +This module exports the following: + +--- +OBJECT_TAG_REGISTER - list of tuples + +each tuple corresponds to a tag, elements as follows: + 0 : string - name of the tag as used in the _tags dictionary + 1 : string - name of the scitype this tag applies to + must be in _base_classes.BASE_CLASS_SCITYPE_LIST + 2 : string - expected type of the tag value + should be one of: + "bool" - valid values are True/False + "int" - valid values are all integers + "str" - valid values are all strings + "list" - valid values are all lists of arbitrary elements + ("str", list_of_string) - any string in list_of_string is valid + ("list", list_of_string) - any individual string and sub-list is valid + ("list", "str") - any individual string or list of strings is valid + validity can be checked by check_tag_is_valid (see below) + 3 : string - plain English description of the tag + +--- + +OBJECT_TAG_TABLE - pd.DataFrame + OBJECT_TAG_REGISTER in table form, as pd.DataFrame + rows of OBJECT_TABLE correspond to elements in OBJECT_TAG_REGISTER + +OBJECT_TAG_LIST - list of string + elements are 0-th entries of OBJECT_TAG_REGISTER, in same order + +--- + +check_tag_is_valid(tag_name, tag_value) - checks whether tag_value is valid for tag_name +""" +import pandas as pd + +OBJECT_TAG_REGISTER = [ + # -------------------------- + # all objects and estimators + # -------------------------- + ( + "reserved_params", + "object", + "list", + "list of reserved parameter names", + ), + ( + "object_type", + "object", + "str", + "type of object, e.g., 'regressor', 'transformer'", + ), + ( + "estimator_type", + "estimator", + "str", + "type of estimator, e.g., 'regressor', 'transformer'", + ), + ( + "python_version", + "object", + "str", + "python version specifier (PEP 440) for estimator, or None = all versions ok", + ), + ( + "python_dependencies", + "object", + ("list", "str"), + "python dependencies of estimator as str or list of str", + ), + ( + "python_dependencies_alias", + "object", + "dict", + "should be provided if import name differs from package name, \ + key-value pairs are package name, import name", + ), + # ------------------ + # BaseProbaRegressor + # ------------------ + ( + "capability:multioutput", + "regressor_proba", + "bool", + "whether estimator supports multioutput regression", + ), + ( + "capability:missing", + "regressor_proba", + "bool", + "whether estimator supports missing values", + ), + # ---------------- + # BaseDistribution + # ---------------- + ( + "capabilities:approx", + "distribution", + ("list", "str"), + "methods of distr that are approximate", + ), + ( + "capabilities:exact", + "distribution", + ("list", "str"), + "methods of distr that are numerically exact", + ), + ( + "distr:measuretype", + "distribution", + ("str", ["continuous", "discrete", "mixed"]), + "measure type of distr", + ), + ( + "approx_mean_spl", + "distribution", + "int", + "sample size used in MC estimates of mean", + ), + ( + "approx_var_spl", + "distribution", + "int", + "sample size used in MC estimates of var", + ), + ( + "approx_energy_spl", + "distribution", + "int", + "sample size used in MC estimates of energy", + ), + ( + "approx_spl", + "distribution", + "int", + "sample size used in other MC estimates", + ), + ( + "bisect_iter", + "distribution", + "int", + "max iters for bisection method in ppf", + ), + # --------------- + # BaseProbaMetric + # --------------- + ( + "scitype:y_pred", + "metric", + "str", + "expected input type for y_pred in performance metric", + ), + ( + "lower_is_better", + "metric", + "bool", + "whether lower (True) or higher (False) is better", + ), + # ---------------------------- + # BaseMetaObject reserved tags + # ---------------------------- + ( + "named_object_parameters", + "object", + "str", + "name of component list attribute for meta-objects", + ), +] + +OBJECT_TAG_TABLE = pd.DataFrame(OBJECT_TAG_REGISTER) +OBJECT_TAG_LIST = OBJECT_TAG_TABLE[0].tolist() + + +def check_tag_is_valid(tag_name, tag_value): + """Check validity of a tag value. + + Parameters + ---------- + tag_name : string, name of the tag + tag_value : object, value of the tag + + Raises + ------ + KeyError - if tag_name is not a valid tag in OBJECT_TAG_LIST + ValueError - if the tag_valid is not a valid for the tag with name tag_name + """ + if tag_name not in OBJECT_TAG_LIST: + raise KeyError(tag_name + " is not a valid tag") + + tag_type = OBJECT_TAG_TABLE[2][OBJECT_TAG_TABLE[0] == "tag_name"] + + if tag_type == "bool" and not isinstance(tag_value, bool): + raise ValueError(tag_name + " must be True/False, found " + tag_value) + + if tag_type == "int" and not isinstance(tag_value, int): + raise ValueError(tag_name + " must be integer, found " + tag_value) + + if tag_type == "str" and not isinstance(tag_value, str): + raise ValueError(tag_name + " must be string, found " + tag_value) + + if tag_type == "list" and not isinstance(tag_value, list): + raise ValueError(tag_name + " must be list, found " + tag_value) + + if tag_type[0] == "str" and tag_value not in tag_type[1]: + raise ValueError( + tag_name + " must be one of " + tag_type[1] + " found " + tag_value + ) + + if tag_type[0] == "list" and not set(tag_value).issubset(tag_type[1]): + raise ValueError( + tag_name + " must be subest of " + tag_type[1] + " found " + tag_value + ) + + if tag_type[0] == "list" and tag_type[1] == "str": + msg = f"{tag_name} must be str or list of str, found {tag_value}" + if not isinstance(tag_value, (str, list)): + raise ValueError(msg) + if isinstance(tag_value, list): + if not all(isinstance(x, str) for x in tag_value): + raise ValueError(msg) diff --git a/skpro/regression/base/_base.py b/skpro/regression/base/_base.py index c2378c22c..c5bce1280 100644 --- a/skpro/regression/base/_base.py +++ b/skpro/regression/base/_base.py @@ -13,7 +13,8 @@ class BaseProbaRegressor(BaseEstimator): """Base class for probabilistic supervised regressors.""" _tags = { - "estimator_type": "regressor", + "object_type": "regressor_proba", # type of object, e.g., 'distribution' + "estimator_type": "regressor_proba", "capability:multioutput": False, "capability:missing": True, } diff --git a/skpro/tests/test_all_estimators.py b/skpro/tests/test_all_estimators.py index fd7e2a2f5..842baa2a0 100644 --- a/skpro/tests/test_all_estimators.py +++ b/skpro/tests/test_all_estimators.py @@ -8,6 +8,8 @@ from skbase.testing import TestAllObjects as _TestAllObjects from skbase.testing.utils.inspect import _get_args +from skpro.registry import OBJECT_TAG_LIST + class PackageConfig: """Contains package config variables for test classes.""" @@ -25,29 +27,7 @@ class PackageConfig: # list of valid tags # expected type: list of str, str are tag names - valid_tags = [ - # all estimators - "reserved_params", - "estimator_type", - "python_version", - "python_dependencies", - # BaseProbaRegressor - "capability:multioutput", - "capability:missing", - # BaseDistribution - "capabilities:approx", # list of str, methods of distr that are approximate - "capabilities:exact", # list of str, methods of distr that are num. exact - "distr:measuretype", # str, "continuous", "discrete", or "mixed" - "approx_mean_spl", # int, sample size used in MC estimates of mean - "approx_var_spl", # int, sample size used in MC estimates of var - "approx_energy_spl", # int, sample size used in MC estimates of energy - "approx_spl", # int, sample size used in other MC estimates - "bisect_iter", # max iters for bisection method in ppf - "scitype:y_pred", # str, expected input type for y_pred in performance metric - "lower_is_better", # bool, whether lower (True) or higher (False) is better - # BaseMetaObject reserved tags - "named_object_parameters", # name of component list attribute for meta-objects - ] + valid_tags = OBJECT_TAG_LIST class TestAllObjects(PackageConfig, _TestAllObjects): From 42495e307c727d555ee0eeab2ddb6102d7273558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:13:29 +0100 Subject: [PATCH 03/10] registry lookup --- skpro/registry/_lookup.py | 72 +++++++++++---------------------------- 1 file changed, 20 insertions(+), 52 deletions(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index 3276ba80b..8704cf98c 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -18,22 +18,8 @@ import pandas as pd from skbase.lookup import all_objects as _all_objects -from sktime.base import BaseEstimator -from sktime.registry._base_classes import ( - BASE_CLASS_LIST, - BASE_CLASS_LOOKUP, - TRANSFORMER_MIXIN_LIST, -) -from sktime.registry._tags import OBJECT_TAG_REGISTER - -VALID_TRANSFORMER_TYPES = tuple(TRANSFORMER_MIXIN_LIST) -VALID_ESTIMATOR_BASE_TYPES = tuple(BASE_CLASS_LIST) - -VALID_ESTIMATOR_TYPES = ( - BaseEstimator, - *VALID_ESTIMATOR_BASE_TYPES, - *VALID_TRANSFORMER_TYPES, -) +from skpro.base import BaseObject +from skpro.registry._tags import OBJECT_TAG_REGISTER def all_objects( @@ -136,35 +122,38 @@ def all_objects( "tests", "setup", "contrib", - "benchmarking", "utils", "all", - "plotting", ) result = [] - ROOT = str(Path(__file__).parent.parent) # sktime package root directory + ROOT = str(Path(__file__).parent.parent) # skpro package root directory + + if isinstance(filter_tags, str): + filter_tags = {filter_tags: True} + filter_tags = filter_tags.copy() if filter_tags else None if object_types: - clsses = _check_object_types(object_types) - if not isinstance(object_types, list): - object_types = [object_types] - CLASS_LOOKUP = {x: y for x, y in zip(object_types, clsses)} - else: - CLASS_LOOKUP = None + if filter_tags and "object_type" not in filter_tags.keys(): + object_tag_filter = {} + else: + object_tag_filter = {"object_type": object_types} + if filter_tags: + filter_tags.update(object_tag_filter) + else: + filter_tags = object_tag_filter result = _all_objects( - object_types=object_types, + object_types=BaseObject, filter_tags=filter_tags, exclude_objects=exclude_objects, return_names=return_names, as_dataframe=as_dataframe, return_tags=return_tags, suppress_import_stdout=suppress_import_stdout, - package_name="sktime", + package_name="skpro", path=ROOT, modules_to_ignore=MODULES_TO_IGNORE, - class_lookup=CLASS_LOOKUP, ) return result @@ -267,14 +256,10 @@ def all_tags( ---------- object_types: string, list of string, optional (default=None) Which kind of objects should be returned. - - If None, no filter is applied and all objects are returned. - - Possible values are 'classifier', 'regressor', 'transformer' and - 'forecaster' to get objects only of these specific types, or a list of - these to get the objects that fit at least one of the types. + If None, no filter is applied and all objects are returned. as_dataframe: bool, optional (default=False) - if False, return is as described below; - if True, return is converted into a pandas.DataFrame for pretty - display + if False, return is as described below; + if True, return is converted into a pandas.DataFrame for pretty display Returns ------- @@ -330,26 +315,9 @@ def _check_object_types(object_types): if not isinstance(object_types, list): object_types = [object_types] # make iterable - def _get_err_msg(object_type): - return ( - f"Parameter `object_type` must be None, a string or a list of " - f"strings. Valid string values are: " - f"{tuple(BASE_CLASS_LOOKUP.keys())}, but found: " - f"{repr(object_type)}" - ) - for i, object_type in enumerate(object_types): if not isinstance(object_type, (type, str)): raise ValueError( "Please specify `object_types` as a list of str or " "types." ) - if isinstance(object_type, str): - if object_type not in BASE_CLASS_LOOKUP.keys(): - raise ValueError(_get_err_msg(object_type)) - object_type = BASE_CLASS_LOOKUP[object_type] - object_types[i] = object_type - elif isinstance(object_type, type): - pass - else: - raise ValueError(_get_err_msg(object_type)) return object_types From f6d011d3d2e150e134fef9b3d00d275c2ac74fae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:17:00 +0100 Subject: [PATCH 04/10] Update _lookup.py --- skpro/registry/_lookup.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index 8704cf98c..4126881d8 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -21,6 +21,8 @@ from skpro.base import BaseObject from skpro.registry._tags import OBJECT_TAG_REGISTER +VALID_OBJECT_TYPE_STRINGS = set([x[1] for x in OBJECT_TAG_REGISTER]) + def all_objects( object_types=None, @@ -315,9 +317,20 @@ def _check_object_types(object_types): if not isinstance(object_types, list): object_types = [object_types] # make iterable + def _get_err_msg(object_type): + return ( + f"Parameter `object_type` must be None, a string or a list of " + f"strings. Valid string values are: " + f"{tuple(VALID_OBJECT_TYPE_STRINGS)}, but found: " + f"{repr(object_type)}" + ) + for i, object_type in enumerate(object_types): if not isinstance(object_type, (type, str)): raise ValueError( "Please specify `object_types` as a list of str or " "types." ) + if isinstance(object_type, str): + if object_type not in VALID_OBJECT_TYPE_STRINGS: + raise ValueError(_get_err_msg(object_type)) return object_types From e4c422fafe5fa7cd4e3fa895a29b70cce031d635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:20:55 +0100 Subject: [PATCH 05/10] linting --- skpro/registry/_lookup.py | 4 +++- skpro/registry/_tags.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index 4126881d8..b0c93e121 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -1,4 +1,4 @@ -# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +# -*- coding: utf-8 -*- """Registry lookup methods. This module exports the following methods for registry lookup: @@ -6,6 +6,8 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +# based on the sktime module of same name __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators diff --git a/skpro/registry/_tags.py b/skpro/registry/_tags.py index 95bd99d88..b2416cb5b 100644 --- a/skpro/registry/_tags.py +++ b/skpro/registry/_tags.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Register of estimator and object tags. Note for extenders: new tags should be entered in OBJECT_TAG_REGISTER. From 4560da3b6242d4589b2fd3f7679f7fdbc1ee06c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:26:22 +0100 Subject: [PATCH 06/10] Update _lookup.py --- skpro/registry/_lookup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index b0c93e121..d6f97336c 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -327,7 +327,7 @@ def _get_err_msg(object_type): f"{repr(object_type)}" ) - for i, object_type in enumerate(object_types): + for object_type in object_types: if not isinstance(object_type, (type, str)): raise ValueError( "Please specify `object_types` as a list of str or " "types." From eebbcf87065fbb089e9f12df784e814f135245e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:32:25 +0100 Subject: [PATCH 07/10] docstring --- skpro/registry/_lookup.py | 13 +++++++------ skpro/regression/base/_base.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index d6f97336c..366b364d0 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -110,17 +110,18 @@ def all_objects( Examples -------- - >>> from sktime.registry import all_objects + >>> from skpro.registry import all_objects >>> # return a complete list of objects as pd.Dataframe >>> all_objects(as_dataframe=True) - >>> # return all forecasters by filtering for object type - >>> all_objects("forecaster") - >>> # return all forecasters which handle missing data in the input by tag filtering - >>> all_objects("forecaster", filter_tags={"handles-missing-data": True}) + >>> # return all probabilistic regressors by filtering for object type + >>> all_objects("regressor_proba") + >>> # return all regressors which handle missing data in the input by tag filtering + >>> all_objects("regressor_proba", filter_tags={""capability:missing"": True}) References ---------- - Modified version from scikit-learn's `all_objects()`. + Adapted version of sktime's ``all_estimators``, + which is an evolution of scikit-learn's ``all_estimators`` """ MODULES_TO_IGNORE = ( "tests", diff --git a/skpro/regression/base/_base.py b/skpro/regression/base/_base.py index c5bce1280..584646e7a 100644 --- a/skpro/regression/base/_base.py +++ b/skpro/regression/base/_base.py @@ -13,8 +13,8 @@ class BaseProbaRegressor(BaseEstimator): """Base class for probabilistic supervised regressors.""" _tags = { - "object_type": "regressor_proba", # type of object, e.g., 'distribution' - "estimator_type": "regressor_proba", + "object_type": "regressor_proba", # type of object, e.g., "distribution" + "estimator_type": "regressor_proba", # type of estimator, e.g., "regressor" "capability:multioutput": False, "capability:missing": True, } From 425be13dd77785c4d97f6b1f10939295f414a243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:35:29 +0100 Subject: [PATCH 08/10] Update _lookup.py --- skpro/registry/_lookup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index 366b364d0..7a1b5868e 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -116,7 +116,7 @@ def all_objects( >>> # return all probabilistic regressors by filtering for object type >>> all_objects("regressor_proba") >>> # return all regressors which handle missing data in the input by tag filtering - >>> all_objects("regressor_proba", filter_tags={""capability:missing"": True}) + >>> all_objects("regressor_proba", filter_tags={"capability:missing": True}) References ---------- From ffff5e0777c4f9b1f3e42716635e3617c0d4109a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:40:20 +0100 Subject: [PATCH 09/10] Update _lookup.py --- skpro/registry/_lookup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index 7a1b5868e..1bba0d565 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -20,7 +20,7 @@ import pandas as pd from skbase.lookup import all_objects as _all_objects -from skpro.base import BaseObject +from skpro.base import BaseObject, BaseEstimator from skpro.registry._tags import OBJECT_TAG_REGISTER VALID_OBJECT_TYPE_STRINGS = set([x[1] for x in OBJECT_TAG_REGISTER]) @@ -149,7 +149,7 @@ def all_objects( filter_tags = object_tag_filter result = _all_objects( - object_types=BaseObject, + object_types=[BaseObject, BaseEstimator], filter_tags=filter_tags, exclude_objects=exclude_objects, return_names=return_names, From 6a4dfbabae5a3a32752e7e1471bef3365114a68c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 23:42:12 +0100 Subject: [PATCH 10/10] Update _lookup.py --- skpro/registry/_lookup.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/skpro/registry/_lookup.py b/skpro/registry/_lookup.py index 1bba0d565..af70977b1 100644 --- a/skpro/registry/_lookup.py +++ b/skpro/registry/_lookup.py @@ -20,7 +20,7 @@ import pandas as pd from skbase.lookup import all_objects as _all_objects -from skpro.base import BaseObject, BaseEstimator +from skpro.base import BaseEstimator, BaseObject from skpro.registry._tags import OBJECT_TAG_REGISTER VALID_OBJECT_TYPE_STRINGS = set([x[1] for x in OBJECT_TAG_REGISTER]) @@ -114,9 +114,13 @@ def all_objects( >>> # return a complete list of objects as pd.Dataframe >>> all_objects(as_dataframe=True) >>> # return all probabilistic regressors by filtering for object type - >>> all_objects("regressor_proba") + >>> all_objects("regressor_proba", as_dataframe=True) >>> # return all regressors which handle missing data in the input by tag filtering - >>> all_objects("regressor_proba", filter_tags={"capability:missing": True}) + >>> all_objects( + ... "regressor_proba", + ... filter_tags={"capability:missing": True}, + ... as_dataframe=True + ... ) References ----------