From c06f54a7d9f9d58d9282be713f15abde783c376e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 03:34:54 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/tabulate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 911e46e620..7394ac082d 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -3,7 +3,6 @@ from functools import ( cached_property, ) -from unittest import result import numpy as np import torch @@ -332,15 +331,15 @@ def _get_layer_size(self): else: raise RuntimeError("Unsupported descriptor") return layer_size - + def _get_network_variable(self, var_name: str) -> dict: """Get network variables (weights or biases) for all layers. - + Parameters ---------- var_name : str Name of the variable to get ('w' for weights, 'b' for biases) - + Returns ------- dict @@ -350,7 +349,9 @@ def _get_network_variable(self, var_name: str) -> dict: for layer in range(1, self.layer_size + 1): result["layer_" + str(layer)] = [] if self.descrpt_type == "Atten": - node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][var_name] + node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ + var_name + ] result["layer_" + str(layer)].append(node) elif self.descrpt_type == "A": if self.type_one_side: