diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 21a285f540..6fe8562ad6 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools import logging import time from copy import ( @@ -38,7 +39,6 @@ from deepmd.pt.utils.dataloader import ( BufferedIterator, get_weighted_sampler, - lazy, ) from deepmd.pt.utils.env import ( DEVICE, @@ -214,7 +214,7 @@ def get_single_model( _validation_data.add_data_requirement(_data_requirement) if not resuming: - @lazy + @functools.lru_cache def get_sample(): sampled = make_stat_input( _training_data.systems, diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index b197f46124..65a96418c9 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -262,20 +262,3 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False): with torch.device("cpu"): sampler = WeightedRandomSampler(probs, len_sampler, replacement=True) return sampler - - -class LazyFunction: - def __init__(self, func): - self.func = func - self.result = None - self.called = False - - def __call__(self, *args, **kwargs): - if not self.called: - self.result = self.func(*args, **kwargs) - self.called = True - return self.result - - -def lazy(func): - return LazyFunction(func)