Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Refine rng_utils.hpp #2510

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions clang/runtime/dpct-rt/include/dpct/rng_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,29 +338,26 @@ class rng_generator : public rng_generator_base {
/// Set the seed of host rng_generator.
/// \param seed The engine seed.
void set_seed(const std::uint64_t seed) {
if (seed == _seed) {
if (seed == _seed)
return;
}
_seed = seed;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
}

/// Set the dimensions of host rng_generator.
/// \param dimensions The engine dimensions.
void set_dimensions(const std::uint32_t dimensions) {
if (dimensions == _dimensions) {
if (dimensions == _dimensions)
return;
}
_dimensions = dimensions;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
}

/// Set the queue of host rng_generator.
/// \param queue The engine queue.
void set_queue(sycl::queue *queue) {
if (queue == _queue) {
if (queue == _queue)
return;
}
_queue = queue;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
}
Expand All @@ -374,9 +371,8 @@ class rng_generator : public rng_generator_base {
if constexpr (!std::is_same_v<engine_t, oneapi::mkl::rng::mrg32k3a>) {
throw std::runtime_error("Only mrg32k3a engine support this method.");
}
if (mode == _mode) {
if (mode == _mode)
return;
}
_mode = mode;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
#endif
Expand All @@ -390,11 +386,11 @@ class rng_generator : public rng_generator_base {
throw std::runtime_error(OneMKLNotSupport);
#else
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::sobol>) {
if (direction_numbers == _direction_numbers) {
if (direction_numbers == _direction_numbers)
return;
}
_direction_numbers = direction_numbers;
_engine = oneapi::mkl::rng::sobol(*_queue, _direction_numbers);
_engine =
create_engine(_queue, _seed, _dimensions, _mode, _direction_numbers);
} else {
throw std::runtime_error("Only Sobol engine supports this method.");
}
Expand All @@ -409,11 +405,11 @@ class rng_generator : public rng_generator_base {
"Interfaces Project does not support this API.");
#else
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mt2203>) {
if (engine_idx == _engine_idx) {
if (engine_idx == _engine_idx)
return;
}
_engine_idx = engine_idx;
_engine = oneapi::mkl::rng::mt2203(*_queue, _seed, _engine_idx);
_engine = create_engine(_queue, _seed, _dimensions, _mode, std::nullopt,
_engine_idx);
} else {
throw std::runtime_error("Only MT2203 engine supports this method.");
}
Expand Down Expand Up @@ -525,10 +521,12 @@ class rng_generator : public rng_generator_base {
}

private:
static inline engine_t create_engine(sycl::queue *queue,
const std::uint64_t seed,
const std::uint32_t dimensions,
const random_mode mode) {
static inline engine_t
create_engine(sycl::queue *queue, const std::uint64_t seed,
const std::uint32_t dimensions, const random_mode mode,
std::optional<std::vector<std::uint32_t>> direction_numbers =
std::nullopt,
std::optional<std::uint32_t> engine_idx = std::nullopt) {
#ifdef __INTEL_MKL__
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mrg32k3a>) {
// oneapi::mkl::rng::mrg32k3a_mode is only supported for GPU device. For
Expand All @@ -546,13 +544,18 @@ class rng_generator : public rng_generator_base {
oneapi::mkl::rng::mrg32k3a_mode::optimal_v);
}
}
} else if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mt2203>) {
if (engine_idx.has_value()) {
return engine_t(*queue, seed, engine_idx.value());
}
} else if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::sobol>) {
if (direction_numbers.has_value()) {
return engine_t(*queue, direction_numbers.value());
}
return engine_t(*queue, dimensions);
}
return std::is_same_v<engine_t, oneapi::mkl::rng::sobol>
? engine_t(*queue, dimensions)
: engine_t(*queue, seed);
#else
return engine_t(*queue, seed);
#endif
return engine_t(*queue, seed);
}

template <typename distr_t, typename buffer_t, class... distr_params_t>
Expand Down