diff --git a/clang/runtime/dpct-rt/include/dpct/rng_utils.hpp b/clang/runtime/dpct-rt/include/dpct/rng_utils.hpp index d741739cfdaf..ce6909899a2f 100644 --- a/clang/runtime/dpct-rt/include/dpct/rng_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/rng_utils.hpp @@ -338,9 +338,8 @@ class rng_generator : public rng_generator_base { /// Set the seed of host rng_generator. /// \param seed The engine seed. void set_seed(const std::uint64_t seed) { - if (seed == _seed) { + if (seed == _seed) return; - } _seed = seed; _engine = create_engine(_queue, _seed, _dimensions, _mode); } @@ -348,9 +347,8 @@ class rng_generator : public rng_generator_base { /// Set the dimensions of host rng_generator. /// \param dimensions The engine dimensions. void set_dimensions(const std::uint32_t dimensions) { - if (dimensions == _dimensions) { + if (dimensions == _dimensions) return; - } _dimensions = dimensions; _engine = create_engine(_queue, _seed, _dimensions, _mode); } @@ -358,9 +356,8 @@ class rng_generator : public rng_generator_base { /// Set the queue of host rng_generator. /// \param queue The engine queue. void set_queue(sycl::queue *queue) { - if (queue == _queue) { + if (queue == _queue) return; - } _queue = queue; _engine = create_engine(_queue, _seed, _dimensions, _mode); } @@ -374,9 +371,8 @@ class rng_generator : public rng_generator_base { if constexpr (!std::is_same_v) { throw std::runtime_error("Only mrg32k3a engine support this method."); } - if (mode == _mode) { + if (mode == _mode) return; - } _mode = mode; _engine = create_engine(_queue, _seed, _dimensions, _mode); #endif @@ -390,11 +386,11 @@ class rng_generator : public rng_generator_base { throw std::runtime_error(OneMKLNotSupport); #else if constexpr (std::is_same_v) { - if (direction_numbers == _direction_numbers) { + if (direction_numbers == _direction_numbers) return; - } _direction_numbers = direction_numbers; - _engine = oneapi::mkl::rng::sobol(*_queue, _direction_numbers); + _engine = + create_engine(_queue, _seed, _dimensions, _mode, _direction_numbers); } else { throw std::runtime_error("Only Sobol engine supports this method."); } @@ -409,11 +405,11 @@ class rng_generator : public rng_generator_base { "Interfaces Project does not support this API."); #else if constexpr (std::is_same_v) { - if (engine_idx == _engine_idx) { + if (engine_idx == _engine_idx) return; - } _engine_idx = engine_idx; - _engine = oneapi::mkl::rng::mt2203(*_queue, _seed, _engine_idx); + _engine = create_engine(_queue, _seed, _dimensions, _mode, std::nullopt, + _engine_idx); } else { throw std::runtime_error("Only MT2203 engine supports this method."); } @@ -525,10 +521,12 @@ class rng_generator : public rng_generator_base { } private: - static inline engine_t create_engine(sycl::queue *queue, - const std::uint64_t seed, - const std::uint32_t dimensions, - const random_mode mode) { + static inline engine_t + create_engine(sycl::queue *queue, const std::uint64_t seed, + const std::uint32_t dimensions, const random_mode mode, + std::optional> direction_numbers = + std::nullopt, + std::optional engine_idx = std::nullopt) { #ifdef __INTEL_MKL__ if constexpr (std::is_same_v) { // oneapi::mkl::rng::mrg32k3a_mode is only supported for GPU device. For @@ -546,13 +544,18 @@ class rng_generator : public rng_generator_base { oneapi::mkl::rng::mrg32k3a_mode::optimal_v); } } + } else if constexpr (std::is_same_v) { + if (engine_idx.has_value()) { + return engine_t(*queue, seed, engine_idx.value()); + } + } else if constexpr (std::is_same_v) { + if (direction_numbers.has_value()) { + return engine_t(*queue, direction_numbers.value()); + } + return engine_t(*queue, dimensions); } - return std::is_same_v - ? engine_t(*queue, dimensions) - : engine_t(*queue, seed); -#else - return engine_t(*queue, seed); #endif + return engine_t(*queue, seed); } template