Skip to content

Commit

Permalink
Merge pull request #4541 from markotoplak/kmeans-faster
Browse files Browse the repository at this point in the history
[FIX] K-means slowness
  • Loading branch information
thocevar authored Mar 31, 2020
2 parents 5d6a884 + cce48e4 commit c88978d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Orange/clustering/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Clustering(metaclass=WrapperMeta):
preprocessors = [Continuize(), SklImpute()]

def __init__(self, preprocessors, parameters):
self.preprocessors = tuple(preprocessors or self.preprocessors)
self.preprocessors = preprocessors if preprocessors is not None else self.preprocessors
self.params = {k: v for k, v in parameters.items()
if k not in ["self", "preprocessors", "__class__"]}

Expand Down
36 changes: 22 additions & 14 deletions Orange/widgets/unsupervised/owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,23 @@ def _compute_clustering(data, k, init, n_init, max_iter, random_state):
if k > len(data):
raise NotEnoughData()

return KMeans(
model = KMeans(
n_clusters=k, init=init, n_init=n_init, max_iter=max_iter,
random_state=random_state
random_state=random_state, preprocessors=[]
).get_model(data)

if data.X.shape[0] <= SILHOUETTE_MAX_SAMPLES:
model.silhouette_samples = silhouette_samples(data.X, model.labels)
model.silhouette = np.mean(model.silhouette_samples)
else:
model.silhouette_samples = None
model.silhouette = \
silhouette_score(data.X, model.labels,
sample_size=SILHOUETTE_MAX_SAMPLES,
random_state=RANDOM_STATE)

return model

@Slot(int, int)
def __progress_changed(self, n, d):
assert QThread.currentThread() is self.thread()
Expand Down Expand Up @@ -338,9 +350,10 @@ def __commit_finished(self):
def __launch_tasks(self, ks):
# type: (List[int]) -> None
"""Execute clustering in separate threads for all given ks."""
preprocessed_data = self.preproces(self.data)
futures = [self.__executor.submit(
self._compute_clustering,
data=Normalize()(self.data) if self.normalize else self.data,
data=preprocessed_data,
k=k,
init=self.INIT_METHODS[self.smart_init][1],
n_init=self.n_init,
Expand Down Expand Up @@ -443,9 +456,8 @@ def invalidate(self, unconditional=False):
self.commit()

def update_results(self):
scores = [mk if isinstance(mk, str) else silhouette_score(
self.preproces(self.data).X, mk.labels) for mk in (
self.clusterings[k] for k in range(self.k_from, self.k_to + 1))]
scores = [mk if isinstance(mk, str) else mk.silhouette for mk in
(self.clusterings[k] for k in range(self.k_from, self.k_to + 1))]
best_row = max(
range(len(scores)), default=0,
key=lambda x: 0 if isinstance(scores[x], str) else scores[x]
Expand Down Expand Up @@ -479,11 +491,6 @@ def preproces(self, data):
data = preprocessor(data)
return data

def samples_scores(self, clust_ids):
d = self.preproces(self.data)
return np.arctan(
silhouette_samples(d.X, clust_ids)) / np.pi + 0.5

def send_data(self):
if self.optimize_k:
row = self.selected_row()
Expand All @@ -505,9 +512,9 @@ def send_data(self):
clust_ids = km.labels
silhouette_var = ContinuousVariable(
get_unique_names(domain, "Silhouette"))
if len(self.data) <= SILHOUETTE_MAX_SAMPLES:
if km.silhouette_samples is not None:
self.Warning.no_silhouettes.clear()
scores = self.samples_scores(clust_ids)
scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
clust_scores = []
for i in range(km.k):
in_clust = clust_ids == i
Expand All @@ -526,10 +533,11 @@ def send_data(self):
new_table.get_column_view(cluster_var)[0][:] = clust_ids
new_table.get_column_view(silhouette_var)[0][:] = scores

domain_attributes = set(domain.attributes)
centroid_attributes = [
attr.compute_value.variable
if isinstance(attr.compute_value, ReplaceUnknowns)
and attr.compute_value.variable in domain.attributes
and attr.compute_value.variable in domain_attributes
else attr
for attr in km.domain.attributes]
centroid_domain = add_columns(
Expand Down
6 changes: 2 additions & 4 deletions Orange/widgets/unsupervised/tests/test_owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ def test_centroids_on_output(self):
self.send_signal(widget.Inputs.data, self.data)
self.commit_and_wait()
widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten()

widget.samples_scores = lambda x: np.arctan(
np.arange(303) / 303) / np.pi + 0.5
widget.clusterings[widget.k].silhouette_samples = np.arange(303) / 303
widget.send_data()
out = self.get_output(widget.Outputs.centroids)
np.testing.assert_array_almost_equal(
Expand Down Expand Up @@ -323,7 +321,7 @@ def test_select_best_row(self):
# the best selection is 3 clusters, so row no. 1
self.assertEqual(widget.selected_row(), 1)

widget.normalize = True
self.widget.controls.normalize.toggle()
self.send_signal(self.widget.Inputs.data, Table("housing"), wait=5000)
self.commit_and_wait()
widget.update_results()
Expand Down
8 changes: 8 additions & 0 deletions benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

from Orange.data import Table


try: # try to disable App Nap on OS X
import appnope
appnope.nope()
except ImportError:
pass


# override method prefix for niceness
BENCH_METHOD_PREFIX = 'bench'
unittest.TestLoader.testMethodPrefix = BENCH_METHOD_PREFIX
Expand Down
58 changes: 58 additions & 0 deletions benchmark/bench_owkmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np

from Orange.data import Domain, Table, ContinuousVariable
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.unsupervised.owkmeans import OWKMeans

from .base import benchmark


def table(rows, cols):
return Table.from_numpy( # pylint: disable=W0201
Domain([ContinuousVariable(str(i)) for i in range(cols)]),
np.random.RandomState(0).rand(rows, cols))


class BenchOWKmeans(WidgetTest):

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.d_100_100 = table(100, 100)
cls.d_sampled_silhouette = table(10000, 1)
cls.d_10_500 = table(10, 500)

def setUp(self):
self.widget = None # to avoid lint errors

def widget_from_to(self):
self.widget = self.create_widget(
OWKMeans, stored_settings={"auto_commit": False})
self.widget.controls.k_from.setValue(2)
self.widget.controls.k_to.setValue(6)

@benchmark(number=3, warmup=1, repeat=3)
def bench_from_to_100_100(self):
self.widget_from_to()
self.send_signal(self.widget.Inputs.data, self.d_100_100)
self.commit_and_wait(wait=100*1000)

@benchmark(number=3, warmup=1, repeat=3)
def bench_from_to_100_100_no_normalize(self):
self.widget_from_to()
self.widget.normalize = False
self.send_signal(self.widget.Inputs.data, self.d_100_100)
self.commit_and_wait(wait=100*1000)

@benchmark(number=3, warmup=1, repeat=3)
def bench_from_to_sampled_silhouette(self):
self.widget_from_to()
self.send_signal(self.widget.Inputs.data, self.d_sampled_silhouette)
self.commit_and_wait(wait=100*1000)

@benchmark(number=3, warmup=1, repeat=3)
def bench_wide(self):
self.widget = self.create_widget(
OWKMeans, stored_settings={"auto_commit": False})
self.send_signal(self.widget.Inputs.data, self.d_10_500)
self.commit_and_wait(wait=100*1000)

0 comments on commit c88978d

Please sign in to comment.