diff --git a/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp b/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp index 9a87858c..ac1bcbf3 100644 --- a/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp +++ b/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp @@ -5,13 +5,15 @@ #include #include -struct neighbors +template class theKernel; + +template struct neighbors { - double dist; + FpTy dist; size_t label; }; -template +template sycl::event knn_impl(sycl::queue q, FpTy *d_train, size_t *d_train_labels, @@ -20,18 +22,18 @@ sycl::event knn_impl(sycl::queue q, size_t classes_num, size_t train_size, size_t test_size, - size_t *d_predictions, + IntTy *d_predictions, FpTy *d_votes_to_classes, size_t data_dim) { sycl::event partial_hists_ev = q.submit([&](sycl::handler &h) { - h.parallel_for( + h.parallel_for>( sycl::range<1>{test_size}, [=](sycl::id<1> myID) { size_t i = myID[0]; // here k has to be 5 in order to match with numpy no. of // neighbors - struct neighbors queue_neighbors[5]; + struct neighbors queue_neighbors[5]; // count distances for (size_t j = 0; j < k; ++j) { @@ -102,10 +104,10 @@ sycl::event knn_impl(sycl::queue q, queue_neighbors[j].label]++; } - size_t max_ind = 0; + IntTy max_ind = 0; FpTy max_value = 0.0; - for (size_t j = 0; j < classes_num; ++j) { + for (IntTy j = 0; j < (IntTy)classes_num; ++j) { if (d_votes_to_classes[i * classes_num + j] > max_value) { max_value = d_votes_to_classes[i * classes_num + j]; max_ind = j; diff --git a/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp b/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp index e7c13b9b..97938bb0 100644 --- a/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp +++ b/dpbench/benchmarks/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp @@ -5,8 +5,6 @@ #include "_knn_kernel.hpp" #include -#include - template bool ensure_compatibility(const Args &...args) { std::vector arrays = {args...}; @@ -41,16 +39,28 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train, votes_to_classes)) throw std::runtime_error("Input arrays are not acceptable."); - if (x_train.get_typenum() != UAR_DOUBLE) { - throw std::runtime_error("Expected a double precision FP array."); + auto typenum = x_train.get_typenum(); + if (typenum == UAR_FLOAT) { + sycl::event res_ev = knn_impl( + x_train.get_queue(), x_train.get_data(), + y_train.get_data(), x_test.get_data(), k, + classes_num, train_size, test_size, + predictions.get_data(), + votes_to_classes.get_data(), data_dim); + res_ev.wait(); + } + else if (typenum == UAR_DOUBLE) { + sycl::event res_ev = knn_impl( + x_train.get_queue(), x_train.get_data(), + y_train.get_data(), x_test.get_data(), k, + classes_num, train_size, test_size, predictions.get_data(), + votes_to_classes.get_data(), data_dim); + res_ev.wait(); + } + else { + throw std::runtime_error( + "Expected a double or single precision FP array."); } - - sycl::event res_ev = knn_impl( - x_train.get_queue(), x_train.get_data(), - y_train.get_data(), x_test.get_data(), k, classes_num, - train_size, test_size, predictions.get_data(), - votes_to_classes.get_data(), data_dim); - res_ev.wait(); } PYBIND11_MODULE(_knn_sycl, m)