From f54bb8f868255a19cff045055c458f1f1d5d4dab Mon Sep 17 00:00:00 2001 From: Thibaut Lunet Date: Wed, 23 Oct 2024 10:30:22 +0200 Subject: [PATCH] TL: coupling with fnop --- pySDC/playgrounds/dedalus/sdc.py | 52 +++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/pySDC/playgrounds/dedalus/sdc.py b/pySDC/playgrounds/dedalus/sdc.py index b56b39c61..1a9c3107f 100644 --- a/pySDC/playgrounds/dedalus/sdc.py +++ b/pySDC/playgrounds/dedalus/sdc.py @@ -120,6 +120,9 @@ class IMEXSDCCore(object): dt = None axpy = None + # For NN use to compute initial guess, etc ... + model = None + @classmethod def setParameters(cls, nNodes=None, nodeType=None, quadType=None, implSweep=None, explSweep=None, initSweep=None, @@ -185,6 +188,15 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None, diagonal *= np.all(np.diag(np.diag(cls.QDelta0)) == cls.QDelta0) cls.diagonal = diagonal + @classmethod + def setupNN(cls, nnType, **params): + if nnType == "FNOP-C": + from fnop.inference.inference import FNOInference as ModelClass + elif nnType == "FNOP-T": + from fnop.fno import FourierNeuralOp as ModelClass + cls.model = ModelClass(**params) + cls.initSweep = "NN" + # ------------------------------------------------------------------------- # Class properties # ------------------------------------------------------------------------- @@ -252,6 +264,7 @@ def __init__(self, solver): self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) self.dt = None self.firstEval = True + self.firstStep = True @property def M(self): @@ -414,6 +427,22 @@ def _presetStateCoeffSpace(self, state): for field in state: field.preset_layout('c') + def _toNumpy(self, state): + """Extract from state fields a 3D numpy array containing ux, uz, b and p, + to be given to a NN model.""" + return np.asarray( + # ux , uz , b , p + [state[2].data[0], state[2].data[1], state[1].data, state[0].data]) + + def _setStateWith(self, u, state): + """Write a 3D numpy array containing ux, uz, b and p into a dedalus state. + Warning : state has to be in grid space""" + np.copyto(state[2].data[0], u[0]) # ux + np.copyto(state[2].data[1], u[1]) # uz + np.copyto(state[1].data, u[2]) # b + np.copyto(state[0].data, u[3]) # p + + def _initSweep(self): """ Initialize node terms for one given time-step @@ -474,7 +503,7 @@ def _initSweep(self): # Evaluate and store F(X, t) with current state self._evalF(Fk[m], t0+dt*tau[m], dt, wall_time) - elif self.initSweep == 'COPY': + elif self.initSweep == 'COPY' or (self.initSweep == "NN" and self.firstStep): self._evalLX(LXk[0]) self._evalF(Fk[0], t0, dt, wall_time) for m in range(1, self.M): @@ -525,16 +554,30 @@ def _sweep(self, k): self._solveAndStoreState(k, m) # Avoid non necessary RHS evaluations work - if not self.forceProl and k == self.nSweeps-1: + if not self.forceProl and k == self.nSweeps-1 and self.initSweep != "NN": if self.diagonal: continue elif m == self.M-1: continue + tEval = t0+dt*tau[m] + # In case NN is used for initial guess (last sweep only) + if self.initSweep == "NN" and k == self.nSweeps-1: + # => evaluate current state with NN to be used + # for the tendencies at k=0 for the initial guess of next step + uState = self._toNumpy(solver.state) + uNext = self.model(uState) + self._setStateWith(uNext, solver.state) + tEval += dt + # Evaluate and store LX with current state self._evalLX(LXk1[m]) # Evaluate and store F(X, t) with current state - self._evalF(Fk1[m], t0+dt*tau[m], dt, wall_time) + self._evalF(Fk1[m], tEval, dt, wall_time) + + if self.initSweep == "NN" and k == self.nSweeps-1: + # Reset state if it was used for NN initial guess + self._setStateWith(uState, solver.state) # Inverse position for iterate k and k+1 in storage # ie making the new evaluation the old for next iteration @@ -611,6 +654,7 @@ def step(self, dt, wall_time): if self.doProlongation: self._prolongation() - # Update simulation time and reset evaluation tag + # Update simulation time and update tags self.solver.sim_time += dt self.firstEval = True + self.firstStep = False