Skip to content

Commit

Permalink
graph: specialize zero fields, use locals to share compute/memory
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy committed Jul 23, 2024
1 parent bd9e9a2 commit f31c1b2
Show file tree
Hide file tree
Showing 4 changed files with 425 additions and 139 deletions.
27 changes: 18 additions & 9 deletions src/experimental/graph/core.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type
tag: Gtags
inputs*: seq[Gvalue]
gfunc*: Gfunc
locals*: seq[Gvalue] ## for sharing values between forward and among backward functions
epoch: int

type
Expand Down Expand Up @@ -75,7 +76,7 @@ proc nodeRepr*(x: Gvalue): string =
method newOneOf*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("newOneOf(" & $x & ")") ## Be sure to zero init fields
method valCopy*(z: Gvalue, x: Gvalue) {.base.} = raiseErrorBaseMethod("valCopy(" & $z & "," & $x & ")")

method isTrue*(x: Gvalue): bool {.base.} = raiseErrorBaseMethod("isTrue(" & $x & ")")
method isZero*(x: Gvalue): bool {.base.} = raiseErrorBaseMethod("isZero(" & $x & ")")
method update*(x: Gvalue, y: int) {.base.} = raiseErrorBaseMethod("update(" & $x & "," & $y & ")")
method update*(x: Gvalue, y: float) {.base.} = raiseErrorBaseMethod("update(" & $x & "," & $y & ")")

Expand Down Expand Up @@ -182,10 +183,10 @@ proc condb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
raiseValueError("i must be 0 or 1, got: " & $i)

proc condf(v: Gvalue) =
if v.inputs[0].isTrue:
v.valCopy v.inputs[1]
else:
if v.inputs[0].isZero:
v.valCopy v.inputs[2]
else:
v.valCopy v.inputs[1]

let gcond = newGfunc(forward = condf, backward = condb, name = "cond")

Expand All @@ -200,6 +201,14 @@ proc updated*(x: Gvalue) =
inc epoch
x.epoch = epoch

proc evaluated*(x: Gvalue) =
## signal up-to-date value, given inputs, useful for update value outside of eval, as in update to locals in forward
var maxep = 0
for i in x.inputs:
if maxep < i.epoch:
maxep = i.epoch
x.epoch = maxep

proc eval*(v: Gvalue): Gvalue {.discardable.} =
proc r(x: Gvalue) =
if gtVisited in x.tag:
Expand All @@ -210,14 +219,14 @@ proc eval*(v: Gvalue): Gvalue {.discardable.} =
x.inputs[0].r
if maxep < x.inputs[0].epoch:
maxep = x.inputs[0].epoch
if x.inputs[0].isTrue:
x.inputs[1].r
if maxep < x.inputs[1].epoch:
maxep = x.inputs[1].epoch
else:
if x.inputs[0].isZero:
x.inputs[2].r
if maxep < x.inputs[2].epoch:
maxep = x.inputs[2].epoch
else:
x.inputs[1].r
if maxep < x.inputs[1].epoch:
maxep = x.inputs[1].epoch
else:
for i in x.inputs:
i.r
Expand Down
Loading

0 comments on commit f31c1b2

Please sign in to comment.