Skip to content

Commit

Permalink
Add 'ABACUSDataset' in data module (#9)
Browse files Browse the repository at this point in the history
* Prototype code for loading Hamiltonian

* add 'ABACUSDataset' in data module

* modified "basis.dat" storage & can load overlap

* recover some original dataset settings

* add ABACUSDataset in init
  • Loading branch information
SharpLonde authored Nov 21, 2023
1 parent f2653fe commit 959a196
Show file tree
Hide file tree
Showing 9 changed files with 502 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def from_points(
if cell is not None:
kwargs[AtomicDataDict.ONSITENV_CELL_SHIFT_KEY] = onsitenv_cell_shift
kwargs[AtomicDataDict.ONSITENV_INDEX_KEY] = onsitenv_index

return cls(edge_index=edge_index, pos=torch.as_tensor(pos), **kwargs)

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions dptb/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NpzDataset,
ASEDataset,
HDF5Dataset,
ABACUSDataset,
)
from .dataloader import DataLoader, Collater, PartialSampler
from .build import dataset_from_config
Expand All @@ -31,6 +32,7 @@
NpzDataset,
ASEDataset,
HDF5Dataset,
ABACUSDataset,
DataLoader,
Collater,
PartialSampler,
Expand Down
4 changes: 2 additions & 2 deletions dptb/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from importlib import import_module

from dptb import data
from dptb.data.transforms import BondMapper
from dptb.data.transforms import TypeMapper
from dptb.data import AtomicDataset, register_fields
from dptb.utils import instantiate, get_w_prefix

Expand Down Expand Up @@ -81,7 +81,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
)

# Build a TypeMapper from the config
type_mapper, _ = instantiate(BondMapper, prefix=prefix, optional_args=config)
type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config)

# Register fields:
# This might reregister fields, but that's OK:
Expand Down
3 changes: 2 additions & 1 deletion dptb/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from ._ase_dataset import ASEDataset
from ._npz_dataset import NpzDataset
from ._hdf5_dataset import HDF5Dataset
from ._abacus_dataset import ABACUSDataset

__all__ = [ASEDataset, AtomicDataset, AtomicInMemoryDataset, NpzDataset, HDF5Dataset]
__all__ = [ABACUSDataset, ASEDataset, AtomicDataset, AtomicInMemoryDataset, NpzDataset, HDF5Dataset]
78 changes: 78 additions & 0 deletions dptb/data/dataset/_abacus_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Dict, Any, List, Callable, Union, Optional
import os
import numpy as np
import h5py

import torch

from .. import (
AtomicData,
AtomicDataDict,
)
from ..transforms import TypeMapper, OrbitalMapper
from ._base_datasets import AtomicDataset
from dptb.utils.tools import ham_block_to_feature

orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"}

class ABACUSDataset(AtomicDataset):

def __init__(
self,
root: str,
key_mapping: Dict[str, str] = {
"pos": AtomicDataDict.POSITIONS_KEY,
"energy": AtomicDataDict.TOTAL_ENERGY_KEY,
"atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY,
"kpoints": AtomicDataDict.KPOINT_KEY,
"eigenvalues": AtomicDataDict.ENERGY_EIGENVALUE_KEY,
},
preprocess_path: str = None,
h5file_names: Optional[str] = None,
AtomicData_options: Dict[str, Any] = {},
type_mapper: Optional[TypeMapper] = None,
):
super().__init__(root=root, type_mapper=type_mapper)
self.key_mapping = key_mapping
self.key_list = list(key_mapping.keys())
self.value_list = list(key_mapping.values())
self.file_names = h5file_names
self.preprocess_path = preprocess_path

self.r_max = AtomicData_options["r_max"]
self.er_max = AtomicData_options["er_max"]
self.oer_max = AtomicData_options["oer_max"]
self.pbc = AtomicData_options["pbc"]

self.index = None
self.num_examples = len(h5file_names)

def get(self, idx):
file_name = self.file_names[idx]
file = os.path.join(self.preprocess_path, file_name)
data = h5py.File(file, "r")

atomic_data = AtomicData.from_points(
pos = data["pos"][:],
r_max = self.r_max,
cell = data["cell"][:],
er_max = self.er_max,
oer_max = self.oer_max,
pbc = self.pbc,
atomic_numbers = data["atomic_numbers"][:],
)

if data["hamiltonian_blocks"]:
basis = {}
for key, value in data["basis"].items():
basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
idp = OrbitalMapper(basis)
ham_block_to_feature(atomic_data, idp, data["hamiltonian_blocks"], data["overlap_blocks"])
if data["eigenvalue"] and data["kpoint"]:
atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoint"][:], dtype=torch.get_default_dtype())
atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalue"][:], dtype=torch.get_default_dtype())

return atomic_data

def len(self) -> int:
return self.num_examples
1 change: 1 addition & 0 deletions dptb/data/dataset/_base_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def process(self):
for i in include_frames
]


else:
raise ValueError("Invalid return from `self.get_data()`")

Expand Down
Loading

0 comments on commit 959a196

Please sign in to comment.