Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fitter: Properly delegate preprocessors #2093

Merged
merged 1 commit into from
Mar 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Orange/modelling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def _fit_model(self, data):
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
return learner.fit(X, Y, W)

def preprocess(self, data):
if data.domain.has_discrete_class:
return self.get_learner(self.CLASSIFICATION).preprocess(data)
else:
return self.get_learner(self.REGRESSION).preprocess(data)

def get_learner(self, problem_type):
"""Get the learner for a given problem type.

Expand Down
26 changes: 24 additions & 2 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from unittest.mock import Mock, patch

from Orange.classification.base_classification import LearnerClassification
from Orange.data import Table
from Orange.data import Table, ContinuousVariable
from Orange.evaluation import CrossValidation
from Orange.modelling import Fitter
from Orange.preprocess import Randomize
from Orange.preprocess import Randomize, Discretize
from Orange.regression.base_regression import LearnerRegression


Expand Down Expand Up @@ -130,3 +130,25 @@ def test_correctly_sets_preprocessors_on_learner(self):
def test_n_jobs_fitting(self):
with patch('Orange.evaluation.testing.CrossValidation._MIN_NJOBS_X_SIZE', 1):
CrossValidation(self.heart_disease, [DummyFitter()], k=5, n_jobs=5)

def test_properly_delegates_preprocessing(self):
class DummyClassificationLearner(LearnerClassification):
preprocessors = [Discretize()]

def __init__(self, classification_param=1, **_):
super().__init__()
self.param = classification_param

class DummyFitter(Fitter):
__fits__ = {'classification': DummyClassificationLearner,
'regression': DummyRegressionLearner}

data = self.heart_disease
fitter = DummyFitter()
# Sanity check
self.assertTrue(any(
isinstance(v, ContinuousVariable) for v in data.domain.variables))
# Preprocess the data and check that the discretization was applied
pp_data = fitter.preprocess(self.heart_disease)
self.assertTrue(not any(
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))