Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamlining of Scaled Dot-Product Attention #901

Draft
wants to merge 17 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fetch-repos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

QONNX_COMMIT="fd61cfeebbdaba351abf7e9d54cd785d7776fa4f"
QONNX_COMMIT="1a4957ebf2aaf139217fd56109386d4518dd6127"
FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2"
BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db"
PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1"
Expand All @@ -40,7 +40,7 @@ RFSOC4x2_BDF_COMMIT="13fb6f6c02c7dfd7e4b336b18b959ad5115db696"
KV260_BDF_COMMIT="98e0d3efc901f0b974006bc4370c2a7ad8856c79"
EXP_BOARD_FILES_MD5="226ca927a16ea4ce579f1332675e9e9a"

QONNX_URL="https://github.com/fastmachinelearning/qonnx.git"
QONNX_URL="https://github.com/iksnagreb/qonnx.git"
FINN_EXP_URL="https://github.com/Xilinx/finn-experimental.git"
BREVITAS_URL="https://github.com/Xilinx/brevitas.git"
PYVERILATOR_URL="https://github.com/maltanar/pyverilator.git"
Expand Down
35 changes: 29 additions & 6 deletions src/finn/transformation/qonnx/fold_quant_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def apply(self, model):
mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
mul_shape, # Note: This shape is known exactly as
# it is an initializer with known shape
)
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, scale)
Expand All @@ -168,7 +169,9 @@ def apply(self, model):
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
None, # Note: Explicitly delete the shape
# annotation to be redone by the next shape
# inference
)
graph.value_info.append(act_mul_tensor)
successor.output[0] = act_mul_tensor.name
Expand All @@ -186,19 +189,37 @@ def apply(self, model):
div_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
None, # Note: Explicitly delete the shape
# annotation to be redone by the next shape
# inference
)
graph.value_info.append(div_tensor)
model.set_initializer(div_tensor.name, scale)

succ_input_name = successor.input[0]
# Detect which input of the add-like successor is
# fed by the quantizer node to select the other
# branch to insert the scale factor
if successor.input[0] == node_out:
succ_input_name = successor.input[1]
else:
succ_input_name = successor.input[0]

act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
None, # Note: Explicitly delete the shape
# annotation to be redone by the next shape
# inference
)
graph.value_info.append(act_mul_tensor)
successor.input[0] = act_mul_tensor.name

# Detect which input of the add-like successor is
# fed by the quantizer node to select the other
# branch to insert the scale factor
if successor.input[0] == node_out:
successor.input[1] = act_mul_tensor.name
else:
successor.input[0] = act_mul_tensor.name

div_node = helper.make_node(
"Div",
Expand All @@ -210,6 +231,8 @@ def apply(self, model):
# remove old node
graph.node.remove(n)
graph_modified = True
# Note: Running shape inference is necessary as shape
# annotations have been deleted above
model = model.transform(InferShapes())
return (model, graph_modified)
return (model, graph_modified)
59 changes: 49 additions & 10 deletions src/finn/transformation/qonnx/qonnx_activation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from abc import ABC, abstractmethod
from onnx import TensorProto, helper
from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -70,7 +70,7 @@ def _check_compatibility(self):
@abstractmethod
def _calculate_act_bias(self):
"""Calculate the activation bias,
which is introduced as an Add node behind the MultiTrheshold node.
which is introduced as an Add node behind the MultiThreshold node.
"""
raise NotImplementedError()

Expand All @@ -82,7 +82,7 @@ def _calculate_thresholds(self):
@abstractmethod
def _calculate_act_scale(self):
"""Calculate the activation scale,
which is indroduced as a Mul node behind the Add node
which is introduced as a Mul node behind the Add node
for the activation bias.
"""
raise NotImplementedError()
Expand Down Expand Up @@ -157,7 +157,7 @@ def replace_quant_node(self):
# Set scale and bias
# If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
# behind the MultiTrheshold nodes.
# behind the MultiThreshold nodes.
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant":
Expand Down Expand Up @@ -355,7 +355,7 @@ def _calculate_thresholds(self):
act_node = self._model.find_direct_predecessors(self._q_node)
act_node = act_node[0]
if act_node.op_type == "Relu":
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
# onnx/finn/handler/act.py#L21
num_distinct_values = 2**bit_width
Expand Down Expand Up @@ -395,8 +395,27 @@ def _calculate_thresholds(self):
else:
thresholds[c][t] = step / selu_scale

# First try to consider the tensor layout of the output for determining
# the number of output channels
layout = self._model.get_tensor_layout(self._q_node.output[0])
# If there is a layout annotation, use this to determine the index of
# the channel dimension
if layout is not None and "C" in layout:
# Lookup the index in list
cdim = layout.index("C")
# If no layout has been annotated or there is no channel dimension, fall
# back to the previous default assumption
else:
# Assume the channels to be in axis 1
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.output[0]}:"
f" Assuming channel dimension at index {cdim}"
)

# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]
final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
Expand All @@ -417,12 +436,12 @@ def _remove_activation_node(self, multi_threshold_node):
act_node = self._model.find_direct_predecessors(self._q_node)
if act_node is None:
raise RuntimeError(
"For handling of Relu activations a predecesor to " "the Quant node must exist."
"For handling of Relu activations a predecessor to " "the Quant node must exist."
)
act_node = act_node[0]
if act_node.op_type not in self.valid_predecessor_op_types():
raise RuntimeError(
"The predecesor of the Quant node must be Relu or Selu for handling "
"The predecessor of the Quant node must be Relu or Selu for handling "
"of activations."
)

Expand Down Expand Up @@ -509,7 +528,7 @@ def _calculate_thresholds(self):
else:
raise RuntimeError("Got an unexpected quantizer node type")

# Calculate thersholds, see: https://github.com/Xilinx/brevitas/
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
# export/onnx/finn/handler/act.py#L76
if bit_width == 1.0:
Expand Down Expand Up @@ -537,8 +556,28 @@ def _calculate_thresholds(self):
for t in range(num_thresholds):
thresholds[c][t] = min_threshold[c] + step[c] * t

# First try to consider the tensor layout of the output for
# determining the number of output channels
layout = self._model.get_tensor_layout(self._q_node.output[0])
# If there is a layout annotation, use this to determine the index
# of the channel dimension
if layout is not None and "C" in layout:
# Lookup the index in list
cdim = layout.index("C")
# If no layout has been annotated or there is no channel dimension,
# fall back to the previous default assumption
else:
# Assume the channels to be in axis 1
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.output[0]}:"
f" Assuming channel dimension at index {cdim}"
)

# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]

final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
Expand Down
2 changes: 1 addition & 1 deletion src/finn/transformation/streamline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def apply(self, model):
BatchNormToAffine(),
ConvertSignToThres(),
MoveMulPastMaxPool(),
MoveScalarLinearPastInvariants(),
AbsorbSignBiasIntoMultiThreshold(),
MoveScalarLinearPastInvariants(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveAddPastConv(),
Expand Down
85 changes: 68 additions & 17 deletions src/finn/transformation/streamline/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import qonnx.core.data_layout as DataLayout
import warnings
from onnx import helper as oh
# Protobuf onnx graph node type
from onnx import NodeProto # noqa
# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
Expand Down Expand Up @@ -100,6 +104,23 @@ def apply(self, model):
return (model, graph_modified)


# Groups inputs by categories, i.e., groups dynamic inputs first, followed by
# initializers. Keeps order of inputs in each category.
def group_inputs_by_category(node: NodeProto, model: ModelWrapper): # noqa
# First select all dynamic inputs, which are those without initializer
# tensor
dynamics = [
i for i in node.input if model.get_initializer(i) is None
]
# Select all input which are initializers, which, by exclusion, are all
# those not among the dynamic inputs
initializers = [
i for i in node.input if i not in dynamics
]
# Return lists of dynamic anc initializer inputs
return dynamics, initializers


class AbsorbAddIntoMultiThreshold(Transformation):
"""Absorb preceding Add ops into MultiThreshold by updating the threshold
values. Only scalar/1D add vectors can be absorbed."""
Expand All @@ -113,13 +134,19 @@ def apply(self, model):
if n.op_type == "Add" and not model.is_fork_node(n) and not model.is_join_node(n):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
add_weight_name = n.input[1]
threshold_name = consumer.input[1]
A = model.get_initializer(add_weight_name)
T = model.get_initializer(threshold_name)
assert A is not None, "Initializer for add weights is not set."
# As Add is not a join node, there must be one initializer
# and one dynamic input. We do not know their order, but
# can group them accordingly to extract the tensor names
(start,), (add_weight, ) = group_inputs_by_category(
n, model
)
threshold = consumer.input[1]
A = model.get_initializer(add_weight)
T = model.get_initializer(threshold)
# Test for the thresholds actually being initializers
# Note: No need to validate the add_weights anymore, this
# is already handled by the grouping and is_join_node test.
assert T is not None, "Initializer for thresholds is not set."
start_name = n.input[0]
# we can only absorb 0d or 1d adds
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
Expand All @@ -128,13 +155,13 @@ def apply(self, model):
Tnew = T - A.reshape(-1, 1)
# Tnew = T - A.reshape(-1, T.shape[1])
# compute new thresholds and set initializer
model.set_initializer(threshold_name, Tnew)
model.set_initializer(threshold, Tnew)
# wire add input directly to MultiThreshold
consumer.input[0] = start_name
consumer.input[0] = start
# remove the add node
graph.node.remove(n)
graph_modified = True
return (model, graph_modified)
return model, graph_modified


class AbsorbMulIntoMultiThreshold(Transformation):
Expand Down Expand Up @@ -215,7 +242,7 @@ def apply(self, model):


class Absorb1BitMulIntoMatMul(Transformation):
"""Absorb bipolar or binary multiplications into the preciding matrix
"""Absorb bipolar or binary multiplications into the preceding matrix
multiply."""

def apply(self, model):
Expand All @@ -224,16 +251,28 @@ def apply(self, model):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MatMul":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "MatMul" and not model.is_fork_node(n):
matmul_weight_name = n.input[1]
W = model.get_initializer(matmul_weight_name)
Wdt = model.get_tensor_datatype(matmul_weight_name)
assert W is not None, "Initializer for matmul weights is not set."
# Just skip matmuls with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
if is_1bit:
Wnew = A * W
Expand All @@ -252,24 +291,36 @@ def apply(self, model):


class Absorb1BitMulIntoConv(Transformation):
"""Absorb bipolar or binary multiplications into the preciding convolution."""
"""Absorb bipolar or binary multiplications into the preceding convolution."""

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Conv":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "Conv" and not model.is_fork_node(n):
conv_weight_name = n.input[1]
W = model.get_initializer(conv_weight_name)
Wdt = model.get_tensor_datatype(conv_weight_name)
assert W is not None, "Initializer for conv weights is not set."
# Just skip convs with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
is_scalar = np.prod(A.shape) == 1
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
Expand Down
Loading