Skip to content

Commit

Permalink
Merge branch 'devel' into refactor_property
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang authored Dec 13, 2024
2 parents 6172dc4 + e9ed267 commit 0608e17
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
exclude: ^source/3rdparty
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.1
rev: v0.8.2
hooks:
- id: ruff
args: ["--fix"]
Expand Down
65 changes: 24 additions & 41 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import queue
import time
from multiprocessing.dummy import (
Pool,
)
from queue import (
Queue,
)
from threading import (
Thread,
)
Expand Down Expand Up @@ -204,70 +206,51 @@ def print_summary(
)


_sentinel = object()
QUEUESIZE = 32


class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len) -> None:
Thread.__init__(self)
def __init__(self, queue, source) -> None:
super().__init__()
self.daemon = True
self._queue = queue
self._source = source # Main DL iterator
self._max_len = max_len #

def run(self) -> None:
for item in self._source:
self._queue.put(item) # Blocking if the queue is full

# Signal the consumer we are done.
self._queue.put(_sentinel)
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration())


QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._queue = Queue(QUEUESIZE)
self._iterable = iterable
self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
self._consumer.daemon = True
self._consumer = BackgroundConsumer(self._queue, self._iterable)
self._consumer.start()
self.last_warning_time = time.time()

def __iter__(self):
return self

def __len__(self) -> int:
return self.total
return len(self._iterable)

def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
log.warning(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if (
wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(
f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes."
)
self.last_warning_time = start_wait
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration
return item


Expand Down
10 changes: 10 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def is_file(self) -> bool:
def is_dir(self) -> bool:
"""Check if self is directory."""

@abstractmethod
def __getnewargs__(self):
"""Return the arguments to be passed to __new__ when unpickling an instance."""

@abstractmethod
def __truediv__(self, key: str) -> "DPPath":
"""Used for / operator."""
Expand Down Expand Up @@ -169,6 +173,9 @@ def __init__(self, path: Union[str, Path], mode: str = "r") -> None:
self.mode = mode
self.path = Path(path)

def __getnewargs__(self):
return (self.path, self.mode)

def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down Expand Up @@ -304,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
# h5 path: default is the root path
self._name = s[1] if len(s) > 1 else "/"

def __getnewargs__(self):
return (self.root_path, self.mode)

@classmethod
@lru_cache(None)
def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File:
Expand Down
8 changes: 7 additions & 1 deletion source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ struct NeighborListData {
std::vector<int*> firstneigh;

public:
void copy_from_nlist(const InputNlist& inlist);
/**
* @brief Copy the neighbor list from an InputNlist.
* @param[in] inlist The input neighbor list.
* @param[in] natoms The number of atoms to copy. If natoms is -1, copy all
* atoms.
*/
void copy_from_nlist(const InputNlist& inlist, const int natoms = -1);
void shuffle(const std::vector<int>& fwd_map);
void shuffle(const deepmd::AtomMap& map);
void shuffle_exclude_empty(const std::vector<int>& fwd_map);
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status);
// nlist
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
}
size_t max_size = 0;
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing) {
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
c10::optional<torch::Tensor> mapping_tensor;
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing) {
Expand Down
5 changes: 3 additions & 2 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ template void deepmd::select_real_atoms_coord<float>(
const int& nall,
const bool aparam_nall);

void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist) {
int inum = inlist.inum;
void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist,
const int natoms) {
int inum = natoms >= 0 ? natoms : inlist.inum;
ilist.resize(inum);
jlist.resize(inum);
memcpy(&ilist[0], inlist.ilist, inum * sizeof(int));
Expand Down
1 change: 1 addition & 0 deletions source/lmp/pppm_dplr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class PPPMDPLR : public PPPM {
~PPPMDPLR() override {};
void init() override;
const std::vector<double> &get_fele() const { return fele; };
std::vector<double> &get_fele() { return fele; }

protected:
void compute(int, int) override;
Expand Down

0 comments on commit 0608e17

Please sign in to comment.