Skip to content

Commit

Permalink
Merge pull request #6604 from janezd/edit-domain-unlink-forward
Browse files Browse the repository at this point in the history
Edit Domain: allow the (existing) checkbox to also control removal of newly created compute value
  • Loading branch information
JakaKokosar authored Nov 13, 2023
2 parents 2e79655 + d36df7f commit 5f02359
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 79 deletions.
67 changes: 29 additions & 38 deletions Orange/widgets/data/oweditdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ class Categorical(
_DataType, NamedTuple("Categorical", [
("name", str),
("categories", Tuple[str, ...]),
("annotations", AnnotationsType),
("linked", bool)
("annotations", AnnotationsType)
])): pass


Expand All @@ -117,24 +116,21 @@ class Real(
("name", str),
# a precision (int, and a format specifier('f', 'g', or '')
("format", Tuple[int, str]),
("annotations", AnnotationsType),
("linked", bool)
("annotations", AnnotationsType)
])): pass


class String(
_DataType, NamedTuple("String", [
("name", str),
("annotations", AnnotationsType),
("linked", bool)
("annotations", AnnotationsType)
])): pass


class Time(
_DataType, NamedTuple("Time", [
("name", str),
("annotations", AnnotationsType),
("linked", bool)
("annotations", AnnotationsType)
])): pass


Expand Down Expand Up @@ -248,7 +244,7 @@ def __call__(self, vector: DataVector) -> StringVector:
if isinstance(var, String):
return vector
return StringVector(
String(var.name, var.annotations, False),
String(var.name, var.annotations),
lambda: as_string(vector.data()),
)

Expand All @@ -268,19 +264,19 @@ def data() -> MArray:
a = categorical_to_string_vector(d, var.values)
return MArray(as_float_or_nan(a, where=a.mask), mask=a.mask)
return RealVector(
Real(var.name, (6, 'g'), var.annotations, var.linked), data
Real(var.name, (6, 'g'), var.annotations), data
)
elif isinstance(var, Time):
return RealVector(
Real(var.name, (6, 'g'), var.annotations, var.linked),
Real(var.name, (6, 'g'), var.annotations),
lambda: vector.data().astype(float)
)
elif isinstance(var, String):
def data():
s = vector.data()
return MArray(as_float_or_nan(s, where=s.mask), mask=s.mask)
return RealVector(
Real(var.name, (6, "g"), var.annotations, var.linked), data
Real(var.name, (6, "g"), var.annotations), data
)
raise AssertionError

Expand All @@ -296,7 +292,7 @@ def __call__(self, vector: DataVector) -> CategoricalVector:
if isinstance(var, (Real, Time, String)):
data, values = categorical_from_vector(vector.data())
return CategoricalVector(
Categorical(var.name, values, var.annotations, var.linked),
Categorical(var.name, values, var.annotations),
lambda: data
)
raise AssertionError
Expand All @@ -310,7 +306,7 @@ def __call__(self, vector: DataVector) -> TimeVector:
return vector
elif isinstance(var, Real):
return TimeVector(
Time(var.name, var.annotations, var.linked),
Time(var.name, var.annotations),
lambda: vector.data().astype("M8[us]")
)
elif isinstance(var, Categorical):
Expand All @@ -320,15 +316,15 @@ def data():
dt = pd.to_datetime(s, errors="coerce").values.astype("M8[us]")
return MArray(dt, mask=d.mask)
return TimeVector(
Time(var.name, var.annotations, var.linked), data
Time(var.name, var.annotations), data
)
elif isinstance(var, String):
def data():
s = vector.data()
dt = pd.to_datetime(s, errors="coerce").values.astype("M8[us]")
return MArray(dt, mask=s.mask)
return TimeVector(
Time(var.name, var.annotations, var.linked), data
Time(var.name, var.annotations), data
)
raise AssertionError

Expand Down Expand Up @@ -636,7 +632,7 @@ def set_data(self, var, transform=()):
else:
self.add_label_action.actionGroup().setEnabled(False)

self.unlink_var_cb.setDisabled(var is None or not var.linked)
self.unlink_var_cb.setDisabled(var is None)

def get_data(self):
"""Retrieve the modified variable.
Expand All @@ -650,7 +646,7 @@ def get_data(self):
tr.append(Rename(name))
if self.var.annotations != labels:
tr.append(Annotate(labels))
if self.var.linked and self.unlink_var_cb.isChecked():
if self.unlink_var_cb.isChecked():
tr.append(Unlink())
return self.var, tr

Expand Down Expand Up @@ -2033,7 +2029,7 @@ class Outputs:
class Error(widget.OWWidget.Error):
duplicate_var_name = widget.Msg("A variable name is duplicated.")

settings_version = 3
settings_version = 4

_domain_change_hints = Setting({}, schema_only=True)
_merge_dialog_settings = Setting({}, schema_only=True)
Expand Down Expand Up @@ -2324,8 +2320,7 @@ def state(i):
state = [state(i) for i in range(model.rowCount())]
input_vars = data.domain.variables + data.domain.metas
if self.output_table_name in ("", data.name) \
and not any(requires_transform(var, trs)
for var, (_, trs) in zip(input_vars, state)):
and all(tr is None or not tr for _, tr in state):
self.Outputs.data.send(data)
return

Expand Down Expand Up @@ -2476,6 +2471,13 @@ def migrate_settings(cls, settings, version):
settings["_domain_change_hints"] = hints
del settings["context_settings"]

if version < 4 and "_domain_change_hints" in settings:
settings["_domain_change_hints"] = {
(name, desc[:-1]): trs
for (name, desc), trs in settings["_domain_change_hints"].items()
}


def enumerate_columns(
table: Orange.data.Table
) -> Iterable[Tuple[int, str, Orange.data.Variable, Callable[[], ndarray]]]:
Expand Down Expand Up @@ -2649,15 +2651,14 @@ def abstract(var):
(key, str(value))
for key, value in var.attributes.items()
))
linked = var.compute_value is not None
if isinstance(var, Orange.data.DiscreteVariable):
return Categorical(var.name, tuple(var.values), annotations, linked)
return Categorical(var.name, tuple(var.values), annotations)
elif isinstance(var, Orange.data.TimeVariable):
return Time(var.name, annotations, linked)
return Time(var.name, annotations)
elif isinstance(var, Orange.data.ContinuousVariable):
return Real(var.name, (var.number_of_decimals, 'f'), annotations, linked)
return Real(var.name, (var.number_of_decimals, 'f'), annotations)
elif isinstance(var, Orange.data.StringVariable):
return String(var.name, annotations, linked)
return String(var.name, annotations)
else:
raise TypeError

Expand Down Expand Up @@ -2687,23 +2688,13 @@ def apply_transform(var, table, trs):


def requires_unlink(var: Orange.data.Variable, trs: List[Transform]) -> bool:
# Variable is only unlinked if it has compute_value or if it has other
# transformations (that might had added compute_value)
# Variable is only unlinked if it has compute_value or if it has other
# transformations (that might have added compute_value)
return trs is not None \
and any(isinstance(tr, Unlink) for tr in trs) \
and (var.compute_value is not None or len(trs) > 1)


def requires_transform(var: Orange.data.Variable, trs: List[Transform]) -> bool:
# Unlink is treated separately: Unlink is required only if the variable
# has compute_value. Hence tranform is required if it has any
# transformations other than Unlink, or if unlink is indeed required.
return trs is not None and (
not all(isinstance(tr, Unlink) for tr in trs)
or requires_unlink(var, trs)
)


@singledispatch
def apply_transform_var(var, trs):
# type: (Orange.data.Variable, List[Transform]) -> Orange.data.Variable
Expand Down
Loading

0 comments on commit 5f02359

Please sign in to comment.