diff --git a/.github/workflows/check-changelog.yml b/.github/workflows/check-changelog.yml index 53eb98ccd9b01..53f64ba5c886b 100644 --- a/.github/workflows/check-changelog.yml +++ b/.github/workflows/check-changelog.yml @@ -10,12 +10,13 @@ jobs: check: name: A reviewer will let you know if it is required or can be bypassed runs-on: ubuntu-latest - if: ${{ contains(github.event.pull_request.labels.*.name, 'No Changelog Needed') == 0 }} && github.repository == 'scikit-learn/scikit-learn' + if: ${{ contains(github.event.pull_request.labels.*.name, 'No Changelog Needed') == 0 && github.repository == 'scikit-learn/scikit-learn' }} steps: - name: Get PR number and milestone run: | echo "PR_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV echo "TAGGED_MILESTONE=${{ github.event.pull_request.milestone.title }}" >> $GITHUB_ENV + echo "${{ github.repository }}" - uses: actions/checkout@v3 with: fetch-depth: '0' diff --git a/README.rst b/README.rst index 2501ee87c61e9..fbdfdaa95ef4c 100644 --- a/README.rst +++ b/README.rst @@ -232,12 +232,42 @@ Python API: users from generalizing the ``Criterion`` and ``Splitter`` and creating a neat Python API wrapper. Moreover, the ``Tree`` class is not customizable. - Our fix: We internally implement a private function to actually build the entire tree, ``BaseDecisionTree._build_tree``, which can be overridden in subclasses that customize the criterion, splitter, or tree, or any combination of them. +- ``sklearn.ensemble.BaseForest`` and its subclass algorithms are slow when ``n_samples`` is very high. Binning + features into a histogram, which is the basis of "LightGBM" and "HistGradientBoostingClassifier" is a computational + trick that can both significantly increase runtime efficiency, but also help prevent overfitting in trees, since + the sorting in "BestSplitter" is done on bins rather than the continuous feature values. This would enable + random forests and their variants to scale to millions of samples. + - Our fix: We added a ``max_bins=None`` keyword argument to the ``BaseForest`` class, and all its subclasses. The default behavior is no binning. The current implementation is not necessarily efficient. There are several improvements to be made. See below. Overall, the existing tree models, such as :class:`~sklearn.tree.DecisionTreeClassifier` and :class:`~sklearn.ensemble.RandomForestClassifier` all work exactly the same as they would in ``scikit-learn`` main, but these extensions enable 3rd-party packages to extend the Cython/Python API easily. +Roadmap +------- +There are several improvements that can be made in this fork. Primarily, the binning feature +promises to make Random Forests and their variants ultra-fast. However, the binning needs +to be implemented in a similar fashion to ``HistGradientBoostingClassifier``, which passes +in the binning thresholds throughout the tree construction step, such that the split nodes +store the actual numerical value of the bin rather than the "bin index". This requires +modifying the tree Cython code to take in a ``binning_thresholds`` parameter that is part +of the ``_BinMapper`` fitted class. This also allows us not to do any binning during prediction/apply +time because the tree already stores the "numerical" threshold value we would want to apply +to any incoming ``X`` that is not binned. + +Besides that modification, the tree and splitter need to be able to handle not just ``np.float32`` +data (the type for X normally in Random Forests), but also ``uint8`` data (the type for X when it +is binned in to e.g. 255 bins). This would not only save RAM since ``uint8`` storage of millions +of samples would result in many GB saved, but also improved runtime. + +So in summary, the Cython code of the tree submodule needs to take in an extra parameter for +the binning thresholds if binning occurs and also be able to handle ``X`` being of dtype ``uint8``. +Afterwards, Random Forests will have fully leveraged the binning feature. + +Something to keep in mind is that upstream scikit-learn is actively working on incorporating +missing-value handling and categorical handling into Random Forests. + Next steps ---------- diff --git a/asv_benchmarks/benchmarks/ensemble.py b/asv_benchmarks/benchmarks/ensemble.py index d592ca0fd2697..8c5a28e3da90f 100644 --- a/asv_benchmarks/benchmarks/ensemble.py +++ b/asv_benchmarks/benchmarks/ensemble.py @@ -2,7 +2,6 @@ RandomForestClassifier, GradientBoostingClassifier, HistGradientBoostingClassifier, - ObliqueRandomForestClassifier, ) from .common import Benchmark, Estimator, Predictor @@ -14,43 +13,6 @@ from .utils import make_gen_classif_scorers -class ObliqueRandomForestClassifierBenchmark(Predictor, Estimator, Benchmark): - """ - Benchmarks for RandomForestClassifier. - """ - - param_names = ["representation", "n_jobs"] - params = (["dense"], Benchmark.n_jobs_vals) - - def setup_cache(self): - super().setup_cache() - - def make_data(self, params): - representation, n_jobs = params - - data = _20newsgroups_lowdim_dataset() - - return data - - def make_estimator(self, params): - representation, n_jobs = params - - n_estimators = 500 if Benchmark.data_size == "large" else 100 - - estimator = ObliqueRandomForestClassifier( - n_estimators=n_estimators, - min_samples_split=10, - max_features="log2", - n_jobs=n_jobs, - random_state=0, - ) - - return estimator - - def make_scorers(self): - make_gen_classif_scorers(self) - - class RandomForestClassifierBenchmark(Predictor, Estimator, Benchmark): """ Benchmarks for RandomForestClassifier. diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 817063cb790aa..c8e4a87ff98a3 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -195,27 +195,6 @@ in bias:: :align: center :scale: 75% -Oblique Random Forests ----------------------- - -In oblique random forests (see :class:`ObliqueRandomForestClassifier` and -:class:`ObliqueRandomForestRegressor` classes), each tree in the ensemble is built -from a sample drawn with replacement (i.e., a bootstrap sample) from the -training set. The oblique random forest is the same as that of a random forest, -except in how the splits are computed in each tree. - -Similar to how random forests achieve a reduced variance by combining diverse trees, -sometimes at the cost of a slight increase in bias, oblique random forests aim to do the same. -They are motivated to construct even more diverse trees, thereby improving model generalization. -In practice the variance reduction is often significant hence yielding an overall better model. - -In contrast to the original publication [B2001]_, the scikit-learn -implementation allows the user to control the number of features to combine in computing -candidate splits. This is done via the ``feature_combinations`` parameter. For -more information and intuition, see -:ref:`documentation on oblique decision trees `. - - .. _random_forest_parameters: Parameters diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 5eb7b331408cd..7fa12fd16d487 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -614,49 +614,6 @@ be pruned. This process stops when the pruned tree's minimal * :ref:`sphx_glr_auto_examples_tree_plot_cost_complexity_pruning.py` -.. _oblique_trees: - -Oblique Trees -============= - -Similar to DTs, **Oblique Trees (OTs)** are a non-parametric supervised learning -method used for :ref:`classification ` and :ref:`regression -`. It was originally described as ``Forest-RC`` in Breiman's -landmark paper on Random Forests [RF]_. Breiman found that combining data features -empirically outperforms DTs on a variety of data sets. - -The algorithm implemented in scikit-learn differs from ``Forest-RC`` in that -it allows the user to specify the number of variables to combine to consider -as a split, :math:`\lambda`. If :math:`\lambda` is set to ``n_features``, then -it is equivalent to ``Forest-RC``. :math:`\lambda` presents a tradeoff between -considering dense combinations of features vs sparse combinations of features. - -Differences compared to decision trees --------------------------------------- - -Compared to DTs, OTs differ in how they compute a candidate split. DTs split -along the passed in data columns in an axis-aligned fashion, whereas OTs split -along oblique curves. Using the Iris dataset, we can similarly construct an OT -as follows: - - >>> from sklearn.datasets import load_iris - >>> from sklearn import tree - >>> iris = load_iris() - >>> X, y = iris.data, iris.target - >>> clf = tree.ObliqueDecisionTreeClassifier() - >>> clf = clf.fit(X, y) - -.. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_dtc_002.png - :target: ../auto_examples/tree/plot_iris_dtc.html - :scale: 75 - :align: center - -Another major difference to DTs is that OTs can by definition sample more candidate -splits. The parameter ``max_features`` controls how many splits to sample at each -node. For DTs "max_features" is constrained to be at most "n_features" by default, -whereas OTs can sample possibly up to :math:`2^{n_{features}}` candidate splits -because they are combining features. - Classification, regression and multi-output problems ---------------------------------------------------- @@ -709,50 +666,6 @@ optimization (e.g. `GridSearchCV`). If one has prior knowledge about how the dat distributed along its features, such as data being axis-aligned, then one might use a DT. Other considerations are runtime and space complexity. -Mathematical formulation ------------------------- - -Given training vectors :math:`x_i \in R^n`, i=1,..., l and a label vector -:math:`y \in R^l`, an oblique decision tree recursively partitions the -feature space such that the samples with the same labels or similar target -values are grouped together. Normal decision trees partition the feature space -in an axis-aligned manner splitting along orthogonal axes based on the dimensions -(columns) of :math:`x_i`. In oblique trees, nodes sample a random projection vector, -:math:`a_i \in R^n`, where the inner-product of :math:`\langle a_i, x_i \rangle` -is a candidate split value. The entries of :math:`a_i` have values -+/- 1 with probability :math:`\lambda / n` with the rest being 0s. - -Let the data at node :math:`m` be represented by :math:`Q_m` with :math:`n_m` -samples. For each candidate split :math:`\theta = (a_i, t_m)` consisting of a -(possibly sparse) vector :math:`a_i` and threshold :math:`t_m`, partition the -data into :math:`Q_m^{left}(\theta)` and :math:`Q_m^{right}(\theta)` subsets - -.. math:: - - Q_m^{left}(\theta) = \{(x, y) | a_i^T x_j \leq t_m\} - - Q_m^{right}(\theta) = Q_m \setminus Q_m^{left}(\theta) - -Note that this formulation is a generalization of decision trees, where -:math:`a_i = e_i`, a standard basis vector with a "1" at index "i" and "0" -elsewhere. - -The quality of a candidate split of node :math:`m` is then computed using an -impurity function or loss function :math:`H()`, in the same exact manner as -decision trees. - -Limitations compared to decision trees --------------------------------------- - - * There currently does not exist support for pruning OTs, such as with the minimal - cost-complexity pruning algorithm. - - * Moreover, OTs do not have built-in support for missing data, so the recommendation - by scikit-learn is for users to first impute, or drop their missing data if they - would like to use OTs. - - * Currently, OTs also does not support sparse inputs for data matrices and labels. - .. topic:: References: .. [BRE] L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 87b7bd2b17091..d57f3f5717e5f 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -950,11 +950,6 @@ Changelog :mod:`sklearn.tree` ................... -- |MajorFeature| Add oblique decision trees and forests for classification - with :class:`tree.ObliqueDecisionTreeClassifier` and - :class:`ensemble.ObliqueRandomForestClassifier`. :pr:`22754` by - `Adam Li `. - - |Enhancement| :func:`tree.plot_tree`, :func:`tree.export_graphviz` now uses a lower case `x[i]` to represent feature `i`. :pr:`23480` by `Thomas Fan`_. diff --git a/examples/ensemble/plot_oblique_axis_aligned_forests_cc18.py b/examples/ensemble/plot_oblique_axis_aligned_forests_cc18.py deleted file mode 100644 index 54d78674d5068..0000000000000 --- a/examples/ensemble/plot_oblique_axis_aligned_forests_cc18.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -=============================================================================== -Plot oblique forest and axis-aligned random forest predictions on cc18 datasets -=============================================================================== -A performance comparison between oblique forest and standard axis- -aligned random forest using three datasets from OpenML benchmarking suites. -Two of these datasets, namely -[WDBC](https://www.openml.org/search?type=data&sort=runs&id=1510) -and [Phishing Website](https://www.openml.org/search?type=data&sort=runs&id=4534) -datasets consist of 31 features where the former dataset is entirely numeric -and the latter dataset is entirely norminal. The third dataset, dubbed -[cnae-9](https://www.openml.org/search?type=data&status=active&id=1468), is a -numeric dataset that has notably large feature space of 857 features. As you -will notice, of these three datasets, the oblique forest outperforms axis-aligned -random forest on cnae-9 utilizing sparse random projection machanism. All datasets -are subsampled due to computational constraints. -""" - -import numpy as np -import pandas as pd -from datetime import datetime - -import seaborn as sns -import matplotlib.pyplot as plt -from sklearn.ensemble import RandomForestClassifier, ObliqueRandomForestClassifier -from sklearn.model_selection import RepeatedKFold, cross_validate -from sklearn.datasets import fetch_openml - -random_state = 123456 -rng = np.random.default_rng(random_state) -t0 = datetime.now() -data_ids = [11, 40499] # openml dataset id -df = pd.DataFrame() - - -def load_cc18(data_id): - dat = fetch_openml(data_id=data_id, as_frame=False) - - d_name = dat["details"]["name"] - d = dat["data"] - y = dat["target"] - - # Subsampling large datasets - n = int(d.shape[0] * 0.1) - subsample_idx = rng.choice(np.arange(d.shape[0]), n) - X = d[subsample_idx, :] - y = y[subsample_idx, ...] - return X, y, d_name - - -def get_scores(X, y, d_name="UNK", n_cv=5, n_repeats=2, random_state=1, kwargs=None): - clfs = [ - RandomForestClassifier(**kwargs[0], random_state=random_state), - ObliqueRandomForestClassifier(**kwargs[1], random_state=random_state), - ] - - tmp = [] - - for i, clf in enumerate(clfs): - cv = RepeatedKFold( - n_splits=n_cv, n_repeats=n_repeats, random_state=random_state - ) - test_score = cross_validate(estimator=clf, X=X, y=y, cv=cv, scoring="accuracy") - - tmp.append( - [ - d_name, - ["RF", "OF"][i], - test_score["test_score"], - test_score["test_score"].mean(), - ] - ) - print( - f'{d_name} mean test score for {["RF", "OF"][i]}:' - f' {test_score["test_score"].mean()}' - ) - - df = pd.DataFrame(tmp, columns=["dataset", "model", "score", "mean"]) - df = df.explode("score") - df["score"] = df["score"].astype(float) - df.reset_index(inplace=True, drop=True) - - return df - - -def load_best_params(data_ids): - folder_path = None - params = [] - - if not folder_path: - # pre-tuned hyper-parameters - params += [ - [ - {"max_depth": 5, "max_features": "sqrt", "n_estimators": 100}, - {"max_depth": 5, "max_features": None, "n_estimators": 100}, - ], - [ - {"max_depth": 10, "max_features": "log2", "n_estimators": 200}, - {"max_depth": 10, "max_features": 80, "n_estimators": 200}, - ], - ] - else: - for data_id in data_ids: - file_path = f"OFvsRF_grid_search_cv_results_openml_{data_id}.csv" - df = pd.read_csv(folder_path + file_path).sort_values( - "mean_test_score", ascending=False - ) - tmp = [] - for clf in ["RF", "OF"]: - tmp.append(eval(df.query(f'clf=="{clf}"')["params"].iloc[0])) - params.append(tmp) - - return params - - -params = load_best_params(data_ids=data_ids) - -for i, data_id in enumerate(data_ids): - X, y, d_name = load_cc18(data_id=data_id) - print(f"Loading [{d_name}] dataset..") - tmp = get_scores( - X=X, y=y, d_name=d_name, random_state=random_state, kwargs=params[i] - ) - df = pd.concat([df, tmp]) - -t_d = (datetime.now() - t0).seconds -print(f"It took {t_d} seconds to run the script") - -# Draw a comparison plot -d_names = df.dataset.unique() -N = d_names.shape[0] - -fig, ax = plt.subplots(1, N, figsize=(6 * N, 6)) - -for i, name in enumerate(d_names): - if N == 1: - axs = ax - else: - axs = ax[i] - dff = df.query(f'dataset == "{name}"') - - sns.stripplot(data=dff, x="model", y="score", ax=axs, dodge=True) - sns.boxplot(data=dff, x="model", y="score", ax=axs, color="white") - axs.set_title(f"{name} (#{data_ids[i]})") - - rf = dff.query('model=="RF"')["mean"].iloc[0] - rff = f"RF (Mean Test Score: {round(rf,3)})" - - of = dff.query('model=="OF"')["mean"].iloc[0] - off = f"OF (Mean Test Score: {round(of,3)})" - - axs.legend([rff, off], loc=4) - - if i != 0: - axs.set_ylabel("") - else: - axs.set_ylabel("Accuracy") - - axs.set_xlabel("") - -plt.savefig(f"plot_cc18_{t_d}s.jpg") -plt.show() diff --git a/examples/ensemble/plot_oblique_axis_aligned_forests_sparse_parity.py b/examples/ensemble/plot_oblique_axis_aligned_forests_sparse_parity.py deleted file mode 100644 index 5dbb9c2ae6e6e..0000000000000 --- a/examples/ensemble/plot_oblique_axis_aligned_forests_sparse_parity.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -============================================================================ -Plot oblique forest and axis-aligned random forest predictions on simulation -============================================================================ -A performance comparison between oblique forest and standard axis- -aligned random forest using sparse parity simulation dataset. -Sparse parity is a variation of the noisy parity problem, -which itself is a multivariate generalization of the noisy XOR problem. -This is a binary classification task in high dimensions. The simulation -will generate uniformly distributed `n_samples` number of sample points -in the range of -1 and +1 with `p` number of features. `p*` is a -parameter used to limit features that carry information about the class. -The informative binary label is then defined as 1 if there are odd number -of the sum of data `X` across first `p*` features that are greater than 0, -otherwise the label is defined as 0. The simulation is further detailed -in this [publication](https://epubs.siam.org/doi/epdf/10.1137/1.9781611974973.56). -""" - -import numpy as np -import pandas as pd -from datetime import datetime -import seaborn as sns -import matplotlib.pyplot as plt -from sklearn.ensemble import RandomForestClassifier, ObliqueRandomForestClassifier -from sklearn.model_selection import RepeatedKFold, cross_validate - -random_state = 123456 -t0 = datetime.now() - - -def sparse_parity(n_samples, p=20, p_star=3, random_seed=None): - if random_seed: - np.random.seed(random_seed) - - X = np.random.uniform(-1, 1, (n_samples, p)) - y = np.zeros(n_samples) - - for i in range(0, n_samples): - y[i] = sum(X[i, :p_star] > 0) % 2 - - return X, y - - -def get_scores(X, y, n_cv=5, n_repeats=1, random_state=1, kwargs=None): - clfs = [ - RandomForestClassifier(**kwargs[0], random_state=random_state), - ObliqueRandomForestClassifier(**kwargs[1], random_state=random_state), - ] - - tmp = [] - - for i, clf in enumerate(clfs): - cv = RepeatedKFold( - n_splits=n_cv, n_repeats=n_repeats, random_state=random_state - ) - test_score = cross_validate(estimator=clf, X=X, y=y, cv=cv, scoring="accuracy") - - tmp.append( - [["RF", "OF"][i], test_score["test_score"], test_score["test_score"].mean()] - ) - - df = pd.DataFrame(tmp, columns=["model", "score", "mean"]) - df = df.explode("score") - df["score"] = df["score"].astype(float) - df.reset_index(inplace=True, drop=True) - - return df - - -# Grid searched hyper-parameters -params = [ - {"max_features": None, "n_estimators": 100, "max_depth": None}, - {"max_features": 40, "n_estimators": 100, "max_depth": 20}, -] - -X, y = sparse_parity(n_samples=10000, random_seed=random_state) - -df = get_scores(X=X, y=y, n_cv=3, n_repeats=1, random_state=random_state, kwargs=params) -t_d = (datetime.now() - t0).seconds -print(f"It took {t_d} seconds to run the script") - -# Draw a comparison plot -fig, ax = plt.subplots(1, 1, figsize=(6, 6)) - -sns.stripplot(data=df, x="model", y="score", ax=ax, dodge=True) -sns.boxplot(data=df, x="model", y="score", ax=ax, color="white") -ax.set_title("Sparse Parity") - -rf = df.query('model=="RF"')["mean"].iloc[0] -rff = f"RF (Mean Test Score: {round(rf,3)})" - -of = df.query('model=="OF"')["mean"].iloc[0] -off = f"OF (Mean Test Score: {round(of,3)})" - -ax.legend([rff, off], loc=4) - -plt.savefig(f"plot_sim_{t_d}s.jpg") -plt.show() diff --git a/examples/ensemble/plot_oblique_random_forest.py b/examples/ensemble/plot_oblique_random_forest.py deleted file mode 100644 index 8634b1368e3e4..0000000000000 --- a/examples/ensemble/plot_oblique_random_forest.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -=============================================================================== -Plot oblique forest and axis-aligned random forest predictions on cc18 datasets -=============================================================================== - -A performance comparison between oblique forest and standard axis- -aligned random forest using three datasets from OpenML benchmarking suites. - -Two of these datasets, namely -[WDBC](https://www.openml.org/search?type=data&sort=runs&id=1510) -and [Phishing Website](https://www.openml.org/search?type=data&sort=runs&id=4534) -datasets consist of 31 features where the former dataset is entirely numeric -and the latter dataset is entirely norminal. The third dataset, dubbed -[cnae-9](https://www.openml.org/search?type=data&status=active&id=1468), is a -numeric dataset that has notably large feature space of 857 features. As you -will notice, of these three datasets, the oblique forest outperforms axis-aligned -random forest on cnae-9 utilizing sparse random projection machanism. All datasets -are subsampled due to computational constraints. -""" - -import pandas as pd -from datetime import datetime - -import seaborn as sns -import matplotlib.pyplot as plt -from sklearn.ensemble import RandomForestClassifier, ObliqueRandomForestClassifier -from sklearn.model_selection import RepeatedKFold, cross_validate -from sklearn.datasets import fetch_openml - - -random_state = 123456 -t0 = datetime.now() -data_ids = [4534, 1510, 1468] # openml dataset id -df = pd.DataFrame() - - -def load_cc18(data_id): - df = fetch_openml(data_id=data_id, as_frame=True, parser="pandas") - - # extract the dataset name - d_name = df.details["name"] - - # Subsampling large datasets - if data_id == 1468: - n = 100 - else: - n = int(df.frame.shape[0] * 0.8) - - df = df.frame.sample(n, random_state=random_state) - X, y = df.iloc[:, :-1], df.iloc[:, -1] - - return X, y, d_name - - -def get_scores(X, y, d_name, n_cv=5, n_repeats=1, **kwargs): - clfs = [RandomForestClassifier(**kwargs), ObliqueRandomForestClassifier(**kwargs)] - - tmp = [] - - for i, clf in enumerate(clfs): - cv = RepeatedKFold( - n_splits=n_cv, n_repeats=n_repeats, random_state=kwargs["random_state"] - ) - test_score = cross_validate(estimator=clf, X=X, y=y, cv=cv, scoring="accuracy") - - tmp.append( - [ - d_name, - ["RF", "OF"][i], - test_score["test_score"], - test_score["test_score"].mean(), - ] - ) - - df = pd.DataFrame( - tmp, columns=["dataset", "model", "score", "mean"] - ) # dtype=[('model',object), ('score',float), ('mean',float)]) - df = df.explode("score") - df["score"] = df["score"].astype(float) - df.reset_index(inplace=True, drop=True) - - return df - - -params = { - "max_features": None, - "n_estimators": 50, - "max_depth": None, - "random_state": random_state, - "n_cv": 3, - "n_repeats": 1, -} - -for data_id in data_ids: - X, y, d_name = load_cc18(data_id=data_id) - print(f"Loading [{d_name}] dataset..") - tmp = get_scores(X=X, y=y, d_name=d_name, **params) - df = pd.concat([df, tmp]) - -print(f"It took {(datetime.now()-t0).seconds} seconds to run the script") - -# Draw a comparison plot -d_names = df.dataset.unique() -N = d_names.shape[0] - -fig, ax = plt.subplots(1, N) -fig.set_size_inches(6 * N, 6) - -for i, name in enumerate(d_names): - sns.stripplot( - data=df.query(f'dataset == "{name}"'), - x="model", - y="score", - ax=ax[i], - dodge=True, - ) - sns.boxplot( - data=df.query(f'dataset == "{name}"'), - x="model", - y="score", - ax=ax[i], - color="white", - ) - ax[i].set_title(name) - if i != 0: - ax[i].set_ylabel("") - ax[i].set_xlabel("") diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index 6a8adf44465fc..0dcca718bc6f0 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -2,18 +2,12 @@ ======================================================================= Plot the decision surface of decision trees trained on the iris dataset ======================================================================= - -Plot the decision surface of a decision tree and oblique decision tree -trained on pairs of features of the iris dataset. - -See :ref:`decision tree ` for more information on the estimators. - -For each pair of iris features, the decision tree learns axis-aligned decision +Plot the decision surface of a decision tree trained on pairs +of features of the iris dataset. +See :ref:`decision tree ` for more information on the estimator. +For each pair of iris features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from -the training samples. The oblique decision tree learns oblique decision boundaries -made from linear combinations of the features in the training samples and then -the same thresholding rule as regular decision trees. - +the training samples. We also show the tree structure of a model built on all of the features. """ # %% @@ -29,7 +23,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import load_iris -from sklearn.tree import DecisionTreeClassifier, ObliqueDecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier from sklearn.inspection import DecisionBoundaryDisplay @@ -38,67 +32,52 @@ plot_colors = "ryb" plot_step = 0.02 -clf_labels = ["Axis-aligned", "Oblique"] -random_state = 123456 - -clfs = [ - DecisionTreeClassifier(random_state=random_state), - ObliqueDecisionTreeClassifier(random_state=random_state), -] -for clf, clf_label in zip(clfs, clf_labels): - fig, axes = plt.subplots(2, 3) - axes = axes.flatten() - - for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]): - # We only take the two corresponding features - X = iris.data[:, pair] - y = iris.target - - # Train - clf.fit(X, y) +for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]): + # We only take the two corresponding features + X = iris.data[:, pair] + y = iris.target + + # Train + clf = DecisionTreeClassifier().fit(X, y) + + # Plot the decision boundary + ax = plt.subplot(2, 3, pairidx + 1) + plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) + DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=plt.cm.RdYlBu, + response_method="predict", + ax=ax, + xlabel=iris.feature_names[pair[0]], + ylabel=iris.feature_names[pair[1]], + ) - # Plot the decision boundary - ax = axes[pairidx] - plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) - DecisionBoundaryDisplay.from_estimator( - clf, - X, + # Plot the training points + for i, color in zip(range(n_classes), plot_colors): + idx = np.where(y == i) + plt.scatter( + X[idx, 0], + X[idx, 1], + c=color, + label=iris.target_names[i], cmap=plt.cm.RdYlBu, - response_method="predict", - ax=ax, - xlabel=iris.feature_names[pair[0]], - ylabel=iris.feature_names[pair[1]], + edgecolor="black", + s=15, ) - # Plot the training points - for i, color in zip(range(n_classes), plot_colors): - idx = np.where(y == i) - ax.scatter( - X[idx, 0], - X[idx, 1], - c=color, - label=iris.target_names[i], - cmap=plt.cm.RdYlBu, - edgecolor="black", - s=15, - ) - - fig.suptitle( - f"Decision surface of {clf_label} decision trees trained on pairs of features" - ) - plt.legend(loc="lower right", borderpad=0, handletextpad=0) - _ = plt.axis("tight") - plt.show() +plt.suptitle("Decision surface of decision trees trained on pairs of features") +plt.legend(loc="lower right", borderpad=0, handletextpad=0) +_ = plt.axis("tight") # %% # Display the structure of a single decision tree trained on all the features # together. from sklearn.tree import plot_tree -for clf, clf_label in zip(clfs, clf_labels): - plt.figure() - clf.fit(iris.data, iris.target) - plot_tree(clf, filled=True) - plt.title(f"{clf_label} decision tree trained on all the iris features") - plt.show() +plt.figure() +clf = DecisionTreeClassifier().fit(iris.data, iris.target) +plot_tree(clf, filled=True) +plt.title("Decision tree trained on all the iris features") +plt.show() diff --git a/setup.py b/setup.py index 390e82de0b376..e39e39455b7bc 100644 --- a/setup.py +++ b/setup.py @@ -392,18 +392,6 @@ def check_package_status(package, min_version): "language": "c++", "optimization_level": "O3", }, - { - "sources": ["_oblique_tree.pyx"], - "language": "c++", - "include_np": True, - "optimization_level": "O3", - }, - { - "sources": ["_oblique_splitter.pyx"], - "language": "c++", - "include_np": True, - "optimization_level": "O3", - }, ], "utils": [ {"sources": ["sparsefuncs_fast.pyx"], "include_np": True}, diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index fca823f683d17..e892d36a0ce46 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -8,7 +8,6 @@ from ._forest import RandomTreesEmbedding from ._forest import ExtraTreesClassifier from ._forest import ExtraTreesRegressor -from ._forest import ObliqueRandomForestClassifier from ._bagging import BaggingClassifier from ._bagging import BaggingRegressor from ._iforest import IsolationForest @@ -32,7 +31,6 @@ "RandomTreesEmbedding", "ExtraTreesClassifier", "ExtraTreesRegressor", - "ObliqueRandomForestClassifier", "BaggingClassifier", "BaggingRegressor", "IsolationForest", diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 8b73aa7eb96c6..a3c29e4a269ce 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -40,6 +40,7 @@ class calls the ``fit`` method of each sub-estimator on random samples # License: BSD 3 clause +from time import time from numbers import Integral, Real from warnings import catch_warnings, simplefilter, warn import threading @@ -60,7 +61,6 @@ class calls the ``fit`` method of each sub-estimator on random samples DecisionTreeRegressor, ExtraTreeClassifier, ExtraTreeRegressor, - ObliqueDecisionTreeClassifier, ) from ..tree._tree import DTYPE, DOUBLE from ..utils import check_random_state, compute_sample_weight @@ -73,17 +73,17 @@ class calls the ``fit`` method of each sub-estimator on random samples _check_sample_weight, _check_feature_names_in, ) +from ..utils._openmp_helpers import _openmp_effective_n_threads from ..utils.validation import _num_samples from ..utils._param_validation import Interval, StrOptions from ..utils._param_validation import RealNotInt - +from ._hist_gradient_boosting.binning import _BinMapper __all__ = [ "RandomForestClassifier", "RandomForestRegressor", "ExtraTreesClassifier", "ExtraTreesRegressor", - "ObliqueRandomForestClassifier", "RandomTreesEmbedding", ] @@ -212,6 +212,10 @@ class BaseForest(MultiOutputMixin, BaseEnsemble, metaclass=ABCMeta): Interval(RealNotInt, 0.0, 1.0, closed="right"), Interval(Integral, 1, None, closed="left"), ], + "max_bins": [ + None, + Interval(Integral, 1, None, closed="left"), + ], } @abstractmethod @@ -230,6 +234,7 @@ def __init__( class_weight=None, max_samples=None, base_estimator="deprecated", + max_bins=None, ): super().__init__( estimator=estimator, @@ -246,6 +251,7 @@ def __init__( self.warm_start = warm_start self.class_weight = class_weight self.max_samples = max_samples + self.max_bins = max_bins def apply(self, X): """ @@ -265,6 +271,15 @@ def apply(self, X): return the index of the leaf x ends up in. """ X = self._validate_X_predict(X) + + # if we trained a binning tree, then we should re-bin the data + # XXX: this is inefficient and should be improved to be in line with what + # the Histogram Gradient Boosting Tree does, where the binning thresholds + # are passed into the tree itself, thus allowing us to set the node feature + # value thresholds within the tree itself. + if self.max_bins is not None: + X = self._bin_data(X, is_training_data=False).astype(DTYPE) + results = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, @@ -345,14 +360,9 @@ def fit(self, X, y, sample_weight=None): # Validate or convert input data if issparse(y): raise ValueError("sparse multilabel-indicator for y is not supported.") - if isinstance(self, ObliqueRandomForestClassifier): - X, y = self._validate_data( - X, y, multi_output=True, accept_sparse=False, dtype=DTYPE - ) - else: - X, y = self._validate_data( - X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE - ) + X, y = self._validate_data( + X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE + ) if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) @@ -427,6 +437,38 @@ def fit(self, X, y, sample_weight=None): n_more_estimators = self.n_estimators - len(self.estimators_) + if self.max_bins is not None: + # `_openmp_effective_n_threads` is used to take cgroups CPU quotes + # into account when determine the maximum number of threads to use. + n_threads = _openmp_effective_n_threads() + + # Bin the data + # For ease of use of the API, the user-facing GBDT classes accept the + # parameter max_bins, which doesn't take into account the bin for + # missing values (which is always allocated). However, since max_bins + # isn't the true maximal number of bins, all other private classes + # (binmapper, histbuilder...) accept n_bins instead, which is the + # actual total number of bins. Everywhere in the code, the + # convention is that n_bins == max_bins + 1 + n_bins = self.max_bins + 1 # + 1 for missing values + self._bin_mapper = _BinMapper( + n_bins=n_bins, + # is_categorical=self.is_categorical_, + known_categories=None, + random_state=random_state, + n_threads=n_threads, + ) + + # XXX: in order for this to work with the underlying tree submodule's Cython + # code, we need to convert this into the original data's DTYPE because + # the Cython code assumes that `DTYPE` is used. + # The proper implementation will be a lot more complicated and should be + # tackled once scikit-learn has finalized their inclusion of missing data + # and categorical support for decision trees + X = self._bin_data(X, is_training_data=True) # .astype(DTYPE) + else: + self._bin_mapper = None + if n_more_estimators < 0: raise ValueError( "n_estimators=%d must be larger or equal to " @@ -635,6 +677,35 @@ def feature_importances_(self): all_importances = np.mean(all_importances, axis=0, dtype=np.float64) return all_importances / np.sum(all_importances) + def _bin_data(self, X, is_training_data): + """Bin data X. + + If is_training_data, then fit the _bin_mapper attribute. + Else, the binned data is converted to a C-contiguous array. + """ + + description = "training" if is_training_data else "validation" + if self.verbose: + print( + "Binning {:.3f} GB of {} data: ".format(X.nbytes / 1e9, description), + end="", + flush=True, + ) + tic = time() + if is_training_data: + X_binned = self._bin_mapper.fit_transform(X) # F-aligned array + else: + X_binned = self._bin_mapper.transform(X) # F-aligned array + # We convert the array to C-contiguous since predicting is faster + # with this layout (training is faster on F-arrays though) + X_binned = np.ascontiguousarray(X_binned) + toc = time() + if self.verbose: + duration = toc - tic + print("{:.3f} s".format(duration)) + + return X_binned + def _accumulate_prediction(predict, X, out, lock): """ @@ -676,6 +747,7 @@ def __init__( class_weight=None, max_samples=None, base_estimator="deprecated", + max_bins=None, ): super().__init__( estimator=estimator, @@ -690,6 +762,7 @@ def __init__( class_weight=class_weight, max_samples=max_samples, base_estimator=base_estimator, + max_bins=max_bins, ) @staticmethod @@ -863,6 +936,14 @@ def predict_proba(self, X): # Check data X = self._validate_X_predict(X) + # if we trained a binning tree, then we should re-bin the data + # XXX: this is inefficient and should be improved to be in line with what + # the Histogram Gradient Boosting Tree does, where the binning thresholds + # are passed into the tree itself, thus allowing us to set the node feature + # value thresholds within the tree itself. + if self.max_bins is not None: + X = self._bin_data(X, is_training_data=False).astype(DTYPE) + # Assign chunk of trees to jobs n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) @@ -944,6 +1025,7 @@ def __init__( warm_start=False, max_samples=None, base_estimator="deprecated", + max_bins=None, ): super().__init__( estimator, @@ -957,6 +1039,7 @@ def __init__( warm_start=warm_start, max_samples=max_samples, base_estimator=base_estimator, + max_bins=max_bins, ) def predict(self, X): @@ -982,6 +1065,14 @@ def predict(self, X): # Check data X = self._validate_X_predict(X) + # if we trained a binning tree, then we should re-bin the data + # XXX: this is inefficient and should be improved to be in line with what + # the Histogram Gradient Boosting Tree does, where the binning thresholds + # are passed into the tree itself, thus allowing us to set the node feature + # value thresholds within the tree itself. + if self.max_bins is not None: + X = self._bin_data(X, is_training_data=False).astype(DTYPE) + # Assign chunk of trees to jobs n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) @@ -1406,6 +1497,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, max_samples=None, + max_bins=None, ): super().__init__( estimator=DecisionTreeClassifier(), @@ -1430,6 +1522,7 @@ def __init__( warm_start=warm_start, class_weight=class_weight, max_samples=max_samples, + max_bins=max_bins, ) self.criterion = criterion @@ -1741,6 +1834,7 @@ def __init__( warm_start=False, ccp_alpha=0.0, max_samples=None, + max_bins=None, ): super().__init__( estimator=DecisionTreeRegressor(), @@ -1764,6 +1858,7 @@ def __init__( verbose=verbose, warm_start=warm_start, max_samples=max_samples, + max_bins=max_bins, ) self.criterion = criterion @@ -2091,6 +2186,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, max_samples=None, + max_bins=None, ): super().__init__( estimator=ExtraTreeClassifier(), @@ -2115,6 +2211,7 @@ def __init__( warm_start=warm_start, class_weight=class_weight, max_samples=max_samples, + max_bins=max_bins, ) self.criterion = criterion @@ -2413,6 +2510,7 @@ def __init__( warm_start=False, ccp_alpha=0.0, max_samples=None, + max_bins=None, ): super().__init__( estimator=ExtraTreeRegressor(), @@ -2436,6 +2534,7 @@ def __init__( verbose=verbose, warm_start=warm_start, max_samples=max_samples, + max_bins=max_bins, ) self.criterion = criterion @@ -2812,363 +2911,3 @@ def transform(self, X): """ check_is_fitted(self) return self.one_hot_encoder_.transform(self.apply(X)) - - -class ObliqueRandomForestClassifier(ForestClassifier): - """ - An oblique random forest classifier. - - A oblique random forest is a meta estimator similar to a random - forest that fits a number of oblique decision tree classifiers - on various sub-samples of the dataset and uses averaging to - improve the predictive accuracy and control over-fitting. - - The sub-sample size is controlled with the `max_samples` parameter if - `bootstrap=True` (default), otherwise the whole dataset is used to build - each tree. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - n_estimators : int, default=100 - The number of trees in the forest. - - .. versionchanged:: 0.22 - The default value of ``n_estimators`` changed from 10 to 100 - in 0.22. - - criterion : {"gini", "entropy"}, default="gini" - The function to measure the quality of a split. Supported criteria are - "gini" for the Gini impurity and "entropy" for the information gain. - Note: this parameter is tree-specific. - - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: - - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a fraction and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_samples_leaf : int or float, default=1 - The minimum number of samples required to be at a leaf node. - A split point at any depth will only be considered if it leaves at - least ``min_samples_leaf`` training samples in each of the left and - right branches. This may have the effect of smoothing the model, - especially in regression. - - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a fraction and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_weight_fraction_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_features : {"sqrt", "log2", None}, int or float, default="sqrt" - The number of features to consider when looking for the best split: - - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a fraction and - `round(max_features * n_features)` features are considered at each - split. - - If "auto", then `max_features=sqrt(n_features)`. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - - .. versionchanged:: 1.1 - The default of `max_features` changed from `"auto"` to `"sqrt"`. - - .. deprecated:: 1.1 - The `"auto"` option was deprecated in 1.1 and will be removed - in 1.3. - - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - max_leaf_nodes : int, default=None - Grow trees with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - min_impurity_decrease : float, default=0.0 - A node will be split if this split induces a decrease of the impurity - greater than or equal to this value. - - The weighted impurity decrease equation is the following:: - - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) - - where ``N`` is the total number of samples, ``N_t`` is the number of - samples at the current node, ``N_t_L`` is the number of samples in the - left child, and ``N_t_R`` is the number of samples in the right child. - - ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, - if ``sample_weight`` is passed. - - .. versionadded:: 0.19 - - bootstrap : bool, default=True - Whether bootstrap samples are used when building trees. If False, the - whole dataset is used to build each tree. - - oob_score : bool, default=False - Whether to use out-of-bag samples to estimate the generalization score. - Only available if bootstrap=True. - - n_jobs : int, default=None - The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, - :meth:`decision_path` and :meth:`apply` are all parallelized over the - trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` - context. ``-1`` means using all processors. See :term:`Glossary - ` for more details. - - random_state : int, RandomState instance or None, default=None - Controls both the randomness of the bootstrapping of the samples used - when building trees (if ``bootstrap=True``) and the sampling of the - features to consider when looking for the best split at each node - (if ``max_features < n_features``). - See :term:`Glossary ` for details. - - verbose : int, default=0 - Controls the verbosity when fitting and predicting. - - warm_start : bool, default=False - When set to ``True``, reuse the solution of the previous call to fit - and add more estimators to the ensemble, otherwise, just fit a whole - new forest. See :term:`the Glossary `. - - class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \ - default=None - Weights associated with classes in the form ``{class_label: weight}``. - If not given, all classes are supposed to have weight one. For - multi-output problems, a list of dicts can be provided in the same - order as the columns of y. - - Note that for multioutput (including multilabel) weights should be - defined for each class of every column in its own dict. For example, - for four-class multilabel classification weights should be - [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of - [{1:1}, {2:5}, {3:1}, {4:1}]. - - The "balanced" mode uses the values of y to automatically adjust - weights inversely proportional to class frequencies in the input data - as ``n_samples / (n_classes * np.bincount(y))`` - - The "balanced_subsample" mode is the same as "balanced" except that - weights are computed based on the bootstrap sample for every tree - grown. - - For multi-output, the weights of each column of y will be multiplied. - - Note that these weights will be multiplied with sample_weight (passed - through the fit method) if sample_weight is specified. - - ccp_alpha : non-negative float, default=0.0 - Complexity parameter used for Minimal Cost-Complexity Pruning. The - subtree with the largest cost complexity that is smaller than - ``ccp_alpha`` will be chosen. By default, no pruning is performed. See - :ref:`minimal_cost_complexity_pruning` for details. - - .. versionadded:: 0.22 - - max_samples : int or float, default=None - If bootstrap is True, the number of samples to draw from X - to train each base estimator. - - - If None (default), then draw `X.shape[0]` samples. - - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. Thus, - `max_samples` should be in the interval `(0.0, 1.0]`. - - .. versionadded:: 0.22 - - feature_combinations : float, default=None - The number of features to combine on average at each split - of the decision trees. If ``None``, then will default to the minimum of - ``(1.5, n_features)``. This controls the number of non-zeros is the - projection matrix. Setting the value to 1.0 is equivalent to a - traditional decision-tree. ``feature_combinations * max_features`` - gives the number of expected non-zeros in the projection matrix of shape - ``(max_features, n_features)``. Thus this value must always be less than - ``n_features`` in order to be valid. - - Attributes - ---------- - estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` - The child estimator template used to create the collection of fitted - sub-estimators. - - estimators_ : list of ObliqueDecisionTreeClassifier - The collection of fitted sub-estimators. - - classes_ : ndarray of shape (n_classes,) or a list of such arrays - The classes labels (single output problem), or a list of arrays of - class labels (multi-output problem). - - n_classes_ : int or list - The number of classes (single output problem), or a list containing the - number of classes for each output (multi-output problem). - - n_features_in_ : int - Number of features seen during :term:`fit`. - - feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - feature_importances_ : ndarray of shape (n_features,) - The impurity-based feature importances. - The higher, the more important the feature. - The importance of a feature is computed as the (normalized) - total reduction of the criterion brought by that feature. It is also - known as the Gini importance. - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - oob_score_ : float - Score of the training dataset obtained using an out-of-bag estimate. - This attribute exists only when ``oob_score`` is True. - - oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \ - (n_samples, n_classes, n_outputs) - Decision function computed with out-of-bag estimate on the training - set. If n_estimators is small it might be possible that a data point - was never left out during the bootstrap. In this case, - `oob_decision_function_` might contain NaN. This attribute exists - only when ``oob_score`` is True. - - See Also - -------- - sklearn.tree.ObliqueDecisionTreeClassifier : An oblique decision - tree classifier. - sklearn.ensemble.RandomForestClassifier : An axis-aligned decision - forest classifier. - - Notes - ----- - The default values for the parameters controlling the size of the trees - (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and - unpruned trees which can potentially be very large on some data sets. To - reduce memory consumption, the complexity and size of the trees should be - controlled by setting those parameter values. - - The features are always randomly permuted at each split. Therefore, - the best found split may vary, even with the same training data, - ``max_features=n_features`` and ``bootstrap=False``, if the improvement - of the criterion is identical for several splits enumerated during the - search of the best split. To obtain a deterministic behaviour during - fitting, ``random_state`` has to be fixed. - - References - ---------- - .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. - - Examples - -------- - >>> from sklearn.ensemble import ObliqueRandomForestClassifier - >>> from sklearn.datasets import make_classification - >>> X, y = make_classification(n_samples=1000, n_features=4, - ... n_informative=2, n_redundant=0, - ... random_state=0, shuffle=False) - >>> clf = ObliqueRandomForestClassifier(max_depth=2, random_state=0) - >>> clf.fit(X, y) - ObliqueRandomForestClassifier(...) - >>> print(clf.predict([[0, 0, 0, 0]])) - [1] - """ - - _parameter_constraints: dict = { - **ForestClassifier._parameter_constraints, - **ObliqueDecisionTreeClassifier._parameter_constraints, - "class_weight": [ - StrOptions({"balanced_subsample", "balanced"}), - dict, - list, - None, - ], - } - _parameter_constraints.pop("splitter") - - def __init__( - self, - n_estimators=100, - *, - criterion="gini", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features="sqrt", - max_leaf_nodes=None, - min_impurity_decrease=0.0, - bootstrap=True, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - class_weight=None, - ccp_alpha=0.0, - max_samples=None, - feature_combinations=None, - ): - super().__init__( - estimator=ObliqueDecisionTreeClassifier(), - n_estimators=n_estimators, - estimator_params=( - "criterion", - "max_depth", - "min_samples_split", - "min_samples_leaf", - "min_weight_fraction_leaf", - "max_features", - "max_leaf_nodes", - "min_impurity_decrease", - "random_state", - "ccp_alpha", - "feature_combinations", - ), - bootstrap=bootstrap, - oob_score=oob_score, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose, - warm_start=warm_start, - class_weight=class_weight, - max_samples=max_samples, - ) - self.criterion = criterion - self.max_depth = max_depth - self.min_samples_split = min_samples_split - self.min_samples_leaf = min_samples_leaf - self.max_features = max_features - self.feature_combinations = feature_combinations - - # unused by oblique forests - self.min_weight_fraction_leaf = min_weight_fraction_leaf - self.max_leaf_nodes = max_leaf_nodes - self.min_impurity_decrease = min_impurity_decrease - self.ccp_alpha = ccp_alpha diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 854b781c0e31f..0150340f24bc6 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -46,7 +46,6 @@ from sklearn.ensemble import ExtraTreesRegressor from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestRegressor -from sklearn.ensemble import ObliqueRandomForestClassifier from sklearn.ensemble import RandomTreesEmbedding from sklearn.metrics import explained_variance_score, f1_score from sklearn.model_selection import train_test_split, cross_val_score @@ -55,7 +54,7 @@ from sklearn.utils.parallel import Parallel from sklearn.utils.validation import check_random_state -from sklearn.metrics import mean_squared_error, accuracy_score +from sklearn.metrics import mean_squared_error from sklearn.tree._classes import SPARSE_SPLITTERS @@ -99,7 +98,6 @@ FOREST_CLASSIFIERS = { "ExtraTreesClassifier": ExtraTreesClassifier, "RandomForestClassifier": RandomForestClassifier, - "ObliqueRandomForestClassifier": ObliqueRandomForestClassifier, } FOREST_REGRESSORS = { @@ -627,9 +625,6 @@ def test_forest_classifier_oob( ForestClassifier, X, y, X_type, lower_bound_accuracy, oob_score ): """Check that OOB score is close to score on a test set.""" - if ForestClassifier == ObliqueRandomForestClassifier and X_type != "array": - pytest.skip() - X = _convert_container(X, constructor_name=X_type) X_train, X_test, y_train, y_test = train_test_split( X, @@ -1251,8 +1246,6 @@ def check_sparse_input(name, X, X_sparse, y): @pytest.mark.parametrize("name", FOREST_ESTIMATORS) @pytest.mark.parametrize("sparse_matrix", (csr_matrix, csc_matrix, coo_matrix)) def test_sparse_input(name, sparse_matrix): - if name == "ObliqueRandomForestClassifier": - pytest.skip() X, y = datasets.make_multilabel_classification(random_state=0, n_samples=50) check_sparse_input(name, X, sparse_matrix(X), y) @@ -1283,10 +1276,7 @@ def check_memory_layout(name, dtype): y = iris.target assert_array_almost_equal(est.fit(X, y).predict(X), y) - if ( - est.estimator.splitter in SPARSE_SPLITTERS - and name != "ObliqueRandomForestClassifier" - ): + if est.estimator.splitter in SPARSE_SPLITTERS: # csr matrix X = csr_matrix(iris.data, dtype=dtype) y = iris.target @@ -1619,15 +1609,14 @@ def check_decision_path(name): np.diff(n_nodes_ptr), [e.tree_.node_count for e in est.estimators_] ) - if name != "ObliqueRandomForestClassifier": - # Assert that leaves index are correct - leaves = est.apply(X) - for est_id in range(leaves.shape[1]): - leave_indicator = [ - indicator[i, n_nodes_ptr[est_id] + j] - for i, j in enumerate(leaves[:, est_id]) - ] - assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) + # Assert that leaves index are correct + leaves = est.apply(X) + for est_id in range(leaves.shape[1]): + leave_indicator = [ + indicator[i, n_nodes_ptr[est_id] + j] + for i, j in enumerate(leaves[:, est_id]) + ] + assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) @pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) @@ -1867,91 +1856,6 @@ def test_random_trees_embedding_feature_names_out(): assert_array_equal(expected_names, names) -def test_oblique_forest_sparse_parity(): - # Sparse parity dataset - n = 1000 - X, y = _sparse_parity(n, random_state=0) - n_test = 0.1 - X_train, X_test, y_train, y_test = train_test_split( - X, - y, - test_size=n_test, - random_state=0, - ) - - rc_clf = ObliqueRandomForestClassifier(max_features=None, random_state=0) - rc_clf.fit(X_train, y_train) - y_hat = rc_clf.predict(X_test) - rc_accuracy = accuracy_score(y_test, y_hat) - - ri_clf = RandomForestClassifier(random_state=0) - ri_clf.fit(X_train, y_train) - y_hat = ri_clf.predict(X_test) - ri_accuracy = accuracy_score(y_test, y_hat) - - assert ri_accuracy < rc_accuracy - assert ri_accuracy > 0.45 - assert rc_accuracy > 0.5 - - -def test_oblique_forest_orthant(): - """Test oblique forests on orthant problem. - - It is expected that axis-aligned and oblique-aligned - forests will perform similarly. - """ - n = 500 - X, y = _orthant(n, p=6, random_state=0) - n_test = 0.3 - X_train, X_test, y_train, y_test = train_test_split( - X, - y, - test_size=n_test, - random_state=0, - ) - - rc_clf = ObliqueRandomForestClassifier(max_features=None, random_state=0) - rc_clf.fit(X_train, y_train) - y_hat = rc_clf.predict(X_test) - rc_accuracy = accuracy_score(y_test, y_hat) - - ri_clf = RandomForestClassifier(max_features="sqrt", random_state=0) - ri_clf.fit(X_train, y_train) - y_hat = ri_clf.predict(X_test) - ri_accuracy = accuracy_score(y_test, y_hat) - - assert rc_accuracy >= ri_accuracy - assert ri_accuracy > 0.84 - assert rc_accuracy > 0.85 - - -def test_oblique_forest_trunk(): - """Test oblique vs axis-aligned forests on Trunk.""" - n = 1000 - X, y = _trunk(n, p=100, random_state=0) - n_test = 0.2 - X_train, X_test, y_train, y_test = train_test_split( - X, - y, - test_size=n_test, - random_state=0, - ) - - rc_clf = ObliqueRandomForestClassifier(max_features=X.shape[1], random_state=0) - rc_clf.fit(X_train, y_train) - y_hat = rc_clf.predict(X_test) - rc_accuracy = accuracy_score(y_test, y_hat) - - ri_clf = RandomForestClassifier(max_features="sqrt", random_state=0) - ri_clf.fit(X_train, y_train) - y_hat = ri_clf.predict(X_test) - ri_accuracy = accuracy_score(y_test, y_hat) - - assert rc_accuracy > ri_accuracy - assert ri_accuracy > 0.83 - assert rc_accuracy > 0.86 - - # TODO(1.4): remove in 1.4 @pytest.mark.parametrize( "name", @@ -2001,3 +1905,60 @@ def test_round_samples_to_one_when_samples_too_low(class_weight): n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0 ) forest.fit(X, y) + + +@pytest.mark.parametrize("name", FOREST_CLASSIFIERS) +def test_classification_toy_withbins(name): + """Check classification on a toy dataset.""" + ForestClassifier = FOREST_CLASSIFIERS[name] + + clf = ForestClassifier(n_estimators=10, random_state=1, max_bins=255) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + assert 10 == len(clf) + + clf = ForestClassifier( + n_estimators=10, max_features=1, random_state=1, max_bins=255 + ) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + assert 10 == len(clf) + + # also test apply + leaf_indices = clf.apply(X) + assert leaf_indices.shape == (len(X), clf.n_estimators) + + +@pytest.mark.parametrize("name", FOREST_REGRESSORS) +@pytest.mark.parametrize( + "criterion", ("squared_error", "absolute_error", "friedman_mse") +) +def test_regression_criterion_withbins(name, criterion): + # Check consistency on regression dataset. + ForestRegressor = FOREST_REGRESSORS[name] + + reg = ForestRegressor( + n_estimators=5, criterion=criterion, random_state=1, max_bins=250 + ) + reg.fit(X_reg, y_reg) + score = reg.score(X_reg, y_reg) + assert ( + score > 0.93 + ), "Failed with max_features=None, criterion %s and score = %f" % ( + criterion, + score, + ) + + reg = ForestRegressor( + n_estimators=5, + criterion=criterion, + max_features=6, + random_state=1, + max_bins=250, + ) + reg.fit(X_reg, y_reg) + score = reg.score(X_reg, y_reg) + assert score > 0.92, "Failed with max_features=6, criterion %s and score = %f" % ( + criterion, + score, + ) diff --git a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py index 48b7a44f39ea8..b8c7c3cfb7c20 100644 --- a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py +++ b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py @@ -333,7 +333,7 @@ def test_confusion_matrix_with_unknown_labels(pyplot, constructor_name): def test_colormap_max(pyplot): """Check that the max color is used for the color of the text.""" - gray = pyplot.get_cmap("gray", 1024) + gray = pyplot.colormaps.get_cmap("gray") confusion_matrix = np.array([[1.0, 0.0], [0.0, 1.0]]) disp = ConfusionMatrixDisplay(confusion_matrix) diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index cbf3ba13a93c0..f7a8fd183c7cc 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -8,7 +8,6 @@ from ._classes import DecisionTreeRegressor from ._classes import ExtraTreeClassifier from ._classes import ExtraTreeRegressor -from ._classes import ObliqueDecisionTreeClassifier from ._export import export_graphviz, plot_tree, export_text __all__ = [ @@ -17,7 +16,6 @@ "DecisionTreeRegressor", "ExtraTreeClassifier", "ExtraTreeRegressor", - "ObliqueDecisionTreeClassifier", "export_graphviz", "plot_tree", "export_text", diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 8429c70e2a74f..bd54483bf2dfe 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -42,21 +42,18 @@ from ._criterion import BaseCriterion from ._splitter import BaseSplitter -from ._oblique_splitter import ObliqueSplitter -from ._oblique_tree import ObliqueTree from ._tree import DepthFirstTreeBuilder from ._tree import BestFirstTreeBuilder from ._tree import Tree from ._tree import _build_pruned_tree_ccp from ._tree import ccp_pruning_path -from . import _tree, _splitter, _criterion, _oblique_splitter +from . import _tree, _splitter, _criterion __all__ = [ "DecisionTreeClassifier", "DecisionTreeRegressor", "ExtraTreeClassifier", "ExtraTreeRegressor", - "ObliqueDecisionTreeClassifier", ] @@ -86,10 +83,6 @@ "random": _splitter.RandomSparseSplitter, } -OBLIQUE_DENSE_SPLITTERS = { - "best": _oblique_splitter.BestObliqueSplitter, -} - # ============================================================================= # Base decision tree # ============================================================================= @@ -1817,415 +1810,3 @@ def __init__( random_state=random_state, ccp_alpha=ccp_alpha, ) - - -class ObliqueDecisionTreeClassifier(DecisionTreeClassifier): - """A decision tree classifier. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - criterion : {"gini", "entropy"}, default="gini" - The function to measure the quality of a split. Supported criteria are - "gini" for the Gini impurity and "entropy" for the information gain. - - splitter : {"best", "random"}, default="best" - The strategy used to choose the split at each node. Supported - strategies are "best" to choose the best split and "random" to choose - the best random split. - - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: - - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a fraction and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_samples_leaf : int or float, default=1 - The minimum number of samples required to be at a leaf node. - A split point at any depth will only be considered if it leaves at - least ``min_samples_leaf`` training samples in each of the left and - right branches. This may have the effect of smoothing the model, - especially in regression. - - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a fraction and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_weight_fraction_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_features : int, float or {"auto", "sqrt", "log2"}, default=None - The number of features to consider when looking for the best split: - - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a fraction and - `int(max_features * n_features)` features are considered at each - split. - - If "auto", then `max_features=sqrt(n_features)`. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - Note: Compared to axis-aligned Random Forests, one can set - max_features to a number greater then ``n_features``. - - random_state : int, RandomState instance or None, default=None - Controls the randomness of the estimator. The features are always - randomly permuted at each split, even if ``splitter`` is set to - ``"best"``. When ``max_features < n_features``, the algorithm will - select ``max_features`` at random at each split before finding the best - split among them. But the best found split may vary across different - runs, even if ``max_features=n_features``. That is the case, if the - improvement of the criterion is identical for several splits and one - split has to be selected at random. To obtain a deterministic behaviour - during fitting, ``random_state`` has to be fixed to an integer. - See :term:`Glossary ` for details. - - max_leaf_nodes : int, default=None - Grow a tree with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - min_impurity_decrease : float, default=0.0 - A node will be split if this split induces a decrease of the impurity - greater than or equal to this value. - - The weighted impurity decrease equation is the following:: - - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) - - where ``N`` is the total number of samples, ``N_t`` is the number of - samples at the current node, ``N_t_L`` is the number of samples in the - left child, and ``N_t_R`` is the number of samples in the right child. - - ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, - if ``sample_weight`` is passed. - - .. versionadded:: 0.19 - - class_weight : dict, list of dict or "balanced", default=None - Weights associated with classes in the form ``{class_label: weight}``. - If None, all classes are supposed to have weight one. For - multi-output problems, a list of dicts can be provided in the same - order as the columns of y. - - Note that for multioutput (including multilabel) weights should be - defined for each class of every column in its own dict. For example, - for four-class multilabel classification weights should be - [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of - [{1:1}, {2:5}, {3:1}, {4:1}]. - - The "balanced" mode uses the values of y to automatically adjust - weights inversely proportional to class frequencies in the input data - as ``n_samples / (n_classes * np.bincount(y))`` - - For multi-output, the weights of each column of y will be multiplied. - - Note that these weights will be multiplied with sample_weight (passed - through the fit method) if sample_weight is specified. - - ccp_alpha : non-negative float, default=0.0 - Complexity parameter used for Minimal Cost-Complexity Pruning. The - subtree with the largest cost complexity that is smaller than - ``ccp_alpha`` will be chosen. By default, no pruning is performed. See - :ref:`minimal_cost_complexity_pruning` for details. - - .. versionadded:: 0.22 - - feature_combinations : float, default=None - The number of features to combine on average at each split - of the decision trees. If ``None``, then will default to the minimum of - ``(1.5, n_features)``. This controls the number of non-zeros is the - projection matrix. Setting the value to 1.0 is equivalent to a - traditional decision-tree. ``feature_combinations * max_features`` - gives the number of expected non-zeros in the projection matrix of shape - ``(max_features, n_features)``. Thus this value must always be less than - ``n_features`` in order to be valid. - - Attributes - ---------- - classes_ : ndarray of shape (n_classes,) or list of ndarray - The classes labels (single output problem), - or a list of arrays of class labels (multi-output problem). - - feature_importances_ : ndarray of shape (n_features,) - The impurity-based feature importances. - The higher, the more important the feature. - The importance of a feature is computed as the (normalized) - total reduction of the criterion brought by that feature. It is also - known as the Gini importance [4]_. - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - max_features_ : int - The inferred value of max_features. - - n_classes_ : int or list of int - The number of classes (for single output problems), - or a list containing the number of classes for each - output (for multi-output problems). - - n_features_in_ : int - Number of features seen during :term:`fit`. - - .. versionadded:: 0.24 - - feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. - - .. versionadded:: 1.0 - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - tree_ : Tree instance - The underlying Tree object. Please refer to - ``help(sklearn.tree._tree.Tree)`` for - attributes of Tree object. - - feature_combinations_ : float - The number of feature combinations on average taken to fit the tree. - - See Also - -------- - DecisionTreeClassifier : An axis-aligned decision tree classifier. - - Notes - ----- - Compared to ``DecisionTreeClassifier``, oblique trees can sample - more features then ``n_features``, where ``n_features`` is the number - of columns in ``X``. This is controlled via the ``max_features`` - parameter. In fact, sampling more times results in better - trees with the caveat that there is an increased computation. It is - always recommended to sample more if one is willing to spend the - computational resources. - - The default values for the parameters controlling the size of the trees - (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and - unpruned trees which can potentially be very large on some data sets. To - reduce memory consumption, the complexity and size of the trees should be - controlled by setting those parameter values. - - The :meth:`predict` method operates using the :func:`numpy.argmax` - function on the outputs of :meth:`predict_proba`. This means that in - case the highest predicted probabilities are tied, the classifier will - predict the tied class with the lowest index in :term:`classes_`. - - References - ---------- - - .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning - - .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification - and Regression Trees", Wadsworth, Belmont, CA, 1984. - - .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical - Learning", Springer, 2009. - - .. [4] L. Breiman, and A. Cutler, "Random Forests", - https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm - - Examples - -------- - >>> from sklearn.datasets import load_iris - >>> from sklearn.model_selection import cross_val_score - >>> from sklearn.tree import ObliqueDecisionTreeClassifier - >>> clf = ObliqueDecisionTreeClassifier(random_state=0) - >>> iris = load_iris() - >>> cross_val_score(clf, iris.data, iris.target, cv=10) - ... # doctest: +SKIP - ... - array([ 1. , 0.93..., 0.86..., 0.93..., 0.93..., - 0.93..., 0.93..., 1. , 0.93..., 1. ]) - """ - - _parameter_constraints = { - **DecisionTreeClassifier._parameter_constraints, - "feature_combinations": [Interval(Real, 1.0, None, closed="left"), None], - } - - def __init__( - self, - *, - criterion="gini", - splitter="best", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features=None, - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0.0, - class_weight=None, - ccp_alpha=0.0, - feature_combinations=None, - ): - super().__init__( - criterion=criterion, - splitter=splitter, - max_depth=max_depth, - min_samples_split=min_samples_split, - min_samples_leaf=min_samples_leaf, - min_weight_fraction_leaf=min_weight_fraction_leaf, - max_features=max_features, - max_leaf_nodes=max_leaf_nodes, - class_weight=class_weight, - random_state=random_state, - min_impurity_decrease=min_impurity_decrease, - ccp_alpha=ccp_alpha, - ) - - self.feature_combinations = feature_combinations - - def _build_tree( - self, - X, - y, - sample_weight, - min_samples_leaf, - min_weight_leaf, - max_leaf_nodes, - min_samples_split, - max_depth, - random_state, - ): - """Build the actual tree. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The target values (class labels) as integers or strings. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - min_samples_leaf : int or float - The minimum number of samples required to be at a leaf node. - min_weight_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights. - max_leaf_nodes : int, default=None - Grow a tree with ``max_leaf_nodes`` in best-first fashion. - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - random_state : int, RandomState instance or None, default=None - Controls the randomness of the estimator. - """ - n_samples, n_features = X.shape - - if self.feature_combinations is None: - self.feature_combinations_ = min(n_features, 1.5) - elif self.feature_combinations > n_features: - raise RuntimeError( - f"Feature combinations {self.feature_combinations} should not be " - f"greater than the possible number of features {n_features}" - ) - else: - self.feature_combinations_ = self.feature_combinations - - # Build tree - criterion = self.criterion - if not isinstance(criterion, BaseCriterion): - if is_classifier(self): - criterion = CRITERIA_CLF[self.criterion]( - self.n_outputs_, self.n_classes_ - ) - else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) - else: - # Make a deepcopy in case the criterion has mutable attributes that - # might be shared and modified concurrently during parallel fitting - criterion = copy.deepcopy(criterion) - - splitter = self.splitter - if issparse(X): - raise ValueError( - "Sparse input is not supported for oblique trees. " - "Please convert your data to a dense array." - ) - else: - OBLIQUE_SPLITTERS = OBLIQUE_DENSE_SPLITTERS - - if not isinstance(self.splitter, ObliqueSplitter): - splitter = OBLIQUE_SPLITTERS[self.splitter]( - criterion, - self.max_features_, - min_samples_leaf, - min_weight_leaf, - random_state, - self.feature_combinations_, - ) - - if is_classifier(self): - self.tree_ = ObliqueTree( - self.n_features_in_, self.n_classes_, self.n_outputs_ - ) - else: - self.tree_ = ObliqueTree( - self.n_features_in_, - # TODO: tree shouldn't need this in this case - np.array([1] * self.n_outputs_, dtype=np.intp), - self.n_outputs_, - ) - - # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise - if max_leaf_nodes < 0: - builder = DepthFirstTreeBuilder( - splitter, - min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - self.min_impurity_decrease, - ) - else: - builder = BestFirstTreeBuilder( - splitter, - min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - max_leaf_nodes, - self.min_impurity_decrease, - ) - - builder.build(self.tree_, X, y, sample_weight) - - if self.n_outputs_ == 1 and is_classifier(self): - self.n_classes_ = self.n_classes_[0] - self.classes_ = self.classes_[0] diff --git a/sklearn/tree/_oblique_splitter.pxd b/sklearn/tree/_oblique_splitter.pxd deleted file mode 100644 index 91882793ced3b..0000000000000 --- a/sklearn/tree/_oblique_splitter.pxd +++ /dev/null @@ -1,95 +0,0 @@ -# distutils: language = c++ - -# Authors: Adam Li -# Chester Huynh -# Parth Vora -# -# License: BSD 3 clause - -# See _oblique_splitter.pyx for details. - -import numpy as np -cimport numpy as cnp - -from ._criterion cimport Criterion - -from ._tree cimport DTYPE_t # Type of X -from ._tree cimport DOUBLE_t # Type of y, sample_weight -from ._tree cimport SIZE_t # Type for indices and counters -from ._tree cimport INT32_t # Signed 32 bit integer -from ._tree cimport UINT32_t # Unsigned 32 bit integer - -from ._splitter cimport Splitter -from ._splitter cimport SplitRecord -from ..utils._sorting cimport simultaneous_sort - -from libcpp.vector cimport vector - -cdef struct ObliqueSplitRecord: - # Data to track sample split - SIZE_t feature # Which feature to split on. - SIZE_t pos # Split samples array at the given position, - # i.e. count of samples below threshold for feature. - # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. - double improvement # Impurity improvement given parent node. - double impurity_left # Impurity of the left split. - double impurity_right # Impurity of the right split. - - vector[DTYPE_t]* proj_vec_weights # weights of the vector (max_features,) - vector[SIZE_t]* proj_vec_indices # indices of the features (max_features,) - - -cdef class BaseObliqueSplitter(Splitter): - # Base class for oblique splitting, where additional data structures and API is defined. - # - - # Oblique Splitting extra parameters - cdef vector[vector[DTYPE_t]] proj_mat_weights # nonzero weights of sparse proj_mat matrix - cdef vector[vector[SIZE_t]] proj_mat_indices # nonzero indices of sparse proj_mat matrix - - # TODO: assumes all oblique splitters only work with dense data - cdef const DTYPE_t[:, :] X - - # All oblique splitters (i.e. non-axis aligned splitters) require a - # function to sample a projection matrix that is applied to the feature matrix - # to quickly obtain the sampled projections for candidate splits. - cdef void sample_proj_mat( - self, - vector[vector[DTYPE_t]]& proj_mat_weights, - vector[vector[SIZE_t]]& proj_mat_indices - ) nogil - - # Redefined here since the new logic requires calling sample_proj_mat - cdef int node_reset( - self, - SIZE_t start, - SIZE_t end, - double* weighted_n_node_samples - ) except -1 nogil - - cdef int node_split( - self, - double impurity, # Impurity of the node - SplitRecord* split, - SIZE_t* n_constant_features - ) except -1 nogil - - -cdef class ObliqueSplitter(BaseObliqueSplitter): - # The splitter searches in the input space for a linear combination of features and a threshold - # to split the samples samples[start:end]. - - # Oblique Splitting extra parameters - cdef public double feature_combinations # Number of features to combine - cdef SIZE_t n_non_zeros # Number of non-zero features - cdef SIZE_t[::1] indices_to_sample # an array of indices to sample of size mtry X n_features - - # All oblique splitters (i.e. non-axis aligned splitters) require a - # function to sample a projection matrix that is applied to the feature matrix - # to quickly obtain the sampled projections for candidate splits. - cdef void sample_proj_mat( - self, - vector[vector[DTYPE_t]]& proj_mat_weights, - vector[vector[SIZE_t]]& proj_mat_indices - ) noexcept nogil diff --git a/sklearn/tree/_oblique_splitter.pyx b/sklearn/tree/_oblique_splitter.pyx deleted file mode 100644 index 98405a60cc409..0000000000000 --- a/sklearn/tree/_oblique_splitter.pyx +++ /dev/null @@ -1,545 +0,0 @@ -#distutils: language=c++ -#cython: language_level=3 -#cython: boundscheck=False -#cython: wraparound=False -#cython: profile=True - -cimport cython -import numpy as np -cimport numpy as cnp -cnp.import_array() - -from ._criterion cimport Criterion -from ._utils cimport rand_int - -from ._utils cimport log -from ._utils cimport rand_uniform -from ._utils cimport RAND_R_MAX - -from libcpp.vector cimport vector -from cython.operator cimport dereference as deref - - -cdef double INFINITY = np.inf - -# Mitigate precision differences between 32 bit and 64 bit -cdef DTYPE_t FEATURE_THRESHOLD = 1e-7 - -# Constant to switch between algorithm non zero value extract algorithm -# in SparseSplitter -cdef DTYPE_t EXTRACT_NNZ_SWITCH = 0.1 - - -cdef inline void _init_split(ObliqueSplitRecord* self, SIZE_t start_pos) noexcept nogil: - self.impurity_left = INFINITY - self.impurity_right = INFINITY - self.pos = start_pos - self.feature = 0 - self.threshold = 0. - self.improvement = -INFINITY - -cdef class BaseObliqueSplitter(Splitter): - """Abstract oblique splitter class. - - Splitters are called by tree builders to find the best splits on - both sparse and dense data, one split at a time. - """ - def __cinit__( - self, - Criterion criterion, - SIZE_t max_features, - SIZE_t min_samples_leaf, - double min_weight_leaf, - object random_state, - *argv - ): - """ - Parameters - ---------- - criterion : Criterion - The criterion to measure the quality of a split. - - max_features : SIZE_t - The maximal number of randomly selected features which can be - considered for a split. - - min_samples_leaf : SIZE_t - The minimal number of samples each leaf can have, where splits - which would result in having less samples in a leaf are not - considered. - - min_weight_leaf : double - The minimal weight each leaf can have, where the weight is the sum - of the weights of each sample in it. - - random_state : object - The user inputted random state to be used for pseudo-randomness - """ - self.criterion = criterion - - self.n_samples = 0 - self.n_features = 0 - - # Max features = output dimensionality of projection vectors - self.max_features = max_features - self.min_samples_leaf = min_samples_leaf - self.min_weight_leaf = min_weight_leaf - self.random_state = random_state - - # Sparse max_features x n_features projection matrix - self.proj_mat_weights = vector[vector[DTYPE_t]](self.max_features) - self.proj_mat_indices = vector[vector[SIZE_t]](self.max_features) - - def __getstate__(self): - return {} - - def __setstate__(self, d): - pass - - cdef int node_reset(self, SIZE_t start, SIZE_t end, - double* weighted_n_node_samples) except -1 nogil: - """Reset splitter on node samples[start:end]. - - Returns -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - - Parameters - ---------- - start : SIZE_t - The index of the first sample to consider - end : SIZE_t - The index of the last sample to consider - weighted_n_node_samples : ndarray, dtype=double pointer - The total weight of those samples - """ - - self.start = start - self.end = end - - self.criterion.init(self.y, - self.sample_weight, - self.weighted_n_samples, - self.samples) - self.criterion.set_sample_pointers(start, end) - - weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples - - # Clear all projection vectors - for i in range(self.max_features): - self.proj_mat_weights[i].clear() - self.proj_mat_indices[i].clear() - return 0 - - cdef void sample_proj_mat( - self, - vector[vector[DTYPE_t]]& proj_mat_weights, - vector[vector[SIZE_t]]& proj_mat_indices - ) noexcept nogil: - """ Sample the projection vector. - - This is a placeholder method. - """ - pass - - cdef int pointer_size(self) noexcept nogil: - """Get size of a pointer to record for ObliqueSplitter.""" - - return sizeof(ObliqueSplitRecord) - - cdef int node_split( - self, - double impurity, - SplitRecord* split, - SIZE_t* n_constant_features - ) except -1 nogil: - """Find the best split on node samples[start:end] - - Returns -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - """ - # typecast the pointer to an ObliqueSplitRecord - cdef ObliqueSplitRecord* oblique_split = (split) - - # Draw random splits and pick the best - cdef SIZE_t[::1] samples = self.samples - cdef SIZE_t start = self.start - cdef SIZE_t end = self.end - - # pointer array to store feature values to split on - cdef DTYPE_t[::1] feature_values = self.feature_values - cdef SIZE_t max_features = self.max_features - cdef SIZE_t min_samples_leaf = self.min_samples_leaf - cdef double min_weight_leaf = self.min_weight_leaf - - # keep track of split record for current node and the best split - # found among the sampled projection vectors - cdef ObliqueSplitRecord best_split, current_split - cdef double current_proxy_improvement = -INFINITY - cdef double best_proxy_improvement = -INFINITY - - cdef SIZE_t feat_i, p # index over computed features and start/end - cdef SIZE_t idx, jdx # index over max_feature, and - cdef SIZE_t partition_end - cdef DTYPE_t temp_d # to compute a projection feature value - - # instantiate the split records - _init_split(&best_split, end) - - # Sample the projection matrix - self.sample_proj_mat(self.proj_mat_weights, self.proj_mat_indices) - - # For every vector in the projection matrix - for feat_i in range(max_features): - # Projection vector has no nonzeros - if self.proj_mat_weights[feat_i].empty(): - continue - - # XXX: 'feature' is not actually used in oblique split records - # Just indicates which split was sampled - current_split.feature = feat_i - current_split.proj_vec_weights = &self.proj_mat_weights[feat_i] - current_split.proj_vec_indices = &self.proj_mat_indices[feat_i] - - # Compute linear combination of features and then - # sort samples according to the feature values. - for idx in range(start, end): - # initialize the feature value to 0 - feature_values[idx] = 0 - for jdx in range(0, current_split.proj_vec_indices.size()): - feature_values[idx] += self.X[ - samples[idx], deref(current_split.proj_vec_indices)[jdx] - ] * deref(current_split.proj_vec_weights)[jdx] - - # Sort the samples - sort(&feature_values[start], &samples[start], end - start) - - # Evaluate all splits - self.criterion.reset() - p = start - while p < end: - while (p + 1 < end and feature_values[p + 1] <= feature_values[p] + FEATURE_THRESHOLD): - p += 1 - - p += 1 - - if p < end: - current_split.pos = p - - # Reject if min_samples_leaf is not guaranteed - if (((current_split.pos - start) < min_samples_leaf) or - ((end - current_split.pos) < min_samples_leaf)): - continue - - self.criterion.update(current_split.pos) - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue - - current_proxy_improvement = \ - self.criterion.proxy_impurity_improvement() - - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - # sum of halves is used to avoid infinite value - current_split.threshold = feature_values[p - 1] / 2.0 + feature_values[p] / 2.0 - - if ( - (current_split.threshold == feature_values[p]) or - (current_split.threshold == INFINITY) or - (current_split.threshold == -INFINITY) - ): - current_split.threshold = feature_values[p - 1] - - best_split = current_split # copy - - # Reorganize into samples[start:best.pos] + samples[best.pos:end] - if best_split.pos < end: - partition_end = end - p = start - - while p < partition_end: - # Account for projection vector - temp_d = 0.0 - for j in range(best_split.proj_vec_indices.size()): - temp_d += self.X[samples[p], deref(best_split.proj_vec_indices)[j]] *\ - deref(best_split.proj_vec_weights)[j] - - if temp_d <= best_split.threshold: - p += 1 - - else: - partition_end -= 1 - samples[p], samples[partition_end] = \ - samples[partition_end], samples[p] - - self.criterion.reset() - self.criterion.update(best_split.pos) - self.criterion.children_impurity(&best_split.impurity_left, - &best_split.impurity_right) - best_split.improvement = self.criterion.impurity_improvement( - impurity, best_split.impurity_left, best_split.impurity_right) - - # Return values - deref(oblique_split).proj_vec_indices = best_split.proj_vec_indices - deref(oblique_split).proj_vec_weights = best_split.proj_vec_weights - deref(oblique_split).feature = best_split.feature - deref(oblique_split).pos = best_split.pos - deref(oblique_split).threshold = best_split.threshold - deref(oblique_split).improvement = best_split.improvement - deref(oblique_split).impurity_left = best_split.impurity_left - deref(oblique_split).impurity_right = best_split.impurity_right - return 0 - - -cdef class ObliqueSplitter(BaseObliqueSplitter): - def __cinit__( - self, - Criterion criterion, - SIZE_t max_features, - SIZE_t min_samples_leaf, - double min_weight_leaf, - object random_state, - double feature_combinations, - *argv - ): - """ - Parameters - ---------- - criterion : Criterion - The criterion to measure the quality of a split. - - max_features : SIZE_t - The maximal number of randomly selected features which can be - considered for a split. - - min_samples_leaf : SIZE_t - The minimal number of samples each leaf can have, where splits - which would result in having less samples in a leaf are not - considered. - - min_weight_leaf : double - The minimal weight each leaf can have, where the weight is the sum - of the weights of each sample in it. - - feature_combinations : double - The average number of features to combine in an oblique split. - Each feature is independently included with probability - ``feature_combination`` / ``n_features``. - - random_state : object - The user inputted random state to be used for pseudo-randomness - """ - # Oblique tree parameters - self.feature_combinations = feature_combinations - - # or max w/ 1... - self.n_non_zeros = max(int(self.max_features * self.feature_combinations), 1) - - def __getstate__(self): - return {} - - def __setstate__(self, d): - pass - - cdef int init( - self, - object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight - ) except -1: - Splitter.init(self, X, y, sample_weight) - - self.X = X - - # create a helper array for allowing efficient Fisher-Yates - self.indices_to_sample = np.arange(self.max_features * self.n_features, - dtype=np.intp) - return 0 - - -cdef class BestObliqueSplitter(ObliqueSplitter): - def __reduce__(self): - """Enable pickling the splitter.""" - return (BestObliqueSplitter, - ( - self.criterion, - self.max_features, - self.min_samples_leaf, - self.min_weight_leaf, - self.feature_combinations, - self.random_state - ), self.__getstate__()) - - cdef void sample_proj_mat( - self, - vector[vector[DTYPE_t]]& proj_mat_weights, - vector[vector[SIZE_t]]& proj_mat_indices - ) noexcept nogil: - """Sample oblique projection matrix. - - Randomly sample features to put in randomly sampled projection vectors - weight = 1 or -1 with probability 0.5. - - Note: vectors are passed by value, so & is needed to pass by reference. - - Parameters - ---------- - proj_mat_weights : vector of vectors reference - The memory address of projection matrix non-zero weights. - proj_mat_indices : vector of vectors reference - The memory address of projection matrix non-zero indices. - - Notes - ----- - Note that grid_size must be larger than or equal to n_non_zeros because - it is assumed ``feature_combinations`` is forced to be smaller than - ``n_features`` before instantiating an oblique splitter. - """ - - cdef SIZE_t n_features = self.n_features - cdef SIZE_t n_non_zeros = self.n_non_zeros - cdef UINT32_t* random_state = &self.rand_r_state - - cdef int i, feat_i, proj_i, rand_vec_index - cdef DTYPE_t weight - - # construct an array to sample from mTry x n_features set of indices - cdef SIZE_t[::1] indices_to_sample = self.indices_to_sample - cdef SIZE_t grid_size = self.max_features * self.n_features - - # shuffle indices over the 2D grid to sample using Fisher-Yates - for i in range(0, grid_size): - j = rand_int(0, grid_size - i, random_state) - indices_to_sample[j], indices_to_sample[i] = \ - indices_to_sample[i], indices_to_sample[j] - - # sample 'n_non_zeros' in a mtry X n_features projection matrix - # which consists of +/- 1's chosen at a 1/2s rate - for i in range(0, n_non_zeros): - # get the next index from the shuffled index array - rand_vec_index = indices_to_sample[i] - - # get the projection index and feature index - proj_i = rand_vec_index // n_features - feat_i = rand_vec_index % n_features - - # sample a random weight - weight = 1 if (rand_int(0, 2, random_state) == 1) else -1 - - proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero - proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero - - -# Sort n-element arrays pointed to by feature_values and samples, simultaneously, -# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(DTYPE_t* feature_values, SIZE_t* samples, SIZE_t n) noexcept nogil: - if n == 0: - return - cdef int maxd = 2 * log(n) - introsort(feature_values, samples, n, maxd) - - -cdef inline void swap(DTYPE_t* feature_values, SIZE_t* samples, - SIZE_t i, SIZE_t j) noexcept nogil: - # Helper for sort - feature_values[i], feature_values[j] = feature_values[j], feature_values[i] - samples[i], samples[j] = samples[j], samples[i] - - -cdef inline DTYPE_t median3(DTYPE_t* feature_values, SIZE_t n) noexcept nogil: - # Median of three pivot selection, after Bentley and McIlroy (1993). - # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef DTYPE_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] - if a < b: - if b < c: - return b - elif a < c: - return c - else: - return a - elif b < c: - if a < c: - return a - else: - return c - else: - return b - - -# Introsort with median of 3 pivot selection and 3-way partition function -# (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(DTYPE_t* feature_values, SIZE_t *samples, - SIZE_t n, int maxd) noexcept nogil: - cdef DTYPE_t pivot - cdef SIZE_t i, l, r - - while n > 1: - if maxd <= 0: # max depth limit exceeded ("gone quadratic") - heapsort(feature_values, samples, n) - return - maxd -= 1 - - pivot = median3(feature_values, n) - - # Three-way partition. - i = l = 0 - r = n - while i < r: - if feature_values[i] < pivot: - swap(feature_values, samples, i, l) - i += 1 - l += 1 - elif feature_values[i] > pivot: - r -= 1 - swap(feature_values, samples, i, r) - else: - i += 1 - - introsort(feature_values, samples, l, maxd) - feature_values += r - samples += r - n -= r - - -cdef inline void sift_down(DTYPE_t* feature_values, SIZE_t* samples, - SIZE_t start, SIZE_t end) noexcept nogil: - # Restore heap order in feature_values[start:end] by moving the max element to start. - cdef SIZE_t child, maxind, root - - root = start - while True: - child = root * 2 + 1 - - # find max of root, left child, right child - maxind = root - if child < end and feature_values[maxind] < feature_values[child]: - maxind = child - if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: - maxind = child + 1 - - if maxind == root: - break - else: - swap(feature_values, samples, root, maxind) - root = maxind - - -cdef void heapsort(DTYPE_t* feature_values, SIZE_t* samples, SIZE_t n) noexcept nogil: - cdef SIZE_t start, end - - # heapify - start = (n - 2) / 2 - end = n - while True: - sift_down(feature_values, samples, start, end) - if start == 0: - break - start -= 1 - - # sort by shrinking the heap, putting the max element immediately after it - end = n - 1 - while end > 0: - swap(feature_values, samples, 0, end) - sift_down(feature_values, samples, 0, end) - end = end - 1 diff --git a/sklearn/tree/_oblique_tree.pxd b/sklearn/tree/_oblique_tree.pxd deleted file mode 100644 index 41c2b7acc95ce..0000000000000 --- a/sklearn/tree/_oblique_tree.pxd +++ /dev/null @@ -1,53 +0,0 @@ -# distutils: language = c++ - -# Authors: Adam Li -# Chester Huynh -# Parth Vora -# -# License: BSD 3 clause - -# See _oblique_tree.pyx for details. - -import numpy as np -cimport numpy as cnp - -from libcpp.vector cimport vector - -from ._tree cimport DTYPE_t # Type of X -from ._tree cimport DOUBLE_t # Type of y, sample_weight -from ._tree cimport SIZE_t # Type for indices and counters -from ._tree cimport INT32_t # Signed 32 bit integer -from ._tree cimport UINT32_t # Unsigned 32 bit integer -from ._tree cimport Tree, Node, TreeBuilder - -from ._splitter cimport SplitRecord -from ._oblique_splitter cimport ObliqueSplitRecord - - -cdef class ObliqueTree(Tree): - cdef vector[vector[DTYPE_t]] proj_vec_weights # (capacity, n_features) array of projection vectors - cdef vector[vector[SIZE_t]] proj_vec_indices # (capacity, n_features) array of projection vectors - - # overridden methods - cdef int _resize_c( - self, - SIZE_t capacity=* - ) except -1 nogil - cdef int _set_split_node( - self, - SplitRecord* split_node, - Node *node - ) nogil except -1 - cdef DTYPE_t _compute_feature( - self, - const DTYPE_t[:, :] X_ndarray, - SIZE_t sample_index, - Node *node - ) noexcept nogil - cdef void _compute_feature_importances( - self, - cnp.float64_t[:] importances, - Node* node - ) noexcept nogil - - cpdef cnp.ndarray get_projection_matrix(self) diff --git a/sklearn/tree/_oblique_tree.pyx b/sklearn/tree/_oblique_tree.pyx deleted file mode 100644 index c73bc3c2376ed..0000000000000 --- a/sklearn/tree/_oblique_tree.pyx +++ /dev/null @@ -1,320 +0,0 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False - -# Authors: Adam Li -# Chester Huynh -# Parth Vora -# -# License: BSD 3 clause - -from libc.string cimport memcpy -from libc.string cimport memset -from libc.stdint cimport INTPTR_MAX - -import numpy as np -cimport numpy as cnp -cnp.import_array() - -from scipy.sparse import issparse -from scipy.sparse import csr_matrix - -from cython.operator cimport dereference as deref - -from ._utils cimport safe_realloc -from ._utils cimport sizet_ptr_to_ndarray - -# Gets Node dtype exposed inside oblique_tree. -# See "_tree.pyx" for more details. -cdef Node dummy; -NODE_DTYPE = np.asarray((&dummy)).dtype - -# Mitigate precision differences between 32 bit and 64 bit -cdef DTYPE_t FEATURE_THRESHOLD = 1e-7 - -# ============================================================================= -# ObliqueTree -# ============================================================================= - -cdef class ObliqueTree(Tree): - """Array-based representation of a binary oblique decision tree. - - The oblique decision tree is represented as a number of parallel arrays. The i-th - element of each array holds information about the node `i`. Node 0 is the - tree's root. You can find a detailed description of all arrays in - `_tree.pxd`. NOTE: Some of the arrays only apply to either leaves or split - nodes, resp. In this case the values of nodes of the other type are - arbitrary! - - Attributes - ---------- - node_count : int - The number of nodes (internal nodes + leaves) in the tree. - - capacity : int - The current capacity (i.e., size) of the arrays, which is at least as - great as `node_count`. - - max_depth : int - The depth of the tree, i.e. the maximum depth of its leaves. - - children_left : array of int, shape [node_count] - children_left[i] holds the node id of the left child of node i. - For leaves, children_left[i] == TREE_LEAF. Otherwise, - children_left[i] > i. This child handles the case where - X[:, feature[i]] <= threshold[i]. - - children_right : array of int, shape [node_count] - children_right[i] holds the node id of the right child of node i. - For leaves, children_right[i] == TREE_LEAF. Otherwise, - children_right[i] > i. This child handles the case where - X[:, feature[i]] > threshold[i]. - - feature : array of int, shape [node_count] - feature[i] holds the feature to split on, for the internal node i. - - threshold : array of double, shape [node_count] - threshold[i] holds the threshold for the internal node i. - - value : array of double, shape [node_count, n_outputs, max_n_classes] - Contains the constant prediction value of each node. - - impurity : array of double, shape [node_count] - impurity[i] holds the impurity (i.e., the value of the splitting - criterion) at node i. - - n_node_samples : array of int, shape [node_count] - n_node_samples[i] holds the number of training samples reaching node i. - - weighted_n_node_samples : array of int, shape [node_count] - weighted_n_node_samples[i] holds the weighted number of training samples - reaching node i. - """ - def __cinit__( - self, - int n_features, - cnp.ndarray[SIZE_t, ndim=1] n_classes, - int n_outputs - ): - """Constructor.""" - # Input/Output layout - self.n_features = n_features - self.n_outputs = n_outputs - self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) - - self.max_n_classes = np.max(n_classes) - self.value_stride = n_outputs * self.max_n_classes - - cdef SIZE_t k - for k in range(n_outputs): - self.n_classes[k] = n_classes[k] - - # Inner structures - self.max_depth = 0 - self.node_count = 0 - self.capacity = 0 - self.value = NULL - self.nodes = NULL - - self.proj_vec_weights = vector[vector[DTYPE_t]](self.capacity) - self.proj_vec_indices = vector[vector[SIZE_t]](self.capacity) - - def __reduce__(self): - """Reduce re-implementation, for pickling.""" - return (ObliqueTree, ( - self.n_features, - sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__() - ) - - def __getstate__(self): - """Getstate re-implementation, for pickling.""" - d = {} - # capacity is inferred during the __setstate__ using nodes - d["max_depth"] = self.max_depth - d["node_count"] = self.node_count - d["nodes"] = self._get_node_ndarray() - d["values"] = self._get_value_ndarray() - - proj_vecs = self.get_projection_matrix() - d['proj_vecs'] = proj_vecs - return d - - def __setstate__(self, d): - """Setstate re-implementation, for unpickling.""" - self.max_depth = d["max_depth"] - self.node_count = d["node_count"] - - if 'nodes' not in d: - raise ValueError('You have loaded ObliqueTree version which ' - 'cannot be imported') - - node_ndarray = d['nodes'] - value_ndarray = d['values'] - - value_shape = (node_ndarray.shape[0], self.n_outputs, self.max_n_classes) - if (node_ndarray.ndim != 1 or - node_ndarray.dtype != NODE_DTYPE or - not node_ndarray.flags.c_contiguous or - value_ndarray.shape != value_shape or - not value_ndarray.flags.c_contiguous or - value_ndarray.dtype != np.float64): - raise ValueError('Did not recognise loaded array layout') - - self.capacity = node_ndarray.shape[0] - if self._resize_c(self.capacity) != 0: - raise MemoryError("resizing tree to %d" % self.capacity) - - # now set the projection vector weights and indices - proj_vecs = d['proj_vecs'] - self.n_features = proj_vecs.shape[1] - for i in range(0, self.node_count): - for j in range(0, self.n_features): - weight = proj_vecs[i, j] - if weight == 0: - continue - self.proj_vec_weights[i].push_back(weight) - self.proj_vec_indices[i].push_back(j) - - nodes = memcpy(self.nodes, cnp.PyArray_DATA(node_ndarray), - self.capacity * sizeof(Node)) - value = memcpy(self.value, cnp.PyArray_DATA(value_ndarray), - self.capacity * self.value_stride * sizeof(double)) - - cpdef cnp.ndarray get_projection_matrix(self): - """Get the projection matrix of shape (node_count, n_features).""" - proj_vecs = np.zeros((self.node_count, self.n_features), dtype=np.float64) - for i in range(0, self.node_count): - for j in range(0, self.proj_vec_weights[i].size()): - weight = self.proj_vec_weights[i][j] - feat = self.proj_vec_indices[i][j] - proj_vecs[i, feat] = weight - return proj_vecs - - cdef int _resize_c(self, SIZE_t capacity=INTPTR_MAX) except -1 nogil: - """Guts of _resize. - - Additionally resizes the projection indices and weights. - - Returns -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - """ - if capacity == self.capacity and self.nodes != NULL: - return 0 - - if capacity == INTPTR_MAX: - if self.capacity == 0: - capacity = 3 # default initial value - else: - capacity = 2 * self.capacity - - safe_realloc(&self.nodes, capacity) - safe_realloc(&self.value, capacity * self.value_stride) - - # only thing added for oblique trees - # TODO: this could possibly be removed if we can add projection - # indices and weights to Node - self.proj_vec_weights.resize(capacity) - self.proj_vec_indices.resize(capacity) - - # value memory is initialised to 0 to enable classifier argmax - if capacity > self.capacity: - memset((self.value + self.capacity * self.value_stride), 0, - (capacity - self.capacity) * self.value_stride * sizeof(double)) - - # if capacity smaller than node_count, adjust the counter - if capacity < self.node_count: - self.node_count = capacity - - self.capacity = capacity - return 0 - - cdef int _set_split_node(self, SplitRecord* split_node, Node *node) except -1 nogil: - """Set node data. - """ - # Cython type cast split record into its inherited split record - # For reference, see: - # https://www.codementor.io/@arpitbhayani/powering-inheritance-in-c-using-structure-composition-176sygr724 - cdef ObliqueSplitRecord* oblique_split_node = (split_node) - cdef SIZE_t node_id = self.node_count - - node.feature = deref(oblique_split_node).feature - node.threshold = deref(oblique_split_node).threshold - - # oblique trees store the projection indices and weights - # inside the tree itself - self.proj_vec_weights[node_id] = deref( - deref(oblique_split_node).proj_vec_weights - ) - self.proj_vec_indices[node_id] = deref( - deref(oblique_split_node).proj_vec_indices - ) - return 1 - - cdef DTYPE_t _compute_feature( - self, - const DTYPE_t[:, :] X_ndarray, - SIZE_t sample_index, - Node *node - ) noexcept nogil: - """Compute feature from a given data matrix, X. - - In oblique-aligned trees, this is the projection of X. - In this case, we take a simple linear combination of some columns of X. - """ - cdef DTYPE_t proj_feat = 0.0 - cdef DTYPE_t weight = 0.0 - cdef int j = 0 - cdef SIZE_t feature_index - - # get the index of the node - cdef SIZE_t node_id = node - self.nodes - - # cdef SIZE_t n_projections = proj_vec_indices.size() - # compute projection of the data based on trained tree - # proj_vec_weights = self.proj_vec_weights[node_id] - # proj_vec_indices = self.proj_vec_indices[node_id] - for j in range(0, self.proj_vec_indices[node_id].size()): - feature_index = self.proj_vec_indices[node_id][j] - weight = self.proj_vec_weights[node_id][j] - - # skip a multiplication step if there is nothing to be done - if weight == 0.0: - continue - proj_feat += X_ndarray[sample_index, feature_index] * weight - - return proj_feat - - cdef void _compute_feature_importances( - self, - cnp.float64_t[:] importances, - Node* node - ) noexcept nogil: - """Compute feature importances from a Node in the Tree. - - Wrapped in a private function to allow subclassing that - computes feature importances. - """ - cdef Node* nodes = self.nodes - cdef Node* left - cdef Node* right - - # get the index of the node - cdef SIZE_t node_id = node - self.nodes - - left = &nodes[node.left_child] - right = &nodes[node.right_child] - - cdef int i, feature_index - cdef DTYPE_t weight - for i in range(0, self.proj_vec_indices[node_id].size()): - feature_index = self.proj_vec_indices[node_id][i] - weight = self.proj_vec_weights[node_id][i] - if weight < 0: - weight *= -1 - - importances[feature_index] += weight * ( - node.weighted_n_node_samples * node.impurity - - left.weighted_n_node_samples * left.impurity - - right.weighted_n_node_samples * right.impurity) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 89d41bc59b2b0..69f948839259a 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -27,7 +27,6 @@ from sklearn.metrics import mean_poisson_deviance from sklearn.model_selection import train_test_split -from sklearn.model_selection import cross_val_score from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal @@ -46,7 +45,6 @@ from sklearn.tree import DecisionTreeRegressor from sklearn.tree import ExtraTreeClassifier from sklearn.tree import ExtraTreeRegressor -from sklearn.tree import ObliqueDecisionTreeClassifier from sklearn import tree from sklearn.tree._tree import TREE_LEAF, TREE_UNDEFINED @@ -70,7 +68,6 @@ CLF_TREES = { "DecisionTreeClassifier": DecisionTreeClassifier, "ExtraTreeClassifier": ExtraTreeClassifier, - "ObliqueDecisionTreeClassifier": ObliqueDecisionTreeClassifier, } REG_TREES = { @@ -458,10 +455,7 @@ def test_importances(): n_important = np.sum(importances > 0.1) assert importances.shape[0] == 10, "Failed with {0}".format(name) - if "Oblique" in name: - assert n_important >= 4, "Failed with {0}".format(name) - else: - assert n_important == 4, "Failed with {0}".format(name) + assert n_important == 4, "Failed with {0}".format(name) # Check on iris that importances are the same for all builders clf = DecisionTreeClassifier(random_state=0) @@ -472,9 +466,7 @@ def test_importances(): assert_array_equal(clf.feature_importances_, clf2.feature_importances_) -@pytest.mark.parametrize( - "clf", [DecisionTreeClassifier(), ObliqueDecisionTreeClassifier()] -) +@pytest.mark.parametrize("clf", [DecisionTreeClassifier()]) def test_importances_raises(clf): # Check if variable importance before fit raises ValueError. with pytest.raises(ValueError): @@ -927,12 +919,6 @@ def test_pickle(): est2 = pickle.loads(serialized_object) assert type(est2) == est.__class__ - # Oblique decision trees should have matching projection matrices - if name == "ObliqueDecisionTreeClassifier": - est_proj_mat = est.tree_.get_projection_matrix() - est2_proj_mat = est2.tree_.get_projection_matrix() - assert_array_equal(est_proj_mat, est2_proj_mat) - # score should match before/after pickling score2 = est2.score(X, y) assert ( @@ -1072,10 +1058,6 @@ def test_memory_layout(): y = iris.target[::3] assert_array_equal(est.fit(X, y).predict(X), y) - # Oblique trees do not support sparse data - if name == "ObliqueDecisionTreeClassifier": - continue - # csr matrix X = csr_matrix(iris.data, dtype=dtype) y = iris.target @@ -1278,8 +1260,8 @@ def test_behaviour_constant_feature_after_splits(): ) y = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3] for name, TreeEstimator in ALL_TREES.items(): - # do not check extra random trees or oblique trees - if all(_name not in name for _name in ["ExtraTree", "Oblique"]): + # do not check extra random trees + if all(_name not in name for _name in ["ExtraTree"]): est = TreeEstimator(random_state=0, max_features=1) est.fit(X, y) assert est.tree_.max_depth == 2 @@ -1291,8 +1273,6 @@ def test_with_only_one_non_constant_features(): y = np.array([0.0, 1.0, 0.0, 1.0]) for name, TreeEstimator in CLF_TREES.items(): - if name == "ObliqueDecisionTreeClassifier": - continue est = TreeEstimator(random_state=0, max_features=1) est.fit(X, y) assert est.tree_.max_depth == 1 @@ -1592,12 +1572,7 @@ def test_1d_input(name): def _check_min_weight_leaf_split_level(TreeEstimator, X, y, sample_weight): est = TreeEstimator(random_state=0) est.fit(X, y, sample_weight=sample_weight) - - # Oblique trees are more shallow by default - if isinstance(TreeEstimator, ObliqueDecisionTreeClassifier): - assert est.tree_.max_depth == 0 - else: - assert est.tree_.max_depth == 1 + assert est.tree_.max_depth == 1 est = TreeEstimator(random_state=0, min_weight_fraction_leaf=0.4) est.fit(X, y, sample_weight=sample_weight) @@ -1613,8 +1588,6 @@ def check_min_weight_leaf_split_level(name): _check_min_weight_leaf_split_level(TreeEstimator, X, y, sample_weight) # skip for sparse inputs - if name == "ObliqueDecisionTreeClassifier": - pytest.skip() _check_min_weight_leaf_split_level(TreeEstimator, csc_matrix(X), y, sample_weight) @@ -1674,9 +1647,7 @@ def check_decision_path(name): leaves = est.apply(X) leave_indicator = [node_indicator[i, j] for i, j in enumerate(leaves)] - # Oblique trees have possibly different leaves - if "Oblique" not in name: - assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) + assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) # Ensure only one leave node per sample all_leaves = est.tree_.children_left == TREE_LEAF @@ -1963,15 +1934,9 @@ def test_apply_path_readonly_all_trees(name, splitter, X_format): dataset = DATASETS["clf_small"] X_small = dataset["X"].astype(tree._tree.DTYPE, copy=False) - if name == "ObliqueDecisionTreeClassifier" and splitter == "random": - pytest.skip() - if X_format == "dense": X_readonly = create_memmap_backed_data(X_small) else: - if name == "ObliqueDecisionTreeClassifier": - pytest.skip() - X_readonly = dataset["X_sparse"] # CSR if X_format == "csc": # Cheap CSR to CSC conversion @@ -2453,33 +2418,3 @@ def test_min_sample_split_1_error(Tree): ) with pytest.raises(ValueError, match=msg): tree.fit(X, y) - - -def test_oblique_tree_sampling(): - """Test Oblique Decision Trees. - - Oblique trees can sample more candidate splits then - a normal axis-aligned tree. - """ - X, y = iris.data, iris.target - n_samples, n_features = X.shape - - # add additional noise dimensions - rng = np.random.RandomState(0) - X_noise = rng.random((n_samples, n_features)) - X = np.concatenate((X, X_noise), axis=1) - - # oblique decision trees can sample significantly more - # diverse sets of splits and will do better if allowed - # to sample more - tree_ri = DecisionTreeClassifier(random_state=0, max_features=n_features) - tree_rc = ObliqueDecisionTreeClassifier(random_state=0, max_features=n_features * 2) - ri_cv_scores = cross_val_score( - tree_ri, X, y, scoring="accuracy", cv=10, error_score="raise" - ) - rc_cv_scores = cross_val_score( - tree_rc, X, y, scoring="accuracy", cv=10, error_score="raise" - ) - assert rc_cv_scores.mean() > ri_cv_scores.mean() - assert rc_cv_scores.std() < ri_cv_scores.std() - assert rc_cv_scores.mean() > 0.91