Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sycl Half #1710

Open
wants to merge 19 commits into
base: half_batch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ option(GINKGO_FAST_TESTS "Reduces the input size for a few tests known to be tim
option(GINKGO_TEST_NONDEFAULT_STREAM "Uses non-default streams in CUDA and HIP tests" OFF)
option(GINKGO_MIXED_PRECISION "Instantiate true mixed-precision kernels (otherwise they will be conversion-based using implicit temporary storage)" OFF)
option(GINKGO_ENABLE_HALF "Enable the use of half precision" ON)
# We do not support MSVC. SYCL will come later
if(MSVC OR GINKGO_BUILD_SYCL)
message(STATUS "HALF is not supported in MSVC, and later support in SYCL")
# We do not support MSVC.
if(MSVC)
yhmtsai marked this conversation as resolved.
Show resolved Hide resolved
message(STATUS "HALF is not supported in MSVC")
set(GINKGO_ENABLE_HALF OFF CACHE BOOL "Enable the use of half precision" FORCE)
endif()
option(GINKGO_SKIP_DEPENDENCY_UPDATE
Expand Down Expand Up @@ -304,9 +304,11 @@ endif()

if(GINKGO_BUILD_SYCL)
ginkgo_extract_dpcpp_version(${CMAKE_CXX_COMPILER} GINKGO_DPCPP_MAJOR_VERSION __LIBSYCL_MAJOR_VERSION)
ginkgo_extract_dpcpp_version(${CMAKE_CXX_COMPILER} GINKGO_DPCPP_MINOR_VERSION __LIBSYCL_MINOR_VERSION)
ginkgo_extract_dpcpp_version(${CMAKE_CXX_COMPILER} GINKGO_DPCPP_VERSION __SYCL_COMPILER_VERSION)
else()
set(GINKGO_DPCPP_MAJOR_VERSION "0")
set(GINKGO_DPCPP_MINOR_VERSION "0")
endif()
configure_file(${Ginkgo_SOURCE_DIR}/include/ginkgo/config.hpp.in
${Ginkgo_BINARY_DIR}/include/ginkgo/config.hpp @ONLY)
Expand Down
202 changes: 202 additions & 0 deletions accessor/sycl_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_ACCESSOR_SYCL_HELPER_HPP_
#define GKO_ACCESSOR_SYCL_HELPER_HPP_


#include <complex>
#include <type_traits>

#include "block_col_major.hpp"
#include "reduced_row_major.hpp"
#include "row_major.hpp"
#include "scaled_reduced_row_major.hpp"
#include "utils.hpp"


// namespace sycl {
// inline namespace _V1 {


// class half;


// }
// } // namespace sycl


namespace gko {


class half;


template <typename V>
class complex;


namespace acc {
namespace detail {


template <typename T>
struct sycl_type {
using type = T;
};

template <>
struct sycl_type<gko::half> {
using type = sycl::half;
};

// Unpack cv and reference / pointer qualifiers
template <typename T>
struct sycl_type<const T> {
using type = const typename sycl_type<T>::type;
};

template <typename T>
struct sycl_type<volatile T> {
using type = volatile typename sycl_type<T>::type;
};

template <typename T>
struct sycl_type<T*> {
using type = typename sycl_type<T>::type*;
};

template <typename T>
struct sycl_type<T&> {
using type = typename sycl_type<T>::type&;
};

template <typename T>
struct sycl_type<T&&> {
using type = typename sycl_type<T>::type&&;
};


// Transform the underlying type of std::complex
template <typename T>
struct sycl_type<std::complex<T>> {
using type = std::complex<typename sycl_type<T>::type>;
};


template <>
struct sycl_type<std::complex<gko::half>> {
using type = gko::complex<typename sycl_type<gko::half>::type>;
};


} // namespace detail


/**
* This is an alias for SYCL's equivalent of `T`.
*
* @tparam T a type
*/
template <typename T>
using sycl_type_t = typename detail::sycl_type<T>::type;


/**
* Reinterprets the passed in value as a SYCL type.
*
* @param val the value to reinterpret
*
* @return `val` reinterpreted to SYCL type
*/
template <typename T>
std::enable_if_t<std::is_pointer<T>::value || std::is_reference<T>::value,
sycl_type_t<T>>
as_sycl_type(T val)
{
return reinterpret_cast<sycl_type_t<T>>(val);
}


/**
* @copydoc as_sycl_type()
*/
template <typename T>
std::enable_if_t<!std::is_pointer<T>::value && !std::is_reference<T>::value,
sycl_type_t<T>>
as_sycl_type(T val)
{
return *reinterpret_cast<sycl_type_t<T>*>(&val);
}


/**
* Changes the types and reinterprets the passed in range pointers as a SYCL
* types.
*
* @param r the range which pointers need to be reinterpreted
*
* @return `r` with appropriate types and reinterpreted to SYCL pointers
*/
template <std::size_t dim, typename Type1, typename Type2>
GKO_ACC_INLINE auto as_sycl_range(
const range<reduced_row_major<dim, Type1, Type2>>& r)
{
return range<
reduced_row_major<dim, sycl_type_t<Type1>, sycl_type_t<Type2>>>(
r.get_accessor().get_size(),
as_sycl_type(r.get_accessor().get_stored_data()),
r.get_accessor().get_stride());
}

/**
* @copydoc as_sycl_range()
*/
template <std::size_t dim, typename Type1, typename Type2, std::uint64_t mask>
GKO_ACC_INLINE auto as_sycl_range(
const range<scaled_reduced_row_major<dim, Type1, Type2, mask>>& r)
{
return range<scaled_reduced_row_major<dim, sycl_type_t<Type1>,
sycl_type_t<Type2>, mask>>(
r.get_accessor().get_size(),
as_sycl_type(r.get_accessor().get_stored_data()),
r.get_accessor().get_storage_stride(),
as_sycl_type(r.get_accessor().get_scalar()),
r.get_accessor().get_scalar_stride());
}

/**
* @copydoc as_sycl_range()
*/
template <typename T, size_type dim>
GKO_ACC_INLINE auto as_sycl_range(const range<block_col_major<T, dim>>& r)
{
return range<block_col_major<sycl_type_t<T>, dim>>(
r.get_accessor().lengths, as_sycl_type(r.get_accessor().data),
r.get_accessor().stride);
}

/**
* @copydoc as_sycl_range()
*/
template <typename T, size_type dim>
GKO_ACC_INLINE auto as_sycl_range(const range<row_major<T, dim>>& r)
{
return range<block_col_major<sycl_type_t<T>, dim>>(
r.get_accessor().lengths, as_sycl_type(r.get_accessor().data),
r.get_accessor().stride);
}

template <typename AccType>
GKO_ACC_INLINE auto as_device_range(AccType&& acc)
{
return as_sycl_range(std::forward<AccType>(acc));
}


} // namespace acc
} // namespace gko


#endif // GKO_ACCESSOR_SYCL_HELPER_HPP_
2 changes: 1 addition & 1 deletion benchmark/utils/dpcpp_timer.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <iostream>

#include <CL/sycl.hpp>
#include <sycl/sycl.hpp>

#include "benchmark/utils/timer_impl.hpp"

Expand Down
2 changes: 1 addition & 1 deletion cmake/build_helpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ endfunction()

# Extract the DPC++ version
function(ginkgo_extract_dpcpp_version DPCPP_COMPILER GINKGO_DPCPP_VERSION MACRO_VAR)
set(DPCPP_VERSION_PROG "#include <CL/sycl.hpp>\n#include <iostream>\n"
set(DPCPP_VERSION_PROG "#include <sycl/sycl.hpp>\n#include <iostream>\n"
"int main() {std::cout << ${MACRO_VAR} << '\\n'\;"
"return 0\;}")
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/extract_dpcpp_ver.cpp" ${DPCPP_VERSION_PROG})
Expand Down
11 changes: 5 additions & 6 deletions cmake/sycl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@ endif()

# Provide a uniform way for those package without add_sycl_to_target
function(gko_add_sycl_to_target)
if(COMMAND add_sycl_to_target)
add_sycl_to_target(${ARGN})
return()
endif()
# We handle them by adding SYCL_FLAGS to compile and link to the target
set(one_value_args TARGET)
set(multi_value_args SOURCES)
cmake_parse_arguments(SYCL
""
"${one_value_args}"
"${multi_value_args}"
${ARGN})
if(COMMAND add_sycl_to_target)
add_sycl_to_target(${ARGN})
return()
endif()
# We handle them by adding SYCL_FLAGS to compile and link to the target
target_compile_options(${SYCL_TARGET} PRIVATE "${SYCL_FLAGS}")
target_link_options(${SYCL_TARGET} PRIVATE "${SYCL_FLAGS}")
endfunction()

14 changes: 4 additions & 10 deletions common/unified/base/kernel_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,14 @@ GKO_INLINE GKO_ATTRIBUTES constexpr unpack_member_type<T> unpack_member(T value)
#define GKO_KERNEL


#include "dpcpp/base/math.hpp"
#include "dpcpp/base/types.hpp"

namespace gko {
namespace kernels {
namespace dpcpp {


template <typename T>
using device_type = T;

template <typename T>
device_type<T> as_device_type(T value)
{
return value;
}


template <typename T>
using unpack_member_type = T;

Expand All @@ -97,6 +90,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr unpack_member_type<T> unpack_member(T value)
return value;
}


} // namespace dpcpp
} // namespace kernels
} // namespace gko
Expand Down
4 changes: 3 additions & 1 deletion common/unified/components/precision_conversion_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ void convert_precision(std::shared_ptr<const DefaultExecutor> exec,
{
run_kernel(
exec,
[] GKO_KERNEL(auto idx, auto in, auto out) { out[idx] = in[idx]; },
[] GKO_KERNEL(auto idx, auto in, auto out) {
out[idx] = static_cast<device_type<TargetType>>(in[idx]);
},
size, in, out);
}

Expand Down
6 changes: 4 additions & 2 deletions common/unified/matrix/dense_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto input, auto output) {
output(row, col) = input(row, col);
output(row, col) =
static_cast<device_type<OutValueType>>(input(row, col));
},
input->get_size(), input, output);
}
Expand Down Expand Up @@ -425,7 +426,8 @@ void row_gather(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto orig, auto rows, auto gathered) {
gathered(row, col) = orig(rows[row], col);
gathered(row, col) =
static_cast<device_type<OutputType>>(orig(rows[row], col));
},
row_collection->get_size(), orig, row_idxs, row_collection);
}
Expand Down
19 changes: 1 addition & 18 deletions core/solver/batch_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,6 @@ using DeviceValueType = gko::kernels::hip::hip_type<ValueType>;
#include "dpcpp/stop/batch_criteria.hpp"


namespace gko {
namespace kernels {
namespace dpcpp {


template <typename T>
inline std::decay_t<T> as_device_type(T val)
{
return val;
}


} // namespace dpcpp
} // namespace kernels
} // namespace gko


namespace gko {
namespace batch {
namespace solver {
Expand All @@ -112,7 +95,7 @@ namespace device = gko::kernels::dpcpp;


template <typename ValueType>
using DeviceValueType = ValueType;
using DeviceValueType = gko::kernels::dpcpp::sycl_type<ValueType>;


} // namespace solver
Expand Down
Loading
Loading