Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocm jaxlib v0.4.30 qa nccl maxnchannels #75

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
56e22fa
[ROCm]: Fix LLVM path issue for ROCm 6.3
Jul 11, 2024
dc89176
[ROCm]: Fix LLVM path for ROCm 6.2
Jul 25, 2024
4d88619
PR #15311: [ROCm] GPU/CPU unified memory for rocm
i-chaochen Jul 29, 2024
973f86b
Merge pull request #33 from ROCm/rocm-jaxlib-v0.4.30-uni_mem
i-chaochen Aug 5, 2024
b1ac447
Let the other stream wait for the main stream before issuing memcpy d2h
zhenying-liu Jun 11, 2024
62b0e7b
Merge pull request #34 from ROCm/rocm-jaxlib-v0.4.30-qa-d2hmem-stream…
i-chaochen Aug 8, 2024
49f81a7
workspace fixing
pemeliya Aug 16, 2024
b61059e
[ROCM] Updated fp8 matmul with adjustments for updated hipBlasLt support
wenchenvincent May 21, 2024
2e76267
[ROCM] Addressed reviewer comment.
wenchenvincent Jun 21, 2024
652fd38
[ROCM] Fix build after 21311f23028acfd4bc2c3e4a3f76bad8c9e640e8
draganmladjenovic Aug 2, 2024
e085f5d
[ROCM] Add basic scaffolding and enable MLIR fusion
draganmladjenovic Mar 28, 2024
b9622e2
Enable dot algorithms for AMD GPUs
hsharsha Jun 14, 2024
d8e87a6
added bias pointer workaround
pemeliya May 22, 2024
83f1366
added precision settings for autotuner and buffer_comparator small re…
pemeliya Jun 5, 2024
a6b98de
added precision settings for autotuner and buffer_comparator small re…
pemeliya Jun 5, 2024
63aeab4
small changes
i-chaochen Jul 4, 2024
3cb9699
adopted changes
i-chaochen Jul 10, 2024
aad6467
Use deterministic ops flag in determinism test
hsharsha Aug 28, 2024
9560ccf
Fix MLIR tests specifically w.r.t number of threads and number of blocks
hsharsha Aug 16, 2024
f57f22a
Disable FP8 rewrite pattern test on ROCm
hsharsha Jul 18, 2024
339dde0
Disable workspace setting
hsharsha Aug 28, 2024
8c73dfe
Merge pull request #36 from ROCm/rocm-jaxlib-v0.4.30-qa-autotuning
hsharsha Aug 29, 2024
8ae1de7
Merge branch 'rocm-jaxlib-v0.4.30-qa' into rocm-jaxlib-v0.4.30-qa-cle…
hsharsha Aug 29, 2024
5945307
Merge pull request #35 from ROCm/rocm-jaxlib-v0.4.30-qa-cleanup
hsharsha Aug 29, 2024
7dc2933
scrub navi
ScXfjiang Sep 11, 2024
ed82401
Merge pull request #42 from ROCm/rocm-jaxlib-v0.4.30-qa_navi_scrub
ScXfjiang Sep 17, 2024
8353dcd
PR #16938: Add NANOO FP8 support for collaborative communication unit…
ScXfjiang Sep 19, 2024
a42d9cd
Merge pull request #45 from ROCm/rocm-jaxlib-v0.4.30-qa_collective_fp8
ScXfjiang Sep 24, 2024
48a9d97
[ROCm] Include clang-19 and clang-20 headers
hsharsha Sep 25, 2024
ed5b782
Merge pull request #48 from ROCm/rocm-jaxlib-v0.4.30-qa-clang20
i-chaochen Sep 25, 2024
7fd3ae6
Reset stream function in Gemm algorithm picker (#39)
hsharsha Oct 2, 2024
f3e91a6
[ROCm] Fixed linker issues with rocblas_get_version_string_size and r…
zoranjovanovic-ns Oct 9, 2024
4ea5b6f
Add multigpu script and disable triton tests
Oct 10, 2024
c718ef3
Merge pull request #53 from ROCm/rocm-jaxlib-v0.4.30-qa-multigpu-disa…
hsharsha Oct 15, 2024
e8b1ff4
[ROCm] Added include of hipblas.h in hipblaslt_wrapper.h
jayfurmanek Oct 22, 2024
e2dde69
Merge pull request #55 from ROCm/rocm-jaxlib-v0.4.30-qa-hipblasfix
jayfurmanek Oct 22, 2024
e2b918d
buffer init fix and gpu_hlo_runner test
pemeliya Oct 28, 2024
bf81e49
Merge pull request #59 from ROCm/r0.4.30_buffer_init_and_hlo_runner
pemeliya Oct 30, 2024
4951842
[ROCm] Fixed linker issues related to fp8 buffer_comparator functions
zoranjovanovic-ns Oct 22, 2024
98a8fe9
SWDEV-492517 (#65)
zoranjovanovic-ns Nov 13, 2024
49a7651
Merge pull request #66 from ROCm/rocm-jaxlib-v0.4.30-qa-SWDEV-476829-2
jayfurmanek Nov 13, 2024
d9660ac
[ROCm] Pass AMDGPU_TARGETS to crosstool wrapper
hsharsha Sep 24, 2024
08d8691
[Rocm] fix arch
Ruturaj4 Oct 10, 2024
430d8c3
Merge pull request #63 from ROCm/ci_clang_31_2
pramenku Nov 6, 2024
1d70f15
Skip gpu_hlo_runner_test if input is not provided (#71)
mmakevic-amd Nov 20, 2024
ee307e3
Add NCCL_MAX_NCHANNELS env variable to multi gpu tests
hsharsha Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions build_tools/rocm/run_xla_multi_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env bash
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================

set -e
set -x

N_BUILD_JOBS=$(grep -c ^processor /proc/cpuinfo)
# If rocm-smi exists locally (it should) use it to find
# out how many GPUs we have to test with.
rocm-smi -i
STATUS=$?
if [ $STATUS -ne 0 ]; then TF_GPU_COUNT=1; else
TF_GPU_COUNT=$(rocm-smi -i|grep 'Device ID' |grep 'GPU' |wc -l)
fi
if [[ $TF_GPU_COUNT -lt 4 ]]; then
echo "Found only ${TF_GPU_COUNT} gpus, multi-gpu tests need atleast 4 gpus."
exit
fi

TF_TESTS_PER_GPU=1
N_TEST_JOBS=$(expr ${TF_GPU_COUNT} \* ${TF_TESTS_PER_GPU})

echo ""
echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} concurrent test job(s)."
echo ""

# First positional argument (if any) specifies the ROCM_INSTALL_DIR
if [[ -n $1 ]]; then
ROCM_INSTALL_DIR=$1
else
if [[ -z "${ROCM_PATH}" ]]; then
ROCM_INSTALL_DIR=/opt/rocm-6.0.2
else
ROCM_INSTALL_DIR=$ROCM_PATH
fi
fi

export PYTHON_BIN_PATH=`which python3`
export TF_NEED_ROCM=1
export ROCM_PATH=$ROCM_INSTALL_DIR
TAGS_FILTER="-oss_excluded,-oss_serial"
UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{60,70,80,86,89,90}{,-only})"
TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}"

bazel \
test \
--config=rocm \
--build_tag_filters=${TAGS_FILTER} \
--test_tag_filters=${TAGS_FILTER} \
--test_timeout=920,2400,7200,9600 \
--test_sharding_strategy=disabled \
--test_output=errors \
--flaky_test_attempts=3 \
--keep_going \
--local_test_jobs=${N_TEST_JOBS} \
--test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \
--test_env=TF_GPU_COUNT=$TF_GPU_COUNT \
--action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \
--action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \
--action_env=NCCL_MAX_NCHANNELS=1 \
-- //xla/tests:collective_ops_test_e2e_gpu_amd_any \
//xla/tests:collective_ops_test_gpu_amd_any \
//xla/tests:replicated_io_feed_test_gpu_amd_any \
//xla/tools/multihost_hlo_runner:functional_hlo_runner_test_gpu_amd_any \
//xla/pjrt/distributed:topology_util_test \
//xla/pjrt/distributed:client_server_test
35 changes: 35 additions & 0 deletions third_party/llvm/rocdl_shuffle_down.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From a46b9e979ffa523bfed61487a2404e1f48140288 Mon Sep 17 00:00:00 2001
From: Dragan Mladjenovic <[email protected]>
Date: Fri, 29 Mar 2024 12:27:36 +0000
Subject: [PATCH] Support gpu::ShuffleMode::DOWN lowering

---
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d872..9317e30290c6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -140,7 +140,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
@@ -151,6 +151,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
// perf.
switch (op.getMode()) {
+ case gpu::ShuffleMode::DOWN:
+ dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
+ adaptor.getOffset());
+ break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
--
2.25.1

1 change: 1 addition & 0 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def repo(name):
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:rocdl_shuffle_down.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
27 changes: 27 additions & 0 deletions third_party/triton/temporary/amd_pr7.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
index b0976f8..bcdc5c7 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -956,6 +956,22 @@ struct FpToFpOpConversion
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
}
+
+ bool isSrcFP16 = srcElementType.isF16();
+ bool isSrcBF16 = srcElementType.isBF16();
+
+ if ((isSrcFP16 || isSrcBF16)
+ && isDstFP32) {
+ SmallVector<Value> outVals;
+ for (Value &v : inVals) {
+ if(isSrcFP16)
+ outVals.push_back(convertFp16ToFp32(loc, rewriter, v));
+ else
+ outVals.push_back(convertBf16ToFp32(loc, rewriter, v));
+ }
+ return outVals;
+ }
+
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = cvtFp32ToFp16(loc, rewriter, v,
4 changes: 3 additions & 1 deletion third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ These are created temporarily and should be moved to the first copybara workflow
internal patch during the next triton integration process.
"""

temporary_patch_list = []
temporary_patch_list = [
"//third_party/triton/temporary:amd_pr7.patch",
]
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}'
ROCR_RUNTIME_PATH = '%{rocr_runtime_path}'
ROCR_RUNTIME_LIBRARY = '%{rocr_runtime_library}'
VERBOSE = '%{crosstool_verbose}'=='1'
ROCM_AMDGPU_TARGETS = '%{rocm_amdgpu_targets}'

def Log(s):
print('gpus/crosstool: {0}'.format(s))
Expand Down Expand Up @@ -96,6 +97,29 @@ def GetHostCompilerOptions(argv):

return opts

def GetHipccOptions(argv):
"""Collect the -hipcc_options values from argv.

Args:
argv: A list of strings, possibly the argv passed to main().

Returns:
The string that can be passed directly to hipcc.
"""

parser = ArgumentParser()
parser.add_argument('--offload-arch', nargs='*', action='append')
# TODO find a better place for this
parser.add_argument('-gline-tables-only', action='store_true')

args, _ = parser.parse_known_args(argv)

hipcc_opts = ' -gline-tables-only ' if args.gline_tables_only else ''
if args.offload_arch:
hipcc_opts = hipcc_opts + ' '.join(['--offload-arch=' + a for a in sum(args.offload_arch, [])])

return hipcc_opts

def system(cmd):
"""Invokes cmd with os.system().

Expand All @@ -112,7 +136,6 @@ def system(cmd):
else:
return -os.WTERMSIG(retv)


def InvokeHipcc(argv, log=False):
"""Call hipcc with arguments assembled from argv.

Expand All @@ -125,6 +148,7 @@ def InvokeHipcc(argv, log=False):
"""

host_compiler_options = GetHostCompilerOptions(argv)
hipcc_compiler_options = GetHipccOptions(argv)
opt_option = GetOptionValue(argv, 'O')
m_options = GetOptionValue(argv, 'm')
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
Expand Down Expand Up @@ -163,7 +187,7 @@ def InvokeHipcc(argv, log=False):
srcs = ' '.join(src_files)
out = ' -o ' + out_file[0]

hipccopts = ' '
hipccopts = hipcc_compiler_options + ' '
# In hip-clang environment, we need to make sure that hip header is included
# before some standard math header like <complex> is included in any source.
# Otherwise, we get build error.
Expand Down Expand Up @@ -219,6 +243,7 @@ def main():

if VERBOSE: print('PWD=' + os.getcwd())
if VERBOSE: print('HIPCC_ENV=' + HIPCC_ENV)
if VERBOSE: print('ROCM_AMDGPU_TARGETS= ' + ROCM_AMDGPU_TARGETS)

if args.x and args.x[0] == 'rocm':
# compilation for GPU objects
Expand Down
10 changes: 9 additions & 1 deletion third_party/tsl/third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include")
if int(rocm_config.rocm_version_number) >= 60200:
inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/19/include")
inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/20/include")

# Support hcc based off clang 10.0.0 (for ROCm 3.3)
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
Expand Down Expand Up @@ -535,7 +539,7 @@ def _genrule(src_dir, genrule_name, command, outs):
)

def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
amdgpu_target_flags = ["--amdgpu-target=" +
amdgpu_target_flags = ["--offload-arch=" +
amdgpu_target for amdgpu_target in amdgpu_targets]
return str(amdgpu_target_flags)

Expand Down Expand Up @@ -707,6 +711,7 @@ def _create_local_rocm_repository(repository_ctx):
"-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_AMD__",
"-DEIGEN_USE_HIP",
"-DUSE_ROCM",
])

rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
Expand Down Expand Up @@ -744,6 +749,9 @@ def _create_local_rocm_repository(repository_ctx):
"%{hip_runtime_library}": "amdhip64",
"%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
"%{gcc_host_compiler_path}": str(cc),
"%{rocm_amdgpu_targets}": ",".join(
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
),
},
)

Expand Down
35 changes: 35 additions & 0 deletions third_party/tsl/third_party/llvm/rocdl_shuffle_down.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From a46b9e979ffa523bfed61487a2404e1f48140288 Mon Sep 17 00:00:00 2001
From: Dragan Mladjenovic <[email protected]>
Date: Fri, 29 Mar 2024 12:27:36 +0000
Subject: [PATCH] Support gpu::ShuffleMode::DOWN lowering

---
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d872..9317e30290c6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -140,7 +140,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
@@ -151,6 +151,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
// perf.
switch (op.getMode()) {
+ case gpu::ShuffleMode::DOWN:
+ dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
+ adaptor.getOffset());
+ break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
--
2.25.1

1 change: 1 addition & 0 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def repo(name):
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:rocdl_shuffle_down.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
15 changes: 12 additions & 3 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_llvm_enable_invariant_load_metadata(true);
opts.set_xla_llvm_disable_expensive_passes(false);
opts.set_xla_backend_optimization_level(3);
opts.set_xla_gpu_autotune_level(4);
opts.set_xla_gpu_autotune_level(5);
opts.set_xla_gpu_autotune_max_solutions(0);
opts.set_xla_cpu_multi_thread_eigen(true);
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
Expand Down Expand Up @@ -180,7 +180,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_partitioning_algorithm(
DebugOptions::PARTITIONING_ALGORITHM_NOOP);

opts.set_xla_gpu_enable_triton_gemm(true);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(false);
opts.set_xla_gpu_enable_triton_softmax_fusion(false);
Expand Down Expand Up @@ -270,6 +270,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_shard_autotuning(false);

opts.set_xla_gpu_autotune_gemm_rtol(0.1f);

return opts;
}

Expand Down Expand Up @@ -816,13 +818,20 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
debug_options->xla_gpu_autotune_level(),
"Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; "
" 5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also "
" the related flag xla_gpu_autotune_gemm_rtol."));
flag_list->push_back(tsl::Flag(
"xla_gpu_autotune_max_solutions",
int64_setter_for(&DebugOptions::set_xla_gpu_autotune_max_solutions),
debug_options->xla_gpu_autotune_max_solutions(),
"Maximal number of GEMM solutions to consider for autotuning: 0 means "
"consider all solutions returned by the GEMM library."));
flag_list->push_back(tsl::Flag(
"xla_gpu_autotune_gemm_rtol",
float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol),
debug_options->xla_gpu_autotune_gemm_rtol(),
"Relative precision for comparing GEMM solutions vs the reference one"));
flag_list->push_back(tsl::Flag(
"xla_force_host_platform_device_count",
int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
Expand Down
Loading
Loading