diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index c9ce73bc4..ae2b3f210 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -1255,7 +1255,9 @@ def transform_single_component(self, u, axes=None, padding=None): fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) - _in = self.get_aligned(result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft) + _in = self.get_aligned( + result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape + ) alignment = self.ndim + _axes[-1] @@ -1266,10 +1268,12 @@ def transform_single_component(self, u, axes=None, padding=None): axes_next_base = axes_collapsed[(trf + 1) % len(axes_collapsed)] alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1] - result = self.get_aligned(_out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True) + result = self.get_aligned( + _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape + ) fft = self.get_fft(axes=axes, padding=padding) - return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft, forward=True) + return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft, forward=True, shape=shape) def transform(self, u, axes=None, padding=None): """ @@ -1384,7 +1388,9 @@ def itransform_single_component(self, u, axes=None, padding=None): fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) - _in = self.get_aligned(result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft) + _in = self.get_aligned( + result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape + ) if self.comm is not None: _in /= np.prod([self.axes[i].N for i in _axes]) @@ -1393,43 +1399,79 @@ def itransform_single_component(self, u, axes=None, padding=None): _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) for _ax in _axes: - shape[_ax] = _out.shape[_ax] + if fft: + shape[_ax] = fft._input_shape[_ax] + else: + shape[_ax] = _out.shape[_ax] axes_next_base = axes_collapsed[(trf + 1) % len(axes_collapsed)] alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0] - result = self.get_aligned(_out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False) + result = self.get_aligned( + _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape + ) fft = self.get_fft(axes=axes, padding=padding) - return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft) + return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft, shape=shape) - def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, fill=True, **kwargs): + def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs): """ Realign the data along the axis when using distributed FFTs Args: u: The solution - axis (int): New alignment + axis_in (int): Current alignment + axis_out (int): New alignment + fft (mpi4py_fft.PFFT), optional: parallel FFT object + forward (bool): Whether the input is in spectral space or not Returns: - solution aligned on `axis` + solution aligned on `axis_in` """ - if self.comm is None: - if fill: - return u - elif forward: - return self.u_init_forward - else: - return self.u_init - - from mpi4py_fft import newDistArray + if self.comm is None or axis_in == axis_out: + return u.copy() fft = self.get_fft(**kwargs) if fft is None else fft - _in = newDistArray(fft, forward).redistribute(axis_in) - if fill: - _in[...] = u + global_fft = self.get_fft(**kwargs) + axisA = [me.axisA for me in global_fft.transfer] + axisB = [me.axisB for me in global_fft.transfer] + + current_axis = axis_in + + if axis_in in axisA and axis_out in axisB: + while current_axis != axis_out: + transfer = global_fft.transfer[axisA.index(current_axis)] + + arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) + arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) + arrayA[:] = u[:] + + transfer.forward(arrayA=arrayA, arrayB=arrayB) + + current_axis = transfer.axisB + u = arrayB + return u + elif axis_in in axisB and axis_out in axisA: + while current_axis != axis_out: + transfer = global_fft.transfer[axisB.index(current_axis)] + + arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) + arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) + arrayB[:] = u[:] + + transfer.backward(arrayA=arrayA, arrayB=arrayB) + + current_axis = transfer.axisA + u = arrayA + return u + else: # go the potentially slower route of not reusing transfer classes + from mpi4py_fft import newDistArray + + _in = newDistArray(fft, forward).redistribute(axis_in) + if fill: + _in[...] = u - return _in.redistribute(axis_out) + return _in.redistribute(axis_out) def itransform(self, u, axes=None, padding=None): axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes diff --git a/pySDC/tests/test_problems/test_RayleighBenard.py b/pySDC/tests/test_problems/test_RayleighBenard.py index 54d4a0a16..b3b4c260d 100644 --- a/pySDC/tests/test_problems/test_RayleighBenard.py +++ b/pySDC/tests/test_problems/test_RayleighBenard.py @@ -289,9 +289,9 @@ def test_Nyquist_mode_elimination(): if __name__ == '__main__': # test_eval_f(2**0, 2**2, 'z', True) # test_Poisson_problem(1, 'T') - # test_Poisson_problem_v() + test_Poisson_problem_v() # test_Nusselt_numbers(1) # test_buoyancy_computation() # test_viscous_dissipation() # test_CFL() - test_Nyquist_mode_elimination() + # test_Nyquist_mode_elimination()