forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLinearAlgebraKernel.cpp
89 lines (80 loc) · 2.67 KB
/
LinearAlgebraKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/LinearAlgebra.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/cpu/Reduce.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/irange.h>
namespace at::native { namespace {
void addr_kernel(TensorIterator &iter,
const Scalar& beta, const Scalar& alpha) {
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
// when beta is false, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == false) {
cpu_kernel(iter,
[=](scalar_t /*self_val*/,
scalar_t vec1_val,
scalar_t vec2_val) -> scalar_t {
return alpha_val && vec1_val && vec2_val;
}
);
} else {
cpu_kernel(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) -> scalar_t {
return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
}
);
}
return;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
iter.dtype(), "addr_cpu", [&]() {
using Vec = Vectorized<scalar_t>;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
auto beta_vec = Vec(beta_val);
auto alpha_vec = Vec(alpha_val);
const scalar_t zero_val(0);
// when beta == 0, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == zero_val) {
cpu_kernel_vec(iter,
[=](scalar_t /*self_val*/,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return alpha_val * vec1_val * vec2_val;
},
[=](Vec /*self_vec*/,
Vec vec1_vec,
Vec vec2_vec) __ubsan_ignore_undefined__ {
return alpha_vec * vec1_vec * vec2_vec;
}
);
} else {
cpu_kernel_vec(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return beta_val * self_val + alpha_val * vec1_val * vec2_val;
},
[=](Vec self_vec,
Vec vec1_vec,
Vec vec2_vec) __ubsan_ignore_undefined__ {
return beta_vec * self_vec + alpha_vec * vec1_vec * vec2_vec;
}
);
}
}
);
}
} // anonymous namespace
REGISTER_DISPATCH(addr_stub, &addr_kernel)
} // namespace at::native