-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updates and test coverage for hipGraph support in rocRAND (#439)
This change allows device-side generators to be used inside of hipGraphs. More specifically, you can wrap calls to rocrand_generate_* inside of a hipGraph. There are a few things to be aware of: - Generator creation (rocrand_create_generator), initialization (rocrand_initialize_generator), and destruction (rocrand_destroy_generator) must still happen outside the hipGraph. - After the generator is created, you may call API functions to set its seed, offset, and order. - After the generator is initialized (but before stream capture or manual graph creation begins), use rocrand_set_stream to set the stream the generator will use within the graph. - A generator's seed, offset, and stream may not be changed from within the hipGraph. Attempting to do so may result in unpredicable behaviour. - API calls for the poisson distribution (eg. rocrand_generate_poisson) are not yet supported inside of hipGraphs. I've added a note to the changelog that mentions these details. In addition to the changes necessary to support the behaviour described above, this change also: - updates the changelog to alert the user to the restrictions mentioned above - adds new unit test coverage to exercises generators and distributions within hipGraphs.
- Loading branch information
Showing
3 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#include <stdio.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include <hip/hip_runtime.h> | ||
#include <rocrand/rocrand.h> | ||
|
||
#include "test_common.hpp" | ||
#include "test_rocrand_common.hpp" | ||
#include "test_utils_hipgraphs.hpp" | ||
|
||
class rocrand_hipgraph_generate_tests : public ::testing::TestWithParam<rocrand_rng_type> {}; | ||
|
||
void test_float(std::function<rocrand_status(rocrand_generator, float*, size_t, float, float)> generate_fn, rocrand_rng_type rng_type) | ||
{ | ||
rocrand_generator generator; | ||
ROCRAND_CHECK( | ||
rocrand_create_generator( | ||
&generator, | ||
rng_type | ||
) | ||
); | ||
|
||
ROCRAND_CHECK(rocrand_initialize_generator(generator)); | ||
|
||
const size_t size = 12563; | ||
float mean = 5.0f; | ||
float stddev = 2.0f; | ||
float * data; | ||
HIP_CHECK(hipMallocHelper(&data, size * sizeof(float))); | ||
HIP_CHECK(hipDeviceSynchronize()); | ||
|
||
// Default stream does not support hipGraph stream capture, so create a non-blocking one | ||
hipStream_t stream = 0; | ||
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); | ||
rocrand_set_stream(generator, stream); | ||
|
||
hipGraphExec_t graph_instance; | ||
hipGraph_t graph = test_utils::createGraphHelper(stream); | ||
|
||
// Any sizes | ||
ROCRAND_CHECK( | ||
generate_fn(generator, data, 1, mean, stddev) | ||
); | ||
|
||
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); | ||
test_utils::resetGraphHelper(graph, graph_instance, stream); | ||
|
||
// Any alignment | ||
ROCRAND_CHECK( | ||
generate_fn(generator, data+1, 2, mean, stddev) | ||
); | ||
|
||
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); | ||
test_utils::resetGraphHelper(graph, graph_instance, stream); | ||
|
||
ROCRAND_CHECK( | ||
generate_fn(generator, data, size, mean, stddev) | ||
); | ||
|
||
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); | ||
|
||
HIP_CHECK(hipFree(data)); | ||
ROCRAND_CHECK(rocrand_destroy_generator(generator)); | ||
test_utils::cleanupGraphHelper(graph, graph_instance); | ||
HIP_CHECK(hipStreamDestroy(stream)); | ||
} | ||
|
||
TEST_P(rocrand_hipgraph_generate_tests, normal_float_test) | ||
{ | ||
auto generator_fcn = [](rocrand_generator generator, float* output_data, size_t n, float mean, float stddev) | ||
{ | ||
return rocrand_generate_normal(generator, output_data, n, mean, stddev); | ||
}; | ||
|
||
test_float(generator_fcn, GetParam()); | ||
} | ||
|
||
TEST_P(rocrand_hipgraph_generate_tests, log_normal_float_test) | ||
{ | ||
auto generator_fcn = [](rocrand_generator generator, float* output_data, size_t n, float mean, float stddev) | ||
{ | ||
return rocrand_generate_log_normal(generator, output_data, n, mean, stddev); | ||
}; | ||
|
||
test_float(generator_fcn, GetParam()); | ||
} | ||
|
||
TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test) | ||
{ | ||
const rocrand_rng_type rng_type = GetParam(); | ||
|
||
rocrand_generator generator; | ||
ROCRAND_CHECK( | ||
rocrand_create_generator( | ||
&generator, | ||
rng_type | ||
) | ||
); | ||
|
||
ROCRAND_CHECK(rocrand_initialize_generator(generator)); | ||
|
||
const size_t size = 12563; | ||
float * data; | ||
HIP_CHECK(hipMallocHelper(&data, size * sizeof(float))); | ||
HIP_CHECK(hipDeviceSynchronize()); | ||
|
||
// Default stream does not support hipGraph stream capture, so create a non-blocking one | ||
hipStream_t stream = 0; | ||
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); | ||
rocrand_set_stream(generator, stream); | ||
|
||
hipGraphExec_t graph_instance; | ||
hipGraph_t graph = test_utils::createGraphHelper(stream); | ||
|
||
// Any sizes | ||
ROCRAND_CHECK( | ||
rocrand_generate_uniform(generator, data, 1) | ||
); | ||
|
||
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); | ||
test_utils::resetGraphHelper(graph, graph_instance, stream); | ||
|
||
// Any alignment | ||
ROCRAND_CHECK( | ||
rocrand_generate_uniform(generator, data+1, 2) | ||
); | ||
|
||
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); | ||
test_utils::resetGraphHelper(graph, graph_instance, stream); | ||
|
||
ROCRAND_CHECK( | ||
rocrand_generate_uniform(generator, data, size) | ||
); | ||
|
||
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); | ||
|
||
HIP_CHECK(hipFree(data)); | ||
ROCRAND_CHECK(rocrand_destroy_generator(generator)); | ||
test_utils::cleanupGraphHelper(graph, graph_instance); | ||
HIP_CHECK(hipStreamDestroy(stream)); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(rocrand_hipgraph_generate_tests, | ||
rocrand_hipgraph_generate_tests, | ||
::testing::ValuesIn(rng_types)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
// | ||
// Permission is hereby granted, free of charge, to any person obtaining a copy | ||
// of this software and associated documentation files (the "Software"), to deal | ||
// in the Software without restriction, including without limitation the rights | ||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
// copies of the Software, and to permit persons to whom the Software is | ||
// furnished to do so, subject to the following conditions: | ||
// | ||
// The above copyright notice and this permission notice shall be included in | ||
// all copies or substantial portions of the Software. | ||
// | ||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
// THE SOFTWARE. | ||
|
||
#ifndef ROCRAND_TEST_UTILS_HIPGRAPHS_HPP | ||
#define ROCRAND_TEST_UTILS_HIPGRAPHS_HPP | ||
|
||
#include <hip/hip_runtime.h> | ||
#include "test_common.hpp" | ||
|
||
// Helper functions for testing with hipGraph stream capture. | ||
// Note: graphs will not work on the default stream. | ||
namespace test_utils | ||
{ | ||
|
||
inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true) | ||
{ | ||
// Create a new graph | ||
hipGraph_t graph; | ||
HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0)); | ||
|
||
// Optionally begin stream capture | ||
if (beginCapture) | ||
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); | ||
|
||
return graph; | ||
} | ||
|
||
inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance) | ||
{ | ||
HIP_CHECK_NON_VOID(hipGraphDestroy(graph)); | ||
HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance)); | ||
} | ||
|
||
inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true) | ||
{ | ||
// Destroy the old graph and instance | ||
cleanupGraphHelper(graph, instance); | ||
|
||
// Create a new graph and optionally begin capture | ||
graph = createGraphHelper(stream, beginCapture); | ||
} | ||
|
||
inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false) | ||
{ | ||
// End the capture | ||
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph)); | ||
|
||
// Instantiate the graph | ||
hipGraphExec_t instance; | ||
HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); | ||
|
||
// Optionally launch the graph | ||
if (launchGraph) | ||
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream)); | ||
|
||
// Optionally synchronize the stream when we're done | ||
if (sync) | ||
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream)); | ||
|
||
return instance; | ||
} | ||
|
||
inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false) | ||
{ | ||
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream)); | ||
|
||
// Optionally sync after the launch | ||
if (sync) | ||
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream)); | ||
} | ||
|
||
} // end namespace test_utils | ||
|
||
#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP |