diff --git a/eli5/xgboost.py b/eli5/xgboost.py index 56db0504..c8ef78ca 100644 --- a/eli5/xgboost.py +++ b/eli5/xgboost.py @@ -243,7 +243,8 @@ def _prediction_feature_weights(booster, dmatrix, n_targets, http://blog.datadive.net/interpreting-random-forests/ """ # XGBClassifier does not have pred_leaf argument, so use booster - leaf_ids, = booster.predict(dmatrix, pred_leaf=True) + predictions = booster.predict(dmatrix, pred_leaf=True) + leaf_ids = result.reshape((len(predictions.T))) xgb_feature_names = {f: i for i, f in enumerate(xgb_feature_names)} tree_dumps = booster.get_dump(with_stats=True) assert len(tree_dumps) == len(leaf_ids)