From 27f448da3e53c50a5bb5a81f8980b718d259c83f Mon Sep 17 00:00:00 2001 From: nara Date: Mon, 8 Jul 2024 15:43:16 +0900 Subject: [PATCH] Use value infos to determine input shapes of ONNX nodes --- thop/onnx_profile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thop/onnx_profile.py b/thop/onnx_profile.py index 10da68c..4109046 100644 --- a/thop/onnx_profile.py +++ b/thop/onnx_profile.py @@ -66,6 +66,8 @@ def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor: input = model.graph.input output = model.graph.output name2dims = self.create_dict(weight, input, output) + for v in model.graph.value_info: + name2dims[v.name] = np.array([i.dim_value for i in v.type.tensor_type.shape.dim]) macs = 0 for n in nodes: macs_adding, out_size, outname = self.nodes_counter(name2dims, n)