1414 register_stabilize ,
1515)
1616from pytensor .tensor .shape import Reshape
17- from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedSubtensor , Subtensor
17+ from pytensor .tensor .subtensor import (
18+ AdvancedIncSubtensor ,
19+ AdvancedSubtensor ,
20+ Subtensor ,
21+ indices_from_subtensor ,
22+ )
1823
1924
2025@node_rewriter ([Blockwise ])
@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
216221
217222 Reshape is tricky to vectorize eagerly, because a graph like
218223 `x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
219- that must be vectorized before we arrize at the reshape operation.
224+ that must be vectorized before we arrive at the reshape operation.
220225
221- For the square Reshape case, we must wait for all the intemediate
226+ For the square Reshape case, we must wait for all the intermediate
222227 operations to be lifted as Allocs
223228 """
224229 if not isinstance (node .op .core_op , Reshape ):
@@ -234,6 +239,29 @@ def local_blockwise_reshape(fgraph, node):
234239 return [new_out ]
235240
236241
242+ @register_stabilize
243+ @register_specialize
244+ @node_rewriter ([Blockwise ])
245+ def local_blockwise_of_subtensor (fgraph , node ):
246+ """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
247+
248+ Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
249+ """
250+ if not isinstance (node .op .core_op , Subtensor ):
251+ return
252+
253+ x , * idxs = node .inputs
254+ if not all (all (idx .type .broadcastable ) for idx in idxs ):
255+ return
256+
257+ core_idxs = indices_from_subtensor (
258+ [idx .squeeze () for idx in idxs ], node .op .core_op .idx_list
259+ )
260+ # Add empty slices for the batch dims
261+ none_slices = (slice (None ),) * node .op .batch_ndim (node )
262+ return [x [(* none_slices , * core_idxs )]]
263+
264+
237265@node_rewriter (tracks = [Blockwise ], inplace = True )
238266def blockwise_inplace (fgraph , node ):
239267 blockwise_op = node .op
0 commit comments