forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FusedAdamWKernel.cu
45 lines (41 loc) · 1.95 KB
/
FusedAdamWKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TypeDefault.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/fused_adamw_amsgrad_impl.cuh>
#include <ATen/native/cuda/fused_adamw_impl.cuh>
#include <c10/util/Exception.h>
namespace at { namespace native {
// note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile, defensively split instantiations into _impl files.
// this is only for CUDA 11.3 for which it took about 20 minutes and 28 minutes in my workstation and CI, respectively.
// As a data point, it took about 20 seconds for CUDA 11.7 installed in my environment.
// See https://github.com/pytorch/pytorch/pull/81705 for details.
void _fused_adamw_kernel_cuda_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
) {
if (amsgrad) {
TORCH_CHECK(
at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
_fused_adamw_amsgrad_cuda_impl_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, maximize, grad_scale, found_inf);
} else {
TORCH_CHECK(
at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
_fused_adamw_cuda_impl_(params, grads, exp_avgs, exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, maximize, grad_scale, found_inf);
}
}
}} // namespace at::native