diff --git a/Orange/modelling/base.py b/Orange/modelling/base.py index d5f944a0c7b..aa0a10cfe1a 100644 --- a/Orange/modelling/base.py +++ b/Orange/modelling/base.py @@ -1,6 +1,7 @@ import numpy as np from Orange.base import Learner, Model, SklLearner +from Orange.data import Table, Domain class Fitter(Learner): @@ -31,10 +32,7 @@ def __init__(self, preprocessors=None, **kwargs): self.__learners = {self.CLASSIFICATION: None, self.REGRESSION: None} def _fit_model(self, data): - if data.domain.has_discrete_class: - learner = self.get_learner(self.CLASSIFICATION) - else: - learner = self.get_learner(self.REGRESSION) + learner = self.get_learner(data) if type(self).fit is Learner.fit: return learner.fit_storage(data) @@ -43,20 +41,34 @@ def _fit_model(self, data): return learner.fit(X, Y, W) def preprocess(self, data): - if data.domain.has_discrete_class: - return self.get_learner(self.CLASSIFICATION).preprocess(data) - else: - return self.get_learner(self.REGRESSION).preprocess(data) + return self.get_learner(data).preprocess(data) def get_learner(self, problem_type): """Get the learner for a given problem type. + Parameters + ---------- + problem_type: str or Table or Domain + If str, one of ``'classification'`` or ``'regression'``. If Table + or Domain, the type is inferred from Domain's first class variable. + Returns ------- Learner The appropriate learner for the given problem type. + Raises + ------ + TypeError + When (inferred) problem type not one of ``'classification'`` + or ``'regression'``. """ + if isinstance(problem_type, Table): + problem_type = problem_type.domain + if isinstance(problem_type, Domain): + problem_type = (self.CLASSIFICATION if problem_type.has_discrete_class else + self.REGRESSION if problem_type.has_continuous_class else + None) # Prevent trying to access the learner when problem type is None if problem_type not in self.__fits__: raise TypeError("No learner to handle '{}'".format(problem_type)) @@ -112,8 +124,5 @@ class SklFitter(Fitter): def _fit_model(self, data): model = super()._fit_model(data) model.used_vals = [np.unique(y) for y in data.Y[:, None].T] - if data.domain.has_discrete_class: - model.params = self.get_params(self.CLASSIFICATION) - else: - model.params = self.get_params(self.REGRESSION) + model.params = self.get_params(data) return model diff --git a/Orange/modelling/linear.py b/Orange/modelling/linear.py index cf907e5641f..493af8309c5 100644 --- a/Orange/modelling/linear.py +++ b/Orange/modelling/linear.py @@ -1,11 +1,24 @@ +import numpy as np + from Orange.classification.sgd import SGDClassificationLearner +from Orange.data import Variable from Orange.modelling import SklFitter +from Orange.preprocess.score import LearnerScorer from Orange.regression import SGDRegressionLearner __all__ = ['SGDLearner'] -class SGDLearner(SklFitter): +class _FeatureScorerMixin(LearnerScorer): + feature_type = Variable + class_type = Variable + + def score(self, data): + model = self.get_learner(data)(data) + return np.atleast_2d(np.abs(model.skl_model.coef_)).mean(0) + + +class SGDLearner(SklFitter, _FeatureScorerMixin): name = 'sgd' __fits__ = {'classification': SGDClassificationLearner, diff --git a/Orange/modelling/randomforest.py b/Orange/modelling/randomforest.py index e64637b0658..d14296fdd97 100644 --- a/Orange/modelling/randomforest.py +++ b/Orange/modelling/randomforest.py @@ -1,12 +1,23 @@ from Orange.base import RandomForestModel from Orange.classification import RandomForestLearner as RFClassification +from Orange.data import Variable from Orange.modelling import SklFitter +from Orange.preprocess.score import LearnerScorer from Orange.regression import RandomForestRegressionLearner as RFRegression __all__ = ['RandomForestLearner'] -class RandomForestLearner(SklFitter): +class _FeatureScorerMixin(LearnerScorer): + feature_type = Variable + class_type = Variable + + def score(self, data): + model = self.get_learner(data)(data) + return model.skl_model.feature_importances_ + + +class RandomForestLearner(SklFitter, _FeatureScorerMixin): name = 'random forest' __fits__ = {'classification': RFClassification, diff --git a/Orange/preprocess/score.py b/Orange/preprocess/score.py index 86b0dd1a2ae..6f86d694c07 100644 --- a/Orange/preprocess/score.py +++ b/Orange/preprocess/score.py @@ -154,11 +154,10 @@ def score(self, data): raise NotImplementedError def score_data(self, data, feature=None): - scores = self.score(data) def average_scores(scores): scores_grouped = defaultdict(list) - for attr, score in zip(self.domain.attributes, scores): + for attr, score in zip(model_domain.attributes, scores): # Go up the chain of preprocessors to obtain the original variable while getattr(attr, 'compute_value', False): attr = getattr(attr.compute_value, 'variable', attr) @@ -167,8 +166,14 @@ def average_scores(scores): if attr in scores_grouped else 0 for attr in data.domain.attributes] - scores = np.atleast_2d(scores) - if data.domain != self.domain: + scores = np.atleast_2d(self.score(data)) + + from Orange.modelling import Fitter # Avoid recursive import + model_domain = (self.get_learner(data).domain + if isinstance(self, Fitter) else + self.domain) + + if data.domain != model_domain: scores = np.array([average_scores(row) for row in scores]) return scores[:, data.domain.attributes.index(feature)] \ diff --git a/Orange/widgets/data/tests/test_owrank.py b/Orange/widgets/data/tests/test_owrank.py index 6162f208898..62a84fe708a 100644 --- a/Orange/widgets/data/tests/test_owrank.py +++ b/Orange/widgets/data/tests/test_owrank.py @@ -1,6 +1,7 @@ import numpy as np from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable +from Orange.modelling import RandomForestLearner, SGDLearner from Orange.preprocess.score import Scorer from Orange.classification import LogisticRegressionLearner from Orange.regression import LinearRegressionLearner @@ -8,6 +9,8 @@ from Orange.widgets.data.owrank import OWRank from Orange.widgets.tests.base import WidgetTest +from AnyQt.QtCore import Qt + class TestOWRank(WidgetTest): def setUp(self): @@ -39,6 +42,31 @@ def test_input_scorer(self): self.assertEqual(self.log_reg, value.score) self.assertIsInstance(value.score, Scorer) + def test_input_scorer_fitter(self): + heart_disease = Table('heart_disease') + self.assertEqual(self.widget.learners, {}) + + for fitter, name in ((RandomForestLearner(), 'random forest'), + (SGDLearner(), 'sgd')): + with self.subTest(fitter=fitter): + self.send_signal("Scorer", fitter, 1) + + for data, model in ((self.housing, self.widget.contRanksModel), + (heart_disease, self.widget.discRanksModel)): + with self.subTest(data=data.name): + self.send_signal('Data', data) + scores = [model.data(model.index(row, model.columnCount() - 1)) + for row in range(model.rowCount())] + self.assertEqual(len(scores), len(data.domain.attributes)) + self.assertFalse(np.isnan(scores).any()) + + last_column = model.headerData( + model.columnCount() - 1, Qt.Horizontal).lower() + self.assertIn(name, last_column) + + self.send_signal("Scorer", None, 1) + self.assertEqual(self.widget.learners, {}) + def test_input_scorer_disconnect(self): """Check widget's scorer after disconnecting scorer on the input""" self.send_signal("Scorer", self.log_reg, 1) diff --git a/doc/visual-programming/source/widgets/data/rank.rst b/doc/visual-programming/source/widgets/data/rank.rst index f80eefb5a10..879283f8914 100644 --- a/doc/visual-programming/source/widgets/data/rank.rst +++ b/doc/visual-programming/source/widgets/data/rank.rst @@ -14,6 +14,11 @@ Signals An input data set. +- **Scorer** (multiple) + + Models that implement the feature scoring interface, such as linear / + logistic regression, random forest, stochastic gradient descent, etc. + **Outputs**: - **Reduced Data** @@ -47,6 +52,12 @@ Scoring methods 6. `ReliefF `_: the ability of an attribute to distinguish between classes on similar data instances 7. `FCBF (Fast Correlation Based Filter) `_: entropy-based measure, which also identifies redundancy due to pairwise correlations between features +Additionally, you can connect certain learners that enable scoring the features +according to how important they are in models that the learners build (e.g. +:ref:`Linear ` / :ref:`Logistic Regression `, +:ref:`Random Forest `, :ref:`SGD `, …). + + Example: Attribute Ranking and Selection ---------------------------------------- diff --git a/doc/visual-programming/source/widgets/model/linearregression.rst b/doc/visual-programming/source/widgets/model/linearregression.rst index c14a3f8fe37..511d7e11d7e 100644 --- a/doc/visual-programming/source/widgets/model/linearregression.rst +++ b/doc/visual-programming/source/widgets/model/linearregression.rst @@ -1,3 +1,5 @@ +.. _model.lr: + Linear Regression ================= diff --git a/doc/visual-programming/source/widgets/model/logisticregression.rst b/doc/visual-programming/source/widgets/model/logisticregression.rst index 860245912b1..533ccde5c7d 100644 --- a/doc/visual-programming/source/widgets/model/logisticregression.rst +++ b/doc/visual-programming/source/widgets/model/logisticregression.rst @@ -1,3 +1,5 @@ +.. _model.logit: + Logistic Regression =================== diff --git a/doc/visual-programming/source/widgets/model/randomforest.rst b/doc/visual-programming/source/widgets/model/randomforest.rst index dd1946c0051..bec22a9e9aa 100644 --- a/doc/visual-programming/source/widgets/model/randomforest.rst +++ b/doc/visual-programming/source/widgets/model/randomforest.rst @@ -1,3 +1,5 @@ +.. _model.rf: + Random Forest ============= diff --git a/doc/visual-programming/source/widgets/model/stochasticgradient.rst b/doc/visual-programming/source/widgets/model/stochasticgradient.rst index 910173383f5..51a73e19492 100644 --- a/doc/visual-programming/source/widgets/model/stochasticgradient.rst +++ b/doc/visual-programming/source/widgets/model/stochasticgradient.rst @@ -1,3 +1,5 @@ +.. _model.sgd: + Stochastic Gradient Descent ===========================