diff --git a/src/experimental/graph/core.nim b/src/experimental/graph/core.nim index 31e57ea..e322488 100644 --- a/src/experimental/graph/core.nim +++ b/src/experimental/graph/core.nim @@ -26,6 +26,7 @@ type tag: Gtags inputs*: seq[Gvalue] gfunc*: Gfunc + locals*: seq[Gvalue] ## for sharing values between forward and among backward functions epoch: int type @@ -75,7 +76,7 @@ proc nodeRepr*(x: Gvalue): string = method newOneOf*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("newOneOf(" & $x & ")") ## Be sure to zero init fields method valCopy*(z: Gvalue, x: Gvalue) {.base.} = raiseErrorBaseMethod("valCopy(" & $z & "," & $x & ")") -method isTrue*(x: Gvalue): bool {.base.} = raiseErrorBaseMethod("isTrue(" & $x & ")") +method isZero*(x: Gvalue): bool {.base.} = raiseErrorBaseMethod("isZero(" & $x & ")") method update*(x: Gvalue, y: int) {.base.} = raiseErrorBaseMethod("update(" & $x & "," & $y & ")") method update*(x: Gvalue, y: float) {.base.} = raiseErrorBaseMethod("update(" & $x & "," & $y & ")") @@ -182,10 +183,10 @@ proc condb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = raiseValueError("i must be 0 or 1, got: " & $i) proc condf(v: Gvalue) = - if v.inputs[0].isTrue: - v.valCopy v.inputs[1] - else: + if v.inputs[0].isZero: v.valCopy v.inputs[2] + else: + v.valCopy v.inputs[1] let gcond = newGfunc(forward = condf, backward = condb, name = "cond") @@ -200,6 +201,14 @@ proc updated*(x: Gvalue) = inc epoch x.epoch = epoch +proc evaluated*(x: Gvalue) = + ## signal up-to-date value, given inputs, useful for update value outside of eval, as in update to locals in forward + var maxep = 0 + for i in x.inputs: + if maxep < i.epoch: + maxep = i.epoch + x.epoch = maxep + proc eval*(v: Gvalue): Gvalue {.discardable.} = proc r(x: Gvalue) = if gtVisited in x.tag: @@ -210,14 +219,14 @@ proc eval*(v: Gvalue): Gvalue {.discardable.} = x.inputs[0].r if maxep < x.inputs[0].epoch: maxep = x.inputs[0].epoch - if x.inputs[0].isTrue: - x.inputs[1].r - if maxep < x.inputs[1].epoch: - maxep = x.inputs[1].epoch - else: + if x.inputs[0].isZero: x.inputs[2].r if maxep < x.inputs[2].epoch: maxep = x.inputs[2].epoch + else: + x.inputs[1].r + if maxep < x.inputs[1].epoch: + maxep = x.inputs[1].epoch else: for i in x.inputs: i.r diff --git a/src/experimental/graph/gauge.nim b/src/experimental/graph/gauge.nim index 62ac32b..84f216a 100644 --- a/src/experimental/graph/gauge.nim +++ b/src/experimental/graph/gauge.nim @@ -11,42 +11,45 @@ import layout, ../../gauge, physics/qcdTypes type Gauge = seq[DLatticeColorMatrixV] type Ggauge* {.final.} = ref object of Gvalue + isZero: bool = false ## specialized for zero fields, unrelated to actual gval gval: Gauge proc getgauge*(x: Gvalue): Gauge = Ggauge(x).gval -proc update*(x: Gvalue, g: Gauge) = - let u = x.getgauge - threads: - for mu in 0..