Skip to content

Commit

Permalink
finish cpu pass
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Mar 19, 2024
1 parent b78f457 commit 49f0204
Show file tree
Hide file tree
Showing 22 changed files with 473 additions and 121 deletions.
3 changes: 2 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from deepmd.pt.utils.env import (
DEVICE,
load_op
)
from deepmd.pt.utils.finetune import (
change_finetune_model_params,
Expand Down Expand Up @@ -301,7 +302,7 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None)
log.debug("Log handles were successfully set")
log.info("DeepMD version: %s", __version__)

load_op()
if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "freeze":
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def forward_common_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
) -> Dict[str, torch.Tensor]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
Expand All @@ -107,6 +108,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict
)

if self.atom_excl is not None:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.
Expand Down Expand Up @@ -167,6 +168,7 @@ def forward_atomic(
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict
)
assert descriptor is not None
# energy, force
Expand Down
20 changes: 14 additions & 6 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Optional,
Tuple,
Union,
Dict,
)

import torch
Expand All @@ -21,6 +22,10 @@
from deepmd.pt.utils.update_sel import (
UpdateSel,
)

from deepmd.pt.utils.env import(
load_op
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -395,6 +400,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
):
"""Compute the descriptor.
Expand Down Expand Up @@ -450,11 +456,12 @@ def forward(
# linear to change shape
g1 = self.g1_shape_tranform(g1)
# mapping g1
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
if(comm_dict is None):
assert mapping is not None
# mapping_ext = (
# mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
# )
# g1_ext = torch.gather(g1, 1, mapping_ext)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -464,8 +471,9 @@ def forward(
],
extended_coord,
extended_atype,
g1_ext,
g1,
mapping,
comm_dict
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
41 changes: 34 additions & 7 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,18 @@ def reinit_exclude(
):
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@torch.jit.script_method
def forward(
self,
nlist: torch.Tensor,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
):
assert mapping is not None
assert extended_atype_embd is not None
if comm_dict is None:
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -257,8 +258,12 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# [nframes, nloc, tebd_dim]
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
#atype_embd = extended_atype_embd[:, :nloc, :]
atype_embd = extended_atype_embd
if atype_embd is not None:
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
else:
raise NotImplementedError

g1 = self.act(atype_embd)
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
Expand All @@ -275,11 +280,33 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x ng1
mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
if comm_dict is None:
assert mapping is not None
mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
g1_ext = torch.gather(g1, 1, mapping)
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
# padding = torch.zeros(nall-nloc, g1.size(2),device=mydev)
# g1 = torch.cat((g1.squeeze(0), padding), dim=0)
n_padding = nall -nloc
g1 = torch.nn.functional.pad(g1.squeeze(0), (0, 0, 0, n_padding), value=0.0)
assert 'send_list' in comm_dict
assert 'send_proc' in comm_dict
assert 'recv_proc' in comm_dict
assert 'send_num' in comm_dict
assert 'recv_num' in comm_dict
assert 'communicator' in comm_dict
ret = env.op_module.border_op(comm_dict['send_list'],
comm_dict['send_proc'], comm_dict['recv_proc'],
comm_dict['send_num'], comm_dict['recv_num'],
g1,
comm_dict['communicator'],torch.tensor(nloc),torch.tensor(nall-nloc))
g1_ext = ret[0].unsqueeze(0)

g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
):
"""Compute the descriptor.
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from .dp_model import (
DPModel,
)

from deepmd.pt.utils.env import (
load_op
)

class EnergyModel(DPModel):
model_type = "ener"
Expand Down Expand Up @@ -69,6 +71,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -78,6 +81,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict
)
if self.fitting_net is not None:
model_predict = {}
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def forward_common(
mapping,
do_atomic_virial=do_atomic_virial,
fparam=fp,
aparam=ap,
aparam=ap
)
model_predict = communicate_extended_output(
model_predict_lower,
Expand All @@ -167,6 +167,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand Down Expand Up @@ -210,6 +211,7 @@ def forward_common_lower(
mapping=mapping,
fparam=fp,
aparam=ap,
comm_dict = comm_dict
)
model_predict = fit_output_to_model_output(
atomic_ret,
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os

from typing import (
Any,
)
import numpy as np
import torch

Expand Down Expand Up @@ -80,6 +82,7 @@
"ENERGY_BIAS_TRAINABLE",
"LOCAL_RANK",
]

def load_op():
torch.ops.load_library("/mnt/user/zhangxiangyu/workspace/dpkit/deepmd-kit/dp/lib/")
torch.ops.load_library("/mnt/user/zhangxiangyu/workspace/dpkit/deepmd-kit/source/op_pt/libop_pt.so")

op_module: Any=torch.ops.my_ops
4 changes: 0 additions & 4 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ endif()
# define names of libs
set(LIB_DEEPMD "deepmd")
set(LIB_DEEPMD_OP "deepmd_op")
set(LIB_DEEPMD_OP_PT "deepmd_op_pt")
if(BUILD_CPP_IF)
set(LIB_DEEPMD_CC "deepmd_cc")
set(LIB_DEEPMD_C "deepmd_c")
Expand Down Expand Up @@ -282,9 +281,6 @@ if(NOT DEEPMD_C_ROOT)
if(ENABLE_TENSORFLOW)
add_subdirectory(op/)
endif()
if(ENABLE_PYTORCH)
add_subdirectory(op_pt/)
endif()
add_subdirectory(lib/)
endif()
if(BUILD_PY_IF)
Expand Down
2 changes: 1 addition & 1 deletion source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int** sendlist,
int* sendproc,
int* recvproc,
long int* world);
int world);

/**
* @brief Delete a neighbor list.
Expand Down
42 changes: 24 additions & 18 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,20 +502,20 @@ inline double *_DP_Get_Energy_Pointer(double &vec, const int nframes) {

namespace deepmd {
namespace hpp {
struct CommData {
int nswap;
int* sendnum;
int* recvnum;
int* firstrecv;
int** sendlist;
int* sendproc;
int* recvproc;
long int* world;

CommData() : nswap(0), sendnum(nullptr), recvnum(nullptr),
firstrecv(nullptr), sendlist(nullptr),
sendproc(nullptr), recvproc(nullptr),world(nullptr) {}
};
// struct CommData {
// int nswap;
// int* sendnum;
// int* recvnum;
// int* firstrecv;
// int** sendlist;
// int* sendproc;
// int* recvproc;
// long int* world;

// CommData() : nswap(0), sendnum(nullptr), recvnum(nullptr),
// firstrecv(nullptr), sendlist(nullptr),
// sendproc(nullptr), recvproc(nullptr),world(nullptr) {}
// };
/**
* @brief Neighbor list.
**/
Expand All @@ -536,13 +536,20 @@ struct InputNlist {
nl(DP_NewNlist(inum_, ilist_, numneigh_, firstneigh_)) {
DP_CHECK_OK(DP_NlistCheckOK, nl);
};
InputNlist(int inum_, int *ilist_, int *numneigh_, int **firstneigh_, CommData *commdata_)
InputNlist(int inum_, int *ilist_, int *numneigh_, int **firstneigh_, int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
int world)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
firstneigh(firstneigh_),
nl(DP_NewNlist_comm(inum_, ilist_, numneigh_, firstneigh_,commdata_->nswap,commdata_->sendnum,commdata_->recvnum,commdata_->firstrecv,commdata_->sendlist,commdata_->sendproc,commdata_->recvproc,commdata_->world)) {
DP_CHECK_OK(DP_NlistCheckOK, nl);
nl(DP_NewNlist_comm(inum_, ilist_, numneigh_, firstneigh_,nswap,sendnum,recvnum,firstrecv,sendlist,sendproc,recvproc,world)) {
//DP_CHECK_OK(DP_NlistCheckOK, nl);
};
~InputNlist() { DP_DeleteNlist(nl); };
/// @brief C API neighbor list.
Expand Down Expand Up @@ -822,7 +829,6 @@ class DeepPot {
aparam);
const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr;
const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr;

_DP_DeepPotComputeNList<VALUETYPE>(
dp, nframes, natoms, coord_, atype_, box_, nghost, lmp_list.nl, ago,
fparam__, aparam__, ener_, force_, virial_, nullptr, nullptr);
Expand Down
5 changes: 2 additions & 3 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
int** sendlist,
int* sendproc,
int* recvproc,
long int* world) {
deepmd::CommData commdata(nswap, sendnum, recvnum, firstrecv, sendlist,
int world) {
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_,nswap, sendnum, recvnum, firstrecv, sendlist,
sendproc, recvproc, world);
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, &commdata);
DP_Nlist* new_nl = new DP_Nlist(nl);
return new_nl;
}
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ class DeepPotPT : public DeepPotBase {
int gpu_id;
bool gpu_enabled;
at::Tensor firstneigh_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
};

} // namespace deepmd
Loading

0 comments on commit 49f0204

Please sign in to comment.