diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 1e5314a821..ce7fc136bd 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -103,7 +103,8 @@ def get_trainer( finetune_links=None, ): multi_task = "model_dict" in config.get("model", {}) - + # https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere + torch.backends.cuda.matmul.allow_tf32 = True # Initialize DDP local_rank = os.environ.get("LOCAL_RANK") if local_rank is not None: