From a061f139dccb5f56c9d14e25ef54ff821b4dd3c8 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 27 Dec 2022 01:28:47 +0000 Subject: [PATCH] [optim] Adam defaults to fused when CUDA + differentiable=False (#90865) Step 1 in faster default optimizers. Preliminary benchmarks show gaps in improvement on CUDA for BERT_pytorch and resnet18: ![image](https://user-images.githubusercontent.com/31798555/207707118-14221802-77ce-4ee0-96e3-04638c07924c.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90865 Approved by: https://github.com/albanD --- torch/optim/adam.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 7463e8e13300d..c5123e67659ce 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -105,9 +105,12 @@ class Adam(Optimizer): capturable (bool, optional): whether this instance is safe to capture in a CUDA graph. Passing True can impair ungraphed performance, so if you don't intend to graph capture this instance, leave it False (default: False) - fused (bool, optional): whether fused implementation of optimizer is used. + fused (bool, optional): whether the fused implementation (CUDA only) is used. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` - are supported. (default: False) + are supported. Since the fused implementation is usually significantly faster than + the for-loop implementation, we try to use it whenever possible (all parameters + are on CUDA and are of a supported type). Else, we continue with the for-loop + implementation. (default: None) .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -118,7 +121,7 @@ class Adam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, *, foreach: Optional[bool] = None, maximize: bool = False, capturable: bool = False, - differentiable: bool = False, fused: bool = False): + differentiable: bool = False, fused: Optional[bool] = None): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -129,6 +132,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize, foreach=foreach, capturable=capturable, @@ -288,7 +292,7 @@ def adam(params: List[Tensor], foreach: Optional[bool] = None, capturable: bool = False, differentiable: bool = False, - fused: bool = False, + fused: Optional[bool] = None, grad_scale: Optional[_MultiDeviceReplicator] = None, found_inf: Optional[_MultiDeviceReplicator] = None, *, @@ -303,6 +307,19 @@ def adam(params: List[Tensor], See :class:`~torch.optim.Adam` for details. """ + # We try to use the fused implementation whenever we can since it is fastest. + # It's only available when the tensors are floats on the same CUDA device + # and when differentiable=False. + # We still respect when the user inputs False for fused. + if fused is None: + if not differentiable and all( + p.is_cuda and torch.is_floating_point(p) + for p in params + grads + exp_avgs + exp_avg_sqs + max_exp_avg_sqs + state_steps + ): + fused = True + else: + fused = False + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")