Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Neural network widget that works in a separate thread #2958

Merged
merged 5 commits into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Orange/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,11 @@ def __call__(self, data):
m.params = self.params
return m

def _initialize_wrapped(self):
return self.__wraps__(**self.params)

def fit(self, X, Y, W=None):
clf = self.__wraps__(**self.params)
clf = self._initialize_wrapped()
Y = Y.reshape(-1)
if W is None or not self.supports_weights:
return self.__returns__(clf.fit(X, Y))
Expand Down
25 changes: 24 additions & 1 deletion Orange/classification/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,28 @@
__all__ = ["NNClassificationLearner"]


class NIterCallbackMixin:
orange_callback = None

@property
def n_iter_(self):
return self.__orange_n_iter

@n_iter_.setter
def n_iter_(self, v):
self.__orange_n_iter = v
if self.orange_callback:
self.orange_callback(v)


class MLPClassifierWCallback(skl_nn.MLPClassifier, NIterCallbackMixin):
pass


class NNClassificationLearner(NNBase, SklLearner):
__wraps__ = skl_nn.MLPClassifier
__wraps__ = MLPClassifierWCallback

def _initialize_wrapped(self):
clf = SklLearner._initialize_wrapped(self)
clf.orange_callback = getattr(self, "callback", None)
return clf
7 changes: 7 additions & 0 deletions Orange/modelling/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@
class NNLearner(SklFitter):
__fits__ = {'classification': NNClassificationLearner,
'regression': NNRegressionLearner}

callback = None

def get_learner(self, problem_type):
learner = super().get_learner(problem_type)
learner.callback = self.callback
return learner
12 changes: 11 additions & 1 deletion Orange/regression/neural_network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sklearn.neural_network as skl_nn
from Orange.base import NNBase
from Orange.regression import SklLearner
from Orange.classification.neural_network import NIterCallbackMixin

__all__ = ["NNRegressionLearner"]


class MLPRegressorWCallback(skl_nn.MLPRegressor, NIterCallbackMixin):
pass


class NNRegressionLearner(NNBase, SklLearner):
__wraps__ = skl_nn.MLPRegressor
__wraps__ = MLPRegressorWCallback

def _initialize_wrapped(self):
clf = SklLearner._initialize_wrapped(self)
clf.orange_callback = getattr(self, "callback", None)
return clf
166 changes: 164 additions & 2 deletions Orange/widgets/model/owneuralnetwork.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,62 @@
from functools import partial
import copy
import logging
import re
import sys
import concurrent.futures

from AnyQt.QtWidgets import QApplication
from AnyQt.QtCore import Qt
from AnyQt.QtCore import Qt, QThread, QObject
from AnyQt.QtCore import pyqtSlot as Slot, pyqtSignal as Signal

from Orange.data import Table
from Orange.modelling import NNLearner
from Orange.widgets import gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner

from Orange.widgets.utils.concurrent import ThreadExecutor, FutureWatcher


class Task(QObject):
"""
A class that will hold the state for an learner evaluation.
"""
done = Signal(object)
progressChanged = Signal(float)

future = None # type: concurrent.futures.Future
watcher = None # type: FutureWatcher
cancelled = False # type: bool

def setFuture(self, future):
if self.future is not None:
raise RuntimeError("future is already set")
self.future = future
self.watcher = FutureWatcher(future, parent=self)
self.watcher.done.connect(self.done)

def cancel(self):
"""
Cancel the task.

Set the `cancelled` field to True and block until the future is done.
"""
# set cancelled state
self.cancelled = True
self.future.cancel()
concurrent.futures.wait([self.future])

def emitProgressUpdate(self, value):
self.progressChanged.emit(value)

def isInterruptionRequested(self):
return self.cancelled


class CancelTaskException(BaseException):
pass


class OWNNLearner(OWBaseLearner):
name = "Neural Network"
Expand Down Expand Up @@ -53,11 +100,20 @@ def add_main_layout(self):
label="Alpha:", decimals=5, alignment=Qt.AlignRight,
callback=self.settings_changed, controlWidth=80)
self.max_iter_spin = gui.spin(
box, self, "max_iterations", 10, 300, step=10,
box, self, "max_iterations", 10, 10000, step=10,
label="Max iterations:", orientation=Qt.Horizontal,
alignment=Qt.AlignRight, callback=self.settings_changed,
controlWidth=80)

def setup_layout(self):
super().setup_layout()

self._task = None # type: Optional[Task]
self._executor = ThreadExecutor()

# just a test cancel button
gui.button(self.controlArea, self, "Cancel", callback=self.cancel)

def create_learner(self):
return self.LEARNER(
hidden_layer_sizes=self.get_hidden_layers(),
Expand All @@ -81,6 +137,112 @@ def get_hidden_layers(self):
self.hidden_layers_edit.setText("100,")
return layers

def update_model(self):
self.show_fitting_failed(None)
self.model = None
if self.check_data():
self.__update()
else:
self.Outputs.model.send(self.model)

@Slot(float)
def setProgressValue(self, value):
assert self.thread() is QThread.currentThread()
self.progressBarSet(value)

def __update(self):
if self._task is not None:
# First make sure any pending tasks are cancelled.
self.cancel()
assert self._task is None

max_iter = self.learner.kwargs["max_iter"]

# Setup the task state
task = Task()
lastemitted = 0.

def callback(iteration):
nonlocal task # type: Task
nonlocal lastemitted
if task.isInterruptionRequested():
raise CancelTaskException()
progress = round(iteration / max_iter * 100)
if progress != lastemitted:
task.emitProgressUpdate(progress)
lastemitted = progress

# copy to set the callback so that the learner output is not modified
# (currently we can not pass callbacks to learners __call__)
learner = copy.copy(self.learner)
learner.callback = callback

def build_model(data, learner):
try:
return learner(data)
except CancelTaskException:
return None

build_model_func = partial(build_model, self.data, learner)

task.setFuture(self._executor.submit(build_model_func))
task.done.connect(self._task_finished)
task.progressChanged.connect(self.setProgressValue)

self._task = task
self.progressBarInit()
self.setBlocking(True)

@Slot(concurrent.futures.Future)
def _task_finished(self, f):
"""
Parameters
----------
f : Future
The future instance holding the built model
"""
assert self.thread() is QThread.currentThread()
assert self._task is not None
assert self._task.future is f
assert f.done()
self._task.deleteLater()
self._task = None
self.setBlocking(False)
self.progressBarFinished()

try:
self.model = f.result()
except Exception as ex: # pylint: disable=broad-except
# Log the exception with a traceback
log = logging.getLogger()
log.exception(__name__, exc_info=True)
self.model = None
self.show_fitting_failed(ex)
else:
self.model.name = self.learner_name
self.model.instances = self.data
self.Outputs.model.send(self.model)

def cancel(self):
"""
Cancel the current task (if any).
"""
if self._task is not None:
self._task.cancel()
assert self._task.future.done()
# disconnect from the task
self._task.done.disconnect(self._task_finished)
self._task.progressChanged.disconnect(self.setProgressValue)
self._task.deleteLater()
self._task = None

self.progressBarFinished()
self.setBlocking(False)

def onDeleteWidget(self):
self.cancel()
super().onDeleteWidget()


if __name__ == "__main__":
a = QApplication(sys.argv)
Expand Down
16 changes: 16 additions & 0 deletions Orange/widgets/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,16 @@ def test_input_data(self):
self.assertEqual(self.widget.data, None)
self.send_signal("Data", self.data)
self.assertEqual(self.widget.data, self.data)
self.wait_until_stop_blocking()

def test_input_data_disconnect(self):
"""Check widget's data and model after disconnecting data from input"""
self.send_signal("Data", self.data)
self.assertEqual(self.widget.data, self.data)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.send_signal("Data", None)
self.wait_until_stop_blocking()
self.assertEqual(self.widget.data, None)
self.assertIsNone(self.get_output(self.widget.Outputs.model))

Expand All @@ -529,9 +532,11 @@ def test_input_data_learner_adequacy(self):
for inadequate in self.inadequate_dataset:
self.send_signal("Data", inadequate)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertTrue(self.widget.Error.data_error.is_shown())
for valid in self.valid_datasets:
self.send_signal("Data", valid)
self.wait_until_stop_blocking()
self.assertFalse(self.widget.Error.data_error.is_shown())

def test_input_preprocessor(self):
Expand All @@ -542,6 +547,7 @@ def test_input_preprocessor(self):
randomize, self.widget.preprocessors,
'Preprocessor not added to widget preprocessors')
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(
(randomize,), self.widget.learner.preprocessors,
'Preprocessors were not passed to the learner')
Expand All @@ -551,6 +557,7 @@ def test_input_preprocessors(self):
pp_list = PreprocessorList([Randomize(), RemoveNaNColumns()])
self.send_signal("Preprocessor", pp_list)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(
(pp_list,), self.widget.learner.preprocessors,
'`PreprocessorList` was not added to preprocessors')
Expand All @@ -560,10 +567,12 @@ def test_input_preprocessor_disconnect(self):
randomize = Randomize()
self.send_signal("Preprocessor", randomize)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(randomize, self.widget.preprocessors)

self.send_signal("Preprocessor", None)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertIsNone(self.widget.preprocessors,
'Preprocessors not removed on disconnect.')

Expand All @@ -585,6 +594,7 @@ def test_output_model(self):
self.assertIsNone(self.get_output(self.widget.Outputs.model))
self.send_signal('Data', self.data)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
model = self.get_output(self.widget.Outputs.model)
self.assertIsNotNone(model)
self.assertIsInstance(model, self.widget.LEARNER.__returns__)
Expand All @@ -598,6 +608,7 @@ def test_output_learner_name(self):
self.widget.name_line_edit.text())
self.widget.name_line_edit.setText(new_name)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(self.get_output("Learner").name, new_name)

def test_output_model_name(self):
Expand All @@ -606,6 +617,7 @@ def test_output_model_name(self):
self.widget.name_line_edit.setText(new_name)
self.send_signal("Data", self.data)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
self.assertEqual(self.get_output(self.widget.Outputs.model).name, new_name)

def _get_param_value(self, learner, param):
Expand All @@ -626,6 +638,7 @@ def test_parameters_default(self):
for dataset in self.valid_datasets:
self.send_signal("Data", dataset)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
for parameter in self.parameters:
# Skip if the param isn't used for the given data type
if self._should_check_parameter(parameter, dataset):
Expand All @@ -639,6 +652,7 @@ def test_parameters(self):
# to only certain problem types
for dataset in self.valid_datasets:
self.send_signal("Data", dataset)
self.wait_until_stop_blocking()

for parameter in self.parameters:
# Skip if the param isn't used for the given data type
Expand All @@ -650,6 +664,7 @@ def test_parameters(self):
for value in parameter.values:
parameter.set_value(value)
self.widget.apply_button.button.click()
self.wait_until_stop_blocking()
param = self._get_param_value(self.widget.learner, parameter)
self.assertEqual(
param, parameter.get_value(),
Expand All @@ -674,6 +689,7 @@ def test_params_trigger_settings_changed(self):
"""Check that the learner gets updated whenever a param is changed."""
for dataset in self.valid_datasets:
self.send_signal("Data", dataset)
self.wait_until_stop_blocking()

for parameter in self.parameters:
# Skip if the param isn't used for the given data type
Expand Down