diff --git a/bootstrap-travis b/bootstrap-travis index c8554b7..a6f89c4 100755 --- a/bootstrap-travis +++ b/bootstrap-travis @@ -1,17 +1,24 @@ #!/bin/sh +comms="mpi" +cc="mpicc" +if [ "X$1" = "Xsingle" ]; then + comms="single" + cc="gcc" +fi + qmpv="qmp-2.5.4" qmp="$qmpv.tar.gz" if [ ! -f $qmp ]; then - wget "http://usqcd-software.github.io/downloads/qmp/$qmp" + curl -LO "http://usqcd-software.github.io/downloads/qmp/$qmp" fi tar zxvf $qmp mkdir qmp cd $qmpv ./configure \ --prefix="$PWD/../qmp" \ - --with-qmp-comms-type=mpi \ - CC=mpicc \ + --with-qmp-comms-type=$comms \ + CC=$cc \ CFLAGS="-Wall -O3 -std=gnu99 -g -fPIC" make && make install cd .. @@ -19,7 +26,7 @@ cd .. qiov="qio-3.0.0" qio="$qiov.tar.gz" if [ ! -f $qio ]; then - wget "http://usqcd-software.github.io/downloads/qio/$qio" + curl -LO "http://usqcd-software.github.io/downloads/qio/$qio" fi tar zxvf $qio mkdir qio @@ -27,7 +34,7 @@ cd $qiov ./configure \ --prefix="$PWD/../qio" \ --with-qmp="$PWD/../qmp" \ - CC=mpicc \ + CC=$cc \ CFLAGS="-Wall -O3 -std=gnu99 -g -fPIC" make && make install cd .. diff --git a/build/installNim b/build/installNim index 769bba4..b4c077f 100755 --- a/build/installNim +++ b/build/installNim @@ -10,7 +10,7 @@ usage() { echo " -h (this help message)" echo " stable (install latest stable version)" echo " (install named version, e.g. 0.16.0)" - echo " master (install master branch tracking version)" + #echo " master (install master branch tracking version)" echo " devel (install devel branch tracking version)" echo " default stable||master|devel" echo " (set default version)" @@ -35,11 +35,11 @@ if [ "X$1" = "Xdefault" ]; then fi ver="stable" -branch="master" +branch="version-2-0" if [ "X$1" != "X" ]; then case "$1" in stable) ver="stable";; - master) ver="master";; + #master) ver="master";; devel) ver="devel"; branch="devel";; -h) usage;; *) ver="$1"; branch="v$1";; diff --git a/src/base/omp.nim b/src/base/omp.nim index 12c17b0..80b803b 100644 --- a/src/base/omp.nim +++ b/src/base/omp.nim @@ -3,11 +3,11 @@ import os when defined(noOpenmp): static: echo "OpenMP disabled" template omp_set_num_threads*(x: cint) = discard - template omp_get_num_threads*(): cint = 1 - template omp_get_max_threads*(): cint = 1 - template omp_get_thread_num*(): cint = 0 - template ompPragma(p:string):untyped = discard - template ompBlock*(p:string; body:untyped):untyped = + template omp_get_num_threads*(): cint = cint 1 + template omp_get_max_threads*(): cint = cint 1 + template omp_get_thread_num*(): cint = cint 0 + template ompPragma(p:string) = discard + template ompBlock*(p:string; body:untyped) = block: body else: @@ -24,11 +24,11 @@ else: proc omp_get_max_threads*(): cint {.omp.} proc omp_get_thread_num*(): cint {.omp.} #proc forceOmpOn() {.omp.} - template ompPragma(p:string):untyped = + template ompPragma(p:string) = #forceOmpOn() #{. emit:["#pragma omp ", p] .} {. emit:["_Pragma(\"omp ", p, "\")"] .} - template ompBlock*(p:string; body:untyped):untyped = + template ompBlock*(p:string; body:untyped) = #{. emit:"#pragma omp " & p .} #{. emit:"{ /* Inserted by ompBlock " & p & " */".} #{. emit:["#pragma omp ", p] .} @@ -39,14 +39,14 @@ else: template ompBarrier* = ompPragma("barrier") -template ompParallel*(body:untyped):untyped = +template ompParallel*(body:untyped) = ompBlock("parallel"): if(omp_get_thread_num()!=0): setupForeignThreadGc() body -template ompMaster*(body:untyped):untyped = ompBlock("master", body) -template ompSingle*(body:untyped):untyped = ompBlock("single", body) -template ompCritical*(body:untyped):untyped = ompBlock("critical", body) +template ompMaster*(body:untyped) = ompBlock("master", body) +template ompSingle*(body:untyped) = ompBlock("single", body) +template ompCritical*(body:untyped) = ompBlock("critical", body) when isMainModule: proc test = diff --git a/src/base/profile.nim b/src/base/profile.nim index 2f9a366..7197cbf 100644 --- a/src/base/profile.nim +++ b/src/base/profile.nim @@ -1,7 +1,7 @@ import threading export threading import comms/comms, stdUtils, base/basicOps -import os, strutils, sequtils, std/monotimes +import os, strutils, sequtils, std/monotimes, std/tables, std/algorithm, strformat export monotimes getOptimPragmas() @@ -36,7 +36,7 @@ tic() injects symbols: ]## type - II = typeof(instantiationInfo()) + II = ptr typeof(instantiationInfo()) RTInfo = distinct int RTInfoObj = object nsec: int64 @@ -49,13 +49,19 @@ type CodePoint = distinct int32 CodePointObj = object toDropTimer: bool - name: string + #nsec: int64 + #overhead: int64 + count: uint32 + dropcount: uint32 + name: CStr loc: II SString = static[string] | string List[T] = object # No GCed type allowed len,cap:int32 data:ptr UncheckedArray[T] RTInfoObjList = distinct List[RTInfoObj] + CStr = object + p, s: int32 var rtiListLength:int32 = 0 @@ -72,7 +78,7 @@ proc newList[T](len:int32 = 0):List[T] {.noinit.} = result.len = len result.cap = cap if cap > 0: - result.data = cast[ptr UncheckedArray[T]](alloc(sizeof(T)*cap)) + result.data = cast[ptr UncheckedArray[T]](allocShared(sizeof(T)*cap)) listChangeLen[T](cap) else: result.data = nil @@ -80,7 +86,7 @@ proc newListOfCap[T](cap:int32):List[T] {.noinit.} = result.len = 0 result.cap = cap if cap > 0: - result.data = cast[ptr UncheckedArray[T]](alloc(sizeof(T)*cap)) + result.data = cast[ptr UncheckedArray[T]](allocShared(sizeof(T)*cap)) listChangeLen[T](cap) else: result.data = nil @@ -92,16 +98,16 @@ proc setLen[T](ls:var List[T], len:int32) = if cap0 == 0: cap = 1 while cap < len: cap *= 2 - ls.data = cast[ptr UncheckedArray[T]](alloc(sizeof(T)*cap.int)) + ls.data = cast[ptr UncheckedArray[T]](allocShared(sizeof(T)*cap.int)) else: while cap < len: cap *= 2 - ls.data = cast[ptr UncheckedArray[T]](realloc(ls.data, sizeof(T)*cap.int)) + ls.data = cast[ptr UncheckedArray[T]](reallocShared(ls.data, sizeof(T)*cap.int)) ls.cap = cap listChangeLen[T](int32(cap-cap0)) ls.len = len proc free[T](ls:var List[T]) = if ls.cap > 0: - dealloc(ls.data) + deallocShared(ls.data) listChangeLen[T](-ls.cap) ls.len = 0 ls.cap = 0 @@ -117,6 +123,8 @@ iterator items[T](ls:List[T]):T = proc setLen(ls:var RTInfoObjList, len:int32) {.borrow.} proc free(ls:var RTInfoObjList) {.borrow.} +#proc add(ls:var RTInfoObjList, x:RTInfoObj) {.borrow.} +proc add(ls:var RTInfoObjList, x:RTInfoObj) = add(List[RTInfoObj](ls), x) template len(ls:RTInfoObjList):int32 = List[RTInfoObj](ls).len template `[]`(ls:RTInfoObjList, n:int32):untyped = List[RTInfoObj](ls)[n] iterator mitems(ls:RTInfoObjList):var RTInfoObj = @@ -126,6 +134,51 @@ iterator mitems(ls:RTInfoObjList):var RTInfoObj = func isNil(x:RTInfo):bool = x.int<0 func isNil(x:CodePoint):bool = x.int<0 +const defaultCStrPoolCap {.intDefine.} = 512 +type CStrAtom = array[16,char] +var cstrpool = newListOfCap[CStrAtom](defaultCStrPoolCap) + +proc len(s:CStr):int = int(s.s) +proc newCStr(t:string):CStr = + const a = int32(sizeof(CStrAtom)) + let p = cstrpool.len + let s = int32(t.len) + let n = (s+a-1) div a + cstrpool.setlen(p+n) + var k = 0 + var j = 0 + for i in 0.. DropWasteTimerRatio: - #if ii.filename != "cg.nim": + inc thisCode.count + if oh.float > ns.float*DropWasteTimerRatio: + #if not toDropTimer(prevRTI) and oh.float > ns.float*DropWasteTimerRatio: + inc thisCode.dropcount + #if ii.filename != "scg.nim": # echo "drop timer: ", oh.float, "/", ns.float, "=", oh.float / ns.float # echo " ", prevRTI.int, " ", thisRTI.int, " ", ii, " ", s # Signal stop if the overhead is too large. - dropTimer(prevRTI) + if thisCode.dropcount > 10 and thisCode.dropcount*10 > thisCode.count: + #echo "dropTimer: ", rtiStack[thisRTI.int].tic.name, " ", thisCode.name, " ", thisCode.loc + dropTimer(prevRTI) + dropTimerChildren(thisRTI) + #dropTimer(thisCode) if toDropTimer(thisCode): freezeTimers() restartTimer = true @@ -416,7 +508,7 @@ template tocI(f: SomeNumber; s:SString = ""; n = -1) = bind items const cname = compiles(static[string](s)) - ii = instantiationInfo(n) + var ii {.global.} = instantiationInfo(n) when cname: var thisCode {.global.} = CodePoint(-1) else: @@ -437,11 +529,11 @@ template tocI(f: SomeNumber; s:SString = ""; n = -1) = let theTime = getTics() when not cname: for c in items(localCode): - if cpHeap[c.int].name == s: + if cpHeap[c.int].name.equal(s): thisCode = c break if thisCode.isNil: - thisCode = newCodePoint(ii, s) + thisCode = newCodePoint(ii.addr, s) when not cname: localCode.add thisCode let @@ -464,9 +556,9 @@ template tocI(f: SomeNumber; s:SString = ""; n = -1) = #echo "==== end toc ",s," ",ii else: when cname: - tocSet(localTimer,prevRTI,restartTimer,thisCode,f,s,ii,localTic,false) + tocSet(localTimer,prevRTI,restartTimer,thisCode,f,s,ii.addr,localTic,false) else: - tocSet(localTimer,prevRTI,restartTimer,thisCode,f,s,ii,localTic,addr localCode) + tocSet(localTimer,prevRTI,restartTimer,thisCode,f,s,ii.addr,localTic,addr localCode) when noTicToc: template toc*() = discard @@ -483,10 +575,10 @@ else: template toc*(n:int) = tocI(0, "", n-1) template toc*() = tocI(0, "", -2) -when noTicToc: - template getElapsedTime*: float = 0.0 -else: - template getElapsedTime*: float = ticDiffSecs(getTics(), localTimerStart) +#when noTicToc: +# template getElapsedTime*: float = 0.0 +#else: +template getElapsedTime*: float = ticDiffSecs(getTics(), localTimerStart) proc reset(x:var RTInfoObj) = x.nsec = 0 @@ -571,14 +663,14 @@ type Tstr = tuple label: string stats: string +func markMissing(p:bool,str:string):string = + if p: "[" & str & "]" + else: str template ppT(ts: RTInfoObjList, prefix = "-", total = 0'i64, overhead = 0'i64, count = 0'u32, initIx = 0, showAbove = 0.0, showDropped = true): seq[Tstr] = ppT(List[RTInfoObj](ts), prefix, total, overhead, count, initIx, showAbove, showDropped) proc ppT(ts: List[RTInfoObj], prefix = "-", total = 0'i64, overhead = 0'i64, count = 0'u32, initIx = 0, showAbove = 0.0, showDropped = true): seq[Tstr] = - proc markMissing(p:bool,str:string):string = - if p: "[" & str & "]" - else: str var sub:int64 = 0 subo:int64 = 0 @@ -610,7 +702,7 @@ proc ppT(ts: List[RTInfoObj], prefix = "-", total = 0'i64, overhead = 0'i64, small = total!=0 and ns.float/total.float0) loc = pre & markMissing(noexpand, f0 & "(" & $l0 & "-" & (if f==f0:"" else:f) & $l & ")") - nm = pre & markMissing(noexpand, (if tn=="":"" else:tn&":") & (if pn=="":"" else:pn&"-") & ts[j].curr.name) + nm = pre & markMissing(noexpand, (if tn.len==0:"" else: $tn & ":") & (if pn.len==0:"" else: $pn & "-") & $ts[j].curr.name) if total!=0: let cent = 1e2 * ns.float / total.float @@ -670,7 +762,7 @@ template echoTimers*(expandAbove = 0.0, expandDropped = true, aggregate = true) if n0.0: ", not expanding contributions less than " & $(1e2*expandAbove) & " %" else:"" - echo "Timer total ",(tt.float*1e-6)|(0,-3)," ms, overhead ",(oh.float*1e-6)|(0,-3)," ms ~ ",(1e2*oh.float/tt.float)|(0,-1)," %, memory ",rtiListLength*sizeof(RTInfoObj)," B, max ",rtiListLengthMax*sizeof(RTInfoObj)," B",notshowing + echo "Timer total ",(tt.float*1e-6)|(0,-3)," ms, overhead ",(oh.float*1e-6)|(0,-3)," ms ~ ",(1e2*oh.float/tt.float)|(0,-1)," %, runtime info ",rtiListLength*sizeof(RTInfoObj)," B, max ",rtiListLengthMax*sizeof(RTInfoObj)," B, string ",cstrpool.len*sizeof(cstrpool[0])," B",notshowing echo '='.repeat(width) echo "file(lines)"|(-n), "%"|6, "OH%"|6, "microsecs"|12, "OH"|8, "count"|9, "ns/count"|14, "OH/c"|8, "mf"|8, " label" echo '='.repeat(width) @@ -683,6 +775,110 @@ proc echoTimersRaw* = echo cpHeap echo rtiStack +proc getName(t: ptr RTInfoObj): string = + let tn = t.tic.name + let pn = t.prev.name + let name = (if tn.len==0:"" else: $tn & ":") & (if pn.len==0:"" else: $pn & "-") & $t.curr.name + if t.prev.toDropTimer: + result = "{" & name & "}" + else: result = name + +var hs = initTable[string,RTInfoObj]() +proc makeHotspotTable(lrti: List[RTInfoObj]): tuple[ns:int64,oh:int64] = + var nstot = int64 0 + var ohtot = int64 0 + for ri in lrti: + let nc = ri.count + if ri.istic or nc==0: continue + let + f0 = splitFile(ri.prev.loc.filename)[1] + l0 = ri.prev.loc.line + f = splitFile(ri.curr.loc.filename)[1] + l = ri.curr.loc.line + loc = f0 & "(" & $l0 & "-" & (if f==f0:"" else:f) & $l & ") " & getName(unsafeAddr ri) + coh = ri.childrenOverhead + soh = ri.overhead + nsec = ri.nsec + ns = nsec - coh + oh = soh + coh + nstot += ns + ohtot += oh + hs.withValue(loc, t): # t is found value for loc + t.nsec += ri.nsec + t.flops += ri.flops + t.overhead += ri.overhead + t.childrenOverhead += ri.childrenOverhead + t.count += ri.count + for i in 0..0: + let ins = v.nsec - v.childrenOverhead + skeys.add (ns: ins, loc: "incl" & k) + #if v.children.len>0: + var sns = v.nsec + if v.children.len==0: sns -= v.childrenOverhead + for i in 0..0.0: mfs.int|7 else: " ") + nc = (if t.children.len>0: t.children.len|4 else: " ") + f0 = splitFile(t.prev.loc.filename)[1] + l0 = t.prev.loc.line + f = splitFile(t.curr.loc.filename)[1] + l = t.curr.loc.line + loc = f0 & "(" & $l0 & "-" & (if f==f0:"" else:f) & $l & ")" + lc = loc|(-nloc,'.') + nm = getName(t) + #echo &"{pct:6.3f} {ohpct:6.3f} {count} {mf} {nc} {lc} {nm}" + if incl: + echo &"{pct:6.3f} {count} {mf} {nc} I {lc} {nm}" + else: + tsns += nk.ns + let tsnspct = 100.0 * tsns / nstot + echo &"{pct:6.3f} {tsnspct:7.3f} {count} {mf} {nc} S {lc} {nm}" + when isMainModule: import os proc test = @@ -793,3 +989,4 @@ when isMainModule: toc("end") echoTimers() echoTimersRaw() + echoHotspots() diff --git a/src/base/stdUtils.nim b/src/base/stdUtils.nim index 401c088..e0430c3 100644 --- a/src/base/stdUtils.nim +++ b/src/base/stdUtils.nim @@ -59,19 +59,19 @@ proc indexOf*[T](x: openArray[T], y: auto): int = let n = x.len while result= tbar0: break else: inc threadLocals.share[threadNum].counter - #fence() - fence() -template t0wait* = threadBarrier() + fence() +template t0wait* = t0waitO() template twait0O* = threadBarrier() template twait0X* = if threadNum==0: inc threadLocals.share[0].counter - #fence() + fence() else: inc threadLocals.share[threadNum].counter let tbar0 = threadLocals.share[threadNum].counter let p{.volatile.} = threadLocals.share[0].counter.addr while true: + fence() if p[] >= tbar0: break - fence() -template twait0* = threadBarrier() +template twait0* = twait0O() template threadBarrier* = #t0waitX diff --git a/src/bench/benchQio.nim b/src/bench/benchQio.nim index 790c837..feabf4a 100644 --- a/src/bench/benchQio.nim +++ b/src/bench/benchQio.nim @@ -27,8 +27,8 @@ proc test(lat: seq[int]) = var rs = newRNGField(RngMilc6, lo, intParam("seed", 987654321).uint64) #var r0 = lo.RealS() #var r1 = lo.RealS() - var r0 = newSeqWith[type lo.ColorMatrixS](4, lo.ColorMatrixS) - var r1 = newSeqWith[type lo.ColorMatrixS](4, lo.ColorMatrixS) + var r0 = newSeqWith(4, lo.ColorMatrixS) + var r1 = newSeqWith(4, lo.ColorMatrixS) var fn = stringParam("fn", "testqio.lat") let bytes = r0.len * r0[0].numNumbers * sizeof(r0[0][0].numberType) byts.add bytes diff --git a/src/comms/comms.nim b/src/comms/comms.nim index 6b6e69b..9e4be9a 100644 --- a/src/comms/comms.nim +++ b/src/comms/comms.nim @@ -70,8 +70,13 @@ template allReduce*(c: Comm, x: var UncheckedArray[float64], n: int) = template pushSend*(c: Comm, rank: int, xx: SomeNumber) = var x = xx pushSend(c, rank, &&x, sizeof(x)) +template pushSend*(c: Comm, rank: int, x: object) = + pushSend(c, rank, &&x, sizeof(x)) template pushSend*(c: Comm, rank: int, x: seq) = pushSend(c, rank, &&x[0], x.len*sizeof(x[0])) +template waitSend*(c: Comm) = + let n = c.nsends - 1 + c.waitSends(n, n) template waitSends*(c: Comm) = c.waitSends(0, c.nsends-1) template waitSends*(c: Comm, k: int) = @@ -81,6 +86,8 @@ template waitSends*(c: Comm, k: int) = template pushRecv*(c: Comm, rank: int, x: SomeNumber) = pushRecv(c, rank, &&x, sizeof(x)) +template pushRecv*(c: Comm, rank: int, x: object) = + pushRecv(c, rank, &&x, sizeof(x)) template pushRecv*(c: Comm, rank: int, x: seq) = pushRecv(c, rank, &&x[0], x.len*sizeof(x[0])) template freeRecvs*(c: Comm; k: int) = diff --git a/src/comms/commsUtils.nim b/src/comms/commsUtils.nim index 647c0ad..bb9a423 100644 --- a/src/comms/commsUtils.nim +++ b/src/comms/commsUtils.nim @@ -47,6 +47,7 @@ macro echo0*(args: varargs[untyped]): untyped = var call = newCall(bindSym"echoRaw") result = evalArgs(call, args) result.add(quote do: + bind myRank if myRank==0 and threadNum==0: `call` ) @@ -179,16 +180,18 @@ macro rankSumN*(comm:Comm, a:varargs[typed]):auto = i0 = i break if i0<0: + let b = newNimNode(nnkBracket) var s = newNimNode(nnkStmtList) let t = ident("t") for i in 0.. (x_i - x_j)^2 +import qex, strformat + +qexInit() + +let + lat = intSeqParam("lat", @[16]) + lo = lat.newLayout + seed = uint64 intParam("seed", 1) + ntraj = intParam("ntraj", 2) + nsteps = intParam("nsteps", 2) + tau = floatParam("tau", 1) + +var + nAccept = 0 + x = lo.Real + p = lo.Real + xSave = lo.Real + #pSave = lo.Real + rng = lo.newRNGField(MRG32k3a, seed) + globalRng: MRG32k3a # global RNG +globalRng.seed(seed, 987654321) + +echo "ntraj: ", ntraj +echo "nsteps: ", nsteps +echo "tau: ", tau + +proc refreshMomentum(p: auto) = + threads: + p.gaussian rng + +proc action(p,x: auto): float = + var sp, sx: float + threads: + sp := 0.5*p.norm2 + var t = newShifters(x, 1) + let nd = lo.nDim + for mu in 0.. sp.r2req: - qexError "Max r2 larger than requested." + qexError &"Max r2 ({sp.r2.max}) larger than requested ({sp.r2req})" sp.resetStats proc reunit(g:auto) = diff --git a/src/experimental/graph/core.nim b/src/experimental/graph/core.nim new file mode 100644 index 0000000..e322488 --- /dev/null +++ b/src/experimental/graph/core.nim @@ -0,0 +1,349 @@ +#[ + +- Graph traversals are not thread safe +- backward functions for scalar output may receive nil gradient + +TODO + +- function/lambda + +]# + +from strutils import join, toHex, strip + +type + Gfunc* {.acyclic.} = ref object + ## Represent an functional operation: [input] -> output, + forward: proc(z: Gvalue) + backward: proc(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue ## create new graph for backprop + runCount: int + name: string + Gtag = enum + gtVisited, gtGrad, gtFixedGrad + Gtags = set[Gtag] + Gvalue* {.acyclic.} = ref object of RootObj + ## A Value knows its dependencies, which allows backpropagation. + tag: Gtags + inputs*: seq[Gvalue] + gfunc*: Gfunc + locals*: seq[Gvalue] ## for sharing values between forward and among backward functions + epoch: int + +type + GraphError* = object of Defect + GraphValueError* = object of GraphError + +template raiseError*(msg: string) = + raise newException(GraphError, msg) + +template raiseValueError*(msg: string) = + raise newException(GraphValueError, msg) + +template raiseErrorBaseMethod*(msg: string) = + raiseError( + "Base method invoked: " & msg & + "\nMake sure to pass `--multimethods:on` and check there is a custom method for each derived type.") + +var graphDebug* = false + +proc newGfunc*( + forward: proc(z: Gvalue) = nil, + backward: proc(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = nil, + name: string): Gfunc = + Gfunc( + forward: forward, + backward: backward, + name: name) + +proc runCount*(f: Gfunc): int = f.runCount + +proc `$`*(x: Gfunc): string + +method `$`*(x: Gvalue): string {.base.} = + let f = x.gfunc + result = "Gvalue(" & $x.epoch & " " & $x.tag & ")" + if f != nil: + result &= " " & $f + +proc `$`*(x: Gfunc): string = x.name & "<" & $x.runCount & ">" + +proc nodeRepr*(x: Gvalue): string = + let f = x.gfunc + result = $x & " (" & $x.epoch & " " & $x.tag & ")" & "@0X" & strip(toHex(cast[int](x)), trailing = false, chars = {'0'}) + if f != nil: + result &= " " & $f & "@0X" & strip(toHex(cast[int](f)), trailing = false, chars = {'0'}) + +method newOneOf*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("newOneOf(" & $x & ")") ## Be sure to zero init fields +method valCopy*(z: Gvalue, x: Gvalue) {.base.} = raiseErrorBaseMethod("valCopy(" & $z & "," & $x & ")") + +method isZero*(x: Gvalue): bool {.base.} = raiseErrorBaseMethod("isZero(" & $x & ")") +method update*(x: Gvalue, y: int) {.base.} = raiseErrorBaseMethod("update(" & $x & "," & $y & ")") +method update*(x: Gvalue, y: float) {.base.} = raiseErrorBaseMethod("update(" & $x & "," & $y & ")") + +proc assignGvalue(z: Gvalue, x: Gvalue) = + z.tag = x.tag + z.inputs = x.inputs + z.gfunc = x.gfunc + z.epoch = x.epoch + z.valCopy x + +proc copyGvalue(x: Gvalue): Gvalue = + result = newOneOf x + result.assignGvalue x + +let identPlaceholderGFunc = newGfunc(name = "identPlaceholder") +proc identPlaceholder(x: Gvalue): Gvalue = + result = x.copyGvalue + result.tag = {} + result.inputs = @[x] + result.gfunc = identPlaceholderGFunc + result.epoch = 0 + +proc tagClearVisited(x: Gvalue) = + ## only works after recursive proc used gtVisited for the graph traversal. + if gtVisited in x.tag: + x.tag.excl gtVisited + for i in x.inputs: + i.tagClearVisited + +proc tagClear(x: Gvalue, t: Gtag) = + proc c(v: Gvalue) = + if gtVisited in v.tag: + return + v.tag.incl gtVisited + v.tag.excl t + for i in v.inputs: + i.c + x.c + x.tagClearVisited + +proc treeRepr*(v: Gvalue): string = + var shared = newseq[Gvalue]() + proc s(x: Gvalue) = + if gtVisited in x.tag: + if shared.find(x) < 0: + shared.add x + else: + x.tag.incl gtVisited + for i in x.inputs: + i.s + proc r(x: Gvalue): seq[string] = + let si = shared.find x + result = @[x.nodeRepr] + if gtVisited in x.tag: + result[0] &= " #" & $si + else: + if si >= 0: + result[0] &= " #" & $si & "#" + x.tag.incl gtVisited + for i in x.inputs: + for ir in i.r: + result.add(" " & ir) + v.s + v.tagClearVisited + result = v.r.join "\n" + v.tagClearVisited + +method `-`*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`-`(" & $x & ")") +method `+`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`+`(" & $x & ", " & $y & ")") +method `*`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`*`(" & $x & ", " & $y & ")") +method `-`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`-`(" & $x & ", " & $y & ")") +method `/`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`/`(" & $x & ", " & $y & ")") +method exp*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("exp(" & $x & ")") + +proc cond*(c: Gvalue, x: Gvalue, y: Gvalue): Gvalue + +proc condb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = + case i + of 0: + let r = z.inputs[0].newOneOf + r.update 0 + return r + of 1: + if zb == nil: + # the output must be a scalar, otherwise crash later + let r1 = z.inputs[1].newOneOf + let r0 = z.inputs[1].newOneOf + r1.update 1 + r0.update 0 + return cond(z.inputs[0], r1, r0) + else: + return cond(z.inputs[0], zb, zb.newOneOf) + of 2: + if zb == nil: + # the output must be a scalar, otherwise crash later + let r0 = z.inputs[2].newOneOf + let r1 = z.inputs[2].newOneOf + r0.update 0 + r1.update 1 + return cond(z.inputs[0], r0, r1) + else: + return cond(z.inputs[0], zb.newOneOf, zb) + else: + raiseValueError("i must be 0 or 1, got: " & $i) + +proc condf(v: Gvalue) = + if v.inputs[0].isZero: + v.valCopy v.inputs[2] + else: + v.valCopy v.inputs[1] + +let gcond = newGfunc(forward = condf, backward = condb, name = "cond") + +proc cond*(c: Gvalue, x: Gvalue, y: Gvalue): Gvalue = + ## Assume the result is the same type as y, otherwise it'll throw exception later in forward valCopy. + result = y.newOneOf + result.inputs = @[c, x, y] + result.gfunc = gcond + +proc updated*(x: Gvalue) = + var epoch {.global.} = 0 + inc epoch + x.epoch = epoch + +proc evaluated*(x: Gvalue) = + ## signal up-to-date value, given inputs, useful for update value outside of eval, as in update to locals in forward + var maxep = 0 + for i in x.inputs: + if maxep < i.epoch: + maxep = i.epoch + x.epoch = maxep + +proc eval*(v: Gvalue): Gvalue {.discardable.} = + proc r(x: Gvalue) = + if gtVisited in x.tag: + return + x.tag.incl gtVisited + var maxep = 0 + if x.gfunc == gcond: + x.inputs[0].r + if maxep < x.inputs[0].epoch: + maxep = x.inputs[0].epoch + if x.inputs[0].isZero: + x.inputs[2].r + if maxep < x.inputs[2].epoch: + maxep = x.inputs[2].epoch + else: + x.inputs[1].r + if maxep < x.inputs[1].epoch: + maxep = x.inputs[1].epoch + else: + for i in x.inputs: + i.r + if maxep < i.epoch: + maxep = i.epoch + if x.epoch < maxep: + let f = x.gfunc + if graphDebug: + var s = "[graph/core] eval: " & x.nodeRepr + for c in x.inputs: + s &= "\n " & c.nodeRepr + echo s + if f.forward != nil: + x.epoch = maxep + f.runCount.inc + f.forward x + else: + raiseError("inputs.len: " & $x.inputs.len & ", but no forward function defined for:\n" & x.nodeRepr) + v.r + v.tagClearVisited + v + +type + Grad = object + input: Gvalue + grad: Gvalue + Grads = object + output: Gvalue + grads: seq[Grad] + +var gradientList = newseq[Grads]() + +proc dumpGradientList* = + echo "# Gradient List:" + for gs in gradientList: + echo "## output: ",gs.output.nodeRepr + for g in gs.grads: + echo "### w.r.t.: ",g.input.nodeRepr + echo g.grad.treeRepr + +proc recordGrad(input: Gvalue, output: Gvalue, gradient: Gvalue) = + for k in 0..= 0: + for k in 0..0: + g = axexpmuly(t0, p, g) + p = p - h * gaugeForce(gc, g) + g = axexpmuly(t1, p, g) + p = p - h * gaugeForce(gc, g) + g = axexpmuly(t05, p, g) + (g, p, @[lambda]) + +proc int4MN3F1GP(gc, g0, p0, dt: Gvalue, n: int, coeffs: openarray[float]): (Gvalue, Gvalue, seq[Gvalue]) = + let lambda = coeffs.get(0, 0.2470939580390842) + let theta = coeffs.get(1, 0.5 - 1.0 / sqrt(24.0 * lambda.getfloat)) + # scale the force gradient coeff to about the same order as the other + let chi = coeffs.get(2, (1.0 - sqrt(6.0 * lambda.getfloat) * (1.0 - lambda.getfloat)) / 12.0 * (2.0 / (1.0 - 2.0*lambda.getfloat) * 10.0)) + var g = g0 + var p = p0 + let a0 = theta*dt + let a02 = 2.0*a0 + let a1 = 0.5*dt - a0 + let b0 = lambda*dt + let b1 = dt - 2.0*b0 + let c1 = 0.1*chi*(dt*dt) + g = axexpmuly(a0, p, g) + for i in 0..0: + g = axexpmuly(a02, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a1, p, g) + p = p - b1 * gaugeForce(gc, axexpmuly(-c1, gaugeForce(gc, g), g)) + g = axexpmuly(a1, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a0, p, g) + (g, p, @[lambda, theta, chi]) + +proc int4MN5F2GP(gc, g0, p0, dt: Gvalue, n: int, coeffs: openarray[float]): (Gvalue, Gvalue, seq[Gvalue]) = + let rho = coeffs.get(0, 0.06419108866816235) + let theta = coeffs.get(1, 0.1919807940455741) + let vtheta = coeffs.get(2, 0.1518179640276466) + let lambda = coeffs.get(3, 0.2158369476787619) + # scale the force gradient coeff to about the same order as the other + let xi = coeffs.get(4, 0.0009628905212024874 * (2.0 / lambda.getfloat * 20.0)) + var g = g0 + var p = p0 + let a0 = rho*dt + let a02 = 2.0*a0 + let a1 = theta*dt + let a2 = (0.5-(theta+rho))*dt + let b1 = lambda*dt + let b0 = vtheta*dt + let b2 = (1.0-2.0*(lambda+vtheta))*dt + let c1 = 0.05*xi*(dt*dt) + g = axexpmuly(a0, p, g) + for i in 0..0: + g = axexpmuly(a02, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a1, p, g) + p = p - b1 * gaugeForce(gc, axexpmuly(-c1, gaugeForce(gc, g), g)) + g = axexpmuly(a2, p, g) + p = p - b2 * gaugeForce(gc, g) + g = axexpmuly(a2, p, g) + p = p - b1 * gaugeForce(gc, axexpmuly(-c1, gaugeForce(gc, g), g)) + g = axexpmuly(a1, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a0, p, g) + (g, p, @[rho, theta, vtheta, lambda, xi]) + +qexInit() + +tic() + +letParam: + gaugefile = "" + savefile = "config" + savefreq = 0 + lat = + if fileExists(gaugefile): + getFileLattice gaugefile + else: + if gaugefile.len > 0: + qexWarn "Nonexistent gauge file: ", gaugefile + @[8,8,8,16] + beta = 5.4 + dt = 0.025 + trajsThermo = 0 + trajsTrain = 50 + trajsTrainlrWarm = 10 + trajsInfer = 0 + lrmax = 1.0 + lrmin = 0.0001 + weightDecay = 0.0 + seed:uint = 1234567891 + gintalg = "2MN" + lambda = @[0.0] + gsteps = 4 + alwaysAccept:bool = 0 + +echo "rank ", myRank, "/", nRanks +threads: echo "thread ", threadNum, "/", numThreads + +installStandardParams() +echoParams() +processHelpParam() + +let + lo = lat.newLayout + vol = lo.physVol + gc = actWilson(beta) + +var r = lo.newRNGField(RngMilc6, seed) +var R:RngMilc6 # global RNG +R.seed(seed, 987654321) + +var + g = lo.newgauge + p = lo.newgauge + +if fileExists(gaugefile): + tic("load") + if 0 != g.loadGauge gaugefile: + qexError "failed to load gauge file: ", gaugefile + qexLog "loaded gauge from file: ", gaugefile," secs: ",getElapsedTime() + toc("read") + g.reunit + toc("reunit") +else: + #g.random r + g.unit + +g.echoPlaq + +let gdt = toGvalue dt +var params = @[gdt] +let + gg = toGvalue g + gp = toGvalue p + ga0 = gc.gaugeAction gg + t0 = 0.5 * gp.norm2 + h0 = ga0 + t0 + tau = float(gsteps) * gdt + (g1, p1, coeffs) = case gintalg + of "2MN": + int2MN(gc, gg, gp, gdt, gsteps, lambda) + of "4MN3F1GP": + int4MN3F1GP(gc, gg, gp, gdt, gsteps, lambda) + of "4MN5F2GP": + int4MN5F2GP(gc, gg, gp, gdt, gsteps, lambda) + else: + raise newException(ValueError, "unknown intalg: " & gintalg) + ga1 = gc.gaugeAction g1 + t1 = 0.5 * p1.norm2 + h1 = ga1 + t1 + dH = h1 - h0 + acc = cond(dH<0.0, 1.0, exp(-dH)) + loss = -acc * (tau * tau) + +params.add coeffs +var grads = newseq[Gvalue]() +for x in params: + grads.add loss.grad x + +var param = newseq[float]() +for x in params: + param.add x.getfloat +var grad = param +var opt = newAdamW(param, lambda = weightDecay) + +block: + var ps = "param:" + for i in 0.. 0 and traj mod savefreq == 0: + tic("save") + let fn = savefile & &".{traj:05}.lime" + if 0 != g.saveGauge(fn): + qexError "Failed to save gauge to file: ",fn + qexLog "saved gauge to file: ",fn," secs: ",getElapsedTime() + toc("done") + + qexLog "traj ",traj," secs: ",getElapsedTime() + toc("traj end") + +toc() + +processSaveParams() +writeParamFile() +qexFinalize() diff --git a/src/experimental/graph/multi.nim b/src/experimental/graph/multi.nim new file mode 100644 index 0000000..c1f3576 --- /dev/null +++ b/src/experimental/graph/multi.nim @@ -0,0 +1,88 @@ +import core, scalar + +type + Gmulti* {.final.} = ref object of Gvalue + mval: seq[Gvalue] + +proc getmulti*(x: Gvalue): seq[Gvalue] = Gmulti(x).mval + +proc `getmulti=`*(x: Gvalue, y: seq[Gvalue]) = + let xs = Gmulti(x) + xs.mval = y + +proc update*(x: Gvalue, y: seq[Gvalue]) = + x.getmulti = y + x.updated + +proc toGvalue*(x: seq[Gvalue]): Gmulti = + # proc instead of converter to avoid converting seq + result = Gmulti(mval: x) + result.updated + +method newOneOf*(x: Gmulti): Gvalue = + let r = Gmulti(mval: newseq[Gvalue](x.mval.len)) + for i in 0..`*(x, y: Gvalue): Gvalue = not(x < y) +proc `>=`*(x, y: Gvalue): Gvalue = x > y or equal(x,y) +proc `<=`*(x, y: Gvalue): Gvalue = x < y or equal(x,y) + +proc ltsb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = + case i + of 0, 1: + return toGvalue(0.0) + else: + raiseValueError("i must be 0 or 1, got: " & $i) + +proc ltsf(v: Gvalue) = + let x = Gscalar(v.inputs[0]) + let y = Gscalar(v.inputs[1]) + let z = Gscalar(v) + z.sval = if x.sval < y.sval: 1.0 else: 0.0 + +let lts = newGfunc(forward = ltsf, backward = ltsb, name = "lts") + +method `<`*(x: Gscalar, y: Gscalar): Gvalue = Gscalar(inputs: @[Gvalue(x), y], gfunc: lts) + +proc equalsb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = + case i + of 0, 1: + return toGvalue(0.0) + else: + raiseValueError("i must be 0 or 1, got: " & $i) + +proc equalsf(v: Gvalue) = + let x = Gscalar(v.inputs[0]) + let y = Gscalar(v.inputs[1]) + let z = Gscalar(v) + z.sval = if x.sval == y.sval: 1.0 else: 0.0 + +let equals = newGfunc(forward = equalsf, backward = equalsb, name = "equals") + +method equal*(x: Gscalar, y: Gscalar): Gvalue = Gscalar(inputs: @[Gvalue(x), y], gfunc: equals) + +proc ltib(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = + case i + of 0, 1: + return toGvalue(0) + else: + raiseValueError("i must be 0 or 1, got: " & $i) + +proc ltif(v: Gvalue) = + let x = Gint(v.inputs[0]) + let y = Gint(v.inputs[1]) + let z = Gint(v) + z.ival = if x.ival < y.ival: 1 else: 0 + +let lti = newGfunc(forward = ltif, backward = ltib, name = "lti") + +method `<`*(x: Gint, y: Gint): Gvalue = Gint(inputs: @[Gvalue(x), y], gfunc: lti) + +proc equalib(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = + case i + of 0, 1: + return toGvalue(0) + else: + raiseValueError("i must be 0 or 1, got: " & $i) + +proc equalif(v: Gvalue) = + let x = Gint(v.inputs[0]) + let y = Gint(v.inputs[1]) + let z = Gint(v) + z.ival = if x.ival == y.ival: 1 else: 0 + +let equali = newGfunc(forward = equalif, backward = equalib, name = "equali") + +method equal*(x: Gint, y: Gint): Gvalue = Gint(inputs: @[Gvalue(x), y], gfunc: equali) + +when isMainModule: + import math + import std/assertions + + graphDebug = true + + let x = Gscalar() + let y = Gscalar() + let w = x-2.0 + let v = w+y + let z = v*(-v)/w + let dzdy = z.grad y + + func f(a, b: float): float = (a+b-2.0)*(2.0-a-b)/(a-2.0) + func dfdb(a, b: float): float = -2.0*(a+b-2.0)/(a-2.0) + + let a = 1.1 + let b = 3.7 + let c = 1.3 + + x.update a + y.update b + echo z.treeRepr + echo dzdy.treeRepr + z.eval + dzdy.eval + echo "z = ",z + echo z.treeRepr + echo "dzdy = ",dzdy + echo dzdy.treeRepr + + dumpGradientList() + + doAssert almostEqual(z.getfloat, f(a,b)) + doAssert almostEqual(dzdy.getfloat, dfdb(a,b)) + + y.update c + z.eval + dzdy.eval + echo "z = ",z + echo z.treeRepr + echo "dzdy = ",dzdy + echo dzdy.treeRepr + doAssert almostEqual(z.getfloat, f(a,c)) + doAssert almostEqual(dzdy.getfloat, dfdb(a,c)) + + # may need to change the following after we implement optimization passes + doAssert gsneg.runCount == 4 + doAssert gsadd.runCount == 4 + doAssert gsmul.runCount == 6 + doAssert gssub.runCount == 1 + doAssert gsdiv.runCount == 3 diff --git a/src/experimental/graph/tggauge.nim b/src/experimental/graph/tggauge.nim new file mode 100644 index 0000000..a1c188d --- /dev/null +++ b/src/experimental/graph/tggauge.nim @@ -0,0 +1,336 @@ +import math, unittest + +addOutputFormatter(newConsoleOutputFormatter(colorOutput = false)) + +# basicOps.epsilon collide with fenv.epsilon +import qex except epsilon +import algorithms/numdiff, gauge/stoutsmear +import core, scalar, gauge + +template checkeq(ii: tuple[filename:string, line:int, column:int], sa: string, a: float, sb: string, b: float) = + if not almostEqual(a, b, unitsInLastPlace = 64): + checkpoint(ii.filename & ":" & $ii.line & ":" & $ii.column & ": Check failed: " & sa & " :~ " & sb) + checkpoint(" " & sa & ": " & $a) + checkpoint(" " & sb & ": " & $b) + fail() + +template `:~`(a:Gvalue, b:Gvalue) = + checkeq(instantiationInfo(), astToStr a, a.eval.getfloat, astToStr b, b.eval.getfloat) + +template `:<`(a:Gvalue, b:float) = + let av = a.eval.getfloat.abs + if av >= b: + let ii = instantiationInfo() + let sa = astToStr a + let sb = astToStr b + checkpoint(ii.filename & ":" & $ii.line & ":" & $ii.column & ": Check failed: " & sa & " :< " & sb) + checkpoint(" " & sa & ": " & $av) + checkpoint(" " & sb & ": " & $b) + fail() + +# basic test: y <- f(x), or z = y B† = f(x) B†, with x = x + t A +# d/dt z = d/dt f(x+tA) B† +# d/dt z = (d/dt y) (d/dy z)† = (d/dt x) (d/dx z)† = (d/dx z) A† + +proc ndiff(zt: Gvalue, t: Gscalar): (float, float) = + proc z(v:float):float = + t.update v + zt.eval.getfloat + var dzdt,e: float + ndiff(dzdt, e, z, 0.0, 0.125, ordMax=3) + (dzdt, e) + +template check(ii: tuple[filename:string, line:int, column:int], ast: string, dzdt, e, gdota: float) = + if not almostEqual(gdota, dzdt, unitsInLastPlace = 512*1024): + checkpoint(ii.filename & ":" & $ii.line & ":" & $ii.column & ": Check failed: " & ast) + checkpoint(" ndiff: " & $dzdt & " +/- " & $e) + checkpoint(" grad: " & $gdota) + checkpoint(" reldelta: " & $(abs(dzdt-gdota)/abs(dzdt+gdota))) + fail() + +template ckforce(s: untyped, f: untyped, x: Gvalue, p: Gvalue) = + let t = Gscalar() + let (dsdt, e) = ndiff(s(exp(t*p)*x), t) + let pdotf = eval(redot(p, f(x))).getfloat + check(instantiationInfo(), astTostr(s(x) -> f(x)), dsdt, e, pdotf) + +template ckgrad(f: untyped, x: Gvalue, a: Gvalue) = + let t = Gscalar() + let (dzdt, e) = ndiff(f(x+t*a), t) + let ff = f(x) + let gdota = eval(redot(grad(ff, x), a)).getfloat + check(instantiationInfo(), astTostr(f(x)), dzdt, e, gdota) + +template ckgrad2(f: untyped, x: Gvalue, y: Gvalue, ax: Gvalue, ay: Gvalue) = + let t = Gscalar() + let (dzdt, e) = ndiff(f(x+t*ax, y+t*ay), t) + let ff = f(x, y) + let gdota = eval(redot(grad(ff, x), ax) + redot(grad(ff, y), ay)).getfloat + check(instantiationInfo(), astTostr(f(x,y)), dzdt, e, gdota) + +template ckgradm(f: untyped, x: Gvalue, a: Gvalue, b: Gvalue) = + let t = Gscalar() + let (dzdt, e) = ndiff(f(x+t*a).redot b, t) + let ff = f(x).redot b + let gdota = eval(redot(grad(ff, x), a)).getfloat + check(instantiationInfo(), astTostr(f(x)), dzdt, e, gdota) + +template ckgradm2(f: untyped, x: Gvalue, y: Gvalue, ax: Gvalue, ay: Gvalue, b: Gvalue) = + let t = Gscalar() + let (dzdt, e) = ndiff(f(x+t*ax, y+t*ay).redot b, t) + let ff = f(x, y).redot b + let gdota = eval(redot(grad(ff, x), ax) + redot(grad(ff, y), ay)).getfloat + check(instantiationInfo(), astTostr(f(x,y)), dzdt, e, gdota) + +template ckgradm3(f: untyped, x: Gvalue, y: Gvalue, u: Gvalue, ax: Gvalue, ay: Gvalue, au: Gvalue, b: Gvalue) = + let t = Gscalar() + let (dzdt, e) = ndiff(f(x+t*ax, y+t*ay, u+t*au).redot b, t) + let ff = f(x, y, u).redot b + let gdota = eval(redot(grad(ff, x), ax) + redot(grad(ff, y), ay) + redot(grad(ff, u), au)).getfloat + check(instantiationInfo(), astToStr(f(x,y,u)), dzdt, e, gdota) + +qexInit() + +let + lat = @[8,8,8,16] + lo = lat.newLayout + seed = 1234567891u64 + vol = lo.physVol +var + r = lo.newRNGField(MRG32k3a, seed) + g = lo.newgauge + u = lo.newgauge + p = lo.newgauge + q = lo.newgauge + m = lo.newgauge + ss = lo.newStoutSmear(0.1) +const nc = g[0][0].nrows +threads: + g.random r + u.random r + p.randomTAH r + q.randomTAH r + m.randomTAH r +for i in 0..4: + ss.smear(g, g) + ss.smear(u, u) +threads: + for t in m: + t *= 0.01 + +let a = 0.5 * (sqrt(5.0) - 1.0) +let b = sqrt(2.0) - 1.0 + +suite "gauge basic": + setup: + let gg {.used.} = toGvalue g + let gu {.used.} = toGvalue u + let gp {.used.} = toGvalue p + let gq {.used.} = toGvalue q + let gm {.used.} = toGvalue m + let x {.used.} = toGvalue a + let y {.used.} = toGvalue b + + test "norm2": + let n2 = gg.norm2 + let p2 = gp.norm2 + let dp = grad(0.5 * p2, gp) + n2 :~ 4.0*float(nc*vol) + dp.norm2 :~ p2 + norm2(dp-gp) :~ 0 + ckgrad(norm2, gm, gq) + + test "redot": + let n2 = gg.redot gg + let p2 = gp.redot gp + let dp = grad(0.5 * p2, gp) + n2.eval :~ 4.0*float(nc*vol) + dp.norm2 :~ p2 + norm2(dp-gp).eval :~ 0 + let pq = gp.redot gq + norm2(grad(pq, gp) - gq) :< 1e-26 + norm2(grad(pq, gq) - gp) :< 1e-26 + ckgrad2(redot, gp, gq, gg, gu) + + test "retr": + let rtp = gp.retr + let n2 = retr(gg * gg.adj) + rtp*rtp :< 1e-20 + n2.eval :~ 4.0*float(nc*vol) + let p2 = retr(gp * gq.adj) + p2 :~ redot(gp, gq) + norm2(grad(p2, gp) - gq) :< 1e-26 + norm2(grad(p2, gq) - gp) :< 1e-26 + ckgrad(retr, gp, gq) + + test "adj": + norm2(gg.adj*gg - 1.0)/float(4*nc*vol) :< 1e-22 + norm2(gg*gg.adj - 1.0)/float(4*nc*vol) :< 1e-22 + norm2(gp.adj + gp) :< 1e-26 + norm2(grad(gp.adj.norm2, gp) - 2.0*gp) :< 1e-26 + ckgradm(adj, gg, gp, gq) + + test "neg": + norm2(gp.adj - (-gp)) :< 1e-26 + norm2(-gp) :~ gp.norm2 + ckgradm(`-`, gg, gp, gq) + + test "addsg": + let p2 = norm2(x+gp) + grad(p2, x) :~ retr(2.0*(a+gp)) + norm2(grad(p2, gp) - 2.0*(a+gp)) :< 1e-26 + ckgradm2(`+`, x, gp, y, gq, gg) + + test "addgg": + let pq = norm2(gp+gq) + norm2(grad(pq, gp) - 2.0*(gp+gq)) :< 1e-26 + norm2(grad(pq, gq) - 2.0*(gp+gq)) :< 1e-26 + ckgradm2(`+`, gq, gp, gu, gg, gm) + + test "mulsg": + let p2 = norm2(x*gp) + grad(p2, x) :~ 2.0*a*gp.norm2 + norm2(grad(p2, gp) - 2.0*a*a*gp) :< 1e-26 + ckgradm2(`*`, x, gp, y, gq, gg) + + test "mulgg": + let pq = norm2(gp*gq) + norm2(grad(pq, gp) - 2.0*gp*gq*gq.adj) :< 1e-24 + norm2(grad(pq, gq) - 2.0*gp.adj*gp*gq) :< 1e-24 + ckgradm2(`*`, gq, gp, gu, gg, gm) + + test "subgs": + let p2 = norm2(gp-x) + grad(p2, x) :~ retr(-2.0*(gp-a)) + norm2(grad(p2, gp) - 2.0*(gp-x)) :< 1e-26 + ckgradm2(`-`, gp, x, gq, y, gg) + + test "subgg": + let pq = norm2(gp-gq) + norm2(grad(pq, gp) - 2.0*(gp-gq)) :< 1e-26 + norm2(grad(pq, gq) - 2.0*(gq-gp)) :< 1e-26 + ckgradm2(`-`, gq, gp, gu, gg, gm) + + test "exp": + let egp = exp(gp) + norm2(egp.adj*egp - 1.0) :< 1e-20 + norm2(egp*egp.adj - 1.0) :< 1e-20 + ckgradm(exp, gm, 0.1*gp, gg) + + test "projTAH": + let gt = gg.projTAH + let tgt = gt.retr + tgt*tgt :< 1e-26 + ckgradm(projTAH, gg, gp, gu) + +suite "gauge fused": + setup: + let gg {.used.} = toGvalue g + let gu {.used.} = toGvalue u + let gp {.used.} = toGvalue p + let gq {.used.} = toGvalue q + let gm {.used.} = toGvalue m + let x {.used.} = toGvalue a + let y {.used.} = toGvalue b + + test "adjmul": + let rf = gg.adjmul gu + let rg = gg.adj * gu + norm2(rf - rg) :< 1e-26 + let srf = rf.norm2 + let srg = rg.norm2 + norm2(grad(srf, gg) - grad(srg, gg)) :< 1e-25 + norm2(grad(srf, gu) - grad(srg, gu)) :< 1e-25 + ckgradm2(adjmul, gg, gu, gp, gq, gm) + + test "muladj": + let rf = gg.muladj gu + let rg = gg * gu.adj + norm2(rf - rg) :< 1e-26 + let srf = rf.norm2 + let srg = rg.norm2 + norm2(grad(srf, gg) - grad(srg, gg)) :< 1e-25 + norm2(grad(srf, gu) - grad(srg, gu)) :< 1e-25 + ckgradm2(muladj, gg, gu, gp, gq, gm) + + test "contractProjTAH": + let rf = contractProjTAH(gg, gu) + let rg = projTAH(gg * gu.adj) + norm2(rf - rg) :< 1e-26 + let srf = rf.norm2 + let srg = rg.norm2 + norm2(grad(srf, gg) - grad(srg, gg)) :< 1e-26 + norm2(grad(srf, gu) - grad(srg, gu)) :< 1e-26 + ckgradm2(contractProjTAH, gg, gu, gp, gq, gm) + + test "axexp": + let rf = axexp(x, gm) + let rg = exp(x*gm) + norm2(rf - rg) :< 1e-26 + let srf = retr(rf * gu) + let srg = retr(rg * gu) + grad(srf, x) :~ grad(srg, x) + norm2(grad(srf, gm) - grad(srg, gm)) :< 1e-26 + ckgradm2(axexp, x, gm, y, 0.05*gq, gp) + + test "axexpmuly": + let rf = axexpmuly(x, gm, gg) + let rg = exp(x*gm)*gg + norm2(rf - rg) :< 1e-26 + let srf = retr(rf * gu) + let srg = retr(rg * gu) + grad(srf, x) :~ grad(srg, x) + norm2(grad(srf, gm) - grad(srg, gm)) :< 1e-26 + norm2(grad(srf, gg) - grad(srg, gg)) :< 1e-26 + ckgradm3(axexpmuly, x, gm, gu, y, 0.05*gq, gg, gp) + +suite "gauge action": + let gplaq = block: + var pl = 0.0 + for t in g.plaq: + pl += t + pl + + setup: + let gg {.used.} = toGvalue g + let gu {.used.} = toGvalue u + let gm {.used.} = toGvalue m + + test "wilson action": + let beta = 5.4 + let c = actWilson(beta) + let s = gaugeAction(c, gg) + s :~ -gplaq*float(6*vol*beta) + proc act(x: Gvalue): Gvalue = gaugeAction(c, x) + ckgrad(act, gg, gu) + + test "wilson force": + let beta = 5.4 + let c = actWilson(beta) + proc act(x: Gvalue): Gvalue = gaugeAction(c, x) + proc force(x: Gvalue): Gvalue = gaugeForce(c, x) + ckforce(act, force, gg, 10.0*gm) + + test "wilson force gradient": + let beta = 5.4 + let c = actWilson(beta) + proc force(x: Gvalue): Gvalue = gaugeForce(c, x) + ckgradm(force, gg, gu, gm) + + test "wilson force gradient recomp": + let beta = 5.4 + let c = actWilson(beta) + let a = gaugeAction(c, gg) + let f2 = gaugeForce(c, gg).norm2 + let df2 = grad(f2, gg).norm2 + let rs1 = [a.eval.getfloat, f2.eval.getfloat, df2.eval.getfloat] + c.updated + gg.updated + let rs2 = [a.eval.getfloat, f2.eval.getfloat, df2.eval.getfloat] + c.updated + gg.updated + let rs3 = [a.eval.getfloat, f2.eval.getfloat, df2.eval.getfloat] + check rs1 == rs2 + check rs1 == rs3 + +qexFinalize() diff --git a/src/experimental/graph/tgraph.nim b/src/experimental/graph/tgraph.nim new file mode 100644 index 0000000..e83ce25 --- /dev/null +++ b/src/experimental/graph/tgraph.nim @@ -0,0 +1,331 @@ +import math, unittest + +addOutputFormatter(newConsoleOutputFormatter(colorOutput = false)) + +import core, scalar + +template checkeq(ii: tuple[filename:string, line:int, column:int], sa: string, a: float, sb: string, b: float) = + if not almostEqual(a, b, unitsInLastPlace = 64): + checkpoint(ii.filename & ":" & $ii.line & ":" & $ii.column & ": Check failed: " & sa & " :~ " & sb) + checkpoint(" " & sa & ": " & $a) + checkpoint(" " & sb & ": " & $b) + fail() + +template checkeq(ii: tuple[filename:string, line:int, column:int], sa: string, a: int, sb: string, b: int) = + if a != b: + checkpoint(ii.filename & ":" & $ii.line & ":" & $ii.column & ": Check failed: " & sa & " :~ " & sb) + checkpoint(" " & sa & ": " & $a) + checkpoint(" " & sb & ": " & $b) + fail() + +template `:~`(a:Gvalue, b:float) = + checkeq(instantiationInfo(), astToStr a, a.eval.getfloat, astToStr b, b) + +template `:~`(a:Gvalue, b:int) = + checkeq(instantiationInfo(), astToStr a, a.eval.getint, astToStr b, b) + +suite "scalar basic": + # run once before + setup: + # before each test + let a = 0.5 * (sqrt(5.0) - 1.0) + let b = sqrt(2.0) - 1.0 + let x = toGvalue(a) + let y = toGvalue(b) + #teardown: + # after each test + # run once after + + test "assign": + x :~ a + y :~ b + + test "n": + let z = -x + let dx = z.grad x + z :~ -a + dx :~ -1.0 + + test "a": + let z = x+y + let dx = z.grad x + let dy = z.grad y + z :~ a+b + dx :~ 1.0 + dy :~ 1.0 + + test "m": + let z = x*y + let dx = z.grad x + let dy = z.grad y + z :~ a*b + dx :~ b + dy :~ a + + test "s": + let z = x-y + let dx = z.grad x + let dy = z.grad y + z :~ a-b + dx :~ 1.0 + dy :~ -1.0 + + test "d": + let z = x/y + let dx = z.grad x + let dy = z.grad y + z :~ a/b + dx :~ 1.0/b + dy :~ -a/(b*b) + + test "exp": + let z = exp(x) + let dx = z.grad x + let ddx = dx.grad x + let dddx = ddx.grad x + let e = exp(a) + z :~ e + dx :~ e + ddx :~ e + dddx :~ e + + test "nm": + let z = (-x)*x + let dx = z.grad x + z :~ -a*a + dx :~ -2.0*a + + test "nm exp": + let z = (-exp(x))*exp(x) + let dx = z.grad x + z :~ -exp(2.0*a) + dx :~ -2.0*exp(2.0*a) + + test "am": + let z = (x+y)*x + let dx = z.grad x + let dy = z.grad y + z :~ (a+b)*a + dx :~ 2.0*a+b + dy :~ a + + test "ama": + let w = x + let v = w+y + let z = v*v + let dy = z.grad y + z :~ (a+b)*(a+b) + dy :~ 2.0*(a+b) + + test "amd": + let w = x + let v = w+y + let z = v*v/w + let dy = z.grad y + z :~ (a+b)*(a+b)/a + dy :~ 2.0*(a+b)/a + + test "amnd": + let w = x + let v = w+y + let z = v*(-v)/w + let dy = z.grad y + z :~ (a+b)*(-a-b)/a + dy :~ -2.0*(a+b)/a + + test "samnd": + let w = x-2.0 + let v = w+y + let z = v*(-v)/w + let dy = z.grad y + z :~ (a+b-2.0)*(2.0-a-b)/(a-2.0) + dy :~ -2.0*(a+b-2.0)/(a-2.0) + +suite "scalar d2": + setup: + let a = 0.5 * (sqrt(5.0) - 1.0) + let b = sqrt(2.0) - 1.0 + let c = 2.0 * a - 1.0 + let d = a + 3.0 * b - 1.0 + let x = Gscalar() + let y = Gscalar() + x.update a + y.update b + + test "samnd dx dy": + let w = x-2.0 + let v = w+y + let z = v*(-v)/w + let dy = z.grad y + let dxy = dy.grad x + z :~ (a+b-2.0)*(2.0-a-b)/(a-2.0) + dy :~ -2.0*(a+b-2.0)/(a-2.0) + dxy :~ 2.0*b/((a-2.0)*(a-2.0)) + + test "samnd dx dy repeat": + let w = x-2.0 + let v = w+y + let z = v*(-v)/w + let dy = z.grad y + let dxy = dy.grad x + z :~ (a+b-2.0)*(2.0-a-b)/(a-2.0) + dy :~ -2.0*(a+b-2.0)/(a-2.0) + dxy :~ 2.0*b/((a-2.0)*(a-2.0)) + y.update c + dy :~ -2.0*(a+c-2.0)/(a-2.0) + x.update d + dxy :~ 2.0*c/((d-2.0)*(d-2.0)) + y.update a + z :~ (d+a-2.0)*(2.0-d-a)/(d-2.0) + dy :~ -2.0*(d+a-2.0)/(d-2.0) + dxy :~ 2.0*a/((d-2.0)*(d-2.0)) + + test "samndpdy dx": + let w = x-2.0 + let v = w+y + let z = v*(-v)/w + let dy = z.grad y + let u = z+0.1*dy + let dx = (u*u).grad x + z :~ (a+b-2.0)*(2.0-a-b)/(a-2.0) + dx :~ -2.0*(b+a-2.0)*(5.0*b+5.0*a-9.0)*(5.0*b*b+b-5.0*a*a+20.0*a-20.0)/(25.0*(a-2.0)*(a-2.0)*(a-2.0)) + y.update c + dx :~ -2.0*(c+a-2.0)*(5.0*c+5.0*a-9.0)*(5.0*c*c+c-5.0*a*a+20.0*a-20.0)/(25.0*(a-2.0)*(a-2.0)*(a-2.0)) + x.update d + y.update a + u :~ (d+a-2.0)*(2.0-d-a)/(d-2.0) - 0.1*2.0*(d+a-2.0)/(d-2.0) + dx :~ -2.0*(a+d-2.0)*(5.0*a+5.0*d-9.0)*(5.0*a*a+a-5.0*d*d+20.0*d-20.0)/(25.0*(d-2.0)*(d-2.0)*(d-2.0)) + +suite "bool and cond": + setup: + let a = 0.5 * (sqrt(5.0) - 1.0) + let b = sqrt(2.0) - 1.0 + let c = 2.0 * a - 1.0 + let d = a + 3.0 * b - 1.0 + let x = toGvalue a + let y = toGvalue b + + test "not": + let f = toGvalue 0 + not(f) :~ 1 + not(not f) :~ 0 + let t = toGvalue 1.0 + not(t) :~ 0.0 + not(not t) :~ 1.0 + + test "and": + let fi = toGvalue 0 + let ti = toGvalue 1 + let t = toGvalue 1.0 + let f = toGvalue 0.0 + fi and t :~ 0.0 + t and fi :~ 0 + ti and t :~ 1.0 + t and ti :~ 1 + f and fi :~ 0 + fi and f :~ 0.0 + + test "or": + let fi = toGvalue 0 + let ti = toGvalue 1 + let t = toGvalue 1.0 + let f = toGvalue 0.0 + fi or t :~ 1.0 + t or fi :~ 1 + ti or t :~ 1.0 + t or ti :~ 1 + f or fi :~ 0 + fi or f :~ 0.0 + + test "xor": + let fi = toGvalue 0 + let ti = toGvalue 1 + let t = toGvalue 1.0 + let f = toGvalue 0.0 + fi xor t :~ 1.0 + t xor fi :~ 1 + ti xor t :~ 0.0 + t xor ti :~ 0 + f xor fi :~ 0 + fi xor f :~ 0.0 + + test "condi": + let k = toGvalue 0 + let z = cond(k, x, y) + let dx = z.grad x + let dy = z.grad y + z :~ b + dx :~ 0.0 + dy :~ 1.0 + k.update 1 + z :~ a + dx :~ 1.0 + dy :~ 0.0 + + test "conds": + let k = toGvalue 1.0 + let z = cond(k, x, y) + let dx = z.grad x + let dy = z.grad y + z :~ a + dx :~ 1.0 + dy :~ 0.0 + k.update 0.0 + z :~ b + dx :~ 0.0 + dy :~ 1.0 + + test "condi 2": + let k = toGvalue 0 + let z = cond(k, x, y) + let z2 = z*z + let dx = z2.grad x + let dy = z2.grad y + z2 :~ b*b + dx :~ 0.0 + dy :~ 2.0*b + k.update 1 + y.update c + z2 :~ a*a + dx :~ 2.0*a + dy :~ 0.0 + k.update 0 + z2 :~ c*c + dx :~ 0.0 + dy :~ 2.0*c + + test "conds 2": + let k = toGvalue 1.0 + let z = cond(k, x, y) + let z2 = z*z + let dx = z2.grad x + let dy = z2.grad y + z2 :~ a*a + dx :~ 2.0*a + dy :~ 0.0 + k.update 0.0 + x.update d + z2 :~ b*b + dx :~ 0.0 + dy :~ 2.0*b + k.update 1.0 + x.update c + z2 :~ c*c + dx :~ 2.0*c + dy :~ 0.0 + + test "cond eval shortcut": + let t = toGvalue 2.0 + let f = toGvalue 0.0 + let t2 = t*t + let t3 = t*t*t + check t2.getfloat == 0.0 # should be zero before eval + check t3.getfloat == 0.0 # ditto + var tt = cond(t, t2, t3) + tt :~ 4.0 + check t2.getfloat == 4.0 + check t3.getfloat == 0.0 # should remain zero after eval + tt = cond(f, t3, t2) + tt :~ 4.0 + check t2.getfloat == 4.0 + check t3.getfloat == 0.0 # should remain zero after eval diff --git a/src/experimental/graph/tgstout.nim b/src/experimental/graph/tgstout.nim new file mode 100644 index 0000000..2a35d9d --- /dev/null +++ b/src/experimental/graph/tgstout.nim @@ -0,0 +1,98 @@ +import qex, algorithms/numdiff, gauge/stoutsmear +import core, scalar, gauge + +qexInit() + +letParam: + lat = @[12,12,12,24] + dt = 0.1 + eps = 0.004 + nstep = 3 + beta = 5.4 + seed:uint = 1234567891 + +echoParams() +echo "rank ", myRank, "/", nRanks +threads: echo "thread ", threadNum, "/", numThreads + +let + lo = lat.newLayout + vol = lo.physVol + gc = GaugeActionCoeffs(plaq: beta) + g = lo.newgauge + u = lo.newgauge + +var + r = lo.newRNGField(RngMilc6, seed) + ss = lo.newStoutSmear(dt) + +for i in 0..3: + ss.smear(g, g) + +g.random r +g.echoPlaq +for i in 1..nstep: + if i==1: + ss.smear(g, u) + else: + ss.smear(u, u) +u.echoPlaq +let sgs = gc.gaugeAction1 u +echo "smear S: ",sgs + +proc act(t: float): float = + var ss = lo.newStoutSmear(t) + for i in 1..nstep: + if i==1: + ss.smear(g, u) + else: + ss.smear(u, u) + gc.gaugeAction1 u + +var ndt, err: float +ndiff(ndt, err, act, dt, eps, ordMax=3) +echo "numdiff smear dS/dt: ",ndt," +/- ",err + +proc stout(g, t: Gvalue, n: int): Gvalue = + var g = g + for i in 1..n: + g = axexpmuly(t, gaugeForce(actWilson(-3.0), g), g) + g + +let + gdt = toGvalue dt + gg = toGvalue g + gs = gg.stout(gdt, nstep) + s = gc.gaugeAction gs + ddt = s.grad gdt + +# echo ddt.treeRepr + +gs.eval.getgauge.echoPlaq +let sgg = s.eval.getfloat +echo "graph S: ",sgg + +let gddt = ddt.eval.getfloat +echo "graph dS/dt: ",gddt +# echo ddt.treeRepr + +proc gact(t: float): float = + gdt.update t + s.eval.getfloat +var gndt, gerr: float +ndiff(gndt, gerr, gact, dt, eps, ordMax=4) +echo "numdiff graph dS/dt: ",gndt," +/- ",gerr + +let + rds = abs((sgs-sgg)/(sgs+sgg)) + rgdt = abs((ndt-gddt)/(ndt+gddt)) + rndt = abs((ndt-gndt)/(ndt+gndt)) +echo "rel dS: ",rds +echo "rel graph dS/dt: ",rgdt +echo "rel ndiff dS/dt: ",rndt + +doassert rds < 1e-11 +doassert rgdt < 1e-11 +doassert rndt < 1e-11 + +qexFinalize() diff --git a/src/experimental/stagag.nim b/src/experimental/stagag.nim index 982ee36..e29f934 100644 --- a/src/experimental/stagag.nim +++ b/src/experimental/stagag.nim @@ -1,11 +1,15 @@ import qex, gauge, physics/[qcdTypes,stagSolve] -import times, macros, algorithm +import times, macros, algorithm, sequtils import hmc/metropolis import observables/sources import quda/qudaWrapper import hmc/agradOps -qexinit() +qexInit() +echo "rank ", myRank, "/", nRanks +threads: + echo "thread ", threadNum, "/", numThreads + proc `:=`*(r: var seq, x: seq) = for i in 0..=0 and i 0 or - p grad(hn) var pg = 0.0 if m.hNew > m.hOld: - pg = - m.pAccept * params[i].grad + pg = - pm * params[i].grad paccg[i].push pg #pg = paccg[i].mean - var costg = (if i==0: 2.0*ct*pm else: 0.0) + var costg = (if i==0: 2.0*nsteps*ct*pm else: 0.0) costg = costg + ct * ct * pg let d = m.hNew - m.hOld - let alp = alpha - costg += alp*d*(d*pg + 2*(pm-1)*params[i].grad) - #costg = costg/nff + #let alp = alpha + #costg += alp*d*(d*pg + 2*(pm-1)*params[i].grad) + costg = costg/fc - (ct*ct*pm/(fc*fc))*forceCostGrad(i) #costg = nff*costg*(cost*cost) # extra - to make it minimize + if i==0: + if m.deltaH > 10: + costg = -0.1 if fixtau and i==0: costg = 0 - if fixparams and i>0: costg = 0 + if fixhmasses and i>0 and i<=hmasses.len: costg = 0 + if fixparams and i>hmasses.len: costg = 0 #if i > 0: # costg = (m.hOld-m.hNew)*params[i].grad result.add costg @@ -1267,16 +1750,24 @@ proc checkGrad(m: Met) = gx := g #let tg = vtau.grad #let f0 = vtau.obj * exp(m.hOld-m.hNew) + #let h0 = m.hOld let eps = 1e-6 var gs = newSeq[float](0) for i in 0..0 and i<=hmasses.len: + rate = ratefac * lrateh #let m = cgstat[i].mean #let d = s*g[i] let d = rate*g[i] @@ -1417,6 +1915,24 @@ block: src.wallSource(0, v) #echo src.norm2slice(3) +var gfStats: RunningStat +var ffStats = newSeq[RunningStat](1+hmasses.len) + +proc resetMeasure = + gfStats.clear + for i in 0.. 0: - echo "Starting warmups" - #setupMDx() - alwaysAccept = warmmd - for n in 1..nwarm: - m.update - m.clearStats - pacc.clear +block: + tic("warmup") + if nwarm > 0: + echo "Starting warmups" + #setupMDx() + alwaysAccept = warmmd + for n in 1..nwarm: + m.update + m.clearStats + pacc.clear + toc("end warmup") echo "Starting HMC" #setupMD5() @@ -1493,32 +2044,42 @@ alwaysAccept = false #gutime = 0.0 #gftime = 0.0 #fftime = 0.0 +for i in 0.. 0: if n mod upit == 0: - updateParams(sqrt(float upit)*lrate) + updateParams(sqrt(float upit)) lrate *= anneal - let tup = getElapsedTime() - measure() + lrateh *= anneal let ttot = getElapsedTime() echo "End trajectory update: ", tup, " measure: ", ttot-tup, " total: ", ttot let et = getElapsedTime() - toc() + toc("end training") echo "HMC time: ", et #let at = gutime + gftime + fftime #echo &"gu: {gutime} gf: {gftime} ff: {fftime} ot: {et-at} tt: {et}" +resetMeasure() +for i in 0.. 0: m.clearStats pacc.clear - tic() + tic("inference") for n in 1..trajs: echo "Starting inference: ", n echoParams() @@ -1529,14 +2090,21 @@ if trajs > 0: #echo "cost: ", nff/(vtau.obj*vtau.obj*m.avgPAccept) echo "cost: ", getCost0(m) let tup = getElapsedTime() - echo "End inference: ", tup + for i in 0..1: - result = "#" & $id & "=(" - g.refs = -id - inc id - else: - result = "(" - result &= g.str - for x in g.args: - result &= " " & x.go - result &= ")" - else: - result = g.str - result = g.go - -proc newVar[T](x:T, s="$V"):GraphNode[T] = - GraphNode[T](val:GraphValue[T](v:x), str:s, - initGrad: (proc(g:Graph) = g.grad = GraphValue[T](v:1.T))) - -proc newConst[T](x:T, s="$C"):GraphNode[T] = - GraphNode[T](tag: {gtConst}, val:GraphValue[T](v:x), str:s & "|" & $x & "|", - initGrad: (proc(g:Graph) = g.grad = GraphValue[T](v:1.T))) - -proc wasUpdated*(g:Graph) = - g.tag.excl gtRun - -proc isValid*(g:Graph):bool = - result = true - if g.isop: - for x in g.args: - result &= x.isValid - if result == false: - g.excl gtRun - else: - result = gtRun in g.tag - -proc eval*(g:Graph) = - if g.isop: - if gtRun in g.tags: - if not g.isValid: - eval g - else: - g.tag.incl gtRun - for x in g.args: - x.eval - g.run(g) - else: - g.tag.incl gtRun - -proc clearGrad(g:Graph) = - g.grad = nil - g.tag.excl gtDF - g.tag.excl gtGrad - if g.isop: - for x in g.args: - x.clearGrad - -proc evalGrad(g:Graph) = - if g.isop and gtRun notin g.tag: - g.eval - if gtDF notin g.tag: - g.clearGrad - g.countRefs - g.tag.incl gtDF - g.initGrad(g) - proc go(g:Graph) = - g.tag.incl gtGrad - if g.isop: - g.refs.dec - if g.refs>0: - return - # Wait until the last reference of the shared nodes. - g.back(g) - for x in g.args: - x.go - g.go - -proc evalGrad[G,X](g:GraphNode[G], x:GraphNode[X]):X = - # TODO only descend in to nodes that contains x. - g.Graph.evalGrad - proc go(g:Graph):Graph = - if g == x: - return x - elif g.isop: - for c in g.args: - if c.go == x: - return x - g - if g.go == x: - GraphValue[X](x.grad).v - else: - X 0 - -proc eval[T](g:GraphNode[T]):T = - g.Graph.eval - GraphValue[T](g.val).v - -proc `+`[X,Y](x:GraphNode[X], y:GraphNode[Y]):auto = - type R = type(GraphValue[X](x.val).v+GraphValue[Y](y.val).v) - GraphNode[R](isop:true, str:"+", args: @[x.Graph,y], - run: (proc(g:Graph) = - echo "# Run: ",g.args[0].str," + ",g.args[1].str - let v = GraphValue[R](v:GraphValue[X](g.args[0].val).v+GraphValue[Y](g.args[1].val).v) - g.val = v - g.str &= "(=" & $v.v & ")"), - initGrad: (proc(g:Graph) = g.grad = GraphValue[R](v:1.R)), - back: (proc(g:Graph) = - echo "# Back: ",x.str," * ",y.str - if gtConst notin g.args[0].tag: - let t = GraphValue[R](g.grad).v.X - if g.args[0].grad != nil: - g.args[0].grad = GraphValue[X](v:GraphValue[X](g.args[0].grad).v+t) - else: - g.args[0].grad = GraphValue[X](v:t) - if gtConst notin g.args[1].tag: - let t = GraphValue[R](g.grad).v.Y - if g.args[1].grad != nil: - g.args[1].grad = GraphValue[Y](v:GraphValue[Y](g.args[1].grad).v+t) - else: - g.args[1].grad = GraphValue[Y](v:t))) -proc `+`[X](x:GraphNode[X], y:SomeNumber):auto = x + newConst(y) -proc `+`[Y](x:SomeNumber, y:GraphNode[Y]):auto = newConst(x) + y - -proc `*`[X,Y](x:GraphNode[X], y:GraphNode[Y]):auto = - type R = type(GraphValue[X](x.val).v*GraphValue[Y](y.val).v) - GraphNode[R](isop:true, str:"*", args: @[x.Graph,y], - run: (proc(g:Graph) = - echo "# Run: ",x.str," * ",y.str - let v = GraphValue[R](v:GraphValue[X](g.args[0].val).v*GraphValue[Y](g.args[1].val).v) - g.val = v - g.str &= "(=" & $v.v & ")"), - initGrad: (proc(g:Graph) = g.grad = GraphValue[R](v:1.R)), - back: (proc(g:Graph) = - echo "# Back: ",x.str," * ",y.str - if gtConst notin g.args[0].tag: - let t = GraphValue[R](g.grad).v*GraphValue[Y](g.args[1].val).v - if g.args[0].grad != nil: - g.args[0].grad = GraphValue[X](v:GraphValue[X](g.args[0].grad).v+t) - else: - g.args[0].grad = GraphValue[X](v:t) - if gtConst notin g.args[1].tag: - let t = GraphValue[X](g.args[0].val).v*GraphValue[R](g.grad).v - if g.args[1].grad != nil: - g.args[1].grad = GraphValue[X](v:GraphValue[X](g.args[1].grad).v+t) - else: - g.args[1].grad = GraphValue[Y](v:t))) -proc `*`[X](x:GraphNode[X], y:SomeNumber):auto = x * newConst(y) -proc `*`[Y](x:SomeNumber, y:GraphNode[Y]):auto = newConst(x) * y - -when isMainModule: - let - x = newVar(2.0, "x") - y = newVar(3.0, "y") - z = x*(y+5.0) - t = x*y*(z+1.0)*z - echo "x: ",x - echo "y: ",y - echo "z: ",z - echo "t: ",t - let rt = t.eval - echo "t: ",t - echo "rt: ",rt - echo "rz: ",z.eval - echo "dtdz: ",t.evalGrad z - echo "dtdx: ",t.evalGrad x - echo "dtdy: ",t.evalGrad y - echo "dzdx: ",z.evalGrad x - echo "dzdy: ",z.evalGrad y - let u = (x+t)*(t+z) - echo "u: ",u - echo "dudy: ",u.evalGrad y diff --git a/src/hmc/agradOps.nim b/src/hmc/agradOps.nim index 0f313be..7a46688 100644 --- a/src/hmc/agradOps.nim +++ b/src/hmc/agradOps.nim @@ -8,6 +8,8 @@ type #GaugeFV* = AgVar[GaugeF] GaugeF*[V:static int,T] = seq[Field[V,T]] GaugeFV*[V:static int,T] = AgVar[GaugeF[V,T]] + FieldV*[V:static int,T] = AgVar[Field[V,T]] + template newFloatV*(c: AgTape, x = 0.0): auto = var t = FloatV.new() t.doGrad = true @@ -51,6 +53,26 @@ proc peqmul[G:GaugeF](r: G, x: float, y: G) = for s in r[mu]: r[mu][s] += x * y[mu][s] +proc assigngrad(c: float, r: var float) = + r += c +proc assignfwd[I,O](op: AgOp[I,O]) {.nimcall.} = + #mixin peq + op.outputs.obj := op.inputs.maybeObj + when op.inputs is AgVar: + zero op.inputs.grad +proc assignbck[I,O](op: AgOp[I,O]) {.nimcall.} = + #mixin peq + when op.inputs is AgVar: + if op.inputs.doGrad: + assigngrad(op.outputs.grad, op.inputs.grad) +proc assign(c: var AgTape, r: AgVar, x: auto) = + var op = newAgOp(x, r, assignfwd, assignbck) + c.add op +template assign*(r: AgVar, x: auto) = + r.ctx.assign(r, x) +template `:=`*(r: AgVar, x: auto) = + r.ctx.assign(r, x) + proc addgrad1(c: float, r: var float, y: float) = r += c proc addgrad2(c: float, x: float, r: var float) = @@ -123,6 +145,25 @@ proc mulgrad1(c: float, r: var float, y: float) = proc mulgrad2(c: float, x: float, r: var float) = r += x * c +proc mul[F:Field](r: F, x: float, y: F) = + threads: + r := x * y +proc mulgrad1[F:Field](c: F, r: var float, y: F) = + var rr = 0.0 + threads: + var l: typeof(redot(y[0][], c[0][])) + for s in c: + l += redot(y[s][], c[s][]) + var m = simdReduce l + threadRankSum m + threadSingle: + rr = m + r += rr +proc mulgrad2[F:Field](c: F, x: float, r: F) = + threads: + for s in c: + r[s][] += x * c[s][] + proc mul[G:GaugeF](r: G, x: float, y: G) = threads: for mu in 0..sbQex>op") threadBarrier() if parEven: stagD2ee(s.se, s.so, a, s.g, b, m*m) @@ -55,12 +74,16 @@ proc solveXX*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams; stagD2oo(s.se, s.so, a, s.g, b, m*m) toc("stagD2oo") #threadBarrier() - var oa = (apply: op) var cg = newCgState(r, x) if parEven: sp.subset.layoutSubset(r.l, "even") else: sp.subset.layoutSubset(r.l, "odd") + #if precon: + #var oap = (apply: op, applyPrecon: oppre) + #cg.solve(oap, sp) + #else: + var oa = (apply: op, precon: cpNone) cg.solve(oa, sp) toc("cg.solve") sp.calls = 1 @@ -72,9 +95,11 @@ proc solveXX*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams; echo "solveEE(QEX): ", sp.getStats else: echo "solveOO(QEX): ", sp.getStats + toc("end sbQex") of sbQuda: - tic() + tic("sbQuda") if parEven: + #echo x.even.norm2, " ", sp.r2req s.qudaSolveEE(r,x,m,sp) toc("qudaSolveEE") else: @@ -87,7 +112,7 @@ proc solveXX*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams; if sp0.verbosity>0: echo "solveXX(QUDA): ", sp.getStats of sbGrid: - tic() + tic("sbQuda") if parEven: s.gridSolveEE(r,x,m,sp) toc("gridSolveEE") @@ -103,6 +128,7 @@ proc solveXX*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams; sp.iterationsMax = sp.iterations sp.r2.push 0.0 sp0.addStats(sp) + toc("end solveXX") proc solveEE*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams) = solveXX(s, r, x, m, sp0, parEven=true) @@ -113,7 +139,7 @@ proc solveOO*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams) = # right-preconditioned proc solveReconR(s:Staggered; x,b:Field; m:SomeNumber; sp: var SolverParams; b2e,b2o: float) = - tic() + tic("solveReconR") let b2 = b2e + b2o let r2stop = sp.r2req * b2 let r2stop2 = 0.5 * r2stop @@ -151,7 +177,7 @@ proc solveReconR(s:Staggered; x,b:Field; m:SomeNumber; sp: var SolverParams; # left-preconditioned with odd reconstruction proc solveReconL(s:Staggered; x,b:Field; m:SomeNumber; sp: var SolverParams; b2e,b2o: float) = - tic() + tic("solveReconL") #if b2e == 0.0 or b2o == 0.0: #solveR(s, y, r, m, sp, r2e, r2o) var d = newOneOf(b) @@ -172,6 +198,7 @@ proc solveReconL(s:Staggered; x,b:Field; m:SomeNumber; sp: var SolverParams; toc("setup") s.solveEE(x, d, m, sp) toc("solveEE") + #echo "solveReconL ", d2e, " ", sp.r2req threads: x.even *= 4 threadBarrier() @@ -241,6 +268,7 @@ proc solve*(s:Staggered; x,b:Field; m:SomeNumber; sp0: var SolverParams) = s.D(r, x, m) threadBarrier() r := b - r + threadBarrier() let r2et = r.even.norm2 r2ot = r.odd.norm2 @@ -249,7 +277,7 @@ proc solve*(s:Staggered; x,b:Field; m:SomeNumber; sp0: var SolverParams) = r2o = r2ot r2 = r2e + r2o if sp.verbosity>0: - echo "stagSolve r2: ", r2/b2 + echo "stagSolve r2/b2: ", r2/b2 sp.r2.init r2/b2 sp.calls = 1 @@ -264,6 +292,23 @@ proc solve*(s:Staggered; x,b:Field; m:SomeNumber; sp0: var SolverParams) = echo "stagSolve: ", sp.getStats sp0.addStats(sp) +# multimass (trivial version with multiple single mass calls for now) +proc solve*(s: Staggered; r: seq[Field]; x: Field; m: seq[float]; + sp: seq[SolverParams]) = + let n = m.len + doAssert(r.len == n) + doAssert(sp.len == n) + for i in 0.. 1: +# packp(r, x, l, 1) +template packp1*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 1: packp(r, x, l, 1) -template packp2*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packp2*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 2: +# packp(r, x, l, 2) +template packp2*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 2: packp(r, x, l, 2) -template packp4*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packp4*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 4: +# packp(r, x, l, 4) +template packp4*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 4: packp(r, x, l, 4) -template packp8*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packp8*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 8: +# packp(r, x, l, 8) +template packp8*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 8: packp(r, x, l, 8) @@ -301,16 +313,28 @@ proc packm*(r: var openArray[SomeNumber], x: SimdArrayObj, else: assign(la[][il], x[][i]) inc il -template packm1*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packm1*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 1: +# packm(r, x, l, 1) +template packm1*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 1: packm(r, x, l, 1) -template packm2*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packm2*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 2: +# packm(r, x, l, 2) +template packm2*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 2: packm(r, x, l, 2) -template packm4*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packm4*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 4: +# packm(r, x, l, 4) +template packm4*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 4: packm(r, x, l, 4) -template packm8*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +#template packm8*[R,L:openArray[SomeNumber]](r: var R, x: SimdArrayObj, l: var L) = +# when numNumbers(x) > 8: +# packm(r, x, l, 8) +template packm8*(r: var openArray[SomeNumber], x: SimdArrayObj, l: var openArray[SomeNumber]) = when numNumbers(x) > 8: packm(r, x, l, 8) @@ -800,6 +824,9 @@ template makeSimdArray2*(L:typed;B,F:typedesc;N0,N:typed,T:untyped) {.dirty.} = template imadd*(r:var T; x:SomeNumber; y:T) = imadd(r, x.to(type(T)), y) template imsub*(r:var T; x:SomeNumber; y:T) = imsub(r, x.to(type(T)), y) template divd*(r:var T; x:SomeNumber; y:T) = divd(r, x.to(type(T)), y) + template imadd*(r:var T; x:T; y:SomeNumber) = imadd(r, x, y.to(type(T))) + template imsub*(r:var T; x:T; y:SomeNumber) = imsub(r, x, y.to(type(T))) + template divd*(r:var T; x:T; y:SomeNumber) = divd(r, x, y.to(type(T))) template imul*(r:var T; x:SomeNumber) = imul(r, x.to(type(T))) template idiv*(r:var T; x:SomeNumber) = idiv(r, x.to(type(T))) template msub*(r:var T; x:SomeNumber; y,z:T) = msub(r, x.to(type(T)), y, z) diff --git a/src/solvers/cg.nim b/src/solvers/cg.nim index 6fddca6..8f9234d 100644 --- a/src/solvers/cg.nim +++ b/src/solvers/cg.nim @@ -5,57 +5,73 @@ import solverBase export solverBase type + CgPrecon* = enum + cpNone, + cpHerm, + cpLeftRight, + #cpRightNonHerm # x = Ry, r = b-ARy, min r'A^-1r, p'r = 0 CgState*[T] = object - r,Ap,b: T - p,x,z: T - b2,r2,r2old,r2stop,rz,rzold: float + r*,Ap,b*: T + p,x*: T + z,q,LAp: T + b2,r2,r2old,r2stop,rz,rzold,alpha: float iterations: int - precon: bool + precon: CgPrecon proc reset*(cgs: var CgState) = cgs.b2 = -1 cgs.iterations = 0 + cgs.r2 = 1.0 cgs.r2old = 1.0 cgs.rzold = 1.0 cgs.r2stop = 0.0 -proc newCgState*[T](x,b: T; precon=false): CgState[T] = +proc initPrecon*(state: var CgState) = + case state.precon + of cpNone: + state.z = state.r + state.q = state.p + state.LAp = state.Ap + of cpHerm: + state.z = newOneof state.r + state.q = state.p + state.LAp = state.Ap + of cpLeftRight: + state.z = newOneof state.r + state.q = newOneof state.p + state.LAp = newOneOf state.z + +proc newCgState*[T](x,b: T): CgState[T] = result.r = newOneOf b result.Ap = newOneOf b result.b = b result.p = newOneOf x result.x = x - result.precon = precon - if precon: - result.z = newOneof b - else: - result.z = result.r + result.precon = cpNone + result.initPrecon result.reset # solves: A x = b proc solve*(state: var CgState; op: auto; sp: var SolverParams) = mixin apply, applyPrecon - tic() + tic("solve") let vrb = sp.verbosity - template verb(n:int; body:untyped):untyped = + template verb(n:int; body:untyped) = if vrb>=n: body let sub = sp.subset - template subset(body:untyped):untyped = + template subset(body:untyped) = onNoSync(sub): body - template mythreads(body:untyped):untyped = + template mythreads(body:untyped) = threads: onNoSync(sub): body - const precon = compiles(op.applyPrecon(state.z, state.r)) + let precon = op.precon if precon != state.precon: state.precon = precon - if precon: - state.z = newOneOf state.r - else: - state.z = state.r - + state.initPrecon + state.reset let r = state.r p = state.p @@ -63,10 +79,55 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) = x = state.x b = state.b z = state.z + q = state.q + LAp = state.LAp var b2 = state.b2 r2 = state.r2 rz = state.rz + #qLAp = state.qLAp + + if precon == cpHerm: + when not compiles(op.applyPrecon(z, r)): + qexError("cg.solve: precon == cpHerm but op.applyPrecon not found") + if precon == cpLeftRight: + when not compiles(op.applyPreconL(z, r)): + qexError("cg.solve: precon == cpLeftRight but op.applyPreconL not found") + when not compiles(op.applyPreconR(p, q)): + qexError("cg.solve: precon == cpLeftRight but op.applyPreconR not found") + + template getRz = + case precon + of cpNone: + rz = r2 + of cpHerm: + subset: + rz = r.redot z + of cpLeftRight: + subset: + rz = z.norm2 # convenient to use rz here for z2 + #of cpRightNonHerm: + # subset: + # rz = Ap.dot z # convenient to use rz here + template preconL(z, r) = + case precon + of cpNone: + discard + of cpHerm: + when compiles(op.applyPrecon(z, r)): + op.applyPrecon(z, r) + of cpLeftRight: + when compiles(op.applyPreconL(z, r)): + op.applyPreconL(z, r) + template preconR(p, q) = + case precon + of cpNone: + discard + of cpHerm: + discard + of cpLeftRight: + when compiles(op.applyPreconR(p, q)): + op.applyPreconR(p, q) if b2<0: # first call mythreads: @@ -78,8 +139,6 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) = mythreads: x := 0 r := 0 - if precon: - z := 0 r2 = 0.0 rz = 0.0 else: @@ -99,40 +158,54 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) = var itn0 = state.iterations var r2o0 = state.r2old var rzo0 = state.rzold - + var alpha0 = state.alpha toc("cg setup") + if r2 > r2stop: threads: var itn = itn0 var r2o = r2o0 var rzo = rzo0 + var alpha = alpha0 + var qlap = 0.0 verb(1): - #echo(-1, " ", r2) - echo(itn, " ", r2/b2) + echo("CG iteration: ", itn, " r2/b2: ", r2/b2) while itnr2stop: - tic() - when precon: - op.applyPrecon(z, r) - subset: - rz = r.redot z + tic("cg loop") + if itn == 0 or precon != cpLeftRight: + preconL(z, r) # z = L r or z = R r for RightNonHerm else: - rz = r2 - let beta = rz/rzo + subset: + z -= alpha * LAp + getRz() # r.z or z.z or Ap.z + var beta = 0.0 + #if precon == cpRightNonHerm: + # beta = -rz / qLAp + #else: + beta = rz/rzo r2o = r2 rzo = rz subset: - p := z + beta*p - toc("p update", flops=2*numNumbers(r[0])*sub.lenOuter) + if itn == 0: + q := z + else: + q := z + beta*q + toc("q update", flops=2*numNumbers(q[0])*sub.lenOuter) verb(3): echo "beta: ", beta + preconR(p, q) # p = R q + toc("preconR") inc itn op.apply(Ap, p) toc("Ap") + if precon == cpLeftRight: + preconL(LAp, Ap) # LAp = L Ap subset: - let pAp = p.redot(Ap) - toc("pAp", flops=2*numNumbers(p[0])*sub.lenOuter) - let alpha = rz/pAp + #let pAp = p.redot(Ap) + qLAp = q.redot(LAp) + toc("qLAp", flops=2*numNumbers(p[0])*sub.lenOuter) + alpha = rz/qLAp x += alpha*p toc("x", flops=2*numNumbers(p[0])*sub.lenOuter) r -= alpha*Ap @@ -141,43 +214,54 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) = toc("r2", flops=2*numNumbers(r[0])*sub.lenOuter) verb(2): #echo(itn, " ", r2) - echo(itn, " ", r2/b2) + echo("CG iteration: ", itn, " r2/b2: ", r2/b2) verb(3): subset: - let pAp = p.redot(Ap) - echo "p2: ", p.norm2 - echo "Ap2: ", Ap.norm2 - echo "pAp: ", pAp - echo "alpha: ", r2o/pAp + #qLAp = q.redot(LAp) + #echo "p2: ", p.norm2 + #echo "Ap2: ", Ap.norm2 + echo "rz: ", rz + echo "qLAp: ", qLAp + echo "alpha: ", alpha echo "x2: ", x.norm2 echo "r2: ", r2 + echo "z2: ", z.norm2 + echo "q2: ", q.norm2 + echo "Ap2: ", Ap.norm2 + echo "LAp2: ", LAp.norm2 op.apply(Ap, x) var fr2: float subset: - fr2 = (b - Ap).norm2 - echo " ", fr2, " ", fr2/b2 + threadBarrier() + Ap -= b + threadBarrier() + fr2 = Ap.norm2 + echo "fr2: ", fr2, " fr2/b2: ", fr2/b2 if itn mod 64 == 0: aggregateTimers() toc("cg iterations") if threadNum==0: itn0 = itn r2o0 = r2o rzo0 = rzo + alpha0 = alpha #var fr2: float #op.apply(Ap, x) #subset: # r := b - Ap # fr2 = r.norm2 #verb(1): - # echo iterations, " acc r2:", r2/b2 - # echo iterations, " tru r2:", fr2/b2 + # echo iterations, " acc r2: ", r2/b2 + # echo iterations, " tru r2: ", fr2/b2 state.iterations = itn0 state.r2old = r2o0 state.r2 = r2 state.rzold = rzo0 state.rz = rz + state.alpha = alpha0 + #state.qLAp = qLAp verb(1): - echo state.iterations, " acc r2:", r2/b2 + echo "CG final iterations: ", state.iterations, " r2/b2: ", r2/b2 #threads: # op.apply(Ap, x) # var fr2: float @@ -213,21 +297,36 @@ when isMainModule: type opArgs = object m: type(m) - var oa = opArgs(m: m) + precon: CgPrecon + var oa = opArgs(m: m, precon: cpNone) proc apply*(oa: opArgs; r: type(v1); x: type(v1)) = r := oa.m*x #mul(r, m, x) + type opArgsP = object m: type(m) - var oap = opArgsP(m: m) + precon: CgPrecon + var oap = opArgsP(m: m, precon: cpHerm) proc apply*(oa: opArgsP; r: type(v1); x: type(v1)) = r := oa.m*x #mul(r, m, x) proc applyPrecon*(oa: opArgsP; r: type(v1); x: type(v1)) = - for e in r: - let t = sqrt(1.0 / m[e][0,0]) - r[e] := t * x[e] - #mul(r, m, x) + for e in r: + let t = sqrt(1.0 / m[e][0,0]) + r[e] := t * x[e] + #mul(r, m, x) + var precL = true + var precR = true + proc applyPreconL*(oa: opArgsP; r: type(v1); x: type(v1)) = + if precL: + applyPrecon(oa, r, x) + else: + r := x + proc applyPreconR*(oa: opArgsP; r: type(v1); x: type(v1)) = + if precR: + applyPrecon(oa, r, x) + else: + r := x var sp:SolverParams sp.r2req = 1e-20 @@ -249,12 +348,20 @@ when isMainModule: template resid(r,b,x,oa: untyped) = oa.apply(r, x) r := b - r + template checkResid(r, b, x, oa: auto) = + resid(r, b, x, oa) + let r2 = r.norm2 + let b2 = b.norm2 + echo "true r2/b2: ", r2/b2 #cgSolve(v2, v1, oa, sp) var cg = newCgState(x=v2, b=v1) + echo "starting cg.solve" cg.solve(oa, sp) - echo sp.finalIterations + checkResid(v3, v1, v2, oa) + echo "end cg.solve iterations: ", sp.finalIterations + echo "starting cg.solve restart test" v2 := 0 cg.reset sp.maxits = 0 @@ -268,7 +375,10 @@ when isMainModule: #cg.r := v3 #cg.r2 = tr2 echo sp.finalIterations, " ", cg.r2, "/", cg.r2stop + checkResid(v3, v1, v2, oa) + echo "end cg.solve restart test" + echo "starting cg.solve restart test 2" v2 := 0 cg.reset sp.maxits = 0 @@ -277,12 +387,64 @@ when isMainModule: cg.solve(oa, sp) let c = cg.x.norm2 echo cg.iterations, ": ", c, " ", cg.r2 + checkResid(v3, v1, v2, oa) + echo "end cg.solve restart test 2" + + echo "starting cg.solve cpHerm restart test" + v2 := 0 + cg.reset + sp.maxits = 0 + sp.verbosity = 0 + while cg.r2 > cg.r2stop: + sp.maxits += 10 + cg.solve(oap, sp) + let c = cg.x.norm2 + echo cg.iterations, ": ", c, " ", cg.r2 + checkResid(v3, v1, v2, oa) + echo "end cg.solve cpHerm restart test" + + echo "starting cg.solve cpLeftRight restart test" + oap.precon = cpLeftRight + v2 := 0 + cg.reset + sp.maxits = 0 + sp.verbosity = 1 + while cg.r2 > cg.r2stop: + sp.maxits += 10 + cg.solve(oap, sp) + let c = cg.x.norm2 + echo cg.iterations, ": ", c, " ", cg.r2 + checkResid(v3, v1, v2, oa) + echo "end cg.solve cpLeftRight restart test" + echo "starting cg.solve cpLeftRight R restart test" + oap.precon = cpLeftRight + precL = false + precR = true v2 := 0 cg.reset sp.maxits = 0 + sp.verbosity = 0 + while cg.r2 > cg.r2stop: + sp.maxits += 10 + cg.solve(oap, sp) + let c = cg.x.norm2 + echo cg.iterations, ": ", c, " ", cg.r2 + checkResid(v3, v1, v2, oa) + echo "end cg.solve cpLeftRight R restart test" + + echo "starting cg.solve cpLeftRight L restart test" + oap.precon = cpLeftRight + precL = true + precR = false + v2 := 0 + cg.reset + sp.maxits = 0 + sp.verbosity = 0 while cg.r2 > cg.r2stop: sp.maxits += 10 cg.solve(oap, sp) let c = cg.x.norm2 echo cg.iterations, ": ", c, " ", cg.r2 + checkResid(v3, v1, v2, oa) + echo "end cg.solve cpLeftRight L restart test" diff --git a/src/solvers/gcr.nim b/src/solvers/gcr.nim index 86d9096..dad25f6 100644 --- a/src/solvers/gcr.nim +++ b/src/solvers/gcr.nim @@ -53,17 +53,21 @@ template `[]`(gs: GcrState, i: int): untyped = gs.vecs[i] #template level(gs: GcrState, i: int): untyped = gs[i].level proc combine*(gs: var GcrState, n: int) = - var c = gs[n-1].alpha / gs[n].alpha - gs[n].Avec += c * gs[n-1].Avec - swap(gs[n-1].Avec, gs[n].Avec) - gs[n-1].Avn = gs[n].Avn + c.norm2 * gs[n-1].Avn - c += gs[n].beta[n-1] - gs[n].vec += c * gs[n-1].vec - swap(gs[n-1].vec, gs[n].vec) - gs[n-1].alpha = gs[n].alpha + let gsp = addr gs + var c = gsp[][n-1].alpha / gsp[][n].alpha + var d = c + gsp[][n].beta[n-1] + threads: + #block: + var c = gsp[][n-1].alpha / gsp[][n].alpha + gsp[][n].Avec += c * gsp[][n-1].Avec + gsp[][n-1].Avn = gsp[][n].Avn + c.norm2 * gsp[][n-1].Avn + gsp[][n].vec += d * gsp[][n-1].vec + gsp[][n-1].alpha = gsp[][n].alpha for i in 0.. rsqstop and total_iterations < max_iterations: + tic() inc(iteration) inc(total_iterations) #echo "begin addvec" gs.addvec + toc("addvec") #echo "end addvec" let nv = gs.nv - 1 if nv == 1: - gs[0].vec := gs[0].alpha * gs[0].vec - op.apply(gs[0].Avec, gs[0].vec) - gs[0].Avn = gs[0].Avec.norm2 - op.apply(Ap, x) - r := b - Ap - let ctmp = dot(gs[0].Avec, r) - gs[0].alpha = ctmp / gs[0].Avn - r -= gs[0].alpha * gs[0].Avec - rsq = r.norm2 - verb(2): - echo iteration, " rsq: ", gs.r2, " -> ", rsq - gs.r2 = rsq - op.preconditioner(gs[nv].vec, gs) - op.apply(gs[nv].Avec, gs[nv].vec) + let gsp = addr gs + threads: + gsp[][0].vec := gsp[][0].alpha * gsp[][0].vec + op.apply(gsp[][0].Avec, gsp[][0].vec) + gsp[][0].Avn = gsp[][0].Avec.norm2 + op.apply(Ap, x) + r := b - Ap + let ctmp = dot(gsp[][0].Avec, r) + gsp[][0].alpha = ctmp / gsp[][0].Avn + r -= gsp[][0].alpha * gsp[][0].Avec + rsq = r.norm2 + verb(2): + echo iteration, " rsq: ", gsp[].r2, " -> ", rsq + gsp[].r2 = rsq + let gsp = addr gs + threads: + op.preconditioner(gsp[][nv].vec, gsp[]) + op.apply(gsp[][nv].Avec, gsp[][nv].vec) gs.orth gs[nv].Avn = gs[nv].Avec.norm2 let ctmp = dot(gs[nv].Avec, r) @@ -239,10 +251,13 @@ proc solve*(gs: var GcrState; opx: var auto; sp: var SolverParams) = # rsq, relnorm2) gs.getx() #gs.fini() - opx = op - #res_arg.final_rsq = rsq div insq - #res_arg.final_rel = relnorm2 - sp.finalIterations = iteration + #opx = op + sp.calls += 1 + sp.iterations += iteration + sp.iterationsMax = max(sp.iterationsMax, iteration) + sp.seconds += getElapsedTime() + sp.flops += 0 + sp.r2.push gs.r2 verb(1): echo "GCR: its: ", iteration, " rsq: ", rsq #return QOP_SUCCESS diff --git a/src/solvers/solverBase.nim b/src/solvers/solverBase.nim index 907b2b1..68924c6 100644 --- a/src/solvers/solverBase.nim +++ b/src/solvers/solverBase.nim @@ -82,11 +82,11 @@ proc addStats*(sp0: var SolverParams, sp1: SolverParams) = sp0.r2 += sp1.r2 proc getStats*(sp: SolverParams, typ0= -1): string = - let c = sp.calls + let c = max(1,sp.calls) let its = sp.iterations let ic = its div c let im = sp.iterationsMax - let s = sp.seconds + let s = max(1e-12, sp.seconds) let sc = s/c.float let f = sp.flops let gf = 1e-9*f/s diff --git a/tests/base/tshift.nim b/tests/base/tshift.nim index 4afe1d5..4ddc04f 100644 --- a/tests/base/tshift.nim +++ b/tests/base/tshift.nim @@ -56,9 +56,9 @@ proc testfb(x,y,z: auto, mu,d: int): float = # echoAll xi, " ", yi, " ", zi result = res -proc test2(Smd: typedesc, lat: array): float = +proc test2[N,T](Smd: typedesc, lat: array[N,T]): float = const vl = int Smd.numNumbers - let nd = lat.len + const nd = lat.len var lo = newLayout(lat, vl) type LatReal = Field[vl, Smd] var x,y,z: LatReal diff --git a/tests/diffnum b/tests/diffnum index b65c7b0..16b5772 100755 --- a/tests/diffnum +++ b/tests/diffnum @@ -13,7 +13,7 @@ if len(sys.argv)>2: rel = 1e-12 if len(sys.argv)>3: rel = float(sys.argv[3]) -spchars = " \t[\]," +spchars = " \t[\\]," if len(sys.argv)>4: spchars = sys.argv[4] #print(fn1, fn2, rel, repr(spchars))