diff --git a/Orange/widgets/data/owmergedata.py b/Orange/widgets/data/owmergedata.py index 6a99c72194a..6d1e65f3868 100644 --- a/Orange/widgets/data/owmergedata.py +++ b/Orange/widgets/data/owmergedata.py @@ -409,7 +409,8 @@ def merge(self): return None method = self._merge_methods[self.merging] lefti, righti, rightu = method(self, left, left_mask, right, right_mask) - reduced_extra_data = self._compute_reduced_extra_data(right_vars) + reduced_extra_data = \ + self._compute_reduced_extra_data(right_vars, lefti, righti, rightu) return self._join_table_by_indices( reduced_extra_data, lefti, righti, rightu) @@ -446,18 +447,40 @@ def _check_uniqueness(self, left, left_mask, right, right_mask): ok = False return ok - def _compute_reduced_extra_data(self, extra_vars): + def _compute_reduced_extra_data(self, + right_match_vars, lefti, righti, rightu): """Prepare a table with extra columns that will appear in the merged table""" domain = self.data.domain extra_domain = self.extra_data.domain - all_vars = set(chain(domain.variables, domain.metas)) - if self.merging != self.OuterJoin: - all_vars |= set(extra_vars) - extra_vars = chain(extra_domain.variables, extra_domain.metas) - return self.extra_data[:, [var for var in extra_vars - if var not in all_vars]] + def var_needed(var): + if rightu is not None and rightu.size: + return True + if var in right_match_vars and self.merging != self.OuterJoin: + return False + if var not in domain: + return True + both_defined = (lefti != -1) * (righti != -1) + left_col = \ + self.data.get_column_view(var)[0][lefti[both_defined]] + right_col = \ + self.extra_data.get_column_view(var)[0][righti[both_defined]] + if var.is_primitive(): + left_col = left_col.astype(float) + right_col = right_col.astype(float) + mask_left = np.isfinite(left_col) + mask_right = np.isfinite(right_col) + return not ( + np.all(mask_left == mask_right) + and np.all(left_col[mask_left] == right_col[mask_right])) + else: + return not np.all(left_col == right_col) + + extra_vars = [ + var for var in chain(extra_domain.variables, extra_domain.metas) + if var_needed(var)] + return self.extra_data[:, extra_vars] @staticmethod def _values(data, var, mask): @@ -532,20 +555,34 @@ def _join_table_by_indices(self, reduced_extra, lefti, righti, rightu): of rows given in indices""" if not lefti.size: return None - domain = Orange.data.Domain( - *(getattr(self.data.domain, x) + getattr(reduced_extra.domain, x) - for x in ("attributes", "class_vars", "metas"))) - domain = self._domain_rename_duplicates(domain) - X = self._join_array_by_indices(self.data.X, reduced_extra.X, lefti, righti) + lt_dom = self.data.domain + xt_dom = reduced_extra.domain + domain = self._domain_rename_duplicates( + lt_dom.attributes + xt_dom.attributes, + lt_dom.class_vars + xt_dom.class_vars, + lt_dom.metas + xt_dom.metas) + X = self._join_array_by_indices( + self.data.X, reduced_extra.X, lefti, righti) Y = self._join_array_by_indices( np.c_[self.data.Y], np.c_[reduced_extra.Y], lefti, righti) string_cols = [i for i, var in enumerate(domain.metas) if var.is_string] metas = self._join_array_by_indices( self.data.metas, reduced_extra.metas, lefti, righti, string_cols) if rightu is not None: - extras = self.extra_data[rightu].transform(domain) + # This domain is used for transforming the extra rows for outer join + # It must use the original - not renamed - variables from right, so + # values are copied, + # but new domain for the left, so renamed values are *not* copied + right_domain = Orange.data.Domain( + domain.attributes[:len(lt_dom.attributes)] + xt_dom.attributes, + domain.class_vars[:len(lt_dom.class_vars)] + xt_dom.class_vars, + domain.metas[:len(lt_dom.metas)] + xt_dom.metas) + extras = self.extra_data[rightu].transform(right_domain) X = np.vstack((X, extras.X)) - Y = np.vstack((Y, extras.Y)) + extras_Y = extras.Y + if extras_Y.ndim == 1: + extras_Y = extras_Y.reshape(-1, 1) + Y = np.vstack((Y, extras_Y)) metas = np.vstack((metas, extras.metas)) table = Orange.data.Table.from_numpy(domain, X, Y, metas) table.name = getattr(self.data, 'name', '') @@ -558,28 +595,27 @@ def _join_table_by_indices(self, reduced_extra, lefti, righti, rightu): return table - def _domain_rename_duplicates(self, domain): + def _domain_rename_duplicates(self, attributes, class_vars, metas): """Check for duplicate variable names in domain. If any, rename the variables, by replacing them with new ones (names are appended a number). """ - attrs, cvars, metas = [], [], [] - n_attrs, n_cvars, n_metas = (len(domain.attributes), - len(domain.class_vars), len(domain.metas)) - lists = [attrs] * n_attrs + [cvars] * n_cvars + [metas] * n_metas + attrs, cvars, mets = [], [], [] + n_attrs, n_cvars, n_metas = len(attributes), len(class_vars), len(metas) + lists = [attrs] * n_attrs + [cvars] * n_cvars + [mets] * n_metas - variables = domain.variables + domain.metas - proposed_names = [m.name for m in variables] + all_vars = attributes + class_vars + metas + proposed_names = [m.name for m in all_vars] unique_names = get_unique_names_duplicates(proposed_names) duplicates = set() for p_name, u_name, var, c in zip(proposed_names, unique_names, - variables, lists): + all_vars, lists): if p_name != u_name: duplicates.add(p_name) var = var.copy(name=u_name) c.append(var) if duplicates: self.Warning.renamed_vars(", ".join(duplicates)) - return Orange.data.Domain(attrs, cvars, metas) + return Orange.data.Domain(attrs, cvars, mets) @staticmethod def _join_array_by_indices(left, right, lefti, righti, string_cols=None): diff --git a/Orange/widgets/data/tests/test_owmergedata.py b/Orange/widgets/data/tests/test_owmergedata.py index be3873242f4..00a910c846f 100644 --- a/Orange/widgets/data/tests/test_owmergedata.py +++ b/Orange/widgets/data/tests/test_owmergedata.py @@ -1,5 +1,7 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring +# There are never too many tests, so: +# pylint: disable=too-many-lines,too-many-public-methods from itertools import chain import unittest @@ -447,10 +449,15 @@ def test_output_merge_by_ids_inner(self): def test_output_merge_by_ids_outer(self): """Check output for merging option 'Concatenate tables, merge rows' by Source position (index)""" - domain = self.dataA.domain + domainA = self.dataA.domain + values = domainA.class_var.values + domain = Domain(domainA.attributes, + (DiscreteVariable("clsA (1)", values), + DiscreteVariable("clsA (2)", values)), + domainA.metas) result = Table(domain, np.array([[1, 1], [2, 0], [3, np.nan], [np.nan, 0]]), - np.array([1, 2, np.nan, 0]), + np.array([[1, 1], [2, 2], [np.nan, np.nan], [np.nan, 0]]), np.array([[1.0, "m2"], [np.nan, "m3"], [0.0, ""], [np.nan, "m1"]]).astype(object)) self.widget.attr_boxes.set_state([(INSTANCEID, INSTANCEID)]) @@ -463,6 +470,33 @@ def test_output_merge_by_ids_outer(self): np.testing.assert_equal( out.ids, np.hstack((self.dataA.ids[1:], self.dataA.ids[:1]))) + def test_output_merge_by_ids_outer_single_class(self): + """Check output for merging option 'Concatenate tables, merge rows' by + Source position (index) when all extra rows are matched and there is + only a single class variable in the output""" + domainA = self.dataA.domain + values = domainA.class_var.values + domain = Domain(domainA.attributes, + DiscreteVariable("clsA", values), + domainA.metas) + result = Table(domain, + np.array([[0, 0], [1, 1], [2, 0], [3, np.nan]]), + np.array([[0], [1], [2], [np.nan]]), + np.array([[0.0, "m1"], [1.0, "m2"], [np.nan, "m3"], + [0.0, ""]]).astype(object)) + self.widget.attr_boxes.set_state([(INSTANCEID, INSTANCEID)]) + self.widget.merging = 2 + self.widget.controls.merging.buttons[self.widget.OuterJoin].click() + # When Y is a single column, Table.Y returns a vector, not a 2d array, + # which cause an exception in outer_join's vstack for Y if extra data + # has no unmatched rows. + # This test also checks this condition. + self.send_signal(self.widget.Inputs.data, self.dataA[:, [0, "clsA", -1]]) + self.send_signal(self.widget.Inputs.extra_data, self.dataA[:3, [1, -2]]) + out = self.get_output(self.widget.Outputs.data) + self.assertTablesEqual(out, result) + np.testing.assert_equal(out.ids, self.dataA.ids) + def test_output_merge_by_index_left(self): """Check output for merging option 'Append columns from Extra Data' by Position (index)""" @@ -605,7 +639,7 @@ def test_output_merge_by_attribute_outer_same_attr(self): np.testing.assert_equal( out.X, np.array([[0, 6], [1, 4], [2, 7], [np.nan, 5]])) - self.assertEqual(" ".join(out.metas.flatten()), "a b c d") + self.assertEqual(" ".join(out.metas.flatten()), "a a b b c c d") def test_output_merge_by_class_left(self): """Check output for merging option 'Append columns from Extra Data' by @@ -953,6 +987,51 @@ def test_duplicate_names(self): self.assertListEqual([m.name for m in merged_data.domain.metas], ["Feature (1)", "Feature (2)"]) + def test_keep_non_duplicate_variables(self): + domain = Domain([ContinuousVariable("A"), ContinuousVariable("B")]) + data = Table(domain, np.array([[0., 0], [0, 1]])) + extra_data = Table(domain, np.array([[0., 1], [0, 1]])) + self.send_signal(self.widget.Inputs.data, data) + self.send_signal(self.widget.Inputs.extra_data, extra_data) + merged_data = self.get_output(self.widget.Outputs.data) + self.assertListEqual([m.name for m in merged_data.domain.variables], + ["A", "B (1)", "B (2)"]) + + def test_keep_non_duplicate_variables_missing_rows(self): + c = DiscreteVariable("C", values=["a", "b", "c"]) + domain = Domain([ContinuousVariable("A"), ContinuousVariable("B"), c]) + data = Table(domain, np.array([[0., 0, 0], [1, 1, 1]])) + extra_data = Table(domain, np.array([[0., 1, 1], [0, 1, 2]])) + self.send_signal(self.widget.Inputs.data, data) + self.send_signal(self.widget.Inputs.extra_data, extra_data) + self.widget.attr_boxes.set_state([(c, c)]) + + # Only one row is matched; A has different values and it's duplicated, + # and B has the same values, so we get only one copy + self.widget.merging = self.widget.InnerJoin + self.widget.unconditional_commit() + merged_data = self.get_output(self.widget.Outputs.data) + self.assertListEqual([m.name for m in merged_data.domain.variables], + ["A (1)", "B", "C", "A (2)"]) + + # Table has additional rows; keep all columns + self.widget.merging = self.widget.OuterJoin + self.widget.unconditional_commit() + merged_data = self.get_output(self.widget.Outputs.data) + self.assertListEqual( + [m.name for m in merged_data.domain.variables], + ["A (1)", "B (1)", "C (1)", "A (2)", "B (2)", "C (2)"]) + + # First row is unmatched, data for B(2) is missing, but attribute + # shouldn't be added + extra_data = Table(domain, np.array([[1., 1, 1], [0, 1, 2]])) + self.send_signal(self.widget.Inputs.extra_data, extra_data) + self.widget.merging = self.widget.LeftJoin + self.widget.unconditional_commit() + merged_data = self.get_output(self.widget.Outputs.data) + self.assertListEqual([m.name for m in merged_data.domain.variables], + ["A", "B", "C"]) + if __name__ == "__main__": unittest.main()