diff --git a/pyg_lib/csrc/ops/cpu/radix_sort.h b/pyg_lib/csrc/ops/cpu/radix_sort.h index 8629958bb..86eb10afa 100644 --- a/pyg_lib/csrc/ops/cpu/radix_sort.h +++ b/pyg_lib/csrc/ops/cpu/radix_sort.h @@ -60,16 +60,16 @@ void radix_sort_kernel(K* input_keys, V* input_values, K* output_keys, V* output_values, - int elements_count, - int* histogram, - int* histogram_ps, + int64_t elements_count, + int64_t* histogram, + int64_t* histogram_ps, int pass) { int tid = omp_get_thread_num(); int nthreads = omp_get_num_threads(); - int elements_count_4 = elements_count / 4 * 4; + int64_t elements_count_4 = elements_count / 4 * 4; - int* local_histogram = &histogram[RDX_HIST_SIZE * tid]; - int* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid]; + int64_t* local_histogram = &histogram[RDX_HIST_SIZE * tid]; + int64_t* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid]; // Step 1: compute histogram for (int i = 0; i < RDX_HIST_SIZE; i++) { @@ -97,7 +97,7 @@ void radix_sort_kernel(K* input_keys, #pragma omp barrier // Step 2: prefix sum if (tid == 0) { - int sum = 0, prev_sum = 0; + int64_t sum = 0, prev_sum = 0; for (int bins = 0; bins < RDX_HIST_SIZE; bins++) { for (int t = 0; t < nthreads; t++) { sum += histogram[t * RDX_HIST_SIZE + bins]; @@ -123,7 +123,7 @@ void radix_sort_kernel(K* input_keys, int bin_3 = (key_3 >> (pass * 8)) & 0xFF; int bin_4 = (key_4 >> (pass * 8)) & 0xFF; - int pos; + int64_t pos; pos = local_histogram_ps[bin_1]++; output_keys[pos] = key_1; output_values[pos] = input_values[i]; @@ -140,7 +140,7 @@ void radix_sort_kernel(K* input_keys, if (tid == (nthreads - 1)) { for (int64_t i = elements_count_4; i < elements_count; ++i) { K key = input_keys[i]; - int pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++; + int64_t pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++; output_keys[pos] = key; output_values[pos] = input_values[i]; } @@ -161,11 +161,11 @@ std::pair radix_sort_parallel(K* inp_key_buf, int64_t elements_count, int64_t max_value) { int maxthreads = omp_get_max_threads(); - std::unique_ptr histogram_tmp(new int[RDX_HIST_SIZE * maxthreads]); - std::unique_ptr histogram_ps_tmp( - new int[RDX_HIST_SIZE * maxthreads + 1]); - int* histogram = histogram_tmp.get(); - int* histogram_ps = histogram_ps_tmp.get(); + std::unique_ptr histogram_tmp(new int64_t[RDX_HIST_SIZE * maxthreads]); + std::unique_ptr histogram_ps_tmp( + new int64_t[RDX_HIST_SIZE * maxthreads + 1]); + int64_t* histogram = histogram_tmp.get(); + int64_t* histogram_ps = histogram_ps_tmp.get(); if (max_value == 0) { return std::make_pair(inp_key_buf, inp_value_buf); }