Skip to content

Commit

Permalink
debug datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingCatty committed Nov 22, 2023
1 parent b58e472 commit 9885430
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
18 changes: 8 additions & 10 deletions dptb/data/dataset/_abacus_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ def __init__(
self.file_names = h5file_names
self.preprocess_path = preprocess_path

self.r_max = AtomicData_options["r_max"]
self.er_max = AtomicData_options["er_max"]
self.oer_max = AtomicData_options["oer_max"]
self.pbc = AtomicData_options["pbc"]
self.AtomicData_options = AtomicData_options
# self.r_max = AtomicData_options["r_max"]
# self.er_max = AtomicData_options["er_max"]
# self.oer_max = AtomicData_options["oer_max"]
# self.pbc = AtomicData_options["pbc"]

self.index = None
self.num_examples = len(h5file_names)
Expand All @@ -54,21 +55,18 @@ def get(self, idx):

atomic_data = AtomicData.from_points(
pos = data["pos"][:],
r_max = self.r_max,
cell = data["cell"][:],
er_max = self.er_max,
oer_max = self.oer_max,
pbc = self.pbc,
atomic_numbers = data["atomic_numbers"][:],
**self.AtomicData_options,
)

if data["hamiltonian_blocks"]:
basis = {}
for key, value in data["basis"].items():
basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
idp = OrbitalMapper(basis)
ham_block_to_feature(atomic_data, idp, data["hamiltonian_blocks"], data["overlap_blocks"])
if data["eigenvalue"] and data["kpoint"]:
ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
if data.get("eigenvalue") and data.get("kpoint"):
atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoint"][:], dtype=torch.get_default_dtype())
atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalue"][:], dtype=torch.get_default_dtype())

Expand Down
14 changes: 7 additions & 7 deletions dptb/data/interfaces/abacus.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,13 @@ def parse_matrix(matrix_path, factor, spinful=False):
f["hamiltonian_blocks"] = h5py.ExternalLink("hamiltonians.h5", "/")
if add_overlap:
f["overlap_blocks"] = h5py.ExternalLink("overlaps.h5", "/")
else:
f["overlap_blocks"] = False
else:
f["hamiltonian_blocks"] = False
# else:
# f["overlap_blocks"] = False
# else:
# f["hamiltonian_blocks"] = False
if get_eigenvalues:
f["kpoint"] = kpts
f["eigenvalue"] = band
else:
f["kpoint"] = False
f["eigenvalue"] = False
# else:
# f["kpoint"] = False
# f["eigenvalue"] = False

0 comments on commit 9885430

Please sign in to comment.