Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up the init interface of pt.dataloader #3434

Merged
merged 7 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def prepare_trainer_input_single(
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
model_params_single["type_map"],
)
if validation_systems
else None
Expand All @@ -143,13 +143,13 @@ def prepare_trainer_input_single(
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
model_params_single["type_map"],
)
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
model_params_single["type_map"],
)
return (
train_data_single,
Expand Down
21 changes: 17 additions & 4 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,27 @@ def setup_seed(seed):


class DpLoaderSet(Dataset):
"""A dataset for storing DataLoaders to multiple Systems."""
"""A dataset for storing DataLoaders to multiple Systems.

Parameters
----------
sys_path
Path to the data system
batch_size
Max frame count in a batch.
type_map
Gives the name of different atom types
seed
Random seed for dataloader
shuffle
If the data are shuffled (Only effective in serial mode. Always shuffle in distributed data parallelism)
"""

def __init__(
self,
systems,
batch_size,
model_params,
type_map,
seed=10,
shuffle=True,
):
Expand All @@ -77,8 +91,7 @@ def __init__(
def construct_dataset(system):
return DeepmdDataSetForLoader(
system=system,
type_map=model_params["type_map"],
shuffle=shuffle,
type_map=type_map,
)

with Pool(
Expand Down
13 changes: 3 additions & 10 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import (
List,
Optional,
)

from torch.utils.data import (
Expand All @@ -16,24 +17,16 @@


class DeepmdDataSetForLoader(Dataset):
def __init__(
self,
system: str,
type_map: str,
shuffle=True,
):
def __init__(self, system: str, type_map: Optional[List[str]] = None):

Check warning on line 20 in deepmd/pt/utils/dataset.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataset.py#L20

Added line #L20 was not covered by tests
"""Construct DeePMD-style dataset containing frames cross different systems.

Args:
- systems: Paths to systems.
- batch_size: Max frame count in a batch.
- type_map: Atom types.
"""
self.system = system
self._type_map = type_map
self._data_system = DeepmdData(
sys_path=system, shuffle_test=shuffle, type_map=self._type_map
)
self._data_system = DeepmdData(sys_path=system, type_map=self._type_map)

Check warning on line 29 in deepmd/pt/utils/dataset.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataset.py#L29

Added line #L29 was not covered by tests
self.mixed_type = self._data_system.mixed_type
self._ntypes = self._data_system.get_ntypes()
self._natoms = self._data_system.get_natoms()
Expand Down
13 changes: 1 addition & 12 deletions source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,7 @@ def test_consistency(self):
self.wanted_step
)
# Build DeePMD graph
my_ds = DpLoaderSet(
self.systems,
self.batch_size,
model_params={
"descriptor": {
"type": "se_e2_a",
"sel": self.sel,
"rcut": self.rcut,
},
"type_map": self.type_map,
},
)
my_ds = DpLoaderSet(self.systems, self.batch_size, self.type_map)
my_ds.add_data_requirement(energy_data_requirement)
my_model = get_model(
model_params={
Expand Down
13 changes: 1 addition & 12 deletions source/tests/pt/model/test_saveload_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,7 @@ def get_dataset(config):
batch_size = config["training"]["training_data"]["batch_size"]
type_map = model_config["type_map"]

dataset = DpLoaderSet(
systems,
batch_size,
model_params={
"descriptor": {
"type": "dpa1",
"sel": sel,
"rcut": rcut,
},
"type_map": type_map,
},
)
dataset = DpLoaderSet(systems, batch_size, type_map)
data_stat_nbatch = model_config.get("data_stat_nbatch", 10)
sampled = make_stat_input(dataset.systems, dataset.dataloaders, data_stat_nbatch)
return dataset, sampled
Expand Down
13 changes: 1 addition & 12 deletions source/tests/pt/model/test_saveload_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,7 @@ def get_dataset(config):
batch_size = config["training"]["training_data"]["batch_size"]
type_map = model_config["type_map"]

dataset = DpLoaderSet(
systems,
batch_size,
model_params={
"descriptor": {
"type": "se_e2_a",
"sel": sel,
"rcut": rcut,
},
"type_map": type_map,
},
)
dataset = DpLoaderSet(systems, batch_size, type_map)
data_stat_nbatch = model_config.get("data_stat_nbatch", 10)
sampled = make_stat_input(dataset.systems, dataset.dataloaders, data_stat_nbatch)
return dataset, sampled
Expand Down
9 changes: 1 addition & 8 deletions source/tests/pt/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@ def setUp(self):
self.my_dataset = DpLoaderSet(
self.systems,
self.batch_size,
model_params={
"descriptor": {
"type": "se_e2_a",
"sel": self.sel,
"rcut": self.rcut,
},
"type_map": model_config["type_map"],
},
model_config["type_map"],
seed=10,
shuffle=False,
)
Expand Down
9 changes: 1 addition & 8 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,7 @@ def setUp(self):
self.my_dataset = DpLoaderSet(
self.systems,
self.batch_size,
model_params={
"descriptor": {
"type": "se_e2_a",
"sel": self.sel,
"rcut": self.rcut,
},
"type_map": model_config["type_map"],
},
model_config["type_map"],
seed=10,
)
self.filter_neuron = model_config["descriptor"]["neuron"]
Expand Down