Skip to content

Commit

Permalink
Fixing bugs introduced in earlier commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Mar 9, 2024
1 parent 8a9ef21 commit 12a2c5c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
30 changes: 22 additions & 8 deletions resolve/random/RandSketchingCountSketch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,28 @@ namespace ReSolve
*/
int RandSketchingCountSketch::Theta(vector_type* input, vector_type* output)
{
mem_.deviceSynchronize();
count_sketch_theta(n_,
k_rand_,
d_labels_,
d_flip_,
input->getData(memspace_),
output->getData(memspace_));
mem_.deviceSynchronize();
using namespace memory;

switch (memspace_) {
case DEVICE:
mem_.deviceSynchronize();
count_sketch_theta(n_,
k_rand_,
d_labels_,
d_flip_,
input->getData(memspace_),
output->getData(memspace_));
mem_.deviceSynchronize();
break;
case HOST:
count_sketch_theta(n_,
k_rand_,
h_labels_,
h_flip_,
input->getData(memspace_),
output->getData(memspace_));
break;
}
return 0;
}

Expand Down
43 changes: 26 additions & 17 deletions resolve/random/RandSketchingFWHT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,36 @@ namespace ReSolve
switch (memspace_) {
case DEVICE:
mem_.setZeroArrayOnDevice(d_aux_, N_);
FWHT_scaleByD(n_,
d_D_,
input->getData(memspace_),
d_aux_);

mem_.deviceSynchronize();
FWHT(1, log2N_, d_aux_);

mem_.deviceSynchronize();
FWHT_select(k_rand_,
d_perm_,
d_aux_,
output->getData(memspace_));
mem_.deviceSynchronize();
break; // remember - scaling is the solver's problem
case HOST:
// std::memset(d_aux_, 0.0, static_cast<size_t>(N_) * sizeof(real_type));
mem_.setZeroArrayOnHost(d_aux_, N_);
std::memset(d_aux_, 0.0, static_cast<size_t>(N_) * sizeof(real_type));
FWHT_scaleByD(n_,
h_D_,
input->getData(memspace_),
d_aux_);

FWHT(1, log2N_, d_aux_);

FWHT_select(k_rand_,
h_perm_,
d_aux_,
output->getData(memspace_));
break;
}

FWHT_scaleByD(n_,
d_D_,
input->getData(memspace_),
d_aux_);

mem_.deviceSynchronize();
FWHT(1, log2N_, d_aux_);

mem_.deviceSynchronize();
FWHT_select(k_rand_,
d_perm_,
d_aux_,
output->getData(memspace_));
mem_.deviceSynchronize();
return 0;
}

Expand Down

0 comments on commit 12a2c5c

Please sign in to comment.