diff --git a/ezr.py b/ezr.py index 36381ed..2a3af3d 100755 --- a/ezr.py +++ b/ezr.py @@ -105,9 +105,9 @@ def XY(at,txt,lo,hi=None,ys=None) -> xy: "`ys` counts symbols seen in one column between `lo`.. `hi` of another column." return o(this=XY, n=0, at=at, txt=txt, lo=lo, hi=hi or lo, ys=ys or {}) -def NODE(klasses: classes, left=None, right=None) -> node: +def NODE(klasses: classes, parent: node, left=None, right=None) -> node: "NODEs are parts of binary trees." - return o(this=NODE, klasses=klasses, left=left, right=right) + return o(this=NODE, klasses=klasses, parent=parent, left=left, right=right, cut=None) def WANT(best="best", bests=1, rests=1) -> want: "Used to score how well a distribution selects for `best'." @@ -315,20 +315,23 @@ def _span(xys : list[xy]) -> list[xy]: def tree(i:data, klasses:classes, want1:Callable, stop:int=4) -> node: "Return a binary tree, each level splitting on the range with most `score`." - def _grow(klasses:classes, lvl:int=1, above:int=1E30) -> node: + def _grow(klasses:classes, lvl:int=1, parent=None) -> node: "Collect the stats needed for branching, then call `_branch()`." counts = {k:len(rows1) for k,rows1 in klasses.items()} - total = sum(n for n in counts.values()) + total = sum(counts.values()) most = counts[max(counts, key=counts.get)] - return _branch(NODE(klasses), lvl, above, total, most) + return _branch(NODE(klasses,parent), lvl, total, most) - def _branch(here:node, lvl:int, above:int, total:int, most:int) -> node: + def _branch(here:node, lvl:int, total:int, most:int) -> node: "Divide the data on tbe best cut. Recurse." - if total > 2*stop and total < above and most < total: #most==total means "purity" (all of one class) - cut = max(cuts, key=lambda cut0: _want(cut0, here.klasses)) - left,right = _split(cut, here.klasses) - here.left = _grow(left, lvl+1, total) - here.right = _grow(right, lvl+1, total) + if total > 2*stop and most < total: #most==total means "purity" (all of one: class) + here.cut = max(cuts, key=lambda cut0: _want(cut0, here.klasses)) + left,right = _split(here.cut, here.klasses) + leftn = sum(len(rows1) for rows1 in left.values()) + rightn = sum(len(rows1) for rows1 in right.values()) + if leftn < total and rightn < total: + here.left = _grow(left, lvl+1, here) + here.right = _grow(right, lvl+1, here) return here def _want(cut:xy, klasses:classes) -> float : @@ -349,12 +352,17 @@ def _split(cut:xy, klasses:classes) -> tuple[classes,classes]: def nodes(i:node, lvl=0, left=True) -> node: if i: yield i,lvl,left - for j,lvl1,left1 in nodes(i.left, lvl+1, left=True) : yield j,lvl1,left1 + for j,lvl1,left1 in nodes(i.left, lvl+1, left=True) : yield j,lvl1,left1 for j,lvl1,right1 in nodes(i.right, lvl+1, left=False): yield j,lvl1,right1 def showTree(i:node): + print("") for j,lvl,isLeft in nodes(i): - print("|.. "*lvl, "if" if isLeft else "else", {k:len(rows) for k,rows in j.klasses.items()}) + pre="" + if lvl>0: + pre = f"if {showXY(j.parent.cut)}" if isLeft else "else " + print(f"{'|.. '*(lvl-1) + pre:35}", + show({k:len(rows) for k,rows in j.klasses.items()})) #--------- --------- --------- --------- --------- --------- --------- --------- -------- # ## Distances