Skip to content

Commit

Permalink
fix: Protect local memory stores with shmem_barrier_all()
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed Jul 30, 2023
1 parent b6cfc4e commit 1400715
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/libshmem/fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ static inline long *_py_shmem_pSync()
_py_shmem_pSync_array = (long *) shmem_malloc(SHMEM_SYNC_SIZE * sizeof(long));
for (int i = 0; i < SHMEM_SYNC_SIZE; i++)
_py_shmem_pSync_array[i] = SHMEM_SYNC_VALUE;
shmem_sync_all();
shmem_barrier_all();
}
return _py_shmem_pSync_array;
}
Expand All @@ -328,7 +328,7 @@ static inline void *_py_shmem_pWrk(size_t nreduce, size_t eltsize)
shmem_free(_py_shmem_pWrk_array);
_py_shmem_pWrk_size = wrk_size;
_py_shmem_pWrk_array = shmem_malloc(wrk_size);
shmem_sync_all();
shmem_barrier_all();
}
return _py_shmem_pWrk_array;
}
Expand Down
2 changes: 1 addition & 1 deletion src/libshmem/memalloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ static void *shmem_py_alloc(size_t size, size_t align, long hints, int clear)
}
if (clear) {
memset(ptr, 0, size);
shmem_sync_all();
shmem_barrier_all();
}
return ptr;
}
Expand Down
6 changes: 3 additions & 3 deletions src/shmem4py/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def array(
if tmp.ndim > 1:
a.strides = tmp.strides
np.copyto(a, tmp, casting='no')
lib.shmem_sync_all()
lib.shmem_barrier_all()
return a


Expand Down Expand Up @@ -969,7 +969,7 @@ def ones(
"""
a = new_array(shape, dtype, order, align=align, hints=hints, clear=False)
np.copyto(a, 1, casting='unsafe')
lib.shmem_sync_all()
lib.shmem_barrier_all()
return a


Expand Down Expand Up @@ -1001,7 +1001,7 @@ def full(
dtype = np.array(fill_value).dtype
a = new_array(shape, dtype, order, align=align, hints=hints, clear=False)
np.copyto(a, fill_value, casting='unsafe')
lib.shmem_sync_all()
lib.shmem_barrier_all()
return a


Expand Down

0 comments on commit 1400715

Please sign in to comment.