diff --git a/fetch-repos.sh b/fetch-repos.sh index 6e904b3959..397f29637d 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,12 +27,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="04e24583fb5c1895744801480db3ced8a5b6a914" +QONNX_COMMIT="47e4357faf66b5b0d1bf77bf908bb47752421e5b" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" -BREVITAS_COMMIT="9bb26bf2798de210a267d1e4aed4c20087e0e8a5" +BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="766e457465f5c0dd315490d7b9cc5d74f9a76f4f" CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4" -HLSLIB_COMMIT="c17aa478ae574971d115afa9fa4d9c215857d1ac" +HLSLIB_COMMIT="16e5847a5e3ef76cffe84c8fad2f010d593457d3" OMX_COMMIT="0b59762f9e4c4f7e5aa535ee9bc29f292434ca7a" AVNET_BDF_COMMIT="2d49cfc25766f07792c0b314489f21fe916b639b" XIL_BDF_COMMIT="8cf4bb674a919ac34e3d99d8d71a9e60af93d14e" diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index c120667d81..d6c0794b00 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -43,6 +43,7 @@ from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch from finn.custom_op.fpgadataflow.eltwise import StreamingEltwise from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch +from finn.custom_op.fpgadataflow.fmpadding_pixel import FMPadding_Pixel from finn.custom_op.fpgadataflow.fmpadding_rtl import FMPadding_rtl from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch from finn.custom_op.fpgadataflow.iodma import IODMA @@ -83,6 +84,7 @@ custom_op["GlobalAccPool_Batch"] = GlobalAccPool_Batch custom_op["Pool_Batch"] = Pool_Batch custom_op["FMPadding_Batch"] = FMPadding_Batch +custom_op["FMPadding_Pixel"] = FMPadding_Pixel custom_op["Thresholding_Batch"] = Thresholding_Batch custom_op["AddStreams_Batch"] = AddStreams_Batch custom_op["LabelSelect_Batch"] = LabelSelect_Batch diff --git a/src/finn/custom_op/fpgadataflow/fmpadding_pixel.py b/src/finn/custom_op/fpgadataflow/fmpadding_pixel.py new file mode 100644 index 0000000000..bc686bc6d2 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/fmpadding_pixel.py @@ -0,0 +1,335 @@ +# Copyright (c) 2023, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import numpy as np +import os +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + + +class FMPadding_Pixel(HLSCustomOp): + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = { + # spatial size of input images + "ImgDim": ("ints", True, []), + # stride to apply, can be non-square + "Stride": ("ints", True, []), + # number of channels in input image + "NumChannels": ("i", True, 0), + # SIMD Input parallelism + "SIMD": ("i", False, 1), + # FINN input datatype + "inputDataType": ("s", True, ""), + # shape describing input vecs per execution + "numInputVectors": ("i", False, 1), + } + my_attrs.update(super().get_nodeattr_types()) + return my_attrs + + def get_padded_odim(self): + "Return the padded spatial size of the output." + idim_h, idim_w = self.get_nodeattr("ImgDim") + stride_h, stride_w = self.get_nodeattr("Stride") + odim_h = idim_h + (idim_h - 1) * (stride_h - 1) + odim_w = idim_w + (idim_w - 1) * (stride_w - 1) + return [odim_h, odim_w] + + def get_exp_cycles(self): + odim_h, odim_w = self.get_padded_odim() + channels = self.get_nodeattr("NumChannels") + simd = self.get_nodeattr("SIMD") + batch_size = self.get_nodeattr("numInputVectors") + exp_cycles = (channels / simd) * batch_size * odim_h * odim_w + return int(exp_cycles) + + def get_normal_input_shape(self, ind=0): + idim_h, idim_w = self.get_nodeattr("ImgDim") + num_ch = self.get_nodeattr("NumChannels") + ishape = (1, idim_h, idim_w, num_ch) + return ishape + + def get_normal_output_shape(self, ind=0): + odim_h, odim_w = self.get_padded_odim() + num_ch = self.get_nodeattr("NumChannels") + oshape = (1, odim_h, odim_w, num_ch) + return oshape + + def get_folded_input_shape(self, ind=0): + normal_ishape = list(self.get_normal_input_shape()) + ifm_ch = self.get_nodeattr("NumChannels") + simd = self.get_nodeattr("SIMD") + assert ifm_ch % simd == 0, "SIMD must divide input channels" + fold = int(normal_ishape[-1] / simd) + folded_ishape = normal_ishape[:-1] + [fold, simd] + return tuple(folded_ishape) + + def get_folded_output_shape(self, ind=0): + normal_oshape = list(self.get_normal_output_shape()) + ifm_ch = self.get_nodeattr("NumChannels") + simd = self.get_nodeattr("SIMD") + assert ifm_ch % simd == 0, "SIMD must divide input channels" + fold = int(normal_oshape[-1] / simd) + folded_oshape = normal_oshape[:-1] + [fold, simd] + return tuple(folded_oshape) + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape() + oshape = self.get_normal_output_shape() + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpect input shape for FMPadding_Pixel." + return super().make_const_shape_op(oshape) + + def infer_node_datatype(self, model): + node = self.onnx_node + idt = model.get_tensor_datatype(node.input[0]) + if idt != self.get_input_datatype(): + warn_str = "inputDataType changing for %s: %s -> %s " % ( + node.name, + str(self.get_input_datatype()), + str(idt), + ) + warnings.warn(warn_str) + self.set_nodeattr("inputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_input_datatype(self, ind=0): + """Returns FINN DataType of input.""" + ret = DataType[self.get_nodeattr("inputDataType")] + # the hlslib op always pads with zeros, so ensure that the DataType + # is able to represent zeros + assert ret.allowed(0), "FMPadding_Pixel DataType must support zero" + return ret + + def get_output_datatype(self, ind=0): + """Returns FINN DataType of output. (Same as input datatype)""" + return self.get_input_datatype() + + def get_instream_width(self, ind=0): + ibits = self.get_input_datatype().bitwidth() + simd = self.get_nodeattr("SIMD") + return ibits * simd + + def get_outstream_width(self, ind=0): + obits = self.get_output_datatype().bitwidth() + simd = self.get_nodeattr("SIMD") + return obits * simd + + def get_number_output_values(self): + folded_oshape = self.get_folded_output_shape() + return np.prod(folded_oshape[:-1]) + + def global_includes(self): + self.code_gen_dict["$GLOBALS$"] = ['#include "streamtools.h"'] + + def defines(self, var): + odim_h, odim_w = self.get_padded_odim() + stride_h, stride_w = self.get_nodeattr("Stride") + self.code_gen_dict["$DEFINES$"] = [ + """ + #define OutputDim_x {}\n + #define OutputDim_y {}\n + #define Stride_x {}\n + #define Stride_y {}\n + #define NumChannels {}\n + #define SIMD {}\n + """.format( + odim_w, + odim_h, + stride_w, + stride_h, + self.get_nodeattr("NumChannels"), + self.get_nodeattr("SIMD"), + ) + ] + + def read_npy_data(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + dtype = self.get_input_datatype() + if dtype == DataType["BIPOLAR"]: + # use binary for bipolar storage + dtype = DataType["BINARY"] + elem_bits = dtype.bitwidth() + packed_bits = self.get_instream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + elem_hls_type = dtype.get_hls_datatype_str() + npy_type = "float" + npy_in = "%s/input_0.npy" % code_gen_dir + self.code_gen_dict["$READNPYDATA$"] = [] + self.code_gen_dict["$READNPYDATA$"].append( + 'npy2apintstream<%s, %s, %d, %s>("%s", in0);' + % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in) + ) + + def strm_decl(self): + self.code_gen_dict["$STREAMDECLARATIONS$"] = [] + self.code_gen_dict["$STREAMDECLARATIONS$"].append( + 'hls::stream> in0 ("in0");'.format(self.get_instream_width()) + ) + self.code_gen_dict["$STREAMDECLARATIONS$"].append( + 'hls::stream> out ("out");'.format(self.get_outstream_width()) + ) + + def docompute(self): + in_t = self.get_input_datatype().get_hls_datatype_str() + odim_h, odim_w = self.get_padded_odim() + stride_h, stride_w = self.get_nodeattr("Stride") + hls_call = "FMPadding_Pixel_Nonsquare" + self.code_gen_dict["$DOCOMPUTE$"] = [ + """{} (in0, out);""".format( + hls_call, in_t + ) + ] + + def dataoutstrm(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + dtype = self.get_output_datatype() + if dtype == DataType["BIPOLAR"]: + # use binary for bipolar storage + dtype = DataType["BINARY"] + elem_bits = dtype.bitwidth() + packed_bits = self.get_outstream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + elem_hls_type = dtype.get_hls_datatype_str() + npy_type = "float" + npy_out = "%s/output.npy" % code_gen_dir + oshape = self.get_folded_output_shape() + oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}") + + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + 'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s");' + % ( + packed_hls_type, + elem_hls_type, + elem_bits, + npy_type, + oshape_cpp_str, + npy_out, + ) + ] + + def save_as_npy(self): + self.code_gen_dict["$SAVEASCNPY$"] = [] + + def blackboxfunction(self): + packed_bits = self.get_instream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)" + % (self.onnx_node.name, packed_hls_type, packed_hls_type) + ] + + def pragmas(self): + self.code_gen_dict["$PRAGMAS$"] = [ + "#pragma HLS INTERFACE axis port=in0 name=in0_" + self.hls_sname() + ] + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE axis port=out name=out_" + self.hls_sname() + ) + self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE ap_ctrl_none port=return") + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + node = self.onnx_node + exp_ishape = self.get_normal_input_shape() + exp_oshape = self.get_normal_output_shape() + folded_ishape = self.get_folded_input_shape() + + if mode == "cppsim": + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + elif mode == "rtlsim": + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following value ("cppsim", "rtlsim")""".format( + mode + ) + ) + + inp = context[node.input[0]] + assert str(inp.dtype) == "float32", "Input datatype is not float32" + assert ( + inp.shape == exp_ishape + ), """Input shape doesn't + match expected shape (1, ImgDim_h, ImgDim_w, NumChannels).""" + export_idt = self.get_input_datatype() + + reshaped_input = inp.reshape(folded_ishape) + np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input) + + if mode == "cppsim": + # execute the precompiled model + super().exec_precompiled_singlenode_model() + # load output npy file + super().npy_to_dynamic_output(context) + assert ( + context[node.output[0]].shape == exp_oshape + ), "cppsim did not produce expected output shape" + elif mode == "rtlsim": + sim = self.get_rtlsim() + nbits = self.get_instream_width() + rtlsim_inp = npy_to_rtlsim_input( + "{}/input_0.npy".format(code_gen_dir), export_idt, nbits + ) + super().reset_rtlsim(sim) + super().toggle_clk(sim) + rtlsim_output = self.rtlsim(sim, rtlsim_inp) + odt = export_idt + target_bits = odt.bitwidth() + packed_bits = self.get_outstream_width() + out_npy_path = "{}/output.npy".format(code_gen_dir) + out_shape = self.get_folded_output_shape() + rtlsim_output_to_npy( + rtlsim_output, out_npy_path, odt, out_shape, packed_bits, target_bits + ) + # load and reshape output + output = np.load(out_npy_path) + output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape) + context[node.output[0]] = output + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following value ("cppsim", "rtlsim")""".format( + mode + ) + ) + assert ( + context[node.output[0]].shape == exp_oshape + ), """Output shape doesn't match expected shape + (1, OutputDim_H, OutputDim_W, NumChannels).""" diff --git a/src/finn/transformation/fpgadataflow/infer_pixel_padding_deconv.py b/src/finn/transformation/fpgadataflow/infer_pixel_padding_deconv.py new file mode 100644 index 0000000000..8dbf7071fc --- /dev/null +++ b/src/finn/transformation/fpgadataflow/infer_pixel_padding_deconv.py @@ -0,0 +1,205 @@ +import numpy as np +import warnings +from onnx import TensorProto, helper +from qonnx.transformation.base import Transformation +from qonnx.transformation.lower_convs_to_matmul import _auto_pad_to_explicit_padding +from qonnx.util.basic import get_by_name + + +class InferPixelPaddingDeconv(Transformation): + """ + Lowering and conversion of ConvTranspose (NCHW) nodes to + FMPadding_Pixel + Im2Col + MatMul (NHWC) surrounded by Transpose nodes + note: this transformation produces a mix of hw layers and non hw layers + to implement this on an FPGA the Im2Col and MatMul nodes need to be converted to hw layers + after applying this transformation and the resulting transpose nodes need to be streamlined. + See deconv test case under tests/fpgadataflow for an example. + """ + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "ConvTranspose": + # conversion currently only supported for group=1 + group = get_by_name(n.attribute, "group").i + if group != 1: + warnings.warn( + "%s : Only group=1 is currently supported. Can't infer PixelPaddingDeconv." + % n.name + ) + continue + deconv_input = n.input[0] + deconv_output = n.output[0] + idt = model.get_tensor_datatype(deconv_input) + odt = model.get_tensor_datatype(deconv_output) + k_h = get_by_name(n.attribute, "kernel_shape").ints[0] + k_w = get_by_name(n.attribute, "kernel_shape").ints[1] + stride_h = get_by_name(n.attribute, "strides").ints[0] + stride_w = get_by_name(n.attribute, "strides").ints[1] + weight_name = n.input[1] + W_conv = model.get_initializer(weight_name) + ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW + ofm_ch = model.get_tensor_shape(n.output[0])[1] # assume NCHW + ifm_dim_h = model.get_tensor_shape(n.input[0])[2] # assume NCHW + ifm_dim_w = model.get_tensor_shape(n.input[0])[3] + ofm_dim_h = model.get_tensor_shape(n.output[0])[2] # assume NCHW + ofm_dim_w = model.get_tensor_shape(n.output[0])[3] + dilation_attr = get_by_name(n.attribute, "dilations") + if dilation_attr is not None: + dilation = dilation_attr.ints + else: + dilation = [1, 1] # default value + # handle both auto_pad and explicit padding + auto_pad = get_by_name(n.attribute, "auto_pad") + if auto_pad is not None: + # find equivalent specified padding + auto_pad = auto_pad.s.decode("utf-8") + if auto_pad == "NOTSET": + # use specified padding + pad = get_by_name(n.attribute, "pads").ints + else: + pad = _auto_pad_to_explicit_padding( + auto_pad, + ifm_dim_h, + ifm_dim_w, + k_h, + k_w, + stride_h, + stride_w, + len(model.get_tensor_shape(n.input[0])) - 2, + ) + else: + # use specified padding + pad = get_by_name(n.attribute, "pads").ints + + # If len(pad) == 2, assume no padding for other dimension + if len(pad) == 2: # only one dimension should be padded + assert ( + ifm_dim_h == 1 or ifm_dim_w == 1 + ), "Padding is assumed to be 1D, image is 2D" + # reuse ConvTranspose weights for new matmul weights + # conv weights are [IFM][OFM][k][k] + # We need to rotate the weights and make them [OFM][IFM][k][k] + # for pixel padding deconv to remain mathematically equivalent + # and then convert to [OFM][k][k][IFM] (to remain compatible + # with finn-hlslib and how it does im2col/sliding window) + W_conv = np.rot90(W_conv, 2, [2, 3]) + W_conv = np.moveaxis(W_conv, 0, 1) + W_matmul = W_conv.transpose(0, 2, 3, 1) # W_conv = [OFM, IFM, k_H, k_W] + # reshape into [OFM][k*k*IFM] matrix + W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k_h * k_w) + # transpose to get ONNX-compatible [k*k*IFM][OFM] matrix + W_matmul = W_matmul.T + model.set_initializer(weight_name, W_matmul) + + # Compute intermediate parameters + padded_odim_h = ifm_dim_h + (ifm_dim_h - 1) * (stride_h - 1) + padded_odim_w = ifm_dim_w + (ifm_dim_w - 1) * (stride_w - 1) + conv_padding = [dilation[0] * (k_h - 1) - pad[0]] * 4 + + # create new intermediate values + inp_trans_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, ifm_dim_h, ifm_dim_w, ifm_ch), # NHWC + ) + padding_pixel_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, padded_odim_h, padded_odim_w, ifm_ch), # NHWC + ) + graph.value_info.append(inp_trans_out) + graph.value_info.append(padding_pixel_out) + inp_trans_out = inp_trans_out.name + padding_pixel_out = padding_pixel_out.name + model.set_tensor_datatype(inp_trans_out, idt) + model.set_tensor_datatype(padding_pixel_out, idt) + + need_im2col = True + if all(p == 0 for p in conv_padding): + padding = 0 + + # k_h=k_w==1: pointwise convolution, thus no im2col needed + if k_h == 1 and k_w == 1 and padding == 0 and stride_h == 1 and stride_w == 1: + need_im2col = False + + if need_im2col: + im2col_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w), + ) + graph.value_info.append(im2col_out) + im2col_out = im2col_out.name + model.set_tensor_datatype(im2col_out, idt) + + matmul_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, ofm_dim_h, ofm_dim_w, ofm_ch), + ) + graph.value_info.append(matmul_out) + matmul_out = matmul_out.name + model.set_tensor_datatype(matmul_out, odt) + + # create new nodes + + # NCHW -> NHWC + inp_trans_node = helper.make_node( + "Transpose", [deconv_input], [inp_trans_out], perm=[0, 2, 3, 1] + ) + # Pixel Padding + fmpadding_pixel_node = helper.make_node( + "FMPadding_Pixel", + [inp_trans_out], + [padding_pixel_out], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + ImgDim=(ifm_dim_h, ifm_dim_w), + Stride=[stride_h, stride_w], + NumChannels=ifm_ch, + inputDataType=str(idt.name), + numInputVectors=1, + SIMD=1, + ) + # lower input tensor + matmul_input = padding_pixel_out + if need_im2col: + matmul_input = im2col_out + im2col_node = helper.make_node( + "Im2Col", + [padding_pixel_out], + [im2col_out], + domain="qonnx.custom_op.general", + stride=[1, 1], + kernel_size=[k_h, k_w], + pad_amount=conv_padding, + input_shape="(1,{},{},{})".format(padded_odim_h, padded_odim_w, ifm_ch), + depthwise=False, + dilations=dilation, + ) + + # do matmul + matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) + # NHWC -> NCHW + out_trans_node = helper.make_node( + "Transpose", [matmul_out], [deconv_output], perm=[0, 3, 1, 2] + ) + # insert nodes where the conv is to preserve topological ordering + graph.node.insert(node_ind, inp_trans_node) + if need_im2col: + graph.node.insert(node_ind + 1, fmpadding_pixel_node) + graph.node.insert(node_ind + 2, im2col_node) + graph.node.insert(node_ind + 3, matmul_node) + graph.node.insert(node_ind + 4, out_trans_node) + else: + graph.node.insert(node_ind + 1, fmpadding_pixel_node) + graph.node.insert(node_ind + 2, matmul_node) + graph.node.insert(node_ind + 3, out_trans_node) + # remove old nodes + graph.node.remove(n) + + return (model, graph_modified) diff --git a/src/finn/transformation/fpgadataflow/set_folding.py b/src/finn/transformation/fpgadataflow/set_folding.py index eca1053f8f..4045a28e16 100644 --- a/src/finn/transformation/fpgadataflow/set_folding.py +++ b/src/finn/transformation/fpgadataflow/set_folding.py @@ -112,6 +112,7 @@ def apply(self, model): simd_ops = [ "DownSampler", "FMPadding_Batch", + "FMPadding_Pixel", "ConvolutionInputGenerator", "ConvolutionInputGenerator1D", "ConvolutionInputGenerator_rtl", diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index e027010271..0f6cbacb82 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -97,7 +97,14 @@ def apply(self, model): model.set_initializer(node_out, q_node_output) else: # Check next operator type - mul_like_nodes = ["Mul", "Div", "Conv", "MatMul", "Gather"] + mul_like_nodes = [ + "Mul", + "Div", + "Conv", + "MatMul", + "Gather", + "ConvTranspose", + ] add_like_nodes = ["Add", "Sub"] all_supported_ops = mul_like_nodes.copy() all_supported_ops.extend(add_like_nodes) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 2e6aebf093..8ac2d7dad6 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -338,6 +338,55 @@ def apply(self, model): return (model, graph_modified) +class MoveScalarMulPastConvTranspose(Transformation): + """Move scalar mul operations past ConvTranspose operations. We want to have muls + next to each other such that they can be collapsed into a single mul.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n): + consumer = model.find_consumer(n.output[0]) + if ( + consumer is not None + and consumer.op_type == "ConvTranspose" + and not model.is_join_node(consumer) + ): + mul_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + if A is None: + warnings.warn("Mul param is not constant, skipping") + continue + conv_node = consumer + mul_node = n + start_name = mul_node.input[0] + conv_in_name = conv_node.input[0] + conv_in_shape = model.get_tensor_shape(conv_in_name) + conv_out_name = conv_node.output[0] + conv_out_shape = model.get_tensor_shape(conv_out_name) + if all(x == 1 for x in A.shape): + # if the mul is scalar, we can simply swap the order of ops + # rewire mul input to be conv input + conv_node.input[0] = start_name + model.set_tensor_shape(start_name, conv_in_shape) + # use old conv input tensor as conv output + conv_node.output[0] = conv_in_name + model.set_tensor_shape(conv_in_name, conv_out_shape) + # use new conv output as new mul node input + mul_node.input[0] = conv_in_name + # use old conv output as new mul node output + mul_node.output[0] = conv_out_name + # move add node past conv node + graph.node.remove(mul_node) + graph.node.insert(node_ind, mul_node) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) + + class MoveMulPastDWConv(Transformation): """Move channelwise mul operations past depthwise conv operations. We want to have muls next to each other such that they can be collapsed into a single mul.""" diff --git a/tests/brevitas/test_brevitas_deconv.py b/tests/brevitas/test_brevitas_deconv.py new file mode 100644 index 0000000000..dfcecc9187 --- /dev/null +++ b/tests/brevitas/test_brevitas_deconv.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +import brevitas.nn as qnn +import numpy as np +import os +import torch +from brevitas.export import export_qonnx +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.cleanup import cleanup as qonnx_cleanup + +import finn.core.onnx_exec as oxe +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN + +export_path = "test_brevitas_deconv.onnx" + + +@pytest.mark.brevitas_export +@pytest.mark.parametrize("ifm_ch", [3]) +@pytest.mark.parametrize("ofm_ch", [5]) +@pytest.mark.parametrize("mh", [4]) +@pytest.mark.parametrize("mw", [4]) +@pytest.mark.parametrize("padding", [1]) +@pytest.mark.parametrize("stride", [2]) +@pytest.mark.parametrize("kw", [4]) +@pytest.mark.parametrize("bias", [False]) +def test_brevitas_QTransposeConv(ifm_ch, ofm_ch, mh, mw, padding, stride, kw, bias): + kh = kw + oh = stride * (mh - 1) - (2 * padding) + kh + if oh % mh != 0: + pytest.skip("Skip test because oh needs to be divisible by mh") + ishape = (1, ifm_ch, mh, mw) # NCHW + inp = torch.randn(ishape) + b_deconv = qnn.QuantConvTranspose2d( + in_channels=ifm_ch, + out_channels=ofm_ch, + kernel_size=kw, + stride=stride, + padding=padding, + bias=bias, + ) + # outp = el(inp) # expects NCHW data format + export_qonnx(b_deconv, input_t=inp, export_path=export_path, opset_version=11) + qonnx_cleanup(export_path, out_file=export_path) + model = ModelWrapper(export_path) + model = model.transform(ConvertQONNXtoFINN()) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + expected = b_deconv.forward(inp_tensor).detach().numpy() + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_path) diff --git a/tests/fpgadataflow/test_fpgadataflow_deconv.py b/tests/fpgadataflow/test_fpgadataflow_deconv.py new file mode 100644 index 0000000000..6c25be0f85 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_deconv.py @@ -0,0 +1,207 @@ +# Copyright (c) 2023, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +import os +from onnx import TensorProto, helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames +from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model + +import finn.core.onnx_exec as oxe +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.convert_to_hls_layers import ( + InferConvInpGen, + InferQuantizedMatrixVectorActivation, +) +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.infer_pixel_padding_deconv import ( + InferPixelPaddingDeconv, +) +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.util.basic import pynq_part_map + +test_pynq_board = os.getenv("PYNQ_BOARD", default="Pynq-Z1") +test_fpga_part = pynq_part_map[test_pynq_board] +target_clk_ns = 10 + + +def set_up_reference_model(idt, wdt, k, idim, ifm_ch, ofm_ch, stride, padding): + idim_h, idim_w = idim + stride_h, stride_w = stride + odim_h = (idim_h - 1) * stride_h - 2 * padding + (k - 1) + 1 + odim_w = (idim_w - 1) * stride_w - 2 * padding + (k - 1) + 1 + odt = DataType["INT32"] + + inp = helper.make_tensor_value_info( + "inp", + TensorProto.FLOAT, + [ + 1, + ifm_ch, + idim_h, + idim_w, + ], + ) + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, ofm_ch, odim_h, odim_w]) + + W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [ifm_ch, ofm_ch, k, k]) + + ConvTranspose = helper.make_node( + "ConvTranspose", + ["inp", "W"], + ["outp"], + dilations=(1, 1), + group=1, + kernel_shape=(k, k), + pads=(padding, padding, padding, padding), + strides=(stride_h, stride_w), + ) + + node_list = [ConvTranspose] + value_info = [W] + + graph = helper.make_graph( + nodes=node_list, + name="convtranspose_graph", + inputs=[inp], + outputs=[outp], + value_info=value_info, + ) + + model = qonnx_make_model(graph, producer_name="convtranspose-model") + model = ModelWrapper(model) + + # initialize model + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype(model.graph.output[0].name, odt) + model.set_tensor_datatype("W", wdt) + + w_tensor = gen_finn_dt_tensor(wdt, [ifm_ch, ofm_ch, k, k]) + model.set_initializer("W", w_tensor) + + model = model.transform(InferShapes()) + + return model + + +# input image dimension +@pytest.mark.parametrize("idim", [[8, 8], [10, 8]]) +# number of rows and number of cols to add +@pytest.mark.parametrize("stride", [[2, 2], [2, 3]]) +# number of channels +@pytest.mark.parametrize("ifm_ch", [2]) +# number of channels +@pytest.mark.parametrize("ofm_ch", [4]) +# Input parallelism +@pytest.mark.parametrize("simd", [1, 2]) +# PE +@pytest.mark.parametrize("pe", [1, 2]) +# kernel size +@pytest.mark.parametrize("k", [2]) +# padding +@pytest.mark.parametrize("padding", [0, 1]) +# exec mode +@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_deconv(idim, stride, ifm_ch, ofm_ch, simd, pe, k, padding, exec_mode): + idt = wdt = DataType["INT4"] + wdt = idt + idim_h, idim_w = idim + stride_h, stride_w = stride + + if idim_h == idim_w and stride_h == stride_w: + convinpgen_rtl = False + else: + convinpgen_rtl = True + + if exec_mode == "cppsim" and convinpgen_rtl: + pytest.skip("ConvolutionInputGenerator_rtl has no cppsim, skipping cppsim") + + ref_model = set_up_reference_model(idt, wdt, k, idim, ifm_ch, ofm_ch, stride, padding) + + odim_h = (idim_h - 1) * stride_h - 2 * padding + (k - 1) + 1 + odim_w = (idim_w - 1) * stride_w - 2 * padding + (k - 1) + 1 + + input_tensor = gen_finn_dt_tensor(idt, [1, ifm_ch, idim_h, idim_w]) + input_dict = {"inp": input_tensor} + + model = ref_model.transform(InferPixelPaddingDeconv()) + model = model.transform(InferConvInpGen(use_rtl_variant=convinpgen_rtl)) + model = model.transform(InferQuantizedMatrixVectorActivation()) + model = model.transform(InferShapes()) + model = model.transform(GiveUniqueNodeNames()) + + for n in model.graph.node: + if n.op_type == "ConvolutionInputGenerator" and not convinpgen_rtl: + convinputgen_node = getCustomOp(n) + convinputgen_node.set_nodeattr("SIMD", simd) + elif n.op_type == "MatrixVectorActivation": + mvau_node = getCustomOp(n) + mvau_node.set_nodeattr("PE", pe) + mvau_node.set_nodeattr("SIMD", simd) + + expected_oshape = (1, ofm_ch, odim_h, odim_w) + y_expected = oxe.execute_onnx(ref_model, input_dict)["outp"] + + # cppsim + if exec_mode == "cppsim": + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + model = model.transform(SetExecMode("cppsim")) + + # rtlsim + else: + model = model.transform(PrepareIP(test_fpga_part, target_clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + model = model.transform(SetExecMode("rtlsim")) + + y_produced = oxe.execute_onnx(model, input_dict)["outp"] + assert y_produced.shape == expected_oshape + assert (y_produced == y_expected).all() + + if exec_mode == "rtlsim": + node = model.get_nodes_by_op_type("FMPadding_Pixel")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 diff --git a/tests/transformation/streamline/test_move_scalar_past_convtranspose.py b/tests/transformation/streamline/test_move_scalar_past_convtranspose.py new file mode 100644 index 0000000000..7da22abd87 --- /dev/null +++ b/tests/transformation/streamline/test_move_scalar_past_convtranspose.py @@ -0,0 +1,106 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +import numpy as np +import onnx.helper as oh +from onnx import TensorProto +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.basic import qonnx_make_model + +import finn.core.onnx_exec as ox +from finn.transformation.streamline.reorder import MoveScalarMulPastConvTranspose + + +@pytest.mark.streamline +# input image dimension +@pytest.mark.parametrize("idim", [[8, 8], [10, 8]]) +# number of rows and number of cols to add +@pytest.mark.parametrize("stride", [[2, 2], [2, 3]]) +# number of channels +@pytest.mark.parametrize("ifm_ch", [2, 4]) +# number of channels +@pytest.mark.parametrize("ofm_ch", [2, 4]) +# kernel size +@pytest.mark.parametrize("k", [2, 4]) +# padding +@pytest.mark.parametrize("padding", [False, True]) +def test_move_scalar_past_conv(idim, stride, ifm_ch, ofm_ch, k, padding): + idim_h, idim_w = idim + stride_h, stride_w = stride + + odim_h = (idim_h - 1) * stride_h - 2 * padding + (k - 1) + 1 + odim_w = (idim_w - 1) * stride_w - 2 * padding + (k - 1) + 1 + + input_shape = [1, ifm_ch, idim_h, idim_w] + output_shape = [1, ofm_ch, odim_h, odim_w] + + conv_param_shape = [ifm_ch, ofm_ch, k, k] + + conv_config = {} + conv_config["dilations"] = [1, 1] + conv_config["group"] = 1 + conv_config["kernel_shape"] = [k, k] + if padding: + conv_config["pads"] = [1, 1, 1, 1] + else: + conv_config["pads"] = [0, 0, 0, 0] + conv_config["strides"] = [stride_h, stride_w] + + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape) + + value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)] + + modelproto = qonnx_make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + oh.make_node("Mul", ["top_in", "p1"], ["t1"]), + oh.make_node("ConvTranspose", ["t1", "p2"], ["top_out"], **conv_config), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("p1", *np.random.rand(1).astype(np.float32)) + model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32)) + + new_model = model.transform(MoveScalarMulPastConvTranspose()) + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "ConvTranspose" + assert new_model.graph.node[1].op_type == "Mul"