Skip to content

Commit

Permalink
Bundle MIGraphX with ROCm when built together (#47)
Browse files Browse the repository at this point in the history
* create package for migraphx ep
* add migrahx to the gpu providers for benchmark.py
* remove rocm from migraphx perfs tests
  • Loading branch information
apwojcik authored Jul 18, 2024
1 parent 479551b commit 5f3b3be
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def run_onnxruntime(
if (
use_gpu
and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers())
and ("MIGraphXExecutionProvider" not in onnxruntime.get_available_providers())
and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers())
and ("DmlExecutionProvider" not in onnxruntime.get_available_providers())
):
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (provider_name_ == onnxruntime::kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(session_options, 0));
OrtROCMProviderOptions rocm_options;
rocm_options.miopen_conv_exhaustive_search = performance_test_config.run_config.cudnn_conv_algo;
rocm_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream;
session_options.AppendExecutionProvider_ROCM(rocm_options);
#else
ORT_THROW("MIGraphX is not supported in this build\n");
#endif
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def parse_arg_remove_string(argv, arg_name_equal):
elif parse_arg_remove_boolean(sys.argv, "--use_rocm"):
is_rocm = True
rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=")
if parse_arg_remove_boolean(sys.argv, "--use_migraphx"):
is_migraphx = True
elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"):
is_migraphx = True
elif parse_arg_remove_boolean(sys.argv, "--use_openvino"):
Expand All @@ -89,8 +91,10 @@ def parse_arg_remove_string(argv, arg_name_equal):
elif parse_arg_remove_boolean(sys.argv, "--use_qnn"):
package_name = "onnxruntime-qnn"

if is_rocm or is_migraphx:
if is_rocm:
package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly"
elif is_migraphx:
package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly"

# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686
Expand Down
2 changes: 2 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,6 +2224,8 @@ def build_python_wheel(
args.append("--use_rocm")
if rocm_version:
args.append(f"--rocm_version={rocm_version}")
if use_migraphx:
args.append("--use_migraphx")
elif use_migraphx:
args.append("--use_migraphx")
elif use_openvino:
Expand Down

0 comments on commit 5f3b3be

Please sign in to comment.