Skip to content

Commit

Permalink
Add support for pytorch1.5.1
Browse files Browse the repository at this point in the history
* Update pytorch_graph.py

* Update pytorch_parser.py

* Update pytorch_parser.py

* Update pytorch_graph.py

* Update pytorch_emitter.py

* Update pytorch_graph.py

* Update pytorch_parser.py

* Update keras2_emitter.py

* Update convertToIR.py

* Update conversion_imagenet.py
  • Loading branch information
XiaoXYe authored Jul 24, 2020
1 parent 925036e commit c9e18b8
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 70 deletions.
9 changes: 7 additions & 2 deletions mmdnn/conversion/_script/convertToIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ def _convert(args):

elif args.srcFramework == 'pytorch':
assert inputshape != None
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser040
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser151
import torch
model = args.network or args.weights
assert model != None
parser = PytorchParser(model, inputshape[0])
if torch.__version__ == "0.4.0":
parser = PytorchParser040(model, inputshape[0])
else:
parser = PytorchParser151(model, inputshape[0])

elif args.srcFramework == 'torch' or args.srcFramework == 'torch7':
from mmdnn.conversion.torch.torch_parser import TorchParser
Expand Down
10 changes: 8 additions & 2 deletions mmdnn/conversion/keras/keras2_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def _emit_activation(self, IR_node, op, in_scope=False):

def _emit_merge(self, IR_node, func):
if len(IR_node.in_edges) == 1:
if func == "concatenate":
inputs = ', '.join('%s' % self.parent_variable_name(IR_node, i) for i in IR_node.in_edges)
code = "{:<15} = {}".format(
IR_node.variable_name,
inputs)
return code
IR_node.in_edges.append(IR_node.in_edges[0])
inputs = ', '.join('%s' % self.parent_variable_name(IR_node, i) for i in IR_node.in_edges)
axis = ' axis = {},'.format(IR_node.get_attr('axis')) if 'axis' in IR_node.layer.attr else ""
Expand Down Expand Up @@ -1077,7 +1083,7 @@ def call(self, x, mask=None):
squared = K.square(x)
scale = self.k
norm_alpha = self.alpha / self.n
if (K.image_data_format() == 'channels_first'):
if K.image_data_format() == 'channels_first':
b, f, r, c = self.shape
squared = K.expand_dims(squared, 0)
squared = K.spatial_3d_padding(squared, padding=((half_n, half_n), (0, 0), (0,0)))
Expand Down Expand Up @@ -1350,4 +1356,4 @@ def mul_constant(weight_factor, layer_name):
weight = Lambda(lambda x: x*weight_factor)
weight(layer_name)
return weight.output
''')
''')
4 changes: 2 additions & 2 deletions mmdnn/conversion/pytorch/pytorch_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def emit_Conv(self, IR_node):

if IR_node.type == 'DepthwiseConv':
group = in_channels
filter *= group
filter = group

else:
group = IR_node.get_attr('group', 1)
Expand Down Expand Up @@ -522,7 +522,7 @@ def _convert_axis(self, IR_node, axis):

def emit_Concat(self, IR_node):
axis = self._convert_axis(IR_node, IR_node.get_attr('axis'))
code = "{:<15} = torch.cat(({}), {})".format(
code = "{:<15} = torch.cat(({},), {})".format(
IR_node.variable_name,
', '.join(self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))),
axis,
Expand Down
135 changes: 99 additions & 36 deletions mmdnn/conversion/pytorch/pytorch_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from mmdnn.conversion.common.DataStructure.graph import GraphNode, Graph
import torch
import torch.jit
import torch.autograd
import torch.serialization
import contextlib
Expand All @@ -16,18 +15,14 @@
class PytorchGraphNode(GraphNode):

def __init__(self, layer):
self._name = layer.scopeName()
self._kind = layer.kind()
self.version = torch.__version__
import re
self._kind = layer.kind()
node_id = re.search(r"[\d]+", layer.__str__())
self.id = node_id.group(0)

super(PytorchGraphNode, self).__init__(layer)
self.attrs = {k : layer[k] for k in layer.attributeNames()}

self.weights_name = '.'.join(
re.findall(r'\[([\w\d.]+)\]', self._name)
)


@property
Expand All @@ -48,6 +43,22 @@ def pytorch_layer(self):
return self.layer


class PytorchGraphNode040(PytorchGraphNode):
def __init__(self, layer):
self._name = layer.scopeName()
import re
self.weights_name = '.'.join(
re.findall(r'\[([\w\d.]+)\]', self._name)
)
super(PytorchGraphNode040, self).__init__(layer)


class PytorchGraphNode151(PytorchGraphNode):

def __init__(self, layer):
self._name = 'node'
super(PytorchGraphNode151, self).__init__(layer)



class PytorchGraph(Graph):
Expand All @@ -58,6 +69,7 @@ def __init__(self, model):
self.model = model
self.state_dict = _unique_state_dict(self.model)
self.shape_dict = dict()
self.layer_weight_map = dict()


@staticmethod
Expand Down Expand Up @@ -110,50 +122,101 @@ def set_training(self, model, mode):

def build(self, shape):
"""
build graph for pytorch 0.4.0
build graph for pytorch
"""

import re
# construct graph
dummy_input = torch.autograd.Variable(torch.randn(shape), requires_grad=False)


with self.set_training(self.model, False):
trace, output = torch.jit.get_trace_graph(self.model, (dummy_input, ))

trace.set_graph(PytorchGraph._optimize_graph(trace.graph(), False))
# nodes
nodes = list(trace.graph().nodes())


# input layer
# TODO



graph, nodes = self.extractgraph(dummy_input)

# build each layer
for node in nodes:

node_id = PytorchGraph.get_node_id(node)
node_scope = node.scopeName()
node_name = node_scope + node_id
node_name = node_name.replace('-','n').replace('\\','n').replace('/','n').replace('_','n').replace('[','n').replace(']','n')
node_name = self.rename_nodes(node, node_id)
output_shape_str = re.findall(r'[^()!]+', node.__str__())[1]
output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')]


if '%' in output_shape_str:
out_put_shape = None
else:
output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')]
self.shape_dict[node_name] = output_shape
self.layer_map[node_name] = PytorchGraphNode(node)
self.layer_map[node_name] = self.CreateGraphNode(node)
self.layer_name_map[node_name] = node_name
# make connection
self.node_connection(graph, node, node_name)

super(PytorchGraph, self).build()


class PytorchGraph040(PytorchGraph):

def __init__(self, model):
super(PytorchGraph040, self).__init__(model)

# input
for node_input in list(node.inputs()):
def extractgraph(self, dummy_input):
with self.set_training(self.model, False):
import torch.jit
trace, output = torch.jit.get_trace_graph(self.model, (dummy_input, ))

if PytorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName():
trace.set_graph(PytorchGraph._optimize_graph(trace.graph(), False))
# nodes
nodes = list(trace.graph().nodes())
graph = trace.graph()
return graph, nodes

def rename_nodes(self, node, node_id):
node_scope = node.scopeName()
node_name = node_scope + node_id
node_name = node_name.replace('-','n').replace('\\','n').replace('/','n').replace('_','n').replace('[','n').replace(']','n')
return node_name

def node_connection(self, graph, node, node_name):
for node_input in list(node.inputs()):
if PytorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName():
node_input_name = node_input.node().scopeName() + PytorchGraph.get_node_id(node_input.node())
node_input_name = node_input_name.replace('-','n').replace('\\','n').replace('/','n').replace('_','n').replace('[','n').replace(']','n')
self._make_connection(node_input_name, node_name)
# print(node_input_name ,'->', node_name)

def CreateGraphNode(self, node):
return PytorchGraphNode040(node)


super(PytorchGraph, self).build()
class PytorchGraph151(PytorchGraph):

def __init__(self, model):
super(PytorchGraph151, self).__init__(model)

def extractgraph(self, dummy_input):
import re
import torch.onnx.utils
# connect name and id in nodes with weights
graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input, _retain_param_name=True)
nodes = list(graph.nodes())
for node in nodes:
# print(node.__str__())
node_id = PytorchGraph.get_node_id(node)
node_name = 'node' + node_id
node_scope_str = re.findall(r'[^()!]+', node.__str__())[-2]
for x in node_scope_str.split(','):
if re.findall(r'%\S+.weight', x):
node_scope = '.'.join(re.findall(r'%\S+.weight', x)[0].replace('%','',1).split('.')[:-1])
self.layer_weight_map[node_name] = node_scope

graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input)
nodes = list(graph.nodes())
return graph, nodes

def rename_nodes(self, node, node_id):
node_name = 'node' + node_id
return node_name

def node_connection(self, graph, node, node_name):
for node_input in list(node.inputs()):
if PytorchGraph.get_node_id(node_input.node()) and node_input.node() in graph.nodes():
node_input_name = 'node' + PytorchGraph.get_node_id(node_input.node())
self._make_connection(node_input_name, node_name)

def CreateGraphNode(self, node):
return PytorchGraphNode151(node)


Loading

0 comments on commit c9e18b8

Please sign in to comment.