Skip to content

Commit

Permalink
temp knn
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga committed Apr 19, 2024
1 parent 04b4fcb commit 2c5fb18
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dpbench/benchmarks/default/knn/knn_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def _gen_test_data(test_size, data_dim, seed_test, dtype):
)
x_test = _gen_test_data(test_size, data_dim, seed_test, dtype)
predictions = np.empty(test_size, types_dict["int"])
votes_to_classes = np.zeros((test_size, classes_num), dtype)
votes_to_classes = np.zeros((test_size, classes_num), types_dict["int"])

return (x_train, y_train, x_test, predictions, votes_to_classes)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ sycl::event knn_impl(sycl::queue q,
size_t train_size,
size_t test_size,
IntTy *d_predictions,
FpTy *d_votes_to_classes,
IntTy *d_votes_to_classes,
size_t data_dim)
{
sycl::event partial_hists_ev = q.submit([&](sycl::handler &h) {
Expand Down Expand Up @@ -100,11 +100,11 @@ sycl::event knn_impl(sycl::queue q,

// simple vote
for (size_t j = 0; j < k; ++j) {
d_votes_to_classes[(i*classes_num) + (queue_neighbors[j].label)] += 1.0;
d_votes_to_classes[(i*classes_num) + (queue_neighbors[j].label)] += 1;
}

IntTy max_ind = 0;
FpTy max_value = 0.0;
IntTy max_value = 0.0;

for (IntTy j = 0; j < (IntTy)classes_num; ++j) {
if (d_votes_to_classes[i * classes_num + j] > max_value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train,
y_train.get_data<size_t>(), x_test.get_data<float>(), k,
classes_num, train_size, test_size,
predictions.get_data<unsigned int>(),
votes_to_classes.get_data<float>(), data_dim);
votes_to_classes.get_data<unsigned int>(), data_dim);
res_ev.wait();
}
else if (typenum == UAR_DOUBLE) {
sycl::event res_ev = knn_impl<double, size_t>(
x_train.get_queue(), x_train.get_data<double>(),
y_train.get_data<size_t>(), x_test.get_data<double>(), k,
classes_num, train_size, test_size, predictions.get_data<size_t>(),
votes_to_classes.get_data<double>(), data_dim);
votes_to_classes.get_data<size_t>(), data_dim);
res_ev.wait();
}
else {
Expand Down

0 comments on commit 2c5fb18

Please sign in to comment.