Skip to content

Commit

Permalink
[optim] Adam defaults to fused when CUDA + differentiable=False (pyto…
Browse files Browse the repository at this point in the history
…rch#90865)

Step 1 in faster default optimizers.

Preliminary benchmarks show gaps in improvement on CUDA for BERT_pytorch and resnet18:
![image](https://user-images.githubusercontent.com/31798555/207707118-14221802-77ce-4ee0-96e3-04638c07924c.png)

Pull Request resolved: pytorch#90865
Approved by: https://github.com/albanD
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Dec 27, 2022
1 parent 0b255b3 commit a061f13
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,12 @@ class Adam(Optimizer):
capturable (bool, optional): whether this instance is safe to capture in a CUDA graph.
Passing True can impair ungraphed performance, so if you don't intend to
graph capture this instance, leave it False (default: False)
fused (bool, optional): whether fused implementation of optimizer is used.
fused (bool, optional): whether the fused implementation (CUDA only) is used.
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
are supported. (default: False)
are supported. Since the fused implementation is usually significantly faster than
the for-loop implementation, we try to use it whenever possible (all parameters
are on CUDA and are of a supported type). Else, we continue with the for-loop
implementation. (default: None)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
Expand All @@ -118,7 +121,7 @@ class Adam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, *, foreach: Optional[bool] = None,
maximize: bool = False, capturable: bool = False,
differentiable: bool = False, fused: bool = False):
differentiable: bool = False, fused: Optional[bool] = None):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -129,6 +132,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad,
maximize=maximize, foreach=foreach, capturable=capturable,
Expand Down Expand Up @@ -288,7 +292,7 @@ def adam(params: List[Tensor],
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[_MultiDeviceReplicator] = None,
found_inf: Optional[_MultiDeviceReplicator] = None,
*,
Expand All @@ -303,6 +307,19 @@ def adam(params: List[Tensor],
See :class:`~torch.optim.Adam` for details.
"""

# We try to use the fused implementation whenever we can since it is fastest.
# It's only available when the tensors are floats on the same CUDA device
# and when differentiable=False.
# We still respect when the user inputs False for fused.
if fused is None:
if not differentiable and all(
p.is_cuda and torch.is_floating_point(p)
for p in params + grads + exp_avgs + exp_avg_sqs + max_exp_avg_sqs + state_steps
):
fused = True
else:
fused = False

if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")

Expand Down

0 comments on commit a061f13

Please sign in to comment.