Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for depth-wise policy #63

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
HistUpdater<GradientSumT>::ExpandWithLossGuide(gmat, p_tree, gpair);
}

auto TestExpandWithDepthWise(const common::GHistIndexMatrix& gmat,
DMatrix *p_fmat,
RegTree* p_tree,
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
HistUpdater<GradientSumT>::ExpandWithDepthWise(gmat, p_tree, gpair);
}
};

void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
Expand Down Expand Up @@ -532,6 +539,53 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param)
ASSERT_NEAR(ans[2], -0.15, 1e-6);
}

template <typename GradientSumT>
void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param) {
const size_t num_rows = 3;
const size_t num_columns = 1;
const size_t n_bins = 16;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());

std::vector<float> data = {7, 3, 15};
auto p_fmat = GetDMatrixFromData(data, num_rows, num_columns);

DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());
common::GHistIndexMatrix gmat;
gmat.Init(qu, &ctx, dmat, n_bins);

std::vector<GradientPair> gpair_host = {{1, 2}, {3, 1}, {1, 1}};
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, gpair_host);

RegTree tree;
FeatureInteractionConstraintHost int_constraints;
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);

updater.TestExpandWithDepthWise(gmat, p_fmat.get(), &tree, gpair);

const auto& nodes = tree.GetNodes();
std::vector<float> ans(data.size());
for (size_t data_idx = 0; data_idx < data.size(); ++data_idx) {
size_t node_idx = 0;
while (!nodes[node_idx].IsLeaf()) {
node_idx = data[data_idx] < nodes[node_idx].SplitCond() ? nodes[node_idx].LeftChild() : nodes[node_idx].RightChild();
}
ans[data_idx] = nodes[node_idx].LeafValue();
}

ASSERT_NEAR(ans[0], -0.15, 1e-6);
ASSERT_NEAR(ans[1], -0.45, 1e-6);
ASSERT_NEAR(ans[2], -0.15, 1e-6);
}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
Expand Down Expand Up @@ -608,4 +662,12 @@ TEST(SyclHistUpdater, ExpandWithLossGuide) {
TestHistUpdaterExpandWithLossGuide<double>(param);
}

TEST(SyclHistUpdater, ExpandWithDepthWise) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "2"}});

TestHistUpdaterExpandWithDepthWise<float>(param);
TestHistUpdaterExpandWithDepthWise<double>(param);
}

} // namespace xgboost::sycl::tree
Loading