diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 48c6e5076..30a3bf5c2 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1259,6 +1259,147 @@ struct BatchNormTransformInputKernelFunctor { stat_accscalar_t epsilon_; }; +template < + int VEC_SIZE, + typename input_scalar_t, + typename stat_scalar_t, + typename stat_accscalar_t, + bool train, + typename index_t> +struct BatchNormTransformInputVectorizedKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + index_t plane = item.get_group(1); + + if (plane >= input_.size(1)) { + return; + } + + stat_accscalar_t gamma = weight_.size(0) > 0 + ? static_cast(weight_[plane]) + : static_cast(1); + stat_accscalar_t beta = bias_.size(0) > 0 + ? static_cast(bias_[plane]) + : static_cast(0); + stat_accscalar_t mean = static_cast(mean_[plane]); + stat_accscalar_t invstd; + if constexpr (train) { + invstd = var_or_invstd_[plane]; + } else { + invstd = + static_cast(1) / + device_sqrt( + static_cast(var_or_invstd_[plane]) + epsilon_); + } + + index_t bs = input_.size(0); + index_t fs = input_.size(2); + + index_t bstep = item.get_local_range(0) * item.get_group_range(0); + for (index_t batch = item.get_global_id(0); batch < bs; batch += bstep) { + auto o = output_[batch][plane]; + auto i = input_[batch][plane]; + + for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE; + feature_vec_begin < fs; + feature_vec_begin += VEC_SIZE * item.get_local_range(1)) { + auto remaining = fs - feature_vec_begin; + if (remaining < VEC_SIZE) { + for (index_t idx = 0; idx < remaining; ++idx) { + index_t feature = feature_vec_begin + idx; + o[feature] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); + } + } else { + using vec_t = memory::aligned_vector; + vec_t vec; +#pragma unroll + for (int vt = 0; vt < VEC_SIZE; ++vt) { + index_t feature = feature_vec_begin + vt; + vec[vt] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); + } + input_scalar_t* write_ptr = &o[feature_vec_begin]; + *(reinterpret_cast(write_ptr)) = vec; + } + } + } + } + + BatchNormTransformInputVectorizedKernelFunctor( + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + RestrictPtrTraits, + index_t> input, + GenericPackedTensorAccessor + output, + const GenericPackedTensorAccessor< + typename std::conditional:: + type, + 1, + RestrictPtrTraits, + index_t> mean, + const GenericPackedTensorAccessor< + typename std::conditional:: + type, + 1, + RestrictPtrTraits, + index_t> var_or_invstd, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> weight, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> bias, + stat_accscalar_t epsilon) + : input_(input), + output_(output), + mean_(mean), + var_or_invstd_(var_or_invstd), + weight_(weight), + bias_(bias), + epsilon_(epsilon) {} + + private: + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + RestrictPtrTraits, + index_t> + input_; + GenericPackedTensorAccessor + output_; + const GenericPackedTensorAccessor< + typename std::conditional::type, + 1, + RestrictPtrTraits, + index_t> + mean_; + const GenericPackedTensorAccessor< + typename std::conditional::type, + 1, + RestrictPtrTraits, + index_t> + var_or_invstd_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> + weight_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> + bias_; + stat_accscalar_t epsilon_; +}; + template void batch_norm_elemt_template( const Tensor& output_, @@ -1315,14 +1456,27 @@ void batch_norm_elemt_template( nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); - auto kfn = BatchNormTransformInputKernelFunctor< - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - - sycl_kernel_submit(global_range, local_range, queue, kfn); + auto output_ptr = (char*)output_reshaped.data_ptr(); + if (output_reshaped.is_contiguous() && + memory::can_vectorize_up_to(output_ptr) >= 4 && + sizeof(input_scalar_t) < sizeof(float)) { + auto kfn = BatchNormTransformInputVectorizedKernelFunctor< + 4, + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } else { + auto kfn = BatchNormTransformInputKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } } template <