Skip to content

Commit

Permalink
update build model
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingCatty committed Nov 17, 2023
1 parent 96d9076 commit 1bbaba5
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 113 deletions.
14 changes: 9 additions & 5 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AtomicDataDict.ATOM_TYPE_KEY,
AtomicDataDict.BATCH_KEY,
}

_DEFAULT_NODE_FIELDS: Set[str] = {
AtomicDataDict.POSITIONS_KEY,
AtomicDataDict.NODE_FEATURES_KEY,
Expand All @@ -43,9 +44,11 @@
AtomicDataDict.ATOM_TYPE_KEY,
AtomicDataDict.FORCE_KEY,
AtomicDataDict.PER_ATOM_ENERGY_KEY,
AtomicDataDict.NODE_HAMILTONIAN_KEY,
AtomicDataDict.NODE_OVERLAP_KEY,
AtomicDataDict.BATCH_KEY,
}

_DEFAULT_EDGE_FIELDS: Set[str] = {
AtomicDataDict.EDGE_CELL_SHIFT_KEY,
AtomicDataDict.EDGE_VECTORS_KEY,
Expand All @@ -56,6 +59,7 @@
AtomicDataDict.EDGE_CUTOFF_KEY,
AtomicDataDict.EDGE_ENERGY_KEY,
AtomicDataDict.EDGE_OVERLAP_KEY,
AtomicDataDict.EDGE_HAMILTONIAN_KEY,
AtomicDataDict.EDGE_TYPE_KEY,
}

Expand All @@ -67,7 +71,6 @@
AtomicDataDict.ENV_EMBEDDING_KEY,
AtomicDataDict.ENV_FEATURES_KEY,
AtomicDataDict.ENV_CUTOFF_KEY,

}

_DEFAULT_ONSITENV_FIELDS: Set[str] = {
Expand All @@ -79,6 +82,7 @@
AtomicDataDict.ONSITENV_FEATURES_KEY,
AtomicDataDict.ONSITENV_CUTOFF_KEY,
}

_DEFAULT_GRAPH_FIELDS: Set[str] = {
AtomicDataDict.TOTAL_ENERGY_KEY,
AtomicDataDict.STRESS_KEY,
Expand All @@ -91,6 +95,7 @@
AtomicDataDict.OVERLAP_KEY, # new
AtomicDataDict.ENERGY_EIGENVALUE_KEY # new
}

_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS)
_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS)
_ENV_FIELDS: Set[str] = set(_DEFAULT_ENV_FIELDS)
Expand Down Expand Up @@ -350,7 +355,6 @@ def from_points(
strict_self_interaction: bool = True,
cell=None,
pbc: Optional[PBC] = None,
reduce: Optional[bool] = True,
er_max: Optional[float] = None,
oer_max: Optional[float] = None,
**kwargs,
Expand Down Expand Up @@ -387,8 +391,8 @@ def from_points(
else:
assert len(pbc) == 3

# TODO: Need to add edge features and edge index.
# TODO: We can only compute the edge vector one times with the largest radial distance among [r_max, er_max, oer_max]

pos = torch.as_tensor(pos, dtype=torch.get_default_dtype())

edge_index, edge_cell_shift, cell = neighbor_list_and_relative_vec(
Expand All @@ -397,7 +401,7 @@ def from_points(
self_interaction=self_interaction,
strict_self_interaction=strict_self_interaction,
cell=cell,
reduce=reduce,
reduce=True,
atomic_numbers=kwargs.get("atomic_numbers", None),
pbc=pbc,
)
Expand Down
2 changes: 2 additions & 0 deletions dptb/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
EDGE_ENERGY_KEY: Final[str] = "edge_energy"
EDGE_OVERLAP_KEY: Final[str] = "edge_overlap"
NODE_OVERLAP_KEY: Final[str] = "node_overlap"
EDGE_HAMILTONIAN_KEY: Final[str] = "edge_hamiltonian"
NODE_HAMILTONIAN_KEY: Final[str] = "node_hamiltonian"

NODE_FEATURES_KEY: Final[str] = "node_features"
NODE_ATTRS_KEY: Final[str] = "node_attrs"
Expand Down
40 changes: 25 additions & 15 deletions dptb/data/use_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from _build import dataset_from_config\n",
"from build import dataset_from_config\n",
"from dptb.utils.config import Config"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -54,29 +54,39 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing dataset...\n",
"Done!\n"
]
}
],
"outputs": [],
"source": [
"dataset = dataset_from_config(config=config, prefix=\"dataset\")\n",
"\n",
"from dptb.data.dataloader import DataLoader\n",
"\n",
"dl = DataLoader(dataset, 3)\n",
"\n",
"\n",
"data = next(iter(dl))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 8, 16, 24])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.to_dict()[\"ptr\"]"
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand Down
Loading

0 comments on commit 1bbaba5

Please sign in to comment.