Skip to content

Commit

Permalink
graph/multi: fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy committed Jul 18, 2024
1 parent 7d828b3 commit 4c080ca
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/experimental/graph/multi.nim
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ method updateAt*(x: Gvalue, i: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorB
proc getAtmb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
case i
of 0:
return z.inputs[0].newOneOf.updateAt(i, zb)
if zb == nil:
# the output must be a scalar, otherwise crash later
return z.inputs[0].newOneOf.updateAt(z.inputs[1], toGvalue(1.0))
else:
return z.inputs[0].newOneOf.updateAt(z.inputs[1], zb)
of 1:
return toGvalue(0)
else:
raiseValueError("i must be 0, got: " & $i)
raiseValueError("i must be 0 or 1, got: " & $i)

proc getAtmf(v: Gvalue) =
let x = Gmulti(v.inputs[0])
Expand All @@ -56,22 +62,27 @@ method `[]`*(x: Gmulti, i: Gint): Gvalue =
proc updateAtmb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
case i
of 0:
return zb.updateAt(i, zb.inputs[0].newOneOf)
return zb.updateAt(z.inputs[1], z.inputs[2].newOneOf)
of 1:
return toGvalue(0)
of 2:
return zb[i]
else:
raiseValueError("i must be 0 or 2, got: " & $i)
raiseValueError("i must be 0, 1, or 2, got: " & $i)

proc updateAtmf(v: Gvalue) =
let x = Gmulti(v.inputs[0])
let i = v.inputs[1].getint
let z = Gmulti(v)
for k in 0..<z.mval.len:
if k == i:
z.mval[k] = v.inputs[2]
z.mval[k].valCopy(v.inputs[2])
else:
z.mval[k] = x.mval[i]
z.mval[k].valCopy(x.mval[k])

let updateAtm = newGfunc(forward = updateAtmf, backward = updateAtmb, name = "updateAtm")

method updateAt*(x: Gmulti, i: Gint, y: Gvalue): Gvalue = Gmulti(mval: newseq[Gvalue](x.mval.len), inputs: @[Gvalue(x), i, y], gfunc: updateAtm)
method updateAt*(x: Gmulti, i: Gint, y: Gvalue): Gvalue =
result = x.newOneOf
result.inputs = @[Gvalue(x), i, y]
result.gfunc = updateAtm

0 comments on commit 4c080ca

Please sign in to comment.