From ba7f3c5a58418f6a4bdcfacc36ddb6c435ba41ca Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Wed, 25 Sep 2024 13:05:54 -0500 Subject: [PATCH] experimental/graph/hmcgauge: tuning for pure gauge HMC --- src/experimental/graph/hmcgauge.nim | 323 ++++++++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 src/experimental/graph/hmcgauge.nim diff --git a/src/experimental/graph/hmcgauge.nim b/src/experimental/graph/hmcgauge.nim new file mode 100644 index 0000000..1e20775 --- /dev/null +++ b/src/experimental/graph/hmcgauge.nim @@ -0,0 +1,323 @@ +import qex +import core, scalar, gauge +from os import fileExists +from strformat import `&` +from math import cos, PI + +proc newOneOf(x: float): float = 0.0 + +type AdamW[Param] = object + alpha, beta1, beta2, eps, lambda: float + m: Param + v: Param + +func newAdamW[Param](param: Param, alpha = 0.001, beta1 = 0.9, beta2 = 0.999, eps = 1e-8, lambda = 0.01): AdamW[Param] = + result = AdamW[Param](alpha: alpha, beta1: beta1, beta2: beta2, eps: eps, lambda: lambda) + result.m = newOneOf param + result.v = newOneOf param + +func newAdam[Param](param: Param, alpha = 0.001, beta1 = 0.9, beta2 = 0.999, eps = 1e-8): AdamW[Param] = + newAdamW[Param](param, alpha, beta1, beta2, eps, lambda = 0.0) + +proc optimize[Param](opt: var AdamW[Param], param: var Param, grad: Param, t: int, lr: float) = + ## arXiv:1711.05101, standard Adam if lambda == 0, effectively scale grad by alpha/stdev(grad) for descent. + ## Decay term is equivalent to an additional term of lambda/(2 scale) param^2 ~ lambda/(2 alpha) stdev(grad) param^2, added to the objective. + ## Normalized weight decay suggests lambda = lambda_norm sqrt(b/BT), for batch size b, training number B, total epoch T. + let + a = opt.alpha + b1 = opt.beta1 + b2 = opt.beta2 + sb1 = 1.0 - b1 + sb2 = 1.0 - b2 + sb1t = 1.0 - b1^t + sb2t = 1.0 - b2^t + dr = opt.lambda + eps = opt.eps + for i in 0..0: + g = axexpmuly(t0, p, g) + p = p - h * gaugeForce(gc, g) + g = axexpmuly(t1, p, g) + p = p - h * gaugeForce(gc, g) + g = axexpmuly(t05, p, g) + (g, p, @[lambda]) + +proc int4MN3F1GP(gc, g0, p0, dt: Gvalue, n: int, coeffs: openarray[float]): (Gvalue, Gvalue, seq[Gvalue]) = + let lambda = coeffs.get(0, 0.2470939580390842) + let theta = coeffs.get(1, 0.5 - 1.0 / sqrt(24.0 * lambda.getfloat)) + # scale the force gradient coeff to about the same order as the other + let chi = coeffs.get(2, (1.0 - sqrt(6.0 * lambda.getfloat) * (1.0 - lambda.getfloat)) / 12.0 * (2.0 / (1.0 - 2.0*lambda.getfloat) * 10.0)) + var g = g0 + var p = p0 + let a0 = theta*dt + let a02 = 2.0*a0 + let a1 = 0.5*dt - a0 + let b0 = lambda*dt + let b1 = dt - 2.0*b0 + let c1 = 0.1*chi*(dt*dt) + g = axexpmuly(a0, p, g) + for i in 0..0: + g = axexpmuly(a02, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a1, p, g) + p = p - b1 * gaugeForce(gc, axexpmuly(-c1, gaugeForce(gc, g), g)) + g = axexpmuly(a1, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a0, p, g) + (g, p, @[lambda, theta, chi]) + +proc int4MN5F2GP(gc, g0, p0, dt: Gvalue, n: int, coeffs: openarray[float]): (Gvalue, Gvalue, seq[Gvalue]) = + let rho = coeffs.get(0, 0.06419108866816235) + let theta = coeffs.get(1, 0.1919807940455741) + let vtheta = coeffs.get(2, 0.1518179640276466) + let lambda = coeffs.get(3, 0.2158369476787619) + # scale the force gradient coeff to about the same order as the other + let xi = coeffs.get(4, 0.0009628905212024874 * (2.0 / lambda.getfloat * 20.0)) + var g = g0 + var p = p0 + let a0 = rho*dt + let a02 = 2.0*a0 + let a1 = theta*dt + let a2 = (0.5-(theta+rho))*dt + let b1 = lambda*dt + let b0 = vtheta*dt + let b2 = (1.0-2.0*(lambda+vtheta))*dt + let c1 = 0.05*xi*(dt*dt) + g = axexpmuly(a0, p, g) + for i in 0..0: + g = axexpmuly(a02, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a1, p, g) + p = p - b1 * gaugeForce(gc, axexpmuly(-c1, gaugeForce(gc, g), g)) + g = axexpmuly(a2, p, g) + p = p - b2 * gaugeForce(gc, g) + g = axexpmuly(a2, p, g) + p = p - b1 * gaugeForce(gc, axexpmuly(-c1, gaugeForce(gc, g), g)) + g = axexpmuly(a1, p, g) + p = p - b0 * gaugeForce(gc, g) + g = axexpmuly(a0, p, g) + (g, p, @[rho, theta, vtheta, lambda, xi]) + +qexInit() + +tic() + +letParam: + gaugefile = "" + savefile = "config" + savefreq = 0 + lat = + if fileExists(gaugefile): + getFileLattice gaugefile + else: + if gaugefile.len > 0: + qexWarn "Nonexistent gauge file: ", gaugefile + @[8,8,8,16] + beta = 5.4 + dt = 0.025 + trajsThermo = 0 + trajsTrain = 50 + trajsTrainlrWarm = 10 + trajsInfer = 0 + lrmax = 1.0 + lrmin = 0.0001 + weightDecay = 0.0 + seed:uint = 1234567891 + gintalg = "2MN" + lambda = @[0.0] + gsteps = 4 + alwaysAccept:bool = 0 + +echo "rank ", myRank, "/", nRanks +threads: echo "thread ", threadNum, "/", numThreads + +installStandardParams() +echoParams() +processHelpParam() + +let + lo = lat.newLayout + vol = lo.physVol + gc = actWilson(beta) + +var r = lo.newRNGField(RngMilc6, seed) +var R:RngMilc6 # global RNG +R.seed(seed, 987654321) + +var + g = lo.newgauge + p = lo.newgauge + +if fileExists(gaugefile): + tic("load") + if 0 != g.loadGauge gaugefile: + qexError "failed to load gauge file: ", gaugefile + qexLog "loaded gauge from file: ", gaugefile," secs: ",getElapsedTime() + toc("read") + g.reunit + toc("reunit") +else: + #g.random r + g.unit + +g.echoPlaq + +let gdt = toGvalue dt +var params = @[gdt] +let + gg = toGvalue g + gp = toGvalue p + ga0 = gc.gaugeAction gg + t0 = 0.5 * gp.norm2 + h0 = ga0 + t0 + tau = float(gsteps) * gdt + (g1, p1, coeffs) = case gintalg + of "2MN": + int2MN(gc, gg, gp, gdt, gsteps, lambda) + of "4MN3F1GP": + int4MN3F1GP(gc, gg, gp, gdt, gsteps, lambda) + of "4MN5F2GP": + int4MN5F2GP(gc, gg, gp, gdt, gsteps, lambda) + else: + raise newException(ValueError, "unknown intalg: " & gintalg) + ga1 = gc.gaugeAction g1 + t1 = 0.5 * p1.norm2 + h1 = ga1 + t1 + dH = h1 - h0 + acc = cond(dH<0.0, 1.0, exp(-dH)) + loss = -acc * (tau * tau) + +params.add coeffs +var grads = newseq[Gvalue]() +for x in params: + grads.add loss.grad x + +var param = newseq[float]() +for x in params: + param.add x.getfloat +var grad = param +var opt = newAdamW(param, lambda = weightDecay) + +block: + var ps = "param:" + for i in 0.. 0 and traj mod savefreq == 0: + tic("save") + let fn = savefile & &".{traj:05}.lime" + if 0 != g.saveGauge(fn): + qexError "Failed to save gauge to file: ",fn + qexLog "saved gauge to file: ",fn," secs: ",getElapsedTime() + toc("done") + + qexLog "traj ",traj," secs: ",getElapsedTime() + toc("traj end") + +toc() + +processSaveParams() +writeParamFile() +qexFinalize()