From 9bc99f08bd3017a5d804a6a3adab088b9787b8da Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Tue, 27 Feb 2024 14:02:44 +0100 Subject: [PATCH] make sycl::ExpandEntry more similar to original one (#34) Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/tree/expand_entry.h | 48 ++++++++++-------- plugin/sycl/tree/updater_quantile_hist.cc | 59 ++++++++++------------- plugin/sycl/tree/updater_quantile_hist.h | 8 ++- 3 files changed, 57 insertions(+), 58 deletions(-) diff --git a/plugin/sycl/tree/expand_entry.h b/plugin/sycl/tree/expand_entry.h index 61b803f66819..2520ff95db5a 100644 --- a/plugin/sycl/tree/expand_entry.h +++ b/plugin/sycl/tree/expand_entry.h @@ -1,6 +1,6 @@ /*! - * Copyright 2017-2021 by Contributors - * \file updater_quantile_hist.h + * Copyright 2017-2024 by Contributors + * \file expand_entry.h */ #ifndef PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ #define PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ @@ -9,29 +9,37 @@ #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #include "../../src/tree/constraints.h" #pragma GCC diagnostic pop +#include "../../src/tree/hist/expand_entry.h" namespace xgboost { namespace sycl { namespace tree { /* tree growing policies */ -struct ExpandEntry { - static const int kRootNid = 0; - static const int kEmptyNid = -1; - int nid; - int sibling_nid; - int depth; - bst_float loss_chg; - unsigned timestamp; - ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, - unsigned tstmp) - : nid(nid), sibling_nid(sibling_nid), depth(depth), - loss_chg(loss_chg), timestamp(tstmp) {} - - bool IsValid(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const { - bool ret = loss_chg <= kRtEps || - (param.max_depth > 0 && this->depth == param.max_depth) || - (param.max_leaves > 0 && num_leaves == param.max_leaves); - return ret; +struct ExpandEntry : public xgboost::tree::ExpandEntryImpl { + static constexpr bst_node_t kRootNid = 0; + + xgboost::tree::SplitEntry split; + + ExpandEntry(int nid, int depth) : ExpandEntryImpl{nid, depth} {} + + inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree) const { + CHECK_EQ((*p_tree)[nid].IsRoot(), false); + const size_t parent_id = (*p_tree)[nid].Parent(); + return GetSiblingId(p_tree, parent_id); + } + + inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree, size_t parent_id) const { + return p_tree->IsLeftChild(nid) ? p_tree->RightChild(parent_id) + : p_tree->LeftChild(parent_id); + } + + bool IsValidImpl(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + if (split.loss_chg < param.min_split_loss) return false; + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + + return true; } }; diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index 3411b66e7dc5..5b0ef4517cb0 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -148,10 +148,10 @@ void BatchHistSynchronizer::SyncHistograms(BuilderT *builder, const auto entry = builder->nodes_for_explicit_hist_build_[i]; auto& this_hist = builder->hist_[entry.nid]; - if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { + if (!(*p_tree)[entry.nid].IsRoot()) { const size_t parent_id = (*p_tree)[entry.nid].Parent(); auto& parent_hist = builder->hist_[parent_id]; - auto& sibling_hist = builder->hist_[entry.sibling_nid]; + auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)]; hist_sync_events_[i] = common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist, this_hist, nbins, ::sycl::event()); } @@ -174,14 +174,15 @@ void DistributedHistSynchronizer::SyncHistograms(BuilderT* builder auto& this_local = builder->hist_local_worker_[entry.nid]; common::CopyHist(builder->qu_, &this_local, this_hist, nbins); - if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { + if (!(*p_tree)[entry.nid].IsRoot()) { const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto sibling_nid = entry.GetSiblingId(p_tree, parent_id); auto& parent_hist = builder->hist_local_worker_[parent_id]; - auto& sibling_hist = builder->hist_[entry.sibling_nid]; + auto& sibling_hist = builder->hist_[sibling_nid]; common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist, this_hist, nbins, ::sycl::event()); // Store posible parent node - auto& sibling_local = builder->hist_local_worker_[entry.sibling_nid]; + auto& sibling_local = builder->hist_local_worker_[sibling_nid]; common::CopyHist(builder->qu_, &sibling_local, sibling_hist, nbins); } } @@ -204,9 +205,10 @@ void DistributedHistSynchronizer::ParallelSubtractionHist( if (!((*p_tree)[entry.nid].IsLeftChild())) { auto& this_hist = builder->hist_[entry.nid]; - if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { - auto& parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()]; - auto& sibling_hist = builder->hist_[entry.sibling_nid]; + if (!(*p_tree)[entry.nid].IsRoot()) { + const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto& parent_hist = builder->hist_[parent_id]; + auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)]; common::SubtractionHist(builder->qu_, &this_hist, parent_hist, sibling_hist, nbins, ::sycl::event()); } @@ -316,9 +318,9 @@ void QuantileHistMaker::Builder::BuildHistogramsLossGuide( nodes_for_subtraction_trick_.clear(); nodes_for_explicit_hist_build_.push_back(entry); - if (entry.sibling_nid > -1) { - nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid, - p_tree->GetDepth(entry.sibling_nid), 0.0f, 0); + if (!(*p_tree)[entry.nid].IsRoot()) { + auto sibling_id = entry.GetSiblingId(p_tree); + nodes_for_subtraction_trick_.emplace_back(sibling_id, p_tree->GetDepth(sibling_id)); } std::vector sync_ids; @@ -390,7 +392,6 @@ void QuantileHistMaker::Builder::AddSplitsToTree( RegTree *p_tree, int *num_leaves, int depth, - unsigned *timestamp, std::vector* nodes_for_apply_split, std::vector* temp_qexpand_depth) { auto evaluator = tree_evaluator_.GetEvaluator(); @@ -417,10 +418,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree( int left_id = (*p_tree)[nid].LeftChild(); int right_id = (*p_tree)[nid].RightChild(); - temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, - p_tree->GetDepth(left_id), 0.0, (*timestamp)++)); - temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id, - p_tree->GetDepth(right_id), 0.0, (*timestamp)++)); + temp_qexpand_depth->push_back(ExpandEntry(left_id, p_tree->GetDepth(left_id))); + temp_qexpand_depth->push_back(ExpandEntry(right_id, p_tree->GetDepth(right_id))); // - 1 parent + 2 new children (*num_leaves)++; } @@ -433,12 +432,11 @@ void QuantileHistMaker::Builder::EvaluateAndApplySplits( RegTree *p_tree, int *num_leaves, int depth, - unsigned *timestamp, std::vector *temp_qexpand_depth) { EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree); std::vector nodes_for_apply_split; - AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp, + AddSplitsToTree(gmat, p_tree, num_leaves, depth, &nodes_for_apply_split, temp_qexpand_depth); ApplySplit(nodes_for_apply_split, gmat, hist_, p_tree); } @@ -486,12 +484,11 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise( RegTree *p_tree, const std::vector &gpair, const USMVector &gpair_device) { - unsigned timestamp = 0; int num_leaves = 0; // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway - qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, - p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++)); + qexpand_depth_wise_.emplace_back(ExpandEntry::kRootNid, + p_tree->GetDepth(ExpandEntry::kRootNid)); ++num_leaves; for (int depth = 0; depth < param_.max_depth + 1; depth++) { std::vector sync_ids; @@ -503,7 +500,7 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise( hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree); BuildNodeStats(gmat, p_fmat, p_tree, gpair); - EvaluateAndApplySplits(gmat, p_tree, &num_leaves, depth, ×tamp, + EvaluateAndApplySplits(gmat, p_tree, &num_leaves, depth, &temp_qexpand_depth); // clean up @@ -527,18 +524,16 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( const std::vector &gpair, const USMVector &gpair_device) { builder_monitor_.Start("ExpandWithLossGuide"); - unsigned timestamp = 0; int num_leaves = 0; const auto lr = param_.learning_rate; - ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, - p_tree->GetDepth(0), 0.0f, timestamp++); + ExpandEntry node(ExpandEntry::kRootNid, p_tree->GetDepth(ExpandEntry::kRootNid)); BuildHistogramsLossGuide(node, gmat, p_tree, gpair_device); this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, *p_tree); this->EvaluateSplits({node}, gmat, hist_, *p_tree); - node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg; + node.split.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg; qexpand_loss_guided_->push(node); ++num_leaves; @@ -547,7 +542,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( const ExpandEntry candidate = qexpand_loss_guided_->top(); const int nid = candidate.nid; qexpand_loss_guided_->pop(); - if (candidate.IsValid(param_, num_leaves)) { + if (!candidate.IsValid(param_, num_leaves)) { (*p_tree)[nid].SetLeaf(snode_[nid].weight * lr); } else { auto evaluator = tree_evaluator_.GetEvaluator(); @@ -566,10 +561,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( const int cleft = (*p_tree)[nid].LeftChild(); const int cright = (*p_tree)[nid].RightChild(); - ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft), - 0.0f, timestamp++); - ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright), - 0.0f, timestamp++); + ExpandEntry left_node(cleft, p_tree->GetDepth(cleft)); + ExpandEntry right_node(cright, p_tree->GetDepth(cright)); if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { BuildHistogramsLossGuide(left_node, gmat, p_tree, gpair_device); @@ -585,8 +578,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( interaction_constraints_.Split(nid, featureid, cleft, cright); this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree); - left_node.loss_chg = snode_[cleft].best.loss_chg; - right_node.loss_chg = snode_[cright].best.loss_chg; + left_node.split.loss_chg = snode_[cleft].best.loss_chg; + right_node.split.loss_chg = snode_[cright].best.loss_chg; qexpand_loss_guided_->push(left_node); qexpand_loss_guided_->push(right_node); diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h index 0246e88f09a6..a8b4831fca9e 100644 --- a/plugin/sycl/tree/updater_quantile_hist.h +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -367,7 +367,6 @@ class QuantileHistMaker: public TreeUpdater { RegTree *p_tree, int *num_leaves, int depth, - unsigned *timestamp, std::vector *temp_qexpand_depth); void AddSplitsToTree( @@ -375,7 +374,6 @@ class QuantileHistMaker: public TreeUpdater { RegTree *p_tree, int *num_leaves, int depth, - unsigned *timestamp, std::vector* nodes_for_apply_split, std::vector* temp_qexpand_depth); @@ -388,10 +386,10 @@ class QuantileHistMaker: public TreeUpdater { void ReduceHists(const std::vector& sync_ids, size_t nbins); inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { - if (lhs.loss_chg == rhs.loss_chg) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp + if (lhs.GetLossChange() == rhs.GetLossChange()) { + return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp } else { - return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg + return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg } } // --data fields--