Skip to content

Commit

Permalink
enable test CApiTest.basic_cuda_graph_with_annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdaily committed Oct 15, 2024
1 parent f464de8 commit 898293b
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& mod
#endif
} else if (provider_type == 3) {
#ifdef USE_ROCM
std::cout << "Running simple inference with rocm provider" << std::endl;
OrtROCMProviderOptions rocm_options;
session_options.AppendExecutionProvider_ROCM(rocm_options);
#else
Expand Down Expand Up @@ -384,7 +385,7 @@ static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& mod
}

static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
#if defined(USE_CUDA) || defined(USE_DML)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx");
#endif
static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx");
Expand Down Expand Up @@ -2341,7 +2342,7 @@ TEST(CApiTest, basic_cuda_graph) {
#endif
}

#if defined(USE_CUDA) || defined(USE_DML)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
struct CudaGraphInputOutputData_0 {
const std::array<int64_t, 2> x_shape = {3, 2};
std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
Expand Down Expand Up @@ -2385,6 +2386,12 @@ static void RunWithCudaGraphAnnotation(T& cg_data,
Ort::MemoryAllocation& input_data,
Ort::MemoryAllocation& output_data,
const char* cuda_graph_annotation) {
// a local hipify of select cuda symbols to avoid code duplication
#ifdef USE_ROCM
#define cudaMemcpy hipMemcpy
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#endif
#ifdef USE_DML
Ort::SessionOptions session_options;
Ort::Allocator allocator(session, info_mem);
Expand Down Expand Up @@ -2488,6 +2495,11 @@ static void RunWithCudaGraphAnnotation(T& cg_data,
// Clean up
binding.ClearBoundInputs();
binding.ClearBoundOutputs();
#ifdef USE_ROCM
#undef cudaMemcpy
#undef cudaMemcpyHostToDevice
#undef cudaMemcpyDeviceToHost
#endif
}

TEST(CApiTest, basic_cuda_graph_with_annotation) {
Expand All @@ -2502,7 +2514,7 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) {
ort_dml_api->SessionOptionsAppendExecutionProvider_DML1(session_options, dml_objects.dml_device.Get(), dml_objects.command_queue.Get());

Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault);
#else
#elif defined(USE_CUDA)
// Enable cuda graph in cuda provider option.
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr);
Expand All @@ -2516,6 +2528,20 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) {
static_cast<OrtSessionOptions*>(session_options),
rel_cuda_options.get()) == nullptr);
Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#elif defined(USE_ROCM)
// Enable hip graph in rocm provider option.
OrtROCMProviderOptions* rocm_options = nullptr;
ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr);
std::unique_ptr<OrtROCMProviderOptions, decltype(api.ReleaseROCMProviderOptions)>
rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions);
std::vector<const char*> keys{"enable_hip_graph"};
std::vector<const char*> values{"1"};
ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr);

ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM(
static_cast<OrtSessionOptions*>(session_options),
rel_rocm_options.get()) == nullptr);
Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#endif

Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options);
Expand Down

0 comments on commit 898293b

Please sign in to comment.