Skip to content

Commit

Permalink
Resolve failed formatting checks
Browse files Browse the repository at this point in the history
Signed-off-by: Sanket Kale <[email protected]>
  • Loading branch information
Sanket Kale committed Nov 18, 2024
1 parent 5d75f6f commit 73726e1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
38 changes: 19 additions & 19 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,28 @@ struct KernelVecType<c10::BFloat16> {
#else
#ifdef __aarch64__
#ifndef BF16_SUPPORT
//pass
// pass
#else
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
#else
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
#endif

Expand Down
18 changes: 9 additions & 9 deletions csrc/cpu/cpu_types_arm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float ans = 0;
unroll_loop<int, VEC_ELEM_NUM>([&ans, &ar](int i) { ans += ar.values[i]; });
float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });

return ans;
return answer;
}

FP32Vec8 exp() const {
Expand Down Expand Up @@ -408,23 +408,23 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float ans = 0;
unroll_loop<int, VEC_ELEM_NUM>([&ans, &ar](int i) { ans += ar.values[i]; });
float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });

return ans;
return answer;
};

template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);

AliasReg ar;
ar.reg = reg;
float ans = 0;
float answer = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&ans, &start, ar](int i) { ans += ar.values[start + i]; });
[&answer, &start, ar](int i) { answer += ar.values[start + i]; });

return ans;
return answer;
};

void save(float *ptr) const {
Expand Down

0 comments on commit 73726e1

Please sign in to comment.