diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index 6ae1629314..6303dd1d0f 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -70,6 +70,7 @@ def __init__( self.padding = padding self.use_econf_tebd = use_econf_tebd self.type_map = type_map + embed_input_dim = ntypes if self.use_econf_tebd: from deepmd.utils.econf_embd import ( ECONF_DIM, @@ -87,8 +88,9 @@ def __init__( [electronic_configuration_embedding[kk] for kk in self.type_map], dtype=PRECISION_DICT[self.precision], ) + embed_input_dim = ECONF_DIM self.embedding_net = EmbeddingNet( - ECONF_DIM if self.use_econf_tebd else ntypes, + embed_input_dim, self.neuron, self.activation_function, self.resnet_dt, diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index aa844fb659..d6fcccf73d 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -674,6 +674,7 @@ def __init__( self.use_econf_tebd = use_econf_tebd self.type_map = type_map self.econf_tebd = None + embed_input_dim = ntypes if self.use_econf_tebd: from deepmd.utils.econf_embd import ( ECONF_DIM, @@ -693,22 +694,15 @@ def __init__( dtype=PRECISION_DICT[self.precision], ) ) - self.embedding_net = EmbeddingNet( - ECONF_DIM, - self.neuron, - self.activation_function, - self.resnet_dt, - self.precision, - ) - else: - # no way to pass seed? - self.embedding_net = EmbeddingNet( - ntypes, - self.neuron, - self.activation_function, - self.resnet_dt, - self.precision, - ) + embed_input_dim = ECONF_DIM + # no way to pass seed? + self.embedding_net = EmbeddingNet( + embed_input_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) for param in self.parameters(): param.requires_grad = trainable diff --git a/deepmd/tf/utils/type_embed.py b/deepmd/tf/utils/type_embed.py index ffa68abc60..b858e2443c 100644 --- a/deepmd/tf/utils/type_embed.py +++ b/deepmd/tf/utils/type_embed.py @@ -307,26 +307,20 @@ def serialize(self, suffix: str = "") -> dict: else: type_embedding_pattern = TYPE_EMBEDDING_PATTERN assert self.type_embedding_net_variables is not None - if not self.use_econf_tebd: - embedding_net = EmbeddingNet( - in_dim=self.ntypes, - neuron=self.neuron, - activation_function=self.filter_activation_fn_name, - resnet_dt=self.filter_resnet_dt, - precision=self.filter_precision.name, - ) - else: + embed_input_dim = self.ntypes + if self.use_econf_tebd: from deepmd.utils.econf_embd import ( ECONF_DIM, ) - embedding_net = EmbeddingNet( - in_dim=ECONF_DIM, - neuron=self.neuron, - activation_function=self.filter_activation_fn_name, - resnet_dt=self.filter_resnet_dt, - precision=self.filter_precision.name, - ) + embed_input_dim = ECONF_DIM + embedding_net = EmbeddingNet( + in_dim=embed_input_dim, + neuron=self.neuron, + activation_function=self.filter_activation_fn_name, + resnet_dt=self.filter_resnet_dt, + precision=self.filter_precision.name, + ) for key, value in self.type_embedding_net_variables.items(): m = re.search(type_embedding_pattern, key) m = [mm for mm in m.groups() if mm is not None] diff --git a/deepmd/utils/econf_embd.py b/deepmd/utils/econf_embd.py index 8719756dfa..a1b427ac7d 100644 --- a/deepmd/utils/econf_embd.py +++ b/deepmd/utils/econf_embd.py @@ -172,9 +172,9 @@ maxn = 7 maxl = maxn maxm = 2 * maxl + 1 -ECONF_DIM = 59 type_map = dpdata.periodic_table.ELEMENTS +ECONF_DIM = electronic_configuration_embedding[type_map[0]].shape[0] def make_empty_list_vec():