Skip to content

Commit

Permalink
make sycl::ExpandEntry more similar to original one (#34)
Browse files Browse the repository at this point in the history
Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Feb 27, 2024
1 parent eb1617e commit 9bc99f0
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 58 deletions.
48 changes: 28 additions & 20 deletions plugin/sycl/tree/expand_entry.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright 2017-2021 by Contributors
* \file updater_quantile_hist.h
* Copyright 2017-2024 by Contributors
* \file expand_entry.h
*/
#ifndef PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
#define PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
Expand All @@ -9,29 +9,37 @@
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "../../src/tree/constraints.h"
#pragma GCC diagnostic pop
#include "../../src/tree/hist/expand_entry.h"

namespace xgboost {
namespace sycl {
namespace tree {
/* tree growing policies */
struct ExpandEntry {
static const int kRootNid = 0;
static const int kEmptyNid = -1;
int nid;
int sibling_nid;
int depth;
bst_float loss_chg;
unsigned timestamp;
ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg,
unsigned tstmp)
: nid(nid), sibling_nid(sibling_nid), depth(depth),
loss_chg(loss_chg), timestamp(tstmp) {}

bool IsValid(xgboost::tree::TrainParam const &param, int32_t num_leaves) const {
bool ret = loss_chg <= kRtEps ||
(param.max_depth > 0 && this->depth == param.max_depth) ||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
return ret;
struct ExpandEntry : public xgboost::tree::ExpandEntryImpl<ExpandEntry> {
static constexpr bst_node_t kRootNid = 0;

xgboost::tree::SplitEntry split;

ExpandEntry(int nid, int depth) : ExpandEntryImpl{nid, depth} {}

inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree) const {
CHECK_EQ((*p_tree)[nid].IsRoot(), false);
const size_t parent_id = (*p_tree)[nid].Parent();
return GetSiblingId(p_tree, parent_id);
}

inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree, size_t parent_id) const {
return p_tree->IsLeftChild(nid) ? p_tree->RightChild(parent_id)
: p_tree->LeftChild(parent_id);
}

bool IsValidImpl(xgboost::tree::TrainParam const &param, int32_t num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.loss_chg < param.min_split_loss) return false;
if (param.max_depth > 0 && depth == param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;

return true;
}
};

Expand Down
59 changes: 26 additions & 33 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
const auto entry = builder->nodes_for_explicit_hist_build_[i];
auto& this_hist = builder->hist_[entry.nid];

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
auto& parent_hist = builder->hist_[parent_id];
auto& sibling_hist = builder->hist_[entry.sibling_nid];
auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)];
hist_sync_events_[i] = common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist,
this_hist, nbins, ::sycl::event());
}
Expand All @@ -174,14 +174,15 @@ void DistributedHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT* builder
auto& this_local = builder->hist_local_worker_[entry.nid];
common::CopyHist(builder->qu_, &this_local, this_hist, nbins);

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
auto sibling_nid = entry.GetSiblingId(p_tree, parent_id);
auto& parent_hist = builder->hist_local_worker_[parent_id];
auto& sibling_hist = builder->hist_[entry.sibling_nid];
auto& sibling_hist = builder->hist_[sibling_nid];
common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist,
this_hist, nbins, ::sycl::event());
// Store posible parent node
auto& sibling_local = builder->hist_local_worker_[entry.sibling_nid];
auto& sibling_local = builder->hist_local_worker_[sibling_nid];
common::CopyHist(builder->qu_, &sibling_local, sibling_hist, nbins);
}
}
Expand All @@ -204,9 +205,10 @@ void DistributedHistSynchronizer<GradientSumT>::ParallelSubtractionHist(
if (!((*p_tree)[entry.nid].IsLeftChild())) {
auto& this_hist = builder->hist_[entry.nid];

if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
auto& parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()];
auto& sibling_hist = builder->hist_[entry.sibling_nid];
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
auto& parent_hist = builder->hist_[parent_id];
auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)];
common::SubtractionHist(builder->qu_, &this_hist, parent_hist,
sibling_hist, nbins, ::sycl::event());
}
Expand Down Expand Up @@ -316,9 +318,9 @@ void QuantileHistMaker::Builder<GradientSumT>::BuildHistogramsLossGuide(
nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.push_back(entry);

if (entry.sibling_nid > -1) {
nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid,
p_tree->GetDepth(entry.sibling_nid), 0.0f, 0);
if (!(*p_tree)[entry.nid].IsRoot()) {
auto sibling_id = entry.GetSiblingId(p_tree);
nodes_for_subtraction_trick_.emplace_back(sibling_id, p_tree->GetDepth(sibling_id));
}

std::vector<int> sync_ids;
Expand Down Expand Up @@ -390,7 +392,6 @@ void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(
RegTree *p_tree,
int *num_leaves,
int depth,
unsigned *timestamp,
std::vector<ExpandEntry>* nodes_for_apply_split,
std::vector<ExpandEntry>* temp_qexpand_depth) {
auto evaluator = tree_evaluator_.GetEvaluator();
Expand All @@ -417,10 +418,8 @@ void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(

int left_id = (*p_tree)[nid].LeftChild();
int right_id = (*p_tree)[nid].RightChild();
temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id,
p_tree->GetDepth(left_id), 0.0, (*timestamp)++));
temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id,
p_tree->GetDepth(right_id), 0.0, (*timestamp)++));
temp_qexpand_depth->push_back(ExpandEntry(left_id, p_tree->GetDepth(left_id)));
temp_qexpand_depth->push_back(ExpandEntry(right_id, p_tree->GetDepth(right_id)));
// - 1 parent + 2 new children
(*num_leaves)++;
}
Expand All @@ -433,12 +432,11 @@ void QuantileHistMaker::Builder<GradientSumT>::EvaluateAndApplySplits(
RegTree *p_tree,
int *num_leaves,
int depth,
unsigned *timestamp,
std::vector<ExpandEntry> *temp_qexpand_depth) {
EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree);

std::vector<ExpandEntry> nodes_for_apply_split;
AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp,
AddSplitsToTree(gmat, p_tree, num_leaves, depth,
&nodes_for_apply_split, temp_qexpand_depth);
ApplySplit(nodes_for_apply_split, gmat, hist_, p_tree);
}
Expand Down Expand Up @@ -486,12 +484,11 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithDepthWise(
RegTree *p_tree,
const std::vector<GradientPair> &gpair,
const USMVector<GradientPair, MemoryType::on_device> &gpair_device) {
unsigned timestamp = 0;
int num_leaves = 0;

// in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway
qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++));
qexpand_depth_wise_.emplace_back(ExpandEntry::kRootNid,
p_tree->GetDepth(ExpandEntry::kRootNid));
++num_leaves;
for (int depth = 0; depth < param_.max_depth + 1; depth++) {
std::vector<int> sync_ids;
Expand All @@ -503,7 +500,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithDepthWise(
hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree);
BuildNodeStats(gmat, p_fmat, p_tree, gpair);

EvaluateAndApplySplits(gmat, p_tree, &num_leaves, depth, &timestamp,
EvaluateAndApplySplits(gmat, p_tree, &num_leaves, depth,
&temp_qexpand_depth);

// clean up
Expand All @@ -527,18 +524,16 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
const std::vector<GradientPair> &gpair,
const USMVector<GradientPair, MemoryType::on_device> &gpair_device) {
builder_monitor_.Start("ExpandWithLossGuide");
unsigned timestamp = 0;
int num_leaves = 0;
const auto lr = param_.learning_rate;

ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
p_tree->GetDepth(0), 0.0f, timestamp++);
ExpandEntry node(ExpandEntry::kRootNid, p_tree->GetDepth(ExpandEntry::kRootNid));
BuildHistogramsLossGuide(node, gmat, p_tree, gpair_device);

this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, *p_tree);

this->EvaluateSplits({node}, gmat, hist_, *p_tree);
node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;
node.split.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;

qexpand_loss_guided_->push(node);
++num_leaves;
Expand All @@ -547,7 +542,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
const ExpandEntry candidate = qexpand_loss_guided_->top();
const int nid = candidate.nid;
qexpand_loss_guided_->pop();
if (candidate.IsValid(param_, num_leaves)) {
if (!candidate.IsValid(param_, num_leaves)) {
(*p_tree)[nid].SetLeaf(snode_[nid].weight * lr);
} else {
auto evaluator = tree_evaluator_.GetEvaluator();
Expand All @@ -566,10 +561,8 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
const int cleft = (*p_tree)[nid].LeftChild();
const int cright = (*p_tree)[nid].RightChild();

ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft),
0.0f, timestamp++);
ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright),
0.0f, timestamp++);
ExpandEntry left_node(cleft, p_tree->GetDepth(cleft));
ExpandEntry right_node(cright, p_tree->GetDepth(cright));

if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
BuildHistogramsLossGuide(left_node, gmat, p_tree, gpair_device);
Expand All @@ -585,8 +578,8 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
interaction_constraints_.Split(nid, featureid, cleft, cright);

this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree);
left_node.loss_chg = snode_[cleft].best.loss_chg;
right_node.loss_chg = snode_[cright].best.loss_chg;
left_node.split.loss_chg = snode_[cleft].best.loss_chg;
right_node.split.loss_chg = snode_[cright].best.loss_chg;

qexpand_loss_guided_->push(left_node);
qexpand_loss_guided_->push(right_node);
Expand Down
8 changes: 3 additions & 5 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,15 +367,13 @@ class QuantileHistMaker: public TreeUpdater {
RegTree *p_tree,
int *num_leaves,
int depth,
unsigned *timestamp,
std::vector<ExpandEntry> *temp_qexpand_depth);

void AddSplitsToTree(
const GHistIndexMatrix &gmat,
RegTree *p_tree,
int *num_leaves,
int depth,
unsigned *timestamp,
std::vector<ExpandEntry>* nodes_for_apply_split,
std::vector<ExpandEntry>* temp_qexpand_depth);

Expand All @@ -388,10 +386,10 @@ class QuantileHistMaker: public TreeUpdater {
void ReduceHists(const std::vector<int>& sync_ids, size_t nbins);

inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
if (lhs.loss_chg == rhs.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
if (lhs.GetLossChange() == rhs.GetLossChange()) {
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
} else {
return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg
return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg
}
}
// --data fields--
Expand Down

0 comments on commit 9bc99f0

Please sign in to comment.