diff --git a/docs/sources/CHANGELOG.html b/docs/sources/CHANGELOG.html deleted file mode 100644 index 98abe2dc6..000000000 --- a/docs/sources/CHANGELOG.html +++ /dev/null @@ -1,1469 +0,0 @@ - - -CHANGELOG

Release Notes

-
-

The CHANGELOG for the current development version is available at
-https://github.com/rasbt/mlxtend/blob/master/docs/sources/CHANGELOG.md.

-
-

Version 0.9.1 (2017-11-19)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.9.0 (2017-10-21)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.8.0 (2017-09-09)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.7.0 (2017-06-22)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.6.0 (2017-03-18)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.5.1 (2017-02-14)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.5.0 (2016-11-09)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.4.2 (2016-08-24)

-
Downloads
- -
New Features
- -
Changes
- -
Bug Fixes
- -

Version 0.4.1 (2016-05-01)

-
Downloads
- -
New Features
- -
Changes
- -

Version 0.4.0 (2016-04-09)

-
New Features
- -
Changes
- -

Version 0.3.0 (2016-01-31)

-
Downloads
- -
New Features
- -
Changes
- -

Version 0.2.9 (2015-07-14)

-
Downloads
- -
New Features
- -
Changes
- -

Version 0.2.8 (2015-06-27)

- -

Version 0.2.7 (2015-06-20)

- -

Version 0.2.6 (2015-05-08)

- -

Version 0.2.5 (2015-04-17)

- -

Version 0.2.4 (2015-03-15)

- -

Version 0.2.3 (2015-03-11)

- -

Version 0.2.2 (2015-03-01)

- -

Version 0.2.1 (2015-01-20)

- -

Version 0.2.0 (2015-01-13)

- -

Version 0.1.9 (2015-01-10)

- -

Version 0.1.8 (2015-01-07)

- -

Version 0.1.7 (2015-01-07)

- -

Version 0.1.6 (2015-01-04)

- -

Version 0.1.5 (2014-12-11)

- -

Version 0.1.4 (2014-08-20)

- -

Version 0.1.3 (2014-08-19)

- -

Version 0.1.2 (2014-08-19)

- -

Version 0.1.1 (2014-08-13)

-
\ No newline at end of file diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index 108df5271..64733fa3b 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -18,9 +18,9 @@ The CHANGELOG for the current development version is available at ##### New Features - A new `feature_importance_permuation` function to compute the feature importance in classifiers and regressors via the *permutation importance* method ([#358](https://github.com/rasbt/mlxtend/pull/358)) -- The fit method of the ExhaustiveFeatureSelector now optionally accepts **fit_params for the estimator that is used for the feature selection. ([#354](https://github.com/rasbt/mlxtend/pull/354) by Zach Griffith) -- The fit method of the SequentialFeatureSelector now optionally accepts -**fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith) +- The fit method of the `ExhaustiveFeatureSelector` now optionally accepts `**fit_params` for the estimator that is used for the feature selection. ([#354](https://github.com/rasbt/mlxtend/pull/354) by Zach Griffith) +- The fit method of the `SequentialFeatureSelector` now optionally accepts +`**fit_params` for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith) ##### Changes @@ -34,6 +34,7 @@ The CHANGELOG for the current development version is available at ##### Bug Fixes - Various changes in the documentation and documentation tools to fix formatting issues ([#363](https://github.com/rasbt/mlxtend/pull/363)) +- Fixes a bug where the `StackingCVClassifier`'s meta features were not stored in the original order when `shuffle=True` ([#370](https://github.com/rasbt/mlxtend/pull/370)) diff --git a/docs/sources/user_guide/classifier/StackingCVClassifier.ipynb b/docs/sources/user_guide/classifier/StackingCVClassifier.ipynb index 92e8fd7e2..0b845923e 100644 --- a/docs/sources/user_guide/classifier/StackingCVClassifier.ipynb +++ b/docs/sources/user_guide/classifier/StackingCVClassifier.ipynb @@ -517,7 +517,7 @@ "text": [ "## StackingCVClassifier\n", "\n", - "*StackingCVClassifier(classifiers, meta_classifier, use_probas=False, cv=2, use_features_in_secondary=False, stratify=True, shuffle=True, verbose=0, store_train_meta_features=False, refit=True)*\n", + "*StackingCVClassifier(classifiers, meta_classifier, use_probas=False, cv=2, use_features_in_secondary=False, stratify=True, shuffle=True, verbose=0, store_train_meta_features=False, use_clones=True)*\n", "\n", "A 'Stacking Cross-Validation' classifier for scikit-learn estimators.\n", "\n", @@ -601,15 +601,19 @@ " `self.train_meta_features_` array, which can be\n", " accessed after calling `fit`.\n", "\n", - "- `refit` : bool (default: True)\n", + "- `use_clones` : bool (default: True)\n", "\n", " Clones the classifiers for stacking classification if True (default)\n", " or else uses the original ones, which will be refitted on the dataset\n", - " upon calling the `fit` method. Setting refit=False is\n", + " upon calling the `fit` method. Hence, if use_clones=True, the original\n", + " input classifiers will remain unmodified upon using the\n", + " StackingCVClassifier's `fit` method.\n", + " Setting `use_clones=False` is\n", " recommended if you are working with estimators that are supporting\n", " the scikit-learn fit/predict API interface but are not compatible\n", " to scikit-learn's `clone` function.\n", "\n", + "\n", "**Attributes**\n", "\n", "- `clfs_` : list, shape=[n_classifiers]\n", @@ -801,15 +805,6 @@ "with open('../../api_modules/mlxtend.classifier/StackingCVClassifier.md', 'r') as f:\n", " print(f.read())" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] } ], "metadata": { @@ -829,7 +824,19 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.6.4" + }, + "toc": { + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index e86b3b0b6..60c41e36a 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -229,7 +229,14 @@ def fit(self, X, y, groups=None): single_model_prediction]) if self.store_train_meta_features: - self.train_meta_features_ = all_model_predictions + # Store the meta features in the order of the + # original X,y arrays + reodered_indices = np.array([]).astype(y.dtype) + for train_index, test_index in skf: + reodered_indices = np.concatenate((reodered_indices, + test_index)) + self.train_meta_features_ = all_model_predictions[np.argsort( + reodered_indices)] # We have to shuffle the labels in the same order as we generated # predictions during CV (we kinda shuffled them when we did diff --git a/mlxtend/classifier/tests/test_stacking_cv_classifier.py b/mlxtend/classifier/tests/test_stacking_cv_classifier.py index 1efd0e469..4b4d9bad6 100644 --- a/mlxtend/classifier/tests/test_stacking_cv_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_cv_classifier.py @@ -19,9 +19,13 @@ from sklearn.model_selection import KFold from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split +from sklearn.metrics import roc_auc_score iris = datasets.load_iris() -X, y = iris.data[:, 1:3], iris.target +X_iris, y_iris = iris.data[:, 1:3], iris.target + +breast_cancer = datasets.load_breast_cancer() +X_breast, y_breast = breast_cancer.data[:, 1:3], breast_cancer.target def test_StackingClassifier(): @@ -34,8 +38,8 @@ def test_StackingClassifier(): shuffle=False) scores = cross_val_score(sclf, - X, - y, + X_iris, + y_iris, cv=5, scoring='accuracy') scores_mean = (round(scores.mean(), 2)) @@ -53,8 +57,8 @@ def test_StackingClassifier_proba(): shuffle=False) scores = cross_val_score(sclf, - X, - y, + X_iris, + y_iris, cv=5, scoring='accuracy') scores_mean = (round(scores.mean(), 2)) @@ -112,8 +116,8 @@ def test_use_probas(): shuffle=False) scores = cross_val_score(sclf, - X, - y, + X_iris, + y_iris, cv=5, scoring='accuracy') scores_mean = (round(scores.mean(), 2)) @@ -131,8 +135,8 @@ def test_use_features_in_secondary(): shuffle=False) scores = cross_val_score(sclf, - X, - y, + X_iris, + y_iris, cv=5, scoring='accuracy') scores_mean = (round(scores.mean(), 2)) @@ -149,8 +153,8 @@ def test_do_not_stratify(): stratify=False) scores = cross_val_score(sclf, - X, - y, + X_iris, + y_iris, cv=5, scoring='accuracy') scores_mean = (round(scores.mean(), 2)) @@ -171,8 +175,8 @@ def test_cross_validation_technique(): cv=cv) scores = cross_val_score(sclf, - X, - y, + X_iris, + y_iris, cv=5, scoring='accuracy') scores_mean = (round(scores.mean(), 2)) @@ -224,7 +228,7 @@ def test_verbose(): def test_list_of_lists(): - X_list = [i for i in X] + X_list = [i for i in X_iris] meta = LogisticRegression() clf1 = RandomForestClassifier() clf2 = GaussianNB() @@ -241,7 +245,7 @@ def test_list_of_lists(): def test_pandas(): - X_df = pd.DataFrame(X) + X_df = pd.DataFrame(X_iris) meta = LogisticRegression() clf1 = RandomForestClassifier() clf2 = GaussianNB() @@ -296,7 +300,7 @@ def test_classifier_gridsearch(): param_grid=params, cv=5, refit=True) - grid.fit(X, y) + grid.fit(X_iris, y_iris) assert len(grid.best_params_['classifiers']) == 3 @@ -308,7 +312,8 @@ def test_train_meta_features_(): stclf = StackingCVClassifier(classifiers=[knn, gnb], meta_classifier=lr, store_train_meta_features=True) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) + X_train, X_test, y_train, y_test = train_test_split(X_iris, y_iris, + test_size=0.3) stclf.fit(X_train, y_train) train_meta_features = stclf.train_meta_features_ assert train_meta_features.shape == (X_train.shape[0], 2) @@ -318,8 +323,8 @@ def test_predict_meta_features(): knn = KNeighborsClassifier() lr = LogisticRegression() gnb = GaussianNB() - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) - + X_train, X_test, y_train, y_test = train_test_split(X_iris, y_iris, + test_size=0.3) # test default (class labels) stclf = StackingCVClassifier(classifiers=[knn, gnb], meta_classifier=lr, @@ -327,3 +332,19 @@ def test_predict_meta_features(): stclf.fit(X_train, y_train) test_meta_features = stclf.predict(X_test) assert test_meta_features.shape == (X_test.shape[0],) + + +def test_meta_feat_reordering(): + knn = KNeighborsClassifier() + lr = LogisticRegression() + gnb = GaussianNB() + stclf = StackingCVClassifier(classifiers=[knn, gnb], + meta_classifier=lr, + shuffle=True, + store_train_meta_features=True) + X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast, + test_size=0.3) + stclf.fit(X_train, y_train) + + assert round(roc_auc_score(y_train, + stclf.train_meta_features_[:, 1]), 2) == 0.88