Skip to content

Commit

Permalink
Merge branch 'tf32' of https://github.com/sdbds/sd-scripts into qinglong
Browse files Browse the repository at this point in the history
# Conflicts:
#	requirements.txt
  • Loading branch information
sdbds committed Sep 26, 2024
2 parents f77de6b + 36d912b commit 8071c5f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
17 changes: 17 additions & 0 deletions library/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
except Exception:
HAS_CUDA = False

try:
HAS_TF32 = torch.cuda.get_device_properties(0).major >= 8
except Exception:
HAS_TF32 = False

try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
Expand Down Expand Up @@ -82,3 +87,15 @@ def init_ipex():
return
except Exception as e:
print("failed to initialize ipex:", e)

def tf32_on():
"""
Enable TF32 on supported GPUs.
"""
if HAS_TF32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("TF32 enabled")
else:
print("TF32 is not supported on this device")
3 changes: 2 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from tqdm import tqdm

import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import init_ipex, clean_memory_on_device, tf32_on

init_ipex()
tf32_on()

from accelerate.utils import set_seed
from diffusers import DDPMScheduler
Expand Down

0 comments on commit 8071c5f

Please sign in to comment.