From b12c3d6215ea05ddfb36c08a3fa78d2168118062 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Tue, 7 May 2024 11:47:30 -0500 Subject: [PATCH] experimental/symbolic/graph: open for testing and suggestions --- src/experimental/symbolic/graph.nim | 492 ++++++++++++++++++++++++++++ src/graph.nim | 200 ----------- 2 files changed, 492 insertions(+), 200 deletions(-) create mode 100644 src/experimental/symbolic/graph.nim delete mode 100644 src/graph.nim diff --git a/src/experimental/symbolic/graph.nim b/src/experimental/symbolic/graph.nim new file mode 100644 index 00000000..dd63d5b3 --- /dev/null +++ b/src/experimental/symbolic/graph.nim @@ -0,0 +1,492 @@ +## lazy generic computational graph + +#[ + +requires: `--multimethods:on` + +We want the symbolic graph nodes to be type generic. Therefore we +need dynamic dispatch based on object types at runtime. This +implementation uses Nim's builtin multimethods for this purpose. + +Nim's multimethods may slow down single method dispatch time, which +would affect the performance of the comms module. We need to measure +it to understand the impact. + +The functions, ident and add, are the same as a variable argument +function, in terms of symbolic formulae, but we treat functions +with different number of arguments differently, because the +implementations of the functions would be different. It also avoids +increasing dynamic dispatch overhead. + +Typed values are enclosed in derived types of `SymNodeValueConcrete` +and referenced from nodes, which have a single type. Since Nim +doesn't allow mixing generic functions with methods, we need to +define a method for each and every combination of concrete types +we use, and wrap our existing generic functions in each method. + +Both SymNodeValueConcrete and SymNode are references. In the graph +we build, a shared node means the same variable. We create new +nodes that equal to the existing nodes with the ident function to +create new referenced node objects, in order to avoid false sharing. +We use copySymNodeValue to create new referenced value objects, +such that different nodes refer to distinct value objects. + +We use the tag sntVisited to avoid repeatedly traverse shared nodes. +The recursive graph traversal function all ends with Rec, just to +remind us to call `tagClearRec(z, sntVisited)` in the top level. + +Further optimizations only possible after building all the graphs: +- Remove ident nodes +- Analyze and reuse allocations when possible + +]# + +# +# basic type support +# + +type + SymNodeTag = enum + sntVisited, sntAssigned, sntNeedUpdate, sntNeedGradient, sntFixedGradient + # sntReusable, ... + SymNodeTags = set[SymNodeTag] + SymNodeValue = ref object of RootObj ## Represent unallocated symbolic value + SymNodeValueConcrete = ref object of SymNodeValue ## For any concrete values + SymNodeGradient = object + ## for a particular variable v + dependent: SymNode ## a variable v that depends on this node, x + gradient: SymNode ## dv/dx + SymNode = ref object + # This object can be cyclic, because gradients refer to ancestor nodes + value: SymNodeValue + inputs: seq[SymNode] + forward: proc(z: SymNode) ## runs the actual compute + arg: SymNodeValue ## extra argument forward/backward uses + runCount: int + allocateValue: proc(z: SymNode) + backward: proc(z: SymNode, i: int, dep: SymNode): SymNode ## create graphs + gradients: seq[SymNodeGradient] ## saved gradient graphs + name: string + tag: SymNodeTags + id: int + SymNodeError = object of Defect + SymNodeValueError = object of Defect + +template raiseError(msg: string) = + raise newException(SymNodeError, msg) + +template raiseValueError(msg: string) = + raise newException(SymNodeValueError, msg) + +template raiseErrorBaseMethod(msg: string) = + raise newException( + SymNodeError, + "Base method invoked: " & msg & + "\nMake sure to pass `--multimethods:on` and check there is a custom method for each derived type.") + +method `$`(v: SymNodeValue): string {.base.} = "SymNodeValue" + +func `$`(z: SymNode): string = + z.name & "#" & $z.id + +func nodeRepr(z: SymNode): string +func treeRepr(z: SymNode): string + +method copySymNodeValue(v: SymNodeValue): SymNodeValue {.base.} = + ## nothing to copy + v + +method copySymNodeValue(v: SymNodeValueConcrete): SymNodeValueConcrete = + raiseValueError("Custom method required for concrete value: " & $v) + +proc newSymNode( + value = SymNodeValue(), + inputs: seq[SymNode] = @[], + forward: proc(z: SymNode) = nil, + arg: SymNodeValue = nil, + runCount: int = 0, + allocateValue: proc(z: SymNode) = nil, + backward: proc(z: SymNode, i: int, dep: SymNode): SymNode = nil, + gradients: seq[SymNodeGradient] = @[], + name: string = "", + tag: SymNodeTags = {}): SymNode = + ## Create new SymNode with a unique id. + var id {.global.} = 0 + result = SymNode(value: value, inputs: inputs, forward: forward, arg: arg, runCount: runCount, + allocateValue: allocateValue, backward: backward, gradients: gradients, name: name, tag: tag, id: id) + id.inc + +proc copySymNode(z: SymNode): SymNode = + newSymNode(value = z.value.copySymNodeValue, inputs = z.inputs, forward = z.forward, + arg = if z.arg != nil: z.arg.copySymNodeValue else: nil, runCount = z.runCount, + allocateValue = z.allocateValue, backward = z.backward, gradients = z.gradients, + name = z.name, tag = z.tag) + +proc assignSymNode(z: SymNode, x: SymNode) = + z.value = x.value.copySymNodeValue + z.inputs = x.inputs + z.forward = x.forward + if x.arg != nil: + z.arg = x.arg.copySymNodeValue + z.runCount = x.runCount + z.allocateValue = x.allocateValue + z.backward = x.backward + z.gradients = x.gradients + z.name = x.name + z.tag = x.tag + +proc gradientDependentOrNil(z: SymNode, dep: SymNode): SymNode = + ## May return nil if not exists. + for g in z.gradients: + if dep == g.dependent: + return g.gradient + # We don't have a matching dependent variable. + return nil + +proc gradientDependentAssign(z: SymNode, dep: SymNode, grad: SymNode) = + ## Replace if exists, otherwise add to the list. + for g in z.gradients.mitems: + if dep == g.dependent: + g.gradient = grad + return + z.gradients.add SymNodeGradient(dependent: dep, gradient: grad) + +proc assign(z: SymNode, v: SymNodeValueConcrete) = + z.value = v + z.tag.incl sntAssigned + +# +# generic symbol support +# + +proc newSym(s: string): SymNode = + newSymNode(name = s) + +method identSymNodeValue(z: SymNodeValue, x: SymNodeValue) {.base.} = + raiseErrorBaseMethod("args:\n " & z.repr & "\n " & x.repr) + +method identAllocateSymNodeValue(z: SymNode, x: SymNodeValue) {.base.} = + raiseErrorBaseMethod("args:\n " & z.nodeRepr & "\n " & x.repr) + +method iaddSymNodeValue(z: SymNodeValue, x: SymNodeValue, y: SymNodeValue) {.base.} = + raiseErrorBaseMethod("args:\n " & z.repr & "\n " & x.repr & "\n " & y.repr) + +method iaddAllocateSymNodeValue(z: SymNode, x: SymNodeValue, y: SymNodeValue) {.base.} = + raiseErrorBaseMethod("args:\n " & z.nodeRepr & "\n " & x.repr & "\n " & y.repr) + +# +# float support +# + +type SymNodeValueFloat = ref object of SymNodeValueConcrete + floatValue: float + +method `$`(v: SymNodeValueFloat): string = $v.floatValue + +proc assign(z: SymNode, v: float) = + z.assign SymNodeValueFloat(floatValue: v) + +method copySymNodeValue(v: SymNodeValueFloat): SymNodeValueFloat = + SymNodeValueFloat(floatValue: v.floatValue) + +method identSymNodeValue(z: SymNodeValueFloat, x: SymNodeValueFloat) = + z.floatValue = x.floatValue + +method identAllocateSymNodeValue(z: SymNode, x: SymNodeValueFloat) = + z.value = SymNodeValueFloat() + +method iaddSymNodeValue(z: SymNodeValueFloat, x: SymNodeValueFloat, y: SymNodeValueFloat) = + z.floatValue = x.floatValue + y.floatValue + +method iaddAllocateSymNodeValue(z: SymNode, x: SymNodeValueFloat, y: SymNodeValueFloat) = + z.value = SymNodeValueFloat() + +# +# minimum algebra for the nodes +# + +proc ident(x:SymNode): SymNode +proc add(x: SymNode, y: SymNode): SymNode + +proc identForward(z: SymNode) = + identSymNodeValue(z.value, z.inputs[0].value) + +proc identAllocate(z: SymNode) = + identAllocateSymNodeValue(z, z.inputs[0].value) + +proc identBackward(z: SymNode, i: int, dep: SymNode): SymNode = + let g = z.gradientDependentOrNil dep + if g == nil: + return newSymNode(value = SymNodeValueFloat(floatValue: 1.0), name = "One[ident]") + else: + return g.ident + +proc ident(x:SymNode): SymNode = + newSymNode( + inputs = @[x], + forward = identForward, + allocateValue = identAllocate, + backward = identBackward, + name = "ident") + +proc addForward(z: SymNode) = + iaddSymNodeValue(z.value, z.inputs[0].value, z.inputs[1].value) + +proc addAllocate(z: SymNode) = + iaddAllocateSymNodeValue(z, z.inputs[0].value, z.inputs[1].value) + +proc addBackward(z: SymNode, i: int, dep: SymNode): SymNode = + let g = z.gradientDependentOrNil dep + if g == nil: + return newSymNode(value = SymNodeValueFloat(floatValue: 1.0), name = "One[add]") + else: + return g.ident + +proc add(x: SymNode, y: SymNode): SymNode = + newSymNode( + inputs = @[x, y], + forward = addForward, + allocateValue = addAllocate, + backward = addBackward, + name = "add") + +# +# graph traversal and evaluation +# + +proc tagClearRec(z: SymNode, tag: SymNodeTag) = + ## This does not use sntVisited, so it will repeat on shared nodes. + if tag in z.tag: + z.tag.excl tag + for i in z.inputs: + i.tagClearRec tag + +proc allocateRec(z: SymNode) = + if sntVisited notin z.tag: + z.tag.incl sntVisited + for i in z.inputs: + i.allocateRec + if not (z.value of SymNodeValueConcrete): + if z.allocateValue == nil: + raiseError("undefined allocateValue for node: " & z.nodeRepr) + z.allocateValue z + +proc allocate(z: SymNode) = + z.allocateRec + z.tagClearRec sntVisited + +proc tagUpdateRec(z: SymNode) = + if sntVisited in z.tag: + return + z.tag.incl sntVisited + if sntAssigned in z.tag: + z.tag.excl sntAssigned + else: + var needupdate = false + for i in z.inputs: + needupdate = needupdate or sntAssigned in i.tag + i.tagUpdateRec + needupdate = needupdate or sntNeedUpdate in i.tag + if needupdate: + z.tag.incl sntNeedUpdate + +proc evalRec(z: SymNode) = + if sntVisited in z.tag: + if sntNeedUpdate in z.tag: + raiseError "cycle detected" + elif sntNeedUpdate in z.tag: + z.tag.incl sntVisited + for i in z.inputs: + i.evalRec + if z.forward != nil: + z.forward z + z.runCount.inc + elif z.inputs.len > 0: + raiseError("inputs.len: " & $z.inputs.len & ", but no forward function defined for:\n" & z.nodeRepr) + z.tag.excl sntNeedUpdate + +proc eval(z: SymNode) = + z.tagUpdateRec + z.tagClearRec sntVisited + z.evalRec + z.tagClearRec sntVisited + +proc tagUpdateNeedGradientRec(z: SymNode) = + if sntVisited in z.tag: + return + z.tag.incl sntVisited + var needgradient = false + for i in z.inputs: + i.tagUpdateNeedGradientRec + needgradient = needgradient or sntNeedGradient in i.tag + if needgradient and sntNeedGradient notin z.tag: + z.tag.incl sntNeedGradient + +proc gradientRec(z: SymNode, dep: SymNode) = + ## gradient of dep with respect to z + # We tag newly created nodes from z.backward(z, i, dep), with needUpdate. + for i in 0.. 0: + result &= ", inputs: {" + for i in 0.. 0: + result &= ", " + result &= "[" & $i & "]: " & $z.inputs[i] + result &= "}" + if z.gradients.len > 0: + result &= ", gradients: {" + for i in 0.. 0: + result &= ", " + result &= "[" & $i & "]: " & $z.gradients[i].dependent & " -> " & $z.gradients[i].gradient + result &= "}" + if z.value != nil: + result &= ": " & $z.value + +func toStringRec(z: SymNode, pre: string, shared: seq[SymNode]): string = + result = pre & z.nodeRepr + if sntVisited in z.tag: + result &= " [shared]" + else: + z.tag.incl sntVisited + for zid in 0..1: - result = "#" & $id & "=(" - g.refs = -id - inc id - else: - result = "(" - result &= g.str - for x in g.args: - result &= " " & x.go - result &= ")" - else: - result = g.str - result = g.go - -proc newVar[T](x:T, s="$V"):GraphNode[T] = - GraphNode[T](val:GraphValue[T](v:x), str:s, - initGrad: (proc(g:Graph) = g.grad = GraphValue[T](v:1.T))) - -proc newConst[T](x:T, s="$C"):GraphNode[T] = - GraphNode[T](tag: {gtConst}, val:GraphValue[T](v:x), str:s & "|" & $x & "|", - initGrad: (proc(g:Graph) = g.grad = GraphValue[T](v:1.T))) - -proc eval(g:Graph) = - if g.isop and gtRun notin g.tag: - g.tag.incl gtRun - for x in g.args: - x.eval - g.run(g) - -proc clearGrad(g:Graph) = - g.grad = nil - g.tag.excl gtDF - g.tag.excl gtGrad - if g.isop: - for x in g.args: - x.clearGrad - -proc evalGrad(g:Graph) = - if g.isop and gtRun notin g.tag: - g.eval - if gtDF notin g.tag: - g.clearGrad - g.countRefs - g.tag.incl gtDF - g.initGrad(g) - proc go(g:Graph) = - g.tag.incl gtGrad - if g.isop: - g.refs.dec - if g.refs>0: - return - # Wait until the last reference of the shared nodes. - g.back(g) - for x in g.args: - x.go - g.go - -proc evalGrad[G,X](g:GraphNode[G], x:GraphNode[X]):X = - # TODO only descend in to nodes that contains x. - g.Graph.evalGrad - proc go(g:Graph):Graph = - if g == x: - return x - elif g.isop: - for c in g.args: - if c.go == x: - return x - g - if g.go == x: - GraphValue[X](x.grad).v - else: - X 0 - -proc eval[T](g:GraphNode[T]):T = - g.Graph.eval - GraphValue[T](g.val).v - -proc `+`[X,Y](x:GraphNode[X], y:GraphNode[Y]):auto = - type R = type(GraphValue[X](x.val).v+GraphValue[Y](y.val).v) - GraphNode[R](isop:true, str:"+", args: @[x.Graph,y], - run: (proc(g:Graph) = - echo "# Run: ",g.args[0].str," + ",g.args[1].str - let v = GraphValue[R](v:GraphValue[X](g.args[0].val).v+GraphValue[Y](g.args[1].val).v) - g.val = v - g.str &= "(=" & $v.v & ")"), - initGrad: (proc(g:Graph) = g.grad = GraphValue[R](v:1.R)), - back: (proc(g:Graph) = - echo "# Back: ",x.str," * ",y.str - if gtConst notin g.args[0].tag: - let t = GraphValue[R](g.grad).v.X - if g.args[0].grad != nil: - g.args[0].grad = GraphValue[X](v:GraphValue[X](g.args[0].grad).v+t) - else: - g.args[0].grad = GraphValue[X](v:t) - if gtConst notin g.args[1].tag: - let t = GraphValue[R](g.grad).v.Y - if g.args[1].grad != nil: - g.args[1].grad = GraphValue[Y](v:GraphValue[Y](g.args[1].grad).v+t) - else: - g.args[1].grad = GraphValue[Y](v:t))) -proc `+`[X](x:GraphNode[X], y:SomeNumber):auto = x + newConst(y) -proc `+`[Y](x:SomeNumber, y:GraphNode[Y]):auto = newConst(x) + y - -proc `*`[X,Y](x:GraphNode[X], y:GraphNode[Y]):auto = - type R = type(GraphValue[X](x.val).v*GraphValue[Y](y.val).v) - GraphNode[R](isop:true, str:"*", args: @[x.Graph,y], - run: (proc(g:Graph) = - echo "# Run: ",x.str," * ",y.str - let v = GraphValue[R](v:GraphValue[X](g.args[0].val).v*GraphValue[Y](g.args[1].val).v) - g.val = v - g.str &= "(=" & $v.v & ")"), - initGrad: (proc(g:Graph) = g.grad = GraphValue[R](v:1.R)), - back: (proc(g:Graph) = - echo "# Back: ",x.str," * ",y.str - if gtConst notin g.args[0].tag: - let t = GraphValue[R](g.grad).v*GraphValue[Y](g.args[1].val).v - if g.args[0].grad != nil: - g.args[0].grad = GraphValue[X](v:GraphValue[X](g.args[0].grad).v+t) - else: - g.args[0].grad = GraphValue[X](v:t) - if gtConst notin g.args[1].tag: - let t = GraphValue[X](g.args[0].val).v*GraphValue[R](g.grad).v - if g.args[1].grad != nil: - g.args[1].grad = GraphValue[X](v:GraphValue[X](g.args[1].grad).v+t) - else: - g.args[1].grad = GraphValue[Y](v:t))) -proc `*`[X](x:GraphNode[X], y:SomeNumber):auto = x * newConst(y) -proc `*`[Y](x:SomeNumber, y:GraphNode[Y]):auto = newConst(x) * y - -when isMainModule: - let - x = newVar(2.0, "x") - y = newVar(3.0, "y") - z = x*(y+5.0) - t = x*y*(z+1.0)*z - echo "x: ",x - echo "y: ",y - echo "z: ",z - echo "t: ",t - let rt = t.eval - echo "t: ",t - echo "rt: ",rt - echo "rz: ",z.eval - echo "dtdz: ",t.evalGrad z - echo "dtdx: ",t.evalGrad x - echo "dtdy: ",t.evalGrad y - echo "dzdx: ",z.evalGrad x - echo "dzdy: ",z.evalGrad y - let u = (x+t)*(t+z) - echo "u: ",u - echo "dudy: ",u.evalGrad y