Skip to content

Commit

Permalink
refactor partition builder and add basic.test (#29)
Browse files Browse the repository at this point in the history
* refactor partition builder and add basic.test

* linting

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Jan 30, 2024
1 parent ca540cf commit 025b286
Show file tree
Hide file tree
Showing 11 changed files with 752 additions and 583 deletions.
1 change: 1 addition & 0 deletions plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ endif (PLUGIN_DENSE_PARSER)
if (PLUGIN_SYCL)
set(CMAKE_CXX_COMPILER "icpx")
add_library(plugin_sycl OBJECT
${xgboost_SOURCE_DIR}/plugin/sycl/data/gradient_index.cc
${xgboost_SOURCE_DIR}/plugin/sycl/common/hist_util.cc
${xgboost_SOURCE_DIR}/plugin/sycl/objective/regression_obj.cc
${xgboost_SOURCE_DIR}/plugin/sycl/objective/multiclass_obj.cc
Expand Down
145 changes: 1 addition & 144 deletions plugin/sycl/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <limits>
#include <algorithm>

#include "../data/gradient_index.h"
#include "hist_util.h"

#include <CL/sycl.hpp>
Expand All @@ -14,18 +15,6 @@ namespace xgboost {
namespace sycl {
namespace common {

uint32_t SearchBin(const bst_float* cut_values, const uint32_t* cut_ptrs, Entry const& e) {
auto beg = cut_ptrs[e.index];
auto end = cut_ptrs[e.index + 1];
const auto &values = cut_values;
auto it = std::upper_bound(cut_values + beg, cut_values + end, e.fvalue);
uint32_t idx = it - cut_values;
if (idx == end) {
idx -= 1;
}
return idx;
}

template <typename BinIdxType>
void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
const size_t total_len = end - begin;
Expand All @@ -49,138 +38,6 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
}
}

template <typename BinIdxType>
void GHistIndexMatrix::SetIndexData(::sycl::queue qu,
xgboost::common::Span<BinIdxType> index_data_span,
const DeviceMatrix &dmat_device,
size_t nbins,
size_t row_stride,
uint32_t* offsets) {
if (hit_count.size() == 0) return;
const xgboost::Entry *data_ptr = dmat_device.data.DataConst();
const bst_row_t *offset_vec = dmat_device.row_ptr.DataConst();
const size_t num_rows = dmat_device.row_ptr.Size() - 1;
BinIdxType* index_data = index_data_span.data();
const bst_float* cut_values = cut_device.Values().DataConst();
const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst();
::sycl::buffer<size_t, 1> hit_count_buf(hit_count.data(), hit_count.size());

USMVector<BinIdxType> sort_buf(&qu, num_rows * row_stride);
BinIdxType* sort_data = sort_buf.Data();

qu.submit([&](::sycl::handler& cgh) {
auto hit_count_acc = hit_count_buf.template get_access<::sycl::access::mode::atomic>(cgh);
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
const size_t ibegin = offset_vec[i];
const size_t iend = offset_vec[i + 1];
const size_t size = iend - ibegin;
const size_t start = i * row_stride;
for (bst_uint j = 0; j < size; ++j) {
uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]);
index_data[start + j] = offsets ? idx - offsets[j] : idx;
::sycl::atomic_fetch_add<size_t>(hit_count_acc[idx], 1);
}
if (!offsets) {
// Sparse case only
mergeSort<BinIdxType>(index_data + start, index_data + start + size, sort_data + start);
for (bst_uint j = size; j < row_stride; ++j) {
index_data[start + j] = nbins;
}
}
});
}).wait();
}

void GHistIndexMatrix::ResizeIndex(const size_t n_offsets,
const size_t n_index,
const bool isDense) {
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize);
index.Resize((sizeof(uint8_t)) * n_index);
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize);
index.Resize((sizeof(uint16_t)) * n_index);
} else {
index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize);
index.Resize((sizeof(uint32_t)) * n_index);
}
}

void GHistIndexMatrix::Init(::sycl::queue qu,
Context const * ctx,
const DeviceMatrix& p_fmat_device,
int max_bins) {
nfeatures = p_fmat_device.p_mat->Info().num_col_;

cut = xgboost::common::SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins);
cut_device.Init(qu, cut);

max_num_bins = max_bins;
const uint32_t nbins = cut.Ptrs().back();
this->nbins = nbins;
hit_count.resize(nbins, 0);

this->p_fmat = p_fmat_device.p_mat;
const bool isDense = p_fmat_device.p_mat->IsDense();
this->isDense_ = isDense;

row_ptr = std::vector<size_t>(p_fmat_device.row_ptr.Begin(), p_fmat_device.row_ptr.End());
row_ptr_device = p_fmat_device.row_ptr;

index.setQueue(qu);

row_stride = 0;
for (const auto& batch : p_fmat_device.p_mat->GetBatches<SparsePage>()) {
const auto& row_offset = batch.offset.ConstHostVector();
for (auto i = 1ull; i < row_offset.size(); i++) {
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
}
}

const size_t n_offsets = cut.Ptrs().size() - 1;
const size_t n_rows = p_fmat_device.row_ptr.Size() - 1;
const size_t n_index = n_rows * row_stride;
ResizeIndex(n_offsets, n_index, isDense);

CHECK_GT(cut.Values().size(), 0U);

uint32_t* offsets = nullptr;
if (isDense) {
index.ResizeOffset(n_offsets);
offsets = index.Offset();
qu.memcpy(offsets, cut.Ptrs().data(), sizeof(uint32_t) * n_offsets).wait_and_throw();
// for (size_t i = 0; i < n_offsets; ++i) {
// offsets[i] = cut.Ptrs()[i];
// }
}

if (isDense) {
BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) {
xgboost::common::Span<uint8_t> index_data_span = {index.data<uint8_t>(),
n_index};
SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets);

} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) {
xgboost::common::Span<uint16_t> index_data_span = {index.data<uint16_t>(),
n_index};
SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets);
} else {
CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize);
xgboost::common::Span<uint32_t> index_data_span = {index.data<uint32_t>(),
n_index};
SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets);
}
/* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is not reduced */
} else {
xgboost::common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets);
}
}

/*!
* \brief Fill histogram with zeroes
*/
Expand Down
201 changes: 0 additions & 201 deletions plugin/sycl/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,209 +21,8 @@ namespace common {
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
using GHistRow = USMVector<xgboost::detail::GradientPairInternal<GradientSumT>, memory_type>;

template <typename T>
using AtomicRef = ::sycl::atomic_ref<T,
::sycl::memory_order::relaxed,
::sycl::memory_scope::device,
::sycl::access::address_space::ext_intel_global_device_space>;

/*!
* \brief SYCL implementation of HistogramCuts stored in USM buffers to provide access from device kernels
*/
class HistogramCuts {
protected:
using BinIdx = uint32_t;

public:
HistogramCuts() {}

explicit HistogramCuts(::sycl::queue qu) {
cut_ptrs_.Resize(&qu_, 1, 0);
}

~HistogramCuts() {
}

void Init(::sycl::queue qu, xgboost::common::HistogramCuts const& cuts) {
qu_ = qu;
cut_values_.Init(&qu_, cuts.cut_values_.HostVector());
cut_ptrs_.Init(&qu_, cuts.cut_ptrs_.HostVector());
min_vals_.Init(&qu_, cuts.min_vals_.HostVector());
}

// Getters for USM buffers to pass pointers into device kernels
const USMVector<uint32_t>& Ptrs() const { return cut_ptrs_; }
const USMVector<float>& Values() const { return cut_values_; }
const USMVector<float>& MinValues() const { return min_vals_; }

private:
USMVector<bst_float> cut_values_;
USMVector<uint32_t> cut_ptrs_;
USMVector<float> min_vals_;
::sycl::queue qu_;
};

using BinTypeSize = ::xgboost::common::BinTypeSize;

/*!
* \brief Index data and offsets stored in USM buffers to provide access from device kernels
*/
struct Index {
Index() {
SetBinTypeSize(binTypeSize_);
}
Index(const Index& i) = delete;
Index& operator=(Index i) = delete;
Index(Index&& i) = delete;
Index& operator=(Index&& i) = delete;
uint32_t operator[](size_t i) const {
if (!offset_.Empty()) {
return func_(data_.DataConst(), i) + offset_[i%p_];
} else {
return func_(data_.DataConst(), i);
}
}
void SetBinTypeSize(BinTypeSize binTypeSize) {
binTypeSize_ = binTypeSize;
switch (binTypeSize) {
case BinTypeSize::kUint8BinsTypeSize:
func_ = &GetValueFromUint8;
break;
case BinTypeSize::kUint16BinsTypeSize:
func_ = &GetValueFromUint16;
break;
case BinTypeSize::kUint32BinsTypeSize:
func_ = &GetValueFromUint32;
break;
default:
CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize ||
binTypeSize == BinTypeSize::kUint16BinsTypeSize ||
binTypeSize == BinTypeSize::kUint32BinsTypeSize);
}
}
BinTypeSize GetBinTypeSize() const {
return binTypeSize_;
}

template<typename T>
T* data() {
return reinterpret_cast<T*>(data_.Data());
}

template<typename T>
const T* data() const {
return reinterpret_cast<const T*>(data_.DataConst());
}

uint32_t* Offset() {
return offset_.Data();
}

const uint32_t* Offset() const {
return offset_.DataConst();
}

size_t Size() const {
return data_.Size() / (binTypeSize_);
}

void Resize(const size_t nBytesData) {
data_.Resize(&qu_, nBytesData);
}

void ResizeOffset(const size_t nDisps) {
offset_.Resize(&qu_, nDisps);
p_ = nDisps;
}

uint8_t* begin() const {
return data_.Begin();
}

uint8_t* end() const {
return data_.End();
}

void setQueue(::sycl::queue qu) {
qu_ = qu;
}

private:
static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) {
return reinterpret_cast<const uint8_t*>(t)[i];
}
static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) {
return reinterpret_cast<const uint16_t*>(t)[i];
}
static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) {
return reinterpret_cast<const uint32_t*>(t)[i];
}

using Func = uint32_t (*)(const uint8_t*, size_t);

USMVector<uint8_t, MemoryType::on_device> data_;
// size of this field is equal to number of features
USMVector<uint32_t, MemoryType::on_device> offset_;
BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize};
size_t p_ {1};
Func func_;

::sycl::queue qu_;
};


/*!
* \brief Preprocessed global index matrix, in CSR format, stored in USM buffers
*
* Transform floating values to integer index in histogram
*/
struct GHistIndexMatrix {
/*! \brief row pointer to rows by element position */
std::vector<size_t> row_ptr;
USMVector<size_t> row_ptr_device;
/*! \brief The index data */
Index index;
/*! \brief hit count of each index */
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
xgboost::common::HistogramCuts cut;
HistogramCuts cut_device;
DMatrix* p_fmat;
size_t max_num_bins;
size_t nbins;
size_t nfeatures;
size_t row_stride;

// Create a global histogram matrix based on a given DMatrix device wrapper
void Init(::sycl::queue qu, Context const * ctx,
const sycl::DeviceMatrix& p_fmat_device, int max_num_bins);

template <typename BinIdxType>
void SetIndexData(::sycl::queue qu, xgboost::common::Span<BinIdxType> index_data_span,
const sycl::DeviceMatrix &dmat_device,
size_t nbins, size_t row_stride, uint32_t* offsets);

void ResizeIndex(const size_t n_offsets, const size_t n_index,
const bool isDense);

inline void GetFeatureCounts(std::vector<size_t>* counts) const {
auto nfeature = cut_device.Ptrs().Size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
auto ibegin = cut_device.Ptrs()[fid];
auto iend = cut_device.Ptrs()[fid + 1];
for (auto i = ibegin; i < iend; ++i) {
(*counts)[fid] += hit_count[i];
}
}
}
inline bool IsDense() const {
return isDense_;
}

private:
bool isDense_;
};

class ColumnMatrix;

/*!
Expand Down
Loading

0 comments on commit 025b286

Please sign in to comment.