Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fdd89a6
[Streamline] Prefer AbsorbSignBiasIntoMultiThreshold transform
iksnagreb Sep 30, 2023
be33bbc
[Streamline] Refactor MoveScalarMulPastMatMul to handle join-node matmul
iksnagreb Sep 30, 2023
09c1993
Remove misplaced/outdated comment
iksnagreb Sep 30, 2023
9dade0c
[Streamline] Soften initializer tests in Absorb1BitMulIntoMatMul/Conv
iksnagreb Sep 30, 2023
8bae5d7
Address some linting issues
iksnagreb Oct 19, 2023
b22ebe3
[Tests] Add test for MoveScalarMulPastMatMul handling join nodes
iksnagreb Oct 19, 2023
c10fa1d
[Deps] Update qonnx version to include FoldTransposeIntoQuantInit fix
iksnagreb Oct 27, 2023
475a27b
[Streamline] Fix FoldQuantWeights input order and shape annotations
iksnagreb Nov 13, 2023
bd6a8f8
[Streamline] Fix AbsorbAddIntoMultiThreshold assumed input order
iksnagreb Nov 13, 2023
1f7dd4c
[Streamline] Add support for Slice to MoveScalarLinearPastInvariants
iksnagreb Nov 15, 2023
b3e50d7
[Streamline] Absorb1BitMulIntoMatMul/Conv does not handle fork-nodes
iksnagreb Nov 17, 2023
0413368
[Deps] Temporarily switch qonnx to my fork including necessary fixes
iksnagreb Nov 17, 2023
2bf7949
Make quantized activation handlers data layout aware
iksnagreb Nov 20, 2023
8783fd4
[Deps] Update qonnx
iksnagreb Nov 20, 2023
2bf37f1
[Deps] Update qonnx
iksnagreb Dec 13, 2023
a4fc498
[Deps] Update qonnx
iksnagreb Mar 13, 2024
6c56382
Fix some typos
iksnagreb Apr 4, 2024
15a9daa
Merge remote-tracking branch 'xilinx/dev' into feature/attention-stre…
iksnagreb Jan 20, 2025
fb6fe31
Merge branch 'feature/split-concat' into feature/streamline-plus
iksnagreb Feb 4, 2025
83174a4
[Streamline] Introduce StreamlinePlus: Exhaustive streamlining
iksnagreb Feb 4, 2025
f7e1178
Merge remote-tracking branch 'eki-project/dev' into feature/streamlin…
iksnagreb Feb 6, 2025
2ffb095
Merge branch 'dev' into feature/streamline-plus
iksnagreb Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ exclude =
dist
.eggs
docs/conf.py
per-file-ignores =
src/finn/transformation/streamline/streamline_plus.py: F405, F403

[pyscaffold]
# PyScaffold's parameters when the project was created.
Expand Down
2 changes: 1 addition & 1 deletion src/finn/transformation/streamline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def apply(self, model):
BatchNormToAffine(),
ConvertSignToThres(),
MoveMulPastMaxPool(),
MoveScalarLinearPastInvariants(),
AbsorbSignBiasIntoMultiThreshold(),
MoveScalarLinearPastInvariants(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveAddPastConv(),
Expand Down
220 changes: 160 additions & 60 deletions src/finn/transformation/streamline/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,81 +29,157 @@
import numpy as np
import qonnx.core.data_layout as DataLayout
import warnings

# Protobuf onnx graph node type
from onnx import NodeProto # noqa
from onnx import helper as oh
from qonnx.core.datatype import DataType

# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name

# Protobuf onnx graph node type
from onnx import NodeProto # noqa
from finn.transformation.util import group_inputs_by_category


# Note: Old name kept for compatibility reasons but actually allows to absorb
# any bias irrespective of signedness which might result in changed signedness
# of the output type
class AbsorbSignBiasIntoMultiThreshold(Transformation):
"""Absorb scalar bias originating from signed int export back into
MultiThreshold and re-evaluate the output datatype."""

def apply(self, model):
def apply(self, model: ModelWrapper):
# Get the model graph out of the model wrapper object
graph = model.graph
node_ind = 0
# Keep track of whether the graph has been modified
graph_modified = False
for n in graph.node:
# search for (MultiThreshold, Add) pair
node_ind += 1
# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# Only non-branching threshold operations are supported
if (
n.op_type == "MultiThreshold"
and not model.is_fork_node(n)
and not model.is_join_node(n)
node.op_type == "MultiThreshold"
and not model.is_fork_node(node)
and not model.is_join_node(node)
):
consumer = model.find_consumer(n.output[0])
# We now we are not forking, so there is at most one consumer
consumer = model.find_consumer(node.output[0])
# At the end of the graph we might have no consumer. If we have
# one, only handle Adds, turn Sub into Add first...
if consumer is not None and consumer.op_type == "Add":
mt_node = n
add_node = consumer
threshold_name = mt_node.input[1]
add_weight_name = add_node.input[1]
T = model.get_initializer(threshold_name)
A = model.get_initializer(add_weight_name)
if (A is None) or (T is None):
warnings.warn("Threshold or add bias not constant, skipping")
# Try to get the parameter tensor for the addition: Sanity
# check whether this is present, even though we already
# tested for non-joining
bias = model.get_initializer(consumer.input[1])

# Warn and skip if there is no constant bias present
if bias is None:
warnings.warn(
f"{self.__class__.__name__}: Bias not constant for"
f" {consumer.name}, skipping."
)
# Skip to next node, nothing changed so far, no need to
# break here
continue
end_name = add_node.output[0]
# we can only absorb scalar adds
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
if not is_scalar:

# Try to get the parameter tensor for the thresholds: Sanity
# check whether this is present, even though we already
# tested for non-joining
thresholds = model.get_initializer(node.input[1])

# Warn and skip if there is no constant bias present
if thresholds is None:
warnings.warn(
f"{self.__class__.__name__}: Thresholds not"
f" constant for {node.name}, skipping."
)
# Skip to next node, nothing changed so far, no need to
# break here
continue

# Check whether the bias is as scalar as we cannot absorb
# full tensors into node attributes
if not (bias.ndim == 0 or all(x == 1 for x in bias.shape)):
warnings.warn(
f"{self.__class__.__name__}: Bias not scalar"
f" for {consumer.name}, skipping."
)
# Skip to next node, nothing changed so far, no need to
# break here
continue
bias = A.flatten()[0]
# set MultiThreshold bias property
mt_inst = getCustomOp(mt_node)
bias += mt_inst.get_nodeattr("out_bias")
mt_inst.set_nodeattr("out_bias", bias)

# Flatten effectively scalar bias tensors and extract to
# have "plain" scalar
bias = bias.flatten()[0]
# CustomOp instance of the thresholding node required for
# convenient attribute manipulation
threshold_op = getCustomOp(node)
# Shift the output bias of the thresholding operator
out_bias = threshold_op.get_nodeattr("out_bias") + bias
# Derive the new output range due to shifting the bias
# Note: We count thresholds steps on top of the bias
new_min = out_bias
new_max = out_bias + thresholds.shape[-1]

# Allows the signedness to change depending on the new
# output range [new_min,new_max]
if abs(new_min) > abs(new_max):
odt = DataType.get_smallest_possible(new_min)
else:
odt = DataType.get_smallest_possible(new_max)

# Check whether the new range can be represented with the
# derived integer datatype
if not (odt.allowed(new_max) and odt.allowed(new_min)):
# Cannot be represented, warn and skip transforming
warnings.warn(
f"{self.__class__.__name__}: Cannot absorb bias"
f" from {consumer.name} into {node.name}: {bias}"
)
# Skip to the next candidate node
continue

# Remember the old datatype for some further checks and info
old_odt = threshold_op.get_nodeattr("out_dtype")

# Check whether the datatype changes as this is something
# the "user" should be aware of
if odt.name != old_odt:
warnings.warn(
f"{self.__class__.__name__}: Output datatype for"
f" {node.name} changing from {old_odt} to {odt}"
)

# Up until now we did not modify the nodes/grap, just did
# some checks and derive the new bias and datatype. Start
# inserting this back into the graph now...

# Set new bias and datatype attributes into the threshold
# operator
threshold_op.set_nodeattr("out_bias", out_bias)
threshold_op.set_nodeattr("out_dtype", odt.name)
# Remove the bias operator and rewire the graph to skip the
# now-missing node
node.output[0] = consumer.output[0]
graph.node.remove(consumer)
# Update the datatype at the output of the threshold
# operation
model.set_tensor_datatype(node.output[0], odt)

# Graph modified so we need to apply this transformation
# again
graph_modified = True
# compute new DataType for MultiThreshold output
steps = T.shape[-1]
new_min = bias
new_max = steps + bias
odt = DataType.get_smallest_possible(steps).name.replace("UINT", "INT")
odt = DataType[odt]
assert odt.allowed(new_max) and odt.allowed(
new_min
), """Could
not compute new MultiThreshold DataType (min = %d max = %d)""" % (
new_min,
new_max,
)
mt_inst.set_nodeattr("out_dtype", odt.name)
# remove Add node, rewire MultiThreshold
graph.node.remove(add_node)
mt_node.output[0] = end_name
# set datatype
model.set_tensor_datatype(end_name, odt)
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
# Better break now to clean up and recover annotations first
break
# As we might have changes types and removed nodes better redo some
# annotations
model = model.transform(InferDataTypes())
model = model.transform(InferShapes())
# Transformed model and indication whether the transformation should be
# applied again
return model, graph_modified


# Groups inputs by categories, i.e., groups dynamic inputs first, followed by
Expand Down Expand Up @@ -261,7 +337,7 @@ def apply(self, model):


class Absorb1BitMulIntoMatMul(Transformation):
"""Absorb bipolar or binary multiplications into the preciding matrix
"""Absorb bipolar or binary multiplications into the preceding matrix
multiply."""

def apply(self, model):
Expand All @@ -270,16 +346,28 @@ def apply(self, model):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MatMul":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "MatMul" and not model.is_fork_node(n):
matmul_weight_name = n.input[1]
W = model.get_initializer(matmul_weight_name)
Wdt = model.get_tensor_datatype(matmul_weight_name)
assert W is not None, "Initializer for matmul weights is not set."
# Just skip matmuls with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
if is_1bit:
Wnew = A * W
Expand All @@ -298,24 +386,36 @@ def apply(self, model):


class Absorb1BitMulIntoConv(Transformation):
"""Absorb bipolar or binary multiplications into the preciding convolution."""
"""Absorb bipolar or binary multiplications into the preceding convolution."""

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Conv":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "Conv" and not model.is_fork_node(n):
conv_weight_name = n.input[1]
W = model.get_initializer(conv_weight_name)
Wdt = model.get_tensor_datatype(conv_weight_name)
assert W is not None, "Initializer for conv weights is not set."
# Just skip convs with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
is_scalar = np.prod(A.shape) == 1
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
Expand Down
Loading
Loading