diff --git a/setup.cfg b/setup.cfg index 4834011dea..0a0b9eff1d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -164,6 +164,8 @@ exclude = dist .eggs docs/conf.py +per-file-ignores = + src/finn/transformation/streamline/streamline_plus.py: F405, F403 [pyscaffold] # PyScaffold's parameters when the project was created. diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index 2e68de698b..39ef87f81c 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -76,8 +76,8 @@ def apply(self, model): BatchNormToAffine(), ConvertSignToThres(), MoveMulPastMaxPool(), - MoveScalarLinearPastInvariants(), AbsorbSignBiasIntoMultiThreshold(), + MoveScalarLinearPastInvariants(), MoveAddPastMul(), MoveScalarAddPastMatMul(), MoveAddPastConv(), diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 4c280d8f28..55ac575580 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -29,10 +29,11 @@ import numpy as np import qonnx.core.data_layout as DataLayout import warnings + +# Protobuf onnx graph node type +from onnx import NodeProto # noqa from onnx import helper as oh from qonnx.core.datatype import DataType - -# QONNX wrapper of ONNX model graphs from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation @@ -40,70 +41,145 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.util.basic import get_by_name -# Protobuf onnx graph node type -from onnx import NodeProto # noqa +from finn.transformation.util import group_inputs_by_category +# Note: Old name kept for compatibility reasons but actually allows to absorb +# any bias irrespective of signedness which might result in changed signedness +# of the output type class AbsorbSignBiasIntoMultiThreshold(Transformation): """Absorb scalar bias originating from signed int export back into MultiThreshold and re-evaluate the output datatype.""" - def apply(self, model): + def apply(self, model: ModelWrapper): + # Get the model graph out of the model wrapper object graph = model.graph - node_ind = 0 + # Keep track of whether the graph has been modified graph_modified = False - for n in graph.node: - # search for (MultiThreshold, Add) pair - node_ind += 1 + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Only non-branching threshold operations are supported if ( - n.op_type == "MultiThreshold" - and not model.is_fork_node(n) - and not model.is_join_node(n) + node.op_type == "MultiThreshold" + and not model.is_fork_node(node) + and not model.is_join_node(node) ): - consumer = model.find_consumer(n.output[0]) + # We now we are not forking, so there is at most one consumer + consumer = model.find_consumer(node.output[0]) + # At the end of the graph we might have no consumer. If we have + # one, only handle Adds, turn Sub into Add first... if consumer is not None and consumer.op_type == "Add": - mt_node = n - add_node = consumer - threshold_name = mt_node.input[1] - add_weight_name = add_node.input[1] - T = model.get_initializer(threshold_name) - A = model.get_initializer(add_weight_name) - if (A is None) or (T is None): - warnings.warn("Threshold or add bias not constant, skipping") + # Try to get the parameter tensor for the addition: Sanity + # check whether this is present, even though we already + # tested for non-joining + bias = model.get_initializer(consumer.input[1]) + + # Warn and skip if there is no constant bias present + if bias is None: + warnings.warn( + f"{self.__class__.__name__}: Bias not constant for" + f" {consumer.name}, skipping." + ) + # Skip to next node, nothing changed so far, no need to + # break here continue - end_name = add_node.output[0] - # we can only absorb scalar adds - is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) - if not is_scalar: + + # Try to get the parameter tensor for the thresholds: Sanity + # check whether this is present, even though we already + # tested for non-joining + thresholds = model.get_initializer(node.input[1]) + + # Warn and skip if there is no constant bias present + if thresholds is None: + warnings.warn( + f"{self.__class__.__name__}: Thresholds not" + f" constant for {node.name}, skipping." + ) + # Skip to next node, nothing changed so far, no need to + # break here + continue + + # Check whether the bias is as scalar as we cannot absorb + # full tensors into node attributes + if not (bias.ndim == 0 or all(x == 1 for x in bias.shape)): + warnings.warn( + f"{self.__class__.__name__}: Bias not scalar" + f" for {consumer.name}, skipping." + ) + # Skip to next node, nothing changed so far, no need to + # break here continue - bias = A.flatten()[0] - # set MultiThreshold bias property - mt_inst = getCustomOp(mt_node) - bias += mt_inst.get_nodeattr("out_bias") - mt_inst.set_nodeattr("out_bias", bias) + + # Flatten effectively scalar bias tensors and extract to + # have "plain" scalar + bias = bias.flatten()[0] + # CustomOp instance of the thresholding node required for + # convenient attribute manipulation + threshold_op = getCustomOp(node) + # Shift the output bias of the thresholding operator + out_bias = threshold_op.get_nodeattr("out_bias") + bias + # Derive the new output range due to shifting the bias + # Note: We count thresholds steps on top of the bias + new_min = out_bias + new_max = out_bias + thresholds.shape[-1] + + # Allows the signedness to change depending on the new + # output range [new_min,new_max] + if abs(new_min) > abs(new_max): + odt = DataType.get_smallest_possible(new_min) + else: + odt = DataType.get_smallest_possible(new_max) + + # Check whether the new range can be represented with the + # derived integer datatype + if not (odt.allowed(new_max) and odt.allowed(new_min)): + # Cannot be represented, warn and skip transforming + warnings.warn( + f"{self.__class__.__name__}: Cannot absorb bias" + f" from {consumer.name} into {node.name}: {bias}" + ) + # Skip to the next candidate node + continue + + # Remember the old datatype for some further checks and info + old_odt = threshold_op.get_nodeattr("out_dtype") + + # Check whether the datatype changes as this is something + # the "user" should be aware of + if odt.name != old_odt: + warnings.warn( + f"{self.__class__.__name__}: Output datatype for" + f" {node.name} changing from {old_odt} to {odt}" + ) + + # Up until now we did not modify the nodes/grap, just did + # some checks and derive the new bias and datatype. Start + # inserting this back into the graph now... + + # Set new bias and datatype attributes into the threshold + # operator + threshold_op.set_nodeattr("out_bias", out_bias) + threshold_op.set_nodeattr("out_dtype", odt.name) + # Remove the bias operator and rewire the graph to skip the + # now-missing node + node.output[0] = consumer.output[0] + graph.node.remove(consumer) + # Update the datatype at the output of the threshold + # operation + model.set_tensor_datatype(node.output[0], odt) + + # Graph modified so we need to apply this transformation + # again graph_modified = True - # compute new DataType for MultiThreshold output - steps = T.shape[-1] - new_min = bias - new_max = steps + bias - odt = DataType.get_smallest_possible(steps).name.replace("UINT", "INT") - odt = DataType[odt] - assert odt.allowed(new_max) and odt.allowed( - new_min - ), """Could - not compute new MultiThreshold DataType (min = %d max = %d)""" % ( - new_min, - new_max, - ) - mt_inst.set_nodeattr("out_dtype", odt.name) - # remove Add node, rewire MultiThreshold - graph.node.remove(add_node) - mt_node.output[0] = end_name - # set datatype - model.set_tensor_datatype(end_name, odt) - if graph_modified: - model = model.transform(InferDataTypes()) - return (model, graph_modified) + # Better break now to clean up and recover annotations first + break + # As we might have changes types and removed nodes better redo some + # annotations + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + # Transformed model and indication whether the transformation should be + # applied again + return model, graph_modified # Groups inputs by categories, i.e., groups dynamic inputs first, followed by @@ -261,7 +337,7 @@ def apply(self, model): class Absorb1BitMulIntoMatMul(Transformation): - """Absorb bipolar or binary multiplications into the preciding matrix + """Absorb bipolar or binary multiplications into the preceding matrix multiply.""" def apply(self, model): @@ -270,16 +346,28 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "MatMul": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "MatMul" and not model.is_fork_node(n): matmul_weight_name = n.input[1] W = model.get_initializer(matmul_weight_name) Wdt = model.get_tensor_datatype(matmul_weight_name) - assert W is not None, "Initializer for matmul weights is not set." + # Just skip matmuls with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 if is_1bit: Wnew = A * W @@ -298,7 +386,7 @@ def apply(self, model): class Absorb1BitMulIntoConv(Transformation): - """Absorb bipolar or binary multiplications into the preciding convolution.""" + """Absorb bipolar or binary multiplications into the preceding convolution.""" def apply(self, model): graph = model.graph @@ -306,16 +394,28 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Conv": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "Conv" and not model.is_fork_node(n): conv_weight_name = n.input[1] W = model.get_initializer(conv_weight_name) Wdt = model.get_tensor_datatype(conv_weight_name) - assert W is not None, "Initializer for conv weights is not set." + # Just skip convs with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 is_scalar = np.prod(A.shape) == 1 actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py index d297110186..db18aeed39 100644 --- a/src/finn/transformation/streamline/collapse_repeated.py +++ b/src/finn/transformation/streamline/collapse_repeated.py @@ -26,11 +26,24 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Helper for creating ONNX nodes from onnx import helper as oh + +# QONNX arbitrary precision data types from qonnx.core.datatype import DataType + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# QONNX graph transformation base class from qonnx.transformation.base import Transformation + +# QONNX graph transformations for inferring datatypes and shapes from qonnx.transformation.infer_shapes import InferShapes +# Gets items from protobuf by name +from qonnx.util.basic import get_by_name + class CollapseRepeatedOp(Transformation): """Collapse repeated consecutive operations with constant parameters into @@ -106,3 +119,94 @@ class CollapseRepeatedMul(CollapseRepeatedOp): def __init__(self): super().__init__("Mul", lambda x, y: y * x) + + +# Collapses repeated transpose operations into a single transpose operation +# having the same effect +class CollapseRepeatedTranspose(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Transpose operation types + if node.op_type == "Transpose": + # Currently does not handle fork- or join-nodes + if model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # As this is not a fork-node, there can be at most one successor + successor = model.find_direct_successors(node) + # If Transpose is the final operation in the graph, there might + # be no successor + if successor is None: + # Softly skip this node + continue + # Now there is exactly one successor which needs to be extracted + # from the list + successor = successor[0] + # Successor must be a Transpose to be collapsed + if successor.op_type != "Transpose": + # Softly skip this node + continue + # Get the (optional) permutation indices of the first transpose + # in case it is a multi-axis transpose + perm1 = get_by_name(node.attribute, "perm") + # Convert permutation indices to list of integers + perm1 = perm1.ints if perm1 is not None else None + + # Get the (optional) permutation indices of the second transpose + # in case it is a multi-axis transpose + perm2 = get_by_name(successor.attribute, "perm") + # Convert permutation indices to list of integers + perm2 = perm2.ints if perm2 is not None else None + + # Get the shape of the input tensor + shape = model.get_tensor_shape( + # fmt: off + node.input[0], fix_missing_init_shape=True + # fmt: on + ) + # List of dimension indices in order + dims = range(len(shape)) + + # Substitute the permutation indices by the reversed index list + # if they are not given: This is default behavior, see the docs: + # https://onnx.ai/onnx/operators/onnx__Transpose.html + perm1 = list(reversed(dims)) if perm1 is None else perm1 + perm2 = list(reversed(dims)) if perm2 is None else perm2 + + # Combined permutation permutes the first permutation of the + # dimensions according to the second permutation + perm = [perm1[i] for i in perm2] + + # Create a new Transpose operator replacing the other two + transpose = oh.make_node( + # Name of the operator type + "Transpose", + # Connect to the inputs to the first transpose + inputs=node.input, + # Connect to the outputs of the second transpose + outputs=successor.output, + # Insert the new permutation indices + perm=perm, + ) + # Insert the collapsed transpose operator + graph.node.insert(index + 2, transpose) + # Remove the two original transpose operators + graph.node.remove(node) + graph.node.remove(successor) + # Track whether the graph has been modified, never resets to + # False + graph_modified = True + # Break the loop after adding and removing nodes to start over + # with a clean index + break + # Need to redo the shape inference after potentially removing nodes + model = model.transform(InferShapes()) # noqa: Shadows model + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py new file mode 100644 index 0000000000..a392f9a4ef --- /dev/null +++ b/src/finn/transformation/streamline/remove.py @@ -0,0 +1,101 @@ +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# QONNX graph transformation base class +from qonnx.transformation.base import Transformation + +# QONNX graph transformations for inferring datatypes and shapes +from qonnx.transformation.infer_shapes import InferShapes + +# Reuse node removal and rewiring from qonnx +from qonnx.transformation.remove import remove_node_and_rewire + +# Gets items from protobuf by name +from qonnx.util.basic import get_by_name + + +# Removes identity reshape operations, i.e., Reshape where input shape is the +# same as the target shape +class RemoveIdentityReshape(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Reshape operation types + if node.op_type == "Reshape": + # Currently does not handle join-nodes + if model.is_join_node(node): + # Softly skip this node + continue + # Second input to the reshape operation is the target shape + shape = model.get_initializer(node.input[1]) + # If the initializer is present, this is a constant shape + # reshape which can be removed if it does not reshape + if shape is not None: + # Get the shape of the input to the reshape + inp = model.get_tensor_shape(node.input[0]) + # If input and target shape are the same, this is an + # identity operation + if len(shape) == len(inp) and (shape == inp).all(): # noqa + # Remove and rewire this node + remove_node_and_rewire(model, node) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Need to redo the shape inference after potentially removing nodes + model = model.transform(InferShapes()) # noqa: Shadows from outer scope + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified + + +# Removes identity transpose operations, i.e., Transpose where input order is +# the same as the target permutation +class RemoveIdentityTranspose(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Transpose operation types + if node.op_type == "Transpose": + # Currently does not handle join-nodes + if model.is_join_node(node): + # Softly skip this node + continue + # Get the (optional) permutation indices of the transpose in + # case it is a multi-axis transpose + perm = get_by_name(node.attribute, "perm") + # If the permutation indices are given, we can check whether + # they are in order making this an identity transpose + # Note: Without perm attribute, this is implicitly reversing the + # axes, i.e., not an identity transpose + if perm is not None: + # Convert permutation indices to list of integers + perm = perm.ints + # Get the shape of the input tensor + shape = model.get_tensor_shape( + # fmt: off + node.input[0], fix_missing_init_shape=True + # fmt: on + ) + # If the permutation indices cover the input shape in order, + # this transpose does nothing + if perm == [i for i in range(len(shape))]: + # Remove and rewire this node + remove_node_and_rewire(model, node) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Need to redo the shape inference after potentially removing nodes + model = model.transform(InferShapes()) # noqa: Shadows model + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 2c54518edf..792aac5cb5 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -43,6 +43,9 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.util.basic import get_by_name +# Groups node inputs by dynamic vs. initializer category +from finn.transformation.util import group_inputs_by_category + class MoveAddPastMul(Transformation): """Move add operations past multiply operations on linear segments of the graph. @@ -116,58 +119,133 @@ def apply(self, model): return model, graph_modified +# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1 +def is_scalar(tensor): + return tensor is not None and all(x == 1 for x in tensor.shape) + + +# Tests whether a node is a scalar multiplication with a constant scale factor +def is_const_scalar_mul(node, model): + # Only handle existing Mul type nodes + if node is not None and node.op_type == "Mul": + # The constant must be an initializer + # Note: Assumes the constant parameter to always be the second input + scale = model.get_initializer(node.input[1]) + # Test for existence of a constant scale factor + return scale is not None and is_scalar(scale) + # Did not match the operator type + return False + + +# Refactored version of the MoveScalarMulPastMatMul transform capable of +# transforming two-input MatMul, like those being part of the attention operator class MoveScalarMulPastMatMul(Transformation): """Move scalar mul operations past matmul operations. We want to have muls next to each other such that they can be collapsed into a single mul.""" + # Applies the transform to a whole model graph def apply(self, model): + # Get the model graph out of the model wrapper object graph = model.graph - node_ind = 0 + # Keep track of whether the graph has been modified graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n): - consumer = model.find_consumer(n.output[0]) - if ( - consumer is not None - and consumer.op_type == "MatMul" - and not model.is_join_node(consumer) - ): - mul_weight_name = n.input[1] - matmul_weight_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - W = model.get_initializer(matmul_weight_name) - if (A is None) or (W is None): - warnings.warn("MatMul or Mul params are not constant, skipping") + + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # First pattern matching condition: For the transform to be + # applicable, the node has to be a MatMul operator + if node.op_type == "MatMul": + # Note: When touching the following code, remember to treat both + # branches equivalently! + # TODO: Can this be enforced or at least be made easier by + # extracting common code patterns to a function? + + # Get the left hand side and right hand side inputs + # Note: Assumes the ordering of left to right inputs to match + # indices 0 to 1. However, it does not "hurt" if it is + # reversed as both sides are treated equivalently. + lhs = model.find_producer(node.input[0]) + rhs = model.find_producer(node.input[1]) + + # Give precedence to the left hand side input testing for the + # presence of a scalar multiplication + if is_const_scalar_mul(lhs, model): + # Cannot handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. + if model.is_fork_node(lhs): + # Softly skip this node continue - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - mm_out_shape = model.get_tensor_shape(end_name) - if all(x == 1 for x in A.shape): - # if the mul is scalar, we can simply swap the order of ops - # make and insert new nodes - new_matmul = oh.make_node( - "MatMul", - [start_name, matmul_weight_name], - [middle_name], - name=consumer.name, - ) - new_mul = oh.make_node( - "Mul", - [middle_name, mul_weight_name], - [end_name], - name=n.name, - ) - graph.node.insert(node_ind, new_matmul) - graph.node.insert(node_ind + 1, new_mul) - model.set_tensor_shape(middle_name, mm_out_shape) - # remove old nodes - graph.node.remove(n) - graph.node.remove(consumer) - graph_modified = True + # Unpack the connection pattern of a scalar mul feeding the + # lhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = lhs.input[0], lhs.input[1], node.input[1] + # Names of the intermediate and the global output + m, o = lhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, c], [m], node.name) + mul = oh.make_node("Mul", [m, b], [o], lhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(lhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + # Cannot further modify the node (i.e., the rhs) as the + # index and state of the nodes changed and need to be + # queried again from the graph.node at the start of the next + # iteration. + continue + + # Next try whether the right hand side matches the pattern of a + # scalar multiplication + if is_const_scalar_mul(rhs, model): + # Cannot handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. + if model.is_fork_node(rhs): + # Softly skip this node + continue + # Unpack the connection pattern of a scalar mul feeding the + # rhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = node.input[0], rhs.input[0], rhs.input[1] + # Names of the intermediate and the global output + m, o = rhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, b], [m], node.name) + mul = oh.make_node("Mul", [m, c], [o], rhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(rhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + + # Finalize the transformation by inferring shapes again (as these might + # have changed) model = model.transform(InferShapes()) - return (model, graph_modified) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified class MoveScalarAddPastMatMul(Transformation): @@ -1765,3 +1843,601 @@ def apply(self, model: ModelWrapper): # noqa # Return the transformed model and indicate whether the graph # actually has been transformed return model, graph_modified + + +# Moves a transpose operator past elementwise addition or multiplication +class MoveTransposePastEltwise(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Transpose operation types + if node.op_type == "Transpose": + # Currently does not handle fork- or join-nodes + if model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # As this is not a fork-node, there can be at most one successor + successor = model.find_direct_successors(node) + # If Transpose is the final operation in the graph, there might + # be no successor + if successor is None: + # Softly skip this node + continue + # Now there is exactly one successor which needs to be extracted + # from the list + successor = successor[0] + # Applies to elementwise add and mul operations + if successor.op_type in {"Add", "Mul"}: + # Get names of all tensors involved in connecting the nodes + inp = node.input[0] + mid = node.output[0] + out = successor.output[0] + + # y = x^T + a <=> y = (x + a^T)^T + + # Assume left-to-right order of input to the Add operator + xt, a = successor.input + # Check whether the assumption holds true + if xt != mid: + # Leaves only the option of a and xt commuting + xt, a = a, xt + # If this assumption still does not hold true, something is + # wrong with the graph + assert xt == mid, f"Messed up graph pattern at {node.name}" + + # Get the (optional) permutation indices of the transpose in + # case it is a multi-axis transpose + perm = get_by_name(node.attribute, "perm") + # Convert permutation indices to list of integers + perm = list(perm.ints) if perm is not None else None + + # Inverse permutation needs to be applied to the initializer + # fmt: off + inverse_perm = None if not perm else [ + perm.index(i) for i in range(len(perm)) + ] + # fmt: on + + # This transformation does only apply to Add nodes where the + # second input is a constant initializer + if (value := model.get_initializer(a)) is not None: + # Do not transpose scalar or effectively scalar + # initializers + if not (value.shape is None or all(x == 1 for x in value.shape)): + # Transpose the initializer and re-insert into the + # model + # fmt: off + model.set_initializer( + a, value.transpose(inverse_perm) + ) + # fmt: on + # Rewire the graph to feed original input and the + # transposed initializer into the Add node first + successor.input[:] = [inp, a] + # Repurpose the middle tensor for the output of the + # addition + successor.output[0] = mid + # The Transpose operator now gets the middle tensor as + # its input + node.input[0] = mid + # Transpose now produces the original output tensor + node.output[0] = out + # Delete the shape annotation of the connecting tensors + # to be re-done later + model.set_tensor_shape(inp, None) + model.set_tensor_shape(mid, None) + model.set_tensor_shape(out, None) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Break the loop after deleting shape annotations to + # immediately re-do these before changing the next + # operator + break + # Need to redo the shape inference after potentially removing nodes + model = model.transform(InferShapes()) # noqa: Shadows model + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified + + +# Moves elementwise additions past MatMul operations: Applicable if each +# operation has one initializer input +class MoveAddPastMatMul(Transformation): + # Applies the transform to a whole model graph # noqa: Duplicate + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Add operations + if node.op_type == "Add": + # If the add is a join operation, we do not have a constant + # added to the input + if model.is_join_node(node): + # Skip transforming this + continue + # If the Add is a fork operation we should first distribute the + # Add into the branches + if model.is_fork_node(node): + # Issue a warning to make the use aware of this potential + # transformation if the fork is moved first + warnings.warn( + f"{self.__class__.__name__}:" + f" Skipping near match: {node.name} is a fork-node," + f" try MoveLinearPastFork first" + ) + # Skip transforming this node as moving this would lead + # to messed up or detached graph + continue + # Decompose the inputs into the dynamic and the constant + # initializer input + (x_name,), (c_name,) = group_inputs_by_category(node, model) + # Now check the successor node which must be a MatMul + consumer = model.find_direct_successors(node) + # If there is no consumer, this Add seems to be last node of the + # graph + if not consumer: + # Skip transforming this + continue + # There must be exactly one consumer now + consumer = consumer[0] + # This transformation only applies to Add in front of MatMul + if not consumer.op_type == "MatMul": + # Skip this if not MatMul + continue + # MatMul may not be a join operation to apply this + # transformation + if model.is_join_node(consumer): + # Skip transforming without warning (there is nothing we can + # do about this) + continue + # Decompose the inputs to the MatMul to get the weight tensor + # name (the other input is the output of the Add) + _, (w_name,) = group_inputs_by_category(consumer, model) + # Read the weights and the constant addition tensor + w = model.get_initializer(w_name) + c = model.get_initializer(c_name) + # Determine whether the weights are the left or right input to + # the MatMul + left = w_name == consumer.input[0] + # Apply the weights to the constant tensor + c = np.matmul(w, c) if left else np.matmul(c, w) + # Insert the transformed tensor back into the mode as an + # initializer + model.set_initializer(c_name, c) + # The connecting tensors of this pattern + inp = x_name + mid = node.output[0] + out = consumer.output[0] + # Rewire the graph pattern connecting the input to the MatMul + # and the MatMul output to the Add node + consumer.input[1 if left else 0] = inp + # The Add now produces the original MatMul output + node.output[0] = out + # The middel tensor connects to the Add input + node.input[0 if node.input[0] == x_name else 1] = mid + # The MatMul feeds the middle tensors + consumer.output[0] = mid + # Delete the shape annotation of the connecting tensors + # to be re-done later + model.set_tensor_shape(mid, None) + model.set_tensor_shape(out, None) + # Delete the type annotations of the connecting tensors + # to be re-done later + # model.set_tensor_datatype(mid, None) + # model.set_tensor_datatype(out, None) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Break the loop after deleting shape annotations to + # immediately re-do these before changing the next + # operator + break + # Redo datatype and shape annotations + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the transformation + # needs to be applied again + return model, graph_modified + + +# Moves constant elementwise multiplication past another joining multiplication +class MoveConstMulPastJoinMul(Transformation): + # Applies the transform to a whole model graph # noqa: Duplicate + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Mul operation types + if node.op_type == "Mul": + # Currently does not handle fork- or join-nodes + if model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # As this is not a fork-node, there can be at most one successor + successor = model.find_direct_successors(node) + # If Squeeze is the final operation in the graph, there might + # be no successor + if successor is None: + # Softly skip this node + continue + # Now there is exactly one successor which needs to be extracted + # from the list + successor = successor[0] + # Applies to Multiplications + if successor.op_type in {"Mul"}: + # Applies only if the second multiplication is a join-node + if model.is_join_node(successor): + # Get names of all tensors involved in connecting the + # nodes + inp = node.input[0] # noqa: Duplicate + mid = node.output[0] + out = successor.output[0] + # Need to match the correct input of the joining second + # multiplication + for i, name in enumerate(successor.input): + # If the successors input currently matches the + # intermediate tensors, this input needs to be + # rewired + if name == mid: + # Rewire the graph to feed original into the + # second Mul node first + successor.input[i] = inp + # Note: Do not break here as it is perfectly + # legal to connect the same tensor multiple + # times to different inputs + # Repurpose the middle tensor for the output of the + # second Mul + successor.output[0] = mid + # The first Mul operator now gets the middle tensor as + # its input + node.input[0] = mid + # The first Mul now produces the original output tensor + node.output[0] = out + # Delete the shape annotation of the connecting tensors + # to be re-done later + model.set_tensor_shape(mid, None) + model.set_tensor_shape(out, None) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Break the loop after deleting shape annotations to + # immediately re-do these before changing the next + # operator + break + # Redo datatype and shape annotations + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the transformation + # needs to be applied again + return model, graph_modified + + +# Moves elementwise multiplication past elementwise addition if one input to +# each of the operators is a known constant +# Note: Reverse of MoveAddPastMul +class MoveMulPastAdd(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Mul operation types + if node.op_type == "Mul": + # Currently does not handle fork- or join-nodes + if model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # As this is not a fork-node, there can be at most one successor + successor = model.find_direct_successors(node) + # If Squeeze is the final operation in the graph, there might + # be no successor + if successor is None: + # Softly skip this node + continue + # Now there is exactly one successor which needs to be extracted + # from the list + successor = successor[0] + # Applies to additions + if successor.op_type in {"Add"}: + # The addition may not join as we need to know the second + # input + if not model.is_join_node(successor): + # Get the constant initializer tensors for both + # operations: y = s * x + b + _, s_name = group_inputs_by_category(node, model) + _, b_name = group_inputs_by_category(successor, model) + # Skip if either node has no constant initializer + if not s_name or not b_name: + # Skip without warning ok? + continue + # There must be exactly one constant per operations + assert len(s_name) == 1, f"To many constant inputs for {node}" + assert len(b_name) == 1, f"To many constant inputs for {successor}" + # Now read the initializer tensors + s = model.get_initializer(*s_name) + b = model.get_initializer(*b_name) + # Update the addition initializer according to the + # distributive law + model.set_initializer(*b_name, b / s) + # Get names of all tensors involved in connecting the + # nodes + inp = node.input[0] # noqa: Duplicate + mid = node.output[0] + out = successor.output[0] + # Rewire the graph to feed original input into the + # Add node first + successor.input[0] = inp + # Repurpose the middle tensor for the output of the Add + successor.output[0] = mid + # The Mul operator now gets the middle tensor as its + # input + node.input[0] = mid + # Mul now produces the original output tensor + node.output[0] = out + # Delete the shape annotation of the connecting tensors + # to be re-done later + model.set_tensor_shape(mid, None) + model.set_tensor_shape(out, None) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Break the loop after deleting shape annotations to + # immediately re-do these before changing the next + # operator + break + # Redo datatype and shape annotations + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the transformation + # needs to be applied again + return model, graph_modified + + +# Moves scalar linear elementwise operations past fork nodes, applies to Add, +# Mul, Sub, Div, etc. +class MoveScalarLinearPastFork(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Mul-like and Add-like operation types + if node.op_type in {"Add", "Sub", "Mul", "Div"}: + # Only handles non-joining forks for now + if not model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # Only handles one forking output for now + if len(node.output) > 1: + # Softly skip this node + continue + # Left and right side of the operation + (inp,), (const,) = group_inputs_by_category(node, model) + # Test whether the node initializer is a scalar... + if not is_scalar(model.get_initializer(const)): + # Softly skip this node + continue + # We need to insert a replica of this operation in front of each + # consumer node + for consumer in model.find_direct_successors(node): + # Create an exact replica of this operator + copy = deepcopy(node) + # Insert a new unique tensor connecting the output of the + # copy to the consumer + copy.output[0] = model.make_new_valueinfo_name() + # The original node might be connecting to multiple inputs + # of the consumer... + for idx, inp in enumerate(consumer.input): + # Find each instance of connection from original node + if inp == node.output[0]: + # Rewire to connect to the replica + consumer.input[idx] = copy.output[0] + # Insert the new replica node into the graph + graph.node.insert(index + 1, copy) + # Remove the original node from the graph + graph.node.remove(node) + # Redo datatype and shape annotations + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the transformation + # needs to be applied again + return model, graph_modified + + +# Moves scalar linear channel-wise operations past fork nodes, applies to Add, +# Mul, Sub, Div, etc. +class MoveChannelwiseLinearPastFork(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Mul-like and Add-like operation types + if node.op_type in {"Add", "Sub", "Mul", "Div"}: + # Only handles non-joining forks for now + if not model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # Only handles one forking output for now + if len(node.output) > 1: + # Softly skip this node + continue + + # Left and right side of the operation + (inp,), (const,) = group_inputs_by_category(node, model) + + # First try to consider the tensor layout of the input for + # determining the number of input channels + layout = model.get_tensor_layout(inp) + # If there is no layout annotation, guess based on rank of the + # tensor + if layout is None: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"} + # Lookup the layout required by this input shape + layout = rank_to_layout[len(model.get_tensor_shape(inp))] + # If there is a layout annotation, use this to determine the + # index of the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel + # dimension, fall back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"{self.__class__.__name__}: No layout for {inp}:" + f" Assuming channel dimension at index {cdim}" + ) + + # Tests whether two shapes can be broadcast according to NumPy + # semantics + def can_broadcast_to(lhs, rhs): + # Broadcasting might raise an exception + try: + # Try broadcasting the shapes + if np.broadcast_to(np.zeros(lhs), rhs).shape == rhs: + # These tensors can be broadcast, preserving the + # left-hand-side shape + return True + # These tensors cannot be broadcast + return False + # Failing to broadcast the tensors raises ValueError + except ValueError: + # These tensors cannot be broadcast + return False + + # Per-tensor or per-channel means we have some parameter tensor + # which can be broadcast to the channel dimension of the output + if not can_broadcast_to( + model.get_tensor_shape(const), (model.get_tensor_shape(node.output[0])[cdim],) + ): + # Issue a warning to the user, so they are aware of this + warnings.warn(f"{self.__class__.__name__}: Not channel-wise {const}:") + # Softly skip this node + continue + + # We need to insert a replica of this operation in front of each + # consumer node + for consumer in model.find_direct_successors(node): + # Create an exact replica of this operator + copy = deepcopy(node) + # Insert a new unique tensor connecting the output of the + # copy to the consumer + copy.output[0] = model.make_new_valueinfo_name() + # The original node might be connecting to multiple inputs + # of the consumer... + for idx, inp in enumerate(consumer.input): + # Find each instance of connection from original node + if inp == node.output[0]: + # Rewire to connect to the replica + consumer.input[idx] = copy.output[0] + # Insert the new replica node into the graph + graph.node.insert(index + 1, copy) + # Remove the original node from the graph + graph.node.remove(node) + # Redo datatype and shape annotations + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the transformation + # needs to be applied again + return model, graph_modified + + +# Moves scale factor, i.e., scalar Mul and Div, past Im2Col (and Col2Im): These +# cannot be handled by MoveScalarLinearPastInvariants as potential padding makes +# Add-Im2Col not commute to Im2Col-Add +class MoveScalesPastIm2Col(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to Mul operation types + if node.op_type in {"Mul", "Div"}: + # Cannot handle fork- or join-multiplications + if model.is_fork_node(node) or model.is_join_node(node): + # Softly skip this node + continue + # Only handles one forking output for now + if len(node.output) > 1: + # Softly skip this node + continue + # The first input must be dynamically received from upstream + if model.get_initializer(node.input[0]) is not None: + # Softly skip this node + continue + # Test whether the node initializer is a scalar... + if not is_scalar(model.get_initializer(node.input[1])): + # Softly skip this node + continue + # As this is not a fork-node, there can be at most one successor + successor = model.find_direct_successors(node) + # If this is the final operation in the graph, there might be no + # successor + if successor is None: + # Softly skip this node + continue + # Now there is exactly one successor which needs to be extracted + # from the list + successor = successor[0] + # Handle both, Im2Col and the inverse Col2Im, as well as padding + if successor.op_type in {"Im2Col", "Col2Im", "Pad"}: + # Get names of all tensors involved in connecting the + # nodes + inp = node.input[0] # noqa: Duplicate + mid = node.output[0] + out = successor.output[0] + # Rewire the graph to feed original input into the + # Add node first + successor.input[0] = inp + # Repurpose the middle tensor for the output of the Add + successor.output[0] = mid + # The Mul operator now gets the middle tensor as its + # input + node.input[0] = mid + # Mul now produces the original output tensor + node.output[0] = out + # Delete the shape annotation of the connecting tensors + # to be re-done later + model.set_tensor_shape(mid, None) + model.set_tensor_shape(out, None) + # Track whether the graph has been modified, never + # resets to False + graph_modified = True + # Break the loop after deleting shape annotations to + # immediately re-do these before changing the next + # operator + break + # Redo datatype and shape annotations + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the transformation + # needs to be applied again + return model, graph_modified diff --git a/src/finn/transformation/streamline/streamline_plus.py b/src/finn/transformation/streamline/streamline_plus.py new file mode 100644 index 0000000000..1ae9ca75b9 --- /dev/null +++ b/src/finn/transformation/streamline/streamline_plus.py @@ -0,0 +1,112 @@ +# Disable formatter as black messes with the nesting below, spreading ([ and ]) +# to multiple lines... +# fmt: off + +# Exhaustive composition of ONNX graph transformation +from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine +from qonnx.transformation.composed import ComposedTransformation + +# Some extra QONNX conversion, streamlining transformations +from qonnx.transformation.general import ConvertDivToMul, ConvertSubToAdd + +from finn.transformation.streamline.absorb import * +from finn.transformation.streamline.collapse_repeated import * +from finn.transformation.streamline.remove import * + +# Import whole submodules of basic streamlining transformations +from finn.transformation.streamline.reorder import * + +# Some more specialized streamlining transformations +from finn.transformation.streamline.round_thresholds import ( # noqa: isort + RoundAndClipThresholds, +) +from finn.transformation.streamline.sign_to_thres import ConvertSignToThres + + +# Define a set of custom streamlining transformations: These are applied once +# during the actual streamlining step and once after converting attention to +# hardware (the associated cleanup afterward might enable some Streamlining +# transformations once again) +def StreamlinePlus(): # noqa: Uppercase + # Return a set of exhaustively applied transformations + return ComposedTransformation([ + # On skip-connections: prefer pushing scalar multiplication forward + # before MoveAddPastMul + MoveMulPastFork(), + # The "standard" set of FINN streamlining transformations or at least + # inspired by them but applied exhaustively until none of them changes + # the graph anymore. + # Note: Covers most parts of non-branching linear topologies + ComposedTransformation([ + ConvertSubToAdd(), + ConvertDivToMul(), + BatchNormToAffine(), + ConvertSignToThres(), + MoveMulPastMaxPool(), + AbsorbSignBiasIntoMultiThreshold(), + MoveScalarLinearPastInvariants(), + MoveAddPastMul(), + MoveScalarAddPastMatMul(), + MoveAddPastConv(), + MoveScalarMulPastMatMul(), + MoveScalarMulPastConv(), + MoveAddPastMul(), + CollapseRepeatedAdd(), + CollapseRepeatedMul(), + MoveMulPastMaxPool(), + AbsorbAddIntoMultiThreshold(), + FactorOutMulSignMagnitude(), + AbsorbMulIntoMultiThreshold(), + Absorb1BitMulIntoMatMul(), + Absorb1BitMulIntoConv(), + ]), + # Streamlining scales and biases forward through residual topologies + # Note: This mostly covers forking and joining operations + ComposedTransformation([ + # Note: This is probably the most common way of joining skip + # connections, i.e., this corresponds to the original residual + # addition, i.e., y = f(x) + x + MoveLinearPastEltwiseAdd(), + MoveChannelwiseLinearPastFork(), + MoveScalarLinearPastInvariants(), + MoveMulPastFork(), + MoveMulPastJoinAdd(), + MoveAddPastJoinAdd(), + # Note: This brings constant Muls (i.e., quantizer scales to be + # removed) forward through joining Muls (i.e., those ending up + # as actual hardware operators). + MoveConstMulPastJoinMul(), + ]), + # Streamlining scales and biases forward through shape/layout changing + # operations, i.e., mostly transposes + ComposedTransformation([ + # Convolution inputs and padding + MoveScalesPastIm2Col(), + # Streamlining for Split and Concat operations + MoveScalarLinearPastSplit(), + MoveAffinePastJoinConcat(), + MoveMulPastJoinConcat(), + MoveAddPastJoinConcat(), + # Move transposes around to some place where they could be removed + # later, i.e., where they collapse into identities + MoveTransposePastFork(), + MoveTransposePastSplit(), + MoveTransposePastJoinConcat(), + MoveTransposePastEltwise(), + MoveTransposePastJoinMul(), + MoveTransposePastJoinAdd(), + CollapseRepeatedTranspose(), + # Remove identity shape/layout transformations + RemoveIdentityTranspose(), + RemoveIdentityReshape(), + # Squeeze operators can be moved past the thresholding + MoveSqueezePastMultiThreshold(), + # A certain type of 4d-layout transpose can be absorbed (actually + # moved past) MultiThreshold operations + AbsorbTransposeIntoMultiThreshold(), + ]), + # Only round and clip after all streamlining transformations have + # been applied exhaustively. + # Note: Might still enable another round of streamlining. + RoundAndClipThresholds(), + ]) diff --git a/src/finn/transformation/util.py b/src/finn/transformation/util.py new file mode 100644 index 0000000000..1e9ae1817d --- /dev/null +++ b/src/finn/transformation/util.py @@ -0,0 +1,143 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Protobuf onnx graph node type +from onnx import NodeProto + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + + +# Tests whether a node is a multi-threshold operation +def is_threshold(node: NodeProto): + return node.op_type == "MultiThreshold" + + +# Tests whether a node is a join-node MatMul operation, i.e., a MatMul with two +# runtime inputs but no weights initializers +def is_join_matmul(node: NodeProto, model: ModelWrapper): # noqa + # Only handle existing MatMul type nodes + if node is not None and node.op_type in {"MatMul"}: + # No input must have an initializer + return all(model.get_initializer(i) is None for i in node.input) + # Did not match the operator type + return False + + +# Tests whether a node is a MatMul operator +def is_matmul(node: NodeProto): + # Node must exist and be of type MatMul + return node is not None and node.op_type in {"MatMul"} + + +# Tests whether a node is a Softmax operator +def is_softmax(node: NodeProto): + # Node must exist and be of type Softmax + return node is not None and node.op_type in {"Softmax"} + + +# Tests whether a node is an element-wise Mul +def is_mul(node: NodeProto): + # Node must exist and be of type Mul + return node is not None and node.op_type in {"Mul"} + + +# Tests whether a node is an element-wise Add +def is_add(node: NodeProto): + # Node must exist and be of type Add + return node is not None and node.op_type in {"Add"} + + +def is_end(node: NodeProto, model: ModelWrapper): # noqa + return node is not None and not model.find_direct_predecessors(node) + + +# Follow all input branches of a node until reaching a matmul +def all_upstream_to_matmul(node: NodeProto, model: ModelWrapper): # noqa + # Check whether the node is either a matmul node or the end of the graph + def is_matmul_or_end(n: NodeProto): + return is_matmul(n) or is_end(n, model) + + # Enumerate all inputs and collect everything upstream until finding the + # next matmul operation + return (model.find_upstream(i, is_matmul_or_end, True) for i in node.input) + + +# Projects a list of ONNX graph nodes to the string representation of the +# operator types +def op_types(nodes: list[NodeProto]) -> list[str]: + return [node.op_type if node is not None else "None" for node in nodes] + + +# Tests whether a node is a Reshape operator +def is_reshape(node: NodeProto): + return node is not None and node.op_type in {"Reshape"} + + +# Tests whether a node is a Transpose operator +def is_transpose(node: NodeProto): + return node is not None and node.op_type in {"Transpose"} + + +# Tests whether a node is a Reshape-Transpose operator chain +def is_reshape_transpose(node: NodeProto, model: ModelWrapper): # noqa + # Reshape-transpose pattern detection is triggered by detecting a reshape + # operation + if is_reshape(node): + # The reshape may not be a join or fork node + if model.is_join_node(node) or model.is_fork_node(node): + # Reject detection of the pattern + return False + # Get the single successor node + transpose = model.find_direct_successors(node)[0] + # The consumer must be Transpose finalizing the reshaping + if not is_transpose(transpose): + # Reject detection of the pattern + return False + # The transpose may not fork or join either + if model.is_join_node(transpose) or model.is_fork_node(transpose): + # Reject detection of the pattern + return False + # Accept detecting the pattern + return True + # Reject detection of the pattern + return False + + +# Tests whether a node is a Transpose-Reshape operator chain +def is_transpose_reshape(node: NodeProto, model: ModelWrapper): # noqa + # Transpose-Reshape pattern detection is triggered by detecting a transpose + # operation + if is_transpose(node): + # The transpose may not be a join or fork node + if model.is_join_node(node) or model.is_fork_node(node): + # Reject detection of the pattern + return False + # Get the single successor node + reshape = model.find_direct_successors(node)[0] + # The consumer must be a reshape finalizing the transpose-reshape + if not is_reshape(reshape): + # Reject detection of the pattern + return False + # The reshape may not fork or join either + if model.is_join_node(reshape) or model.is_fork_node(reshape): + # Reject detection of the pattern + return False + # Accept detecting the pattern + return True + # Reject detection of the pattern + return False + + +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by +# initializers. Keeps order of inputs in each category. +def group_inputs_by_category(node: NodeProto, model: ModelWrapper): # noqa + # First select all dynamic inputs, which are those without initializer + # tensor + dynamics = [i for i in node.input if model.get_initializer(i) is None] + # Select all input which are initializers, which, by exclusion, are all + # those not among the dynamic inputs + initializers = [i for i in node.input if i not in dynamics] + # Return lists of dynamic anc initializer inputs + return dynamics, initializers diff --git a/tests/transformation/streamline/test_move_scalar_past_matmul.py b/tests/transformation/streamline/test_move_scalar_past_matmul.py index e4f4357fff..515e9b9462 100644 --- a/tests/transformation/streamline/test_move_scalar_past_matmul.py +++ b/tests/transformation/streamline/test_move_scalar_past_matmul.py @@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul(): assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] +@pytest.mark.streamline +def test_move_scalar_mul_past_join_matmul(): + top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2]) + top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1]) + mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1]) + mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1]) + modelproto = qonnx_make_model( + oh.make_graph( + name="test", + inputs=[top_in1, top_in2], + outputs=[top_out], + value_info=[mul1_param, mul2_param], + nodes=[ + oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]), + oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]), + oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32)) + new_model = model.transform(MoveScalarMulPastMatMul()) + inp_dict = { + "top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32), + "top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32), + } + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] + assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0] + + @pytest.mark.streamline def test_move_scalar_add_past_matmul(): top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])