Skip to content

Commit

Permalink
SYCL: Fixing breaking issue of sort related kernel level API due to I…
Browse files Browse the repository at this point in the history
…ntel SYCL compiler uplift (#1017)
  • Loading branch information
chunhuanMeng authored Oct 31, 2024
1 parent 43dfdbb commit 668f5c2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
3 changes: 2 additions & 1 deletion cmake/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC"
if(USE_PER_OPERATOR_HEADERS)
list(APPEND SYCL_HOST_FLAGS -DAT_PER_OPERATOR_HEADERS)
endif()

list(APPEND SYCL_HOST_FLAGS -D__INTEL_LLVM_COMPILER_VERSION=${__INTEL_LLVM_COMPILER})
# -- Kernel flags (SYCL_KERNEL_OPTIONS)
# The fast-math will be enabled by default in SYCL compiler.
# Refer to [https://clang.llvm.org/docs/UsersManual.html#cmdoption-fno-fast-math]
Expand Down Expand Up @@ -89,6 +89,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC"
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -D__INTEL_PREVIEW_BREAKING_CHANGES)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -D_GLIBCXX_USE_CXX11_ABI=${GLIBCXX_USE_CXX11_ABI})
endif()
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -D__INTEL_LLVM_COMPILER_VERSION=${__INTEL_LLVM_COMPILER})

CHECK_SYCL_FLAG("-fsycl-fp64-conv-emu" SUPPORTS_FP64_CONV_EMU)
if(SUPPORTS_FP64_CONV_EMU)
Expand Down
57 changes: 46 additions & 11 deletions cmake/Modules/FindSYCLToolkit.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ endif()
if(SYCLTOOLKIT_FOUND)
return()
endif()

set(SYCLTOOLKIT_FOUND TRUE)

include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)
Expand Down Expand Up @@ -77,7 +78,7 @@ endif()

# Function to write a test case to verify SYCL features.

function(SYCL_CMPLR_TEST_WRITE src)
function(SYCL_CMPLR_TEST_WRITE src macro_name)

set(cpp_macro_if "#if")
set(cpp_macro_endif "#endif")
Expand All @@ -88,8 +89,8 @@ function(SYCL_CMPLR_TEST_WRITE src)

# Feature tests goes here

string(APPEND SYCL_CMPLR_TEST_CONTENT "${cpp_macro_if} defined(SYCL_LANGUAGE_VERSION)\n")
string(APPEND SYCL_CMPLR_TEST_CONTENT "cout << \"SYCL_LANGUAGE_VERSION=\"<<SYCL_LANGUAGE_VERSION<<endl;\n")
string(APPEND SYCL_CMPLR_TEST_CONTENT "${cpp_macro_if} defined(${macro_name})\n")
string(APPEND SYCL_CMPLR_TEST_CONTENT "cout << \"${macro_name}=\"<<${macro_name}<<endl;\n")
string(APPEND SYCL_CMPLR_TEST_CONTENT "${cpp_macro_endif}\n")

string(APPEND SYCL_CMPLR_TEST_CONTENT "return 0;}\n")
Expand All @@ -103,6 +104,7 @@ endfunction()
function(SYCL_CMPLR_TEST_BUILD error TEST_SRC_FILE TEST_EXE)

set(SYCL_CXX_FLAGS_LIST "${SYCL_CXX_FLAGS}")
string(REPLACE "-Wno-stringop-overflow" "" SYCL_CXX_FLAGS_LIST "${SYCL_CXX_FLAGS_LIST}")
separate_arguments(SYCL_CXX_FLAGS_LIST)

execute_process(
Expand Down Expand Up @@ -150,19 +152,19 @@ function(SYCL_CMPLR_TEST_RUN error TEST_EXE)

endfunction()

function(SYCL_CMPLR_TEST_EXTRACT test_output)
function(SYCL_CMPLR_TEST_EXTRACT test_output macro_name)

string(REGEX REPLACE "\n" ";" test_output_list "${test_output}")

set(SYCL_LANGUAGE_VERSION "")
set(${macro_name} "")
foreach(strl ${test_output_list})
if(${strl} MATCHES "^SYCL_LANGUAGE_VERSION=([A-Za-z0-9_]+)$")
string(REGEX REPLACE "^SYCL_LANGUAGE_VERSION=" "" extracted_sycl_lang "${strl}")
set(SYCL_LANGUAGE_VERSION ${extracted_sycl_lang})
if(${strl} MATCHES "^${macro_name}=([A-Za-z0-9_]+)$")
string(REGEX REPLACE "^${macro_name}=" "" extracted_sycl_lang "${strl}")
set(${macro_name} ${extracted_sycl_lang})
endif()
endforeach()

set(SYCL_LANGUAGE_VERSION "${SYCL_LANGUAGE_VERSION}" PARENT_SCOPE)
set(${macro_name} "${extracted_sycl_lang}" PARENT_SCOPE)
endfunction()

set(SYCL_FLAGS "")
Expand All @@ -189,7 +191,7 @@ if(${has_werror} EQUAL -1)
# Create the test source file
set(TEST_SRC_FILE "${SYCL_CMPLR_TEST_DIR}/sycl_features.cpp")
set(TEST_EXE "${TEST_SRC_FILE}.exe")
SYCL_CMPLR_TEST_WRITE(${TEST_SRC_FILE})
SYCL_CMPLR_TEST_WRITE(${TEST_SRC_FILE} "SYCL_LANGUAGE_VERSION")

# Build the test and create test executable
SYCL_CMPLR_TEST_BUILD(error ${TEST_SRC_FILE} ${TEST_EXE})
Expand All @@ -204,7 +206,7 @@ if(${has_werror} EQUAL -1)
endif()

# Extract test output for information
SYCL_CMPLR_TEST_EXTRACT(${test_output})
SYCL_CMPLR_TEST_EXTRACT(${test_output} "SYCL_LANGUAGE_VERSION")

# As per specification, all the SYCL compatible compilers should
# define macro SYCL_LANGUAGE_VERSION
Expand All @@ -221,5 +223,38 @@ if(${has_werror} EQUAL -1)
set(SYCL_LANGUAGE_VERSION "${SYCL_LANGUAGE_VERSION}" CACHE STRING "SYCL Language version")
endif()

# Create a clean working directory.
set(SYCL_CMPLR_TEST_DIR "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/TESTSYCLCMPLR")
file(REMOVE_RECURSE ${SYCL_CMPLR_TEST_DIR})
file(MAKE_DIRECTORY ${SYCL_CMPLR_TEST_DIR})
# Create the test source file
set(TEST_SRC_FILE "${SYCL_CMPLR_TEST_DIR}/llvm_features.cpp")
set(TEST_EXE "${TEST_SRC_FILE}.exe")
SYCL_CMPLR_TEST_WRITE(${TEST_SRC_FILE} "__INTEL_LLVM_COMPILER")
# Build the test and create test executable
SYCL_CMPLR_TEST_BUILD(error ${TEST_SRC_FILE} ${TEST_EXE})
if(error)
message(FATAL_ERROR "Can not build SYCL_CMPLR_TEST")
endif()
# Execute the test to extract information
SYCL_CMPLR_TEST_RUN(error ${TEST_EXE})
if(error)
message(FATAL_ERROR "Can not run SYCL_CMPLR_TEST")
endif()
# Extract test output for information
SYCL_CMPLR_TEST_EXTRACT(${test_output} "__INTEL_LLVM_COMPILER")

# Check whether the value of __INTEL_LLVM_COMPILER macro was successfully extracted
string(COMPARE EQUAL "${__INTEL_LLVM_COMPILER}" "" nosycllang)
if(nosycllang)
set(SYCLTOOLKIT_FOUND False)
set(SYCL_REASON_FAILURE "Can not find __INTEL_LLVM_COMPILER}")
set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
endif()


# Include in Cache
set(__INTEL_LLVM_COMPILER "${__INTEL_LLVM_COMPILER}" CACHE STRING "Intel llvm compiler")

message(DEBUG "The SYCL compiler is ${SYCL_COMPILER}")
message(DEBUG "The SYCL Flags are ${SYCL_FLAGS}")
9 changes: 9 additions & 0 deletions src/ATen/native/xpu/sycl/TensorModeKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,20 @@ void mode_kernel_impl(
auto group_size = problem_size;

// scratch memory size needed by built-in sort
#if defined(__INTEL_LLVM_COMPILER_VERSION) && \
__INTEL_LLVM_COMPILER_VERSION >= 20250000
auto sort_scratch_memory_size =
sycl::ext::oneapi::experimental::default_sorters::group_sorter<
ModeOpValueIndex<scalar_t>,
std::greater<scalar_t>,
1>::memory_required(sycl::memory_scope::work_group, group_size);
#else
auto sort_scratch_memory_size = sycl::ext::oneapi::experimental::
default_sorter<std::greater<scalar_t>>::template memory_required<
ModeOpValueIndex<scalar_t>>(
sycl::memory_scope::work_group,
sycl::range<1>{static_cast<size_t>(group_size)});
#endif

auto values_info = getTensorInfo<scalar_t, int64_t>(values_transposed);
auto indices_info = getTensorInfo<int64_t, int64_t>(indices_transposed);
Expand Down

0 comments on commit 668f5c2

Please sign in to comment.