diff --git a/resolve/random/RandSketchingCountSketch.cpp b/resolve/random/RandSketchingCountSketch.cpp index 78f69779..e12f512a 100644 --- a/resolve/random/RandSketchingCountSketch.cpp +++ b/resolve/random/RandSketchingCountSketch.cpp @@ -58,14 +58,28 @@ namespace ReSolve */ int RandSketchingCountSketch::Theta(vector_type* input, vector_type* output) { - mem_.deviceSynchronize(); - count_sketch_theta(n_, - k_rand_, - d_labels_, - d_flip_, - input->getData(memspace_), - output->getData(memspace_)); - mem_.deviceSynchronize(); + using namespace memory; + + switch (memspace_) { + case DEVICE: + mem_.deviceSynchronize(); + count_sketch_theta(n_, + k_rand_, + d_labels_, + d_flip_, + input->getData(memspace_), + output->getData(memspace_)); + mem_.deviceSynchronize(); + break; + case HOST: + count_sketch_theta(n_, + k_rand_, + h_labels_, + h_flip_, + input->getData(memspace_), + output->getData(memspace_)); + break; + } return 0; } diff --git a/resolve/random/RandSketchingFWHT.cpp b/resolve/random/RandSketchingFWHT.cpp index 40065fec..8daf542c 100644 --- a/resolve/random/RandSketchingFWHT.cpp +++ b/resolve/random/RandSketchingFWHT.cpp @@ -83,27 +83,36 @@ namespace ReSolve switch (memspace_) { case DEVICE: mem_.setZeroArrayOnDevice(d_aux_, N_); + FWHT_scaleByD(n_, + d_D_, + input->getData(memspace_), + d_aux_); + + mem_.deviceSynchronize(); + FWHT(1, log2N_, d_aux_); + + mem_.deviceSynchronize(); + FWHT_select(k_rand_, + d_perm_, + d_aux_, + output->getData(memspace_)); + mem_.deviceSynchronize(); break; // remember - scaling is the solver's problem case HOST: - // std::memset(d_aux_, 0.0, static_cast(N_) * sizeof(real_type)); - mem_.setZeroArrayOnHost(d_aux_, N_); + std::memset(d_aux_, 0.0, static_cast(N_) * sizeof(real_type)); + FWHT_scaleByD(n_, + h_D_, + input->getData(memspace_), + d_aux_); + + FWHT(1, log2N_, d_aux_); + + FWHT_select(k_rand_, + h_perm_, + d_aux_, + output->getData(memspace_)); break; } - - FWHT_scaleByD(n_, - d_D_, - input->getData(memspace_), - d_aux_); - - mem_.deviceSynchronize(); - FWHT(1, log2N_, d_aux_); - - mem_.deviceSynchronize(); - FWHT_select(k_rand_, - d_perm_, - d_aux_, - output->getData(memspace_)); - mem_.deviceSynchronize(); return 0; }