diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b52d517d82..38f5abf616 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.1 + rev: v0.8.2 hooks: - id: ruff args: ["--fix"] diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index d33b17b035..67e5195f6d 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging import os -import queue import time from multiprocessing.dummy import ( Pool, ) +from queue import ( + Queue, +) from threading import ( Thread, ) @@ -204,70 +206,51 @@ def print_summary( ) -_sentinel = object() -QUEUESIZE = 32 - - class BackgroundConsumer(Thread): - def __init__(self, queue, source, max_len) -> None: - Thread.__init__(self) + def __init__(self, queue, source) -> None: + super().__init__() + self.daemon = True self._queue = queue self._source = source # Main DL iterator - self._max_len = max_len # def run(self) -> None: for item in self._source: self._queue.put(item) # Blocking if the queue is full - # Signal the consumer we are done. - self._queue.put(_sentinel) + # Signal the consumer we are done; this should not happen for DataLoader + self._queue.put(StopIteration()) + + +QUEUESIZE = 32 class BufferedIterator: def __init__(self, iterable) -> None: - self._queue = queue.Queue(QUEUESIZE) + self._queue = Queue(QUEUESIZE) self._iterable = iterable - self._consumer = None - - self.start_time = time.time() - self.warning_time = None - self.total = len(iterable) - - def _create_consumer(self) -> None: - self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total) - self._consumer.daemon = True + self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() + self.last_warning_time = time.time() def __iter__(self): return self def __len__(self) -> int: - return self.total + return len(self._iterable) def __next__(self): - # Create consumer if not created yet - if self._consumer is None: - self._create_consumer() - # Notify the user if there is a data loading bottleneck - if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): - if time.time() - self.start_time > 5 * 60: - if ( - self.warning_time is None - or time.time() - self.warning_time > 15 * 60 - ): - log.warning( - "Data loading buffer is empty or nearly empty. This may " - "indicate a data loading bottleneck, and increasing the " - "number of workers (--num-workers) may help." - ) - self.warning_time = time.time() - - # Get next example + start_wait = time.time() item = self._queue.get() + wait_time = time.time() - start_wait + if ( + wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60 + ): # Even for Multi-Task training, each step usually takes < 1s + log.warning( + f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes." + ) + self.last_warning_time = start_wait if isinstance(item, Exception): raise item - if item is _sentinel: - raise StopIteration return item diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index c542ccf661..87a44aa70d 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -114,6 +114,10 @@ def is_file(self) -> bool: def is_dir(self) -> bool: """Check if self is directory.""" + @abstractmethod + def __getnewargs__(self): + """Return the arguments to be passed to __new__ when unpickling an instance.""" + @abstractmethod def __truediv__(self, key: str) -> "DPPath": """Used for / operator.""" @@ -169,6 +173,9 @@ def __init__(self, path: Union[str, Path], mode: str = "r") -> None: self.mode = mode self.path = Path(path) + def __getnewargs__(self): + return (self.path, self.mode) + def load_numpy(self) -> np.ndarray: """Load NumPy array. @@ -304,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None: # h5 path: default is the root path self._name = s[1] if len(s) > 1 else "/" + def __getnewargs__(self): + return (self.root_path, self.mode) + @classmethod @lru_cache(None) def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File: diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 2bd0cf7135..612f699ea4 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -33,7 +33,13 @@ struct NeighborListData { std::vector firstneigh; public: - void copy_from_nlist(const InputNlist& inlist); + /** + * @brief Copy the neighbor list from an InputNlist. + * @param[in] inlist The input neighbor list. + * @param[in] natoms The number of atoms to copy. If natoms is -1, copy all + * atoms. + */ + void copy_from_nlist(const InputNlist& inlist, const int natoms = -1); void shuffle(const std::vector& fwd_map); void shuffle(const deepmd::AtomMap& map); void shuffle_exclude_empty(const std::vector& fwd_map); diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc index 805380081d..07f8b9119b 100644 --- a/source/api_cc/src/DeepPotJAX.cc +++ b/source/api_cc/src/DeepPotJAX.cc @@ -566,7 +566,7 @@ void deepmd::DeepPotJAX::compute(std::vector& ener, input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status); // nlist if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list); + nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); } size_t max_size = 0; diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 6910de3ccd..abd35eaf1e 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -169,7 +169,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, at::Tensor atype_Tensor = torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device); if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list); + nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); if (do_message_passing) { diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index aef2d60150..7421b623db 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -177,7 +177,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device); c10::optional mapping_tensor; if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list); + nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); if (do_message_passing) { diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 5a4f05d75c..c51ae9a8b4 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -232,8 +232,9 @@ template void deepmd::select_real_atoms_coord( const int& nall, const bool aparam_nall); -void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist) { - int inum = inlist.inum; +void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist, + const int natoms) { + int inum = natoms >= 0 ? natoms : inlist.inum; ilist.resize(inum); jlist.resize(inum); memcpy(&ilist[0], inlist.ilist, inum * sizeof(int)); diff --git a/source/lmp/pppm_dplr.h b/source/lmp/pppm_dplr.h index 1484a16e72..b7e221c686 100644 --- a/source/lmp/pppm_dplr.h +++ b/source/lmp/pppm_dplr.h @@ -28,6 +28,7 @@ class PPPMDPLR : public PPPM { ~PPPMDPLR() override {}; void init() override; const std::vector &get_fele() const { return fele; }; + std::vector &get_fele() { return fele; } protected: void compute(int, int) override;