diff --git a/dpbench/benchmarks/default/knn/knn_initialize.py b/dpbench/benchmarks/default/knn/knn_initialize.py index 720b4348..2e76a31f 100644 --- a/dpbench/benchmarks/default/knn/knn_initialize.py +++ b/dpbench/benchmarks/default/knn/knn_initialize.py @@ -41,6 +41,6 @@ def _gen_test_data(test_size, data_dim, seed_test, dtype): ) x_test = _gen_test_data(test_size, data_dim, seed_test, dtype) predictions = np.empty(test_size, types_dict["int"]) - votes_to_classes = np.zeros((test_size, classes_num), dtype) + votes_to_classes = np.zeros((test_size, classes_num), types_dict["int"]) return (x_train, y_train, x_test, predictions, votes_to_classes) diff --git a/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp b/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp index 63db7a5a..75cccd3b 100644 --- a/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp +++ b/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp @@ -23,7 +23,7 @@ sycl::event knn_impl(sycl::queue q, size_t train_size, size_t test_size, IntTy *d_predictions, - FpTy *d_votes_to_classes, + IntTy *d_votes_to_classes, size_t data_dim) { sycl::event partial_hists_ev = q.submit([&](sycl::handler &h) { @@ -100,11 +100,11 @@ sycl::event knn_impl(sycl::queue q, // simple vote for (size_t j = 0; j < k; ++j) { - d_votes_to_classes[(i*classes_num) + (queue_neighbors[j].label)] += 1.0; + d_votes_to_classes[(i*classes_num) + (queue_neighbors[j].label)] += 1; } IntTy max_ind = 0; - FpTy max_value = 0.0; + IntTy max_value = 0.0; for (IntTy j = 0; j < (IntTy)classes_num; ++j) { if (d_votes_to_classes[i * classes_num + j] > max_value) { diff --git a/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp b/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp index 97938bb0..7976c4b8 100644 --- a/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp +++ b/dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp @@ -46,7 +46,7 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train, y_train.get_data(), x_test.get_data(), k, classes_num, train_size, test_size, predictions.get_data(), - votes_to_classes.get_data(), data_dim); + votes_to_classes.get_data(), data_dim); res_ev.wait(); } else if (typenum == UAR_DOUBLE) { @@ -54,7 +54,7 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train, x_train.get_queue(), x_train.get_data(), y_train.get_data(), x_test.get_data(), k, classes_num, train_size, test_size, predictions.get_data(), - votes_to_classes.get_data(), data_dim); + votes_to_classes.get_data(), data_dim); res_ev.wait(); } else {