From d9d7566f1a6c87efcf3db70b836c5fc84b1f58b7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 19:50:56 +0200 Subject: [PATCH 01/10] block_diag dot rewrite --- pytensor/tensor/rewriting/math.py | 73 +++++++++++++++++++++++++++-- tests/tensor/rewriting/test_math.py | 73 +++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e2fe086c00..2d6082e0dc 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -29,9 +29,11 @@ constant, expand_dims, get_underlying_scalar_constant_value, + join, moveaxis, ones_like, register_infer_shape, + split, switch, zeros, zeros_like, @@ -96,6 +98,7 @@ from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( complex_dtypes, @@ -146,6 +149,72 @@ def local_0_dot_x(fgraph, node): return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] +@register_canonicalize +@register_specialize +@register_stabilize +@node_rewriter([Dot]) +def local_block_diag_dot_to_dot_block_diag(fgraph, node): + r""" + Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))`` + + BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity + of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than + a single dot on the larger matrix. + """ + x, y = node.inputs + op = node.op + + def check_for_block_diag(x): + return x.owner and ( + isinstance(x.owner.op, BlockDiagonal) + or isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, BlockDiagonal) + ) + + if not (check_for_block_diag(x) or check_for_block_diag(y)): + return None + + # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the + # non-block diagonal, and return a new block diagonal + if check_for_block_diag(x) and not check_for_block_diag(y): + components = x.owner.inputs + y_splits = split( + y, + splits_size=[component.shape[-1] for component in components], + n_splits=len(components), + ) + new_components = [ + op(component, y_split) for component, y_split in zip(components, y_splits) + ] + new_output = join(0, *new_components) + elif not check_for_block_diag(x) and check_for_block_diag(y): + components = y.owner.inputs + new_components = [op(x, component) for component in components] + new_output = join(0, *new_components) + + # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In + # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case + elif any(shape is None for shape in (*x.type.shape, *y.type.shape)): + return None + elif x.ndim == y.ndim and all( + x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape) + ): + x_components = x.owner.inputs + y_components = y.owner.inputs + + if len(x_components) != len(y_components): + return None + + new_output = BlockDiagonal(len(x_components))( + *[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)] + ) + else: + return None + + copy_stack_trace(node.outputs[0], new_output) + return [new_output] + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): @@ -2582,7 +2651,6 @@ def add_calculate(num, denum, aslist=False, out_type=None): name="add_canonizer_group", ) - register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") @@ -3720,7 +3788,6 @@ def logmexpm1_to_log1mexp(fgraph, node): ) register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") - # log(sigmoid(x) / (1 - sigmoid(x))) -> x # i.e logit(sigmoid(x)) -> x local_logit_sigmoid = PatternNodeRewriter( @@ -3734,7 +3801,6 @@ def logmexpm1_to_log1mexp(fgraph, node): register_canonicalize(local_logit_sigmoid) register_specialize(local_logit_sigmoid) - # sigmoid(log(x / (1-x)) -> x # i.e., sigmoid(logit(x)) -> x local_sigmoid_logit = PatternNodeRewriter( @@ -3775,7 +3841,6 @@ def local_useless_conj(fgraph, node): register_specialize(local_polygamma_to_tri_gamma) - local_log_kv = PatternNodeRewriter( # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x # During stabilize -x is converted to -1.0 * x diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9d488c4ef0..b2ffb67bf0 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -115,6 +115,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.type import ( TensorType, cmatrix, @@ -4745,3 +4746,75 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out.eval({a: a_test, b: b_test}, mode=test_mode), rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), ) + + +def test_local_block_diag_dot_to_dot_block_diag(): + """ + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) + """ + a = tensor("a", shape=(4, 2)) + b = tensor("b", shape=(2, 4)) + c = tensor("c", shape=(4, 4)) + d = tensor("d", shape=(10,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + fn = pytensor.function([a, b, c, d], out) + assert not any( + isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort() + ) + + fn_expected = pytensor.function( + [a, b, c, d], + out, + mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"), + ) + + rng = np.random.default_rng() + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + np.testing.assert_allclose( + fn(a_val, b_val, c_val, d_val), + fn_expected(a_val, b_val, c_val, d_val), + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) + + +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): + rng = np.random.default_rng() + a_size = int(rng.uniform(0, size)) + b_size = int(rng.uniform(0, size - a_size)) + c_size = size - a_size - b_size + + a = tensor("a", shape=(a_size, a_size)) + b = tensor("b", shape=(b_size, b_size)) + c = tensor("c", shape=(c_size, c_size)) + d = tensor("d", shape=(size,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + mode = get_default_mode() + if not rewrite: + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") + fn = pytensor.function([a, b, c, d], out, mode=mode) + + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + benchmark( + fn, + a_val, + b_val, + c_val, + d_val, + ) From 5208662be2b4eb9afab5bbaa9427a60971e20de3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 22:02:28 +0200 Subject: [PATCH 02/10] Handle right-multiplication case --- pytensor/tensor/rewriting/math.py | 29 +++++++++++------------------ tests/tensor/rewriting/test_math.py | 11 ++++++++--- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2d6082e0dc..a276d59614 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -155,7 +155,7 @@ def local_0_dot_x(fgraph, node): @node_rewriter([Dot]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" - Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))`` + Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than @@ -189,25 +189,18 @@ def check_for_block_diag(x): new_output = join(0, *new_components) elif not check_for_block_diag(x) and check_for_block_diag(y): components = y.owner.inputs - new_components = [op(x, component) for component in components] - new_output = join(0, *new_components) - - # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In - # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case - elif any(shape is None for shape in (*x.type.shape, *y.type.shape)): - return None - elif x.ndim == y.ndim and all( - x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape) - ): - x_components = x.owner.inputs - y_components = y.owner.inputs + x_splits = split( + x, + splits_size=[component.shape[0] for component in components], + n_splits=len(components), + axis=1, + ) - if len(x_components) != len(y_components): - return None + new_components = [ + op(x_split, component) for component, x_split in zip(components, x_splits) + ] + new_output = join(1, *new_components) - new_output = BlockDiagonal(len(x_components))( - *[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)] - ) else: return None diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index b2ffb67bf0..440db27483 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4748,17 +4748,22 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): ) -def test_local_block_diag_dot_to_dot_block_diag(): +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) +def test_local_block_diag_dot_to_dot_block_diag(left_multiply): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ a = tensor("a", shape=(4, 2)) b = tensor("b", shape=(2, 4)) c = tensor("c", shape=(4, 4)) - d = tensor("d", shape=(10,)) + d = tensor("d", shape=(10, 10)) x = pt.linalg.block_diag(a, b, c) - out = x @ d + + if left_multiply: + out = x @ d + else: + out = d @ x fn = pytensor.function([a, b, c, d], out) assert not any( From b619e87f861acb280f24fd0d278853f7b56cceac Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 22:17:33 +0200 Subject: [PATCH 03/10] The robot was right! --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 440db27483..506bc520a9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4767,7 +4767,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn = pytensor.function([a, b, c, d], out) assert not any( - isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort() + isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) fn_expected = pytensor.function( From 14ee8e2ea0a8729704e2c6fe431eae84d3b88365 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 22 Jun 2025 12:33:12 +0200 Subject: [PATCH 04/10] Respond to feedback --- pytensor/tensor/rewriting/math.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a276d59614..8674076f1c 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -149,10 +149,8 @@ def local_0_dot_x(fgraph, node): return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] -@register_canonicalize -@register_specialize @register_stabilize -@node_rewriter([Dot]) +@node_rewriter([Blockwise]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` @@ -161,8 +159,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than a single dot on the larger matrix. """ - x, y = node.inputs - op = node.op + if not isinstance(node.op.core_op, BlockDiagonal): + return def check_for_block_diag(x): return x.owner and ( @@ -171,6 +169,15 @@ def check_for_block_diag(x): and isinstance(x.owner.op.core_op, BlockDiagonal) ) + # Check that the BlockDiagonal is an input to a Dot node: + clients = list(get_clients_at_depth(fgraph, node, depth=1)) + if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot): + return + + [dot_node] = clients + op = dot_node.op + x, y = dot_node.inputs + if not (check_for_block_diag(x) or check_for_block_diag(y)): return None @@ -187,6 +194,7 @@ def check_for_block_diag(x): op(component, y_split) for component, y_split in zip(components, y_splits) ] new_output = join(0, *new_components) + elif not check_for_block_diag(x) and check_for_block_diag(y): components = y.owner.inputs x_splits = split( @@ -201,11 +209,14 @@ def check_for_block_diag(x): ] new_output = join(1, *new_components) + # Case 2: Both inputs are BlockDiagonal. Do nothing else: + # TODO: If shapes are statically known and all components have equal shapes, we could rewrite + # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) return None copy_stack_trace(node.outputs[0], new_output) - return [new_output] + return {dot_node.outputs[0]: new_output} @register_canonicalize From 1a89309b23854667a9205ea7af3aee6b79115dd8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 22 Jun 2025 13:00:24 +0200 Subject: [PATCH 05/10] Use `rewrite_mode` defined in `test_math.py` for testing --- tests/tensor/rewriting/test_math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 506bc520a9..533fc65acc 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4765,7 +4765,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): else: out = d @ x - fn = pytensor.function([a, b, c, d], out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) @@ -4773,7 +4773,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn_expected = pytensor.function( [a, b, c, d], out, - mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"), + mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), ) rng = np.random.default_rng() From 3ab2bb09e503d9561623abcb1c84b09c9b4bf23f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 26 Jun 2025 14:25:24 +0800 Subject: [PATCH 06/10] Handle case with multiple clients --- pytensor/tensor/rewriting/math.py | 83 +++++++++++++++-------------- tests/tensor/rewriting/test_math.py | 15 +++--- 2 files changed, 51 insertions(+), 47 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 8674076f1c..ecf2c39cbf 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -170,53 +170,54 @@ def check_for_block_diag(x): ) # Check that the BlockDiagonal is an input to a Dot node: - clients = list(get_clients_at_depth(fgraph, node, depth=1)) - if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot): - return + for client in get_clients_at_depth(fgraph, node, depth=1): + if not isinstance(client.op, Dot): + return - [dot_node] = clients - op = dot_node.op - x, y = dot_node.inputs + op = client.op + x, y = client.inputs - if not (check_for_block_diag(x) or check_for_block_diag(y)): - return None + if not (check_for_block_diag(x) or check_for_block_diag(y)): + return None - # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the - # non-block diagonal, and return a new block diagonal - if check_for_block_diag(x) and not check_for_block_diag(y): - components = x.owner.inputs - y_splits = split( - y, - splits_size=[component.shape[-1] for component in components], - n_splits=len(components), - ) - new_components = [ - op(component, y_split) for component, y_split in zip(components, y_splits) - ] - new_output = join(0, *new_components) - - elif not check_for_block_diag(x) and check_for_block_diag(y): - components = y.owner.inputs - x_splits = split( - x, - splits_size=[component.shape[0] for component in components], - n_splits=len(components), - axis=1, - ) + # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the + # non-block diagonal, and return a new block diagonal + if check_for_block_diag(x) and not check_for_block_diag(y): + components = x.owner.inputs + y_splits = split( + y, + splits_size=[component.shape[-1] for component in components], + n_splits=len(components), + ) + new_components = [ + op(component, y_split) + for component, y_split in zip(components, y_splits) + ] + new_output = join(0, *new_components) + + elif not check_for_block_diag(x) and check_for_block_diag(y): + components = y.owner.inputs + x_splits = split( + x, + splits_size=[component.shape[0] for component in components], + n_splits=len(components), + axis=1, + ) - new_components = [ - op(x_split, component) for component, x_split in zip(components, x_splits) - ] - new_output = join(1, *new_components) + new_components = [ + op(x_split, component) + for component, x_split in zip(components, x_splits) + ] + new_output = join(1, *new_components) - # Case 2: Both inputs are BlockDiagonal. Do nothing - else: - # TODO: If shapes are statically known and all components have equal shapes, we could rewrite - # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) - return None + # Case 2: Both inputs are BlockDiagonal. Do nothing + else: + # TODO: If shapes are statically known and all components have equal shapes, we could rewrite + # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) + return None - copy_stack_trace(node.outputs[0], new_output) - return {dot_node.outputs[0]: new_output} + copy_stack_trace(node.outputs[0], new_output) + return {client.outputs[0]: new_output} @register_canonicalize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 533fc65acc..d4b4112060 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4757,21 +4757,23 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): b = tensor("b", shape=(2, 4)) c = tensor("c", shape=(4, 4)) d = tensor("d", shape=(10, 10)) + e = tensor("e", shape=(10, 10)) x = pt.linalg.block_diag(a, b, c) + # Test multiple clients are all rewritten if left_multiply: - out = x @ d + out = [x @ d, x @ e] else: - out = d @ x + out = [d @ x, e @ x] - fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) fn_expected = pytensor.function( - [a, b, c, d], + [a, b, c, d, e], out, mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), ) @@ -4781,10 +4783,11 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + e_val = rng.normal(size=e.type.shape).astype(e.type.dtype) np.testing.assert_allclose( - fn(a_val, b_val, c_val, d_val), - fn_expected(a_val, b_val, c_val, d_val), + fn(a_val, b_val, c_val, d_val, e_val), + fn_expected(a_val, b_val, c_val, d_val, e_val), atol=1e-6 if config.floatX == "float32" else 1e-12, rtol=1e-6 if config.floatX == "float32" else 1e-12, ) From 558acbcf1245be648f7371b15d3d438997c628b5 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 26 Jun 2025 14:41:52 +0800 Subject: [PATCH 07/10] use `continue` on rewrite failures when checking clients --- pytensor/tensor/rewriting/math.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index ecf2c39cbf..00f4ecdad7 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -172,13 +172,13 @@ def check_for_block_diag(x): # Check that the BlockDiagonal is an input to a Dot node: for client in get_clients_at_depth(fgraph, node, depth=1): if not isinstance(client.op, Dot): - return + continue op = client.op x, y = client.inputs if not (check_for_block_diag(x) or check_for_block_diag(y)): - return None + continue # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the # non-block diagonal, and return a new block diagonal @@ -214,7 +214,7 @@ def check_for_block_diag(x): else: # TODO: If shapes are statically known and all components have equal shapes, we could rewrite # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) - return None + continue copy_stack_trace(node.outputs[0], new_output) return {client.outputs[0]: new_output} From 6736e8ec5118f22ab3b87c198b45d4932b4df24f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 8 Jul 2025 18:35:28 +0800 Subject: [PATCH 08/10] pair coding results --- pytensor/tensor/rewriting/math.py | 69 +++++++++++------------------ tests/tensor/rewriting/test_math.py | 36 ++++++++++----- 2 files changed, 50 insertions(+), 55 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 00f4ecdad7..43461b3584 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -162,59 +162,40 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): if not isinstance(node.op.core_op, BlockDiagonal): return - def check_for_block_diag(x): - return x.owner and ( - isinstance(x.owner.op, BlockDiagonal) - or isinstance(x.owner.op, Blockwise) - and isinstance(x.owner.op.core_op, BlockDiagonal) - ) - # Check that the BlockDiagonal is an input to a Dot node: for client in get_clients_at_depth(fgraph, node, depth=1): - if not isinstance(client.op, Dot): + if not ( + ( + isinstance(client.op, Dot) + and all(input.ndim == 2 for input in client.inputs) + ) + or client.op == _matrix_matrix_matmul + ): continue op = client.op - x, y = client.inputs - if not (check_for_block_diag(x) or check_for_block_diag(y)): - continue + client_idx = client.inputs.index(node.outputs[0]) - # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the - # non-block diagonal, and return a new block diagonal - if check_for_block_diag(x) and not check_for_block_diag(y): - components = x.owner.inputs - y_splits = split( - y, - splits_size=[component.shape[-1] for component in components], - n_splits=len(components), - ) - new_components = [ - op(component, y_split) - for component, y_split in zip(components, y_splits) - ] - new_output = join(0, *new_components) - - elif not check_for_block_diag(x) and check_for_block_diag(y): - components = y.owner.inputs - x_splits = split( - x, - splits_size=[component.shape[0] for component in components], - n_splits=len(components), - axis=1, - ) + other_input = client.inputs[1 - client_idx] + components = node.inputs - new_components = [ - op(x_split, component) - for component, x_split in zip(components, x_splits) - ] - new_output = join(1, *new_components) + split_axis = -2 if client_idx == 0 else -1 + shape_idx = -1 if client_idx == 0 else -2 - # Case 2: Both inputs are BlockDiagonal. Do nothing - else: - # TODO: If shapes are statically known and all components have equal shapes, we could rewrite - # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) - continue + other_dot_input_split = split( + other_input, + splits_size=[component.shape[shape_idx] for component in components], + n_splits=len(components), + axis=split_axis, + ) + new_components = [ + op(component, other_split) + if client_idx == 0 + else op(other_split, component) + for component, other_split in zip(components, other_dot_input_split) + ] + new_output = join(split_axis, *new_components) copy_stack_trace(node.outputs[0], new_output) return {client.outputs[0]: new_output} diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index d4b4112060..fb9a87833b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4749,15 +4749,21 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) -def test_local_block_diag_dot_to_dot_block_diag(left_multiply): +@pytest.mark.parametrize( + "batch_left", [True, False], ids=["batched_left", "unbatched_left"] +) +@pytest.mark.parametrize( + "batch_right", [True, False], ids=["batched_right", "unbatched_right"] +) +def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ a = tensor("a", shape=(4, 2)) - b = tensor("b", shape=(2, 4)) + b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4)) c = tensor("c", shape=(4, 4)) d = tensor("d", shape=(10, 10)) - e = tensor("e", shape=(10, 10)) + e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10)) x = pt.linalg.block_diag(a, b, c) @@ -4767,7 +4773,9 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): else: out = [d @ x, e @ x] - fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) + with config.change_flags(optimizer_verbose=True): + fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) + assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) @@ -4775,9 +4783,11 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn_expected = pytensor.function( [a, b, c, d, e], out, - mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), + mode=Mode(linker="py", optimizer=None), ) + # TODO: Count Dots + rng = np.random.default_rng() a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) @@ -4785,12 +4795,16 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) e_val = rng.normal(size=e.type.shape).astype(e.type.dtype) - np.testing.assert_allclose( - fn(a_val, b_val, c_val, d_val, e_val), - fn_expected(a_val, b_val, c_val, d_val, e_val), - atol=1e-6 if config.floatX == "float32" else 1e-12, - rtol=1e-6 if config.floatX == "float32" else 1e-12, - ) + rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val) + expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val) + + for out, expected in zip(rewrite_outs, expected_outs): + np.testing.assert_allclose( + out, + expected, + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) @pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) From 993fb641287939407d941b3f6f1c36e9cc2b237f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 8 Jul 2025 12:57:23 +0200 Subject: [PATCH 09/10] Cleanup test --- tests/tensor/rewriting/test_math.py | 66 +++++++++++++++++------------ 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index fb9a87833b..d710a5d485 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4750,41 +4750,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) @pytest.mark.parametrize( - "batch_left", [True, False], ids=["batched_left", "unbatched_left"] + "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] ) @pytest.mark.parametrize( - "batch_right", [True, False], ids=["batched_right", "unbatched_right"] + "batch_other", [True, False], ids=["batched_other", "unbatched_other"] ) -def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right): +def test_local_block_diag_dot_to_dot_block_diag( + left_multiply, batch_blockdiag, batch_other +): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ + + def has_blockdiag(graph): + return any( + ( + var.owner + and ( + isinstance(var.owner.op, BlockDiagonal) + or ( + isinstance(var.owner.op, Blockwise) + and isinstance(var.owner.op.core_op, BlockDiagonal) + ) + ) + ) + for var in ancestors([graph]) + ) + a = tensor("a", shape=(4, 2)) - b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4)) + b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4)) c = tensor("c", shape=(4, 4)) - d = tensor("d", shape=(10, 10)) - e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10)) - x = pt.linalg.block_diag(a, b, c) + d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) + # Test multiple clients are all rewritten if left_multiply: - out = [x @ d, x @ e] + out = x @ d else: - out = [d @ x, e @ x] + out = d @ x - with config.change_flags(optimizer_verbose=True): - fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) - - assert not any( - isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() - ) + assert has_blockdiag(out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + assert not has_blockdiag(fn.maker.fgraph.outputs[0]) fn_expected = pytensor.function( - [a, b, c, d, e], + [a, b, c, d], out, mode=Mode(linker="py", optimizer=None), ) + assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) # TODO: Count Dots @@ -4793,18 +4808,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) - e_val = rng.normal(size=e.type.shape).astype(e.type.dtype) - rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val) - expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val) - - for out, expected in zip(rewrite_outs, expected_outs): - np.testing.assert_allclose( - out, - expected, - atol=1e-6 if config.floatX == "float32" else 1e-12, - rtol=1e-6 if config.floatX == "float32" else 1e-12, - ) + rewrite_out = fn(a_val, b_val, c_val, d_val) + expected_out = fn_expected(a_val, b_val, c_val, d_val) + np.testing.assert_allclose( + rewrite_out, + expected_out, + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) @pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) From 7c3820b93db957f72e21fd803cea77e511e8d962 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 25 Jul 2025 21:08:39 +0800 Subject: [PATCH 10/10] look for `_matmul` in `local_block_diag_dot_to_dot_block_diag` --- pytensor/tensor/rewriting/math.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 43461b3584..b5aec70884 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -38,6 +38,7 @@ zeros, zeros_like, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays @@ -169,7 +170,7 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): isinstance(client.op, Dot) and all(input.ndim == 2 for input in client.inputs) ) - or client.op == _matrix_matrix_matmul + or client.op == _matmul ): continue