Skip to content

Commit

Permalink
batch_normalization: Introduce vectorization optimization in the batc…
Browse files Browse the repository at this point in the history
…h norm elementwise kernel. (#933)

Due to performance issues with the low-precision data type
implementation of group stride loops on PVC ([jira:
PYTORCHDGQ-5162](https://jira.devtools.intel.com/browse/PYTORCHDGQ-5162)),
partial vectorization optimization is used.
  • Loading branch information
xytintel authored Nov 14, 2024
1 parent ed0dbe4 commit e035f6b
Showing 1 changed file with 162 additions and 8 deletions.
170 changes: 162 additions & 8 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,147 @@ struct BatchNormTransformInputKernelFunctor {
stat_accscalar_t epsilon_;
};

template <
int VEC_SIZE,
typename input_scalar_t,
typename stat_scalar_t,
typename stat_accscalar_t,
bool train,
typename index_t>
struct BatchNormTransformInputVectorizedKernelFunctor {
void operator()(sycl::nd_item<2> item) const {
index_t plane = item.get_group(1);

if (plane >= input_.size(1)) {
return;
}

stat_accscalar_t gamma = weight_.size(0) > 0
? static_cast<stat_accscalar_t>(weight_[plane])
: static_cast<stat_accscalar_t>(1);
stat_accscalar_t beta = bias_.size(0) > 0
? static_cast<stat_accscalar_t>(bias_[plane])
: static_cast<stat_accscalar_t>(0);
stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
stat_accscalar_t invstd;
if constexpr (train) {
invstd = var_or_invstd_[plane];
} else {
invstd =
static_cast<stat_accscalar_t>(1) /
device_sqrt(
static_cast<stat_accscalar_t>(var_or_invstd_[plane]) + epsilon_);
}

index_t bs = input_.size(0);
index_t fs = input_.size(2);

index_t bstep = item.get_local_range(0) * item.get_group_range(0);
for (index_t batch = item.get_global_id(0); batch < bs; batch += bstep) {
auto o = output_[batch][plane];
auto i = input_[batch][plane];

for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE;
feature_vec_begin < fs;
feature_vec_begin += VEC_SIZE * item.get_local_range(1)) {
auto remaining = fs - feature_vec_begin;
if (remaining < VEC_SIZE) {
for (index_t idx = 0; idx < remaining; ++idx) {
index_t feature = feature_vec_begin + idx;
o[feature] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
} else {
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
#pragma unroll
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
}
}
}
}

BatchNormTransformInputVectorizedKernelFunctor(
const GenericPackedTensorAccessor<
const input_scalar_t,
3,
RestrictPtrTraits,
index_t> input,
GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t>
output,
const GenericPackedTensorAccessor<
typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::
type,
1,
RestrictPtrTraits,
index_t> mean,
const GenericPackedTensorAccessor<
typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::
type,
1,
RestrictPtrTraits,
index_t> var_or_invstd,
const GenericPackedTensorAccessor<
const stat_scalar_t,
1,
RestrictPtrTraits,
index_t> weight,
const GenericPackedTensorAccessor<
const stat_scalar_t,
1,
RestrictPtrTraits,
index_t> bias,
stat_accscalar_t epsilon)
: input_(input),
output_(output),
mean_(mean),
var_or_invstd_(var_or_invstd),
weight_(weight),
bias_(bias),
epsilon_(epsilon) {}

private:
const GenericPackedTensorAccessor<
const input_scalar_t,
3,
RestrictPtrTraits,
index_t>
input_;
GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t>
output_;
const GenericPackedTensorAccessor<
typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type,
1,
RestrictPtrTraits,
index_t>
mean_;
const GenericPackedTensorAccessor<
typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type,
1,
RestrictPtrTraits,
index_t>
var_or_invstd_;
const GenericPackedTensorAccessor<
const stat_scalar_t,
1,
RestrictPtrTraits,
index_t>
weight_;
const GenericPackedTensorAccessor<
const stat_scalar_t,
1,
RestrictPtrTraits,
index_t>
bias_;
stat_accscalar_t epsilon_;
};

template <typename input_scalar_t, typename stat_scalar_t, typename index_t>
void batch_norm_elemt_template(
const Tensor& output_,
Expand Down Expand Up @@ -1315,14 +1456,27 @@ void batch_norm_elemt_template(
nwg_y = std::min<int>(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb));
sycl::range<2> global_range(nwg_y * tb, nwg_x * tf);

auto kfn = BatchNormTransformInputKernelFunctor<
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);

sycl_kernel_submit(global_range, local_range, queue, kfn);
auto output_ptr = (char*)output_reshaped.data_ptr();
if (output_reshaped.is_contiguous() &&
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
sizeof(input_scalar_t) < sizeof(float)) {
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
4,
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
} else {
auto kfn = BatchNormTransformInputKernelFunctor<
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
}
}

template <
Expand Down

0 comments on commit e035f6b

Please sign in to comment.