diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index e3e2468bba..72f49d3ac0 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -224,16 +224,22 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 + # TODO: Maybe test for join-node here and reject? if n.op_type == "MatMul": matmul_weight_name = n.input[1] W = model.get_initializer(matmul_weight_name) Wdt = model.get_tensor_datatype(matmul_weight_name) - assert W is not None, "Initializer for matmul weights is not set." + # Just skip matmuls with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # TODO: Maybe test for join-node here and reject? if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 if is_1bit: Wnew = A * W @@ -260,16 +266,22 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 + # TODO: Maybe test for join-node here and reject? if n.op_type == "Conv": conv_weight_name = n.input[1] W = model.get_initializer(conv_weight_name) Wdt = model.get_tensor_datatype(conv_weight_name) - assert W is not None, "Initializer for conv weights is not set." + # Just skip convs with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # TODO: Maybe test for join-node here and reject? if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 is_scalar = np.prod(A.shape) == 1 actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))