Skip to content

Commit

Permalink
Merge branch 'main' into get_fermi
Browse files Browse the repository at this point in the history
  • Loading branch information
AsymmetryChou committed Dec 21, 2024
2 parents c2895c0 + 42eaf5d commit dabd47c
Show file tree
Hide file tree
Showing 37 changed files with 2,662 additions and 308 deletions.
110 changes: 99 additions & 11 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from .util import _TORCH_INTEGER_DTYPES
from dptb.utils.torch_geometric.data import Data
from dptb.utils.constants import atomic_num_dict
import logging

log = logging.getLogger(__name__)

# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]
Expand Down Expand Up @@ -874,11 +877,10 @@ def without_nodes(self, which_nodes):
return type(self)(**new_dict)


_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false")
_ERROR_ON_NO_EDGES = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false"), "NEQUIP_ERROR_ON_NO_EDGES must be 'true' or 'false'"
_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"


def neighbor_list_and_relative_vec(
pos,
r_max,
Expand Down Expand Up @@ -1026,6 +1028,13 @@ def neighbor_list_and_relative_vec(
# so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True.
if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)):
o_mask[i] = True

if self_interaction:
log.warning("self_interaction is True, but usually we do not want the self-interaction, please check if it is correct.")
# for self-interaction, the above will remove the self-interaction, i.e. i == j, shift == [0, 0, 0]. since -0 = 0.
if (o_shift[i] == np.array([0, 0, 0])).all():
o_mask[i] = True

del rev_dict
del o_first_idex
del o_second_idex
Expand All @@ -1038,6 +1047,7 @@ def neighbor_list_and_relative_vec(
shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device)

if not reduce:
assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated."
first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0)
shifts = torch.cat((shifts, -shifts), dim=0)

Expand All @@ -1049,7 +1059,7 @@ def neighbor_list_and_relative_vec(
# TODO: mask the edges that is larger than r_max
if mask_r:
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
if cell is not None:
if cell is not None :
edge_vec = edge_vec + torch.einsum(
"ni,ij->nj",
shifts,
Expand All @@ -1058,17 +1068,36 @@ def neighbor_list_and_relative_vec(

edge_length = torch.linalg.norm(edge_vec, dim=-1)

atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
for i in set(atomic_numbers):
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])
# atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
# for i in set(atomic_numbers):
# assert i in atom_species_num
# r_map = torch.zeros(max(atom_species_num))
# for k, v in r_max.items():
# r_map[atomic_num_dict[k]-1] = v

first_key = next(iter(r_max.keys()))
key_parts = first_key.split("-")

if len(key_parts)==1:
r_map = get_r_map(r_max, atomic_numbers)
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])

elif len(key_parts)==2:
r_map = get_r_map_bondwise(r_max, atomic_numbers)
edge_length_max = r_map[atomic_numbers[edge_index[0]]-1,atomic_numbers[edge_index[1]]-1]
else:
raise ValueError("The r_max keys should be either atomic number or atomic number pair.")

r_mask = edge_length <= edge_length_max
if any(~r_mask):
edge_index = edge_index[:, r_mask]
shifts = shifts[r_mask]
# 收集不同类型的边及其对应的最大截断半径
#edge_types = {}
#for i in range(edge_index.shape[1]):
# atom_type_pair = (atomic_numbers[edge_index[0, i]], atomic_numbers[edge_index[1, i]])
# if atom_type_pair not in edge_types:
# edge_types[atom_type_pair] = edge_length_max[i].item()

del edge_length
del edge_vec
Expand All @@ -1077,3 +1106,62 @@ def neighbor_list_and_relative_vec(
del r_mask

return edge_index, shifts, cell_tensor

def get_r_map(r_max: dict, atomic_numbers=None):
"""
Returns a torch tensor representing the mapping of atomic species to their maximum distances.
Args:
r_max (dict): A dictionary mapping atomic species to their maximum distances.
atomic_numbers (list, optional): A list of atomic numbers to validate against the atomic species. Defaults to None.
Returns:
torch.Tensor: A torch tensor representing the mapping of atomic species to their maximum distances.
"""
atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
if atomic_numbers is not None:
for i in atomic_numbers:
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
return r_map

def get_r_map_bondwise(r_max:dict, atomic_numbers=None):
"""
Calculate the bondwise distance map based on the maximum bond length dictionary.
Args:
r_max (dict): A dictionary containing the maximum bond lengths for different atom pairs.
atomic_numbers (list, optional): A list of atomic numbers. Defaults to None.
Returns:
torch.Tensor: A torch tensor representing the bondwise distance map.
"""
atom_species_num = []
for k in r_max.keys():
assert len(k.split('-')) == 2
atom_a, atom_b = k.split('-')
if atomic_num_dict[atom_a] not in atom_species_num:
atom_species_num.append(atomic_num_dict[atom_a])
if atomic_num_dict[atom_b] not in atom_species_num:
atom_species_num.append(atomic_num_dict[atom_b])

if atomic_numbers is not None:
for i in atomic_numbers:
assert i in atom_species_num

r_map = torch.zeros(max(atom_species_num), max(atom_species_num))
for k, v in r_max.items():
atom_a, atom_b = k.split('-')

inv_value = r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1]
if inv_value == 0:
r_map[atomic_num_dict[atom_a]-1, atomic_num_dict[atom_b]-1] = v
r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1] = v
else:
mean_val = (v + inv_value) / 2
r_map[atomic_num_dict[atom_a]-1, atomic_num_dict[atom_b]-1] = mean_val
r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1] = mean_val

return r_map
2 changes: 1 addition & 1 deletion dptb/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def __init__(
for ib in self.basis.keys():
self.basis[ib] = sorted(
self.basis[ib],
key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0])
key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0] if re.findall(r"[1-9*]",s) else '0')
)

# TODO: get full basis set
Expand Down
87 changes: 84 additions & 3 deletions dptb/entrypoints/collectskf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Union
import json
from pathlib import Path
import os
import torch
import glob
from dptb.nn.dftb.sk_param import SKParam

from dptb.nn.dftb2nnsk import DFTB2NNSK
import logging
from dptb.utils.loggers import set_log_handles
from dptb.utils.tools import j_loader, setup_seed, j_must_have
from dptb.utils.argcheck import normalize, collect_cutoffs, normalize_skf2nnsk


__all__ = ["skf2pth"]
__all__ = ["skf2pth", "skf2nnsk"]


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,3 +49,80 @@ def skf2pth(
torch.save(skdict, output)


def skf2nnsk(
INPUT:str,
init_model: Optional[str],
output:str,
log_level: int,
log_path: Optional[str] = None,
**kwargs
):
run_opt = {
"init_model": init_model,
"log_path": log_path,
"log_level": log_level
}

# setup output path
if output:
Path(output).parent.mkdir(exist_ok=True, parents=True)
Path(output).mkdir(exist_ok=True, parents=True)
if not log_path:
log_path = os.path.join(str(output), "log.txt")
Path(log_path).parent.mkdir(exist_ok=True, parents=True)

run_opt.update({
"output": str(Path(output).absolute()),
"log_path": str(Path(log_path).absolute())
})
set_log_handles(log_level, Path(log_path) if log_path else None)

jdata = j_loader(INPUT)
jdata = normalize_skf2nnsk(jdata)

common_options = jdata['common_options']
model_options = jdata['model_options']
train_options = jdata['train_options']

basis = j_must_have(common_options, "basis")
skdata_file = j_must_have(common_options, "skdata")

if skdata_file.split('.')[-1] != 'pth':
log.error("The skdata file should be a pth file.")
raise ValueError("The skdata file should be a pth file.")
log.info(f"Loading skdata from {skdata_file}")
skdata = torch.load(skdata_file)

if isinstance(basis, str) and basis == "auto":
log.info("Automatically determining basis")
basis = dict(zip(skdata['OnsiteE'], [['s', 'p', 'd']] * len(skdata['OnsiteE'])))
else:
assert isinstance(basis, dict), "basis must be a dict or 'auto'"

train_options = jdata['train_options']

if init_model:
dftb2nn = DFTB2NNSK.load(ckpt=init_model,
skdata=skdata,
train_options=train_options,
output=run_opt.get('output', './')
)

else:
dftb2nn = DFTB2NNSK(
basis = basis,
skdata = skdata,
method=j_must_have(model_options, "method"),
rs=model_options.get('rs', None),
w = j_must_have(model_options, "w"),
cal_rcuts= model_options.get('rs', None) is None,
atomic_radius= model_options.get('atomic_radius', 'cov'),
train_options=train_options,
output=run_opt.get('output', './')
)

dftb2nn.optimize()




36 changes: 35 additions & 1 deletion dptb/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dptb.entrypoints.data import data
from dptb.utils.loggers import set_log_handles
from dptb.utils.config_check import check_config_train
from dptb.entrypoints.collectskf import skf2pth
from dptb.entrypoints.collectskf import skf2pth, skf2nnsk
from dptb import __version__


Expand Down Expand Up @@ -364,6 +364,37 @@ def main_parser() -> argparse.ArgumentParser:
help="The output pth files of sk params from skfiles."
)

# neighbour
parser_skf2nn = subparsers.add_parser(
"skf2nn",
parents=[parser_log],
help="Convert the sk files to nn-sk TB model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser_skf2nn.add_argument(
"INPUT", help="the input parameter file in json or yaml format",
type=str,
default=None
)

parser_skf2nn.add_argument(
"-i",
"--init-model",
type=str,
default=None,
help="Initialize the model by the provided checkpoint.",
)

parser_skf2nn.add_argument(
"-o",
"--output",
type=str,
default="./",
help="The output files in training.",
)


return parser

def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
Expand Down Expand Up @@ -424,3 +455,6 @@ def main():

elif args.command == 'cskf':
skf2pth(**dict_args)

elif args.command == 'skf2nn':
skf2nnsk(**dict_args)
15 changes: 8 additions & 7 deletions dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
def build_model(
checkpoint: str=None,
model_options: dict={},
common_options: dict={}
common_options: dict={},
no_check: bool=False
):
"""
The build model method should composed of the following steps:
Expand Down Expand Up @@ -161,12 +162,12 @@ def build_model(
model = DFTBSK.from_reference(checkpoint, **model_options["dftbsk"], **common_options)
else:
model = None
for k, v in model.model_options.items():
if k not in model_options:
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)
if not no_check:
for k, v in model.model_options.items():
if k not in model_options:
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)
model.to(model.device)
return model

Expand Down
17 changes: 15 additions & 2 deletions dptb/nn/dftb/hopping_dftb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dptb.nn.sktb.hopping import BaseHopping
import torch
from dptb.utils._xitorch.interpolate import Interp1D

import logging
log = logging.getLogger(__name__)
class HoppingIntp(BaseHopping):

def __init__(
Expand Down Expand Up @@ -36,7 +37,19 @@ def dftb(self, rij:torch.Tensor, xx:torch.Tensor, yy:torch.Tensor, **kwargs):
assert rij.shape[0] == self.num_ingrls, "the bond distance shape rij is not correct."
else:
raise ValueError("The shape of rij is not correct.")
# 检查 rij 是否在 xx 的范围内
min_x, max_x = self.xx.min(), self.xx.max()
mask_in_range = (rij >= min_x) & (rij <= max_x)
mask_out_range = ~mask_in_range
if mask_out_range.any():
# log.warning("Some rij values are outside the interpolation range and will be set to 0.")
# 创建 rij 的副本,并将范围外的值替换为范围内的值(例如,使用 min_x)
rij_modified = rij.clone()
rij_modified[mask_out_range] = (min_x + max_x) / 2
yyintp = self.intpfunc(xq=rij_modified, y=yy)
yyintp[mask_out_range] = 0.0
else:
yyintp = self.intpfunc(xq=rij, y=yy)

yyintp = self.intpfunc(xq=rij,y=yy)
return yyintp.T

Loading

0 comments on commit dabd47c

Please sign in to comment.