From 52e71aad426873ee143d7234b2e4e53e270fdd72 Mon Sep 17 00:00:00 2001 From: Joseph Thomas Guman Date: Sun, 14 Apr 2024 15:00:17 -0700 Subject: [PATCH 1/3] Initial implementation of np.unique(return_index=True) Signed-off-by: Joseph Guman --- cunumeric/array.py | 17 +- cunumeric/config.py | 2 + cunumeric/deferred.py | 61 ++++++- cunumeric/eager.py | 16 +- cunumeric/module.py | 9 +- cunumeric/thunk.py | 4 +- cunumeric_cpp.cmake | 2 + src/cunumeric/cunumeric_c.h | 1 + src/cunumeric/mapper.cc | 3 +- src/cunumeric/set/unique.cc | 37 ++-- src/cunumeric/set/unique.cu | 171 ++++++++++++++++--- src/cunumeric/set/unique_omp.cc | 98 ++++++++--- src/cunumeric/set/unique_reduce_template.inl | 85 ++++++--- src/cunumeric/set/unique_template.inl | 43 ++++- src/cunumeric/set/unzip_indices.cc | 35 ++++ src/cunumeric/set/unzip_indices.h | 34 ++++ src/cunumeric/set/unzip_indices_omp.cc | 29 ++++ src/cunumeric/set/unzip_indices_template.inl | 85 +++++++++ src/cunumeric/set/zip_indices.h | 32 ++++ tests/integration/test_unique.py | 20 ++- 20 files changed, 667 insertions(+), 117 deletions(-) create mode 100644 src/cunumeric/set/unzip_indices.cc create mode 100644 src/cunumeric/set/unzip_indices.h create mode 100644 src/cunumeric/set/unzip_indices_omp.cc create mode 100644 src/cunumeric/set/unzip_indices_template.inl create mode 100644 src/cunumeric/set/zip_indices.h diff --git a/cunumeric/array.py b/cunumeric/array.py index 3b628ae4d..177277cf2 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -4152,7 +4152,9 @@ def view( writeable=self._writeable, ) - def unique(self) -> ndarray: + def unique( + self, return_index: bool = False + ) -> Union[ndarray, tuple[ndarray, ndarray]]: """a.unique() Find the unique elements of an array. @@ -4168,8 +4170,17 @@ def unique(self) -> ndarray: Multiple GPUs, Multiple CPUs """ - thunk = self._thunk.unique() - return ndarray(shape=thunk.shape, thunk=thunk) + thunk = self._thunk.unique(return_index) + if return_index: + if TYPE_CHECKING: + thunk = cast(tuple[NumPyThunk, NumPyThunk], thunk) + return ndarray(shape=thunk[0].shape, thunk=thunk[0]), ndarray( + shape=thunk[1].shape, thunk=thunk[1] + ) + else: + if TYPE_CHECKING: + thunk = cast(NumPyThunk, thunk) + return ndarray(shape=thunk.shape, thunk=thunk) @classmethod def _get_where_thunk( diff --git a/cunumeric/config.py b/cunumeric/config.py index c18d36f4b..1175aa71b 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -212,6 +212,7 @@ class _CunumericSharedLib: CUNUMERIC_UNARY_RED: int CUNUMERIC_UNIQUE: int CUNUMERIC_UNIQUE_REDUCE: int + CUNUMERIC_UNZIP_INDICES: int CUNUMERIC_UNLOAD_CUDALIBS: int CUNUMERIC_UNPACKBITS: int CUNUMERIC_UOP_ABSOLUTE: int @@ -378,6 +379,7 @@ class CuNumericOpCode(IntEnum): UNARY_RED = _cunumeric.CUNUMERIC_UNARY_RED UNIQUE = _cunumeric.CUNUMERIC_UNIQUE UNIQUE_REDUCE = _cunumeric.CUNUMERIC_UNIQUE_REDUCE + UNZIP = _cunumeric.CUNUMERIC_UNZIP_INDICES UNLOAD_CUDALIBS = _cunumeric.CUNUMERIC_UNLOAD_CUDALIBS UNPACKBITS = _cunumeric.CUNUMERIC_UNPACKBITS WHERE = _cunumeric.CUNUMERIC_WHERE diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 9d9fa963d..9ae1a5154 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -3487,25 +3487,68 @@ def scan( assert self.shape == swapped.shape self.copy(swapped, deep=True) - def unique(self) -> NumPyThunk: - result = self.runtime.create_unbound_thunk(self.base.type) - + def unique( + self, return_index: bool = False + ) -> Union[NumPyThunk, tuple[NumPyThunk, Optional[NumPyThunk]]]: task = self.context.create_auto_task(CuNumericOpCode.UNIQUE) - task.add_output(result.base) task.add_input(self.base) + task.add_scalar_arg(return_index, ty.bool_) + result = None + # Assuming legate core will always choose GPU variant if self.runtime.num_gpus > 0: task.add_nccl_communicator() + result = self.runtime.create_unbound_thunk(self.base.type) + elif return_index: + result = self.runtime.create_unbound_thunk( + ty.struct_type( + [ + self.base.type, + ty.int64, + ], + True, + ) + ) + else: + result = self.runtime.create_unbound_thunk(self.base.type) + task.add_output(result.base) + + returned_indices = None + if return_index: + returned_indices = self.runtime.create_unbound_thunk(ty.int64) + if self.runtime.num_gpus > 0: + task.add_output(returned_indices.base) + + for i in range(self.ndim): + task.add_scalar_arg(self.shape[i], ty.int32) task.execute() - if self.runtime.num_gpus == 0 and self.runtime.num_procs > 1: - result.base = self.context.tree_reduce( - CuNumericOpCode.UNIQUE_REDUCE, result.base - ) + if self.runtime.num_gpus == 0: + if self.runtime.num_procs > 1: + result.base = self.context.tree_reduce( + CuNumericOpCode.UNIQUE_REDUCE, + result.base, + scalar_args=[(return_index, ty.bool_)], + ) + if return_index: + task = self.context.create_auto_task(CuNumericOpCode.UNZIP) + task.add_input(result.base) - return result + result = self.runtime.create_unbound_thunk(self.base.type) + + task.add_output(result.base) + + returned_indices = cast(DeferredArray, returned_indices) + task.add_output(returned_indices.base) + + task.execute() + + if return_index: + return result, returned_indices + else: + return result @auto_convert("rhs", "v") def searchsorted(self, rhs: Any, v: Any, side: SortSide = "left") -> None: diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 4e6e504c2..86b4d33b8 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -1713,11 +1713,21 @@ def scan( else: raise RuntimeError(f"unsupported scan op {op}") - def unique(self) -> NumPyThunk: + def unique( + self, return_index: bool = False + ) -> Union[NumPyThunk, tuple[NumPyThunk, Optional[NumPyThunk]]]: if self.deferred is not None: - return self.deferred.unique() + return self.deferred.unique(return_index=return_index) else: - return EagerArray(self.runtime, np.unique(self.array)) + if return_index: + np_values, np_indices = np.unique( + self.array, return_index=return_index + ) + return EagerArray(self.runtime, np_values), EagerArray( + self.runtime, np_indices + ) + else: + return EagerArray(self.runtime, np.unique(self.array)) def create_window(self, op_code: WindowOpCode, M: int, *args: Any) -> None: if self.deferred is not None: diff --git a/cunumeric/module.py b/cunumeric/module.py index 424f89df4..cd8ef8006 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -6803,7 +6803,7 @@ def unique( return_inverse: bool = False, return_counts: bool = False, axis: Optional[int] = None, -) -> ndarray: +) -> Union[ndarray, tuple[ndarray, ndarray]]: """ Find the unique elements of an array. @@ -6868,12 +6868,13 @@ def unique( `axis` is also not handled currently. """ - if _builtin_any((return_index, return_inverse, return_counts, axis)): + if _builtin_any((return_inverse, return_counts, axis)): raise NotImplementedError( - "Keyword arguments for `unique` are not yet supported" + "Keyword arguments for `unique` outside" + " of return_index are not yet supported" ) - return ar.unique() + return ar.unique(return_index) ################################## diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index 68aafb6c9..b1cb450dc 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -723,7 +723,9 @@ def scan( ... @abstractmethod - def unique(self) -> NumPyThunk: + def unique( + self, return_index: bool = False + ) -> Union[NumPyThunk, tuple[NumPyThunk, Optional[NumPyThunk]]]: ... @abstractmethod diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index be5c0fbe6..5bb9b52b7 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -164,6 +164,7 @@ list(APPEND cunumeric_SOURCES src/cunumeric/search/nonzero.cc src/cunumeric/set/unique.cc src/cunumeric/set/unique_reduce.cc + src/cunumeric/set/unzip_indices.cc src/cunumeric/stat/bincount.cc src/cunumeric/convolution/convolve.cc src/cunumeric/transform/flip.cc @@ -217,6 +218,7 @@ if(Legion_USE_OpenMP) src/cunumeric/search/nonzero_omp.cc src/cunumeric/set/unique_omp.cc src/cunumeric/set/unique_reduce_omp.cc + src/cunumeric/set/unzip_indices_omp.cc src/cunumeric/stat/bincount_omp.cc src/cunumeric/convolution/convolve_omp.cc src/cunumeric/transform/flip_omp.cc diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index b38ab6620..0016a9be5 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -72,6 +72,7 @@ enum CuNumericOpCode { CUNUMERIC_UNARY_RED, CUNUMERIC_UNIQUE, CUNUMERIC_UNIQUE_REDUCE, + CUNUMERIC_UNZIP_INDICES, CUNUMERIC_UNLOAD_CUDALIBS, CUNUMERIC_UNPACKBITS, CUNUMERIC_WHERE, diff --git a/src/cunumeric/mapper.cc b/src/cunumeric/mapper.cc index 5fd36bceb..a542ce2e9 100644 --- a/src/cunumeric/mapper.cc +++ b/src/cunumeric/mapper.cc @@ -105,7 +105,8 @@ std::vector CuNumericMapper::store_mappings( } case CUNUMERIC_MATMUL: case CUNUMERIC_MATVECMUL: - case CUNUMERIC_UNIQUE_REDUCE: { + case CUNUMERIC_UNIQUE_REDUCE: + case CUNUMERIC_UNZIP_INDICES: { // TODO: Our actual requirements are a little less strict than this; we require each array or // vector to have a stride of 1 on at least one dimension. std::vector mappings; diff --git a/src/cunumeric/set/unique.cc b/src/cunumeric/set/unique.cc index 7aa09d0e5..fc4d20cf9 100644 --- a/src/cunumeric/set/unique.cc +++ b/src/cunumeric/set/unique.cc @@ -25,25 +25,42 @@ template struct UniqueImplBody { using VAL = legate_type_of; - void operator()(Array& output, + void operator()(std::vector& outputs, const AccessorRO& in, const Pitches& pitches, const Rect& rect, const size_t volume, const std::vector& comms, const DomainPoint& point, - const Domain& launch_domain) + const Domain& launch_domain, + const bool return_index, + const DomainPoint& parent_point) { - std::set dedup_set; + auto& output = outputs[0]; + if (return_index) { + std::set, IndexEquality> dedup_set; + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + auto value = in[p]; + int64_t index = rowwise_linearize(DIM, p, parent_point); - for (size_t idx = 0; idx < volume; ++idx) { - auto p = pitches.unflatten(idx, rect.lo); - dedup_set.insert(in[p]); - } + dedup_set.insert(ZippedIndex({value, index})); + } + + auto result = output.create_output_buffer, 1>(dedup_set.size(), true); + size_t pos = 0; + for (auto e : dedup_set) { result[pos++] = e; } + } else { + std::set dedup_set; + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + dedup_set.insert(in[p]); + } - auto result = output.create_output_buffer(dedup_set.size(), true); - size_t pos = 0; - for (auto e : dedup_set) result[pos++] = e; + auto result = output.create_output_buffer(dedup_set.size(), true); + size_t pos = 0; + for (auto e : dedup_set) result[pos++] = e; + } } }; diff --git a/src/cunumeric/set/unique.cu b/src/cunumeric/set/unique.cu index 302077c5f..b770cdbae 100644 --- a/src/cunumeric/set/unique.cu +++ b/src/cunumeric/set/unique.cu @@ -42,19 +42,42 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[offset] = accessor[point]; } +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + fill_subset_indices(int64_t* out, + const Point lo, + const Pitches pitches, + const size_t volume, + const DomainPoint parent_point) +{ + size_t offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= volume) return; + auto point = pitches.unflatten(offset, lo); + int multiplier = 1; + size_t index = 0; + for (int i = DIM - 1; i >= 0; i--) { + index += point[i] * multiplier; + multiplier *= parent_point[i]; + } + out[offset] = index; +} + template using Piece = std::pair, size_t>; auto get_aligned_size = [](auto size) { return std::max(16, (size + 15) / 16 * 16); }; template -static Piece tree_reduce(Array& output, - Piece my_piece, - size_t my_id, - size_t num_ranks, - cudaStream_t stream, - ncclComm_t* comm) +static std::pair, Piece> tree_reduce(std::vector& outputs, + Piece my_piece, + Piece indices, + size_t my_id, + size_t num_ranks, + cudaStream_t stream, + ncclComm_t* comm, + bool return_index) { + auto& output = outputs[0]; size_t remaining = num_ranks; size_t radix = 2; auto all_sizes = create_buffer(num_ranks, Memory::Z_COPY_MEM); @@ -67,6 +90,7 @@ static Piece tree_reduce(Array& output, CHECK_CUDA(cudaStreamSynchronize(stream)); Piece other_piece; + Piece other_index; size_t offset = radix / 2; bool received_something = false; CHECK_NCCL(ncclGroupStart()); @@ -83,6 +107,14 @@ static Piece tree_reduce(Array& output, other_piece.first = create_buffer(buf_size); CHECK_NCCL( ncclRecv(other_piece.first.ptr(0), recv_size, ncclInt8, other_id, *comm, stream)); + if (return_index) { + other_index.second = other_size; + auto recv_size_index = get_aligned_size(other_size * sizeof(int64_t)); + auto buf_size_index = (recv_size_index + sizeof(int64_t) - 1) / sizeof(int64_t); + other_index.first = create_buffer(buf_size_index); + CHECK_NCCL( + ncclRecv(other_index.first.ptr(0), recv_size_index, ncclInt8, other_id, *comm, stream)); + } received_something = true; } } else if (my_id % radix == offset) // This is one of the senders @@ -90,6 +122,11 @@ static Piece tree_reduce(Array& output, auto other_id = my_id - offset; auto send_size = get_aligned_size(my_piece.second * sizeof(VAL)); CHECK_NCCL(ncclSend(my_piece.first.ptr(0), send_size, ncclInt8, other_id, *comm, stream)); + if (return_index) { + auto send_size_index = get_aligned_size(indices.second * sizeof(int64_t)); + CHECK_NCCL( + ncclSend(indices.first.ptr(0), send_size_index, ncclInt8, other_id, *comm, stream)); + } } CHECK_NCCL(ncclGroupEnd()); @@ -101,19 +138,44 @@ static Piece tree_reduce(Array& output, auto p_mine = my_piece.first.ptr(0); auto p_other = other_piece.first.ptr(0); - thrust::merge(DEFAULT_POLICY.on(stream), - p_mine, - p_mine + my_piece.second, - p_other, - p_other + other_piece.second, - p_merged); - auto* end = thrust::unique(DEFAULT_POLICY.on(stream), p_merged, p_merged + merged_size); + Buffer merged_index; + if (return_index) { + merged_index = create_buffer(merged_size); + auto p_merged_index = merged_index.ptr(0); + auto p_mine_index = indices.first.ptr(0); + auto p_other_index = other_index.first.ptr(0); + + auto my_zip = thrust::make_zip_iterator(thrust::make_tuple(p_mine, p_mine_index)); + auto other_zip = thrust::make_zip_iterator(thrust::make_tuple(p_other, p_other_index)); + auto final_zip = thrust::make_zip_iterator(thrust::make_tuple(p_merged, p_merged_index)); + + thrust::merge(DEFAULT_POLICY.on(stream), + my_zip, + my_zip + my_piece.second, + other_zip, + other_zip + other_piece.second, + final_zip); + + auto end = thrust::unique_by_key( + DEFAULT_POLICY.on(stream), p_merged, p_merged + merged_size, p_merged_index); + + my_piece.second = end.first - p_merged; + indices.second = my_piece.second; + } else { + thrust::merge(DEFAULT_POLICY.on(stream), + p_mine, + p_mine + my_piece.second, + p_other, + p_other + other_piece.second, + p_merged); + auto* end = thrust::unique(DEFAULT_POLICY.on(stream), p_merged, p_merged + merged_size); + my_piece.second = end - p_merged; + } // Make sure we release the memory so that we can reuse it my_piece.first.destroy(); other_piece.first.destroy(); - my_piece.second = end - p_merged; auto buf_size = (get_aligned_size(my_piece.second * sizeof(VAL)) + sizeof(VAL) - 1) / sizeof(VAL); assert(my_piece.second <= buf_size); @@ -125,6 +187,24 @@ static Piece tree_reduce(Array& output, cudaMemcpyDeviceToDevice, stream)); merged.destroy(); + + if (return_index) { + indices.first.destroy(); + other_index.first.destroy(); + + auto buf_size_index = + (get_aligned_size(my_piece.second * sizeof(int64_t)) + sizeof(int64_t) - 1) / + sizeof(int64_t); + assert(my_piece.second <= buf_size_index); + indices.first = outputs[1].create_output_buffer(buf_size_index); + + CHECK_CUDA(cudaMemcpyAsync(indices.first.ptr(0), + merged_index.ptr(0), + sizeof(int64_t) * my_piece.second, + cudaMemcpyDeviceToDevice, + stream)); + merged_index.destroy(); + } } remaining = (remaining + 1) / 2; @@ -134,30 +214,45 @@ static Piece tree_reduce(Array& output, if (my_id != 0) { my_piece.second = 0; my_piece.first = output.create_output_buffer(0); + indices.second = 0; + indices.first = output.create_output_buffer(0); } - return my_piece; + return {my_piece, indices}; } template struct UniqueImplBody { using VAL = legate_type_of; - void operator()(Array& output, + void operator()(std::vector& outputs, const AccessorRO& in, const Pitches& pitches, const Rect& rect, const size_t volume, const std::vector& comms, const DomainPoint& point, - const Domain& launch_domain) + const Domain& launch_domain, + const bool return_index, + const DomainPoint& parent_point) { - auto stream = get_cached_stream(); + auto& output = outputs[0]; + auto stream = get_cached_stream(); // Make a copy of the input as we're going to sort it auto temp = create_buffer(volume); VAL* ptr = temp.ptr(0); VAL* end = ptr; + + int64_t* index_ptr = nullptr; + if (return_index) { + auto index_temp = create_buffer(volume); + index_ptr = index_temp.ptr(0); + const size_t num_blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + fill_subset_indices<<>>( + index_ptr, rect.lo, pitches, volume, parent_point); + } + if (volume > 0) { if (in.accessor.is_dense_arbitrary(rect)) { auto* src = in.ptr(rect.lo); @@ -170,30 +265,58 @@ struct UniqueImplBody { } CHECK_CUDA_STREAM(stream); - // Find unique values - thrust::sort(DEFAULT_POLICY.on(stream), ptr, ptr + volume); - end = thrust::unique(DEFAULT_POLICY.on(stream), ptr, ptr + volume); + if (return_index) { + // Find unique values with corresponding index + auto zip_start = thrust::make_zip_iterator(thrust::make_tuple(ptr, index_ptr)); + + thrust::sort(DEFAULT_POLICY.on(stream), zip_start, zip_start + volume); + auto tuple_end = + thrust::unique_by_key(DEFAULT_POLICY.on(stream), ptr, ptr + volume, index_ptr); + end = tuple_end.first; + } else { + // Find unique values + thrust::sort(DEFAULT_POLICY.on(stream), ptr, ptr + volume); + end = thrust::unique(DEFAULT_POLICY.on(stream), ptr, ptr + volume); + } } Piece result; - result.second = end - ptr; + Piece indices; + result.second = end - ptr; + indices.second = end - ptr; auto buf_size = (get_aligned_size(result.second * sizeof(VAL)) + sizeof(VAL) - 1) / sizeof(VAL); assert(end - ptr <= buf_size); result.first = output.create_output_buffer(buf_size); - if (result.second > 0) + if (return_index) { + auto buf_size = + (get_aligned_size(result.second * sizeof(int64_t)) + sizeof(int64_t) - 1) / sizeof(int64_t); + indices.first = outputs[1].create_output_buffer(buf_size); + } + if (result.second > 0) { CHECK_CUDA(cudaMemcpyAsync( result.first.ptr(0), ptr, sizeof(VAL) * result.second, cudaMemcpyDeviceToDevice, stream)); + if (return_index) + CHECK_CUDA(cudaMemcpyAsync(indices.first.ptr(0), + index_ptr, + sizeof(int64_t) * indices.second, + cudaMemcpyDeviceToDevice, + stream)); + } if (comms.size() > 0) { // The launch domain is 1D because of the output region assert(point.dim == 1); auto comm = comms[0].get(); - result = tree_reduce(output, result, point[0], launch_domain.get_volume(), stream, comm); + auto ret = tree_reduce( + outputs, result, indices, point[0], launch_domain.get_volume(), stream, comm, return_index); + result = ret.first; + if (return_index) indices = ret.second; } CHECK_CUDA_STREAM(stream); // Finally we pack the result output.bind_data(result.first, Point<1>(result.second)); + if (return_index) { outputs[1].bind_data(indices.first, Point<1>(indices.second)); } } }; diff --git a/src/cunumeric/set/unique_omp.cc b/src/cunumeric/set/unique_omp.cc index 37a86582b..96b2a60e3 100644 --- a/src/cunumeric/set/unique_omp.cc +++ b/src/cunumeric/set/unique_omp.cc @@ -27,48 +27,96 @@ template struct UniqueImplBody { using VAL = legate_type_of; - void operator()(Array& output, + void operator()(std::vector& outputs, const AccessorRO& in, const Pitches& pitches, const Rect& rect, const size_t volume, const std::vector& comms, const DomainPoint& point, - const Domain& launch_domain) + const Domain& launch_domain, + const bool return_index, + const DomainPoint& parent_point) { + auto& output = outputs[0]; const auto max_threads = omp_get_max_threads(); - std::vector> dedup_set(max_threads); + + if (return_index) { + std::vector, IndexEquality>> dedup_set(max_threads); #pragma omp parallel - { - const int tid = omp_get_thread_num(); - auto& my_dedup_set = dedup_set[tid]; + { + const int tid = omp_get_thread_num(); + auto& my_dedup_set = dedup_set[tid]; #pragma omp for schedule(static) - for (size_t idx = 0; idx < volume; ++idx) { - auto p = pitches.unflatten(idx, rect.lo); - my_dedup_set.insert(in[p]); + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + auto value = in[p]; + int64_t index = rowwise_linearize(DIM, p, parent_point); + + my_dedup_set.insert(ZippedIndex({value, index})); + } } - } - size_t remaining = max_threads; - size_t radix = (max_threads + 1) / 2; - while (remaining > 1) { + size_t remaining = max_threads; + size_t radix = (max_threads + 1) / 2; + while (remaining > 1) { #pragma omp for schedule(static, 1) - for (size_t idx = 0; idx < radix; ++idx) { - if (idx + radix < remaining) { - auto& my_set = dedup_set[idx]; - auto& other_set = dedup_set[idx + radix]; - my_set.insert(other_set.begin(), other_set.end()); + for (size_t idx = 0; idx < radix; ++idx) { + if (idx + radix < remaining) { + auto& my_set = dedup_set[idx]; + auto& other_set = dedup_set[idx + radix]; + + for (auto e : other_set) { + auto temp = my_set.find(e); + if (temp != my_set.end() && e.index < temp->index) my_set.erase(temp); + my_set.insert(e); + } + } + } + remaining = radix; + radix = (radix + 1) / 2; + } + + auto& final_dedup_set = dedup_set[0]; + auto result = output.create_output_buffer, 1>(final_dedup_set.size(), true); + size_t pos = 0; + for (auto e : final_dedup_set) result[pos++] = e; + + } else { + std::vector> dedup_set(max_threads); + +#pragma omp parallel + { + const int tid = omp_get_thread_num(); + auto& my_dedup_set = dedup_set[tid]; +#pragma omp for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + my_dedup_set.insert(in[p]); } } - remaining = radix; - radix = (radix + 1) / 2; - } - auto& final_dedup_set = dedup_set[0]; - auto result = output.create_output_buffer(final_dedup_set.size(), true); - size_t pos = 0; - for (auto e : final_dedup_set) result[pos++] = e; + size_t remaining = max_threads; + size_t radix = (max_threads + 1) / 2; + while (remaining > 1) { +#pragma omp for schedule(static, 1) + for (size_t idx = 0; idx < radix; ++idx) { + if (idx + radix < remaining) { + auto& my_set = dedup_set[idx]; + auto& other_set = dedup_set[idx + radix]; + my_set.insert(other_set.begin(), other_set.end()); + } + } + remaining = radix; + radix = (radix + 1) / 2; + } + + auto& final_dedup_set = dedup_set[0]; + auto result = output.create_output_buffer(final_dedup_set.size(), true); + size_t pos = 0; + for (auto e : final_dedup_set) result[pos++] = e; + } } }; diff --git a/src/cunumeric/set/unique_reduce_template.inl b/src/cunumeric/set/unique_reduce_template.inl index c38d289eb..6a2fbcebe 100644 --- a/src/cunumeric/set/unique_reduce_template.inl +++ b/src/cunumeric/set/unique_reduce_template.inl @@ -19,20 +19,33 @@ // Useful for IDEs #include "cunumeric/set/unique_reduce.h" #include "cunumeric/pitches.h" +#include "cunumeric/set/zip_indices.h" #include #include #include #include +#include namespace cunumeric { using namespace legate; +template +struct IndexFreeEqual { + bool operator()(const ZippedIndex& a, const ZippedIndex& b) + { + return a.value == b.value; + } +}; + template struct UniqueReduceImpl { template - void operator()(Array& output, std::vector& input_arrs, const exe_pol_t& exe_pol) + void operator()(Array& output, + std::vector& input_arrs, + const exe_pol_t& exe_pol, + bool return_index) { using VAL = legate_type_of; @@ -41,33 +54,65 @@ struct UniqueReduceImpl { auto shape = input_arr.shape<1>(); res_size += shape.hi[0] - shape.lo[0] + 1; } - auto result = output.create_output_buffer(Point<1>(res_size)); - VAL* res_ptr = result.ptr(0); - size_t offset = 0; - for (auto& input_arr : input_arrs) { - size_t strides[1]; - Rect<1> shape = input_arr.shape<1>(); - size_t volume = shape.volume(); - const VAL* in_ptr = input_arr.read_accessor(shape).ptr(shape, strides); - assert(shape.volume() <= 1 || strides[0] == 1); - thrust::copy(exe_pol, in_ptr, in_ptr + volume, res_ptr + offset); - offset += volume; - } - assert(offset == res_size); + // Is splitting into two completely distinct cases the most concise way to do this? + if (return_index) { + auto result = output.create_output_buffer, 1>(Point<1>(res_size)); + ZippedIndex* res_ptr = result.ptr(0); - thrust::sort(exe_pol, res_ptr, res_ptr + res_size); - VAL* actual_end = thrust::unique(exe_pol, res_ptr, res_ptr + res_size); - output.bind_data(result, Point<1>(actual_end - res_ptr)); + size_t offset = 0; + for (auto& input_arr : input_arrs) { + size_t strides[1]; + Rect<1> shape = input_arr.shape<1>(); + size_t volume = shape.volume(); + const ZippedIndex* in_ptr = + input_arr.read_accessor, 1>(shape).ptr(shape, strides); + assert(shape.volume() <= 1 || strides[0] == 1); + thrust::copy(exe_pol, in_ptr, in_ptr + volume, res_ptr + offset); + offset += volume; + } + assert(offset == res_size); + + thrust::sort(exe_pol, res_ptr, res_ptr + res_size, ZippedComparator()); + ZippedIndex* actual_end = + thrust::unique(exe_pol, res_ptr, res_ptr + res_size, IndexFreeEqual()); + output.bind_data(result, Point<1>(actual_end - res_ptr)); + } else { + auto result = output.create_output_buffer(Point<1>(res_size)); + VAL* res_ptr = result.ptr(0); + + size_t offset = 0; + for (auto& input_arr : input_arrs) { + size_t strides[1]; + Rect<1> shape = input_arr.shape<1>(); + size_t volume = shape.volume(); + const VAL* in_ptr = input_arr.read_accessor(shape).ptr(shape, strides); + assert(shape.volume() <= 1 || strides[0] == 1); + thrust::copy(exe_pol, in_ptr, in_ptr + volume, res_ptr + offset); + offset += volume; + } + assert(offset == res_size); + + thrust::sort(exe_pol, res_ptr, res_ptr + res_size); + VAL* actual_end = thrust::unique(exe_pol, res_ptr, res_ptr + res_size); + output.bind_data(result, Point<1>(actual_end - res_ptr)); + } } }; template static void unique_reduce_template(TaskContext& context, const exe_pol_t& exe_pol) { - auto& inputs = context.inputs(); - auto& output = context.outputs()[0]; - type_dispatch(output.code(), UniqueReduceImpl{}, output, inputs, exe_pol); + auto& inputs = context.inputs(); + auto& output = context.outputs()[0]; + bool return_index = context.scalars()[0].value(); + Type::Code code{output.code()}; + if (return_index) { + assert(Type::Code::STRUCT == code); + auto& field_type = static_cast(output.type()).field_type(0); + code = field_type.code; + } + type_dispatch(code, UniqueReduceImpl{}, output, inputs, exe_pol, return_index); } } // namespace cunumeric diff --git a/src/cunumeric/set/unique_template.inl b/src/cunumeric/set/unique_template.inl index 1ab1a7e1f..7409e177a 100644 --- a/src/cunumeric/set/unique_template.inl +++ b/src/cunumeric/set/unique_template.inl @@ -19,22 +19,33 @@ // Useful for IDEs #include "cunumeric/set/unique.h" #include "cunumeric/pitches.h" +#include "cunumeric/set/zip_indices.h" namespace cunumeric { using namespace legate; +template +struct IndexEquality { + bool operator()(const ZippedIndex& a, const ZippedIndex& b) const + { + return a.value < b.value; + } +}; + template struct UniqueImplBody; template struct UniqueImpl { template - void operator()(Array& output, + void operator()(std::vector& outputs, Array& input, std::vector& comms, const DomainPoint& point, - const Domain& launch_domain) const + const Domain& launch_domain, + const bool return_index, + std::vector& parent_extents) const { using VAL = legate_type_of; @@ -42,26 +53,42 @@ struct UniqueImpl { Pitches pitches; size_t volume = pitches.flatten(rect); + Point parent_point; + if (return_index) { + for (int i = 0; i < DIM; i++) { parent_point[i] = parent_extents[i]; } + } + auto in = input.read_accessor(rect); UniqueImplBody()( - output, in, pitches, rect, volume, comms, point, launch_domain); + outputs, in, pitches, rect, volume, comms, point, launch_domain, return_index, parent_point); } }; template static void unique_template(TaskContext& context) { - auto& input = context.inputs()[0]; - auto& output = context.outputs()[0]; - auto& comms = context.communicators(); + auto& input = context.inputs()[0]; + auto& outputs = context.outputs(); + auto& comms = context.communicators(); + bool return_index = context.scalars()[0].value(); + if (outputs.size() > 1) { assert(return_index); } + std::vector parent_extents(input.dim()); + if (return_index) { + for (int i = 0; i < parent_extents.size(); i++) { + parent_extents[i] = context.scalars()[1 + i].value(); + } + } + double_dispatch(input.dim(), input.code(), UniqueImpl{}, - output, + outputs, input, comms, context.get_task_index(), - context.get_launch_domain()); + context.get_launch_domain(), + return_index, + parent_extents); } } // namespace cunumeric diff --git a/src/cunumeric/set/unzip_indices.cc b/src/cunumeric/set/unzip_indices.cc new file mode 100644 index 000000000..4c89352c7 --- /dev/null +++ b/src/cunumeric/set/unzip_indices.cc @@ -0,0 +1,35 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/set/unzip_indices.h" +#include "cunumeric/set/unzip_indices_template.inl" + +namespace cunumeric { + +/*static*/ void UnzipIndicesTask::cpu_variant(TaskContext& context) +{ + unzip_indices_template(context, thrust::host); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) +{ + UnzipIndicesTask::register_variants(); +} +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/set/unzip_indices.h b/src/cunumeric/set/unzip_indices.h new file mode 100644 index 000000000..c7905ff36 --- /dev/null +++ b/src/cunumeric/set/unzip_indices.h @@ -0,0 +1,34 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" + +namespace cunumeric { + +class UnzipIndicesTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_UNZIP_INDICES; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/set/unzip_indices_omp.cc b/src/cunumeric/set/unzip_indices_omp.cc new file mode 100644 index 000000000..340ec981a --- /dev/null +++ b/src/cunumeric/set/unzip_indices_omp.cc @@ -0,0 +1,29 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/set/unzip_indices.h" +#include "cunumeric/set/unzip_indices_template.inl" + +#include + +namespace cunumeric { + +/*static*/ void UnzipIndicesTask::omp_variant(TaskContext& context) +{ + unzip_indices_template(context, thrust::omp::par); +} + +} // namespace cunumeric diff --git a/src/cunumeric/set/unzip_indices_template.inl b/src/cunumeric/set/unzip_indices_template.inl new file mode 100644 index 000000000..9809f79c7 --- /dev/null +++ b/src/cunumeric/set/unzip_indices_template.inl @@ -0,0 +1,85 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +// Useful for IDEs +#include "cunumeric/set/unzip_indices.h" +#include "cunumeric/pitches.h" +#include "cunumeric/set/zip_indices.h" + +#include +#include +#include +#include +#include +#include + +namespace cunumeric { + +using namespace legate; + +template +struct ValExtract { + VAL operator()(const ZippedIndex& x) { return x.value; } +}; + +template +struct IndexExtract { + int64_t operator()(const ZippedIndex& x) { return x.index; } +}; + +template +struct UnzipIndicesImpl { + template + void operator()(std::vector& outputs, Array& input, const exe_pol_t& exe_pol) + { + using VAL = legate_type_of; + + auto input_shape = input.shape<1>(); + size_t res_size = input_shape.hi[0] - input_shape.lo[0] + 1; + + auto value = outputs[0].create_output_buffer(res_size, true); + auto index = outputs[1].create_output_buffer(res_size, true); + VAL* value_ptr = value.ptr(0); + int64_t* index_ptr = index.ptr(0); + + size_t strides[1]; + const ZippedIndex* in_ptr = + input.read_accessor, 1>(input_shape).ptr(input_shape, strides); + // unique_reduce has this check, so it's probably worthwhile to keep it here + assert(input_shape.volume() <= 1 || strides[0] == 1); + + thrust::transform(exe_pol, in_ptr, in_ptr + res_size, value_ptr, ValExtract()); + thrust::transform(exe_pol, in_ptr, in_ptr + res_size, index_ptr, IndexExtract()); + } +}; + +template +static void unzip_indices_template(TaskContext& context, const exe_pol_t& exe_pol) +{ + auto& input = context.inputs()[0]; + auto& outputs = context.outputs(); + + Type::Code code{input.code()}; + assert(Type::Code::STRUCT == code); + auto& field_type = static_cast(input.type()).field_type(0); + code = field_type.code; + + type_dispatch(code, UnzipIndicesImpl{}, outputs, input, exe_pol); +} + +} // namespace cunumeric diff --git a/src/cunumeric/set/zip_indices.h b/src/cunumeric/set/zip_indices.h new file mode 100644 index 000000000..e540e2e2c --- /dev/null +++ b/src/cunumeric/set/zip_indices.h @@ -0,0 +1,32 @@ +namespace cunumeric { + +using namespace legate; + +template +struct ZippedIndex { + VAL value; + int64_t index; +}; + +// Surprisingly it seems as though thrust can't figure out this comparison +template +struct ZippedComparator { + bool operator()(const ZippedIndex& a, const ZippedIndex& b) + { + return (a.value == b.value) ? a.index < b.index : a.value < b.value; + } +}; + +inline int64_t rowwise_linearize(int32_t DIM, const DomainPoint& p, const DomainPoint& parent_point) +{ + int multiplier = 1; + int64_t index = 0; + for (int i = DIM - 1; i >= 0; i--) { + index += p[i] * multiplier; + multiplier *= parent_point[i]; + } + + return index; +} + +} // namespace cunumeric diff --git a/tests/integration/test_unique.py b/tests/integration/test_unique.py index 28374586c..f3a4be957 100644 --- a/tests/integration/test_unique.py +++ b/tests/integration/test_unique.py @@ -31,28 +31,31 @@ def test_with_nonzero(): @pytest.mark.parametrize("ndim", range(LEGATE_MAX_DIM + 1)) -def test_ndim(ndim): - shape = (4,) * ndim +@pytest.mark.parametrize("return_index", (True, False)) +def test_ndim(ndim, return_index): + shape = (10,) * ndim a = num.random.randint(0, 3, size=shape) a_np = np.array(a) - b = np.unique(a) - b_np = num.unique(a_np) + b = np.unique(a, return_index=return_index) + b_np = num.unique(a_np, return_index=return_index) - assert np.array_equal(b, b_np) + if return_index: + assert num.array_equal(b[0], b_np[0]) + assert num.array_equal(b[1], b_np[1]) + else: + assert num.array_equal(b, b_np) @pytest.mark.xfail -@pytest.mark.parametrize("return_index", (True, False)) @pytest.mark.parametrize("return_inverse", (True, False)) @pytest.mark.parametrize("return_counts", (True, False)) @pytest.mark.parametrize("axis", (0, 1)) -def test_parameters(return_index, return_inverse, return_counts, axis): +def test_parameters(return_inverse, return_counts, axis): arr_num = num.random.randint(0, 3, size=(3, 3)) arr_np = np.array(arr_num) res_num = num.unique( arr_num, - return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, ) @@ -60,7 +63,6 @@ def test_parameters(return_index, return_inverse, return_counts, axis): # for `unique` are not yet supported res_np = np.unique( arr_np, - return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, axis=axis, From 50a11de435b34cc31390962549bb1813dd2f4228 Mon Sep 17 00:00:00 2001 From: Joseph Thomas Guman Date: Tue, 14 May 2024 16:39:26 -0700 Subject: [PATCH 2/3] Fixing some nits, expanding tests, and removing unnecessary thrust usage when unzipping indices --- cunumeric/array.py | 12 ++--- cunumeric/deferred.py | 8 +++- src/cunumeric/set/unique_reduce_template.inl | 1 - src/cunumeric/set/unzip_indices.cc | 20 ++++++++- src/cunumeric/set/unzip_indices_omp.cc | 23 ++++++++-- src/cunumeric/set/unzip_indices_template.inl | 46 +++++--------------- tests/integration/test_unique.py | 10 ++++- 7 files changed, 72 insertions(+), 48 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 177277cf2..8fcacd1b7 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -4170,17 +4170,17 @@ def unique( Multiple GPUs, Multiple CPUs """ - thunk = self._thunk.unique(return_index) + deferred_result = self._thunk.unique(return_index) if return_index: if TYPE_CHECKING: - thunk = cast(tuple[NumPyThunk, NumPyThunk], thunk) - return ndarray(shape=thunk[0].shape, thunk=thunk[0]), ndarray( - shape=thunk[1].shape, thunk=thunk[1] + deferred_result = cast(tuple[NumPyThunk, NumPyThunk], deferred_result) + return ndarray(shape=deferred_result[0].shape, thunk=deferred_result[0]), ndarray( + shape=deferred_result[1].shape, thunk=deferred_result[1] ) else: if TYPE_CHECKING: - thunk = cast(NumPyThunk, thunk) - return ndarray(shape=thunk.shape, thunk=thunk) + deferred_result = cast(NumPyThunk, deferred_result) + return ndarray(shape=deferred_result.shape, thunk=deferred_result) @classmethod def _get_where_thunk( diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 9ae1a5154..f482c90c9 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -3497,6 +3497,7 @@ def unique( result = None # Assuming legate core will always choose GPU variant + # CPU uses legate.core Reduce op, which requires storing indices in struct if self.runtime.num_gpus > 0: task.add_nccl_communicator() result = self.runtime.create_unbound_thunk(self.base.type) @@ -3516,8 +3517,9 @@ def unique( returned_indices = None if return_index: - returned_indices = self.runtime.create_unbound_thunk(ty.int64) + # GPU variant uses NCCL for reduction so can directly output indices if self.runtime.num_gpus > 0: + returned_indices = self.runtime.create_unbound_thunk(ty.int64) task.add_output(returned_indices.base) for i in range(self.ndim): @@ -3536,12 +3538,14 @@ def unique( task = self.context.create_auto_task(CuNumericOpCode.UNZIP) task.add_input(result.base) - result = self.runtime.create_unbound_thunk(self.base.type) + result = self.runtime.create_empty_thunk(result.shape, self.base.type) + returned_indices = self.runtime.create_empty_thunk(result.shape, ty.int64) task.add_output(result.base) returned_indices = cast(DeferredArray, returned_indices) task.add_output(returned_indices.base) + task.add_alignment(result.base, returned_indices.base) task.execute() diff --git a/src/cunumeric/set/unique_reduce_template.inl b/src/cunumeric/set/unique_reduce_template.inl index 6a2fbcebe..cc32a4d49 100644 --- a/src/cunumeric/set/unique_reduce_template.inl +++ b/src/cunumeric/set/unique_reduce_template.inl @@ -25,7 +25,6 @@ #include #include #include -#include namespace cunumeric { diff --git a/src/cunumeric/set/unzip_indices.cc b/src/cunumeric/set/unzip_indices.cc index 4c89352c7..a7d4e24e2 100644 --- a/src/cunumeric/set/unzip_indices.cc +++ b/src/cunumeric/set/unzip_indices.cc @@ -19,9 +19,27 @@ namespace cunumeric { +using namespace legate; + +template +struct UniqueImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& values, + const AccessorWO& indices, + const AccessorRO, 1>& in, + const Rect<1> input_shape) + { + for (coord_t i = input_shape.lo[0]; i < input_shape.hi[0] + 1; i++) { + values[i] = in[i].value; + indices[i] = in[i].index; + } + } +}; + /*static*/ void UnzipIndicesTask::cpu_variant(TaskContext& context) { - unzip_indices_template(context, thrust::host); + unzip_indices_template(context); } namespace // unnamed diff --git a/src/cunumeric/set/unzip_indices_omp.cc b/src/cunumeric/set/unzip_indices_omp.cc index 340ec981a..c8cf24c7e 100644 --- a/src/cunumeric/set/unzip_indices_omp.cc +++ b/src/cunumeric/set/unzip_indices_omp.cc @@ -17,13 +17,30 @@ #include "cunumeric/set/unzip_indices.h" #include "cunumeric/set/unzip_indices_template.inl" -#include - namespace cunumeric { +using namespace legate; + +template +struct UniqueImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& values, + const AccessorWO& indices, + const AccessorRO, 1>& in, + const Rect<1> input_shape) + { +#pragma omp parallel for schedule(static) + for (coord_t i = input_shape.lo[0]; i < input_shape.hi[0] + 1; i++) { + values[i] = in[i].value; + indices[i] = in[i].index; + } + } +}; + /*static*/ void UnzipIndicesTask::omp_variant(TaskContext& context) { - unzip_indices_template(context, thrust::omp::par); + unzip_indices_template(context); } } // namespace cunumeric diff --git a/src/cunumeric/set/unzip_indices_template.inl b/src/cunumeric/set/unzip_indices_template.inl index 9809f79c7..895eb2290 100644 --- a/src/cunumeric/set/unzip_indices_template.inl +++ b/src/cunumeric/set/unzip_indices_template.inl @@ -21,55 +21,33 @@ #include "cunumeric/pitches.h" #include "cunumeric/set/zip_indices.h" -#include -#include -#include -#include -#include -#include - namespace cunumeric { using namespace legate; -template -struct ValExtract { - VAL operator()(const ZippedIndex& x) { return x.value; } -}; +template +struct UniqueImplBody; -template -struct IndexExtract { - int64_t operator()(const ZippedIndex& x) { return x.index; } -}; - -template +template struct UnzipIndicesImpl { template - void operator()(std::vector& outputs, Array& input, const exe_pol_t& exe_pol) + void operator()(std::vector& outputs, Array& input) { using VAL = legate_type_of; auto input_shape = input.shape<1>(); - size_t res_size = input_shape.hi[0] - input_shape.lo[0] + 1; - - auto value = outputs[0].create_output_buffer(res_size, true); - auto index = outputs[1].create_output_buffer(res_size, true); - VAL* value_ptr = value.ptr(0); - int64_t* index_ptr = index.ptr(0); + if(input_shape.volume() == 0) return; - size_t strides[1]; - const ZippedIndex* in_ptr = - input.read_accessor, 1>(input_shape).ptr(input_shape, strides); - // unique_reduce has this check, so it's probably worthwhile to keep it here - assert(input_shape.volume() <= 1 || strides[0] == 1); + auto values = outputs[0].write_accessor(input_shape); + auto indices = outputs[1].write_accessor(input_shape); + auto in = input.read_accessor, 1>(input_shape); - thrust::transform(exe_pol, in_ptr, in_ptr + res_size, value_ptr, ValExtract()); - thrust::transform(exe_pol, in_ptr, in_ptr + res_size, index_ptr, IndexExtract()); + UniqueImplBody()(values, indices, in, input_shape); } }; -template -static void unzip_indices_template(TaskContext& context, const exe_pol_t& exe_pol) +template +static void unzip_indices_template(TaskContext& context) { auto& input = context.inputs()[0]; auto& outputs = context.outputs(); @@ -79,7 +57,7 @@ static void unzip_indices_template(TaskContext& context, const exe_pol_t& exe_po auto& field_type = static_cast(input.type()).field_type(0); code = field_type.code; - type_dispatch(code, UnzipIndicesImpl{}, outputs, input, exe_pol); + type_dispatch(code, UnzipIndicesImpl{}, outputs, input); } } // namespace cunumeric diff --git a/tests/integration/test_unique.py b/tests/integration/test_unique.py index f3a4be957..cc61aad49 100644 --- a/tests/integration/test_unique.py +++ b/tests/integration/test_unique.py @@ -60,7 +60,7 @@ def test_parameters(return_inverse, return_counts, axis): return_counts=return_counts, ) # cuNumeric raises NotImplementedError: Keyword arguments - # for `unique` are not yet supported + # for `unique` outside of return_index are not yet supported res_np = np.unique( arr_np, return_inverse=return_inverse, @@ -69,6 +69,14 @@ def test_parameters(return_inverse, return_counts, axis): ) assert np.array_equal(res_np, res_num) +def test_index_selection(): + a = np.array([[1,2,3,4,1,1,9,1,1],[1,4,7,2,0,1,3,10,10]]) + + b_num = np.unique(a, return_index=True) + b_np = num.unique(a, return_index=True) + + assert num.array_equal(b_num[0], b_np[0]) + assert num.array_equal(b_num[1], b_np[1]) if __name__ == "__main__": import sys From e7b0b3a7bb4291f2b8c6a8305c03be699b188e7e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 23:41:15 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cunumeric/array.py | 8 ++++++-- cunumeric/deferred.py | 8 ++++++-- src/cunumeric/set/unzip_indices.cc | 2 +- src/cunumeric/set/unzip_indices_omp.cc | 2 +- src/cunumeric/set/unzip_indices_template.inl | 6 +++--- tests/integration/test_unique.py | 4 +++- 6 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 8fcacd1b7..542f23bf3 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -4173,8 +4173,12 @@ def unique( deferred_result = self._thunk.unique(return_index) if return_index: if TYPE_CHECKING: - deferred_result = cast(tuple[NumPyThunk, NumPyThunk], deferred_result) - return ndarray(shape=deferred_result[0].shape, thunk=deferred_result[0]), ndarray( + deferred_result = cast( + tuple[NumPyThunk, NumPyThunk], deferred_result + ) + return ndarray( + shape=deferred_result[0].shape, thunk=deferred_result[0] + ), ndarray( shape=deferred_result[1].shape, thunk=deferred_result[1] ) else: diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index f482c90c9..8b3fd2108 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -3538,8 +3538,12 @@ def unique( task = self.context.create_auto_task(CuNumericOpCode.UNZIP) task.add_input(result.base) - result = self.runtime.create_empty_thunk(result.shape, self.base.type) - returned_indices = self.runtime.create_empty_thunk(result.shape, ty.int64) + result = self.runtime.create_empty_thunk( + result.shape, self.base.type + ) + returned_indices = self.runtime.create_empty_thunk( + result.shape, ty.int64 + ) task.add_output(result.base) diff --git a/src/cunumeric/set/unzip_indices.cc b/src/cunumeric/set/unzip_indices.cc index a7d4e24e2..0291225d5 100644 --- a/src/cunumeric/set/unzip_indices.cc +++ b/src/cunumeric/set/unzip_indices.cc @@ -31,7 +31,7 @@ struct UniqueImplBody { const Rect<1> input_shape) { for (coord_t i = input_shape.lo[0]; i < input_shape.hi[0] + 1; i++) { - values[i] = in[i].value; + values[i] = in[i].value; indices[i] = in[i].index; } } diff --git a/src/cunumeric/set/unzip_indices_omp.cc b/src/cunumeric/set/unzip_indices_omp.cc index c8cf24c7e..caf0dd708 100644 --- a/src/cunumeric/set/unzip_indices_omp.cc +++ b/src/cunumeric/set/unzip_indices_omp.cc @@ -32,7 +32,7 @@ struct UniqueImplBody { { #pragma omp parallel for schedule(static) for (coord_t i = input_shape.lo[0]; i < input_shape.hi[0] + 1; i++) { - values[i] = in[i].value; + values[i] = in[i].value; indices[i] = in[i].index; } } diff --git a/src/cunumeric/set/unzip_indices_template.inl b/src/cunumeric/set/unzip_indices_template.inl index 895eb2290..1e758fa74 100644 --- a/src/cunumeric/set/unzip_indices_template.inl +++ b/src/cunumeric/set/unzip_indices_template.inl @@ -36,11 +36,11 @@ struct UnzipIndicesImpl { using VAL = legate_type_of; auto input_shape = input.shape<1>(); - if(input_shape.volume() == 0) return; + if (input_shape.volume() == 0) return; - auto values = outputs[0].write_accessor(input_shape); + auto values = outputs[0].write_accessor(input_shape); auto indices = outputs[1].write_accessor(input_shape); - auto in = input.read_accessor, 1>(input_shape); + auto in = input.read_accessor, 1>(input_shape); UniqueImplBody()(values, indices, in, input_shape); } diff --git a/tests/integration/test_unique.py b/tests/integration/test_unique.py index cc61aad49..3e4063352 100644 --- a/tests/integration/test_unique.py +++ b/tests/integration/test_unique.py @@ -69,8 +69,9 @@ def test_parameters(return_inverse, return_counts, axis): ) assert np.array_equal(res_np, res_num) + def test_index_selection(): - a = np.array([[1,2,3,4,1,1,9,1,1],[1,4,7,2,0,1,3,10,10]]) + a = np.array([[1, 2, 3, 4, 1, 1, 9, 1, 1], [1, 4, 7, 2, 0, 1, 3, 10, 10]]) b_num = np.unique(a, return_index=True) b_np = num.unique(a, return_index=True) @@ -78,6 +79,7 @@ def test_index_selection(): assert num.array_equal(b_num[0], b_np[0]) assert num.array_equal(b_num[1], b_np[1]) + if __name__ == "__main__": import sys