Skip to content

Commit

Permalink
add support atomwise rmax for sktb module (#209)
Browse files Browse the repository at this point in the history
* feat: add logging for self-interaction warnings and implement tests for r_max functionality

* feat: add support for atom-wise rs rmax setting in deeptb sk modula.

* remove check atomic_number shape in get_r_map

* add Covalent_radii database

* Add support for basis input in the format : ['s','p','d'], previous only support the format of ['2s','3p','d*'], where the main quantum number must be there before the orbital or the * after the orbital symbol.

* fix test error when add the support for basis input format of ['s','p','d']

* change the unit of bond_length_list from bohr to \AA

* add test_Covalent_radii

* add some comment in dftbsk

* add BondLenCovalent

* add new onsite formula uniform_noref

* remove the unit transition from bohr to \AA since gthe bondlenth are saved in unit AA.

* rename BondLenCovalent to R_cov_list

* chore: Update test_SKHamiltonian to use torch.allclose with specified tolerances

* Refactor AtomicData.py to use environment variable for error handling

* add atomic radius in bondlengthDB

* Refactor covalent radii database to use \AA as unit

* feat: Calculate minimum and maximum atomic radii based on skdata

This commit adds a new function `cal_rmin_rmax` to calculate the minimum and maximum atomic radii based on the given skdata. It uses the `find_first_false` function to find the index of the first occurrence of False in each row of a 2D array. The calculated atomic radii are stored in the `atomic_r_min_dict` and `atomic_r_max_dict` dictionaries.

* add new dftb2nnsk

* add new dftb2nnsk

* feat: Update SKParam class to  update atomic radii and format_skparams to align the distance introduce the mask for rij out the grid from skfiles.

This commit modifies the SKParam class in sk_param.py to calculate and update the minimum and maximum atomic radii based on the given skdata. It introduces a new function `cal_rmin_rmax` that uses the `find_first_false` function to find the index of the first occurrence of False in each row of a 2D array. The calculated atomic radii are stored in the `atomic_r_min` and `atomic_r_max` attributes of the SKParam class.

* feat: Add range check for bond distance in HoppingIntp

This commit adds a range check for the bond distance `rij` in the `HoppingIntp` class of `hopping_dftb.py`. If any `rij` values are outside the interpolation range defined by `min_x` and `max_x`, a warning message is logged and those values are set to 0.0. The interpolated values are then calculated using the modified `rij` array.

* feat: Add poly4pow hopping formula to HoppingFormula class and simplify the class

This commit adds the 'poly4pow' hopping formula to the HoppingFormula class in hopping.py. The formula calculates the SK integrals without the environment dependence of the form of powerlaw. It takes into account parameters such as alpha1, alpha2, alpha3, alpha4, alpha5, and alpha6. The function poly4pow() is used to calculate the value of the hopping formula. If the functype is 'NRL0' or 'NRL1', the NRL_HOP method is called. Otherwise, the method corresponding to the functype is called. If the functype is not recognized, a ValueError is raised.

* update dftb2nnsk to accept large model

* delete temp

* add get_rmap for bond_wise cutoff seting style

* feat: Add poly3exp and poly4exp hopping formulas to HoppingFormula class

This commit adds the 'poly3exp' and 'poly4exp' hopping formulas to the HoppingFormula class in hopping.py. These formulas calculate the SK integrals without the environment dependence of the form of powerlaw. They take into account parameters such as alpha1, alpha2, alpha3, alpha4, alpha5, and alpha6. The functions poly3exp() and poly4exp() are used to calculate the values of the hopping formulas. If the functype is 'NRL0' or 'NRL1', the NRL_HOP method is called. Otherwise, the method corresponding to the functype is called. If the functype is not recognized, a ValueError is raised.

* feat: update AtomicData.neighbor_list_and_relative_vec to support bond wise rmax

* feat:add support bond-wise rmax in nnsk

* feat: Update NNSK class to support bond-wise rs values

This commit updates the NNSK class in nnsk.py to support bond-wise rs values. It introduces a conditional statement to check if the rs values are atom-wise or bond-wise. If the rs values are bond-wise, the get_r_map_bondwise() function is called to generate the r_map dictionary. The r_map_type attribute is set to 2 to indicate bond-wise rs values. This allows for more flexibility in defining the hopping options for the NNSK model.

* feat: Update AtomicData.neighbor_list_and_relative_vec to support bond-wise rmax

* feat: Add support for bond-wise rmax in build_dataset tests

* test:update some tests

* feat: add bondwise cal rmax and rim in skparam

* use bondwise rmin and rmax in skparam and dftb2nnsk

* Refactor bondwise rmax calculation in sk_param.py and dftb2nnsk.py

* Refactor DFTB2NNSK add save and load model

* Refactor NRL_OVERLAP0 and NRL_OVERLAP1 in hopping.py to use torch.Tensor instead of torch.float32 for rs and w parameters

* Refactor get_lr_scheduler to add support for cosine annealing learning rate scheduler

* Refactor  optimise funciton in dftb2nnsk.py

* Refactor collectskf.py and main.py to add support for converting sk files to nn-sk TB model

* Refactor argcheck.py to add support for RMSprop and LBFGS optimizers and CosineAnnealingLR learning rate scheduler

* Refactor argcheck.py to remove duplicate formulas in the hopping function

* update test for dftb2nnsk

* Refactor argcheck.py to remove duplicate formulas in the hopping function

* rename functype to method in dftb2nnsk and fix bug in argcheck.

* Add output inputpara

* add bond integrl plot viz

* add bond integrl plot viz

* Refactor NNSK class to include support for different atomic radius options

* Refactor atomic radius initialization in DFTB2NNSK and NNSK classes

* fix: the support for uniform_noref onsite mode in nnsk

* Refactor OnsiteFormula class in onsite.py

* Refactor save method in NNSK class  add device and dtype in dftb2nnsk and sk_param

* Refactor ElecStruCal class to handle atomic radius options and fix a bug related to the pbc parameter

* Refactor get_cutoffs_from_model_options function to handle the case when r_max is not provided in dftbsk model options

* add example of dftb2nnsk

* test: add skf2nnsk example

* Refactor load method in DFTB2NNSK class to include an optional output parameter

* add no check for model options in build.py

* fix bug in lr_scheduler if argcheck

* add base model and example

* update uasge

* fix the typo for get_r_map_bondwise
  • Loading branch information
QG-phy authored Dec 21, 2024
1 parent 52221e6 commit 42eaf5d
Show file tree
Hide file tree
Showing 37 changed files with 2,657 additions and 307 deletions.
110 changes: 99 additions & 11 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from .util import _TORCH_INTEGER_DTYPES
from dptb.utils.torch_geometric.data import Data
from dptb.utils.constants import atomic_num_dict
import logging

log = logging.getLogger(__name__)

# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]
Expand Down Expand Up @@ -874,11 +877,10 @@ def without_nodes(self, which_nodes):
return type(self)(**new_dict)


_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false")
_ERROR_ON_NO_EDGES = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false"), "NEQUIP_ERROR_ON_NO_EDGES must be 'true' or 'false'"
_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"


def neighbor_list_and_relative_vec(
pos,
r_max,
Expand Down Expand Up @@ -1026,6 +1028,13 @@ def neighbor_list_and_relative_vec(
# so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True.
if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)):
o_mask[i] = True

if self_interaction:
log.warning("self_interaction is True, but usually we do not want the self-interaction, please check if it is correct.")
# for self-interaction, the above will remove the self-interaction, i.e. i == j, shift == [0, 0, 0]. since -0 = 0.
if (o_shift[i] == np.array([0, 0, 0])).all():
o_mask[i] = True

del rev_dict
del o_first_idex
del o_second_idex
Expand All @@ -1038,6 +1047,7 @@ def neighbor_list_and_relative_vec(
shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device)

if not reduce:
assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated."
first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0)
shifts = torch.cat((shifts, -shifts), dim=0)

Expand All @@ -1049,7 +1059,7 @@ def neighbor_list_and_relative_vec(
# TODO: mask the edges that is larger than r_max
if mask_r:
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
if cell is not None:
if cell is not None :
edge_vec = edge_vec + torch.einsum(
"ni,ij->nj",
shifts,
Expand All @@ -1058,17 +1068,36 @@ def neighbor_list_and_relative_vec(

edge_length = torch.linalg.norm(edge_vec, dim=-1)

atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
for i in set(atomic_numbers):
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])
# atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
# for i in set(atomic_numbers):
# assert i in atom_species_num
# r_map = torch.zeros(max(atom_species_num))
# for k, v in r_max.items():
# r_map[atomic_num_dict[k]-1] = v

first_key = next(iter(r_max.keys()))
key_parts = first_key.split("-")

if len(key_parts)==1:
r_map = get_r_map(r_max, atomic_numbers)
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])

elif len(key_parts)==2:
r_map = get_r_map_bondwise(r_max, atomic_numbers)
edge_length_max = r_map[atomic_numbers[edge_index[0]]-1,atomic_numbers[edge_index[1]]-1]
else:
raise ValueError("The r_max keys should be either atomic number or atomic number pair.")

r_mask = edge_length <= edge_length_max
if any(~r_mask):
edge_index = edge_index[:, r_mask]
shifts = shifts[r_mask]
# 收集不同类型的边及其对应的最大截断半径
#edge_types = {}
#for i in range(edge_index.shape[1]):
# atom_type_pair = (atomic_numbers[edge_index[0, i]], atomic_numbers[edge_index[1, i]])
# if atom_type_pair not in edge_types:
# edge_types[atom_type_pair] = edge_length_max[i].item()

del edge_length
del edge_vec
Expand All @@ -1077,3 +1106,62 @@ def neighbor_list_and_relative_vec(
del r_mask

return edge_index, shifts, cell_tensor

def get_r_map(r_max: dict, atomic_numbers=None):
"""
Returns a torch tensor representing the mapping of atomic species to their maximum distances.
Args:
r_max (dict): A dictionary mapping atomic species to their maximum distances.
atomic_numbers (list, optional): A list of atomic numbers to validate against the atomic species. Defaults to None.
Returns:
torch.Tensor: A torch tensor representing the mapping of atomic species to their maximum distances.
"""
atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
if atomic_numbers is not None:
for i in atomic_numbers:
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
return r_map

def get_r_map_bondwise(r_max:dict, atomic_numbers=None):
"""
Calculate the bondwise distance map based on the maximum bond length dictionary.
Args:
r_max (dict): A dictionary containing the maximum bond lengths for different atom pairs.
atomic_numbers (list, optional): A list of atomic numbers. Defaults to None.
Returns:
torch.Tensor: A torch tensor representing the bondwise distance map.
"""
atom_species_num = []
for k in r_max.keys():
assert len(k.split('-')) == 2
atom_a, atom_b = k.split('-')
if atomic_num_dict[atom_a] not in atom_species_num:
atom_species_num.append(atomic_num_dict[atom_a])
if atomic_num_dict[atom_b] not in atom_species_num:
atom_species_num.append(atomic_num_dict[atom_b])

if atomic_numbers is not None:
for i in atomic_numbers:
assert i in atom_species_num

r_map = torch.zeros(max(atom_species_num), max(atom_species_num))
for k, v in r_max.items():
atom_a, atom_b = k.split('-')

inv_value = r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1]
if inv_value == 0:
r_map[atomic_num_dict[atom_a]-1, atomic_num_dict[atom_b]-1] = v
r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1] = v
else:
mean_val = (v + inv_value) / 2
r_map[atomic_num_dict[atom_a]-1, atomic_num_dict[atom_b]-1] = mean_val
r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1] = mean_val

return r_map
2 changes: 1 addition & 1 deletion dptb/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def __init__(
for ib in self.basis.keys():
self.basis[ib] = sorted(
self.basis[ib],
key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0])
key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0] if re.findall(r"[1-9*]",s) else '0')
)

# TODO: get full basis set
Expand Down
87 changes: 84 additions & 3 deletions dptb/entrypoints/collectskf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Union
import json
from pathlib import Path
import os
import torch
import glob
from dptb.nn.dftb.sk_param import SKParam

from dptb.nn.dftb2nnsk import DFTB2NNSK
import logging
from dptb.utils.loggers import set_log_handles
from dptb.utils.tools import j_loader, setup_seed, j_must_have
from dptb.utils.argcheck import normalize, collect_cutoffs, normalize_skf2nnsk


__all__ = ["skf2pth"]
__all__ = ["skf2pth", "skf2nnsk"]


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,3 +49,80 @@ def skf2pth(
torch.save(skdict, output)


def skf2nnsk(
INPUT:str,
init_model: Optional[str],
output:str,
log_level: int,
log_path: Optional[str] = None,
**kwargs
):
run_opt = {
"init_model": init_model,
"log_path": log_path,
"log_level": log_level
}

# setup output path
if output:
Path(output).parent.mkdir(exist_ok=True, parents=True)
Path(output).mkdir(exist_ok=True, parents=True)
if not log_path:
log_path = os.path.join(str(output), "log.txt")
Path(log_path).parent.mkdir(exist_ok=True, parents=True)

run_opt.update({
"output": str(Path(output).absolute()),
"log_path": str(Path(log_path).absolute())
})
set_log_handles(log_level, Path(log_path) if log_path else None)

jdata = j_loader(INPUT)
jdata = normalize_skf2nnsk(jdata)

common_options = jdata['common_options']
model_options = jdata['model_options']
train_options = jdata['train_options']

basis = j_must_have(common_options, "basis")
skdata_file = j_must_have(common_options, "skdata")

if skdata_file.split('.')[-1] != 'pth':
log.error("The skdata file should be a pth file.")
raise ValueError("The skdata file should be a pth file.")
log.info(f"Loading skdata from {skdata_file}")
skdata = torch.load(skdata_file)

if isinstance(basis, str) and basis == "auto":
log.info("Automatically determining basis")
basis = dict(zip(skdata['OnsiteE'], [['s', 'p', 'd']] * len(skdata['OnsiteE'])))
else:
assert isinstance(basis, dict), "basis must be a dict or 'auto'"

train_options = jdata['train_options']

if init_model:
dftb2nn = DFTB2NNSK.load(ckpt=init_model,
skdata=skdata,
train_options=train_options,
output=run_opt.get('output', './')
)

else:
dftb2nn = DFTB2NNSK(
basis = basis,
skdata = skdata,
method=j_must_have(model_options, "method"),
rs=model_options.get('rs', None),
w = j_must_have(model_options, "w"),
cal_rcuts= model_options.get('rs', None) is None,
atomic_radius= model_options.get('atomic_radius', 'cov'),
train_options=train_options,
output=run_opt.get('output', './')
)

dftb2nn.optimize()




36 changes: 35 additions & 1 deletion dptb/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dptb.entrypoints.data import data
from dptb.utils.loggers import set_log_handles
from dptb.utils.config_check import check_config_train
from dptb.entrypoints.collectskf import skf2pth
from dptb.entrypoints.collectskf import skf2pth, skf2nnsk
from dptb import __version__


Expand Down Expand Up @@ -364,6 +364,37 @@ def main_parser() -> argparse.ArgumentParser:
help="The output pth files of sk params from skfiles."
)

# neighbour
parser_skf2nn = subparsers.add_parser(
"skf2nn",
parents=[parser_log],
help="Convert the sk files to nn-sk TB model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser_skf2nn.add_argument(
"INPUT", help="the input parameter file in json or yaml format",
type=str,
default=None
)

parser_skf2nn.add_argument(
"-i",
"--init-model",
type=str,
default=None,
help="Initialize the model by the provided checkpoint.",
)

parser_skf2nn.add_argument(
"-o",
"--output",
type=str,
default="./",
help="The output files in training.",
)


return parser

def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
Expand Down Expand Up @@ -424,3 +455,6 @@ def main():

elif args.command == 'cskf':
skf2pth(**dict_args)

elif args.command == 'skf2nn':
skf2nnsk(**dict_args)
15 changes: 8 additions & 7 deletions dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
def build_model(
checkpoint: str=None,
model_options: dict={},
common_options: dict={}
common_options: dict={},
no_check: bool=False
):
"""
The build model method should composed of the following steps:
Expand Down Expand Up @@ -161,12 +162,12 @@ def build_model(
model = DFTBSK.from_reference(checkpoint, **model_options["dftbsk"], **common_options)
else:
model = None
for k, v in model.model_options.items():
if k not in model_options:
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)
if not no_check:
for k, v in model.model_options.items():
if k not in model_options:
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)
model.to(model.device)
return model

Expand Down
17 changes: 15 additions & 2 deletions dptb/nn/dftb/hopping_dftb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dptb.nn.sktb.hopping import BaseHopping
import torch
from dptb.utils._xitorch.interpolate import Interp1D

import logging
log = logging.getLogger(__name__)
class HoppingIntp(BaseHopping):

def __init__(
Expand Down Expand Up @@ -36,7 +37,19 @@ def dftb(self, rij:torch.Tensor, xx:torch.Tensor, yy:torch.Tensor, **kwargs):
assert rij.shape[0] == self.num_ingrls, "the bond distance shape rij is not correct."
else:
raise ValueError("The shape of rij is not correct.")
# 检查 rij 是否在 xx 的范围内
min_x, max_x = self.xx.min(), self.xx.max()
mask_in_range = (rij >= min_x) & (rij <= max_x)
mask_out_range = ~mask_in_range
if mask_out_range.any():
# log.warning("Some rij values are outside the interpolation range and will be set to 0.")
# 创建 rij 的副本,并将范围外的值替换为范围内的值(例如,使用 min_x)
rij_modified = rij.clone()
rij_modified[mask_out_range] = (min_x + max_x) / 2
yyintp = self.intpfunc(xq=rij_modified, y=yy)
yyintp[mask_out_range] = 0.0
else:
yyintp = self.intpfunc(xq=rij, y=yy)

yyintp = self.intpfunc(xq=rij,y=yy)
return yyintp.T

Loading

0 comments on commit 42eaf5d

Please sign in to comment.