Skip to content

Commit

Permalink
sycl half
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 20, 2024
1 parent 55daa5c commit 92dc246
Show file tree
Hide file tree
Showing 51 changed files with 603 additions and 261 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ option(GINKGO_TEST_NONDEFAULT_STREAM "Uses non-default streams in CUDA and HIP t
option(GINKGO_MIXED_PRECISION "Instantiate true mixed-precision kernels (otherwise they will be conversion-based using implicit temporary storage)" OFF)
option(GINKGO_ENABLE_HALF "Enable the use of half precision" ON)
# We do not support MSVC. SYCL will come later
if(MSVC OR GINKGO_BUILD_SYCL)
if(MSVC)
message(STATUS "HALF is not supported in MSVC, and later support in SYCL")
set(GINKGO_ENABLE_HALF OFF CACHE BOOL "Enable the use of half precision" FORCE)
endif()
Expand Down
12 changes: 6 additions & 6 deletions accessor/sycl_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
#include "utils.hpp"


namespace sycl {
inline namespace _V1 {
// namespace sycl {
// inline namespace _V1 {


class half;
// class half;


}
} // namespace sycl
// }
// } // namespace sycl


namespace gko {
Expand Down Expand Up @@ -181,7 +181,7 @@ GKO_ACC_INLINE auto as_sycl_range(const range<row_major<T, dim>>& r)
template <typename AccType>
GKO_ACC_INLINE auto as_device_range(AccType&& acc)
{
return as_device_range(std::forward<AccType>(acc));
return as_sycl_range(std::forward<AccType>(acc));
}


Expand Down
4 changes: 3 additions & 1 deletion common/unified/base/kernel_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ GKO_INLINE GKO_ATTRIBUTES constexpr unpack_member_type<T> unpack_member(T value)
#define GKO_KERNEL


#include "dpcpp/base/math.hpp"
#include "dpcpp/base/types.hpp"

namespace gko {
namespace kernels {
namespace dpcpp {
#include "dpcpp/base/types.hpp"


template <typename T>
Expand Down
4 changes: 3 additions & 1 deletion common/unified/components/precision_conversion_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ void convert_precision(std::shared_ptr<const DefaultExecutor> exec,
{
run_kernel(
exec,
[] GKO_KERNEL(auto idx, auto in, auto out) { out[idx] = in[idx]; },
[] GKO_KERNEL(auto idx, auto in, auto out) {
out[idx] = static_cast<device_type<TargetType>>(in[idx]);
},
size, in, out);
}

Expand Down
6 changes: 4 additions & 2 deletions common/unified/matrix/dense_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto input, auto output) {
output(row, col) = input(row, col);
output(row, col) =
static_cast<device_type<OutValueType>>(input(row, col));
},
input->get_size(), input, output);
}
Expand Down Expand Up @@ -425,7 +426,8 @@ void row_gather(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto orig, auto rows, auto gathered) {
gathered(row, col) = orig(rows[row], col);
gathered(row, col) =
static_cast<device_type<OutputType>>(orig(rows[row], col));
},
row_collection->get_size(), orig, row_idxs, row_collection);
}
Expand Down
19 changes: 1 addition & 18 deletions core/solver/batch_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,6 @@ using DeviceValueType = gko::kernels::hip::hip_type<ValueType>;
#include "dpcpp/stop/batch_criteria.hpp"


namespace gko {
namespace kernels {
namespace dpcpp {


template <typename T>
inline std::decay_t<T> as_device_type(T val)
{
return val;
}


} // namespace dpcpp
} // namespace kernels
} // namespace gko


namespace gko {
namespace batch {
namespace solver {
Expand All @@ -111,7 +94,7 @@ namespace device = gko::kernels::dpcpp;


template <typename ValueType>
using DeviceValueType = ValueType;
using DeviceValueType = gko::kernels::dpcpp::sycl_type<ValueType>;


} // namespace solver
Expand Down
2 changes: 2 additions & 0 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/dpct.hpp"
#include "dpcpp/base/helper.hpp"
#include "dpcpp/base/math.hpp"
#include "dpcpp/base/types.hpp"
#include "dpcpp/components/cooperative_groups.dp.hpp"
#include "dpcpp/components/intrinsics.dp.hpp"
#include "dpcpp/components/reduction.dp.hpp"
Expand Down
13 changes: 7 additions & 6 deletions dpcpp/base/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "core/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/types.hpp"


namespace gko {
Expand All @@ -32,10 +33,10 @@ namespace dpcpp {
* Generates an immutable uniform batch struct from a batch of multi-vectors.
*/
template <typename ValueType>
inline batch::multi_vector::uniform_batch<const ValueType> get_batch_struct(
const batch::MultiVector<ValueType>* const op)
inline batch::multi_vector::uniform_batch<const device_type<ValueType>>
get_batch_struct(const batch::MultiVector<ValueType>* const op)
{
return {op->get_const_values(), op->get_num_batch_items(),
return {as_device_type(op->get_const_values()), op->get_num_batch_items(),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1])};
Expand All @@ -46,10 +47,10 @@ inline batch::multi_vector::uniform_batch<const ValueType> get_batch_struct(
* Generates a uniform batch struct from a batch of multi-vectors.
*/
template <typename ValueType>
inline batch::multi_vector::uniform_batch<ValueType> get_batch_struct(
batch::MultiVector<ValueType>* const op)
inline batch::multi_vector::uniform_batch<device_type<ValueType>>
get_batch_struct(batch::MultiVector<ValueType>* const op)
{
return {op->get_values(), op->get_num_batch_items(),
return {as_device_type(op->get_values()), op->get_num_batch_items(),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1])};
Expand Down
2 changes: 1 addition & 1 deletion dpcpp/base/device_matrix_data_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void sum_duplicates(std::shared_ptr<const DefaultExecutor> exec, size_type,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DEVICE_MATRIX_DATA_SUM_DUPLICATES_KERNEL);


Expand Down
1 change: 1 addition & 0 deletions dpcpp/base/kernel_launch_reduction.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/synthesizer/implementation_selection.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/types.hpp"
#include "dpcpp/components/cooperative_groups.dp.hpp"
#include "dpcpp/components/reduction.dp.hpp"
#include "dpcpp/components/thread_ids.dp.hpp"
Expand Down
Loading

0 comments on commit 92dc246

Please sign in to comment.