diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index d0a8cd8115..baade3dbea 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -595,7 +595,10 @@ def dist_cp_load( storage_reader: StorageReader, load_planner: Optional[LoadPlanner] = None, ): - if version.parse(torch.__version__) >= version.parse('2.4.0'): + if ( + version.parse(torch.__version__) >= version.parse('2.4.0') and + version.parse(torch.__version__) < version.parse('2.5.0') + ): from torch.distributed.checkpoint.utils import CheckpointException try: dist_cp.load(