diff --git a/ptypy/accelerate/cuda_cupy/engines/stochastic.py b/ptypy/accelerate/cuda_cupy/engines/stochastic.py index f798d569e..6348fa887 100644 --- a/ptypy/accelerate/cuda_cupy/engines/stochastic.py +++ b/ptypy/accelerate/cuda_cupy/engines/stochastic.py @@ -227,9 +227,13 @@ def engine_iterate(self, num=1): Compute one iteration. """ self.dID_list = list(self.di.S.keys()) - error = {} + for it in range(num): + reduced_error = np.zeros((3,)) + reduced_error_count = 0 + local_error = {} + for iblock, dID in enumerate(self.dID_list): # find probe, object and exit ID in dependence of dID @@ -378,14 +382,24 @@ def engine_iterate(self, num=1): err_fourier = prep.err_fourier_gpu.get() err_phot = prep.err_phot_gpu.get() err_exit = prep.err_exit_gpu.get() - errs = np.ascontiguousarray( - np.vstack([err_fourier, err_phot, err_exit]).T) - error.update(zip(prep.view_IDs, errs)) + errs = np.ascontiguousarray(np.vstack([err_fourier, err_phot, err_exit]).T) + if self.p.record_local_error: + local_error.update(zip(prep.view_IDs, errs)) + else: + reduced_error += errs.sum(axis=0) + reduced_error_count += errs.shape[0] + + if self.p.record_local_error: + error = local_error + else: + # Gather errors across all MPI ranks + error = parallel.allreduce(reduced_error) + count = parallel.allreduce(reduced_error_count) + error /= count # wait for the async transfers self.qu_dtoh.synchronize() - self.error = error return error def position_update_local(self, prep, i): diff --git a/ptypy/accelerate/cuda_pycuda/engines/stochastic.py b/ptypy/accelerate/cuda_pycuda/engines/stochastic.py index d45a67218..dec8de24f 100644 --- a/ptypy/accelerate/cuda_pycuda/engines/stochastic.py +++ b/ptypy/accelerate/cuda_pycuda/engines/stochastic.py @@ -222,9 +222,13 @@ def engine_iterate(self, num=1): Compute one iteration. """ self.dID_list = list(self.di.S.keys()) - error = {} + for it in range(num): + reduced_error = np.zeros((3,)) + reduced_error_count = 0 + local_error = {} + for iblock, dID in enumerate(self.dID_list): # find probe, object and exit ID in dependence of dID @@ -357,12 +361,23 @@ def engine_iterate(self, num=1): err_phot = prep.err_phot_gpu.get() err_exit = prep.err_exit_gpu.get() errs = np.ascontiguousarray(np.vstack([err_fourier, err_phot, err_exit]).T) - error.update(zip(prep.view_IDs, errs)) + if self.p.record_local_error: + local_error.update(zip(prep.view_IDs, errs)) + else: + reduced_error += errs.sum(axis=0) + reduced_error_count += errs.shape[0] + + if self.p.record_local_error: + error = local_error + else: + # Gather errors across all MPI ranks + error = parallel.allreduce(reduced_error) + count = parallel.allreduce(reduced_error_count) + error /= count # wait for the async transfers self.qu_dtoh.synchronize() - self.error = error return error def position_update_local(self, prep, i):