Skip to content

Commit

Permalink
test knn labels
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga committed Apr 21, 2024
1 parent a0b47c8 commit 3e38be1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dpbench/benchmarks/default/knn/knn_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _gen_data_x(ip_size, data_dim, seed, dtype):

def _gen_data_y(ip_size, classes_num, seed):
default_rng.seed(seed)
data = default_rng.randint(classes_num, size=ip_size)
data = default_rng.randint(classes_num, size=ip_size, dtype=np.int64)
return data

def _gen_train_data(train_size, data_dim, classes_num, seed_train, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ sycl::event knn_impl(sycl::queue q,

size_t ith_idx = i * classes_num;
size_t jth = queue_neighbors[j].label;
out << ith_idx << " " << jth << sycl::endl;
//out << ith_idx << " " << jth << sycl::endl;
//d_votes_to_classes[ith_idx + jth] = jth;
//d_votes_to_classes[(i*classes_num) + (queue_neighbors[j].label)] += 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train,
votes_to_classes))
throw std::runtime_error("Input arrays are not acceptable.");

size_t* train_labels = y_train.get_data<size_t>();
std::cout << "Labels size = " << y_train.get_size() << std::endl;
sycl::queue q = x_train.get_queue();

size_t h_labels[1024];

q.memcpy(h_labels,train_labels,sizeof(size_t)*1024);
q.wait();

for (int idx = 0; idx<1024;idx++) {
std::cout << "Train labels: "<< h_labels[idx] << std::endl;
}

auto typenum = x_train.get_typenum();
if (typenum == UAR_FLOAT) {
sycl::event res_ev = knn_impl<float, unsigned int>(
Expand Down

0 comments on commit 3e38be1

Please sign in to comment.