diff --git a/Orange/data/variable.py b/Orange/data/variable.py index 83ed19d7e59..55a0ca615da 100644 --- a/Orange/data/variable.py +++ b/Orange/data/variable.py @@ -460,6 +460,10 @@ def copy(self, compute_value=None, *, name=None, **kwargs): var.attributes = dict(self.attributes) return var + def renamed(self, new_name): + # prevent cyclic import, pylint: disable=import-outside-toplevel + from Orange.preprocess.transformation import Identity + return self.copy(name=new_name, compute_value=Identity(variable=self)) del _predicatedescriptor @@ -552,11 +556,16 @@ def repr_val(self, val): str_val = repr_val def copy(self, compute_value=None, *, name=None, **kwargs): - var = super().copy(compute_value=compute_value, name=name, - number_of_decimals=self.number_of_decimals, - **kwargs) - var.adjust_decimals = self.adjust_decimals - var.format_str = self._format_str + # pylint understand not that `var` is `DiscreteVariable`: + # pylint: disable=protected-access + number_of_decimals = kwargs.pop("number_of_decimals", None) + var = super().copy(compute_value=compute_value, name=name, **kwargs) + if number_of_decimals is not None: + var.number_of_decimals = number_of_decimals + else: + var._number_of_decimals = self._number_of_decimals + var.adjust_decimals = self.adjust_decimals + var.format_str = self._format_str return var diff --git a/Orange/widgets/data/owconcatenate.py b/Orange/widgets/data/owconcatenate.py index 3bb1eb2e998..3fabc8fcc8d 100644 --- a/Orange/widgets/data/owconcatenate.py +++ b/Orange/widgets/data/owconcatenate.py @@ -6,14 +6,17 @@ """ -from collections import OrderedDict +from collections import OrderedDict, namedtuple from functools import reduce +from itertools import chain, count +from typing import List import numpy as np from AnyQt.QtWidgets import QFormLayout from AnyQt.QtCore import Qt import Orange.data +from Orange.data.util import get_unique_names_duplicates, get_unique_names from Orange.util import flatten from Orange.widgets import widget, gui, settings from Orange.widgets.settings import Setting @@ -43,6 +46,10 @@ class Outputs: class Error(widget.OWWidget.Error): bow_concatenation = Msg("Inputs must be of the same type.") + class Warning(widget.OWWidget.Warning): + renamed_variables = Msg( + "Variables with duplicated names have been renamed.") + merge_type: int append_source_column: bool source_column_role: int @@ -172,18 +179,15 @@ def incompatible_types(self): return False def apply(self): + self.Warning.renamed_variables.clear() tables, domain, source_var = [], None, None if self.primary_data is not None: tables = [self.primary_data] + list(self.more_data.values()) domain = self.primary_data.domain elif self.more_data: tables = self.more_data.values() - if self.merge_type == OWConcatenate.MergeUnion: - domain = reduce(domain_union, - (table.domain for table in tables)) - else: - domain = reduce(domain_intersection, - (table.domain for table in tables)) + domains = [table.domain for table in tables] + domain = self.merge_domains(domains) if tables and self.append_source_column: assert domain is not None @@ -192,7 +196,7 @@ def apply(self): names = ['{} ({})'.format(name, i) for i, name in enumerate(names)] source_var = Orange.data.DiscreteVariable( - self.source_attr_name, + get_unique_names(domain, self.source_attr_name), values=names ) places = ["class_vars", "attributes", "metas"] @@ -236,36 +240,76 @@ def send_report(self): self.id_roles[self.source_column_role].lower()) self.report_items(items) - -def unique(seq): - seen_set = set() - for el in seq: - if el not in seen_set: - yield el - seen_set.add(el) - - -def domain_union(a, b): - union = Orange.data.Domain( - tuple(unique(a.attributes + b.attributes)), - tuple(unique(a.class_vars + b.class_vars)), - tuple(unique(a.metas + b.metas)) - ) - return union - - -def domain_intersection(a, b): - def tuple_intersection(t1, t2): - inters = set(t1) & set(t2) - return tuple(unique(el for el in t1 + t2 if el in inters)) - - intersection = Orange.data.Domain( - tuple_intersection(a.attributes, b.attributes), - tuple_intersection(a.class_vars, b.class_vars), - tuple_intersection(a.metas, b.metas), - ) - - return intersection + def merge_domains(self, domains): + def fix_names(part): + for i, attr, name in zip(count(), part, name_iter): + if attr.name != name: + part[i] = attr.renamed(name) + self.Warning.renamed_variables() + + oper = set.union if self.merge_type == OWConcatenate.MergeUnion \ + else set.intersection + parts = [self._get_part(domains, oper, part) + for part in ("attributes", "class_vars", "metas")] + all_names = [var.name for var in chain(*parts)] + name_iter = iter(get_unique_names_duplicates(all_names)) + for part in parts: + fix_names(part) + domain = Orange.data.Domain(*parts) + return domain + + @classmethod + def _get_part(cls, domains, oper, part): + # keep the order of variables: first compute union or intersections as + # sets, then iterate through chained parts + vars_by_domain = [getattr(domain, part) for domain in domains] + valid = reduce(oper, map(set, vars_by_domain)) + valid_vars = [var for var in chain(*vars_by_domain) if var in valid] + return cls._unique_vars(valid_vars) + + @staticmethod + def _unique_vars(seq: List[Orange.data.Variable]): + AttrDesc = namedtuple( + "AttrDesc", + ("template", "original", "values", "number_of_decimals")) + + attrs = {} + for el in seq: + desc = attrs.get(el) + if desc is None: + attrs[el] = AttrDesc(el, True, + el.is_discrete and el.values, + el.is_continuous and el.number_of_decimals) + continue + if desc.template.is_discrete: + sattr_values = set(desc.values) + # don't use sets: keep the order + missing_values = [val for val in el.values + if val not in sattr_values] + if missing_values: + attrs[el] = attrs[el]._replace( + original=False, + values=desc.values + missing_values) + elif desc.template.is_continuous: + if el.number_of_decimals > desc.number_of_decimals: + attrs[el] = attrs[el]._replace( + original=False, + number_of_decimals=el.number_of_decimals) + + new_attrs = [] + for desc in attrs.values(): + attr = desc.template + if desc.original: + new_attr = attr + elif desc.template.is_discrete: + new_attr = attr.copy() + for val in desc.values[len(attr.values):]: + new_attr.add_value(val) + else: + assert desc.template.is_continuous + new_attr = attr.copy(number_of_decimals=desc.number_of_decimals) + new_attrs.append(new_attr) + return new_attrs if __name__ == "__main__": # pragma: no cover diff --git a/Orange/widgets/data/tests/test_owconcatenate.py b/Orange/widgets/data/tests/test_owconcatenate.py index 40702971df8..11aef97954d 100644 --- a/Orange/widgets/data/tests/test_owconcatenate.py +++ b/Orange/widgets/data/tests/test_owconcatenate.py @@ -9,9 +9,8 @@ from Orange.data import ( Table, Domain, ContinuousVariable, DiscreteVariable, StringVariable ) -from Orange.widgets.data.owconcatenate import ( - OWConcatenate, domain_intersection, domain_union -) +from Orange.preprocess.transformation import Identity +from Orange.widgets.data.owconcatenate import OWConcatenate from Orange.widgets.tests.base import WidgetTest @@ -133,56 +132,252 @@ def test_type_compatibility(self): self.send_signal(self.widget.Inputs.additional_data, self.DummyTable()) self.assertTrue(self.widget.Error.bow_concatenation.is_shown()) + def test_same_var_name(self): + widget = self.widget + + var1 = DiscreteVariable(name="x", values=list("abcd")) + data1 = Table.from_numpy(Domain([var1]), + np.arange(4).reshape(4, 1), np.zeros((4, 0))) + var2 = DiscreteVariable(name="x", values=list("def")) + data2 = Table.from_numpy(Domain([var2]), + np.arange(3).reshape(3, 1), np.zeros((3, 0))) + + self.send_signal(widget.Inputs.additional_data, data1, 1) + self.send_signal(widget.Inputs.additional_data, data2, 2) + output = self.get_output(widget.Outputs.data) + np.testing.assert_equal(output.X, + np.array([0, 1, 2, 3, 3, 4, 5]).reshape(7, 1)) + + def test_duplicated_id_column(self): + widget = self.widget + + var1 = DiscreteVariable(name="x", values=list("abcd")) + data1 = Table.from_numpy(Domain([var1]), + np.arange(4).reshape(4, 1), np.zeros((4, 0))) + widget.append_source_column = True + widget.source_column_role = 0 + widget.source_attr_name = "x" + self.send_signal(widget.Inputs.primary_data, data1) + out = self.get_output(widget.Outputs.data) + self.assertEqual(out.domain.attributes[0].name, "x") + self.assertEqual(out.domain.class_var.name, "x (1)") -class TestTools(unittest.TestCase): def test_domain_intersect(self): + widget = self.widget + widget.merge_type = OWConcatenate.MergeIntersection + X1, X2, X3 = map(ContinuousVariable, ["X1", "X2", "X3"]) D1, D2, D3 = map(lambda n: DiscreteVariable(n, values=["a", "b"]), ["D1", "D2", "D3"]) S1, S2 = map(StringVariable, ["S1", "S2"]) domain1 = Domain([X1, X2], [D1], [S1]) domain2 = Domain([X3], [D2], [S2]) - res = domain_intersection(domain1, domain2) + res = widget.merge_domains([domain1, domain2]) self.assertSequenceEqual(res.attributes, []) self.assertSequenceEqual(res.class_vars, []) self.assertSequenceEqual(res.metas, []) domain2 = Domain([X2, X3], [D1, D2, D3], [S1, S2]) - res = domain_intersection(domain1, domain2) + res = widget.merge_domains([domain1, domain2]) self.assertSequenceEqual(res.attributes, [X2]) self.assertSequenceEqual(res.class_vars, [D1]) self.assertSequenceEqual(res.metas, [S1]) - res = domain_intersection(domain1, domain1) + res = widget.merge_domains([domain1, domain1]) self.assertSequenceEqual(res.attributes, domain1.attributes) self.assertSequenceEqual(res.class_vars, domain1.class_vars) self.assertSequenceEqual(res.metas, domain1.metas) def test_domain_union(self): + widget = self.widget + widget.merge_type = OWConcatenate.MergeUnion + X1, X2, X3 = map(ContinuousVariable, ["X1", "X2", "X3"]) D1, D2, D3 = map(lambda n: DiscreteVariable(n, values=["a", "b"]), ["D1", "D2", "D3"]) S1, S2 = map(StringVariable, ["S1", "S2"]) domain1 = Domain([X1, X2], [D1], [S1]) domain2 = Domain([X3], [D2], [S2]) - res = domain_union(domain1, domain2) + res = widget.merge_domains([domain1, domain2]) self.assertSequenceEqual(res.attributes, [X1, X2, X3]) self.assertSequenceEqual(res.class_vars, [D1, D2]) self.assertSequenceEqual(res.metas, [S1, S2]) domain2 = Domain([X3, X2], [D2, D1, D3], [S2, S1]) - res = domain_union(domain1, domain2) + res = widget.merge_domains([domain1, domain2]) self.assertSequenceEqual(res.attributes, [X1, X2, X3]) self.assertSequenceEqual(res.class_vars, [D1, D2, D3]) self.assertSequenceEqual(res.metas, [S1, S2]) - res = domain_union(domain1, domain1) + res = widget.merge_domains([domain1, domain1]) self.assertSequenceEqual(res.attributes, domain1.attributes) self.assertSequenceEqual(res.class_vars, domain1.class_vars) self.assertSequenceEqual(res.metas, domain1.metas) + def test_domain_union_duplicated_names(self): + widget = self.widget + widget.merge_type = OWConcatenate.MergeUnion + + X1, X2, X3 = map(ContinuousVariable, ["X1", "X2", "X3"]) + D1, D2 = map(lambda n: DiscreteVariable(n, values=["a", "b"]), + ["D1", "X2"]) + S1, S2 = map(StringVariable, ["S1", "X3"]) + domain1 = Domain([X1, X2], [D1], [S1]) + domain2 = Domain([X3], [D2], [S2]) + res = widget.merge_domains([domain1, domain2]) + + attributes = res.attributes + class_vars = res.class_vars + metas = res.metas + + self.assertEqual([var.name for var in attributes], + ["X1", "X2 (1)", "X3 (1)"]) + self.assertEqual([var.name for var in class_vars], + ["D1", "X2 (2)"]) + self.assertEqual([var.name for var in metas], + ["S1", "X3 (2)"]) + + x21_val_from = attributes[1].compute_value + self.assertIsInstance(x21_val_from, Identity) + self.assertIsInstance(x21_val_from.variable, ContinuousVariable) + self.assertEqual(x21_val_from.variable.name, "X2") + + x22_val_from = class_vars[1].compute_value + self.assertIsInstance(x22_val_from, Identity) + self.assertIsInstance(x22_val_from.variable, DiscreteVariable) + self.assertEqual(x22_val_from.variable.name, "X2") + + x31_val_from = attributes[2].compute_value + self.assertIsInstance(x31_val_from, Identity) + self.assertIsInstance(x31_val_from.variable, ContinuousVariable) + self.assertEqual(x31_val_from.variable.name, "X3") + + x32_val_from = metas[1].compute_value + self.assertIsInstance(x32_val_from, Identity) + self.assertIsInstance(x32_val_from.variable, StringVariable) + self.assertEqual(x32_val_from.variable.name, "X3") + + def test_get_part_union(self): + get_part = OWConcatenate._get_part # pylint: disable=protected-access + + X1, X2, X3, X4 = map(ContinuousVariable, ["X1", "X2", "X3", "X4"]) + D1, D2, D3 = map(lambda n: DiscreteVariable(n, values=["a", "b"]), + ["X1", "X2", "X3"]) + S1, S2, S3 = map(StringVariable, ["X1", "X2", "X3"]) + domain1 = Domain([X1, X2], [D1], [S1, S3]) + domain2 = Domain([X3, X2], [D2, D1], [S2, S3, S1]) + domain3 = Domain([X3, X2, X4], [D2, D1, D3], [S2, S1, S3]) + + self.assertEqual( + get_part([domain1, domain2], set.union, "attributes"), + [X1, X2, X3] + ) + self.assertEqual( + get_part([domain3, domain1, domain2], set.union, "attributes"), + [X3, X2, X4, X1] + ) + self.assertEqual( + get_part([domain1, domain2], set.union, "class_vars"), + [D1, D2] + ) + self.assertEqual( + get_part([domain3, domain1, domain2], set.union, "class_vars"), + [D2, D1, D3] + ) + self.assertEqual( + get_part([domain3, domain1, domain2], set.union, "class_vars"), + [D2, D1, D3] + ) + self.assertEqual( + get_part([domain1, domain2], set.union, "metas"), + [S1, S3, S2] + ) + self.assertEqual( + get_part([domain2, domain1], set.union, "metas"), + [S2, S3, S1] + ) + self.assertEqual( + get_part([domain3, domain2, domain1], set.union, "metas"), + [S2, S1, S3] + ) + + def test_get_part_intersection(self): + get_part = OWConcatenate._get_part # pylint: disable=protected-access + + X1, X2, X3, X4 = map(ContinuousVariable, ["X1", "X2", "X3", "X4"]) + D1, D2, D3 = map(lambda n: DiscreteVariable(n, values=["a", "b"]), + ["X1", "X2", "X3"]) + S1, S2, S3 = map(StringVariable, ["X1", "X2", "X3"]) + domain1 = Domain([X1, X2], [D1], [S1, S3]) + domain2 = Domain([X3, X2], [D2, D1], [S2, S3, S1]) + domain3 = Domain([X3, X2, X4], [D2, D1, D3], [S2, S1, S3]) + + self.assertEqual( + get_part([domain1, domain2], set.intersection, "attributes"), + [X2] + ) + self.assertEqual( + get_part([domain1, domain2, domain3], set.intersection, "class_vars"), + [D1] + ) + self.assertEqual( + get_part([domain3, domain1, domain2], set.intersection, "metas"), + [S1, S3] + ) + self.assertEqual( + get_part([domain2, domain1, domain3], set.intersection, "metas"), + [S3, S1] + ) + + def test_get_unique_vars(self): + X1, X1a, X2, X2a = map(ContinuousVariable, ["X1", "X1", "X2", "X2"]) + X2.number_of_decimals = 3 + X2a.number_of_decimals = 4 + D1 = DiscreteVariable("X1", values=["a", "b", "c"]) + D1a = DiscreteVariable("X1", values=["e", "b", "d"]) + D2 = DiscreteVariable("X2", values=["a", "b", "c"]) + S1 = StringVariable("X1") + + # pylint: disable=unbalanced-tuple-unpacking,protected-access + uX1, uX2, uD1, uD2, uS1 =\ + OWConcatenate._unique_vars([X1, X1a, X2, X2a, D1, D2, D1a, S1]) + + self.assertIs(X1, uX1) + + self.assertEqual(X2, uX2) + self.assertEqual(X2a, uX2) + self.assertEqual(X2.number_of_decimals, 3) + self.assertEqual(X2a.number_of_decimals, 4) + self.assertEqual(uX2.number_of_decimals, 4) + + self.assertEqual(D1.values, list("abc")) + self.assertEqual(D1a.values, list("ebd")) + self.assertEqual(uD1, D1) + self.assertEqual(uD1, D1a) + self.assertEqual(uD1.values, list("abced")) + + self.assertIs(uD2, D2) + + self.assertIs(S1, uS1) + + def test_different_number_decimals(self): + widget = self.widget + + x1 = ContinuousVariable("x", number_of_decimals=3) + x2 = ContinuousVariable("x", number_of_decimals=4) + data1 = Table.from_numpy(Domain([x1]), np.array([[1], [2], [3]])) + data2 = Table.from_numpy(Domain([x2]), np.array([[1], [2], [3]])) + for d1, d2, id1, id2 in ((data1, data2, 1, 2), (data1, data2, 2, 1), + (data2, data1, 1, 2), (data2, data1, 2, 1)): + self.send_signal(widget.Inputs.additional_data, d1, id1) + self.send_signal(widget.Inputs.additional_data, d2, id2) + out_dom = self.get_output(widget.Outputs.data).domain + self.assertEqual(len(out_dom.attributes), 1) + x = out_dom.attributes[0] + self.assertEqual(x.number_of_decimals, 4) + if __name__ == "__main__": unittest.main()