Skip to content

Commit

Permalink
generic: conv: deconv: reduce kernel argument size
Browse files Browse the repository at this point in the history
  • Loading branch information
sgeor255 committed Oct 24, 2024
1 parent 2bf5ffc commit 729ece5
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 38 deletions.
30 changes: 13 additions & 17 deletions src/gpu/generic/sycl/convolution_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace sycl {
struct convolution_kernel_fwd_t {
static constexpr int max_supported_ndims = 6;

convolution_kernel_fwd_t(const sycl_convolution_conf_t &conf,
convolution_kernel_fwd_t(const sycl_convolution_fwd_conf_t &conf,
::sycl::handler &cgh, const exec_ctx_t &ctx)
: conf_(conf)
, data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
Expand Down Expand Up @@ -191,9 +191,8 @@ struct convolution_kernel_fwd_t {
accumulator *= sm_weights;
}

if (bias_md().ndims() != 0) {
auto bias = load_float_value(
bias_md().data_type(), bias_ptr(), oc_tot);
if (conf_.has_bias) {
auto bias = load_float_value(conf_.bias_dt, bias_ptr(), oc_tot);
accumulator += bias;
}

Expand All @@ -214,7 +213,6 @@ struct convolution_kernel_fwd_t {
private:
const xpu::sycl::md_t &data_md() const { return conf_.data_md; }
const xpu::sycl::md_t &weights_md() const { return conf_.weights_md; }
const xpu::sycl::md_t &bias_md() const { return conf_.bias_md; }
const xpu::sycl::md_t &dst_md() const { return conf_.dst_md; }

void *data_ptr() const { return data_.get_pointer(); }
Expand All @@ -227,7 +225,7 @@ struct convolution_kernel_fwd_t {
void *data_zeropoint_ptr() const { return data_zeropoints_.get_pointer(); }
void *dst_zeropoint_ptr() const { return dst_zeropoints_.get_pointer(); }

sycl_convolution_conf_t conf_;
sycl_convolution_fwd_conf_t conf_;

xpu::sycl::in_memory_arg_t data_;
xpu::sycl::in_memory_arg_t weights_;
Expand All @@ -247,7 +245,7 @@ struct convolution_kernel_fwd_t {
struct convolution_kernel_bwd_data_t {
static constexpr int max_supported_ndims = 6;

convolution_kernel_bwd_data_t(const sycl_convolution_conf_t &conf,
convolution_kernel_bwd_data_t(const sycl_convolution_bwd_data_conf_t &conf,
::sycl::handler &cgh, const exec_ctx_t &ctx)
: conf_(conf)
, diff_data_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DIFF_SRC))
Expand Down Expand Up @@ -423,9 +421,8 @@ struct convolution_kernel_bwd_data_t {
accumulator *= sm_weights;
}

if (bias_md().ndims() != 0) {
auto bias = load_float_value(
bias_md().data_type(), bias_ptr(), ic_tot);
if (conf_.has_bias) {
auto bias = load_float_value(conf_.bias_dt, bias_ptr(), ic_tot);
accumulator += bias;
}

Expand All @@ -446,7 +443,6 @@ struct convolution_kernel_bwd_data_t {
private:
const xpu::sycl::md_t &diff_data_md() const { return conf_.diff_data_md; }
const xpu::sycl::md_t &weights_md() const { return conf_.weights_md; }
const xpu::sycl::md_t &bias_md() const { return conf_.bias_md; }
const xpu::sycl::md_t &diff_dst_md() const { return conf_.diff_dst_md; }

void *diff_data_ptr() const { return diff_data_.get_pointer(); }
Expand All @@ -459,7 +455,7 @@ struct convolution_kernel_bwd_data_t {
void *data_zeropoint_ptr() const { return data_zeropoints_.get_pointer(); }
void *dst_zeropoint_ptr() const { return dst_zeropoints_.get_pointer(); }

sycl_convolution_conf_t conf_;
sycl_convolution_bwd_data_conf_t conf_;

xpu::sycl::inout_memory_arg_t diff_data_;
xpu::sycl::in_memory_arg_t weights_;
Expand All @@ -479,7 +475,8 @@ struct convolution_kernel_bwd_data_t {
struct convolution_kernel_bwd_weights_t {
static constexpr int max_supported_ndims = 6;

convolution_kernel_bwd_weights_t(const sycl_convolution_conf_t &conf,
convolution_kernel_bwd_weights_t(
const sycl_convolution_bwd_weights_conf_t &conf,
::sycl::handler &cgh, const exec_ctx_t &ctx, int data_arg,
int diff_dst_arg)
: conf_(conf)
Expand Down Expand Up @@ -572,8 +569,8 @@ struct convolution_kernel_bwd_weights_t {
}
}
}
store_float_value(diff_bias_md().data_type(),
accumulator_bias, diff_bias_ptr(), g * OC + oc);
store_float_value(conf_.bias_dt, accumulator_bias,
diff_bias_ptr(), g * OC + oc);
}
};
if (conf_.is_deconvolution) {
Expand Down Expand Up @@ -624,15 +621,14 @@ struct convolution_kernel_bwd_weights_t {
const xpu::sycl::md_t &diff_weights_md() const {
return conf_.diff_weights_md;
}
const xpu::sycl::md_t &diff_bias_md() const { return conf_.diff_bias_md; }
const xpu::sycl::md_t &diff_dst_md() const { return conf_.diff_dst_md; }

void *data_ptr() const { return data_.get_pointer(); }
void *diff_weights_ptr() const { return diff_weights_.get_pointer(); }
void *diff_bias_ptr() const { return diff_bias_.get_pointer(); }
void *diff_dst_ptr() const { return diff_dst_.get_pointer(); }

sycl_convolution_conf_t conf_;
sycl_convolution_bwd_weights_conf_t conf_;

xpu::sycl::in_memory_arg_t data_;
xpu::sycl::out_memory_arg_t diff_weights_;
Expand Down
19 changes: 13 additions & 6 deletions src/gpu/generic/sycl/ref_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ namespace generic {
namespace sycl {

status_t ref_convolution_fwd_t::pd_t::init_conf() {
conf_ = sycl_convolution_conf_t();
conf_ = sycl_convolution_fwd_conf_t();

conf_.data_md = xpu::sycl::md_t(src_md());
conf_.weights_md = xpu::sycl::md_t(weights_md(0));
if (with_bias()) { conf_.bias_md = xpu::sycl::md_t(weights_md(1)); }
if (with_bias()) {
conf_.bias_dt = weights_md(1)->data_type;
conf_.has_bias = true;
}
conf_.dst_md = xpu::sycl::md_t(dst_md());
conf_.ndims = ndims();

Expand Down Expand Up @@ -85,11 +88,14 @@ status_t ref_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
}

status_t ref_convolution_bwd_data_t::pd_t::init_conf() {
conf_ = sycl_convolution_conf_t();
conf_ = sycl_convolution_bwd_data_conf_t();

conf_.diff_data_md = xpu::sycl::md_t(diff_src_md());
conf_.weights_md = xpu::sycl::md_t(weights_md(0));
if (with_bias()) { conf_.bias_md = xpu::sycl::md_t(weights_md(1)); }
if (with_bias()) {
conf_.bias_dt = weights_md(1)->data_type;
conf_.has_bias = true;
}
conf_.diff_dst_md = xpu::sycl::md_t(diff_dst_md());
conf_.ndims = ndims();

Expand Down Expand Up @@ -145,12 +151,13 @@ status_t ref_convolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
}

status_t ref_convolution_bwd_weights_t::pd_t::init_conf() {
conf_ = sycl_convolution_conf_t();
conf_ = sycl_convolution_bwd_weights_conf_t();

conf_.data_md = xpu::sycl::md_t(src_md());
conf_.diff_weights_md = xpu::sycl::md_t(diff_weights_md(0));
if (with_bias()) {
conf_.diff_bias_md = xpu::sycl::md_t(diff_weights_md(1));
conf_.bias_dt = diff_weights_md(1)->data_type;
conf_.has_bias = true;
}
conf_.diff_dst_md = xpu::sycl::md_t(diff_dst_md());
conf_.ndims = ndims();
Expand Down
6 changes: 3 additions & 3 deletions src/gpu/generic/sycl/ref_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
return init_conf();
}

sycl_convolution_conf_t conf_;
sycl_convolution_fwd_conf_t conf_;

private:
status_t init_conf();
Expand Down Expand Up @@ -164,7 +164,7 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
return init_conf();
}

sycl_convolution_conf_t conf_;
sycl_convolution_bwd_data_conf_t conf_;

private:
status_t init_conf();
Expand Down Expand Up @@ -216,7 +216,7 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
return init_conf();
}

sycl_convolution_conf_t conf_;
sycl_convolution_bwd_weights_conf_t conf_;

private:
status_t init_conf();
Expand Down
5 changes: 3 additions & 2 deletions src/gpu/generic/sycl/ref_deconvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ namespace generic {
namespace sycl {

status_t ref_deconvolution_bwd_weights_t::pd_t::init_conf() {
conf_ = sycl_convolution_conf_t();
conf_ = sycl_convolution_bwd_weights_conf_t();

conf_.diff_dst_md = xpu::sycl::md_t(src_md());
if (with_bias()) {
conf_.diff_bias_md = xpu::sycl::md_t(diff_weights_md(1));
conf_.bias_dt = diff_weights_md(1)->data_type;
conf_.has_bias = true;
}
conf_.data_md = xpu::sycl::md_t(diff_dst_md());
conf_.ndims = ndims();
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/generic/sycl/ref_deconvolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct ref_deconvolution_bwd_weights_t
return init_conf();
}

sycl_convolution_conf_t conf_;
sycl_convolution_bwd_weights_conf_t conf_;

private:
status_t init_conf();
Expand Down
33 changes: 24 additions & 9 deletions src/gpu/generic/sycl/sycl_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,9 @@ struct sycl_binary_conf_t {
sycl_post_ops_t post_ops;
};

struct sycl_convolution_conf_t {
xpu::sycl::md_t data_md;
xpu::sycl::md_t dst_md;
xpu::sycl::md_t weights_md;
xpu::sycl::md_t bias_md;
xpu::sycl::md_t diff_data_md;
xpu::sycl::md_t diff_dst_md;
xpu::sycl::md_t diff_weights_md;
xpu::sycl::md_t diff_bias_md;
struct sycl_convolution_common_conf_t {
bool has_bias = false;
data_type_t bias_dt;

int padding[3];
int strides[3];
Expand All @@ -81,6 +75,24 @@ struct sycl_convolution_conf_t {
sycl_post_ops_t post_ops;
};

struct sycl_convolution_fwd_conf_t : sycl_convolution_common_conf_t {
xpu::sycl::md_t data_md;
xpu::sycl::md_t dst_md;
xpu::sycl::md_t weights_md;
};

struct sycl_convolution_bwd_data_conf_t : sycl_convolution_common_conf_t {
xpu::sycl::md_t weights_md;
xpu::sycl::md_t diff_data_md;
xpu::sycl::md_t diff_dst_md;
};

struct sycl_convolution_bwd_weights_conf_t : sycl_convolution_common_conf_t {
xpu::sycl::md_t data_md;
xpu::sycl::md_t diff_dst_md;
xpu::sycl::md_t diff_weights_md;
};

struct sycl_eltwise_conf_t {
prop_kind_t prop_kind;
xpu::sycl::md_t src_md;
Expand Down Expand Up @@ -416,6 +428,9 @@ CHECK_SYCL_KERNEL_ARG_TYPE(sycl_sum_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_base_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_fwd_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_bwd_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_fwd_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_data_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_weights_conf_t);

} // namespace sycl
} // namespace generic
Expand Down

0 comments on commit 729ece5

Please sign in to comment.