Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Edit Domain: Add option to unlink variable from source variable #4863

Merged
merged 1 commit into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 103 additions & 43 deletions Orange/widgets/data/oweditdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Categorical(
("name", str),
("categories", Tuple[str, ...]),
("annotations", AnnotationsType),
("linked", bool)
])): pass


Expand All @@ -104,20 +105,23 @@ class Real(
# a precision (int, and a format specifier('f', 'g', or '')
("format", Tuple[int, str]),
("annotations", AnnotationsType),
("linked", bool)
])): pass


class String(
_DataType, NamedTuple("String", [
("name", str),
("annotations", AnnotationsType),
("linked", bool)
])): pass


class Time(
_DataType, NamedTuple("Time", [
("name", str),
("annotations", AnnotationsType),
("linked", bool)
])): pass


Expand Down Expand Up @@ -175,10 +179,14 @@ def __call__(self, var):
return var._replace(annotations=self.annotations)


Transform = Union[Rename, CategoriesMapping, Annotate]
TransformTypes = (Rename, CategoriesMapping, Annotate)
class Unlink(_DataType, namedtuple("Unlink", [])):
"""Unlink variable from its source, that is, remove compute_value"""

CategoricalTransformTypes = (CategoriesMapping, )

Transform = Union[Rename, CategoriesMapping, Annotate, Unlink]
TransformTypes = (Rename, CategoriesMapping, Annotate, Unlink)

CategoricalTransformTypes = (CategoriesMapping, Unlink)


# Reinterpret vector transformations.
Expand Down Expand Up @@ -221,7 +229,7 @@ def __call__(self, vector: DataVector) -> StringVector:
if isinstance(var, String):
return vector
return StringVector(
String(var.name, var.annotations),
String(var.name, var.annotations, False),
lambda: as_string(vector.data()),
)

Expand All @@ -241,19 +249,19 @@ def data() -> MArray:
a = categorical_to_string_vector(d, var.values)
return MArray(as_float_or_nan(a, where=a.mask), mask=a.mask)
return RealVector(
Real(var.name, (6, 'g'), var.annotations), data
Real(var.name, (6, 'g'), var.annotations, var.linked), data
)
elif isinstance(var, Time):
return RealVector(
Real(var.name, (6, 'g'), var.annotations),
Real(var.name, (6, 'g'), var.annotations, var.linked),
lambda: vector.data().astype(float)
)
elif isinstance(var, String):
def data():
s = vector.data()
return MArray(as_float_or_nan(s, where=s.mask), mask=s.mask)
return RealVector(
Real(var.name, (6, "g"), var.annotations), data
Real(var.name, (6, "g"), var.annotations, var.linked), data
)
raise AssertionError

Expand All @@ -266,22 +274,10 @@ def __call__(self, vector: DataVector) -> CategoricalVector:
var, _ = vector
if isinstance(var, Categorical):
return vector
if isinstance(var, Real):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations),
lambda: data
)
elif isinstance(var, Time):
if isinstance(var, (Real, Time, String)):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations),
lambda: data
)
elif isinstance(var, String):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations),
Categorical(var.name, values, var.annotations, var.linked),
lambda: data
)
raise AssertionError
Expand All @@ -295,7 +291,7 @@ def __call__(self, vector: DataVector) -> TimeVector:
return vector
elif isinstance(var, Real):
return TimeVector(
Time(var.name, var.annotations),
Time(var.name, var.annotations, var.linked),
lambda: vector.data().astype("M8[us]")
)
elif isinstance(var, Categorical):
Expand All @@ -305,15 +301,15 @@ def data():
dt = pd.to_datetime(s, errors="coerce").values.astype("M8[us]")
return MArray(dt, mask=d.mask)
return TimeVector(
Time(var.name, var.annotations), data
Time(var.name, var.annotations, var.linked), data
)
elif isinstance(var, String):
def data():
s = vector.data()
dt = pd.to_datetime(s, errors="coerce").values.astype("M8[us]")
return MArray(dt, mask=s.mask)
return TimeVector(
Time(var.name, var.annotations), data
Time(var.name, var.annotations, var.linked), data
)
raise AssertionError

Expand Down Expand Up @@ -532,6 +528,17 @@ def __init__(self, parent=None, **kwargs):
)
form.addRow("Name:", self.name_edit)

self.unlink_var_cb = QCheckBox(
"Unlink variable from its source variable", self,
toolTip="Make Orange forget that the variable is derived from "
"another.\n"
"Use this for instance when you want to consider variables "
"with the same name but from different sources as the same "
"variable."
)
self.unlink_var_cb.toggled.connect(self._set_unlink)
form.addRow("", self.unlink_var_cb)

vlayout = QVBoxLayout(margin=0, spacing=1)
self.labels_edit = view = QTreeView(
objectName="annotation-pairs-edit",
Expand Down Expand Up @@ -616,17 +623,23 @@ def set_data(self, var, transform=()):
if var is not None:
name = var.name
annotations = var.annotations
unlink = False
for tr in transform:
if isinstance(tr, Rename):
name = tr.name
elif isinstance(tr, Annotate):
annotations = tr.annotations
elif isinstance(tr, Unlink):
unlink = True
self.name_edit.setText(name)
self.labels_model.set_dict(dict(annotations))
self.add_label_action.actionGroup().setEnabled(True)
self.unlink_var_cb.setChecked(unlink)
else:
self.add_label_action.actionGroup().setEnabled(False)

self.unlink_var_cb.setDisabled(var is None or not var.linked)

def get_data(self):
"""Retrieve the modified variable.
"""
Expand All @@ -639,6 +652,8 @@ def get_data(self):
tr.append(Rename(name))
if self.var.annotations != labels:
tr.append(Annotate(labels))
if self.var.linked and self.unlink_var_cb.isChecked():
tr.append(Unlink())
return self.var, tr

def clear(self):
Expand All @@ -647,6 +662,7 @@ def clear(self):
self.var = None
self.name_edit.setText("")
self.labels_model.setRowCount(0)
self.unlink_var_cb.setChecked(False)

@Slot()
def on_name_changed(self):
Expand All @@ -661,6 +677,10 @@ def on_label_selection_changed(self):
selected = self.labels_edit.selectionModel().selectedRows()
self.remove_label_action.setEnabled(bool(len(selected)))

def _set_unlink(self, unlink):
self.unlink_var_cb.setChecked(unlink)
self.variable_changed.emit()


class GroupItemsDialog(QDialog):
"""
Expand Down Expand Up @@ -1157,7 +1177,7 @@ def __init__(self, *args, **kwargs):
hlayout.addStretch(10)
vlayout.addLayout(hlayout)

form.insertRow(1, "Values:", vlayout)
form.insertRow(2, "Values:", vlayout)

QWidget.setTabOrder(self.name_edit, self.values_edit)
QWidget.setTabOrder(self.values_edit, button1)
Expand Down Expand Up @@ -2030,23 +2050,32 @@ def state(i):
model.data(midx, TransformRole))

state = [state(i) for i in range(model.rowCount())]
if all(tr is None or not tr for _, tr in state) \
and self.output_table_name in ("", data.name):
input_vars = data.domain.variables + data.domain.metas
if self.output_table_name in ("", data.name) \
and not any(requires_transform(var, trs)
for var, (_, trs) in zip(input_vars, state)):
self.Outputs.data.send(data)
self.info.set_output_summary(len(data),
format_summary_details(data))
return

output_vars = []
input_vars = data.domain.variables + data.domain.metas
assert all(v_.vtype.name == v.name
for v, (v_, _) in zip(input_vars, state))
output_vars = []
unlinked_vars = []
unlink_domain = False
for (_, tr), v in zip(state, input_vars):
if tr:
var = apply_transform(v, data, tr)
if requires_unlink(v, tr):
unlinked_var = var.copy(compute_value=None)
unlink_domain = True
else:
unlinked_var = var
else:
var = v
unlinked_var = var = v
output_vars.append(var)
unlinked_vars.append(unlinked_var)

if len(output_vars) != len({v.name for v in output_vars}):
self.Error.duplicate_var_name()
Expand All @@ -2058,15 +2087,23 @@ def state(i):
nx = len(domain.attributes)
ny = len(domain.class_vars)

Xs = output_vars[:nx]
Ys = output_vars[nx: nx + ny]
Ms = output_vars[nx + ny:]
# Move non primitive Xs, Ys to metas (if they were changed)
Ms += [v for v in Xs + Ys if not v.is_primitive()]
Xs = [v for v in Xs if v.is_primitive()]
Ys = [v for v in Ys if v.is_primitive()]
domain = Orange.data.Domain(Xs, Ys, Ms)
def construct_domain(vars_list):
# Move non primitive Xs, Ys to metas (if they were changed)
Xs = [v for v in vars_list[:nx] if v.is_primitive()]
Ys = [v for v in vars_list[nx: nx + ny] if v.is_primitive()]
Ms = vars_list[nx + ny:] + \
[v for v in vars_list[:nx + ny] if not v.is_primitive()]
return Orange.data.Domain(Xs, Ys, Ms)

domain = construct_domain(output_vars)
new_data = data.transform(domain)
if unlink_domain:
unlinked_domain = construct_domain(unlinked_vars)
new_data = new_data.from_numpy(
unlinked_domain,
new_data.X, new_data.Y, new_data.metas, new_data.W,
new_data.attributes, new_data.ids
)
if self.output_table_name:
new_data.name = self.output_table_name
self.Outputs.data.send(new_data)
Expand Down Expand Up @@ -2236,7 +2273,7 @@ def i(text):
def text(text):
return "<span>{}</span>".format(escape(text))
assert trs
rename = annotate = catmap = None
rename = annotate = catmap = unlink = None
reinterpret = None

for tr in trs:
Expand All @@ -2246,6 +2283,8 @@ def text(text):
annotate = tr
elif isinstance(tr, CategoriesMapping):
catmap = tr
elif isinstance(tr, Unlink):
unlink = tr
elif isinstance(tr, ReinterpretTransformTypes):
reinterpret = tr

Expand All @@ -2258,6 +2297,8 @@ def text(text):
header = "{} → {}".format(var.name, rename.name)
else:
header = var.name
if unlink is not None:
header += "(unlinked from source)"

values_section = None
if catmap is not None:
Expand Down Expand Up @@ -2323,14 +2364,15 @@ def abstract(var):
(key, str(value))
for key, value in var.attributes.items()
))
linked = var.compute_value is not None
if isinstance(var, Orange.data.DiscreteVariable):
return Categorical(var.name, tuple(var.values), annotations)
return Categorical(var.name, tuple(var.values), annotations, linked)
elif isinstance(var, Orange.data.TimeVariable):
return Time(var.name, annotations)
return Time(var.name, annotations, linked)
elif isinstance(var, Orange.data.ContinuousVariable):
return Real(var.name, (var.number_of_decimals, 'f'), annotations)
return Real(var.name, (var.number_of_decimals, 'f'), annotations, linked)
elif isinstance(var, Orange.data.StringVariable):
return String(var.name, annotations)
return String(var.name, annotations, linked)
else:
raise TypeError

Expand Down Expand Up @@ -2359,6 +2401,24 @@ def apply_transform(var, table, trs):
return var


def requires_unlink(var: Orange.data.Variable, trs: List[Transform]) -> bool:
# Variable is only unlinked if it has compute_value or if it has other
# transformations (that might had added compute_value)
return trs is not None \
and any(isinstance(tr, Unlink) for tr in trs) \
and (var.compute_value is not None or len(trs) > 1)


def requires_transform(var: Orange.data.Variable, trs: List[Transform]) -> bool:
# Unlink is treated separately: Unlink is required only if the variable
# has compute_value. Hence tranform is required if it has any
# transformations other than Unlink, or if unlink is indeed required.
return trs is not None and (
not all(isinstance(tr, Unlink) for tr in trs)
or requires_unlink(var, trs)
)


@singledispatch
def apply_transform_var(var, trs):
# type: (Orange.data.Variable, List[Transform]) -> Orange.data.Variable
Expand Down
Loading