Skip to content

Commit

Permalink
TL: adapted to last changes on fnop + some debug
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Oct 30, 2024
1 parent fdf2a3c commit aacb5d4
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions pySDC/playgrounds/dedalus/sdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def setupNN(cls, nnType, **params):
if nnType == "FNOP-1":
from fnop.inference.inference import FNOInference as ModelClass
elif nnType == "FNOP-2":
from fnop.fno import FourierNeuralOp as ModelClass
from fnop.training.fno_pysdc import FourierNeuralOp as ModelClass
cls.model = ModelClass(**params)
cls.initSweep = "NN"

Expand Down Expand Up @@ -430,13 +430,15 @@ def _presetStateCoeffSpace(self, state):
def _toNumpy(self, state):
"""Extract from state fields a 3D numpy array containing ux, uz, b and p,
to be given to a NN model."""
for field in state:
field.change_scales(1)
field.require_grid_space()
return np.asarray(
# ux , uz , b , p
[state[2].data[0], state[2].data[1], state[1].data, state[0].data])

def _setStateWith(self, u, state):
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state.
Warning : state has to be in grid space"""
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state."""
np.copyto(state[2].data[0], u[0]) # ux
np.copyto(state[2].data[1], u[1]) # uz
np.copyto(state[1].data, u[2]) # b
Expand Down Expand Up @@ -510,6 +512,11 @@ def _initSweep(self):
np.copyto(LXk[m].data, LXk[0].data)
np.copyto(Fk[m].data, Fk[0].data)

elif self.initSweep == "NN":
# nothing to do, initialization of tendencies already done
# during last sweep ...
pass

else:
raise NotImplementedError(f'initSweep={self.initSweep}')

Expand Down Expand Up @@ -565,19 +572,23 @@ def _sweep(self, k):
if self.initSweep == "NN" and k == self.nSweeps-1:
# => evaluate current state with NN to be used
# for the tendencies at k=0 for the initial guess of next step
uState = self._toNumpy(solver.state)
current = solver.state
state = [field.copy() for field in current]
uState = self._toNumpy(state)
uNext = self.model(uState)
self._setStateWith(uNext, solver.state)
np.clip(uNext[2], a_min=0, a_max=1, out=uNext[2]) # temporary : clip buoyancy between 0 and 1
self._setStateWith(uNext, state)
solver.state = state
tEval += dt

# Evaluate and store LX with current state
self._evalLX(LXk1[m])
# Evaluate and store F(X, t) with current state
self._evalF(Fk1[m], tEval, dt, wall_time)
# Evaluate and store LX with current state
self._evalLX(LXk1[m])

if self.initSweep == "NN" and k == self.nSweeps-1:
# Reset state if it was used for NN initial guess
self._setStateWith(uState, solver.state)
solver.state = current

# Inverse position for iterate k and k+1 in storage
# ie making the new evaluation the old for next iteration
Expand Down

0 comments on commit aacb5d4

Please sign in to comment.