diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 7ad96703fa5a..f74c50cb9233 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -5,6 +5,7 @@ endif (PLUGIN_DENSE_PARSER) if (PLUGIN_SYCL) set(CMAKE_CXX_COMPILER "icpx") add_library(plugin_sycl OBJECT + ${xgboost_SOURCE_DIR}/plugin/sycl/data/gradient_index.cc ${xgboost_SOURCE_DIR}/plugin/sycl/common/hist_util.cc ${xgboost_SOURCE_DIR}/plugin/sycl/objective/regression_obj.cc ${xgboost_SOURCE_DIR}/plugin/sycl/objective/multiclass_obj.cc diff --git a/plugin/sycl/common/hist_util.cc b/plugin/sycl/common/hist_util.cc index 17307b153f60..09589c23c197 100644 --- a/plugin/sycl/common/hist_util.cc +++ b/plugin/sycl/common/hist_util.cc @@ -6,6 +6,7 @@ #include #include +#include "../data/gradient_index.h" #include "hist_util.h" #include @@ -14,18 +15,6 @@ namespace xgboost { namespace sycl { namespace common { -uint32_t SearchBin(const bst_float* cut_values, const uint32_t* cut_ptrs, Entry const& e) { - auto beg = cut_ptrs[e.index]; - auto end = cut_ptrs[e.index + 1]; - const auto &values = cut_values; - auto it = std::upper_bound(cut_values + beg, cut_values + end, e.fvalue); - uint32_t idx = it - cut_values; - if (idx == end) { - idx -= 1; - } - return idx; -} - template void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) { const size_t total_len = end - begin; @@ -49,138 +38,6 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) { } } -template -void GHistIndexMatrix::SetIndexData(::sycl::queue qu, - xgboost::common::Span index_data_span, - const DeviceMatrix &dmat_device, - size_t nbins, - size_t row_stride, - uint32_t* offsets) { - if (hit_count.size() == 0) return; - const xgboost::Entry *data_ptr = dmat_device.data.DataConst(); - const bst_row_t *offset_vec = dmat_device.row_ptr.DataConst(); - const size_t num_rows = dmat_device.row_ptr.Size() - 1; - BinIdxType* index_data = index_data_span.data(); - const bst_float* cut_values = cut_device.Values().DataConst(); - const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst(); - ::sycl::buffer hit_count_buf(hit_count.data(), hit_count.size()); - - USMVector sort_buf(&qu, num_rows * row_stride); - BinIdxType* sort_data = sort_buf.Data(); - - qu.submit([&](::sycl::handler& cgh) { - auto hit_count_acc = hit_count_buf.template get_access<::sycl::access::mode::atomic>(cgh); - cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) { - const size_t i = pid.get_id(0); - const size_t ibegin = offset_vec[i]; - const size_t iend = offset_vec[i + 1]; - const size_t size = iend - ibegin; - const size_t start = i * row_stride; - for (bst_uint j = 0; j < size; ++j) { - uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]); - index_data[start + j] = offsets ? idx - offsets[j] : idx; - ::sycl::atomic_fetch_add(hit_count_acc[idx], 1); - } - if (!offsets) { - // Sparse case only - mergeSort(index_data + start, index_data + start + size, sort_data + start); - for (bst_uint j = size; j < row_stride; ++j) { - index_data[start + j] = nbins; - } - } - }); - }).wait(); -} - -void GHistIndexMatrix::ResizeIndex(const size_t n_offsets, - const size_t n_index, - const bool isDense) { - if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { - index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize); - index.Resize((sizeof(uint8_t)) * n_index); - } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && - max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { - index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize); - index.Resize((sizeof(uint16_t)) * n_index); - } else { - index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize); - index.Resize((sizeof(uint32_t)) * n_index); - } -} - -void GHistIndexMatrix::Init(::sycl::queue qu, - Context const * ctx, - const DeviceMatrix& p_fmat_device, - int max_bins) { - nfeatures = p_fmat_device.p_mat->Info().num_col_; - - cut = xgboost::common::SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins); - cut_device.Init(qu, cut); - - max_num_bins = max_bins; - const uint32_t nbins = cut.Ptrs().back(); - this->nbins = nbins; - hit_count.resize(nbins, 0); - - this->p_fmat = p_fmat_device.p_mat; - const bool isDense = p_fmat_device.p_mat->IsDense(); - this->isDense_ = isDense; - - row_ptr = std::vector(p_fmat_device.row_ptr.Begin(), p_fmat_device.row_ptr.End()); - row_ptr_device = p_fmat_device.row_ptr; - - index.setQueue(qu); - - row_stride = 0; - for (const auto& batch : p_fmat_device.p_mat->GetBatches()) { - const auto& row_offset = batch.offset.ConstHostVector(); - for (auto i = 1ull; i < row_offset.size(); i++) { - row_stride = std::max(row_stride, static_cast(row_offset[i] - row_offset[i - 1])); - } - } - - const size_t n_offsets = cut.Ptrs().size() - 1; - const size_t n_rows = p_fmat_device.row_ptr.Size() - 1; - const size_t n_index = n_rows * row_stride; - ResizeIndex(n_offsets, n_index, isDense); - - CHECK_GT(cut.Values().size(), 0U); - - uint32_t* offsets = nullptr; - if (isDense) { - index.ResizeOffset(n_offsets); - offsets = index.Offset(); - qu.memcpy(offsets, cut.Ptrs().data(), sizeof(uint32_t) * n_offsets).wait_and_throw(); - // for (size_t i = 0; i < n_offsets; ++i) { - // offsets[i] = cut.Ptrs()[i]; - // } - } - - if (isDense) { - BinTypeSize curent_bin_size = index.GetBinTypeSize(); - if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) { - xgboost::common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); - - } else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) { - xgboost::common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); - } else { - CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize); - xgboost::common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); - } - /* For sparse DMatrix we have to store index of feature for each bin - in index field to chose right offset. So offset is nullptr and index is not reduced */ - } else { - xgboost::common::Span index_data_span = {index.data(), n_index}; - SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); - } -} - /*! * \brief Fill histogram with zeroes */ diff --git a/plugin/sycl/common/hist_util.h b/plugin/sycl/common/hist_util.h index 68d5f1010799..d022cfcdca3e 100644 --- a/plugin/sycl/common/hist_util.h +++ b/plugin/sycl/common/hist_util.h @@ -21,209 +21,8 @@ namespace common { template using GHistRow = USMVector, memory_type>; -template -using AtomicRef = ::sycl::atomic_ref; - -/*! - * \brief SYCL implementation of HistogramCuts stored in USM buffers to provide access from device kernels - */ -class HistogramCuts { - protected: - using BinIdx = uint32_t; - - public: - HistogramCuts() {} - - explicit HistogramCuts(::sycl::queue qu) { - cut_ptrs_.Resize(&qu_, 1, 0); - } - - ~HistogramCuts() { - } - - void Init(::sycl::queue qu, xgboost::common::HistogramCuts const& cuts) { - qu_ = qu; - cut_values_.Init(&qu_, cuts.cut_values_.HostVector()); - cut_ptrs_.Init(&qu_, cuts.cut_ptrs_.HostVector()); - min_vals_.Init(&qu_, cuts.min_vals_.HostVector()); - } - - // Getters for USM buffers to pass pointers into device kernels - const USMVector& Ptrs() const { return cut_ptrs_; } - const USMVector& Values() const { return cut_values_; } - const USMVector& MinValues() const { return min_vals_; } - - private: - USMVector cut_values_; - USMVector cut_ptrs_; - USMVector min_vals_; - ::sycl::queue qu_; -}; - using BinTypeSize = ::xgboost::common::BinTypeSize; -/*! - * \brief Index data and offsets stored in USM buffers to provide access from device kernels - */ -struct Index { - Index() { - SetBinTypeSize(binTypeSize_); - } - Index(const Index& i) = delete; - Index& operator=(Index i) = delete; - Index(Index&& i) = delete; - Index& operator=(Index&& i) = delete; - uint32_t operator[](size_t i) const { - if (!offset_.Empty()) { - return func_(data_.DataConst(), i) + offset_[i%p_]; - } else { - return func_(data_.DataConst(), i); - } - } - void SetBinTypeSize(BinTypeSize binTypeSize) { - binTypeSize_ = binTypeSize; - switch (binTypeSize) { - case BinTypeSize::kUint8BinsTypeSize: - func_ = &GetValueFromUint8; - break; - case BinTypeSize::kUint16BinsTypeSize: - func_ = &GetValueFromUint16; - break; - case BinTypeSize::kUint32BinsTypeSize: - func_ = &GetValueFromUint32; - break; - default: - CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize || - binTypeSize == BinTypeSize::kUint16BinsTypeSize || - binTypeSize == BinTypeSize::kUint32BinsTypeSize); - } - } - BinTypeSize GetBinTypeSize() const { - return binTypeSize_; - } - - template - T* data() { - return reinterpret_cast(data_.Data()); - } - - template - const T* data() const { - return reinterpret_cast(data_.DataConst()); - } - - uint32_t* Offset() { - return offset_.Data(); - } - - const uint32_t* Offset() const { - return offset_.DataConst(); - } - - size_t Size() const { - return data_.Size() / (binTypeSize_); - } - - void Resize(const size_t nBytesData) { - data_.Resize(&qu_, nBytesData); - } - - void ResizeOffset(const size_t nDisps) { - offset_.Resize(&qu_, nDisps); - p_ = nDisps; - } - - uint8_t* begin() const { - return data_.Begin(); - } - - uint8_t* end() const { - return data_.End(); - } - - void setQueue(::sycl::queue qu) { - qu_ = qu; - } - - private: - static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) { - return reinterpret_cast(t)[i]; - } - static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) { - return reinterpret_cast(t)[i]; - } - static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) { - return reinterpret_cast(t)[i]; - } - - using Func = uint32_t (*)(const uint8_t*, size_t); - - USMVector data_; - // size of this field is equal to number of features - USMVector offset_; - BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize}; - size_t p_ {1}; - Func func_; - - ::sycl::queue qu_; -}; - - -/*! - * \brief Preprocessed global index matrix, in CSR format, stored in USM buffers - * - * Transform floating values to integer index in histogram - */ -struct GHistIndexMatrix { - /*! \brief row pointer to rows by element position */ - std::vector row_ptr; - USMVector row_ptr_device; - /*! \brief The index data */ - Index index; - /*! \brief hit count of each index */ - std::vector hit_count; - /*! \brief The corresponding cuts */ - xgboost::common::HistogramCuts cut; - HistogramCuts cut_device; - DMatrix* p_fmat; - size_t max_num_bins; - size_t nbins; - size_t nfeatures; - size_t row_stride; - - // Create a global histogram matrix based on a given DMatrix device wrapper - void Init(::sycl::queue qu, Context const * ctx, - const sycl::DeviceMatrix& p_fmat_device, int max_num_bins); - - template - void SetIndexData(::sycl::queue qu, xgboost::common::Span index_data_span, - const sycl::DeviceMatrix &dmat_device, - size_t nbins, size_t row_stride, uint32_t* offsets); - - void ResizeIndex(const size_t n_offsets, const size_t n_index, - const bool isDense); - - inline void GetFeatureCounts(std::vector* counts) const { - auto nfeature = cut_device.Ptrs().Size() - 1; - for (unsigned fid = 0; fid < nfeature; ++fid) { - auto ibegin = cut_device.Ptrs()[fid]; - auto iend = cut_device.Ptrs()[fid + 1]; - for (auto i = ibegin; i < iend; ++i) { - (*counts)[fid] += hit_count[i]; - } - } - } - inline bool IsDense() const { - return isDense_; - } - - private: - bool isDense_; -}; - class ColumnMatrix; /*! diff --git a/plugin/sycl/common/partition_builder.h b/plugin/sycl/common/partition_builder.h index b942e85441fb..a07af244f0f4 100644 --- a/plugin/sycl/common/partition_builder.h +++ b/plugin/sycl/common/partition_builder.h @@ -9,17 +9,110 @@ #pragma GCC diagnostic ignored "-W#pragma-messages" #include #pragma GCC diagnostic pop +#include + #include #include #include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#include "../../../src/common/column_matrix.h" +#pragma GCC diagnostic pop + #include "../data.h" +#include "row_set.h" +#include "../data/gradient_index.h" +#include "../tree/expand_entry.h" #include namespace xgboost { namespace sycl { namespace common { + +// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle dense columns +template +inline ::sycl::event PartitionDenseKernel( + ::sycl::queue* qu, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, + const size_t fid, + const int32_t split_cond, + xgboost::common::Span* rid_buf, + size_t* parts_size, + ::sycl::event event) { + const size_t row_stride = gmat.row_stride; + const BinIdxType* gradient_index = gmat.index.data(); + const size_t* rid = rid_span.begin; + const size_t range_size = rid_span.Size(); + const size_t offset = gmat.cut.Ptrs()[fid]; + + size_t* p_rid_buf = rid_buf->data(); + + return qu->submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { + const size_t id = rid[nid.get_id(0)]; + const int32_t value = static_cast(gradient_index[id * row_stride + fid] + offset); + const bool is_left = value <= split_cond; + if (is_left) { + AtomicRef n_left(parts_size[0]); + p_rid_buf[n_left.fetch_add(1)] = id; + } else { + AtomicRef n_right(parts_size[1]); + p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; + } + }); + }); +} + +// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle sparce columns +template +inline ::sycl::event PartitionSparseKernel(::sycl::queue* qu, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, + const size_t fid, + const int32_t split_cond, + xgboost::common::Span* rid_buf, + size_t* parts_size, + ::sycl::event event) { + const size_t row_stride = gmat.row_stride; + const BinIdxType* gradient_index = gmat.index.data(); + const size_t* rid = rid_span.begin; + const size_t range_size = rid_span.Size(); + const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst(); + const bst_float* cut_vals = gmat.cut_device.Values().DataConst(); + + size_t* p_rid_buf = rid_buf->data(); + return qu->submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { + const size_t id = rid[nid.get_id(0)]; + + const BinIdxType* gr_index_local = gradient_index + row_stride * id; + const int32_t fid_local = std::lower_bound(gr_index_local, + gr_index_local + row_stride, + cut_ptrs[fid]) - gr_index_local; + const bool is_left = (fid_local >= row_stride || + gr_index_local[fid_local] >= cut_ptrs[fid + 1]) ? + default_left : + gr_index_local[fid_local] <= split_cond; + if (is_left) { + AtomicRef n_left(parts_size[0]); + p_rid_buf[n_left.fetch_add(1)] = id; + } else { + AtomicRef n_right(parts_size[1]); + p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; + } + }); + }); +} + // The builder is required for samples partition to left and rights children for set of nodes class PartitionBuilder { public: @@ -29,6 +122,7 @@ class PartitionBuilder { template void Init(::sycl::queue* qu, size_t n_nodes, Func funcNTaks) { + qu_ = qu; nodes_offsets_.resize(n_nodes+1); result_rows_.resize(2 * n_nodes); n_nodes_ = n_nodes; @@ -39,24 +133,11 @@ class PartitionBuilder { nodes_offsets_[i] = nodes_offsets_[i-1] + funcNTaks(i-1); } - if (data_.Size() < nodes_offsets_[n_nodes]) { data_.Resize(qu, nodes_offsets_[n_nodes]); } - prefix_sums_.Resize(qu, maxLocalSums); - } - - - xgboost::common::Span GetData(int nid) { - return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] }; } - - xgboost::common::Span GetPrefixSums() { - return { prefix_sums_.Data(), prefix_sums_.Size() }; - } - - size_t GetLocalSize(const xgboost::common::Range1d& range) { size_t range_size = range.end() - range.begin(); size_t local_subgroups = range_size / (maxLocalSums * subgroupSize) + @@ -69,42 +150,127 @@ class PartitionBuilder { return subgroupSize; } - size_t* GetResultRowsPtr() { - return result_rows_.data(); - } - size_t GetNLeftElems(int nid) const { - // return result_left_rows_[nid]; return result_rows_[2 * nid]; } size_t GetNRightElems(int nid) const { - // return result_right_rows_[nid]; return result_rows_[2 * nid + 1]; } + void SetNLeftElems(int nid, size_t val) { + result_rows_[2 * nid] = val; + } - ::sycl::event MergeToArray(::sycl::queue* qu, size_t node_in_set, - size_t* data_result, - ::sycl::event priv_event) { - size_t n_nodes_total = GetNLeftElems(node_in_set) + GetNRightElems(node_in_set); - if (n_nodes_total > 0) { - const size_t* data = data_.Data() + nodes_offsets_[node_in_set]; - return qu->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, priv_event); + + void SetNRightElems(int nid, size_t val) { + result_rows_[2 * nid + 1] = val; + } + + xgboost::common::Span GetData(int nid) { + return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] }; + } + + + template + ::sycl::event Partition(const int32_t split_cond, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, + const xgboost::RegTree::Node& node, + xgboost::common::Span* rid_buf, + size_t* parts_size, + ::sycl::event event) { + const bst_uint fid = node.SplitIndex(); + const bool default_left = node.DefaultLeft(); + + if (gmat.IsDense()) { + if (default_left) { + return PartitionDenseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } else { + return PartitionDenseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } } else { - return ::sycl::event(); + if (default_left) { + return PartitionSparseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } else { + return PartitionSparseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } + } + } + + // Entry point for Partition + void Partition(const GHistIndexMatrix& gmat, + const std::vector nodes, + const RowSetCollection& row_set_collection, + const std::vector& split_conditions, + RegTree* p_tree, + ::sycl::event* general_event) { + nodes_events_.resize(n_nodes_); + + parts_size_.ResizeAndFill(qu_, 2 * n_nodes_, 0, general_event); + + for (size_t node_in_set = 0; node_in_set < n_nodes_; node_in_set++) { + const int32_t nid = nodes[node_in_set].nid; + ::sycl::event& node_event = nodes_events_[node_in_set]; + const auto& rid_span = row_set_collection[nid]; + if (rid_span.Size() > 0) { + const RegTree::Node& node = (*p_tree)[nid]; + xgboost::common::Span rid_buf = GetData(node_in_set); + size_t* part_size = parts_size_.Data() + 2 * node_in_set; + int32_t split_condition = split_conditions[node_in_set]; + switch (gmat.index.GetBinTypeSize()) { + case common::BinTypeSize::kUint8BinsTypeSize: + node_event = Partition(split_condition, gmat, rid_span, node, + &rid_buf, part_size, *general_event); + break; + case common::BinTypeSize::kUint16BinsTypeSize: + node_event = Partition(split_condition, gmat, rid_span, node, + &rid_buf, part_size, *general_event); + break; + case common::BinTypeSize::kUint32BinsTypeSize: + node_event = Partition(split_condition, gmat, rid_span, node, + &rid_buf, part_size, *general_event); + break; + default: + CHECK(false); // no default behavior + } + } else { + node_event = ::sycl::event(); + } + } + + *general_event = qu_->memcpy(result_rows_.data(), + parts_size_.DataConst(), + sizeof(size_t) * 2 * n_nodes_, + nodes_events_); + } + + void MergeToArray(size_t nid, + size_t* data_result, + ::sycl::event* event) { + size_t n_nodes_total = GetNLeftElems(nid) + GetNRightElems(nid); + if (n_nodes_total > 0) { + const size_t* data = data_.Data() + nodes_offsets_[nid]; + qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, *event); } } protected: std::vector nodes_offsets_; std::vector result_rows_; + std::vector<::sycl::event> nodes_events_; size_t n_nodes_; + USMVector parts_size_; USMVector data_; - USMVector prefix_sums_; + + ::sycl::queue* qu_; }; } // namespace common diff --git a/plugin/sycl/data.h b/plugin/sycl/data.h index c49733e5b7e7..f4a3cb4120c6 100644 --- a/plugin/sycl/data.h +++ b/plugin/sycl/data.h @@ -26,6 +26,13 @@ namespace xgboost { namespace sycl { + +template +using AtomicRef = ::sycl::atomic_ref; + enum class MemoryType { shared, on_device}; @@ -169,20 +176,20 @@ class USMVector { } } - ::sycl::event ResizeAndFill(::sycl::queue* qu, size_t size_new, int v) { + void ResizeAndFill(::sycl::queue* qu, size_t size_new, int v, ::sycl::event* event) { if (size_new <= size_) { size_ = size_new; - return qu->memset(data_.get(), v, size_new * sizeof(T)); + *event = qu->memset(data_.get(), v, size_new * sizeof(T), *event); } else if (size_new <= capacity_) { size_ = size_new; - return qu->memset(data_.get(), v, size_new * sizeof(T)); + *event = qu->memset(data_.get(), v, size_new * sizeof(T), *event); } else { size_t size_old = size_; auto data_old = data_; size_ = size_new; capacity_ = size_new; data_ = allocate_memory_(qu, size_); - return qu->memset(data_.get(), v, size_new * sizeof(T)); + *event = qu->memset(data_.get(), v, size_new * sizeof(T), *event); } } diff --git a/plugin/sycl/data/gradient_index.cc b/plugin/sycl/data/gradient_index.cc new file mode 100644 index 000000000000..b9e2a8d2d197 --- /dev/null +++ b/plugin/sycl/data/gradient_index.cc @@ -0,0 +1,186 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file gradient_index.cc + */ +#include +#include +#include + +#include "gradient_index.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + +uint32_t SearchBin(const bst_float* cut_values, const uint32_t* cut_ptrs, Entry const& e) { + auto beg = cut_ptrs[e.index]; + auto end = cut_ptrs[e.index + 1]; + const auto &values = cut_values; + auto it = std::upper_bound(cut_values + beg, cut_values + end, e.fvalue); + uint32_t idx = it - cut_values; + if (idx == end) { + idx -= 1; + } + return idx; +} + +template +void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) { + const size_t total_len = end - begin; + for (size_t block_len = 1; block_len < total_len; block_len <<= 1) { + for (size_t cur_block = 0; cur_block + block_len < total_len; cur_block += 2 * block_len) { + size_t start = cur_block; + size_t mid = start + block_len; + size_t finish = mid + block_len < total_len ? mid + block_len : total_len; + size_t left_pos = start; + size_t right_pos = mid; + size_t pos = start; + while (left_pos < mid || right_pos < finish) { + if (left_pos < mid && (right_pos == finish || begin[left_pos] < begin[right_pos])) { + buf[pos++] = begin[left_pos++]; + } else { + buf[pos++] = begin[right_pos++]; + } + } + for (size_t i = start; i < finish; i++) begin[i] = buf[i]; + } + } +} + +template +void GHistIndexMatrix::SetIndexData(::sycl::queue qu, + xgboost::common::Span index_data_span, + const DeviceMatrix &dmat_device, + size_t nbins, + size_t row_stride, + uint32_t* offsets) { + if (hit_count.size() == 0) return; + const xgboost::Entry *data_ptr = dmat_device.data.DataConst(); + const bst_row_t *offset_vec = dmat_device.row_ptr.DataConst(); + const size_t num_rows = dmat_device.row_ptr.Size() - 1; + BinIdxType* index_data = index_data_span.data(); + const bst_float* cut_values = cut_device.Values().DataConst(); + const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst(); + ::sycl::buffer hit_count_buf(hit_count.data(), hit_count.size()); + + USMVector sort_buf(&qu, num_rows * row_stride); + BinIdxType* sort_data = sort_buf.Data(); + + qu.submit([&](::sycl::handler& cgh) { + auto hit_count_acc = hit_count_buf.template get_access<::sycl::access::mode::atomic>(cgh); + cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) { + const size_t i = pid.get_id(0); + const size_t ibegin = offset_vec[i]; + const size_t iend = offset_vec[i + 1]; + const size_t size = iend - ibegin; + const size_t start = i * row_stride; + for (bst_uint j = 0; j < size; ++j) { + uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]); + index_data[start + j] = offsets ? idx - offsets[j] : idx; + ::sycl::atomic_fetch_add(hit_count_acc[idx], 1); + } + if (!offsets) { + // Sparse case only + mergeSort(index_data + start, index_data + start + size, sort_data + start); + for (bst_uint j = size; j < row_stride; ++j) { + index_data[start + j] = nbins; + } + } + }); + }).wait(); +} + +void GHistIndexMatrix::ResizeIndex(const size_t n_offsets, + const size_t n_index, + const bool isDense) { + if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize); + index.Resize((sizeof(uint8_t)) * n_index); + } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && + max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize); + index.Resize((sizeof(uint16_t)) * n_index); + } else { + index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize); + index.Resize((sizeof(uint32_t)) * n_index); + } +} + +void GHistIndexMatrix::Init(::sycl::queue qu, + Context const * ctx, + const DeviceMatrix& p_fmat_device, + int max_bins) { + nfeatures = p_fmat_device.p_mat->Info().num_col_; + + cut = xgboost::common::SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins); + cut_device.Init(qu, cut); + + max_num_bins = max_bins; + const uint32_t nbins = cut.Ptrs().back(); + this->nbins = nbins; + hit_count.resize(nbins, 0); + + this->p_fmat = p_fmat_device.p_mat; + const bool isDense = p_fmat_device.p_mat->IsDense(); + this->isDense_ = isDense; + + row_ptr = std::vector(p_fmat_device.row_ptr.Begin(), p_fmat_device.row_ptr.End()); + row_ptr_device = p_fmat_device.row_ptr; + + index.setQueue(qu); + + row_stride = 0; + for (const auto& batch : p_fmat_device.p_mat->GetBatches()) { + const auto& row_offset = batch.offset.ConstHostVector(); + for (auto i = 1ull; i < row_offset.size(); i++) { + row_stride = std::max(row_stride, static_cast(row_offset[i] - row_offset[i - 1])); + } + } + + const size_t n_offsets = cut.Ptrs().size() - 1; + const size_t n_rows = p_fmat_device.row_ptr.Size() - 1; + const size_t n_index = n_rows * row_stride; + ResizeIndex(n_offsets, n_index, isDense); + + CHECK_GT(cut.Values().size(), 0U); + + uint32_t* offsets = nullptr; + if (isDense) { + index.ResizeOffset(n_offsets); + offsets = index.Offset(); + qu.memcpy(offsets, cut.Ptrs().data(), sizeof(uint32_t) * n_offsets).wait_and_throw(); + // for (size_t i = 0; i < n_offsets; ++i) { + // offsets[i] = cut.Ptrs()[i]; + // } + } + + if (isDense) { + BinTypeSize curent_bin_size = index.GetBinTypeSize(); + if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) { + xgboost::common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + + } else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) { + xgboost::common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + } else { + CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize); + xgboost::common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + } + /* For sparse DMatrix we have to store index of feature for each bin + in index field to chose right offset. So offset is nullptr and index is not reduced */ + } else { + xgboost::common::Span index_data_span = {index.data(), n_index}; + SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); + } +} + +} // namespace common +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/data/gradient_index.h b/plugin/sycl/data/gradient_index.h new file mode 100644 index 000000000000..0787e03aefa4 --- /dev/null +++ b/plugin/sycl/data/gradient_index.h @@ -0,0 +1,219 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file gradient_index.h + */ +#ifndef PLUGIN_SYCL_DATA_GRADIENT_INDEX_H_ +#define PLUGIN_SYCL_DATA_GRADIENT_INDEX_H_ + +#include + +#include "../data.h" +#include "../../src/common/hist_util.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + +/*! + * \brief SYCL implementation of HistogramCuts stored in USM buffers to provide access from device kernels + */ +class HistogramCuts { + protected: + using BinIdx = uint32_t; + + public: + HistogramCuts() {} + + explicit HistogramCuts(::sycl::queue qu) { + cut_ptrs_.Resize(&qu, 1, 0u); + } + + ~HistogramCuts() { + } + + void Init(::sycl::queue qu, xgboost::common::HistogramCuts const& cuts) { + qu_ = qu; + cut_values_.Init(&qu_, cuts.cut_values_.HostVector()); + cut_ptrs_.Init(&qu_, cuts.cut_ptrs_.HostVector()); + min_vals_.Init(&qu_, cuts.min_vals_.HostVector()); + } + + // Getters for USM buffers to pass pointers into device kernels + const USMVector& Ptrs() const { return cut_ptrs_; } + const USMVector& Values() const { return cut_values_; } + const USMVector& MinValues() const { return min_vals_; } + + private: + USMVector cut_values_; + USMVector cut_ptrs_; + USMVector min_vals_; + ::sycl::queue qu_; +}; + +using BinTypeSize = ::xgboost::common::BinTypeSize; + +/*! + * \brief Index data and offsets stored in USM buffers to provide access from device kernels + */ +struct Index { + Index() { + SetBinTypeSize(binTypeSize_); + } + Index(const Index& i) = delete; + Index& operator=(Index i) = delete; + Index(Index&& i) = delete; + Index& operator=(Index&& i) = delete; + uint32_t operator[](size_t i) const { + if (!offset_.Empty()) { + return func_(data_.DataConst(), i) + offset_[i%p_]; + } else { + return func_(data_.DataConst(), i); + } + } + void SetBinTypeSize(BinTypeSize binTypeSize) { + binTypeSize_ = binTypeSize; + switch (binTypeSize) { + case BinTypeSize::kUint8BinsTypeSize: + func_ = &GetValueFromUint8; + break; + case BinTypeSize::kUint16BinsTypeSize: + func_ = &GetValueFromUint16; + break; + case BinTypeSize::kUint32BinsTypeSize: + func_ = &GetValueFromUint32; + break; + default: + CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize || + binTypeSize == BinTypeSize::kUint16BinsTypeSize || + binTypeSize == BinTypeSize::kUint32BinsTypeSize); + } + } + BinTypeSize GetBinTypeSize() const { + return binTypeSize_; + } + + template + T* data() { + return reinterpret_cast(data_.Data()); + } + + template + const T* data() const { + return reinterpret_cast(data_.DataConst()); + } + + uint32_t* Offset() { + return offset_.Data(); + } + + const uint32_t* Offset() const { + return offset_.DataConst(); + } + + size_t Size() const { + return data_.Size() / (binTypeSize_); + } + + void Resize(const size_t nBytesData) { + data_.Resize(&qu_, nBytesData); + } + + void ResizeOffset(const size_t nDisps) { + offset_.Resize(&qu_, nDisps); + p_ = nDisps; + } + + uint8_t* begin() const { + return data_.Begin(); + } + + uint8_t* end() const { + return data_.End(); + } + + void setQueue(::sycl::queue qu) { + qu_ = qu; + } + + private: + static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) { + return reinterpret_cast(t)[i]; + } + static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) { + return reinterpret_cast(t)[i]; + } + static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) { + return reinterpret_cast(t)[i]; + } + + using Func = uint32_t (*)(const uint8_t*, size_t); + + USMVector data_; + // size of this field is equal to number of features + USMVector offset_; + BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize}; + size_t p_ {1}; + Func func_; + + ::sycl::queue qu_; +}; + + +/*! + * \brief Preprocessed global index matrix, in CSR format, stored in USM buffers + * + * Transform floating values to integer index in histogram + */ +struct GHistIndexMatrix { + /*! \brief row pointer to rows by element position */ + std::vector row_ptr; + USMVector row_ptr_device; + /*! \brief The index data */ + Index index; + /*! \brief hit count of each index */ + std::vector hit_count; + /*! \brief The corresponding cuts */ + xgboost::common::HistogramCuts cut; + HistogramCuts cut_device; + DMatrix* p_fmat; + size_t max_num_bins; + size_t nbins; + size_t nfeatures; + size_t row_stride; + + // Create a global histogram matrix based on a given DMatrix device wrapper + void Init(::sycl::queue qu, Context const * ctx, + const sycl::DeviceMatrix& p_fmat_device, int max_num_bins); + + template + void SetIndexData(::sycl::queue qu, xgboost::common::Span index_data_span, + const sycl::DeviceMatrix &dmat_device, + size_t nbins, size_t row_stride, uint32_t* offsets); + + void ResizeIndex(const size_t n_offsets, const size_t n_index, + const bool isDense); + + inline void GetFeatureCounts(std::vector* counts) const { + auto nfeature = cut_device.Ptrs().Size() - 1; + for (unsigned fid = 0; fid < nfeature; ++fid) { + auto ibegin = cut_device.Ptrs()[fid]; + auto iend = cut_device.Ptrs()[fid + 1]; + for (auto i = ibegin; i < iend; ++i) { + (*counts)[fid] += hit_count[i]; + } + } + } + inline bool IsDense() const { + return isDense_; + } + + private: + bool isDense_; +}; + +} // namespace common +} // namespace sycl +} // namespace xgboost +#endif // PLUGIN_SYCL_DATA_GRADIENT_INDEX_H_ diff --git a/plugin/sycl/tree/expand_entry.h b/plugin/sycl/tree/expand_entry.h new file mode 100644 index 000000000000..61b803f66819 --- /dev/null +++ b/plugin/sycl/tree/expand_entry.h @@ -0,0 +1,42 @@ +/*! + * Copyright 2017-2021 by Contributors + * \file updater_quantile_hist.h + */ +#ifndef PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ +#define PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#include "../../src/tree/constraints.h" +#pragma GCC diagnostic pop + +namespace xgboost { +namespace sycl { +namespace tree { +/* tree growing policies */ +struct ExpandEntry { + static const int kRootNid = 0; + static const int kEmptyNid = -1; + int nid; + int sibling_nid; + int depth; + bst_float loss_chg; + unsigned timestamp; + ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, + unsigned tstmp) + : nid(nid), sibling_nid(sibling_nid), depth(depth), + loss_chg(loss_chg), timestamp(tstmp) {} + + bool IsValid(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const { + bool ret = loss_chg <= kRtEps || + (param.max_depth > 0 && this->depth == param.max_depth) || + (param.max_leaves > 0 && num_leaves == param.max_leaves); + return ret; + } +}; + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index c6bb87c699ef..d5cb0c16232c 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -1192,122 +1192,6 @@ GradStats QuantileHistMaker::Builder::EnumerateSplit return GradStats(sum_grad, sum_hess); } -// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending -// on comparison of indexes values (idx_span) and split point (split_cond) -// Handle dense columns -template -inline ::sycl::event PartitionDenseKernel( - ::sycl::queue* qu, - const GHistIndexMatrix& gmat, - const RowSetCollection::Elem& rid_span, - const size_t fid, - const int32_t split_cond, - xgboost::common::Span* rid_buf, - size_t* parts_size, - ::sycl::event priv_event) { - const size_t row_stride = gmat.row_stride; - const BinIdxType* gradient_index = gmat.index.data(); - const size_t* rid = rid_span.begin; - const size_t range_size = rid_span.Size(); - const size_t offset = gmat.cut.Ptrs()[fid]; - - size_t* p_rid_buf = rid_buf->data(); - - auto event = qu->submit([&](::sycl::handler& cgh) { - cgh.depends_on(priv_event); - cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { - const size_t id = rid[nid.get_id(0)]; - const int32_t value = static_cast(gradient_index[id * row_stride + fid] + offset); - const bool is_left = value <= split_cond; - if (is_left) { - common::AtomicRef n_left(parts_size[0]); - p_rid_buf[n_left.fetch_add(1)] = id; - } else { - common::AtomicRef n_right(parts_size[1]); - p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; - } - }); - }); - return event; -} - -// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending -// on comparison of indexes values (idx_span) and split point (split_cond) -// Handle dense columns -template -inline ::sycl::event PartitionSparseKernel(::sycl::queue* qu, - const GHistIndexMatrix& gmat, - const RowSetCollection::Elem& rid_span, - const size_t fid, - const int32_t split_cond, - xgboost::common::Span* rid_buf, - size_t* parts_size, - ::sycl::event priv_event) { - const size_t row_stride = gmat.row_stride; - const BinIdxType* gradient_index = gmat.index.data(); - const size_t* rid = rid_span.begin; - const size_t range_size = rid_span.Size(); - const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst(); - const bst_float* cut_vals = gmat.cut_device.Values().DataConst(); - - size_t* p_rid_buf = rid_buf->data(); - auto event = qu->submit([&](::sycl::handler& cgh) { - cgh.depends_on(priv_event); - cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { - const size_t id = rid[nid.get_id(0)]; - - const BinIdxType* gr_index_local = gradient_index + row_stride * id; - const int32_t fid_local = std::lower_bound(gr_index_local, - gr_index_local + row_stride, - cut_ptrs[fid]) - gr_index_local; - const bool is_left = (fid_local >= row_stride || - gr_index_local[fid_local] >= cut_ptrs[fid + 1]) ? - default_left : - gr_index_local[fid_local] <= split_cond; - if (is_left) { - common::AtomicRef n_left(parts_size[0]); - p_rid_buf[n_left.fetch_add(1)] = id; - } else { - common::AtomicRef n_right(parts_size[1]); - p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; - } - }); - }); - return event; -} - -template -template -::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, - const int32_t split_cond, - const GHistIndexMatrix& gmat, - const RegTree::Node& node, - xgboost::common::Span* rid_buf, - size_t* parts_size, - ::sycl::event priv_event) { - const bst_uint fid = node.SplitIndex(); - const bool default_left = node.DefaultLeft(); - - if (gmat.IsDense()) { - if (default_left) { - return PartitionDenseKernel(&qu_, gmat, row_set_collection_[nid], fid, - split_cond, rid_buf, parts_size, priv_event); - } else { - return PartitionDenseKernel(&qu_, gmat, row_set_collection_[nid], fid, - split_cond, rid_buf, parts_size, priv_event); - } - } else { - if (default_left) { - return PartitionSparseKernel(&qu_, gmat, row_set_collection_[nid], fid, - split_cond, rid_buf, parts_size, priv_event); - } else { - return PartitionSparseKernel(&qu_, gmat, row_set_collection_[nid], fid, - split_cond, rid_buf, parts_size, priv_event); - } - } -} - template void QuantileHistMaker::Builder::FindSplitConditions( const std::vector& nodes, @@ -1368,50 +1252,15 @@ void QuantileHistMaker::Builder::ApplySplit( return row_set_collection_[nid].Size(); }); - // Add resize_and_fill method to save one call - auto event = parts_size_.ResizeAndFill(&qu_, 2 * n_nodes, 0); - apply_split_events_.resize(n_nodes); - - for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { - const int32_t nid = nodes[node_in_set].nid; - ::sycl::event& cur_event = apply_split_events_[node_in_set]; - if (row_set_collection_[nid].Size() > 0) { - const RegTree::Node& node = (*p_tree)[nid]; - xgboost::common::Span rid_buf = partition_builder_.GetData(node_in_set); - size_t* part_size = parts_size_.Data() + 2 * node_in_set; - int32_t split_condition = split_conditions[node_in_set]; - switch (gmat.index.GetBinTypeSize()) { - case common::BinTypeSize::kUint8BinsTypeSize: - cur_event = PartitionKernel(nid, split_condition, gmat, node, - &rid_buf, part_size, event); - break; - case common::BinTypeSize::kUint16BinsTypeSize: - cur_event = PartitionKernel(nid, split_condition, gmat, node, - &rid_buf, part_size, event); - break; - case common::BinTypeSize::kUint32BinsTypeSize: - cur_event = PartitionKernel(nid, split_condition, gmat, node, - &rid_buf, part_size, event); - break; - default: - CHECK(false); // no default behavior - } - } else { - cur_event = ::sycl::event(); - } - } - - ::sycl::event event_cpy = qu_.memcpy(partition_builder_.GetResultRowsPtr(), - parts_size_.DataConst(), - sizeof(size_t) * 2 * n_nodes, - apply_split_events_); + ::sycl::event event; + partition_builder_.Partition(gmat, nodes, row_set_collection_, + split_conditions, p_tree, &event); qu_.wait_and_throw(); - merge_to_array_events_.resize(n_nodes); + for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { - ::sycl::event& cur_event = merge_to_array_events_[node_in_set]; const int32_t nid = nodes[node_in_set].nid; size_t* data_result = const_cast(row_set_collection_[nid].begin); - cur_event = partition_builder_.MergeToArray(&qu_, node_in_set, data_result, event_cpy); + partition_builder_.MergeToArray(node_in_set, data_result, &event); } qu_.wait_and_throw(); @@ -1487,30 +1336,6 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, template struct QuantileHistMaker::Builder; template struct QuantileHistMaker::Builder; -template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, - const RegTree::Node& node, xgboost::common::Span* rid_buf, - size_t* parts_size, ::sycl::event priv_event); -template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, - const RegTree::Node& node, xgboost::common::Span* rid_buf, - size_t* parts_size, ::sycl::event priv_event); -template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, - const RegTree::Node& node, xgboost::common::Span* rid_buf, - size_t* parts_size, ::sycl::event priv_event); -template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, - const RegTree::Node& node, xgboost::common::Span* rid_buf, - size_t* parts_size, ::sycl::event priv_event); -template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, - const RegTree::Node& node, xgboost::common::Span* rid_buf, - size_t* parts_size, ::sycl::event priv_event); -template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, - const RegTree::Node& node, xgboost::common::Span* rid_buf, - size_t* parts_size, ::sycl::event priv_event); XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") .describe("Grow tree using quantized histogram with SYCL.") diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h index 8eb3b3d02c09..0246e88f09a6 100644 --- a/plugin/sycl/tree/updater_quantile_hist.h +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -14,6 +14,7 @@ #include #include +#include "../data/gradient_index.h" #include "../common/hist_util.h" #include "../common/row_set.h" #include "../common/partition_builder.h" @@ -262,28 +263,6 @@ class QuantileHistMaker: public TreeUpdater { friend class BatchHistRowsAdder; friend class DistributedHistRowsAdder; - /* tree growing policies */ - struct ExpandEntry { - static const int kRootNid = 0; - static const int kEmptyNid = -1; - int nid; - int sibling_nid; - int depth; - bst_float loss_chg; - unsigned timestamp; - ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, - unsigned tstmp) - : nid(nid), sibling_nid(sibling_nid), depth(depth), - loss_chg(loss_chg), timestamp(tstmp) {} - - bool IsValid(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const { - bool ret = loss_chg <= kRtEps || - (param.max_depth > 0 && this->depth == param.max_depth) || - (param.max_leaves > 0 && num_leaves == param.max_leaves); - return ret; - } - }; - struct SplitQuery { int nid; int fid; @@ -464,9 +443,6 @@ class QuantileHistMaker: public TreeUpdater { constexpr static size_t kNumParallelBuffers = 1; std::array, kNumParallelBuffers> hist_buffers_; std::array<::sycl::event, kNumParallelBuffers> hist_build_events_; - USMVector parts_size_; - std::vector parts_size_cpu_; - std::vector<::sycl::event> apply_split_events_; std::vector<::sycl::event> merge_to_array_events_; // rabit::op::Reducer histred_; std::unique_ptr> hist_synchronizer_; @@ -531,7 +507,7 @@ template class DistributedHistSynchronizer: public HistSynchronizer { public: using BuilderT = QuantileHistMaker::Builder; - using ExpandEntryT = typename BuilderT::ExpandEntry; + using ExpandEntryT = ExpandEntry; void SyncHistograms(BuilderT* builder, const std::vector& sync_ids, diff --git a/tests/cpp/plugin/test_sycl_partition_builder.cc b/tests/cpp/plugin/test_sycl_partition_builder.cc new file mode 100644 index 000000000000..12e2f2f5f933 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_partition_builder.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020-2023 by XGBoost contributors + */ +#include + +#include +#include +#include + +#include "../../../plugin/sycl/common/partition_builder.h" +#include "../../../plugin/sycl/device_manager.h" +#include "../helpers.h" + +namespace xgboost::sycl::common { + +TEST(SyclPartitionBuilder, BasicTest) { + constexpr size_t kNodes = 5; + // Number of rows for each node + std::vector rows = { 5, 5, 10, 1, 2 }; + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(DeviceOrd::SYCL_default()); + PartitionBuilder builder; + builder.Init(&qu, kNodes, [&](size_t i) { + return rows[i]; + }); + + // We test here only the basics, thus syntetic partition builder is adopted + // Number of rows to go left for each node. + std::vector rows_for_left_node = { 2, 0, 7, 1, 2 }; + + size_t first_row_id = 0; + for(size_t nid = 0; nid < kNodes; ++nid) { + size_t n_rows_nodes = rows[nid]; + + auto rid_buff = builder.GetData(nid); + size_t rid_buff_size = rid_buff.size(); + auto* rid_buff_ptr = rid_buff.data(); + + size_t n_left = rows_for_left_node[nid]; + size_t n_right = rows[nid] - n_left; + + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(n_left), [=](::sycl::id<1> pid) { + int row_id = first_row_id + pid[0]; + rid_buff_ptr[pid[0]] = row_id; + }); + }); + qu.wait(); + first_row_id += n_left; + + // We are storing indexes for the right side in the tail of the array to save some memory + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(n_right), [=](::sycl::id<1> pid) { + int row_id = first_row_id + pid[0]; + rid_buff_ptr[rid_buff_size - pid[0] - 1] = row_id; + }); + }); + qu.wait(); + first_row_id += n_right; + + builder.SetNLeftElems(nid, n_left); + builder.SetNRightElems(nid, n_right); + } + + ::sycl::event event; + std::vector v(*std::max_element(rows.begin(), rows.end())); + size_t row_id = 0; + for(size_t nid = 0; nid < kNodes; ++nid) { + builder.MergeToArray(nid, v.data(), &event); + qu.wait(); + + // Check that row_id for left side are correct + for(size_t j = 0; j < rows_for_left_node[nid]; ++j) { + ASSERT_EQ(v[j], row_id++); + } + + // Check that row_id for right side are correct + for(size_t j = 0; j < rows[nid] - rows_for_left_node[nid]; ++j) { + ASSERT_EQ(v[rows[nid] - j - 1], row_id++); + } + + // Check that number of left/right rows are correct + size_t n_left = builder.GetNLeftElems(nid); + size_t n_right = builder.GetNRightElems(nid); + ASSERT_EQ(n_left, rows_for_left_node[nid]); + ASSERT_EQ(n_right, (rows[nid] - rows_for_left_node[nid])); + } +} + +} // namespace xgboost::common