From 36d912bb545ff6f8acde41a468f8b056f7edb36f Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 27 Sep 2024 05:40:13 +0800 Subject: [PATCH] test TF32 --- library/device_utils.py | 17 +++++++++++++++++ train_network.py | 3 ++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/library/device_utils.py b/library/device_utils.py index 8823c5d9a..ad842ee3e 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -8,6 +8,11 @@ except Exception: HAS_CUDA = False +try: + HAS_TF32 = torch.cuda.get_device_properties(0).major >= 8 +except Exception: + HAS_TF32 = False + try: HAS_MPS = torch.backends.mps.is_available() except Exception: @@ -82,3 +87,15 @@ def init_ipex(): return except Exception as e: print("failed to initialize ipex:", e) + +def tf32_on(): + """ + Enable TF32 on supported GPUs. + """ + if HAS_TF32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + print("TF32 enabled") + else: + print("TF32 is not supported on this device") \ No newline at end of file diff --git a/train_network.py b/train_network.py index ebb2079b9..55a9f56fa 100644 --- a/train_network.py +++ b/train_network.py @@ -12,9 +12,10 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import init_ipex, clean_memory_on_device, tf32_on init_ipex() +tf32_on() from accelerate.utils import set_seed from diffusers import DDPMScheduler