Skip to content

Commit

Permalink
src: gpu: nvidia: conv: Fix int8 convolution primitive fails
Browse files Browse the repository at this point in the history
  • Loading branch information
kala855 authored and dzarukin committed May 5, 2024
1 parent bc91dbf commit a986231
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 32 deletions.
11 changes: 11 additions & 0 deletions src/gpu/nvidia/cudnn_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ status_t cudnn_convolution_fwd_t::execute_convolution(
::sycl::access::mode::read_write>(temp_reorder_mem, cgh);
}

impl::sycl::sycl_memory_arg_t<::sycl::access::mode::read_write>
y_fp32_data;

if (!arg_dst_scale.empty() || !arg_src_scale.empty()
|| !arg_wei_scale.empty()) {
memory_storage_t *y_fp32_data_mem = scratch_storage_3.get();
y_fp32_data = impl::sycl::sycl_memory_arg_t<
::sycl::access::mode::read_write>(y_fp32_data_mem, cgh);
}

compat::host_task(cgh, [=, this](const compat::interop_handle &ih) {
auto &sycl_engine = *utils::downcast<sycl_cuda_engine_t *>(
cuda_stream->engine());
Expand All @@ -79,6 +89,7 @@ status_t cudnn_convolution_fwd_t::execute_convolution(
args.push_back(arg_src_scale.get_native_pointer(ih));
args.push_back(arg_wei_scale.get_native_pointer(ih));
args.push_back(arg_dst_scale.get_native_pointer(ih));
args.push_back(y_fp32_data.get_native_pointer(ih));

pd()->impl_->execute(handle, args);
});
Expand Down
38 changes: 27 additions & 11 deletions src/gpu/nvidia/cudnn_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,16 @@ struct cudnn_convolution_fwd_t : public primitive_t {

if (check_for_zero_dims()) return status::success;

const bool use_scales_dst = !attr()->scales_.has_default_values()
&& dst_md_.data_type == s8;
const bool use_temp_dst = attr()->post_ops_.len() > 0;
if (use_temp_dst) {
if (use_temp_dst || use_scales_dst) {
dst_md_temp_ = dst_md_;
if (dst_md_.data_type == s8) { dst_md_temp_.data_type = f32; }
}

impl_.reset(new cudnn_convolution_impl_fwd_t());
return impl_->init(engine, this, use_temp_dst);
return impl_->init(engine, this, use_temp_dst, use_scales_dst);
}
bool with_scratchpad() const { return impl_->with_scratchpad(); }
std::shared_ptr<cudnn_convolution_impl_base_t> impl_;
Expand All @@ -116,6 +118,11 @@ struct cudnn_convolution_fwd_t : public primitive_t {
return false;
}

bool use_scales_dst() const {
if (impl_.get()) return impl_->use_scales_dst();
return false;
}

private:
bool set_default_formats() {
using namespace format_tag;
Expand Down Expand Up @@ -161,23 +168,31 @@ struct cudnn_convolution_fwd_t : public primitive_t {
};

status_t init_temp_dst(engine_t *engine) {
const auto impl = pd()->impl_.get();
auto sycl_engine = utils::downcast<sycl_cuda_engine_t *>(engine);
memory_storage_t *scratch_ptr = nullptr;
auto wrap = memory_desc_wrapper(pd()->dst_md_temp_);
CHECK(sycl_engine->create_memory_storage(
&scratch_ptr, memory_flags_t::alloc, wrap.size(), nullptr));
scratch_storage.reset(scratch_ptr);

CHECK(sycl_engine->create_memory_storage(
&scratch_ptr, memory_flags_t::alloc, wrap.size(), nullptr));
scratch_storage_2.reset(scratch_ptr);
if (impl && impl->use_temp_dst()) {
CHECK(sycl_engine->create_memory_storage(
&scratch_ptr, memory_flags_t::alloc, wrap.size(), nullptr));
scratch_storage.reset(scratch_ptr);

CHECK(sycl_engine->create_memory_storage(
&scratch_ptr, memory_flags_t::alloc, wrap.size(), nullptr));
scratch_storage_2.reset(scratch_ptr);
}
if (impl && impl->use_scales_dst()) {
CHECK(sycl_engine->create_memory_storage(
&scratch_ptr, memory_flags_t::alloc, wrap.size(), nullptr));
scratch_storage_3.reset(scratch_ptr);
}

return status::success;
}

virtual status_t init(engine_t *engine) override {
const auto impl = pd()->impl_.get();
if (impl && impl->use_temp_dst()) { init_temp_dst(engine); }
init_temp_dst(engine);

return status::success;
}

Expand All @@ -200,6 +215,7 @@ struct cudnn_convolution_fwd_t : public primitive_t {
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
std::shared_ptr<memory_storage_t> scratch_storage;
std::shared_ptr<memory_storage_t> scratch_storage_2;
std::shared_ptr<memory_storage_t> scratch_storage_3;
};

struct cudnn_convolution_bwd_data_t : public primitive_t {
Expand Down
142 changes: 121 additions & 21 deletions src/gpu/nvidia/cudnn_convolution_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ struct cudnn_convolution_impl_base_t
bool with_bias = false;

bool do_scaling = false;
bool do_dst_scaling = false;
// When we apply scaling to the src and wei
// the post ops will need to be computed
// in f32 and then quantize using a default
// value of 1.0f
bool do_src_scaling = false;
bool do_wei_scaling = false;
bool use_temp_dst_ = false;
bool use_scales_dst_ = false;
cudnnDataType_t computation_data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t reorder_type = CUDNN_DATA_INT8;

Expand Down Expand Up @@ -97,7 +105,7 @@ struct cudnn_convolution_impl_base_t
bool with_scratchpad() const { return scratchpad_size > 0; }

virtual status_t init(engine_t *engine, convolution_pd_t *pd,
bool use_scratch_dst = false) {
bool use_scratch_dst = false, bool use_scales_dst = false) {
CHECK(configure_parameters(pd));
CHECK(create_cudnn_descs(pd));
CHECK(check_output_dims());
Expand Down Expand Up @@ -134,6 +142,13 @@ struct cudnn_convolution_impl_base_t
with_bias = pd->with_bias();
beta = 0.0f;
do_scaling = !pd->attr()->scales_.has_default_values();
do_dst_scaling
= !pd->attr()->scales_.get(DNNL_ARG_DST).has_default_values();
do_src_scaling
= !pd->attr()->scales_.get(DNNL_ARG_SRC).has_default_values();
do_wei_scaling = !pd->attr()
->scales_.get(DNNL_ARG_WEIGHTS)
.has_default_values();
dnnl_descs[x] = *pd->invariant_src_md();
dnnl_descs[weights] = *pd->invariant_wei_md();
dnnl_descs[y] = *pd->invariant_dst_md();
Expand Down Expand Up @@ -378,13 +393,17 @@ struct cudnn_convolution_impl_base_t
}

bool use_temp_dst() const { return use_temp_dst_; }

bool use_scales_dst() const { return use_scales_dst_; }
};

struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
protected:
cudnnActivationDescriptor_t activation_desc = nullptr;
cudnnActivationDescriptor_t eltwise_desc = nullptr;
cudnnTensorDescriptor_t reorder_dst_desc = nullptr;
cudnnTensorDescriptor_t y_fp32_desc = nullptr;
cudnnOpTensorDescriptor_t op_tensor_desc = nullptr;
cudnnConvolutionFwdAlgo_t fwd_alg_kind;
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf;
int requested_algo_count = 0;
Expand All @@ -407,6 +426,11 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
if (reorder_dst_desc)
CUDNN_EXECUTE_FUNC_V(
cudnnDestroyTensorDescriptor, reorder_dst_desc);
if (y_fp32_desc)
CUDNN_EXECUTE_FUNC_V(cudnnDestroyTensorDescriptor, y_fp32_desc);
if (op_tensor_desc)
CUDNN_EXECUTE_FUNC_V(
cudnnDestroyOpTensorDescriptor, op_tensor_desc);
}

status_t configure_post_ops(convolution_pd_t *pd) {
Expand Down Expand Up @@ -440,19 +464,36 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
// If the only post-op is fused then there is no need for temp dst
if (conv_bias_eltwise && num_post_ops == 1) use_temp_dst_ = false;

if (data_types[y] == CUDNN_DATA_INT8 && use_temp_dst_) {
// We need to take into account if we are scaling
// the src. In which case we will need to compute
// the post-ops in f32 and then quantize using a
// 1.0f scaling factor for dst
if (data_types[y] == CUDNN_DATA_INT8 && use_temp_dst_
&& !(do_dst_scaling || do_src_scaling || do_wei_scaling)) {
data_types[y] = CUDNN_DATA_FLOAT;
need_reorder = true;
CHECK(create_and_set_tensor_descriptor_ex(&reorder_dst_desc,
formats[y], reorder_type, ndims[y], dims[y]));
}

// If dst needs to be scaled and dst datatype is s8
if (y_f32_is_required()) {
CUDNN_EXECUTE_FUNC_V(
cudnnCreateOpTensorDescriptor, &op_tensor_desc);
cudnnOpTensorOp_t opTensorOp = CUDNN_OP_TENSOR_ADD;
cudnnDataType_t opTensorCompType = CUDNN_DATA_FLOAT;
cudnnNanPropagation_t opTensorNanOpt = CUDNN_NOT_PROPAGATE_NAN;
CUDNN_EXECUTE_FUNC_S(cudnnSetOpTensorDescriptor, op_tensor_desc,
opTensorOp, opTensorCompType, opTensorNanOpt);
}

return status::success;
}

status_t init(engine_t *engine, convolution_pd_t *pd,
bool use_scratch_dst) override {
status_t init(engine_t *engine, convolution_pd_t *pd, bool use_scratch_dst,
bool use_scales_dst) override {
use_temp_dst_ = use_scratch_dst;
use_scales_dst_ = use_scales_dst;
CHECK(configure_parameters(pd));
CHECK(create_cudnn_descs(pd));
CHECK(configure_alg_kind(engine, pd));
Expand All @@ -475,13 +516,35 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
}
}

void execute_f32_sum(cudnnHandle_t handle, void *y, void *y_fp32_data,
float alpha_, float beta_) const {
float alpha1 = 0.0f;
float alpha2 = alpha_;
float beta = beta_;
CUDNN_EXECUTE_FUNC(cudnnOpTensor, handle, op_tensor_desc, &alpha1,
descs[io::y], y, &alpha2, descs[io::y], y, &beta, y_fp32_desc,
y_fp32_data);
}

void execute_eltwise(cudnnHandle_t handle, void *src, void *dst) const {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_EXECUTE_FUNC_V(cudnnActivationForward, handle, eltwise_desc,
&alpha, descs[io::y], src, &beta, descs[io::y], dst);
}

void execute_f32_eltwise(cudnnHandle_t handle, void *src, void *dst) const {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_EXECUTE_FUNC_V(cudnnActivationForward, handle, eltwise_desc,
&alpha, y_fp32_desc, src, &beta, y_fp32_desc, dst);
}

bool y_f32_is_required() const {
return ((do_src_scaling || do_dst_scaling || do_wei_scaling)
&& data_types[io::y] == CUDNN_DATA_INT8);
}

void execute(cudnnHandle_t handle,
const std::vector<void *> &args) const override {
auto x = args[0], weights = args[1], y = args[2], bias = args[3],
Expand All @@ -495,6 +558,9 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
weights = w_scratch;
}

float *y_fp32_data = nullptr;
if (y_f32_is_required()) { y_fp32_data = (float *)args[11]; }

bool fused = conv_bias || conv_bias_eltwise;

float scale = 1.0f;
Expand All @@ -513,13 +579,16 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
}
}

auto &y_desc = y_f32_is_required() ? y_fp32_desc : descs[io::y];
void *y_data = y_f32_is_required() ? y_fp32_data : output;

if (fused) {
auto err = cudnnConvolutionBiasActivationForward(handle, &scale,
descs[io::x], x, weights_desc, weights, conv_desc,
fwd_alg_kind, scratchpad, scratchpad_size, &beta,
descs[io::y], output, descs[io::bias], bias,
conv_bias_eltwise ? eltwise_desc : activation_desc,
descs[io::y], output);
conv_bias_eltwise ? eltwise_desc : activation_desc, y_desc,
y_data);
// try to fallback into standalone convolution
if (err == CUDNN_STATUS_NOT_SUPPORTED) {
fused = false;
Expand All @@ -533,14 +602,14 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
const float bias_beta = 1.0f;
CUDNN_EXECUTE_FUNC_V(cudnnConvolutionForward, handle, &scale,
descs[io::x], x, weights_desc, weights, conv_desc,
fwd_alg_kind, scratchpad, scratchpad_size, &beta,
descs[io::y], output);
fwd_alg_kind, scratchpad, scratchpad_size, &beta, y_desc,
y_data);
if (with_bias) {
CUDNN_EXECUTE_FUNC_V(cudnnAddTensor, handle, &bias_alpha,
descs[io::bias], bias, &bias_beta, descs[io::y],
output);
descs[io::bias], bias, &bias_beta, y_desc, y_data);
}
}

// skip first eltwise in case it is fused into convolution
const int post_ops_start_pos = fused && conv_bias_eltwise;
for (int i = post_ops_start_pos; i < num_post_ops; i++) {
Expand All @@ -552,8 +621,14 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
execute_sum(handle, post_op_reorder, post_op_scratch,
sum_scale, 1.0f);
} else if (last_op) {
execute_sum(
handle, post_op_scratch, y, 1.0f, sum_scale);
if (y_f32_is_required()) {
execute_f32_sum(
handle, y, y_fp32_data, 1.0f, sum_scale);
} else {
execute_sum(handle, post_op_scratch, y, 1.0f,
sum_scale);
}

} else {
execute_sum(
handle, y, post_op_scratch, sum_scale, 1.0f);
Expand All @@ -563,7 +638,12 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {

case dnnl_eltwise:
if (last_op) {
execute_eltwise(handle, output, y);
if (y_f32_is_required()) {
execute_f32_eltwise(
handle, y_fp32_data, y_fp32_data);
} else {
execute_eltwise(handle, output, y);
}
} else {
execute_eltwise(handle, output, post_op_scratch);
}
Expand All @@ -576,13 +656,21 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
execute_reorder(handle, post_op_scratch, y, false);
}

if (dst_scale) {
if (dst_scale || src_scale || wei_scale) {
float host_dst_scale = 1.0f;
CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_dst_scale,
(CUdeviceptr)dst_scale, sizeof(float));
if (dst_scale)
CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_dst_scale,
(CUdeviceptr)dst_scale, sizeof(float));
float inv_scale = 1.0f / host_dst_scale;
CUDNN_EXECUTE_FUNC(
cudnnScaleTensor, handle, descs[io::y], y, &inv_scale);
if (data_types[io::y] == CUDNN_DATA_INT8) {
float alpha_beta = 0.0f;
CUDNN_EXECUTE_FUNC(cudnnOpTensor, handle, op_tensor_desc,
&inv_scale, y_fp32_desc, y_fp32_data, &alpha_beta,
y_fp32_desc, y_fp32_data, &alpha_beta, descs[io::y], y);
} else {
CUDNN_EXECUTE_FUNC(
cudnnScaleTensor, handle, descs[io::y], y, &inv_scale);
}
}
}
status_t init_scratchpad(engine_t *engine, convolution_pd_t *pd) override {
Expand All @@ -594,9 +682,21 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
= utils::downcast<sycl_cuda_stream_t *>(service_stream);
auto handle = cuda_stream->get_cudnn_handle();

CHECK(CUDNN_EXECUTE_FUNC_S(cudnnGetConvolutionForwardWorkspaceSize,
handle, descs[x], weights_desc, conv_desc, descs[y],
fwd_alg_kind, &scratchpad_size));
// The scratchpad size will need to be modified in
// cases where the dst_scaling is used and the output
// uses s8 values.
if (use_scales_dst_) {
CHECK(create_and_set_tensor_descriptor(&y_fp32_desc,
CUDNN_DATA_FLOAT, ndims[y], dims[y], strides[y]));
CHECK(CUDNN_EXECUTE_FUNC_S(cudnnGetConvolutionForwardWorkspaceSize,
handle, descs[x], weights_desc, conv_desc, y_fp32_desc,
fwd_alg_kind, &scratchpad_size));
} else {
CHECK(CUDNN_EXECUTE_FUNC_S(cudnnGetConvolutionForwardWorkspaceSize,
handle, descs[x], weights_desc, conv_desc, descs[y],
fwd_alg_kind, &scratchpad_size));
}

if (scratchpad_size > 0)
pd->scratchpad_registry().registrar().book(
memory_tracking::names::key_conv_cudnn_algo,
Expand Down

0 comments on commit a986231

Please sign in to comment.