Skip to content

Commit

Permalink
Merge branch 'devel' into add_dpa1
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Dec 9, 2024
2 parents 22fffd8 + ec3b83f commit edea0aa
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 73 deletions.
52 changes: 26 additions & 26 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ repos:
- id: clang-format
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$)
# markdown, yaml, CSS, javascript
# - repo: https://github.com/pre-commit/mirrors-prettier
# rev: v4.0.0-alpha.8
# hooks:
# - id: prettier
# types_or: [markdown, yaml, css]
# # workflow files cannot be modified by pre-commit.ci
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
types_or: [markdown, yaml, css]
# workflow files cannot be modified by pre-commit.ci
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# Shell
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.10.0-2
Expand All @@ -83,25 +83,25 @@ repos:
hooks:
- id: cmake-format
#- id: cmake-lint
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
# rev: v1.13.0
# hooks:
# - id: bibtex-tidy
# args:
# - --curly
# - --numeric
# - --align=13
# - --blank-lines
# # disable sort: the order of keys and fields has explict meanings
# #- --sort=key
# - --duplicates=key,doi,citation,abstract
# - --merge=combine
# #- --sort-fields
# #- --strip-comments
# - --trailing-commas
# - --encode-urls
# - --remove-empty-fields
# - --wrap=80
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
rev: v1.13.0
hooks:
- id: bibtex-tidy
args:
- --curly
- --numeric
- --align=13
- --blank-lines
# disable sort: the order of keys and fields has explict meanings
#- --sort=key
- --duplicates=key,doi,citation,abstract
- --merge=combine
#- --sort-fields
#- --strip-comments
- --trailing-commas
- --encode-urls
- --remove-empty-fields
- --wrap=80
# license header
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.5
Expand Down
65 changes: 24 additions & 41 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import queue
import time
from multiprocessing.dummy import (
Pool,
)
from queue import (
Queue,
)
from threading import (
Thread,
)
Expand Down Expand Up @@ -204,70 +206,51 @@ def print_summary(
)


_sentinel = object()
QUEUESIZE = 32


class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len) -> None:
Thread.__init__(self)
def __init__(self, queue, source) -> None:
super().__init__()
self.daemon = True
self._queue = queue
self._source = source # Main DL iterator
self._max_len = max_len #

def run(self) -> None:
for item in self._source:
self._queue.put(item) # Blocking if the queue is full

# Signal the consumer we are done.
self._queue.put(_sentinel)
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration())


QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._queue = Queue(QUEUESIZE)
self._iterable = iterable
self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
self._consumer.daemon = True
self._consumer = BackgroundConsumer(self._queue, self._iterable)
self._consumer.start()
self.last_warning_time = time.time()

def __iter__(self):
return self

def __len__(self) -> int:
return self.total
return len(self._iterable)

def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
log.warning(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if (
wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(
f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes."
)
self.last_warning_time = start_wait
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration
return item


Expand Down
10 changes: 10 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def is_file(self) -> bool:
def is_dir(self) -> bool:
"""Check if self is directory."""

@abstractmethod
def __getnewargs__(self):
"""Return the arguments to be passed to __new__ when unpickling an instance."""

@abstractmethod
def __truediv__(self, key: str) -> "DPPath":
"""Used for / operator."""
Expand Down Expand Up @@ -169,6 +173,9 @@ def __init__(self, path: Union[str, Path], mode: str = "r") -> None:
self.mode = mode
self.path = Path(path)

def __getnewargs__(self):
return (self.path, self.mode)

def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down Expand Up @@ -304,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
# h5 path: default is the root path
self._name = s[1] if len(s) > 1 else "/"

def __getnewargs__(self):
return (self.root_path, self.mode)

@classmethod
@lru_cache(None)
def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File:
Expand Down
8 changes: 7 additions & 1 deletion source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ struct NeighborListData {
std::vector<int*> firstneigh;

public:
void copy_from_nlist(const InputNlist& inlist);
/**
* @brief Copy the neighbor list from an InputNlist.
* @param[in] inlist The input neighbor list.
* @param[in] natoms The number of atoms to copy. If natoms is -1, copy all
* atoms.
*/
void copy_from_nlist(const InputNlist& inlist, const int natoms = -1);
void shuffle(const std::vector<int>& fwd_map);
void shuffle(const deepmd::AtomMap& map);
void shuffle_exclude_empty(const std::vector<int>& fwd_map);
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status);
// nlist
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
}
size_t max_size = 0;
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing) {
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
c10::optional<torch::Tensor> mapping_tensor;
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing) {
Expand Down
5 changes: 3 additions & 2 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ template void deepmd::select_real_atoms_coord<float>(
const int& nall,
const bool aparam_nall);

void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist) {
int inum = inlist.inum;
void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist,
const int natoms) {
int inum = natoms >= 0 ? natoms : inlist.inum;
ilist.resize(inum);
jlist.resize(inum);
memcpy(&ilist[0], inlist.ilist, inum * sizeof(int));
Expand Down
1 change: 1 addition & 0 deletions source/lmp/pppm_dplr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class PPPMDPLR : public PPPM {
~PPPMDPLR() override {};
void init() override;
const std::vector<double> &get_fele() const { return fele; };
std::vector<double> &get_fele() { return fele; }

protected:
void compute(int, int) override;
Expand Down

0 comments on commit edea0aa

Please sign in to comment.