Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Discretize: Simplify interface, add nicer binning #5919

Merged
merged 9 commits into from
Apr 25, 2022
208 changes: 174 additions & 34 deletions Orange/preprocess/discretize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import scipy.sparse as sp

from Orange.data import DiscreteVariable, Domain
from Orange.data import DiscreteVariable, Domain, TimeVariable, Table
from Orange.data.sql.table import SqlTable
from Orange.statistics import distribution, contingency, util as ut
from Orange.statistics.basic_stats import BasicStats
Expand Down Expand Up @@ -58,13 +58,17 @@ def _fmt_interval(low, high, formatter):
return f"{formatter(low)} - {formatter(high)}"

@classmethod
def create_discretized_var(cls, var, points):
def fmt(val):
sval = var.str_val(val)
# For decimal numbers, remove trailing 0's and . if no decimals left
if re.match(r"^\d+\.\d+", sval):
return sval.rstrip("0").rstrip(".")
return sval
def create_discretized_var(cls, var, points, ndigits=None):
if ndigits is None:
def fmt(val):
sval = var.str_val(val)
# For decimal numbers, remove trailing 0's and . if no decimals left
if re.match(r"^\d+\.\d+", sval):
return sval.rstrip("0").rstrip(".")
return sval
else:
def fmt(val):
return f"{val:.{ndigits}f}"

lpoints = list(points)
if lpoints:
Expand Down Expand Up @@ -96,8 +100,8 @@ def __init__(self, var, points):
self.points = points

def __call__(self):
return 'width_bucket(%s, ARRAY%s::double precision[])' % (
self.var.to_sql(), str(self.points))
return f'width_bucket({self.var.to_sql()}, ' \
f'ARRAY{str(self.points)}::double precision[])'


class SingleValueSql:
Expand Down Expand Up @@ -163,30 +167,174 @@ def __init__(self, n=4):
self.n = n

# noinspection PyProtectedMember
def __call__(self, data, attribute, fixed=None):
def __call__(self, data: Table, attribute, fixed=None):
if fixed:
min, max = fixed[attribute.name]
points = self._split_eq_width(min, max)
mn, mx = fixed[attribute.name]
points = self._split_eq_width(mn, mx)
else:
if type(data) == SqlTable:
stats = BasicStats(data, attribute)
points = self._split_eq_width(stats.min, stats.max)
else:
values = data[:, attribute]
values = values.X if values.X.size else values.Y
values, _ = data.get_column_view(attribute)
if values.size:
min, max = ut.nanmin(values), ut.nanmax(values)
points = self._split_eq_width(min, max)
mn, mx = ut.nanmin(values), ut.nanmax(values)
points = self._split_eq_width(mn, mx)
else:
points = []
return Discretizer.create_discretized_var(
data.domain[attribute], points)

def _split_eq_width(self, min, max):
if np.isnan(min) or np.isnan(max) or min == max:
def _split_eq_width(self, mn, mx):
if np.isnan(mn) or np.isnan(mx) or mn == mx:
return []
dif = (max - min) / self.n
return [min + (i + 1) * dif for i in range(self.n - 1)]
dif = (mx - mn) / self.n
return [mn + i * dif for i in range(1, self.n)]


class TooManyIntervals(ValueError):
pass


class FixedWidth(Discretization):
def __init__(self, width, digits=None):
super().__init__()
self.width = width
self.digits = digits

def __call__(self, data: Table, attribute):
values, _ = data.get_column_view(attribute)
points = []
if values.size:
mn, mx = ut.nanmin(values), ut.nanmax(values)
if not np.isnan(mn):
minf = int(1 + np.floor(mn / self.width))
maxf = int(1 + np.floor(mx / self.width))
if maxf - minf - 1 >= 100:
raise TooManyIntervals
points = [i * self.width for i in range(minf, maxf)]
return Discretizer.create_discretized_var(
data.domain[attribute], points, ndigits=self.digits)


class FixedTimeWidth(Discretization):
def __init__(self, width, unit):
# unit: 0=year, 1=month, 2=day, 3=hour, 4=minute, 5=second
# for week, use day with a width of 7
super().__init__()
self.width = width
self.unit = unit

def __call__(self, data: Table, attribute):
fmt = ["%Y", "%y %b", "%y %b %d", "%y %b %d %H:%M", "%y %b %d %H:%M",
"%H:%M:%S"][self.unit]
values, _ = data.get_column_view(attribute)
times = []
if values.size:
mn, mx = ut.nanmin(values), ut.nanmax(values)
if not np.isnan(mn):
mn = utc_from_timestamp(mn).timetuple()
mx = utc_from_timestamp(mx).timetuple()
times = _time_range(mn, mx, self.unit, self.width, 0, 100)
if times is None:
raise TooManyIntervals
times = [time.struct_time(t + (0, 0, 0)) for t in times][1:-1]
points = np.array([calendar.timegm(t) for t in times])
values = [time.strftime(fmt, t) for t in times]
values = _simplified_time_intervals(values)
var = data.domain[attribute]
return DiscreteVariable(name=var.name, values=values,
compute_value=Discretizer(var, points),
sparse=var.sparse)


def _simplified_time_intervals(labels):
def no_common(a, b):
for i, pa, pb in zip(count(), a, b):
if pa != pb:
if common + i == 2:
i -= 1
return b[i:]
# can't come here (unless a == b?!)
return b # pragma: no cover


if not labels:
return []
common = 100
labels = [label.split() for label in labels]
for common, parts in enumerate(map(set, zip(*labels))):
if len(parts) > 1:
break
if common == 2: # If we keep days, we must also keep months
common = 1
labels = [label[common:] for label in labels]
join = " ".join
return [f"< {join(labels[0])}"] + [
f"{join(low)} - {join(no_common(low, high))}"
for low, high in zip(labels, labels[1:])
] + [f"≥ {join(labels[-1])}"]



class Binning(Discretization):
"""Discretization with nice thresholds

This class creates different decimal or time binnings and picks the one
in which the number of interval is closest to the desired number.
The difference is measured as proportion; e.g. having 30 % less intervals
is the same difference as having 30 % too many.

.. attribute:: n

Desired number of bins (default: 4).
"""
def __init__(self, n=4):
self.n = n

def __call__(self, data: Table, attribute):
attribute = data.domain[attribute]
values, _ = data.get_column_view(attribute)
values = values.astype(float)
if not values.size:
return self._create_binned_var(None, attribute)

var = data.domain[attribute]
if isinstance(var, TimeVariable):
binnings = time_binnings(values)
else:
binnings = decimal_binnings(values)
return self._create_binned_var(binnings, attribute)

def _create_binned_var(self, binnings, variable):
if not binnings:
return Discretizer.create_discretized_var(variable, [])

# If self.n is 2, require two intervals (one threshold, excluding top
# and bottom), else require at least three intervals
# ... unless this is the only option, in which case we use it
# Break ties in favour of more bins
binning = min(
(binning for binning in binnings
if len(binning.thresholds) - 2 >= 1 + (self.n != 2)),
key=lambda binning: (abs(self.n - (len(binning.short_labels) - 1)),
-len(binning.short_labels)),
default=binnings[-1])

if len(binning.thresholds) == 2:
return Discretizer.create_discretized_var(variable, [])

blabels = binning.labels[1:-1]
labels = [f"< {blabels[0]}"] + [
f"{lab1} - {lab2}" for lab1, lab2 in zip(blabels, blabels[1:])
] + [f"≥ {blabels[-1]}"]

discretizer = Discretizer(variable, list(binning.thresholds[1:-1]))
dvar = DiscreteVariable(name=variable.name, values=labels,
compute_value=discretizer,
sparse=variable.sparse)
dvar.source_variable = variable
return dvar


class BinDefinition(NamedTuple):
Expand Down Expand Up @@ -234,7 +382,7 @@ def decimal_binnings(
data, *, min_width=0, min_bins=2, max_bins=50,
min_unique=5, add_unique=0,
factors=(0.01, 0.02, 0.025, 0.05, 0.1, 0.2, 0.25, 0.5, 1, 2, 5, 10, 20),
label_fmt="%g"):
label_fmt="%g") -> List[BinDefinition]:
"""
Find a set of nice splits of data into bins

Expand Down Expand Up @@ -283,22 +431,13 @@ def decimal_binnings(
or a function for formatting thresholds (e.g. var.str_val)

Returns:
bin_boundaries (list of np.ndarray): a list of bin boundaries,
including the top boundary of the last interval, hence the list
size equals the number bins + 1. These array match the `bin`
argument of `numpy.histogram`.

This is returned if `return_defs` is left `True`.

bin_definition (list of BinDefinition):
`BinDefinition` is a named tuple containing the beginning of the
first bin (`start`), number of bins (`nbins`) and their widths
(`width`). The last value can also be a `nd.array` with `nbins + 1`
elements, which describes bins of unequal width and is used for
binnings that match the unique values in the data (see `min_unique`
and `add_unique`).

This is returned if `return_defs` is `False`.
"""
bins = []

Expand Down Expand Up @@ -329,7 +468,8 @@ def decimal_binnings(
return bins


def time_binnings(data, *, min_bins=2, max_bins=50, min_unique=5, add_unique=0):
def time_binnings(data, *, min_bins=2, max_bins=50, min_unique=5, add_unique=0
) -> List[BinDefinition]:
"""
Find a set of nice splits of time variable data into bins

Expand All @@ -355,7 +495,7 @@ def time_binnings(data, *, min_bins=2, max_bins=50, min_unique=5, add_unique=0):
number of unique values

Returns:
bin_boundaries (list): a list of possible binning.
bin_boundaries (list of BinDefinition): a list of possible binning.
Each element of `bin_boundaries` is a tuple consisting of a label
describing the bin size (e.g. `2 weeks`) and a list of thresholds.
Thresholds are given as pairs
Expand Down Expand Up @@ -448,7 +588,7 @@ def _simplified_labels(labels):
to_remove = "42"
while True:
firsts = {f for f, *_ in (lab.split() for lab in labels)}
if len(firsts) > 1:
if len(firsts) != 1: # can be 0 if there are no labels
break
to_remove = firsts.pop()
flen = len(to_remove)
Expand Down
Loading