diff --git a/fetch-repos.sh b/fetch-repos.sh index 073c052d67..9ad51fefb0 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,7 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="fd61cfeebbdaba351abf7e9d54cd785d7776fa4f" +QONNX_COMMIT="1a4957ebf2aaf139217fd56109386d4518dd6127" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1" @@ -40,7 +40,7 @@ RFSOC4x2_BDF_COMMIT="13fb6f6c02c7dfd7e4b336b18b959ad5115db696" KV260_BDF_COMMIT="98e0d3efc901f0b974006bc4370c2a7ad8856c79" EXP_BOARD_FILES_MD5="226ca927a16ea4ce579f1332675e9e9a" -QONNX_URL="https://github.com/fastmachinelearning/qonnx.git" +QONNX_URL="https://github.com/iksnagreb/qonnx.git" FINN_EXP_URL="https://github.com/Xilinx/finn-experimental.git" BREVITAS_URL="https://github.com/Xilinx/brevitas.git" PYVERILATOR_URL="https://github.com/maltanar/pyverilator.git" diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index 0f6cbacb82..59ebe4eea3 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -149,7 +149,8 @@ def apply(self, model): mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - mul_shape, + mul_shape, # Note: This shape is known exactly as + # it is an initializer with known shape ) graph.value_info.append(mul_tensor) model.set_initializer(mul_tensor.name, scale) @@ -168,7 +169,9 @@ def apply(self, model): act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - output_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(act_mul_tensor) successor.output[0] = act_mul_tensor.name @@ -186,19 +189,37 @@ def apply(self, model): div_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - mul_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(div_tensor) model.set_initializer(div_tensor.name, scale) - succ_input_name = successor.input[0] + # Detect which input of the add-like successor is + # fed by the quantizer node to select the other + # branch to insert the scale factor + if successor.input[0] == node_out: + succ_input_name = successor.input[1] + else: + succ_input_name = successor.input[0] + act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - output_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(act_mul_tensor) - successor.input[0] = act_mul_tensor.name + + # Detect which input of the add-like successor is + # fed by the quantizer node to select the other + # branch to insert the scale factor + if successor.input[0] == node_out: + successor.input[1] = act_mul_tensor.name + else: + successor.input[0] = act_mul_tensor.name div_node = helper.make_node( "Div", @@ -210,6 +231,8 @@ def apply(self, model): # remove old node graph.node.remove(n) graph_modified = True + # Note: Running shape inference is necessary as shape + # annotations have been deleted above model = model.transform(InferShapes()) return (model, graph_modified) return (model, graph_modified) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 323e391df4..451ba52c29 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -25,8 +25,8 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - import numpy as np +import warnings from abc import ABC, abstractmethod from onnx import TensorProto, helper from qonnx.core.modelwrapper import ModelWrapper @@ -70,7 +70,7 @@ def _check_compatibility(self): @abstractmethod def _calculate_act_bias(self): """Calculate the activation bias, - which is introduced as an Add node behind the MultiTrheshold node. + which is introduced as an Add node behind the MultiThreshold node. """ raise NotImplementedError() @@ -82,7 +82,7 @@ def _calculate_thresholds(self): @abstractmethod def _calculate_act_scale(self): """Calculate the activation scale, - which is indroduced as a Mul node behind the Add node + which is introduced as a Mul node behind the Add node for the activation bias. """ raise NotImplementedError() @@ -157,7 +157,7 @@ def replace_quant_node(self): # Set scale and bias # If these values are scalar then they can be set as attributes # of the MultiThreshold node, if not they get inserted as adder and mul nodes - # behind the MultiTrheshold nodes. + # behind the MultiThreshold nodes. bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant": @@ -355,7 +355,7 @@ def _calculate_thresholds(self): act_node = self._model.find_direct_predecessors(self._q_node) act_node = act_node[0] if act_node.op_type == "Relu": - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/ + # Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/ # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ # onnx/finn/handler/act.py#L21 num_distinct_values = 2**bit_width @@ -395,8 +395,27 @@ def _calculate_thresholds(self): else: thresholds[c][t] = step / selu_scale + # First try to consider the tensor layout of the output for determining + # the number of output channels + layout = self._model.get_tensor_layout(self._q_node.output[0]) + # If there is a layout annotation, use this to determine the index of + # the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel dimension, fall + # back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"No layout annotations for {self._q_node.output[0]}:" + f" Assuming channel dimension at index {cdim}" + ) + # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim] final_shape = (num_output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = np.broadcast_to(thresholds, final_shape) @@ -417,12 +436,12 @@ def _remove_activation_node(self, multi_threshold_node): act_node = self._model.find_direct_predecessors(self._q_node) if act_node is None: raise RuntimeError( - "For handling of Relu activations a predecesor to " "the Quant node must exist." + "For handling of Relu activations a predecessor to " "the Quant node must exist." ) act_node = act_node[0] if act_node.op_type not in self.valid_predecessor_op_types(): raise RuntimeError( - "The predecesor of the Quant node must be Relu or Selu for handling " + "The predecessor of the Quant node must be Relu or Selu for handling " "of activations." ) @@ -509,7 +528,7 @@ def _calculate_thresholds(self): else: raise RuntimeError("Got an unexpected quantizer node type") - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/ + # Calculate thresholds, see: https://github.com/Xilinx/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L76 if bit_width == 1.0: @@ -537,8 +556,28 @@ def _calculate_thresholds(self): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t + # First try to consider the tensor layout of the output for + # determining the number of output channels + layout = self._model.get_tensor_layout(self._q_node.output[0]) + # If there is a layout annotation, use this to determine the index + # of the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel dimension, + # fall back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"No layout annotations for {self._q_node.output[0]}:" + f" Assuming channel dimension at index {cdim}" + ) + # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim] + final_shape = (num_output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = np.broadcast_to(thresholds, final_shape) diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index 2e68de698b..39ef87f81c 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -76,8 +76,8 @@ def apply(self, model): BatchNormToAffine(), ConvertSignToThres(), MoveMulPastMaxPool(), - MoveScalarLinearPastInvariants(), AbsorbSignBiasIntoMultiThreshold(), + MoveScalarLinearPastInvariants(), MoveAddPastMul(), MoveScalarAddPastMatMul(), MoveAddPastConv(), diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index e3e2468bba..9d5239eb5f 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -30,6 +30,10 @@ import qonnx.core.data_layout as DataLayout import warnings from onnx import helper as oh +# Protobuf onnx graph node type +from onnx import NodeProto # noqa +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.datatype import DataType from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation @@ -100,6 +104,23 @@ def apply(self, model): return (model, graph_modified) +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by +# initializers. Keeps order of inputs in each category. +def group_inputs_by_category(node: NodeProto, model: ModelWrapper): # noqa + # First select all dynamic inputs, which are those without initializer + # tensor + dynamics = [ + i for i in node.input if model.get_initializer(i) is None + ] + # Select all input which are initializers, which, by exclusion, are all + # those not among the dynamic inputs + initializers = [ + i for i in node.input if i not in dynamics + ] + # Return lists of dynamic anc initializer inputs + return dynamics, initializers + + class AbsorbAddIntoMultiThreshold(Transformation): """Absorb preceding Add ops into MultiThreshold by updating the threshold values. Only scalar/1D add vectors can be absorbed.""" @@ -113,13 +134,19 @@ def apply(self, model): if n.op_type == "Add" and not model.is_fork_node(n) and not model.is_join_node(n): consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": - add_weight_name = n.input[1] - threshold_name = consumer.input[1] - A = model.get_initializer(add_weight_name) - T = model.get_initializer(threshold_name) - assert A is not None, "Initializer for add weights is not set." + # As Add is not a join node, there must be one initializer + # and one dynamic input. We do not know their order, but + # can group them accordingly to extract the tensor names + (start,), (add_weight, ) = group_inputs_by_category( + n, model + ) + threshold = consumer.input[1] + A = model.get_initializer(add_weight) + T = model.get_initializer(threshold) + # Test for the thresholds actually being initializers + # Note: No need to validate the add_weights anymore, this + # is already handled by the grouping and is_join_node test. assert T is not None, "Initializer for thresholds is not set." - start_name = n.input[0] # we can only absorb 0d or 1d adds is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) @@ -128,13 +155,13 @@ def apply(self, model): Tnew = T - A.reshape(-1, 1) # Tnew = T - A.reshape(-1, T.shape[1]) # compute new thresholds and set initializer - model.set_initializer(threshold_name, Tnew) + model.set_initializer(threshold, Tnew) # wire add input directly to MultiThreshold - consumer.input[0] = start_name + consumer.input[0] = start # remove the add node graph.node.remove(n) graph_modified = True - return (model, graph_modified) + return model, graph_modified class AbsorbMulIntoMultiThreshold(Transformation): @@ -215,7 +242,7 @@ def apply(self, model): class Absorb1BitMulIntoMatMul(Transformation): - """Absorb bipolar or binary multiplications into the preciding matrix + """Absorb bipolar or binary multiplications into the preceding matrix multiply.""" def apply(self, model): @@ -224,16 +251,28 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "MatMul": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "MatMul" and not model.is_fork_node(n): matmul_weight_name = n.input[1] W = model.get_initializer(matmul_weight_name) Wdt = model.get_tensor_datatype(matmul_weight_name) - assert W is not None, "Initializer for matmul weights is not set." + # Just skip matmuls with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 if is_1bit: Wnew = A * W @@ -252,7 +291,7 @@ def apply(self, model): class Absorb1BitMulIntoConv(Transformation): - """Absorb bipolar or binary multiplications into the preciding convolution.""" + """Absorb bipolar or binary multiplications into the preceding convolution.""" def apply(self, model): graph = model.graph @@ -260,16 +299,28 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Conv": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "Conv" and not model.is_fork_node(n): conv_weight_name = n.input[1] W = model.get_initializer(conv_weight_name) Wdt = model.get_tensor_datatype(conv_weight_name) - assert W is not None, "Initializer for conv weights is not set." + # Just skip convs with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 is_scalar = np.prod(A.shape) == 1 actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 8ac2d7dad6..74cef0558a 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -100,58 +100,133 @@ def apply(self, model): return (model, graph_modified) +# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1 +def is_scalar(tensor): + return tensor is not None and all(x == 1 for x in tensor.shape) + + +# Tests whether a node is a scalar multiplication with a constant scale factor +def is_const_scalar_mul(node, model): + # Only handle existing Mul type nodes + if node is not None and node.op_type == "Mul": + # The constant must be an initializer + # Note: Assumes the constant parameter to always be the second input + scale = model.get_initializer(node.input[1]) + # Test for existence of a constant scale factor + return scale is not None and is_scalar(scale) + # Did not match the operator type + return False + + +# Refactored version of the MoveScalarMulPastMatMul transform capable of +# transforming two-input MatMul, like those being part of the attention operator class MoveScalarMulPastMatMul(Transformation): """Move scalar mul operations past matmul operations. We want to have muls next to each other such that they can be collapsed into a single mul.""" + # Applies the transform to a whole model graph def apply(self, model): + # Get the model graph out of the model wrapper object graph = model.graph - node_ind = 0 + # Keep track of whether the graph has been modified graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n): - consumer = model.find_consumer(n.output[0]) - if ( - consumer is not None - and consumer.op_type == "MatMul" - and not model.is_join_node(consumer) - ): - mul_weight_name = n.input[1] - matmul_weight_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - W = model.get_initializer(matmul_weight_name) - if (A is None) or (W is None): - warnings.warn("MatMul or Mul params are not constant, skipping") + + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # First pattern matching condition: For the transform to be + # applicable, the node has to be a MatMul operator + if node.op_type == "MatMul": + # Note: When touching the following code, remember to treat both + # branches equivalently! + # TODO: Can this be enforced or at least be made easier by + # extracting common code patterns to a function? + + # Get the left hand side and right hand side inputs + # Note: Assumes the ordering of left to right inputs to match + # indices 0 to 1. However, it does not "hurt" if it is + # reversed as both sides are treated equivalently. + lhs = model.find_producer(node.input[0]) + rhs = model.find_producer(node.input[1]) + + # Give precedence to the left hand side input testing for the + # presence of a scalar multiplication + if is_const_scalar_mul(lhs, model): + # Cannot handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. + if model.is_fork_node(lhs): + # Softly skip this node continue - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - mm_out_shape = model.get_tensor_shape(end_name) - if all(x == 1 for x in A.shape): - # if the mul is scalar, we can simply swap the order of ops - # make and insert new nodes - new_matmul = oh.make_node( - "MatMul", - [start_name, matmul_weight_name], - [middle_name], - name=consumer.name, - ) - new_mul = oh.make_node( - "Mul", - [middle_name, mul_weight_name], - [end_name], - name=n.name, - ) - graph.node.insert(node_ind, new_matmul) - graph.node.insert(node_ind + 1, new_mul) - model.set_tensor_shape(middle_name, mm_out_shape) - # remove old nodes - graph.node.remove(n) - graph.node.remove(consumer) - graph_modified = True + # Unpack the connection pattern of a scalar mul feeding the + # lhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = lhs.input[0], lhs.input[1], node.input[1] + # Names of the intermediate and the global output + m, o = lhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, c], [m], node.name) + mul = oh.make_node("Mul", [m, b], [o], lhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(lhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + # Cannot further modify the node (i.e., the rhs) as the + # index and state of the nodes changed and need to be + # queried again from the graph.node at the start of the next + # iteration. + continue + + # Next try whether the right hand side matches the pattern of a + # scalar multiplication + if is_const_scalar_mul(rhs, model): + # Cannot handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. + if model.is_fork_node(rhs): + # Softly skip this node + continue + # Unpack the connection pattern of a scalar mul feeding the + # rhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = node.input[0], rhs.input[0], rhs.input[1] + # Names of the intermediate and the global output + m, o = rhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, b], [m], node.name) + mul = oh.make_node("Mul", [m, c], [o], rhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(rhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + + # Finalize the transformation by inferring shapes again (as these might + # have changed) model = model.transform(InferShapes()) - return (model, graph_modified) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified class MoveScalarAddPastMatMul(Transformation): @@ -605,6 +680,17 @@ class MoveScalarLinearPastInvariants(Transformation): GlobalAveragePool """ + # Op-types of currently supported invariants + SUPPORTED_INVARIANTS = { + "GlobalAveragePool", + "Reshape", + "Transpose", + "Flatten", + "Slice", + "Squeeze", + "Unsqueeze", + } + def apply(self, model): graph = model.graph node_ind = 0 @@ -617,13 +703,7 @@ def apply(self, model): # Extract mode and scales and input shape mode = get_by_name(n.attribute, "mode").s.decode("ascii") is_nearest_neighbor_resample = mode == "nearest" - if ( - n.op_type == "GlobalAveragePool" - or n.op_type == "Reshape" - or n.op_type == "Transpose" - or n.op_type == "Flatten" - or is_nearest_neighbor_resample - ): + if n.op_type in self.SUPPORTED_INVARIANTS or is_nearest_neighbor_resample: in0 = n.input[0] if in0 is None: continue @@ -633,6 +713,16 @@ def apply(self, model): continue if prod0.op_type in ["Mul", "Add", "Div"]: + # Cannot handle fork-nodes, try MoveLinearPastFork first + if model.is_fork_node(prod0): + warnings.warn( + f"{self.__class__.__name__}:" + f" Skipping near match: {prod0.name} is a fork-node," + f" try MoveLinearPastFork first" + ) + # Skip transforming this node as moving this would lead + # to messed up or detached graph + continue # check if second input of producer is an initializer init0 = model.get_initializer(prod0.input[1]) # if either initializer is None, skip diff --git a/tests/transformation/streamline/test_move_scalar_past_matmul.py b/tests/transformation/streamline/test_move_scalar_past_matmul.py index e4f4357fff..515e9b9462 100644 --- a/tests/transformation/streamline/test_move_scalar_past_matmul.py +++ b/tests/transformation/streamline/test_move_scalar_past_matmul.py @@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul(): assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] +@pytest.mark.streamline +def test_move_scalar_mul_past_join_matmul(): + top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2]) + top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1]) + mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1]) + mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1]) + modelproto = qonnx_make_model( + oh.make_graph( + name="test", + inputs=[top_in1, top_in2], + outputs=[top_out], + value_info=[mul1_param, mul2_param], + nodes=[ + oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]), + oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]), + oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32)) + new_model = model.transform(MoveScalarMulPastMatMul()) + inp_dict = { + "top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32), + "top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32), + } + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] + assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0] + + @pytest.mark.streamline def test_move_scalar_add_past_matmul(): top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])