diff --git a/dptb/data/dataset/_abacus_dataset.py b/dptb/data/dataset/_abacus_dataset.py index ac340182..04d7dfb5 100644 --- a/dptb/data/dataset/_abacus_dataset.py +++ b/dptb/data/dataset/_abacus_dataset.py @@ -39,10 +39,11 @@ def __init__( self.file_names = h5file_names self.preprocess_path = preprocess_path - self.r_max = AtomicData_options["r_max"] - self.er_max = AtomicData_options["er_max"] - self.oer_max = AtomicData_options["oer_max"] - self.pbc = AtomicData_options["pbc"] + self.AtomicData_options = AtomicData_options + # self.r_max = AtomicData_options["r_max"] + # self.er_max = AtomicData_options["er_max"] + # self.oer_max = AtomicData_options["oer_max"] + # self.pbc = AtomicData_options["pbc"] self.index = None self.num_examples = len(h5file_names) @@ -54,12 +55,9 @@ def get(self, idx): atomic_data = AtomicData.from_points( pos = data["pos"][:], - r_max = self.r_max, cell = data["cell"][:], - er_max = self.er_max, - oer_max = self.oer_max, - pbc = self.pbc, atomic_numbers = data["atomic_numbers"][:], + **self.AtomicData_options, ) if data["hamiltonian_blocks"]: @@ -67,8 +65,8 @@ def get(self, idx): for key, value in data["basis"].items(): basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)] idp = OrbitalMapper(basis) - ham_block_to_feature(atomic_data, idp, data["hamiltonian_blocks"], data["overlap_blocks"]) - if data["eigenvalue"] and data["kpoint"]: + ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) + if data.get("eigenvalue") and data.get("kpoint"): atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoint"][:], dtype=torch.get_default_dtype()) atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalue"][:], dtype=torch.get_default_dtype()) diff --git a/dptb/data/interfaces/abacus.py b/dptb/data/interfaces/abacus.py index 7bb6ac49..b390b3f7 100644 --- a/dptb/data/interfaces/abacus.py +++ b/dptb/data/interfaces/abacus.py @@ -295,13 +295,13 @@ def parse_matrix(matrix_path, factor, spinful=False): f["hamiltonian_blocks"] = h5py.ExternalLink("hamiltonians.h5", "/") if add_overlap: f["overlap_blocks"] = h5py.ExternalLink("overlaps.h5", "/") - else: - f["overlap_blocks"] = False - else: - f["hamiltonian_blocks"] = False + # else: + # f["overlap_blocks"] = False + # else: + # f["hamiltonian_blocks"] = False if get_eigenvalues: f["kpoint"] = kpts f["eigenvalue"] = band - else: - f["kpoint"] = False - f["eigenvalue"] = False \ No newline at end of file + # else: + # f["kpoint"] = False + # f["eigenvalue"] = False \ No newline at end of file