Skip to content

Commit

Permalink
fix bug to add large tensor support
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuan <[email protected]>
  • Loading branch information
kaixuanliu committed Mar 14, 2024
1 parent 0ae8cdf commit 3591822
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions pyg_lib/csrc/ops/cpu/radix_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ void radix_sort_kernel(K* input_keys,
V* input_values,
K* output_keys,
V* output_values,
int elements_count,
int* histogram,
int* histogram_ps,
int64_t elements_count,
int64_t* histogram,
int64_t* histogram_ps,
int pass) {
int tid = omp_get_thread_num();
int nthreads = omp_get_num_threads();
int elements_count_4 = elements_count / 4 * 4;
int64_t elements_count_4 = elements_count / 4 * 4;

Check warning on line 69 in pyg_lib/csrc/ops/cpu/radix_sort.h

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/ops/cpu/radix_sort.h#L69

Added line #L69 was not covered by tests

int* local_histogram = &histogram[RDX_HIST_SIZE * tid];
int* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid];
int64_t* local_histogram = &histogram[RDX_HIST_SIZE * tid];
int64_t* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid];

Check warning on line 72 in pyg_lib/csrc/ops/cpu/radix_sort.h

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/ops/cpu/radix_sort.h#L71-L72

Added lines #L71 - L72 were not covered by tests

// Step 1: compute histogram
for (int i = 0; i < RDX_HIST_SIZE; i++) {
Expand Down Expand Up @@ -97,7 +97,7 @@ void radix_sort_kernel(K* input_keys,
#pragma omp barrier
// Step 2: prefix sum
if (tid == 0) {
int sum = 0, prev_sum = 0;
int64_t sum = 0, prev_sum = 0;

Check warning on line 100 in pyg_lib/csrc/ops/cpu/radix_sort.h

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/ops/cpu/radix_sort.h#L100

Added line #L100 was not covered by tests
for (int bins = 0; bins < RDX_HIST_SIZE; bins++) {
for (int t = 0; t < nthreads; t++) {
sum += histogram[t * RDX_HIST_SIZE + bins];
Expand All @@ -123,7 +123,7 @@ void radix_sort_kernel(K* input_keys,
int bin_3 = (key_3 >> (pass * 8)) & 0xFF;
int bin_4 = (key_4 >> (pass * 8)) & 0xFF;

int pos;
int64_t pos;
pos = local_histogram_ps[bin_1]++;
output_keys[pos] = key_1;
output_values[pos] = input_values[i];
Expand All @@ -140,7 +140,7 @@ void radix_sort_kernel(K* input_keys,
if (tid == (nthreads - 1)) {
for (int64_t i = elements_count_4; i < elements_count; ++i) {
K key = input_keys[i];
int pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++;
int64_t pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++;

Check warning on line 143 in pyg_lib/csrc/ops/cpu/radix_sort.h

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/ops/cpu/radix_sort.h#L143

Added line #L143 was not covered by tests
output_keys[pos] = key;
output_values[pos] = input_values[i];
}
Expand All @@ -161,11 +161,11 @@ std::pair<K*, V*> radix_sort_parallel(K* inp_key_buf,
int64_t elements_count,
int64_t max_value) {
int maxthreads = omp_get_max_threads();
std::unique_ptr<int[]> histogram_tmp(new int[RDX_HIST_SIZE * maxthreads]);
std::unique_ptr<int[]> histogram_ps_tmp(
new int[RDX_HIST_SIZE * maxthreads + 1]);
int* histogram = histogram_tmp.get();
int* histogram_ps = histogram_ps_tmp.get();
std::unique_ptr<int64_t[]> histogram_tmp(new int64_t[RDX_HIST_SIZE * maxthreads]);
std::unique_ptr<int64_t[]> histogram_ps_tmp(
new int64_t[RDX_HIST_SIZE * maxthreads + 1]);
int64_t* histogram = histogram_tmp.get();
int64_t* histogram_ps = histogram_ps_tmp.get();

Check warning on line 168 in pyg_lib/csrc/ops/cpu/radix_sort.h

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/ops/cpu/radix_sort.h#L164-L168

Added lines #L164 - L168 were not covered by tests
if (max_value == 0) {
return std::make_pair(inp_key_buf, inp_value_buf);
}
Expand Down

0 comments on commit 3591822

Please sign in to comment.