Skip to content

Commit

Permalink
[Streamline] Soften initializer tests in Absorb1BitMulIntoMatMul/Conv
Browse files Browse the repository at this point in the history
Assertions are to restrictive, causing the program to terminate in cases
the streamlining simply encounters nodes to which the transforms are not
applicable: Just skip those nodes.

Only the two transforms currently affecting the streamlining of scaled
dot-product attention have been changed.
  • Loading branch information
iksnagreb committed Sep 30, 2023
1 parent 1bad1a5 commit 3962fc0
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/finn/transformation/streamline/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,22 @@ def apply(self, model):
graph_modified = False
for n in graph.node:
node_ind += 1
# TODO: Maybe test for join-node here and reject?
if n.op_type == "MatMul":
matmul_weight_name = n.input[1]
W = model.get_initializer(matmul_weight_name)
Wdt = model.get_tensor_datatype(matmul_weight_name)
assert W is not None, "Initializer for matmul weights is not set."
# Just skip matmuls with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# TODO: Maybe test for join-node here and reject?
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
if is_1bit:
Wnew = A * W
Expand All @@ -260,16 +266,22 @@ def apply(self, model):
graph_modified = False
for n in graph.node:
node_ind += 1
# TODO: Maybe test for join-node here and reject?
if n.op_type == "Conv":
conv_weight_name = n.input[1]
W = model.get_initializer(conv_weight_name)
Wdt = model.get_tensor_datatype(conv_weight_name)
assert W is not None, "Initializer for conv weights is not set."
# Just skip convs with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# TODO: Maybe test for join-node here and reject?
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
is_scalar = np.prod(A.shape) == 1
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
Expand Down

0 comments on commit 3962fc0

Please sign in to comment.