Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mergev2 merge data nequip refactoring into main #61

Merged
merged 97 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
baf3d20
add data
floatingCatty Oct 17, 2023
7a000ca
adapt data and nn module of nequip into deeptb
floatingCatty Oct 19, 2023
a3b6525
just modify some imports
floatingCatty Oct 21, 2023
32a11fb
Merge branch 'main' of github.com:deepmodeling/DeePTB into data_nequip
floatingCatty Oct 26, 2023
ff8f325
update torch-geometry
floatingCatty Oct 28, 2023
9df08f8
add kpoint eigenvalue support
floatingCatty Oct 30, 2023
20edbd2
add support for nested tensor
floatingCatty Oct 31, 2023
d65a115
update
floatingCatty Oct 31, 2023
1c969e9
update data and add batchlize hamiltonian
floatingCatty Nov 3, 2023
94dda6c
update se3 rotation
floatingCatty Nov 3, 2023
21771e4
update test
floatingCatty Nov 3, 2023
7e69899
Merge branch 'main' of github.com:deepmodeling/DeePTB into data_nequip
floatingCatty Nov 3, 2023
ba7e83f
update
floatingCatty Nov 3, 2023
2854864
debug e3
floatingCatty Nov 3, 2023
9f3fbf9
update hamileig
floatingCatty Nov 3, 2023
0dfdd5d
Merge branch 'main' of github.com:deepmodeling/DeePTB into data_nequip
floatingCatty Nov 3, 2023
2379217
delete nequip nn and write our own based on PyG
floatingCatty Nov 4, 2023
775dac6
update nn
floatingCatty Nov 6, 2023
78534e8
nn refactor, write hamiltonian and hop function
floatingCatty Nov 8, 2023
40231a3
update sk hamiltonian and onsite function
floatingCatty Nov 8, 2023
ed22f3e
refactor sktb and add register for descriptor
floatingCatty Nov 9, 2023
feea761
update param prototype and dptb
floatingCatty Nov 9, 2023
a13e8f1
refactor index mapping to data transform
floatingCatty Nov 10, 2023
26afa52
debug sktb and e3tb module
floatingCatty Nov 11, 2023
c6b3b2e
finish debuging sk and e3
floatingCatty Nov 12, 2023
f9f1ed5
update data interfaces
floatingCatty Nov 12, 2023
6dda4fd
update r2k and transform
floatingCatty Nov 13, 2023
aa8710f
remove dash line in file names
floatingCatty Nov 14, 2023
5acd67d
fnishied debugging deeptb module
floatingCatty Nov 14, 2023
07c0a3c
finish debugging hr2hk
floatingCatty Nov 14, 2023
03dafc3
update overlap support
floatingCatty Nov 15, 2023
c57107f
update base trainer and example quantities
floatingCatty Nov 16, 2023
96d9076
Merge branch 'main' of github.com:deepmodeling/DeePTB into data_nequip
floatingCatty Nov 16, 2023
1bbaba5
update build model
floatingCatty Nov 17, 2023
7416057
update trainer
floatingCatty Nov 18, 2023
9e401da
update pyproject.toml dependencies
floatingCatty Nov 18, 2023
87b485a
update bond reduction and self-interaction
floatingCatty Nov 20, 2023
fc9c067
debug nnsk
floatingCatty Nov 20, 2023
43926f6
nnsk run succeed, add from v1 json model
floatingCatty Nov 20, 2023
f2653fe
add nnsk test example of AlAs coupond system
floatingCatty Nov 21, 2023
959a196
Add 'ABACUSDataset' in data module (#9)
SharpLonde Nov 21, 2023
5c16677
debug new dptb and trainer
floatingCatty Nov 21, 2023
b58e472
Merge branch 'data_nequip' of github.com:floatingCatty/DeePTB into da…
floatingCatty Nov 21, 2023
9885430
debug datasets
floatingCatty Nov 22, 2023
f27cb78
pass cmd line train mod to new model and data
floatingCatty Nov 22, 2023
014ff21
add some comments in neighbor_list_and_relative_vec.
QG-phy Nov 23, 2023
0e98830
Merge pull request #10 from floatingCatty/pr/44
QG-phy Nov 23, 2023
9b67da3
add overlap fitting support
floatingCatty Nov 23, 2023
9856bdb
update baseline descriptor and debug validationer
floatingCatty Nov 25, 2023
8e3950c
Merge branch 'data_nequip' of github.com:floatingCatty/DeePTB into da…
floatingCatty Nov 25, 2023
ebf326f
update e3deeph module
floatingCatty Nov 29, 2023
2db61f2
update deephe3 module
floatingCatty Nov 30, 2023
8f85806
Added ABACUSInMemoryDataset in data module (#11)
SharpLonde Dec 1, 2023
5a0de48
update dataset and add deephdataset
floatingCatty Dec 1, 2023
439c622
Merge branch 'data_nequip' of github.com:floatingCatty/DeePTB into da…
floatingCatty Dec 1, 2023
044d7f9
gpu support and debugging
floatingCatty Dec 5, 2023
37f8058
add dptb+nnsk mix model, debugging build, restart
floatingCatty Dec 5, 2023
9a8f604
align run.py, test.py, main.py
floatingCatty Dec 5, 2023
9018a21
debugging
floatingCatty Dec 5, 2023
e58fffe
final
floatingCatty Dec 6, 2023
589d7e4
add new model backbone on allegro
floatingCatty Dec 12, 2023
d02323e
add new e3 embeding and lr schedular
floatingCatty Dec 18, 2023
a9f4a75
Added `DefaultDataset` (#12)
SharpLonde Dec 19, 2023
b52ccb9
aggregating new data class
floatingCatty Dec 19, 2023
d4a458d
debug plugin savor and support atom specific cutoffs
floatingCatty Dec 25, 2023
9d74673
refactor bond reduction and rme parameterization
floatingCatty Dec 28, 2023
51f7dc1
add E3 fitting analysis and E3 rescale
floatingCatty Dec 29, 2023
f069551
update LossAnalysis and e3baseline model
floatingCatty Dec 31, 2023
5aae7ff
update band calc and debug nnsk add orbitals
floatingCatty Jan 2, 2024
cebe7d0
update datatype switch
floatingCatty Jan 2, 2024
f84c016
Unified dataset IO (#13)
SharpLonde Jan 4, 2024
1b95c32
update e3 descriptor and OrbitalMapper
floatingCatty Jan 5, 2024
4e2ead2
Merge branch 'data_nequip' of github.com:floatingCatty/DeePTB into da…
floatingCatty Jan 5, 2024
398139a
Bug fix in reading trajectory data (#15)
SharpLonde Jan 6, 2024
a17871c
add comment and complete eig loss
floatingCatty Jan 9, 2024
2023be5
update new embedding and dependencies
floatingCatty Jan 18, 2024
1ea8ef4
New version of `E3statistics` (#17)
SharpLonde Jan 18, 2024
2d0da97
adding statistics initialization
floatingCatty Jan 19, 2024
dfb14fd
debug nnsk batchlization and eigenvalues loading
floatingCatty Jan 20, 2024
f79e637
debug nnsk
floatingCatty Jan 22, 2024
ab62be9
optimizing saving best checkpoint
floatingCatty Jan 22, 2024
62d7db0
Pr/44 (#19)
QG-phy Jan 22, 2024
73558bc
debug nnsk add orbital and strain
floatingCatty Jan 22, 2024
a9e3685
Merge branch 'data_nequip' of https://github.com/floatingCatty/DeePTB…
floatingCatty Jan 22, 2024
3e07ce2
update `.npy` files loading procedure in DefaultDataset (#18)
SharpLonde Jan 23, 2024
bb2f435
optimizing init and restart param loading
floatingCatty Jan 23, 2024
c7b9d51
Merge branch 'data_nequip' of https://github.com/floatingCatty/DeePTB…
floatingCatty Jan 23, 2024
6c7dbb0
update nnsk push thr
floatingCatty Jan 23, 2024
e4ff19f
update mix model param and deeptb sktb param
floatingCatty Jan 30, 2024
c4b586e
BUG FIX in loading `kpoints.npy` files with `ndim==3` (#20)
SharpLonde Feb 1, 2024
e3563f9
refactor test
floatingCatty Feb 1, 2024
deebcb6
update nrl
floatingCatty Feb 1, 2024
a16c699
Merge branch 'merge_v2' into main
floatingCatty Feb 2, 2024
859e253
denote run
floatingCatty Feb 2, 2024
5ca0ab0
add ref and val batch size
floatingCatty Feb 2, 2024
834f144
update assert bond length less than 83
floatingCatty Feb 2, 2024
f17e7f1
update param init std in nnsk
floatingCatty Feb 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
993 changes: 993 additions & 0 deletions dptb/data/AtomicData.py

Large diffs are not rendered by default.

233 changes: 233 additions & 0 deletions dptb/data/AtomicDataDict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""nequip.data.jit: TorchScript functions for dealing with AtomicData.

These TorchScript functions operate on ``Dict[str, torch.Tensor]`` representations
of the ``AtomicData`` class which are produced by ``AtomicData.to_AtomicDataDict()``.

Authors: Albert Musaelian
"""
from typing import Dict, Any

import torch
import torch.jit

from e3nn import o3

# Make the keys available in this module
from ._keys import * # noqa: F403, F401

# Also import the module to use in TorchScript, this is a hack to avoid bug:
# https://github.com/pytorch/pytorch/issues/52312
from . import _keys

# Define a type alias
Type = Dict[str, torch.Tensor]


def validate_keys(keys, graph_required=True):
# Validate combinations
if graph_required:
if not (_keys.POSITIONS_KEY in keys and _keys.EDGE_INDEX_KEY in keys):
raise KeyError("At least pos and edge_index must be supplied")
if _keys.EDGE_CELL_SHIFT_KEY in keys and "cell" not in keys:
raise ValueError("If `edge_cell_shift` given, `cell` must be given.")


_SPECIAL_IRREPS = [None]


def _fix_irreps_dict(d: Dict[str, Any]):
return {k: (i if i in _SPECIAL_IRREPS else o3.Irreps(i)) for k, i in d.items()}


def _irreps_compatible(ir1: Dict[str, o3.Irreps], ir2: Dict[str, o3.Irreps]):
return all(ir1[k] == ir2[k] for k in ir1 if k in ir2)


@torch.jit.script
def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type:
"""Compute the edge displacement vectors for a graph.

If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
method will return edge vectors correctly connected in the autograd graph.

Returns:
Tensor [n_edges, 3] edge displacement vectors
"""
if _keys.EDGE_VECTORS_KEY in data:
if with_lengths and _keys.EDGE_LENGTH_KEY not in data:
data[_keys.EDGE_LENGTH_KEY] = torch.linalg.norm(
data[_keys.EDGE_VECTORS_KEY], dim=-1
)

return data
else:
# Build it dynamically
# Note that this is
# (1) backwardable, because everything (pos, cell, shifts)
# is Tensors.
# (2) works on a Batch constructed from AtomicData
pos = data[_keys.POSITIONS_KEY]
edge_index = data[_keys.EDGE_INDEX_KEY]
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
if _keys.CELL_KEY in data:
# ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
# -1 gives a batch dim no matter what
cell = data[_keys.CELL_KEY].view(-1, 3, 3)
edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY]
if cell.shape[0] > 1:
batch = data[_keys.BATCH_KEY]
# Cell has a batch dimension
# note the ASE cell vectors as rows convention
edge_vec = edge_vec + torch.einsum(
"ni,nij->nj", edge_cell_shift, cell[batch[edge_index[0]]]
)
# TODO: is there a more efficient way to do the above without
# creating an [n_edge] and [n_edge, 3, 3] tensor?
else:
# Cell has either no batch dimension, or a useless one,
# so we can avoid creating the large intermediate cell tensor.
# Note that we do NOT check that the batch array, if it is present,
# is trivial — but this does need to be consistent.
edge_vec = edge_vec + torch.einsum(
"ni,ij->nj",
edge_cell_shift,
cell.squeeze(0), # remove batch dimension
)

data[_keys.EDGE_VECTORS_KEY] = edge_vec
if with_lengths:
data[_keys.EDGE_LENGTH_KEY] = torch.linalg.norm(edge_vec, dim=-1)
return data

@torch.jit.script
def with_env_vectors(data: Type, with_lengths: bool = True) -> Type:
"""Compute the edge displacement vectors for a graph.

If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
method will return edge vectors correctly connected in the autograd graph.

Returns:
Tensor [n_edges, 3] edge displacement vectors
"""
if _keys.ENV_VECTORS_KEY in data:
if with_lengths and _keys.ENV_LENGTH_KEY not in data:
data[_keys.ENV_LENGTH_KEY] = torch.linalg.norm(
data[_keys.ENV_VECTORS_KEY], dim=-1
)
return data
else:
# Build it dynamically
# Note that this is
# (1) backwardable, because everything (pos, cell, shifts)
# is Tensors.
# (2) works on a Batch constructed from AtomicData
pos = data[_keys.POSITIONS_KEY]
env_index = data[_keys.ENV_INDEX_KEY]
env_vec = pos[env_index[1]] - pos[env_index[0]]
if _keys.CELL_KEY in data:
# ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
# -1 gives a batch dim no matter what
cell = data[_keys.CELL_KEY].view(-1, 3, 3)
env_cell_shift = data[_keys.ENV_CELL_SHIFT_KEY]
if cell.shape[0] > 1:
batch = data[_keys.BATCH_KEY]
# Cell has a batch dimension
# note the ASE cell vectors as rows convention
env_vec = env_vec + torch.einsum(
"ni,nij->nj", env_cell_shift, cell[batch[env_index[0]]]
)
# TODO: is there a more efficient way to do the above without
# creating an [n_edge] and [n_edge, 3, 3] tensor?
else:
# Cell has either no batch dimension, or a useless one,
# so we can avoid creating the large intermediate cell tensor.
# Note that we do NOT check that the batch array, if it is present,
# is trivial — but this does need to be consistent.
env_vec = env_vec + torch.einsum(
"ni,ij->nj",
env_cell_shift,
cell.squeeze(0), # remove batch dimension
)
data[_keys.ENV_VECTORS_KEY] = env_vec
if with_lengths:
data[_keys.ENV_LENGTH_KEY] = torch.linalg.norm(env_vec, dim=-1)
return data

@torch.jit.script
def with_onsitenv_vectors(data: Type, with_lengths: bool = True) -> Type:
"""Compute the edge displacement vectors for a graph.

If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
method will return edge vectors correctly connected in the autograd graph.

Returns:
Tensor [n_edges, 3] edge displacement vectors
"""
if _keys.ONSITENV_VECTORS_KEY in data:
if with_lengths and _keys.ONSITENV_LENGTH_KEY not in data:
data[_keys.ONSITENV_LENGTH_KEY] = torch.linalg.norm(
data[_keys.ONSITENV_VECTORS_KEY], dim=-1
)
return data
else:
# Build it dynamically
# Note that this is
# (1) backwardable, because everything (pos, cell, shifts)
# is Tensors.
# (2) works on a Batch constructed from AtomicData
pos = data[_keys.POSITIONS_KEY]
env_index = data[_keys.ONSITENV_INDEX_KEY]
env_vec = pos[env_index[1]] - pos[env_index[0]]
if _keys.CELL_KEY in data:
# ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
# -1 gives a batch dim no matter what
cell = data[_keys.CELL_KEY].view(-1, 3, 3)
env_cell_shift = data[_keys.ONSITENV_CELL_SHIFT_KEY]
if cell.shape[0] > 1:
batch = data[_keys.BATCH_KEY]
# Cell has a batch dimension
# note the ASE cell vectors as rows convention
env_vec = env_vec + torch.einsum(
"ni,nij->nj", env_cell_shift, cell[batch[env_index[0]]]
)
# TODO: is there a more efficient way to do the above without
# creating an [n_edge] and [n_edge, 3, 3] tensor?
else:
# Cell has either no batch dimension, or a useless one,
# so we can avoid creating the large intermediate cell tensor.
# Note that we do NOT check that the batch array, if it is present,
# is trivial — but this does need to be consistent.
env_vec = env_vec + torch.einsum(
"ni,ij->nj",
env_cell_shift,
cell.squeeze(0), # remove batch dimension
)
data[_keys.ONSITENV_VECTORS_KEY] = env_vec
if with_lengths:
data[_keys.ONSITENV_LENGTH_KEY] = torch.linalg.norm(env_vec, dim=-1)
return data


@torch.jit.script
def with_batch(data: Type) -> Type:
"""Get batch Tensor.

If this AtomicDataPrimitive has no ``batch``, one of all zeros will be
allocated and returned.
"""
if _keys.BATCH_KEY in data:
return data
else:
pos = data[_keys.POSITIONS_KEY]
batch = torch.zeros(len(pos), dtype=torch.long, device=pos.device)
data[_keys.BATCH_KEY] = batch
# ugly way to make a tensor of [0, len(pos)], but it avoids transfers or casts
data[_keys.BATCH_PTR_KEY] = torch.arange(
start=0,
end=len(pos) + 1,
step=len(pos),
dtype=torch.long,
device=pos.device,
)

return data
49 changes: 49 additions & 0 deletions dptb/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from .AtomicData import (
AtomicData,
PBC,
register_fields,
deregister_fields,
_register_field_prefix,
_NODE_FIELDS,
_EDGE_FIELDS,
_GRAPH_FIELDS,
_LONG_FIELDS,
)
from .dataset import (
AtomicDataset,
AtomicInMemoryDataset,
NpzDataset,
ASEDataset,
HDF5Dataset,
ABACUSDataset,
ABACUSInMemoryDataset,
DefaultDataset
)
from .dataloader import DataLoader, Collater, PartialSampler
from .build import dataset_from_config
from .test_data import EMTTestDataset

__all__ = [
AtomicData,
PBC,
register_fields,
deregister_fields,
_register_field_prefix,
AtomicDataset,
AtomicInMemoryDataset,
NpzDataset,
ASEDataset,
HDF5Dataset,
ABACUSDataset,
ABACUSInMemoryDataset,
DefaultDataset,
DataLoader,
Collater,
PartialSampler,
dataset_from_config,
_NODE_FIELDS,
_EDGE_FIELDS,
_GRAPH_FIELDS,
_LONG_FIELDS,
EMTTestDataset,
]
Loading
Loading