diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index dc92238010..368332162c 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -27,7 +27,7 @@ from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb -from pytensor.tensor.basic import alloc, second +from pytensor.tensor.basic import alloc, join, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -2018,6 +2018,42 @@ def broadcast_with_others(a, others): return brodacasted_vars +def concat_with_broadcast(tensor_list, axis=0): + """ + Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align. + """ + if not tensor_list: + raise ValueError("Cannot concatenate an empty list of tensors.") + + ndim = tensor_list[0].ndim + if not all(t.ndim == ndim for t in tensor_list): + raise TypeError( + "Only tensors with the same number of dimensions can be concatenated. " + f"Input ndims were: {[x.ndim for x in tensor_list]}" + ) + + axis = normalize_axis_index(axis=axis, ndim=ndim) + non_concat_shape = [1 if i != axis else None for i in range(ndim)] + + for tensor_inp in tensor_list: + for i, (bcast, sh) in enumerate( + zip(tensor_inp.type.broadcastable, tensor_inp.shape) + ): + if bcast or i == axis: + continue + non_concat_shape[i] = sh + + assert non_concat_shape.count(None) == 1 + + bcast_tensor_inputs = [] + for tensor_inp in tensor_list: + # We modify the concat_axis in place, as we don't need the list anywhere else + non_concat_shape[axis] = tensor_inp.shape[axis] + bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) + + return join(axis, *bcast_tensor_inputs) + + __all__ = [ "searchsorted", "cumsum", @@ -2035,6 +2071,7 @@ def broadcast_with_others(a, others): "ravel_multi_index", "broadcast_shape", "broadcast_to", + "concat_with_broadcast", "geomspace", "logspace", "linspace", diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e2fe086c00..acc8becbfb 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -32,18 +32,21 @@ moveaxis, ones_like, register_infer_shape, + split, switch, zeros, zeros_like, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_arrays +from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast from pytensor.tensor.math import ( Dot, Prod, Sum, _conj, + _dot, _matmul, add, digamma, @@ -96,6 +99,7 @@ from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( complex_dtypes, @@ -146,6 +150,68 @@ def local_0_dot_x(fgraph, node): return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] +@register_stabilize +@node_rewriter([Blockwise]) +def local_block_diag_dot_to_dot_block_diag(fgraph, node): + r""" + Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` + + BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity + of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than + a single dot on the larger matrix. + """ + if not isinstance(node.op.core_op, BlockDiagonal): + return + + # Check that the BlockDiagonal is an input to a Dot node: + for client in itertools.chain.from_iterable( + get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2] + ): + if client.op not in (_dot, _matmul): + continue + + [blockdiag_result] = node.outputs + blockdiag_inputs = node.inputs + + dot_op = client.op + + try: + client_idx = client.inputs.index(blockdiag_result) + except ValueError: + # If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a + # DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the + # blockdiag result. + for ancestor in client.inputs: + if ancestor.owner and blockdiag_result in ancestor.owner.inputs: + client_idx = client.inputs.index(ancestor) + break + + other_input = client.inputs[1 - client_idx] + + split_axis = -2 if client_idx == 0 else -1 + split_size_axis = -1 if client_idx == 0 else -2 + + other_dot_input_split = split( + other_input, + splits_size=[ + component.shape[split_size_axis] for component in blockdiag_inputs + ], + n_splits=len(blockdiag_inputs), + axis=split_axis, + ) + + split_dot_results = [ + dot_op(component, other_split) + if client_idx == 0 + else dot_op(other_split, component) + for component, other_split in zip(blockdiag_inputs, other_dot_input_split) + ] + new_output = concat_with_broadcast(split_dot_results, axis=split_axis) + + copy_stack_trace(node.outputs[0], new_output) + return {client.outputs[0]: new_output} + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): @@ -2582,7 +2648,6 @@ def add_calculate(num, denum, aslist=False, out_type=None): name="add_canonizer_group", ) - register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") @@ -3720,7 +3785,6 @@ def logmexpm1_to_log1mexp(fgraph, node): ) register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") - # log(sigmoid(x) / (1 - sigmoid(x))) -> x # i.e logit(sigmoid(x)) -> x local_logit_sigmoid = PatternNodeRewriter( @@ -3734,7 +3798,6 @@ def logmexpm1_to_log1mexp(fgraph, node): register_canonicalize(local_logit_sigmoid) register_specialize(local_logit_sigmoid) - # sigmoid(log(x / (1-x)) -> x # i.e., sigmoid(logit(x)) -> x local_sigmoid_logit = PatternNodeRewriter( @@ -3775,7 +3838,6 @@ def local_useless_conj(fgraph, node): register_specialize(local_polygamma_to_tri_gamma) - local_log_kv = PatternNodeRewriter( # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x # During stabilize -x is converted to -1.0 * x diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 9f6238ae40..73ca1c8a72 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,8 +2,8 @@ from pytensor.graph import node_rewriter from pytensor.tensor import ( broadcast_to, + concat_with_broadcast, expand_dims, - join, moveaxis, specify_shape, squeeze, @@ -74,28 +74,7 @@ def lower_concat(fgraph, node): # Convert input XTensors to Tensors and align batch dimensions tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs] - - # Broadcast non-concatenated dimensions of each input - non_concat_shape = [None] * len(out_dims) - for tensor_inp in tensor_inputs: - # TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime - # I'm running this as "shape_unsafe" to simplify the logic / returned graph - for i, (bcast, sh) in enumerate( - zip(tensor_inp.type.broadcastable, tensor_inp.shape) - ): - if bcast or i == concat_axis or non_concat_shape[i] is not None: - continue - non_concat_shape[i] = sh - - assert non_concat_shape.count(None) == 1 - - bcast_tensor_inputs = [] - for tensor_inp in tensor_inputs: - # We modify the concat_axis in place, as we don't need the list anywhere else - non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis] - bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) - - joined_tensor = join(concat_axis, *bcast_tensor_inputs) + joined_tensor = concat_with_broadcast(tensor_inputs, axis=concat_axis) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) return [new_out] diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9d488c4ef0..a6e734ae82 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -115,6 +115,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.type import ( TensorType, cmatrix, @@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out.eval({a: a_test, b: b_test}, mode=test_mode), rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), ) + + +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) +@pytest.mark.parametrize( + "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] +) +@pytest.mark.parametrize( + "batch_other", [True, False], ids=["batched_other", "unbatched_other"] +) +def test_local_block_diag_dot_to_dot_block_diag( + left_multiply, batch_blockdiag, batch_other +): + """ + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) + """ + + def has_blockdiag(graph): + return any( + ( + var.owner + and ( + isinstance(var.owner.op, BlockDiagonal) + or ( + isinstance(var.owner.op, Blockwise) + and isinstance(var.owner.op.core_op, BlockDiagonal) + ) + ) + ) + for var in ancestors([graph]) + ) + + a = tensor("a", shape=(4, 2)) + b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4)) + c = tensor("c", shape=(4, 4)) + x = pt.linalg.block_diag(a, b, c) + + d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) + + # Test multiple clients are all rewritten + if left_multiply: + out = x @ d + else: + out = d @ x + + assert has_blockdiag(out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + assert not has_blockdiag(fn.maker.fgraph.outputs[0]) + + n_dots_rewrite = sum( + isinstance(node.op, Dot | Dot22) + or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) + for node in fn.maker.fgraph.apply_nodes + ) + assert n_dots_rewrite == 3 + + fn_expected = pytensor.function( + [a, b, c, d], + out, + mode=Mode(linker="py", optimizer=None), + ) + assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) + + n_dots_no_rewrite = sum( + isinstance(node.op, Dot | Dot22) + or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) + for node in fn_expected.maker.fgraph.apply_nodes + ) + assert n_dots_no_rewrite == 1 + + rng = np.random.default_rng() + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + rewrite_out = fn(a_val, b_val, c_val, d_val) + expected_out = fn_expected(a_val, b_val, c_val, d_val) + np.testing.assert_allclose( + rewrite_out, + expected_out, + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) + + +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): + rng = np.random.default_rng() + a_size = int(rng.uniform(1, int(0.8 * size))) + b_size = int(rng.uniform(1, int(0.8 * (size - a_size)))) + c_size = size - a_size - b_size + + a = tensor("a", shape=(a_size, a_size)) + b = tensor("b", shape=(b_size, b_size)) + c = tensor("c", shape=(c_size, c_size)) + d = tensor("d", shape=(size,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + mode = get_default_mode() + if not rewrite: + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") + fn = pytensor.function([a, b, c, d], out, mode=mode) + + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + benchmark( + fn, + a_val, + b_val, + c_val, + d_val, + ) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index dee65c5d76..8274ddbcea 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1333,3 +1333,48 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis): atol=1e-6 if config.floatX.endswith("64") else 1e-4, rtol=1e-6 if config.floatX.endswith("64") else 1e-4, ) + + +def test_concat_with_broadcast(): + rng = np.random.default_rng() + a = pt.tensor("a", shape=(1, 3, 5)) + b = pt.tensor("b", shape=(5, 3, 10)) + + c = pt.concat_with_broadcast([a, b], axis=2) + fn = function([a, b], c, mode="FAST_COMPILE") + assert c.type.shape == (5, 3, 15) + + a_val = rng.normal(size=(1, 3, 5)).astype(config.floatX) + b_val = rng.normal(size=(5, 3, 10)).astype(config.floatX) + c_val = fn(a_val, b_val) + + # The result should be a tile + concat + np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1))) + np.testing.assert_allclose(c_val[:, :, 5:], b_val) + + # If a and b already conform, the result should be the same as a concatenation + a = pt.tensor("a", shape=(1, 1, 3, 5, 10)) + b = pt.tensor("b", shape=(1, 1, 3, 2, 10)) + c = pt.concat_with_broadcast([a, b], axis=-2) + assert c.type.shape == (1, 1, 3, 7, 10) + + fn = function([a, b], c, mode="FAST_COMPILE") + a_val = rng.normal(size=(1, 1, 3, 5, 10)).astype(config.floatX) + b_val = rng.normal(size=(1, 1, 3, 2, 10)).astype(config.floatX) + c_val = fn(a_val, b_val) + np.testing.assert_allclose(c_val, np.concatenate([a_val, b_val], axis=-2)) + + c = pt.concat_with_broadcast([a], axis=0) + fn = function([a], c, mode="FAST_COMPILE") + np.testing.assert_allclose(fn(a_val), a_val) + + with pytest.raises(ValueError, match="Cannot concatenate an empty list of tensors"): + pt.concat_with_broadcast([], axis=0) + + with pytest.raises( + TypeError, + match="Only tensors with the same number of dimensions can be concatenated.", + ): + a = pt.tensor("a", shape=(1, 3, 5)) + b = pt.tensor("b", shape=(3, 5)) + pt.concat_with_broadcast([a, b], axis=1)