diff --git a/src/experimental/graph/multi.nim b/src/experimental/graph/multi.nim index 4fcd37a..c1f3576 100644 --- a/src/experimental/graph/multi.nim +++ b/src/experimental/graph/multi.nim @@ -37,9 +37,15 @@ method updateAt*(x: Gvalue, i: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorB proc getAtmb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = case i of 0: - return z.inputs[0].newOneOf.updateAt(i, zb) + if zb == nil: + # the output must be a scalar, otherwise crash later + return z.inputs[0].newOneOf.updateAt(z.inputs[1], toGvalue(1.0)) + else: + return z.inputs[0].newOneOf.updateAt(z.inputs[1], zb) + of 1: + return toGvalue(0) else: - raiseValueError("i must be 0, got: " & $i) + raiseValueError("i must be 0 or 1, got: " & $i) proc getAtmf(v: Gvalue) = let x = Gmulti(v.inputs[0]) @@ -56,11 +62,13 @@ method `[]`*(x: Gmulti, i: Gint): Gvalue = proc updateAtmb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = case i of 0: - return zb.updateAt(i, zb.inputs[0].newOneOf) + return zb.updateAt(z.inputs[1], z.inputs[2].newOneOf) + of 1: + return toGvalue(0) of 2: return zb[i] else: - raiseValueError("i must be 0 or 2, got: " & $i) + raiseValueError("i must be 0, 1, or 2, got: " & $i) proc updateAtmf(v: Gvalue) = let x = Gmulti(v.inputs[0]) @@ -68,10 +76,13 @@ proc updateAtmf(v: Gvalue) = let z = Gmulti(v) for k in 0..