Skip to content

Commit

Permalink
saved
Browse files Browse the repository at this point in the history
  • Loading branch information
timmenzies committed May 27, 2024
1 parent af38c4e commit d8410dd
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions ezr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def XY(at,txt,lo,hi=None,ys=None) -> xy:
"`ys` counts symbols seen in one column between `lo`.. `hi` of another column."
return o(this=XY, n=0, at=at, txt=txt, lo=lo, hi=hi or lo, ys=ys or {})

def NODE(klasses: classes, left=None, right=None) -> node:
def NODE(klasses: classes, parent: node, left=None, right=None) -> node:
"NODEs are parts of binary trees."
return o(this=NODE, klasses=klasses, left=left, right=right)
return o(this=NODE, klasses=klasses, parent=parent, left=left, right=right, cut=None)

def WANT(best="best", bests=1, rests=1) -> want:
"Used to score how well a distribution selects for `best'."
Expand Down Expand Up @@ -315,20 +315,23 @@ def _span(xys : list[xy]) -> list[xy]:

def tree(i:data, klasses:classes, want1:Callable, stop:int=4) -> node:
"Return a binary tree, each level splitting on the range with most `score`."
def _grow(klasses:classes, lvl:int=1, above:int=1E30) -> node:
def _grow(klasses:classes, lvl:int=1, parent=None) -> node:
"Collect the stats needed for branching, then call `_branch()`."
counts = {k:len(rows1) for k,rows1 in klasses.items()}
total = sum(n for n in counts.values())
total = sum(counts.values())
most = counts[max(counts, key=counts.get)]
return _branch(NODE(klasses), lvl, above, total, most)
return _branch(NODE(klasses,parent), lvl, total, most)

def _branch(here:node, lvl:int, above:int, total:int, most:int) -> node:
def _branch(here:node, lvl:int, total:int, most:int) -> node:
"Divide the data on tbe best cut. Recurse."
if total > 2*stop and total < above and most < total: #most==total means "purity" (all of one class)
cut = max(cuts, key=lambda cut0: _want(cut0, here.klasses))
left,right = _split(cut, here.klasses)
here.left = _grow(left, lvl+1, total)
here.right = _grow(right, lvl+1, total)
if total > 2*stop and most < total: #most==total means "purity" (all of one: class)
here.cut = max(cuts, key=lambda cut0: _want(cut0, here.klasses))
left,right = _split(here.cut, here.klasses)
leftn = sum(len(rows1) for rows1 in left.values())
rightn = sum(len(rows1) for rows1 in right.values())
if leftn < total and rightn < total:
here.left = _grow(left, lvl+1, here)
here.right = _grow(right, lvl+1, here)
return here

def _want(cut:xy, klasses:classes) -> float :
Expand All @@ -349,12 +352,17 @@ def _split(cut:xy, klasses:classes) -> tuple[classes,classes]:
def nodes(i:node, lvl=0, left=True) -> node:
if i:
yield i,lvl,left
for j,lvl1,left1 in nodes(i.left, lvl+1, left=True) : yield j,lvl1,left1
for j,lvl1,left1 in nodes(i.left, lvl+1, left=True) : yield j,lvl1,left1
for j,lvl1,right1 in nodes(i.right, lvl+1, left=False): yield j,lvl1,right1

def showTree(i:node):
print("")
for j,lvl,isLeft in nodes(i):
print("|.. "*lvl, "if" if isLeft else "else", {k:len(rows) for k,rows in j.klasses.items()})
pre=""
if lvl>0:
pre = f"if {showXY(j.parent.cut)}" if isLeft else "else "
print(f"{'|.. '*(lvl-1) + pre:35}",
show({k:len(rows) for k,rows in j.klasses.items()}))
#--------- --------- --------- --------- --------- --------- --------- --------- --------
# ## Distances

Expand Down

0 comments on commit d8410dd

Please sign in to comment.