forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 1
/
mixed_precision.py
38 lines (32 loc) · 949 Bytes
/
mixed_precision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import torch
from torch.distributed.fsdp import (
MixedPrecision,
)
# requires grad scaler in main loop
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
cast_forward_inputs=True,
)
bfSixteen_mixed = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)