From c9e18b8fa1caaab021aefc29587966b67eb09883 Mon Sep 17 00:00:00 2001 From: XiaoXYe <50827462+XiaoXYe@users.noreply.github.com> Date: Fri, 24 Jul 2020 13:43:43 +0800 Subject: [PATCH] Add support for pytorch1.5.1 * Update pytorch_graph.py * Update pytorch_parser.py * Update pytorch_parser.py * Update pytorch_graph.py * Update pytorch_emitter.py * Update pytorch_graph.py * Update pytorch_parser.py * Update keras2_emitter.py * Update convertToIR.py * Update conversion_imagenet.py --- mmdnn/conversion/_script/convertToIR.py | 9 +- mmdnn/conversion/keras/keras2_emitter.py | 10 +- mmdnn/conversion/pytorch/pytorch_emitter.py | 4 +- mmdnn/conversion/pytorch/pytorch_graph.py | 135 ++++++++++++++------ mmdnn/conversion/pytorch/pytorch_parser.py | 99 ++++++++++---- tests/conversion_imagenet.py | 14 +- 6 files changed, 201 insertions(+), 70 deletions(-) diff --git a/mmdnn/conversion/_script/convertToIR.py b/mmdnn/conversion/_script/convertToIR.py index 2cfa339e..7c906151 100644 --- a/mmdnn/conversion/_script/convertToIR.py +++ b/mmdnn/conversion/_script/convertToIR.py @@ -86,10 +86,15 @@ def _convert(args): elif args.srcFramework == 'pytorch': assert inputshape != None - from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser + from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser040 + from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser151 + import torch model = args.network or args.weights assert model != None - parser = PytorchParser(model, inputshape[0]) + if torch.__version__ == "0.4.0": + parser = PytorchParser040(model, inputshape[0]) + else: + parser = PytorchParser151(model, inputshape[0]) elif args.srcFramework == 'torch' or args.srcFramework == 'torch7': from mmdnn.conversion.torch.torch_parser import TorchParser diff --git a/mmdnn/conversion/keras/keras2_emitter.py b/mmdnn/conversion/keras/keras2_emitter.py index dd347c34..b6ad046b 100644 --- a/mmdnn/conversion/keras/keras2_emitter.py +++ b/mmdnn/conversion/keras/keras2_emitter.py @@ -168,6 +168,12 @@ def _emit_activation(self, IR_node, op, in_scope=False): def _emit_merge(self, IR_node, func): if len(IR_node.in_edges) == 1: + if func == "concatenate": + inputs = ', '.join('%s' % self.parent_variable_name(IR_node, i) for i in IR_node.in_edges) + code = "{:<15} = {}".format( + IR_node.variable_name, + inputs) + return code IR_node.in_edges.append(IR_node.in_edges[0]) inputs = ', '.join('%s' % self.parent_variable_name(IR_node, i) for i in IR_node.in_edges) axis = ' axis = {},'.format(IR_node.get_attr('axis')) if 'axis' in IR_node.layer.attr else "" @@ -1077,7 +1083,7 @@ def call(self, x, mask=None): squared = K.square(x) scale = self.k norm_alpha = self.alpha / self.n - if (K.image_data_format() == 'channels_first'): + if K.image_data_format() == 'channels_first': b, f, r, c = self.shape squared = K.expand_dims(squared, 0) squared = K.spatial_3d_padding(squared, padding=((half_n, half_n), (0, 0), (0,0))) @@ -1350,4 +1356,4 @@ def mul_constant(weight_factor, layer_name): weight = Lambda(lambda x: x*weight_factor) weight(layer_name) return weight.output -''') \ No newline at end of file +''') diff --git a/mmdnn/conversion/pytorch/pytorch_emitter.py b/mmdnn/conversion/pytorch/pytorch_emitter.py index 4a4a1dcc..ca6d6e0e 100644 --- a/mmdnn/conversion/pytorch/pytorch_emitter.py +++ b/mmdnn/conversion/pytorch/pytorch_emitter.py @@ -158,7 +158,7 @@ def emit_Conv(self, IR_node): if IR_node.type == 'DepthwiseConv': group = in_channels - filter *= group + filter = group else: group = IR_node.get_attr('group', 1) @@ -522,7 +522,7 @@ def _convert_axis(self, IR_node, axis): def emit_Concat(self, IR_node): axis = self._convert_axis(IR_node, IR_node.get_attr('axis')) - code = "{:<15} = torch.cat(({}), {})".format( + code = "{:<15} = torch.cat(({},), {})".format( IR_node.variable_name, ', '.join(self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))), axis, diff --git a/mmdnn/conversion/pytorch/pytorch_graph.py b/mmdnn/conversion/pytorch/pytorch_graph.py index ccb63f74..1b8306ef 100644 --- a/mmdnn/conversion/pytorch/pytorch_graph.py +++ b/mmdnn/conversion/pytorch/pytorch_graph.py @@ -5,7 +5,6 @@ from mmdnn.conversion.common.DataStructure.graph import GraphNode, Graph import torch -import torch.jit import torch.autograd import torch.serialization import contextlib @@ -16,18 +15,14 @@ class PytorchGraphNode(GraphNode): def __init__(self, layer): - self._name = layer.scopeName() - self._kind = layer.kind() + self.version = torch.__version__ import re + self._kind = layer.kind() node_id = re.search(r"[\d]+", layer.__str__()) self.id = node_id.group(0) - super(PytorchGraphNode, self).__init__(layer) self.attrs = {k : layer[k] for k in layer.attributeNames()} - self.weights_name = '.'.join( - re.findall(r'\[([\w\d.]+)\]', self._name) - ) @property @@ -48,6 +43,22 @@ def pytorch_layer(self): return self.layer +class PytorchGraphNode040(PytorchGraphNode): + def __init__(self, layer): + self._name = layer.scopeName() + import re + self.weights_name = '.'.join( + re.findall(r'\[([\w\d.]+)\]', self._name) + ) + super(PytorchGraphNode040, self).__init__(layer) + + +class PytorchGraphNode151(PytorchGraphNode): + + def __init__(self, layer): + self._name = 'node' + super(PytorchGraphNode151, self).__init__(layer) + class PytorchGraph(Graph): @@ -58,6 +69,7 @@ def __init__(self, model): self.model = model self.state_dict = _unique_state_dict(self.model) self.shape_dict = dict() + self.layer_weight_map = dict() @staticmethod @@ -110,50 +122,101 @@ def set_training(self, model, mode): def build(self, shape): """ - build graph for pytorch 0.4.0 + build graph for pytorch """ - import re # construct graph dummy_input = torch.autograd.Variable(torch.randn(shape), requires_grad=False) - - with self.set_training(self.model, False): - trace, output = torch.jit.get_trace_graph(self.model, (dummy_input, )) - - trace.set_graph(PytorchGraph._optimize_graph(trace.graph(), False)) - # nodes - nodes = list(trace.graph().nodes()) - - - # input layer - # TODO - - - + graph, nodes = self.extractgraph(dummy_input) + # build each layer for node in nodes: - node_id = PytorchGraph.get_node_id(node) - node_scope = node.scopeName() - node_name = node_scope + node_id - node_name = node_name.replace('-','n').replace('\\','n').replace('/','n').replace('_','n').replace('[','n').replace(']','n') + node_name = self.rename_nodes(node, node_id) output_shape_str = re.findall(r'[^()!]+', node.__str__())[1] - output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')] - - + if '%' in output_shape_str: + out_put_shape = None + else: + output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')] self.shape_dict[node_name] = output_shape - self.layer_map[node_name] = PytorchGraphNode(node) + self.layer_map[node_name] = self.CreateGraphNode(node) self.layer_name_map[node_name] = node_name + # make connection + self.node_connection(graph, node, node_name) + + super(PytorchGraph, self).build() + + +class PytorchGraph040(PytorchGraph): + + def __init__(self, model): + super(PytorchGraph040, self).__init__(model) - # input - for node_input in list(node.inputs()): + def extractgraph(self, dummy_input): + with self.set_training(self.model, False): + import torch.jit + trace, output = torch.jit.get_trace_graph(self.model, (dummy_input, )) - if PytorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName(): + trace.set_graph(PytorchGraph._optimize_graph(trace.graph(), False)) + # nodes + nodes = list(trace.graph().nodes()) + graph = trace.graph() + return graph, nodes + + def rename_nodes(self, node, node_id): + node_scope = node.scopeName() + node_name = node_scope + node_id + node_name = node_name.replace('-','n').replace('\\','n').replace('/','n').replace('_','n').replace('[','n').replace(']','n') + return node_name + + def node_connection(self, graph, node, node_name): + for node_input in list(node.inputs()): + if PytorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName(): node_input_name = node_input.node().scopeName() + PytorchGraph.get_node_id(node_input.node()) node_input_name = node_input_name.replace('-','n').replace('\\','n').replace('/','n').replace('_','n').replace('[','n').replace(']','n') self._make_connection(node_input_name, node_name) - # print(node_input_name ,'->', node_name) + + def CreateGraphNode(self, node): + return PytorchGraphNode040(node) - super(PytorchGraph, self).build() +class PytorchGraph151(PytorchGraph): + + def __init__(self, model): + super(PytorchGraph151, self).__init__(model) + + def extractgraph(self, dummy_input): + import re + import torch.onnx.utils + # connect name and id in nodes with weights + graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input, _retain_param_name=True) + nodes = list(graph.nodes()) + for node in nodes: + # print(node.__str__()) + node_id = PytorchGraph.get_node_id(node) + node_name = 'node' + node_id + node_scope_str = re.findall(r'[^()!]+', node.__str__())[-2] + for x in node_scope_str.split(','): + if re.findall(r'%\S+.weight', x): + node_scope = '.'.join(re.findall(r'%\S+.weight', x)[0].replace('%','',1).split('.')[:-1]) + self.layer_weight_map[node_name] = node_scope + + graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input) + nodes = list(graph.nodes()) + return graph, nodes + + def rename_nodes(self, node, node_id): + node_name = 'node' + node_id + return node_name + + def node_connection(self, graph, node, node_name): + for node_input in list(node.inputs()): + if PytorchGraph.get_node_id(node_input.node()) and node_input.node() in graph.nodes(): + node_input_name = 'node' + PytorchGraph.get_node_id(node_input.node()) + self._make_connection(node_input_name, node_name) + + def CreateGraphNode(self, node): + return PytorchGraphNode151(node) + + diff --git a/mmdnn/conversion/pytorch/pytorch_parser.py b/mmdnn/conversion/pytorch/pytorch_parser.py index f8762c5d..4686d3f7 100644 --- a/mmdnn/conversion/pytorch/pytorch_parser.py +++ b/mmdnn/conversion/pytorch/pytorch_parser.py @@ -9,7 +9,8 @@ from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType from mmdnn.conversion.common.utils import * from mmdnn.conversion.common.DataStructure.parser import Parser -from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph +from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph040 +from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph151 import torch import torchvision @@ -21,6 +22,7 @@ class PytorchParser(Parser): 'onnx::Gemm': 'FullyConnected', 'onnx::MaxPool': 'Maxpool', 'onnx::AveragePool': 'Avgpool', + 'onnx::GlobalAveragePool': 'GAvgpool', 'onnx::Dropout': 'Dropout', 'onnx::BatchNormalization': 'BatchNormalization', 'onnx::Add': 'Add', @@ -28,7 +30,8 @@ class PytorchParser(Parser): 'onnx::Relu': 'Relu', 'onnx::Tanh': 'Tanh', 'onnx::Sigmoid': 'Sigmoid', - 'onnx::Mul': 'Mul' + 'onnx::Mul': 'Mul', + 'onnx::Pad': 'Pad' # TODO @@ -59,6 +62,8 @@ class PytorchParser(Parser): def src_graph(self): return self.pytorch_graph + def get_weight_name(self, node): + pass #################### # Public Functions # @@ -78,17 +83,17 @@ def __init__(self, model_file_name, input_shape): model = torch.load(model_file_name, map_location='cpu') self.weight_loaded = True - + self.model = model # Build network graph - self.pytorch_graph = PytorchGraph(model) + self.pytorch_graph = None + + def build_graph(self, input_shape): self.input_shape = tuple([1] + input_shape) self.pytorch_graph.build(self.input_shape) self.state_dict = self.pytorch_graph.state_dict self.shape_dict = self.pytorch_graph.shape_dict - def gen_IR(self): - for layer in self.src_graph.topological_sort: current_node = self.src_graph.get_node(layer) onnx_node_type = current_node.type @@ -148,11 +153,14 @@ def _set_output_shape(self, source_node, IR_node): # Layers # ########## def rename_UNKNOWN(self, source_node): - print (source_node.layer) - print (source_node.layer.data.size()) - assert False print("PyTorch parser has not supported operator [%s] with name [%s]." % (source_node.type, source_node.name)) + assert False + print(source_node.layer) + print(source_node.layer.data.size()) + + + def gen_Input(self): IR_node = self.IR_graph.node.add() @@ -226,11 +234,10 @@ def rename_Conv(self, source_node): kwargs['group'] = attr['group'] + weights_scope = self.get_weight_name(source_node) - - bias_name = '{0}.bias'.format(source_node.weights_name) - weights_name = '{0}.weight'.format(source_node.weights_name) - + bias_name = '{0}.bias'.format(weights_scope) + weights_name = '{0}.weight'.format(weights_scope) weight = self.state_dict[weights_name] weight = weight.numpy() @@ -266,12 +273,12 @@ def rename_BatchNormalization(self, source_node): attr = source_node.attrs # epsilon IR_node.attr['epsilon'].f = attr['epsilon'] + weights_scope = self.get_weight_name(source_node) - - bias_name = '{0}.bias'.format(source_node.weights_name) - weights_name = '{0}.weight'.format(source_node.weights_name) - mean_name = '{0}.running_mean'.format(source_node.weights_name) - var_name = '{0}.running_var'.format(source_node.weights_name) + bias_name = '{0}.bias'.format(weights_scope) + weights_name = '{0}.weight'.format(weights_scope) + mean_name = '{0}.running_mean'.format(weights_scope) + var_name = '{0}.running_var'.format(weights_scope) @@ -304,6 +311,15 @@ def rename_BatchNormalization(self, source_node): # var self.set_weight(source_node.name, "var", variance) + def rename_Pad(self, source_node): + IR_node = self._convert_identity_operation(source_node, new_op="Pad") + attr = source_node.attrs + kwargs = dict() + kwargs['mode'] = attr['mode'] + kwargs['pads'] = attr['pads'] + kwargs['constant_values'] = attr['value'] + assign_IRnode_values(IR_node, kwargs) + def rename_Relu(self, source_node): IR_node = self._convert_identity_operation(source_node, new_op="Relu") @@ -340,7 +356,10 @@ def rename_Avgpool(self, source_node): kwargs['dilations'] = [1] + [1, 1] + [1] else: kwargs['dilations'] = [1] + attr['dilations'] + [1] - kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0] + if 'pads' in attr: + kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0] + else: + kwargs['pads'] = [0, 0, 0, 0, 0, 0, 0, 0] kwargs['kernel_shape'] = [1] + attr['kernel_shape'] + [1] IR_node = self._convert_identity_operation(source_node, new_op="Pool") @@ -348,14 +367,28 @@ def rename_Avgpool(self, source_node): assign_IRnode_values(IR_node, kwargs) + def rename_GAvgpool(self, source_node): + attr = source_node.attrs + input_shape = self.pytorch_graph.shape_dict[source_node.in_edges[0]] + kwargs = dict() + kwargs['strides'] = [1, 1, 1, 1] + kwargs['dilations'] = [1] + [1, 1] + [1] + kwargs['pads'] = [0, 0, 0, 0, 0, 0, 0, 0] + kwargs['kernel_shape'] = [1] + input_shape[2:] + [1] + IR_node = self._convert_identity_operation(source_node, new_op="Pool") + + kwargs['pooling_type'] = 'AVG' + + assign_IRnode_values(IR_node, kwargs) + def rename_Flatten(self, source_node): IR_node = self._convert_identity_operation(source_node, new_op="Flatten") def rename_FullyConnected(self, source_node): IR_node = self._convert_identity_operation(source_node, new_op="FullyConnected") - - bias_name = '{0}.bias'.format(source_node.weights_name) - weights_name = '{0}.weight'.format(source_node.weights_name) + weights_scope = self.get_weight_name(source_node) + bias_name = '{0}.bias'.format(weights_scope) + weights_name = '{0}.weight'.format(weights_scope) W = self.state_dict[weights_name].numpy().transpose() @@ -435,7 +468,6 @@ def rename_Addmm(self, source_node): assign_IRnode_values(IR_node, kwargs) - print(IR_node) #################### @@ -472,3 +504,24 @@ def _convert_pooling(self, source_node): raise ValueError('Unknown pooling type') assign_IRnode_values(IR_node, kwargs) + +class PytorchParser040(PytorchParser): + + def __init__(self, model_file_name, input_shape): + super(PytorchParser040, self).__init__(model_file_name, input_shape) + self.pytorch_graph = PytorchGraph040(self.model) + self.build_graph(input_shape) + + def get_weight_name(self, node): + return node.weights_name + +class PytorchParser151(PytorchParser): + + def __init__(self, model_file_name, input_shape): + super(PytorchParser151, self).__init__(model_file_name, input_shape) + self.pytorch_graph = PytorchGraph151(self.model) + self.build_graph(input_shape) + + def get_weight_name(self, node): + return self.pytorch_graph.layer_weight_map[node.name] + diff --git a/tests/conversion_imagenet.py b/tests/conversion_imagenet.py index d4731ccb..d9123863 100644 --- a/tests/conversion_imagenet.py +++ b/tests/conversion_imagenet.py @@ -98,7 +98,6 @@ def tensorflow_frozen_parse(architecture_name, test_input_path): parser.run(IR_file) del parser del TensorflowParser2 - return original_predict @@ -255,8 +254,9 @@ def paddle_parse(architecture_name, test_input_path): @staticmethod def pytorch_parse(architecture_name, test_input_path): from mmdnn.conversion.examples.pytorch.extractor import pytorch_extractor - from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser - + from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser040 + from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser151 + import torch # download model architecture_file = pytorch_extractor.download(architecture_name, TestModels.cachedir) @@ -285,10 +285,14 @@ def pytorch_parse(architecture_name, test_input_path): # original to IR IR_file = TestModels.tmpdir + 'pytorch_' + architecture_name + "_converted" - parser = PytorchParser(architecture_file, [3, size, size]) + if torch.__version__ == "0.4.0": + parser = PytorchParser040(architecture_file, [3, size, size]) + else: + parser = PytorchParser151(architecture_file, [3, size, size]) parser.run(IR_file) del parser - del PytorchParser + del PytorchParser040 + del PytorchParser151 return original_predict