Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 25, 2024
1 parent 09c54f3 commit bc854b2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
5 changes: 4 additions & 1 deletion deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,10 @@ def get_lr(lr_params):

# JIT
if JIT:
self.model = paddle.jit.to_static(self.model, full_graph=False)
raise NotImplementedError(
"JIT is not supported yet when training with Paddle"
)
self.model = paddle.jit.to_static(self.model)

# Model Wrapper
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
Expand Down
51 changes: 41 additions & 10 deletions deepmd/pd/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
import queue
import time
from collections.abc import (
Iterator,
)
from multiprocessing.dummy import (
Pool,
)
Expand Down Expand Up @@ -95,10 +98,15 @@ def construct_dataset(system):
type_map=type_map,
)

with Pool(1) as pool:
self.systems: List[DeepmdDataSetForLoader] = pool.map(
construct_dataset, systems
with Pool(
os.cpu_count()
// (
int(os.environ["LOCAL_WORLD_SIZE"])
if dist.is_available() and dist.is_initialized()
else 1
)
) as pool:
self.systems = pool.map(construct_dataset, systems)

self.sampler_list: List[DistributedBatchSampler] = []
self.index = []
Expand Down Expand Up @@ -129,31 +137,54 @@ def construct_dataset(system):
if dist.is_available() and dist.is_initialized():
system_batch_sampler = DistributedBatchSampler(
system,
shuffle=False,
shuffle=(
(not (dist.is_available() and dist.is_initialized()))
and shuffle
),
batch_size=int(batch_size),
)
self.sampler_list.append(system_batch_sampler)
else:
system_batch_sampler = BatchSampler(
system,
shuffle=shuffle,
shuffle=(
(not (dist.is_available() and dist.is_initialized()))
and shuffle
),
batch_size=int(batch_size),
)
self.sampler_list.append(system_batch_sampler)
system_dataloader = DataLoader(
dataset=system,
num_workers=0, # Should be 0 to avoid too many threads forked
batch_sampler=system_batch_sampler,
collate_fn=collate_batch,
# shuffle=(not (dist.is_available() and dist.is_initialized()))
# and shuffle,
use_buffer_reader=False,
places=["cpu"],
)
self.dataloaders.append(system_dataloader)
self.index.append(len(system_dataloader))
self.total_batch += len(system_dataloader)
# Initialize iterator instances for DataLoader

class LazyIter:
"""Lazy iterator to prevent fetching data when iter(item)."""

def __init__(self, item):
self.item = item

def __iter__(self):
# directly return
return self

def __next__(self):
if not isinstance(self.item, Iterator):
# make iterator here lazily
self.item = iter(self.item)
return next(self.item)

self.iters = []
# with paddle.device("cpu"):
for item in self.dataloaders:
self.iters.append(iter(item))
self.iters.append(LazyIter(item))

def set_noise(self, noise_settings):
# noise_settings['noise_type'] # "trunc_normal", "normal", "uniform"
Expand Down

0 comments on commit bc854b2

Please sign in to comment.