Skip to content

Commit

Permalink
test fparam
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 17, 2024
1 parent 6a65065 commit d8d7873
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def serialize(self, suffix: str) -> dict:
ntypes=self.ntypes,
# TODO: consider type embeddings
ndim=1,
in_dim=self.dim_descrpt,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
Expand Down
17 changes: 13 additions & 4 deletions source/tests/consistent/fitting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,27 @@
class FittingTest:
"""Useful utilities for descriptor tests."""

def build_tf_fitting(self, obj, inputs, natoms, atype, suffix):
def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
t_inputs = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_inputs")
t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms")
t_atype = tf.placeholder(tf.int32, [None], name="i_atype")
t_des = obj.build(
extras = {}
feed_dict = {}
if fparam is not None:
t_fparam = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None], name="i_fparam"
)
extras["fparam"] = t_fparam
feed_dict[t_fparam] = fparam
t_out = obj.build(
t_inputs,
t_natoms,
{"atype": t_atype},
{"atype": t_atype, **extras},
suffix=suffix,
)
return [t_des], {
return [t_out], {
t_inputs: inputs,
t_natoms: natoms,
t_atype: atype,
**feed_dict,
}
31 changes: 31 additions & 0 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
(True, False), # resnet_dt
("float64", "float32"), # precision
(True, False), # distinguish_types
(0, 1), # numb_fparam
)
class TestEner(CommonTest, FittingTest, unittest.TestCase):
@property
Expand All @@ -50,11 +51,13 @@ def data(self) -> dict:
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"numb_fparam": numb_fparam,
"seed": 20240217,
}

Expand All @@ -64,6 +67,7 @@ def skip_tf(self) -> bool:
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
# TODO: distinguish_types
return not distinguish_types or CommonTest.skip_pt
Expand All @@ -74,6 +78,7 @@ def skip_pt(self) -> bool:
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
# TODO: float32 has bug
return precision == "float32" or CommonTest.skip_pt
Expand All @@ -84,6 +89,7 @@ def skip_dp(self) -> bool:
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
# TODO: float32 has bug
return precision == "float32" or CommonTest.skip_dp
Expand All @@ -102,13 +108,15 @@ def setUp(self):
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
# inconsistent if not sorted
self.atype.sort()
self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION)

@property
def addtional_data(self) -> dict:
(
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
return {
"ntypes": self.ntypes,
Expand All @@ -117,29 +125,52 @@ def addtional_data(self) -> dict:
}

def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
(
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
return self.build_tf_fitting(
obj,
self.inputs.ravel(),
self.natoms,
self.atype,
self.fparam if numb_fparam else None,
suffix,
)

def eval_pt(self, pt_obj: Any) -> Any:
(
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
return (
pt_obj(
torch.from_numpy(self.inputs).to(device=PT_DEVICE),
torch.from_numpy(self.atype).to(device=PT_DEVICE),
fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE)
if numb_fparam
else None,
)["energy"]
.detach()
.cpu()
.numpy()
)

def eval_dp(self, dp_obj: Any) -> Any:
(
resnet_dt,
precision,
distinguish_types,
numb_fparam,
) = self.param
return dp_obj(
self.inputs,
self.atype,
fparam=self.fparam if numb_fparam else None,
)["energy"]

def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]:
Expand Down

0 comments on commit d8d7873

Please sign in to comment.