Skip to content

Commit

Permalink
chore: provide multiple typemap zbl torch
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 8, 2024
1 parent 09bd522 commit 2d8a60f
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 8 deletions.
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_spin_model(model_params):

def get_zbl_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
ntypes = len(model_params["type_map"]["zbl"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
descriptor = BaseDescriptor(**model_params["descriptor"])
Expand All @@ -99,14 +99,14 @@ def get_zbl_model(model_params):
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]["dp"])
# pairtab
filepath = model_params["use_srtab"]
pt_model = PairTabAtomicModel(
filepath,
model_params["descriptor"]["rcut"],
model_params["descriptor"]["sel"],
type_map=model_params["type_map"],
type_map=model_params["type_map"]["pairtab"],
)

rmin = model_params["sw_rmin"]
Expand All @@ -118,7 +118,7 @@ def get_zbl_model(model_params):
pt_model,
rmin,
rmax,
type_map=model_params["type_map"],
type_map=model_params["type_map"]["zbl"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
Expand Down
94 changes: 94 additions & 0 deletions examples/water/zbl/input_torch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"_comment1": " model parameters",
"model": {
"use_srtab": "H2O_tab_potential.txt",
"smin_alpha": 0.1,
"sw_rmin": 0.8,
"sw_rmax": 1.0,
"type_map": {
"pairtab":["O","H"],
"dp": ["O", "H"],
"zbl": ["O", "H"]
},
"descriptor": {
"type": "se_e2_a",
"sel": [
46,
92
],
"rcut_smth": 0.50,
"rcut": 6.00,
"neuron": [
25,
50,
100
],
"resnet_dt": false,
"axis_neuron": 16,
"type_one_side": true,
"precision": "float64",
"seed": 1,
"_comment2": " that's all"
},
"fitting_net": {
"neuron": [
240,
240,
240
],
"resnet_dt": true,
"precision": "float64",
"seed": 1,
"_comment3": " that's all"
},
"_comment4": " that's all"
},

"learning_rate": {
"type": "exp",
"decay_steps": 5000,
"start_lr": 0.001,
"stop_lr": 3.51e-8,
"_comment5": "that's all"
},

"loss": {
"type": "ener",
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0,
"_comment6": " that's all"
},

"training": {
"training_data": {
"systems": [
"../data/data_0/",
"../data/data_1/",
"../data/data_2/"
],
"batch_size": "auto",
"_comment7": "that's all"
},
"validation_data": {
"systems": [
"../data/data_3"
],
"batch_size": 1,
"numb_btch": 3,
"_comment8": "that's all"
},
"numb_steps": 1000000,
"seed": 10,
"disp_file": "lcurve.out",
"disp_freq": 100,
"save_freq": 1000,
"_comment9": "that's all"
},

"_comment10": "that's all"
}

9 changes: 5 additions & 4 deletions source/tests/pt/model/water/zbl.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
"smin_alpha": 0.1,
"sw_rmin": 0.8,
"sw_rmax": 1.0,
"type_map": [
"O",
"H"
],
"type_map": {
"pairtab":["O","H"],
"dp": ["O", "H"],
"zbl": ["O", "H"]
},
"descriptor": {
"type": "se_e2_a",
"sel": [
Expand Down

0 comments on commit 2d8a60f

Please sign in to comment.