Skip to content

Commit

Permalink
update weighted hamil loss (#215)
Browse files Browse the repository at this point in the history
* update weighted loss

* update loss

* fix default type in HamilLossWT

* fix bug in loss wt

* change onsite/hopping weight from list to dict
  • Loading branch information
floatingCatty authored Dec 23, 2024
1 parent 9b7272d commit 5307de2
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
156 changes: 155 additions & 1 deletion dptb/nnops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,163 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.bond_types)
)[self.idp.mask_to_erme].mean().sqrt()
)[hopping_index][self.idp.mask_to_erme[hopping_index]].mean().sqrt()
overlap_loss *= 0.5

overlap_onsite_loss = data[AtomicDataDict.NODE_OVERLAP_KEY]-ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
overlap_onsite_loss = scatter_mean(
src = overlap_onsite_loss.abs(),
index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.type_names)
)[onsite_index][self.idp.mask_to_nrme[onsite_index]].mean() + scatter_mean(
src = overlap_onsite_loss**2,
index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.type_names)
)[onsite_index][self.idp.mask_to_nrme[onsite_index]].mean().sqrt()
overlap_loss += overlap_onsite_loss * 0.5

return (1/3) * (hopping_loss + onsite_loss + overlap_loss)
else:
return 0.5 * (onsite_loss + hopping_loss)


@Loss.register("hamil_wt")
class HamilLossWT(nn.Module):
def __init__(
self,
basis: Dict[str, Union[str, list]]=None,
idp: Union[OrbitalMapper, None]=None,
overlap: bool=False,
onsite_shift: bool=False,
onsite_weight: Union[float, int, dict]=1.,
hopping_weight: Union[float, int, dict]=1.,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
**kwargs,
):

super(HamilLossWT, self).__init__()
self.overlap = overlap
self.device = device
self.onsite_shift = onsite_shift

if basis is not None:
self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
if idp is not None:
assert idp == self.idp, "The basis of idp and basis should be the same."
else:
assert idp is not None, "Either basis or idp should be provided."
self.idp = idp

self.onsite_weight = torch.ones(idp.num_types)
self.hopping_weight = torch.ones(len(idp.bond_types))
if isinstance(onsite_weight, float) or isinstance(onsite_weight, int):
self.onsite_weight *= onsite_weight
elif isinstance(onsite_weight, dict):
for k,v in onsite_weight.items():
self.onsite_weight[idp.chemical_symbol_to_type[k]] = v
else:
raise TypeError("onsite weight should be either float, int or dict")

if isinstance(hopping_weight, float) or isinstance(hopping_weight, int):
self.hopping_weight *= hopping_weight
elif isinstance(hopping_weight, dict):
for k,v in hopping_weight.items():
self.hopping_weight[idp.bond_to_type[k]] = v
else:
raise TypeError("hopping weight should be either float, int or dict")

self.onsite_weight = self.onsite_weight.unsqueeze(1)
self.hopping_weight = self.hopping_weight.unsqueeze(1)

def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
# mask the data
# data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.)
# data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.)

if self.onsite_shift:
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
if batch.max() == 0: # when batchsize is zero
mu = mu.mean().detach()
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
elif batch.max() >= 1:
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
slices = [0] + slices
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
mu = mu.detach()
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
for i in range(1, batch.max().item()+1):
edge_mu_index[data["__slices__"]["edge_index"][i]:data["__slices__"]["edge_index"][i+1]] += i
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu[edge_mu_index, None] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]

onsite_loss = data[AtomicDataDict.NODE_FEATURES_KEY]-ref_data[AtomicDataDict.NODE_FEATURES_KEY]
onsite_index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten().unique()
onsite_loss = (self.onsite_weight * scatter_mean(
src = onsite_loss.abs(),
index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.type_names)
)[onsite_index])[self.idp.mask_to_nrme[onsite_index]].mean() + (self.onsite_weight**2 * scatter_mean(
src = onsite_loss**2,
index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.type_names)
)[onsite_index])[self.idp.mask_to_nrme[onsite_index]].mean().sqrt()
onsite_loss *= 0.5

hopping_index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten().unique()
hopping_loss = data[AtomicDataDict.EDGE_FEATURES_KEY]-ref_data[AtomicDataDict.EDGE_FEATURES_KEY]
hopping_loss = (self.hopping_weight * scatter_mean(
src = hopping_loss.abs(),
index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.bond_types)
)[hopping_index])[self.idp.mask_to_erme[hopping_index]].mean() + (self.hopping_weight**2 * scatter_mean(
src = hopping_loss**2,
index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.bond_types)
)[hopping_index])[self.idp.mask_to_erme[hopping_index]].mean().sqrt()
hopping_loss *= 0.5

if self.overlap:
overlap_loss = data[AtomicDataDict.EDGE_OVERLAP_KEY]-ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
overlap_loss = (self.hopping_weight * scatter_mean(
src = overlap_loss.abs(),
index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.bond_types)
)[hopping_index])[self.idp.mask_to_erme[hopping_index]].mean() + (self.hopping_weight **2 * scatter_mean(
src = overlap_loss**2,
index = data[AtomicDataDict.EDGE_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.bond_types)
)[hopping_index])[self.idp.mask_to_erme[hopping_index]].mean().sqrt()
overlap_loss *= 0.5

overlap_onsite_loss = data[AtomicDataDict.NODE_OVERLAP_KEY]-ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
overlap_onsite_loss = (self.onsite_weight * scatter_mean(
src = overlap_onsite_loss.abs(),
index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.type_names)
)[onsite_index])[self.idp.mask_to_nrme[onsite_index]].mean() + ((self.onsite_weight ** 2) * scatter_mean(
src = overlap_onsite_loss**2,
index = data[AtomicDataDict.ATOM_TYPE_KEY].flatten(),
dim=0,
dim_size=len(self.idp.type_names)
)[onsite_index])[self.idp.mask_to_nrme[onsite_index]].mean().sqrt()
overlap_loss += overlap_onsite_loss * 0.5

return (1/3) * (hopping_loss + onsite_loss + overlap_loss)
else:
return 0.5 * (onsite_loss + hopping_loss)
Expand Down
6 changes: 6 additions & 0 deletions dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,11 @@ def loss_options():
Argument("onsite_shift", bool, optional=True, default=False, doc="Whether to use onsite shift in loss function. Default: False"),
]

wt = [
Argument("onsite_weight", [int, float, dict], optional=True, default=1., doc="Whether to use onsite shift in loss function. Default: False"),
Argument("hopping_weight", [int, float, dict], optional=True, default=1., doc="Whether to use onsite shift in loss function. Default: False"),
]

eigvals = [
Argument("diff_on", bool, optional=True, default=False, doc="Whether to use random differences in loss function. Default: False"),
Argument("eout_weight", float, optional=True, default=0.01, doc="The weight of eigenvalue out of range. Default: 0.01"),
Expand All @@ -842,6 +847,7 @@ def loss_options():
Argument("skints", dict, sub_fields=skints),
Argument("hamil_abs", dict, sub_fields=hamil),
Argument("hamil_blas", dict, sub_fields=hamil),
Argument("hamil_wt", dict, sub_fields=hamil+wt),
], optional=False, doc=doc_method)


Expand Down

0 comments on commit 5307de2

Please sign in to comment.