Skip to content

Commit

Permalink
fix integer dot product support for device without int8 support (#732)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjodinchr authored Nov 5, 2024
1 parent 56d99b3 commit 1f76af2
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,19 @@ void cvk_device::build_extension_ils_list() {
}

if (supports_dot_product()) {
m_extensions.push_back(
MAKE_NAME_VERSION(1, 0, 0, "cl_arm_integer_dot_product_int8"));
m_extensions.push_back(MAKE_NAME_VERSION(
1, 0, 0, "cl_arm_integer_dot_product_accumulate_int8"));
if (supports_int8()) {
m_extensions.push_back(MAKE_NAME_VERSION(
2, 0, 0, CL_KHR_INTEGER_DOT_PRODUCT_EXTENSION_NAME));
m_extensions.push_back(
MAKE_NAME_VERSION(1, 0, 0, "cl_arm_integer_dot_product_int8"));
m_extensions.push_back(MAKE_NAME_VERSION(
1, 0, 0, "cl_arm_integer_dot_product_accumulate_int8"));
m_extensions.push_back(MAKE_NAME_VERSION(
1, 0, 0,
"cl_arm_integer_dot_product_accumulate_saturate_int8"));
}
m_extensions.push_back(MAKE_NAME_VERSION(
1, 0, 0, "cl_arm_integer_dot_product_accumulate_int16"));
m_extensions.push_back(MAKE_NAME_VERSION(
1, 0, 0, "cl_arm_integer_dot_product_accumulate_saturate_int8"));
m_extensions.push_back(MAKE_NAME_VERSION(
2, 0, 0, CL_KHR_INTEGER_DOT_PRODUCT_EXTENSION_NAME));
}

auto split_string = [](std::string input, char delimiter) {
Expand Down Expand Up @@ -797,8 +800,10 @@ void cvk_device::build_extension_ils_list() {
MAKE_NAME_VERSION(3, 0, 0, "__opencl_c_fp64"));
}
if (supports_dot_product()) {
m_opencl_c_features.push_back(MAKE_NAME_VERSION(
3, 0, 0, "__opencl_c_integer_dot_product_input_4x8bit"));
if (supports_int8()) {
m_opencl_c_features.push_back(MAKE_NAME_VERSION(
3, 0, 0, "__opencl_c_integer_dot_product_input_4x8bit"));
}
m_opencl_c_features.push_back(MAKE_NAME_VERSION(
3, 0, 0, "__opencl_c_integer_dot_product_input_4x8bit_packed"));
}
Expand Down Expand Up @@ -1230,10 +1235,11 @@ bool cvk_device::supports_capability(spv::Capability capability) const {
m_float_controls_properties.shaderRoundingModeRTEFloat16 ||
m_float_controls_properties.shaderRoundingModeRTEFloat64;
case spv::CapabilityDotProduct:
case spv::CapabilityDotProductInput4x8Bit:
case spv::CapabilityDotProductInput4x8BitPacked:
case spv::CapabilityDotProductInputAll:
return supports_dot_product();
case spv::CapabilityDotProductInput4x8Bit:
case spv::CapabilityDotProductInputAll:
return supports_dot_product() && supports_int8();
// Capabilities that have not yet been mapped to Vulkan features:
default:
cvk_warn_fn("Capability %d not yet mapped to a feature.", capability);
Expand Down

0 comments on commit 1f76af2

Please sign in to comment.