forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for sycl implemenation for GHistIndexMatrix (#32)
* initial * move helper functions to a separate file * lintint --------- Co-authored-by: Dmitry Razdoburdin <>
- Loading branch information
1 parent
9c81341
commit cc34883
Showing
5 changed files
with
148 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/*! | ||
* Copyright 2022-2024 XGBoost contributors | ||
*/ | ||
#pragma once | ||
|
||
#include "../helpers.h" | ||
|
||
namespace xgboost::sycl { | ||
template<typename T, typename Container> | ||
void VerifySyclVector(const USMVector<T, MemoryType::shared>& sycl_vector, | ||
const Container& host_vector) { | ||
ASSERT_EQ(sycl_vector.Size(), host_vector.size()); | ||
|
||
size_t size = sycl_vector.Size(); | ||
for (size_t i = 0; i < size; ++i) { | ||
ASSERT_EQ(sycl_vector[i], host_vector[i]); | ||
} | ||
} | ||
|
||
template<typename T, typename Container> | ||
void VerifySyclVector(const std::vector<T>& sycl_vector, const Container& host_vector) { | ||
ASSERT_EQ(sycl_vector.size(), host_vector.size()); | ||
|
||
size_t size = sycl_vector.size(); | ||
for (size_t i = 0; i < size; ++i) { | ||
ASSERT_EQ(sycl_vector[i], host_vector[i]); | ||
} | ||
} | ||
|
||
} // namespace xgboost::sycl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/** | ||
* Copyright 2021-2024 by XGBoost contributors | ||
*/ | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#pragma GCC diagnostic ignored "-W#pragma-messages" | ||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix | ||
#pragma GCC diagnostic pop | ||
|
||
#include "../../../plugin/sycl/data/gradient_index.h" | ||
#include "../../../plugin/sycl/device_manager.h" | ||
#include "sycl_helpers.h" | ||
#include "../helpers.h" | ||
|
||
namespace xgboost::sycl::data { | ||
|
||
TEST(SyclGradientIndex, HistogramCuts) { | ||
size_t max_bins = 8; | ||
|
||
Context ctx; | ||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); | ||
|
||
DeviceManager device_manager; | ||
auto qu = device_manager.GetQueue(ctx.Device()); | ||
|
||
auto p_fmat = RandomDataGenerator{512, 16, 0.5}.GenerateDMatrix(true); | ||
|
||
xgboost::common::HistogramCuts cut = | ||
xgboost::common::SketchOnDMatrix(&ctx, p_fmat.get(), max_bins); | ||
|
||
common::HistogramCuts cut_sycl; | ||
cut_sycl.Init(qu, cut); | ||
|
||
VerifySyclVector(cut_sycl.Ptrs(), cut.cut_ptrs_.HostVector()); | ||
VerifySyclVector(cut_sycl.Values(), cut.cut_values_.HostVector()); | ||
VerifySyclVector(cut_sycl.MinValues(), cut.min_vals_.HostVector()); | ||
} | ||
|
||
TEST(SyclGradientIndex, Init) { | ||
size_t n_rows = 128; | ||
size_t n_columns = 7; | ||
|
||
Context ctx; | ||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); | ||
|
||
DeviceManager device_manager; | ||
auto qu = device_manager.GetQueue(ctx.Device()); | ||
|
||
auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix(); | ||
|
||
sycl::DeviceMatrix dmat; | ||
dmat.Init(qu, p_fmat.get()); | ||
|
||
int max_bins = 256; | ||
common::GHistIndexMatrix gmat_sycl; | ||
gmat_sycl.Init(qu, &ctx, dmat, max_bins); | ||
|
||
xgboost::GHistIndexMatrix gmat{&ctx, p_fmat.get(), max_bins, 0.3, false}; | ||
|
||
{ | ||
ASSERT_EQ(gmat_sycl.max_num_bins, max_bins); | ||
ASSERT_EQ(gmat_sycl.nfeatures, n_columns); | ||
} | ||
|
||
{ | ||
VerifySyclVector(gmat_sycl.hit_count, gmat.hit_count); | ||
} | ||
|
||
{ | ||
std::vector<size_t> feature_count_sycl(n_columns, 0); | ||
gmat_sycl.GetFeatureCounts(feature_count_sycl.data()); | ||
|
||
std::vector<size_t> feature_count(n_columns, 0); | ||
gmat.GetFeatureCounts(feature_count.data()); | ||
VerifySyclVector(feature_count_sycl, feature_count); | ||
} | ||
} | ||
|
||
} // namespace xgboost::sycl::data |