-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from AnyQt.QtWidgets import QSizePolicy | ||
from AnyQt.QtCore import Qt | ||
from Orange.data import Variable, Table, ContinuousVariable, TimeVariable | ||
from Orange.data.util import get_unique_names | ||
from Orange.widgets import gui, widget | ||
from Orange.widgets.settings import ( | ||
ContextSetting, Setting, DomainContextHandler | ||
) | ||
from Orange.widgets.utils.widgetpreview import WidgetPreview | ||
from Orange.widgets.utils.state_summary import format_summary_details | ||
from Orange.widgets.widget import Input, Output | ||
from Orange.widgets.utils.itemmodels import DomainModel | ||
|
||
|
||
class OWAggregateColumns(widget.OWWidget): | ||
name = "Aggregate Columns" | ||
description = "Compute a sum, max, min ... of selected columns." | ||
icon = "icons/AggregateColumns.svg" | ||
priority = 100 | ||
keywords = ["aggregate", "sum", "product", "max", "min", "mean", | ||
"median", "variance"] | ||
|
||
class Inputs: | ||
data = Input("Data", Table, default=True) | ||
|
||
class Outputs: | ||
data = Output("Data", Table) | ||
|
||
want_main_area = False | ||
|
||
settingsHandler = DomainContextHandler() | ||
variables: List[Variable] = ContextSetting([]) | ||
operation = Setting("Sum") | ||
var_name = Setting("agg") | ||
auto_apply = Setting(True) | ||
|
||
Operations = {"Sum": np.nansum, "Product": np.nanprod, | ||
"Min": np.nanmin, "Max": np.nanmax, | ||
"Mean": np.nanmean, "Variance": np.nanvar, | ||
"Median": np.nanmedian} | ||
TimePreserving = ("Min", "Max", "Mean", "Median") | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.data = None | ||
|
||
box = gui.vBox(self.controlArea, box=True) | ||
|
||
self.variable_model = DomainModel( | ||
order=DomainModel.MIXED, valid_types=(ContinuousVariable, )) | ||
var_list = gui.listView( | ||
box, self, "variables", model=self.variable_model, | ||
callback=self.commit) | ||
var_list.setSelectionMode(var_list.ExtendedSelection) | ||
|
||
combo = gui.comboBox( | ||
box, self, "operation", | ||
label="Operator: ", orientation=Qt.Horizontal, | ||
items=list(self.Operations), sendSelectedValue=True, | ||
callback=self.commit | ||
) | ||
combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) | ||
|
||
gui.lineEdit( | ||
box, self, "var_name", | ||
label="Variable name: ", orientation=Qt.Horizontal, | ||
callback=self.commit | ||
) | ||
|
||
gui.auto_apply(self.controlArea, self) | ||
|
||
@Inputs.data | ||
def set_data(self, data: Table = None): | ||
self.closeContext() | ||
self.data = data | ||
if self.data: | ||
self.variable_model.set_domain(data.domain) | ||
self.info.set_input_summary(len(self.data), | ||
format_summary_details(self.data)) | ||
self.variables.clear() | ||
self.openContext(data) | ||
else: | ||
self.variable_model.set_domain(None) | ||
self.variables.clear() | ||
self.info.set_input_summary(self.info.NoInput) | ||
self.unconditional_commit() | ||
|
||
def commit(self): | ||
augmented = self._compute_data() | ||
self.Outputs.data.send(augmented) | ||
if augmented is None: | ||
self.info.set_output_summary(self.info.NoOutput) | ||
else: | ||
self.info.set_output_summary( | ||
len(augmented), format_summary_details(augmented)) | ||
|
||
def _compute_data(self): | ||
if not self.data or not self.variables: | ||
return self.data | ||
|
||
new_col = self._compute_column() | ||
new_var = self._new_var() | ||
return self.data.add_column(new_var, new_col) | ||
|
||
def _compute_column(self): | ||
arr = np.empty((len(self.data), len(self.variables))) | ||
for i, var in enumerate(self.variables): | ||
arr[:, i] = self.data.get_column_view(var)[0].astype(float) | ||
func = self.Operations[self.operation] | ||
return func(arr, axis=1) | ||
|
||
def _new_var_name(self): | ||
return get_unique_names(self.data.domain, self.var_name) | ||
|
||
def _new_var(self): | ||
name = self._new_var_name() | ||
if self.operation in self.TimePreserving \ | ||
and all(isinstance(var, TimeVariable) for var in self.variables): | ||
return TimeVariable(name) | ||
return ContinuousVariable(name) | ||
|
||
def send_report(self): | ||
# fp for self.variables, pylint: disable=unsubscriptable-object | ||
if not self.data or not self.variables: | ||
return | ||
var_list = ", ".join(f"'{var.name}'" | ||
for var in self.variables[:31][:-1]) | ||
if len(self.variables) > 30: | ||
var_list += f" and {len(self.variables) - 30} others" | ||
else: | ||
var_list += f" and '{self.variables[-1].name}'" | ||
self.report_items(( | ||
("Output:", | ||
f"'{self._new_var_name()}' as {self.operation.lower()} of {var_list}" | ||
), | ||
)) | ||
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
brown = Table("brown-selected") | ||
WidgetPreview(OWAggregateColumns).run(set_data=brown) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Test methods with long descriptive names can omit docstrings | ||
# pylint: disable=missing-docstring, abstract-method, protected-access | ||
import unittest | ||
from itertools import chain | ||
|
||
from unittest.mock import Mock | ||
|
||
import numpy as np | ||
|
||
from Orange.data import ( | ||
Table, Domain, | ||
ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable | ||
) | ||
from Orange.widgets.data.owaggregatecolumns import OWAggregateColumns | ||
from Orange.widgets.tests.base import WidgetTest | ||
from Orange.widgets.utils.state_summary import format_summary_details | ||
|
||
|
||
class TestOWAggregateColumn(WidgetTest): | ||
def setUp(self): | ||
#: OWAggregateColumns | ||
self.widget = self.create_widget(OWAggregateColumns) | ||
c1, c2, c3 = map(ContinuousVariable, "c1 c2 c3".split()) | ||
t1, t2 = map(TimeVariable, "t1 t2".split()) | ||
d1, d2, d3 = (DiscreteVariable(n, values=("a", "b", "c")) | ||
for n in "d1 d2 d3".split()) | ||
s1 = StringVariable("s1") | ||
domain1 = Domain([c1, c2, d1, d2, t1], [d3], [s1, c3, t2]) | ||
self.data1 = Table.from_list(domain1, | ||
[[0, 1, 0, 1, 2, 0, "foo", 0, 3], | ||
[3, 1, 0, 1, 42, 0, "bar", 0, 4]]) | ||
|
||
domain2 = Domain([ContinuousVariable("c4")]) | ||
self.data2 = Table.from_list(domain2, [[4], [5]]) | ||
|
||
def test_no_input(self): | ||
widget = self.widget | ||
domain = self.data1.domain | ||
input_sum = widget.info.set_input_summary = Mock() | ||
output_sum = widget.info.set_output_summary = Mock() | ||
|
||
self.send_signal(widget.Inputs.data, self.data1) | ||
self.assertEqual(widget.variables, []) | ||
widget.commit() | ||
output = self.get_output(self.widget.Outputs.data) | ||
self.assertIs(output, self.data1) | ||
input_sum.assert_called_with(len(self.data1), | ||
format_summary_details(self.data1)) | ||
output_sum.assert_called_with(len(output), | ||
format_summary_details(output)) | ||
|
||
widget.variables = [domain[n] for n in "c1 c2 t2".split()] | ||
widget.commit() | ||
output = self.get_output(self.widget.Outputs.data) | ||
self.assertIsNotNone(output) | ||
output_sum.assert_called_with(len(output), | ||
format_summary_details(output)) | ||
|
||
self.send_signal(widget.Inputs.data, None) | ||
widget.commit() | ||
self.assertIsNone(self.get_output(self.widget.Outputs.data)) | ||
input_sum.assert_called_with(widget.info.NoInput) | ||
output_sum.assert_called_with(widget.info.NoOutput) | ||
|
||
def test_compute_data(self): | ||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] | ||
|
||
self.widget.operation = "Sum" | ||
output = self.widget._compute_data() | ||
self.assertEqual(output.domain.attributes[:-1], domain.attributes) | ||
np.testing.assert_equal(output.X[:, -1], [4, 8]) | ||
|
||
self.widget.operation = "Max" | ||
output = self.widget._compute_data() | ||
self.assertEqual(output.domain.attributes[:-1], domain.attributes) | ||
np.testing.assert_equal(output.X[:, -1], [3, 4]) | ||
|
||
def test_var_name(self): | ||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.widget.variables = self.widget.variable_model[:] | ||
|
||
self.widget.var_name = "test" | ||
output = self.widget._compute_data() | ||
self.assertEqual(output.domain.attributes[-1].name, "test") | ||
|
||
self.widget.var_name = "d1" | ||
output = self.widget._compute_data() | ||
self.assertNotIn( | ||
output.domain.attributes[-1].name, | ||
[var.name for var in chain(domain.variables, domain.metas)]) | ||
|
||
def test_var_types(self): | ||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
|
||
self.widget.variables = [domain[n] for n in "t1 c2 t2".split()] | ||
for self.widget.operation in self.widget.Operations: | ||
self.assertIsInstance(self.widget._new_var(), ContinuousVariable) | ||
|
||
self.widget.variables = [domain[n] for n in "t1 t2".split()] | ||
for self.widget.operation in self.widget.Operations: | ||
self.assertIsInstance( | ||
self.widget._new_var(), | ||
TimeVariable | ||
if self.widget.operation in ("Min", "Max", "Mean", "Median") | ||
else ContinuousVariable) | ||
|
||
def test_operations(self): | ||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] | ||
|
||
m1, m2 = 4 / 3, 8 / 3 | ||
for self.widget.operation, expected in { | ||
"Sum": [4, 8], "Product": [0, 12], | ||
"Min": [0, 1], "Max": [3, 4], | ||
"Mean": [m1, m2], | ||
"Variance": [(m1 ** 2 + (m1 - 1) ** 2 + (m1 - 3) ** 2) / 3, | ||
((m2 - 3) ** 2 + (m2 - 1) ** 2 + (m2 - 4) ** 2) / 3], | ||
"Median": [1, 3]}.items(): | ||
np.testing.assert_equal( | ||
self.widget._compute_column(), expected, | ||
err_msg=f"error in '{self.widget.operation}'") | ||
|
||
def test_operations_with_nan(self): | ||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.data1.X[1, 0] = np.nan | ||
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] | ||
|
||
m1, m2 = 4 / 3, 5 / 2 | ||
for self.widget.operation, expected in { | ||
"Sum": [4, 5], "Product": [0, 4], | ||
"Min": [0, 1], "Max": [3, 4], | ||
"Mean": [m1, m2], | ||
"Variance": [(m1 ** 2 + (m1 - 1) ** 2 + (m1 - 3) ** 2) / 3, | ||
((m2 - 1) ** 2 + (m2 - 4) ** 2) / 2], | ||
"Median": [1, 2.5]}.items(): | ||
np.testing.assert_equal( | ||
self.widget._compute_column(), expected, | ||
err_msg=f"error in '{self.widget.operation}'") | ||
|
||
def test_contexts(self): | ||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] | ||
saved = self.widget.variables[:] | ||
|
||
self.send_signal(self.widget.Inputs.data, self.data2) | ||
self.assertEqual(self.widget.variables, []) | ||
|
||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.assertEqual(self.widget.variables, saved) | ||
|
||
def test_report(self): | ||
self.widget.send_report() | ||
|
||
domain = self.data1.domain | ||
self.send_signal(self.widget.Inputs.data, self.data1) | ||
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] | ||
self.widget.send_report() | ||
|
||
domain3 = Domain([ContinuousVariable(f"c{i:02}") for i in range(100)]) | ||
data3 = Table.from_numpy(domain3, np.zeros((2, 100))) | ||
self.send_signal(self.widget.Inputs.data, data3) | ||
self.widget.variables[:] = self.widget.variable_model[:] | ||
self.widget.send_report() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |