Skip to content

Commit

Permalink
added new reduced error logic to accelerated stochastic engines
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Mar 5, 2024
1 parent 0401be3 commit 2c90200
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
24 changes: 19 additions & 5 deletions ptypy/accelerate/cuda_cupy/engines/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,13 @@ def engine_iterate(self, num=1):
Compute one iteration.
"""
self.dID_list = list(self.di.S.keys())
error = {}

for it in range(num):

reduced_error = np.zeros((3,))
reduced_error_count = 0
local_error = {}

for iblock, dID in enumerate(self.dID_list):

# find probe, object and exit ID in dependence of dID
Expand Down Expand Up @@ -378,14 +382,24 @@ def engine_iterate(self, num=1):
err_fourier = prep.err_fourier_gpu.get()
err_phot = prep.err_phot_gpu.get()
err_exit = prep.err_exit_gpu.get()
errs = np.ascontiguousarray(
np.vstack([err_fourier, err_phot, err_exit]).T)
error.update(zip(prep.view_IDs, errs))
errs = np.ascontiguousarray(np.vstack([err_fourier, err_phot, err_exit]).T)
if self.p.record_local_error:
local_error.update(zip(prep.view_IDs, errs))
else:
reduced_error += errs.sum(axis=0)
reduced_error_count += errs.shape[0]

if self.p.record_local_error:
error = local_error
else:
# Gather errors across all MPI ranks
error = parallel.allreduce(reduced_error)
count = parallel.allreduce(reduced_error_count)
error /= count

# wait for the async transfers
self.qu_dtoh.synchronize()

self.error = error
return error

def position_update_local(self, prep, i):
Expand Down
21 changes: 18 additions & 3 deletions ptypy/accelerate/cuda_pycuda/engines/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ def engine_iterate(self, num=1):
Compute one iteration.
"""
self.dID_list = list(self.di.S.keys())
error = {}

for it in range(num):

reduced_error = np.zeros((3,))
reduced_error_count = 0
local_error = {}

for iblock, dID in enumerate(self.dID_list):

# find probe, object and exit ID in dependence of dID
Expand Down Expand Up @@ -357,12 +361,23 @@ def engine_iterate(self, num=1):
err_phot = prep.err_phot_gpu.get()
err_exit = prep.err_exit_gpu.get()
errs = np.ascontiguousarray(np.vstack([err_fourier, err_phot, err_exit]).T)
error.update(zip(prep.view_IDs, errs))
if self.p.record_local_error:
local_error.update(zip(prep.view_IDs, errs))
else:
reduced_error += errs.sum(axis=0)
reduced_error_count += errs.shape[0]

if self.p.record_local_error:
error = local_error
else:
# Gather errors across all MPI ranks
error = parallel.allreduce(reduced_error)
count = parallel.allreduce(reduced_error_count)
error /= count

# wait for the async transfers
self.qu_dtoh.synchronize()

self.error = error
return error

def position_update_local(self, prep, i):
Expand Down

0 comments on commit 2c90200

Please sign in to comment.