Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Cell-related gradient modifications #12

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup_requires: List[str] = []
install_requires: List[str] = [
"ase>=3.18, <4.0.0", # Note that we require ase==3.21.1 for pytest.
"pymatgen",
"pymatgen>=2020.1.10",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_neighbor_list() used in this library is recently introduced.

]
extras_require: Dict[str, List[str]] = {
"develop": ["pysen[lint]==0.9.1"],
Expand Down
32 changes: 19 additions & 13 deletions tests/functions_tests/test_triplets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def test_calc_triplets():
[1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], dtype=torch.float32, device=device
)
# print("shift", shift.shape)
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
edge_index, shift
)
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index, shift)
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
# print("multiplicity", multiplicity.shape, multiplicity)
# print("triplet_shift", triplet_shift.shape, triplet_shift)
Expand All @@ -38,6 +36,20 @@ def test_calc_triplets():
)
assert multiplicity.shape == (n_triplets,)
assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float32))

assert torch.allclose(
edge_jk.cpu(),
torch.tensor([[7, 6], [8, 6], [8, 7], [9, 10], [9, 11], [11, 10]], dtype=torch.long),
)
# shift for edge `i->j`, `i->k`, `j->k`.
triplet_shift = torch.stack(
[
-shift[edge_jk[:, 0]],
-shift[edge_jk[:, 1]],
shift[edge_jk[:, 0]] - shift[edge_jk[:, 1]],
],
dim=1,
)
assert torch.allclose(
triplet_shift.cpu()[:, :, 0],
torch.tensor(
Expand All @@ -61,7 +73,7 @@ def test_calc_triplets_noshift():
edge_index = torch.tensor(
[[0, 1, 1, 3, 1, 2, 3, 0], [1, 2, 3, 0, 0, 1, 1, 3]], dtype=torch.long, device=device
)
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
edge_index, dtype=torch.float64
)
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
Expand All @@ -78,13 +90,7 @@ def test_calc_triplets_noshift():
assert multiplicity.shape == (n_triplets,)
assert multiplicity.dtype == torch.float64
assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float64))
assert torch.all(
triplet_shift.cpu()
== torch.zeros(
(n_triplets, 3, 3),
dtype=torch.float32,
)
)
assert torch.all(edge_jk.cpu() == torch.tensor([[1, 0], [2, 3]], dtype=torch.long))
assert torch.all(batch_triplets.cpu() == torch.zeros((n_triplets,), dtype=torch.long))


Expand All @@ -95,7 +101,7 @@ def test_calc_triplets_noshift():
def test_calc_triplets_no_triplets(edge_index):
# edge_index = edge_index.to("cuda:0")
# No triplet exist in this graph. Case1: No edge, Case 2 No triplets in this edge.
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(edge_index)
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index)
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
# print("multiplicity", multiplicity.shape, multiplicity)
# print("triplet_shift", triplet_shift.shape, triplet_shift)
Expand All @@ -104,7 +110,7 @@ def test_calc_triplets_no_triplets(edge_index):
# 0 triplets exist.
assert triplet_node_index.shape == (0, 3)
assert multiplicity.shape == (0,)
assert triplet_shift.shape == (0, 3, 3)
assert edge_jk.shape == (0, 2)
assert batch_triplets.shape == (0,)


Expand Down
6 changes: 5 additions & 1 deletion tests/test_torch_dftd3_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def _create_atoms() -> List[Atoms]:
atoms = molecule("CH3CH2OCH3")

slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
slab.set_cell(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified cell shape to find potential bugs related to the skewed cell.

slab.get_cell().array @ np.array([[1.0, 0.1, 0.2], [0.05, 1.0, 0.02], [0.03, 0.04, 1.0]])
)
slab.pbc = np.array([True, True, True])
return [atoms, slab]

Expand Down Expand Up @@ -58,6 +61,8 @@ def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms):
atoms.calc = calc1
f1 = atoms.get_forces()
e1 = atoms.get_potential_energy()
if np.all(atoms.pbc == np.array([True, True, True])):
s1 = atoms.get_stress()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify reported test bug in #7


calc2.reset()
atoms.calc = calc2
Expand All @@ -66,7 +71,6 @@ def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms):
assert np.allclose(e1, e2, atol=1e-4, rtol=1e-4)
assert np.allclose(f1, f2, atol=1e-5, rtol=1e-5)
if np.all(atoms.pbc == np.array([True, True, True])):
s1 = atoms.get_stress()
s2 = atoms.get_stress()
assert np.allclose(s1, s2, atol=1e-5, rtol=1e-5)

Expand Down
17 changes: 11 additions & 6 deletions torch_dftd/functions/dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,25 @@ def edisp(
shift_abc = None if shift_abc is None else torch.cat([shift_abc, -shift_abc], dim=0)
with torch.no_grad():
# triplet_node_index, triplet_edge_index = calc_triplets_cycle(edge_index_abc, n_atoms, shift=shift_abc)
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
edge_index_abc, shift=shift_abc, dtype=pos.dtype, batch_edge=batch_edge_abc
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
edge_index_abc,
shift=shift_abc,
dtype=pos.dtype,
batch_edge=batch_edge_abc,
)
batch_triplets = None if batch_edge is None else batch_triplets

# Apply `cnthr` cutoff threshold for r_kj
idx_j, idx_k = triplet_node_index[:, 1], triplet_node_index[:, 2]
ts2 = triplet_shift[:, 2]
ts2 = shift_abc[edge_jk[:, 0]] - shift_abc[edge_jk[:, 1]]
r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, ts2, batch_triplets)
kj_within_cutoff = r_jk <= cnthr
del ts2

triplet_node_index = triplet_node_index[kj_within_cutoff]
multiplicity, triplet_shift, batch_triplets = (
multiplicity, edge_jk, batch_triplets = (
multiplicity[kj_within_cutoff],
triplet_shift[kj_within_cutoff],
edge_jk[kj_within_cutoff],
None if batch_triplets is None else batch_triplets[kj_within_cutoff],
)

Expand All @@ -306,7 +310,8 @@ def edisp(
triplet_node_index[:, 1],
triplet_node_index[:, 2],
)
ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2]
ts0 = -shift_abc[edge_jk[:, 0]]
ts1 = -shift_abc[edge_jk[:, 1]]

r_ij = calc_distances(pos, torch.stack([idx_i, idx_j], dim=0), cell, ts0, batch_triplets)
r_ik = calc_distances(pos, torch.stack([idx_i, idx_k], dim=0), cell, ts1, batch_triplets)
Expand Down
8 changes: 1 addition & 7 deletions torch_dftd/functions/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@ def calc_distances(
Ri = pos[idx_i]
Rj = pos[idx_j]
if cell is not None:
if batch_edge is None:
# shift (n_edges, 3), cell (3, 3) -> offsets (n_edges, 3)
offsets = torch.mm(shift, cell)
else:
# shift (n_edges, 3), cell[batch] (n_atoms, 3, 3) -> offsets (n_edges, 3)
offsets = torch.bmm(shift[:, None, :], cell[batch_edge])[:, 0]
Rj += offsets
Rj += shift
# eps is to avoid Nan in backward when Dij = 0 with sqrt.
Dij = torch.sqrt(torch.sum((Ri - Rj) ** 2, dim=-1) + eps)
return Dij
41 changes: 25 additions & 16 deletions torch_dftd/functions/triplets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def calc_triplets(
i.e.: idx_i, idx_j, idx_k = triplet_node_index
multiplicity (Tensor): (n_triplets,) multiplicity indicates duplication of same triplet pair.
It only takes 1 in non-pbc, but it takes 2 or 3 in pbc case. dtype is specified in the argument.
triplet_shift (Tensor): (n_triplets, 3=(ij, ik, jk), 3=(xyz)) shift for edge `i->j`, `i->k`, `j->k`.
i.e.: idx_ij, idx_ik, idx_jk = triplet_shift
edge_jk (Tensor): (n_triplet_edges, 2=(j, k)) edge indices for j and k.
i.e.: idx_j, idx_k = triplet_shift
batch_triplets (Tensor): (n_triplets,) batch indices for each triplets.
"""
dst, src = edge_index
Expand All @@ -38,9 +38,12 @@ def calc_triplets(
dst = dst[sort_inds]

if shift is None:
shift = torch.zeros((src.shape[0], 3), dtype=dtype, device=edge_index.device)
edge_indices = torch.arange(src.shape[0], dtype=torch.long, device=edge_index.device)
# shift = torch.zeros((src.shape[0], 3), dtype=dtype, device=edge_index.device)
else:
shift = shift[is_larger][sort_inds]
edge_indices = torch.arange(shift.shape[0], dtype=torch.long, device=edge_index.device)
edge_indices = edge_indices[is_larger][sort_inds]
# shift = shift[is_larger][sort_inds]

if batch_edge is None:
batch_edge = torch.zeros(src.shape[0], dtype=torch.long, device=edge_index.device)
Expand All @@ -55,37 +58,38 @@ def calc_triplets(

if str(unique.device) == "cpu":
return _calc_triplets_core(
counts, unique, dst, shift, batch_edge, counts_cumsum, dtype=dtype
counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype
)
else:
return _calc_triplets_core_gpu(
counts, unique, dst, shift, batch_edge, counts_cumsum, dtype=dtype
counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype
)


def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, dtype):
def _calc_triplets_core(counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype):
device = unique.device
n_triplets = torch.sum(counts * (counts - 1) / 2)
if n_triplets == 0:
# (n_triplet_edges, 3)
triplet_node_index = torch.zeros((0, 3), dtype=torch.long, device=device)
# (n_triplet_edges)
multiplicity = torch.zeros((0,), dtype=dtype, device=device)
# (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) )
triplet_shift = torch.zeros((0, 3, 3), dtype=dtype, device=device)
# (n_triplet_edges, 2=(j, k))
edge_jk = torch.zeros((0, 2), dtype=torch.long, device=device)
# (n_triplet_edges)
batch_triplets = torch.zeros((0,), dtype=torch.long, device=device)
return triplet_node_index, multiplicity, triplet_shift, batch_triplets
return triplet_node_index, multiplicity, edge_jk, batch_triplets

triplet_node_index_list = [] # (n_triplet_edges, 3)
shift_list = [] # (n_triplet_edges, 3, 3) represents shift vector
edge_jk_list = [] # (n_triplet_edges, 2) represents j and k indices
multiplicity_list = [] # (n_triplet_edges) represents multiplicity
batch_triplets_list = [] # (n_triplet_edges) represents batch index for triplets
for i in range(len(unique)):
_src = unique[i].item()
_n_edges = counts[i].item()
_dst = dst[counts_cumsum[i] : counts_cumsum[i + 1]]
_shift = shift[counts_cumsum[i] : counts_cumsum[i + 1]]
# _shift = shift[counts_cumsum[i] : counts_cumsum[i + 1]]
_offset = counts_cumsum[i].item()
_batch_index = batch_edge[counts_cumsum[i]].item()
for j in range(_n_edges - 1):
for k in range(j + 1, _n_edges):
Expand All @@ -101,8 +105,12 @@ def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, d
_j, _k = k, j

triplet_node_index_list.append([_src, _dst0, _dst1])
shift_list.append(
torch.stack([-_shift[_j], -_shift[_k], _shift[_j] - _shift[_k]], dim=0)
edge_jk_list.append(
[
_offset + _j,
_offset + _k,
]
# torch.stack([-_shift[_j], -_shift[_k], _shift[_j] - _shift[_k]], dim=0)
)
# --- multiplicity ---
if _dst0 == _dst1:
Expand All @@ -126,7 +134,8 @@ def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, d
triplet_node_index = torch.as_tensor(triplet_node_index_list, device=device)
# (n_triplet_edges)
multiplicity = torch.as_tensor(multiplicity_list, dtype=dtype, device=device)
# (n_triplet_edges, 2=(j, k))
edge_jk = edge_indices[torch.tensor(edge_jk_list, dtype=torch.long, device=device)]
# (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) )
triplet_shift = torch.stack(shift_list, dim=0)
batch_triplets = torch.as_tensor(batch_triplets_list, dtype=torch.long, device=device)
return triplet_node_index, multiplicity, triplet_shift, batch_triplets
return triplet_node_index, multiplicity, edge_jk, batch_triplets
31 changes: 12 additions & 19 deletions torch_dftd/functions/triplets_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def _cupy2torch(array: cp.ndarray) -> Tensor:

if _cupy_available:
_calc_triplets_core_gpu_kernel = cp.ElementwiseKernel(
"raw int64 counts, raw int64 unique, raw int64 dst, raw T shift, raw int64 batch_edge, raw int64 counts_cumsum",
"raw int64 triplet_node_index, raw T multiplicity, raw T triplet_shift, raw int64 batch_triplets",
"raw int64 counts, raw int64 unique, raw int64 dst, raw int64 edge_indices, raw int64 batch_edge, raw int64 counts_cumsum",
"raw int64 triplet_node_index, raw T multiplicity, raw int64 edge_jk, raw int64 batch_triplets",
"""
long long n_unique = unique.size();
long long a = 0;
Expand Down Expand Up @@ -100,16 +100,9 @@ def _cupy2torch(array: cp.ndarray) -> Tensor:
}
}

// --- triplet_shift ---
triplet_shift[9 * i] = -shift[3 * (_offset + b)];
triplet_shift[9 * i + 1] = -shift[3 * (_offset + b) + 1];
triplet_shift[9 * i + 2] = -shift[3 * (_offset + b) + 2];
triplet_shift[9 * i + 3] = -shift[3 * (_offset + c)];
triplet_shift[9 * i + 4] = -shift[3 * (_offset + c) + 1];
triplet_shift[9 * i + 5] = -shift[3 * (_offset + c) + 2];
triplet_shift[9 * i + 6] = shift[3 * (_offset + b)] - shift[3 * (_offset + c)];
triplet_shift[9 * i + 7] = shift[3 * (_offset + b) + 1] - shift[3 * (_offset + c) + 1];
triplet_shift[9 * i + 8] = shift[3 * (_offset + b) + 2] - shift[3 * (_offset + c) + 2];
// --- edge_jk ---
edge_jk[2 * i] = edge_indices[_offset + b];
edge_jk[2 * i + 1] = edge_indices[_offset + c];

// --- batch_triplets ---
batch_triplets[i] = _batch_index;
Expand All @@ -124,7 +117,7 @@ def _calc_triplets_core_gpu(
counts: Tensor,
unique: Tensor,
dst: Tensor,
shift: Tensor,
edge_indices: Tensor,
batch_edge: Tensor,
counts_cumsum: Tensor,
dtype: torch.dtype = torch.float32,
Expand All @@ -140,26 +133,26 @@ def _calc_triplets_core_gpu(
triplet_node_index = torch.zeros((n_triplets, 3), dtype=torch.long, device=device)
# (n_triplet_edges)
multiplicity = torch.zeros((n_triplets,), dtype=dtype, device=device)
# (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) )
triplet_shift = torch.zeros((n_triplets, 3, 3), dtype=dtype, device=device)
# (n_triplet_edges, 2=(j, k))
edge_jk = torch.zeros((n_triplets, 2), dtype=torch.long, device=device)
# (n_triplet_edges)
batch_triplets = torch.zeros((n_triplets,), dtype=torch.long, device=device)
if n_triplets == 0:
return triplet_node_index, multiplicity, triplet_shift, batch_triplets
return triplet_node_index, multiplicity, edge_jk, batch_triplets

_calc_triplets_core_gpu_kernel(
_torch2cupy(counts),
_torch2cupy(unique),
_torch2cupy(dst),
_torch2cupy(shift),
_torch2cupy(edge_indices),
_torch2cupy(batch_edge),
_torch2cupy(counts_cumsum),
# n_triplets,
_torch2cupy(triplet_node_index),
_torch2cupy(multiplicity),
_torch2cupy(triplet_shift),
_torch2cupy(edge_jk),
_torch2cupy(batch_triplets),
size=n_triplets,
)
# torch tensor buffer is already modified in above cupy functions.
return triplet_node_index, multiplicity, triplet_shift, batch_triplets
return triplet_node_index, multiplicity, edge_jk, batch_triplets
17 changes: 15 additions & 2 deletions torch_dftd/nn/base_dftd_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def calc_energy_and_forces(
# We need to explicitly include this dependency to calculate cell gradient
# for stress computation.
# pos is assumed to be inside "cell", so relative position `rel_pos` lies between 0~1.
assert isinstance(shift, Tensor)
if batch is None:
rel_pos = torch.mm(pos, torch.inverse(cell))
pos = torch.mm(rel_pos.detach(), cell)
Expand All @@ -127,6 +128,7 @@ def calc_energy_and_forces(

pos.retain_grad()
cell.retain_grad()
shift.retain_grad()
E_disp = self.calc_energy_batch(
Z, pos, edge_index, cell, pbc, shift, batch, batch_edge, damping=damping
)
Expand All @@ -150,13 +152,24 @@ def calc_energy_and_forces(
# Get stress in Voigt notation (xx, yy, zz, yz, xz, xy)
if batch is None:
cell_volume = torch.det(cell).abs()
stress = torch.mm(cell.grad, cell.T) / cell_volume
cell_grad = torch.mm(torch.inverse(cell.T), torch.mm(pos.T, pos.grad))
cell_grad += torch.mm(torch.inverse(cell.T), torch.mm(shift.T, shift.grad))
stress = torch.mm(cell_grad, cell.T) / cell_volume
stress = stress.view(-1)[[0, 4, 8, 5, 2, 1]]
results_list[0]["stress"] = stress.detach().cpu().numpy()
else:
cell_volume = torch.det(cell).abs()
cell_T = cell.permute(0, 2, 1)
# cell (bs, 3, 3)
stress = torch.bmm(cell.grad, cell.permute(0, 2, 1)) / cell_volume[:, None, None]
edge_grad = shift.new_zeros((n_graphs, 3, 3))
edge_grad.scatter_add_(
0,
batch_edge.view(batch_edge.size()[0], 1, 1).expand(batch_edge.size()[0], 3, 3),
shift[:, :, None] * shift.grad[:, None, :],
)
cell_grad = cell.grad
cell_grad += torch.bmm(torch.inverse(cell_T), edge_grad)
stress = torch.bmm(cell_grad, cell_T) / cell_volume[:, None, None]
stress = stress.view(-1, 9)[:, [0, 4, 8, 5, 2, 1]].detach().cpu().numpy()
for i in range(n_graphs):
results_list[i]["stress"] = stress[i]
Expand Down
6 changes: 4 additions & 2 deletions torch_dftd/nn/dftd2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def calc_energy_batch(
damping: str = "zero",
) -> Tensor:
"""Forward computation to calculate atomic wise dispersion energy"""
shift = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift is None else shift
pos_bohr = pos / d3_autoang # angstrom -> bohr
if cell is None:
cell_bohr = None
cell_bohr: Optional[Tensor] = None
else:
cell_bohr = cell / d3_autoang # angstrom -> bohr
r = calc_distances(pos_bohr, edge_index, cell_bohr, shift, batch_edge=batch_edge)
shift_bohr = shift / d3_autoang # angstrom -> bohr
r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr, batch_edge=batch_edge)

# E_disp (n_graphs,): Energy in eV unit
E_disp = d3_autoev * edisp_d2(
Expand Down
Loading