From 3b6e954cad99283ee3c60f7f6a89003f6b11e3b4 Mon Sep 17 00:00:00 2001 From: timm Date: Sun, 14 Jul 2024 09:58:40 -0400 Subject: [PATCH] bins working now --- src/ezr.lua | 48 +++++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/src/ezr.lua b/src/ezr.lua index 3be3024..76199cf 100755 --- a/src/ezr.lua +++ b/src/ezr.lua @@ -221,14 +221,17 @@ eg.data["--sort:read and sort data"] = function(train) -- ## Discretize local BIN={} -function BIN:new(s,n, lo,hi,y) --> BIN +function BIN:new(s,n, lo,hi,ymid,ydiv) --> BIN return l.new(BIN,{name=s,pos=n, lo=lo or the.all.inf, - hi=hi or lo or the.all.inf, y=y or NUM:new(s,n)}) end + hi=hi or lo or the.all.inf, + ydiv=ydiv,ymid=ymid,_helper=NUM:new()}) end function BIN:add(x,y) if x ~= "?" then if x < self.lo then self.lo = x end if x > self.hi then self.hi = x end - self.y:add(y) end end + self._helper:add(y) + self.ymid = self._helper:mid() + self.ydiv = self._helper:div() end end function BIN:__tostring( lo,hi,s) lo,hi,s = self.x.lo, self.x.hi,self.x.name @@ -244,15 +247,16 @@ function BIN:selects(rows, u,lo,hi,x) if x=="?" or lo==hi and lo==x or lo < x and x <= hi then l.push(u,r) end end return u end -function SYM:bins(rows,y,_,_, t,x) --> array[bin] ; proposes one split per symbol value +function SYM:bins(rows,y,_,_, t,xx) --> array[bin] ; proposes one split per symbol value t = {} - for row in pairs(rows) do - x = row[self.pos] - if x ~= "?" then t[x] = t[x] or BIN:new(self.name,self.pos,x) - t[x]:add(x, y(row)) end end + for _,row in pairs(rows) do + xx=row[self.pos] + if xx ~= "?" then + t[xx] = t[xx] or BIN:new(self.name,self.pos,xx) + t[xx]:add(xx, y(row)) end end return t end -function NUM:bins(rows,y, enough,epsilon, x,now,out,changed) +function NUM:bins(rows,y, enough,epsilon, x,now,out,cut) local x,q,new,similar,now,out function x(row) return row[self.pos] end function q(row) return x(row)=="?" and -1E32 or x(row) end @@ -261,31 +265,33 @@ function NUM:bins(rows,y, enough,epsilon, x,now,out,changed) x = {lo=self:clone(), hi=self:clone()}} for i,row in pairs(rows) do if x(row) ~= "?" then now.x.hi:add( x(row)); now.y.hi:add( y(row)) end end - enough = now.x.hi.n^enough - epsilon = now.x.hi:div()^epsilon out = l.copy(now) out.ydiv = now.y.hi:div() for i,row in pairs(rows) do - if x(row) ~= "?" then + if x(row) ~= "?" then now.x.lo:add( now.x.hi:sub( x(row))) now.y.lo:add( now.y.hi:sub( y(row))) - if not now.x.lo:similar(now.x.hi, enough,epsilon) then + if not now.x.lo:similar(now.x.hi, enough,epsilon) then if x(row) ~= x(rows[i+1]) then - if now.y.lo:xpect(now.y.hi) < out.ydiv then - changed = true + if now.y.lo:xpect(now.y.hi) < out.ydiv then + cut = x(row); out = l.copy(now) out.ydiv = now.y.lo:xpect(now.y.hi) end end end end end - if changed then - return { BIN:new(self.name, self.pos, -the.all.inf, out.x.hi.lo, out.y.lo), - BIN:new(self.name, self.pos, out.x.hi.lo, the.all.inf, out.y.hi) } end end + if cut then + return { BIN:new(self.name, self.pos, -the.all.inf, cut, + out.y.lo:mid(), out.y.lo:div()), + BIN:new(self.name, self.pos, cut, the.all.inf, + out.y.hi:mid(), out.y.hi:div())} end end -- --------------------------------------------------------------------------------------- eg.bins={} eg.bins["--bins:[?file] read in csv data"] = function(train, d) d = DATA:new():read(train or the.all.train) - print(d.cols.y[1].name) - l.oo(d.cols.y[1]:bins(d.rows, + print(d:chebyshevs():mid()) + for _,col in pairs(d.cols.x) do + print(col.name) + l.oo(col:bins(d.rows, function(row) return d:chebyshev(row) end, - the.bins.enough, the.bins.epsilon)) + (#d.rows)^the.bins.enough, col:div()*the.bins.epsilon)) end end -- ## Tree