Skip to content

Commit

Permalink
[ONNX] Add pass that fuses Conv and BatchNormalization (pytorch#40547)
Browse files Browse the repository at this point in the history
Summary:
Add pass that fuses Conv and Batchnormalization nodes into one node Conv.
This pass is only applied in inference mode (training is None or TrainingMode.Eval).
Since this pass needs access to param_dict it is written outside peephole file where these kind of passes (fusing multiple nodes into one) is usually placed.

This PR also adds wrapper skipIfNoEmbed to skip debug_embed_params test:
Pass that fuses Conv and Batchnorm changes the params of resnet model and parameters of onnx and pytorch model won't match. Since parameters are not matching, debug_embed_params test for test_resnet will fail and that is expected, therefore debug_embed_params test for test_resnet should be skipped.

Pull Request resolved: pytorch#40547

Reviewed By: gchanan

Differential Revision: D22631687

Pulled By: bzinodev

fbshipit-source-id: fe45812400398a32541e797f727fd8697eb6d8c0
  • Loading branch information
KsenijaS authored and facebook-github-bot committed Jul 22, 2020
1 parent ad7133d commit af5d0bf
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 7 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ namespace c10 {
_(onnx, LogSoftmax) \
_(onnx, ReduceL1) \
_(onnx, ReduceL2) \
_(onnx, Conv) \
_(onnx, BatchNormalization) \
FORALL_ATTR_BASE_SYMBOLS(_) \
_(attr, Subgraph) \
_(attr, ReverseSubgraph) \
Expand Down
8 changes: 8 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def wrapper(self):
return func(self)
return wrapper

def skipIfNoEmbed(func):
def wrapper(self):
if not self.embed_params:
raise unittest.SkipTest("Skip debug embed_params test")
return func(self)
return wrapper

# def import_model(proto, input, workspace=None, use_gpu=True):
# model_def = onnx.ModelProto.FromString(proto)
# onnx.checker.check_model(model_def)
Expand Down Expand Up @@ -504,6 +511,7 @@ def test_inception(self):
self.run_model_test(inception_v3(), train=False, batch_size=BATCH_SIZE,
state_dict=state_dict, input=x)

@skipIfNoEmbed
def test_resnet(self):
state_dict = model_zoo.load_url(model_urls['resnet50'], progress=False)
self.run_model_test(resnet50(), train=False, batch_size=BATCH_SIZE,
Expand Down
45 changes: 45 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,51 @@ def test_r2plus1d_18_video(self):
x = torch.randn(1, 3, 4, 112, 112, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)

def test_fuse_conv_bn1d(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
self.bn = torch.nn.BatchNorm1d(33)

def forward(self, x):
out = self.conv(x)
return self.bn(out)

model = Fuse()
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, (x,))

def test_fuse_conv_bn2d(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
self.bn = torch.nn.BatchNorm2d(2)

def forward(self, x):
out = self.conv(x)
return self.bn(out)

model = Fuse()
x = torch.randn(2, 3, 2, 2, requires_grad=True)
self.run_test(model, (x,))

def test_fuse_conv_bn3d(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv3d(3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False)
self.bn = torch.nn.BatchNorm3d(2)

def forward(self, x):
out = self.conv(x)
return self.bn(out)

model = Fuse()
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)

def test_reshape_constant_fold(self):
class Reshape(torch.nn.Module):
def __init__(self, ):
Expand Down
107 changes: 106 additions & 1 deletion test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import torch
import torch.onnx
from torch.onnx import utils, OperatorExportTypes
from torch.onnx import utils, OperatorExportTypes, TrainingMode
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
import torch.utils.cpp_extension
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion

import torchvision

import onnx
import onnxruntime # noqa

Expand Down Expand Up @@ -675,6 +677,109 @@ def forward(self, x):

np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)

def test_fuse_conv_bn(self):
class Fuse(torch.nn.Module):
def __init__(self):
super(Fuse, self).__init__()
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=True)
self.bn = torch.nn.BatchNorm2d(2)

def forward(self, x):
out = self.conv(x)
return self.bn(out)

x = torch.randn(2, 3, 2, 2, requires_grad=True)
graph, _, __ = utils._model_to_graph(Fuse(), (x, ),
do_constant_folding=True,
training=TrainingMode.EVAL)
for node in graph.nodes():
assert node.kind() != "onnx::BatchNormalization"
assert node.kind() == "onnx::Conv"

assert len(list(graph.nodes())) == 1

def test_fuse_resnet18(self):
model = torchvision.models.resnet18(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
graph, _, __ = utils._model_to_graph(model, (x, ),
do_constant_folding=True)

for node in graph.nodes():
assert node.kind() != "onnx::BatchNormalization"

def test_conv_bn(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
self.bn = torch.nn.BatchNorm2d(16, affine=True)

def forward(self, x):
x = self.conv(x)
bn = self.bn(x)
return bn

model = MyModule()
x = torch.randn(10, 3, 128, 128)

f = io.BytesIO()
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
ort_outs1 = ort_sess.run(None, ort_inputs)

f = io.BytesIO()
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
ort_outs2 = ort_sess.run(None, ort_inputs)
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)]

def test_multiple_conv_bn(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.conv2 = torch.nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=False)
self.conv3 = torch.nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = torch.nn.BatchNorm2d(64)
self.bn2 = torch.nn.BatchNorm2d(2)
self.relu = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)


def forward(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn2(x)
x = self.relu(x)
return x

model = MyModule()
x = torch.randn(2, 3, 224, 224)

f = io.BytesIO()
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
ort_outs1 = ort_sess.run(None, ort_inputs)
f = io.BytesIO()
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
ort_outs2 = ort_sess.run(None, ort_inputs)
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)]


# opset 10 tests
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/python/init.cpp",
"torch/csrc/jit/passes/onnx.cpp",
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
"torch/csrc/jit/passes/onnx/eval_peephole.cpp",
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp",
Expand Down
155 changes: 155 additions & 0 deletions torch/csrc/jit/passes/onnx/eval_peephole.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/torch.h>

#include <c10/util/Optional.h>
#include <algorithm>

namespace torch {
namespace jit {

namespace onnx {
using namespace ::c10::onnx;
}

std::vector<at::Tensor> getValues(
Node* node,
const ValueToParamPairMap& valsToParamsMap) {
size_t numInputs = node->inputs().size();
std::vector<at::Tensor> inputTensorValues;
inputTensorValues.reserve(numInputs);
for (auto val : node->inputs()) {
if (val->node()->kind() == prim::Param) {
auto itr = valsToParamsMap.find(val);
if (itr == valsToParamsMap.end()) {
continue;
}
inputTensorValues.push_back(itr->second.second.toTensor());
} else if (val->node()->kind() == onnx::Constant) {
inputTensorValues.push_back(val->node()->t(attr::value));
} else {
continue;
}
}
return inputTensorValues;
}

// This pass fuses Conv and BatchNorm into Conv node
// Conv and BatchNorm can be fused only if inputs for Batchnorm node:
// scale, bias, mean and var are all tensors of same shape (C) and
// if the size of the first dimension (dim 0) is the same between Conv
// input weight and Batchnorm input scale
static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
for (auto* child_block : it->blocks()) {
fuseConvBatchNorm(child_block, valsToParamsMap);
}
if (it->kind() == onnx::Conv) {
if (it->output()->uses().size() != 1) {
continue;
}
auto bnNode = it->output()->uses()[0].user;
if (bnNode->kind() != onnx::BatchNormalization) {
continue;
}
auto origconvNode = *it;
auto epsilon = bnNode->f(attr::epsilon);
auto w_conv_value = getValues(origconvNode, valsToParamsMap);
if (w_conv_value.size() < 1 ||
(origconvNode->inputs().size() == 3 && w_conv_value.size() != 2)) {
continue;
}

auto bn_value = getValues(bnNode, valsToParamsMap);
if (bn_value.size() != 4) {
continue;
}

auto bn_scale = bn_value[0].clone();
auto bn_B = bn_value[1].clone();
auto bn_mean = bn_value[2].clone();
auto bn_var = bn_value[3].clone();
auto w_conv = w_conv_value[0].clone();
at::Tensor b_conv;

if (!bn_scale.is_floating_point() || !bn_B.is_floating_point() ||
!bn_mean.is_floating_point() || !bn_var.is_floating_point() ||
!w_conv.is_floating_point() || bn_scale.dim() != 1 ||
bn_B.dim() != 1 || bn_mean.dim() != 1 || bn_var.dim() != 1 ||
!(bn_scale.size(0) == bn_B.size(0)) ||
!(bn_B.size(0) == bn_mean.size(0)) ||
!(bn_mean.size(0) == bn_var.size(0)) || !(w_conv.dim() > 2) ||
!(w_conv.size(0) == bn_scale.size(0))) {
continue;
}

bn_var = bn_var.add(epsilon);
bn_var = bn_var.sqrt();
bn_scale = bn_scale.div(bn_var);

// Calculate weight
for (size_t i = 0; i < w_conv.size(0); i++) {
w_conv[i] = w_conv[i].mul(bn_scale[i]);
}

// Calculate bias
if (origconvNode->inputs().size() == 3) {
b_conv = w_conv_value[1].clone();
b_conv = b_conv.sub(bn_mean);
b_conv = b_conv.mul(bn_scale);
b_conv = b_conv.add(bn_B);
} else {
bn_mean = bn_mean.mul(bn_scale);
bn_B = bn_B.sub(bn_mean);
b_conv = bn_B;
}

Node* convNode =
b->owningGraph()->create(onnx::Conv, bnNode->outputs().size());
for (size_t i = 0; i < convNode->outputs().size(); ++i) {
convNode->outputs()[i]->copyMetadata(bnNode->outputs()[i]);
}

convNode->copyAttributes(*origconvNode);
convNode->insertBefore(bnNode);
convNode->addInput(origconvNode->inputs().at(0));

auto conv_W = b->owningGraph()->addInput();
valsToParamsMap.insert(
{conv_W, std::make_pair(conv_W->debugName(), w_conv)});
conv_W->inferTypeFrom(w_conv);
convNode->addInput(conv_W);

auto conv_B = b->addInput();
valsToParamsMap.insert(
{conv_B, std::make_pair(conv_B->debugName(), b_conv)});
conv_B->inferTypeFrom(b_conv);
convNode->addInput(conv_B);

bnNode->replaceAllUsesWith(convNode);
bnNode->removeAllInputs();
it->removeAllInputs();
bnNode->destroy();
it.destroyCurrent();
}
}
}

void buildParamsMapFromValueToParamsMap(
const ValueToParamPairMap& valsToParamsMap,
ParamMap& paramsDict) {
paramsDict.clear();
for (const auto& nameTensorParamPair : valsToParamsMap) {
paramsDict.insert(nameTensorParamPair.second);
}
}

void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
fuseConvBatchNorm(b, valsToParamsMap);
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
return;
}

} // namespace jit
} // namespace torch
12 changes: 12 additions & 0 deletions torch/csrc/jit/passes/onnx/eval_peephole.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

void EvalPeepholeONNX(Block* b, std::map<std::string, IValue>& paramDict);

} // namespace jit

} // namespace torch
Loading

0 comments on commit af5d0bf

Please sign in to comment.