diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index f189766c9c..931c7009b3 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node): | pytensor.tensor.blas.Gemv | pytensor.tensor.blas_c.CGemv | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer - | pytensor.tensor.blas_scipy.ScipyGer, + | pytensor.tensor.blas_c.CGer, ) ): # Ops that will work inplace on the Alloc. So if they diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index fc8afcea50..cad07f7bea 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -83,6 +83,7 @@ from pathlib import Path import numpy as np +from scipy.linalg import get_blas_funcs from pytensor.graph import vectorize_graph from pytensor.npy_2_compat import normalize_axis_tuple @@ -288,18 +289,15 @@ def make_node(self, A, alpha, x, y): return Apply(self, inputs, [A.type()]) - def perform(self, node, inp, out): - cA, calpha, cx, cy = inp - (cZ,) = out - if self.destructive: - A = cA - else: - A = cA.copy() - if calpha != 1: - A += calpha * np.outer(cx, cy) + def perform(self, node, inputs, output_storage): + A, alpha, x, y = inputs + ger_func = get_blas_funcs("ger", dtype=A.dtype) + if A.flags["C_CONTIGUOUS"]: + # Work on transposed system to avoid copying + A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T else: - A += np.outer(cx, cy) - cZ[0] = A + A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) + output_storage[0][0] = A def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] @@ -1128,16 +1126,8 @@ def make_node(self, x, y): outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))] return Apply(self, [x, y], outputs) - def perform(self, node, inp, out): - x, y = inp - (z,) = out - try: - z[0] = np.asarray(np.dot(x, y)) - except ValueError as e: - # The error raised by numpy has no shape information, we mean to - # add that - e.args = (*e.args, x.shape, y.shape) - raise + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.dot(*inputs) def infer_shape(self, fgraph, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py index bb3ccf9354..b4494a98a0 100644 --- a/pytensor/tensor/blas_scipy.py +++ b/pytensor/tensor/blas_scipy.py @@ -2,31 +2,22 @@ Implementations of BLAS Ops based on scipy's BLAS bindings. """ +from scipy.linalg.blas import get_blas_funcs + from pytensor.tensor.blas import Ger class ScipyGer(Ger): def perform(self, node, inputs, output_storage): - from scipy.linalg.blas import get_blas_funcs - cA, calpha, cx, cy = inputs (cZ,) = output_storage - # N.B. some versions of scipy (e.g. mine) don't actually work - # in-place on a, even when I tell it to. A = cA - local_ger = get_blas_funcs("ger", dtype=cA.dtype) - if A.size == 0: - # We don't have to compute anything, A is empty. - # We need this special case because Numpy considers it - # C-contiguous, which is confusing. - if not self.destructive: - # Sometimes numpy thinks empty matrices can share memory, - # so here to stop DebugMode from complaining. - A = A.copy() - elif A.flags["C_CONTIGUOUS"]: - A = local_ger(calpha, cy, cx, a=A.T, overwrite_a=int(self.destructive)).T + ger_func = get_blas_funcs("ger", dtype=cA.dtype) + if A.flags["C_CONTIGUOUS"]: + # Work on transposed system to avoid copying + A = ger_func(calpha, cy, cx, a=A.T, overwrite_a=self.destructive).T else: - A = local_ger(calpha, cx, cy, a=A, overwrite_a=int(self.destructive)) + A = ger_func(calpha, cx, cy, a=A, overwrite_a=self.destructive) cZ[0] = A diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 714f597b32..bee8c3df9b 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -40,12 +40,13 @@ get_normalized_batch_axes, scalar_elemwise, ) -from pytensor.tensor.shape import shape, specify_broadcastable +from pytensor.tensor.shape import shape, specify_shape from pytensor.tensor.type import ( DenseTensorType, complex_dtypes, continuous_dtypes, discrete_dtypes, + float_dtypes, int_dtypes, tensor, uint_dtypes, @@ -2986,9 +2987,7 @@ def clip(x, min, max): class Dot(Op): """ - Computes the dot product of two variables. For two matrices, this is - equivalent to matrix multiplication. For two vectors, this is the inner - product. + Computes the dot product of two matrices variables Notes ----- @@ -3001,92 +3000,57 @@ class Dot(Op): """ + gufunc_signature = "(m,n),(n,p)->(m,p)" + gufunc_spec = ("np.matmul", 2, 1) __props__ = () - # the rationale for Dot22 is related to getting GEMM Ops into the - # graph. See Dot22 in tensor.blas for details. - - def make_node(self, *inputs): - inputs = list(map(as_tensor_variable, inputs)) + def make_node(self, x, y): + x = as_tensor_variable(x) + y = as_tensor_variable(y) - if len(inputs) != 2: - raise TypeError(f"Two arguments required, {len(inputs)} given ") - if inputs[0].ndim not in (1, 2): + if x.type.ndim != 2: raise TypeError( - "Input 0 (0-indexed) must have ndim of " - f"1 or 2, {int(inputs[0].ndim)} given. Consider calling " - "pytensor.tensor.dot instead." + f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions" ) - if inputs[1].ndim not in (1, 2): + if y.type.ndim != 2: raise TypeError( - "Input 1 (0-indexed) must have ndim of " - f"1 or 2, {int(inputs[1].ndim)} given. Consider calling " - "pytensor.tensor.dot instead." + f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions" ) - sx, sy = (input.type.shape for input in inputs) - if len(sy) == 2: - sz = sx[:-1] + sy[-1:] - elif len(sy) == 1: - sz = sx[:-1] - - i_dtypes = [input.type.dtype for input in inputs] - outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)] - return Apply(self, inputs, outputs) - - def perform(self, node, inp, out): - x, y = inp - (z,) = out + sx, sy = x.type.shape, y.type.shape + if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]: + raise ValueError( + f"Incompatible shared dimension for dot product: {sx}, {sy}" + ) + sz = sx[:-1] + sy[-1:] + outputs = [tensor(dtype=ps.upcast(x.type.dtype, y.type.dtype), shape=sz)] + return Apply(self, [x, y], outputs) - # the asarray is here because dot between two vectors - # gives a numpy float object but we need to return a 0d - # ndarray - z[0] = np.asarray(np.dot(x, y)) + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.matmul(*inputs) def grad(self, inp, grads): x, y = inp (gz,) = grads - xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim - - # grad is scalar, so x is vector and y is vector - if gdim == 0: - xgrad = gz * y - ygrad = gz * x - - # x is vector, y is matrix, grad is vector - elif xdim == 1 and ydim == 2: - xgrad = dot(gz, y.T) - ygrad = outer(x.T, gz) - - # x is matrix, y is vector, grad is vector - elif xdim == 2 and ydim == 1: - xgrad = outer(gz, y.T) - ygrad = dot(x.T, gz) - # x is matrix, y is matrix, grad is matrix - elif xdim == ydim == 2: - xgrad = dot(gz, y.T) - ygrad = dot(x.T, gz) + xgrad = self(gz, y.T) + ygrad = self(x.T, gz) # If x or y contain broadcastable dimensions but only one of # them know that a matching dimensions is broadcastable, the # above code don't always return the right broadcast pattern. # This cause problem down the road. See gh-1461. - if xgrad.broadcastable != x.broadcastable: - xgrad = specify_broadcastable( - xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) - ) - if ygrad.broadcastable != y.broadcastable: - ygrad = specify_broadcastable( - ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b) - ) - - rval = xgrad, ygrad + if xgrad.type.shape != x.type.shape: + xgrad = specify_shape(xgrad, x.type.shape) + if ygrad.type.shape != y.type.shape: + ygrad = specify_shape(ygrad, y.type.shape) - for elem in rval: - assert elem.dtype.find("float") != -1 + if xgrad.type.dtype not in float_dtypes: + raise TypeError("Dot grad x output must be a float type") + if ygrad.type.dtype not in float_dtypes: + raise TypeError("Dot grad y output must be a float type") - return rval + return xgrad, ygrad def R_op(self, inputs, eval_points): # R_op for a \dot b evaluated at c for a and d for b is @@ -3111,24 +3075,7 @@ def R_op(self, inputs, eval_points): def infer_shape(self, fgraph, node, shapes): xshp, yshp = shapes - x, y = node.inputs - - # vector / vector - if x.ndim == 1 and y.ndim == 1: - return [()] - # matrix / vector - if x.ndim == 2 and y.ndim == 1: - return [xshp[:-1]] - # vector / matrix - if x.ndim == 1 and y.ndim == 2: - return [yshp[-1:]] - # matrix / matrix - if x.ndim == 2 and y.ndim == 2: - return [xshp[:-1] + yshp[-1:]] - raise NotImplementedError() - - def __str__(self): - return "dot" + return [[xshp[0], yshp[1]]] _dot = Dot() @@ -3210,7 +3157,24 @@ def dense_dot(a, b): elif a.ndim > 2 or b.ndim > 2: return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]]) else: - return _dot(a, b) + row_vector = a.ndim == 1 + if row_vector: + # Promote to row matrix + a = a[None] + + col_vector = b.ndim == 1 + if col_vector: + # Promote to column matrix + b = b[:, None] + + out = _dot(a, b) + if row_vector: + # If we promoted a to a row matrix, we need to squeeze the first dimension + out = out.squeeze(0) + if col_vector: + # If we promoted b to a column matrix, we need to squeeze the last dimension + out = out.squeeze(-1) + return out def tensordot( @@ -3916,27 +3880,7 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -# Predefine all batched variations of Dot -_inner_prod = Blockwise( - _dot, - signature="(n),(n)->()", -) - -_matrix_vec_prod = Blockwise( - _dot, - signature="(m,k),(k)->(m)", -) - -_vec_matrix_prod = Blockwise( - _dot, - signature="(k),(k,n)->(n)", -) - -_matrix_matrix_matmul = Blockwise( - _dot, - signature="(m,k),(k,n)->(m,n)", - gufunc_spec=("numpy.matmul", 2, 1), -) +_matmul = Blockwise(_dot, name="matmul") def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): @@ -3988,11 +3932,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None if x1.type.ndim == 1 and x2.type.ndim == 1: out = _dot(x1, x2) elif x1.type.ndim == 1: - out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) + out = vecmat(x1, x2) elif x2.type.ndim == 1: - out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) + out = matvec(x1, x2) else: - out = _matrix_matrix_matmul(x1, x2) + out = _matmul(x1, x2) if dtype is not None: out = out.astype(dtype) @@ -4042,7 +3986,7 @@ def vecdot( >>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,) >>> # Equivalent to numpy.vecdot(x_batch, y_batch) """ - out = _inner_prod(x1, x2) + out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1)) if dtype is not None: out = out.astype(dtype) @@ -4091,7 +4035,7 @@ def matvec( >>> result = pt.matvec(batched_A, batched_v) # shape (2, 3) >>> # Equivalent to numpy.matvec(batched_A, batched_v) """ - out = _matrix_vec_prod(x1, x2) + out = matmul(x1, x2[..., None]).squeeze(-1) if dtype is not None: out = out.astype(dtype) @@ -4129,18 +4073,18 @@ def vecmat( -------- >>> import pytensor.tensor as pt >>> # Vector-matrix product - >>> v = pt.vector("v", shape=(3,)) # shape (3,) - >>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4) + >>> v = pt.vector("v", shape=(3,)) + >>> A = pt.matrix("A", shape=(3, 4)) >>> result = pt.vecmat(v, A) # shape (4,) >>> # Equivalent to numpy.vecmat(v, A) >>> >>> # Batched vector-matrix product - >>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3) - >>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4) + >>> batched_v = pt.matrix("v", shape=(2, 3)) + >>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) >>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4) >>> # Equivalent to numpy.vecmat(batched_v, batched_A) """ - out = _vec_matrix_prod(x1, x2) + out = matmul(x2.mT, x1[..., None]).squeeze(-1) if dtype is not None: out = out.astype(dtype) @@ -4155,18 +4099,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y): old_y_ndim = old_y.type.ndim match (old_x_ndim, old_y_ndim): case (1, 1): - batch_op = _inner_prod + batch_fn = vecdot case (2, 1): - batch_op = _matrix_vec_prod + batch_fn = matvec case (1, 2): - batch_op = _vec_matrix_prod + batch_fn = vecmat case (2, 2): - batch_op = _matrix_matrix_matmul + batch_fn = matmul case _: raise ValueError( f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." ) - return batch_op(batched_x, batched_y).owner + return batch_fn(batched_x, batched_y).owner def nan_to_num(x, nan=0.0, posinf=None, neginf=None): diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 6d411d3827..2293e1d8dd 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -1,7 +1,6 @@ import pytensor.tensor.rewriting.basic import pytensor.tensor.rewriting.blas import pytensor.tensor.rewriting.blas_c -import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.elemwise diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index e626b0720b..685cec5785 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -98,7 +98,7 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( Dot, - _matrix_matrix_matmul, + _matmul, add, mul, neg, @@ -107,7 +107,6 @@ ) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.type import ( - DenseTensorType, TensorType, integer_dtypes, values_eq_approx_remove_inf_nan, @@ -580,12 +579,6 @@ def print_profile(cls, stream, prof, level=0): def local_dot_to_dot22(fgraph, node): # This works for tensor.outer too because basic.outer is a macro that # produces a dot(dimshuffle,dimshuffle) of form 4 below - if not isinstance(node.op, Dot): - return - - if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): - return False - x, y = node.inputs if y.type.dtype != x.type.dtype: # TODO: upcast one so the types match @@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node): return if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): - if x.ndim == 2 and y.ndim == 2: - new_out = [_dot22(*node.inputs)] - elif x.ndim == 2 and y.ndim == 1: - new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] - elif x.ndim == 1 and y.ndim == 2: - new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] - elif x.ndim == 1 and y.ndim == 1: - new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] - else: - return + new_out = [_dot22(*node.inputs)] copy_stack_trace(node.outputs, new_out) return new_out @@ -636,93 +620,89 @@ def local_inplace_ger(fgraph, node): @node_rewriter([gemm_no_inplace]) def local_gemm_to_gemv(fgraph, node): """GEMM acting on row or column matrices -> GEMV.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if z.broadcastable == x.broadcastable == (True, False): - r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) - new_out = [r.dimshuffle("x", 0)] - elif z.broadcastable == y.broadcastable == (False, True): - r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) - new_out = [r.dimshuffle(0, "x")] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out + z, a, x, y, b = node.inputs + if z.broadcastable == x.broadcastable == (True, False): + r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) + new_out = [r.dimshuffle("x", 0)] + elif z.broadcastable == y.broadcastable == (False, True): + r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) + new_out = [r.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out @node_rewriter([gemm_no_inplace]) def local_gemm_to_ger(fgraph, node): """GEMM computing an outer-product -> GER.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if x.broadcastable[1] and y.broadcastable[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - try: - bval = ptb.get_underlying_scalar_constant_value(b) - except NotScalarConstantError: - # b isn't a constant, GEMM is doing useful pre-scaling - return - - if bval == 1: # best case a natural GER - rval = ger(z, a, xv, yv) - new_out = [rval] - elif bval == 0: # GER on zeros_like should be faster than GEMM - zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype) - rval = ger(zeros, a, xv, yv) - new_out = [rval] - else: - # if bval is another constant, then z is being usefully - # pre-scaled and GER isn't really the right tool for the job. - return - copy_stack_trace(node.outputs, new_out) - return new_out - + z, a, x, y, b = node.inputs + if x.broadcastable[1] and y.broadcastable[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + try: + bval = ptb.get_underlying_scalar_constant_value(b) + except NotScalarConstantError: + # b isn't a constant, GEMM is doing useful pre-scaling + return -# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline -# working -@node_rewriter([_dot22]) -def local_dot22_to_ger_or_gemv(fgraph, node): - """dot22 computing an outer-product -> GER.""" - if node.op == _dot22: - x, y = node.inputs - xb = x.broadcastable - yb = y.broadcastable - one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype)) - zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) - if xb[1] and yb[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) - rval = ger(zeros, one, xv, yv) + if bval == 1: # best case a natural GER + rval = ger(z, a, xv, yv) + new_out = [rval] + elif bval == 0: # GER on zeros_like should be faster than GEMM + zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype) + rval = ger(zeros, a, xv, yv) new_out = [rval] - elif xb[0] and yb[1]: - # x and y are both vectors so this qualifies for a sdot / ddot - # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not - xv = x.dimshuffle(1) - zeros = ptb.AllocEmpty(x.dtype)(1) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif xb[0] and not yb[0] and not yb[1]: - # x is vector, y is matrix so try gemv - xv = x.dimshuffle(1) - zeros = ptb.AllocEmpty(x.dtype)(y.shape[1]) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif not xb[0] and not xb[1] and yb[1]: - # x is matrix, y is vector, try gemv - yv = y.dimshuffle(0) - zeros = ptb.AllocEmpty(x.dtype)(x.shape[0]) - rval = gemv_no_inplace(zeros, one, x, yv, zero) - new_out = [rval.dimshuffle(0, "x")] else: + # if bval is another constant, then z is being usefully + # pre-scaled and GER isn't really the right tool for the job. return copy_stack_trace(node.outputs, new_out) return new_out +# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline working +@node_rewriter([_dot22]) +def local_dot22_to_ger_or_gemv(fgraph, node): + """dot22 computing an outer-product -> GER.""" + x, y = node.inputs + xb = x.broadcastable + yb = y.broadcastable + one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype)) + zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) + if xb[1] and yb[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) + rval = ger(zeros, one, xv, yv) + new_out = [rval] + elif xb[0] and yb[1]: + # x and y are both vectors so this qualifies for a sdot / ddot + # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not + xv = x.dimshuffle(1) + zeros = ptb.AllocEmpty(x.dtype)(1) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif xb[0] and not yb[0] and not yb[1]: + # x is vector, y is matrix so try gemv + xv = x.dimshuffle(1) + zeros = ptb.AllocEmpty(x.dtype)(y.shape[1]) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif not xb[0] and not xb[1] and yb[1]: + # x is matrix, y is vector, try gemv + yv = y.dimshuffle(0) + zeros = ptb.AllocEmpty(x.dtype)(x.shape[0]) + rval = gemv_no_inplace(zeros, one, x, yv, zero) + new_out = [rval.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + ################################# # # Set up the BlasOpt optimizer @@ -758,7 +738,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ignore_newtrees=False, ), "fast_run", - position=15, + position=11, ) @@ -903,12 +883,12 @@ def local_dot22_to_dot22scalar(fgraph, node): "local_dot22_to_dot22scalar", in2out(local_dot22_to_dot22scalar), "fast_run", - position=11, + position=12, ) @register_specialize -@node_rewriter([_matrix_matrix_matmul]) +@node_rewriter([_matmul]) def specialize_matmul_to_batched_dot(fgraph, node): """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. @@ -916,6 +896,10 @@ def specialize_matmul_to_batched_dot(fgraph, node): """ x, y = node.inputs + if x.type.ndim < 3: + # This doesn't actually have a batch dimension + return None + # BatchedDot does not allow implicit broadcasting of the batch dimensions # We do not want to explicitly broadcast as it may result in huge arrays if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]: @@ -926,6 +910,7 @@ def specialize_matmul_to_batched_dot(fgraph, node): if len(x_shape) > 3: # If we have more than one batch dim, ravel it x = x.reshape((-1, x_shape[-2], x_shape[-1])) + if len(y_shape) > 3: y = y.reshape((-1, y_shape[-2], y_shape[-1])) new_out = _batched_dot(x, y) diff --git a/pytensor/tensor/rewriting/blas_scipy.py b/pytensor/tensor/rewriting/blas_scipy.py deleted file mode 100644 index 2ed0279e45..0000000000 --- a/pytensor/tensor/rewriting/blas_scipy.py +++ /dev/null @@ -1,37 +0,0 @@ -from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.blas import ger, ger_destructive -from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace -from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb - - -@node_rewriter([ger, ger_destructive]) -def use_scipy_ger(fgraph, node): - if node.op == ger: - return [scipy_ger_no_inplace(*node.inputs)] - - -@node_rewriter([scipy_ger_no_inplace]) -def make_ger_destructive(fgraph, node): - if node.op == scipy_ger_no_inplace: - return [scipy_ger_inplace(*node.inputs)] - - -use_scipy_blas = in2out(use_scipy_ger) -make_scipy_blas_destructive = in2out(make_ger_destructive) - - -# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof -# sucks [citation needed], but it is almost always present. -# C implementations should be scheduled earlier than this, so that they take -# precedence. Once the original Ger is replaced, then these optimizations -# have no effect. -blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) - -# this matches the InplaceBlasOpt defined in blas.py -optdb.register( - "make_scipy_blas_destructive", - make_scipy_blas_destructive, - "fast_run", - "inplace", - position=50.2, -) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index afe69a198b..f08f19f06c 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -30,21 +30,17 @@ from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( MakeVector, - alloc, - cast, constant, - get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import add, exp, mul from pytensor.tensor.rewriting.basic import ( alloc_like, broadcasted_by, register_canonicalize, register_specialize, + register_stabilize, ) -from pytensor.tensor.shape import shape_padleft from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -346,6 +342,7 @@ def is_dimshuffle_useless(new_order, input): @register_canonicalize +@register_stabilize @register_specialize @node_rewriter([DimShuffle]) def local_dimshuffle_lift(fgraph, node): @@ -434,66 +431,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): """ if len(node.outputs) > 1: - return - try: - shape_i = fgraph.shape_feature.shape_i - except AttributeError: - shape_i = None - if isinstance(node.op, Elemwise): - scalar_op = node.op.scalar_op - # print "aa", scalar_op.output_types_preference - if getattr(scalar_op, "output_types_preference", None) in ( - ps.upgrade_to_float, - ps.upcast_out, - ): - # this is the kind of op that we can screw with the input - # dtypes by upcasting explicitly - output_dtype = node.outputs[0].type.dtype - new_inputs = [] - for i in node.inputs: - if i.type.dtype == output_dtype: - new_inputs.append(i) - else: - try: - cval_i = get_underlying_scalar_constant_value( - i, only_process_constants=True - ) - if all(i.broadcastable): - new_inputs.append( - shape_padleft(cast(cval_i, output_dtype), i.ndim) - ) - else: - if shape_i is None: - return - new_inputs.append( - alloc( - cast(cval_i, output_dtype), - *[shape_i(d)(i) for d in range(i.ndim)], - ) - ) - # print >> sys.stderr, "AAA", - # *[Shape_i(d)(i) for d in range(i.ndim)] - except NotScalarConstantError: - # for the case of a non-scalar - if isinstance(i, TensorConstant): - new_inputs.append(cast(i, output_dtype)) - else: - new_inputs.append(i) + return None + + if getattr(node.op.scalar_op, "output_types_preference", None) not in ( + ps.upgrade_to_float, + ps.upcast_out, + ): + return None - if new_inputs != node.inputs: - rval = [node.op(*new_inputs)] - if not node.outputs[0].type.is_super(rval[0].type): - # This can happen for example when floatX=float32 - # and we do the true division between and int64 - # and a constant that will get typed as int8. + # this is the kind of op that we can screw with the input + # dtypes by upcasting explicitly + [old_out] = node.outputs + output_dtype = old_out.type.dtype + new_inputs = list(node.inputs) + changed = False + for i, inp in enumerate(node.inputs): + if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): + new_inputs[i] = constant(inp.data.astype(output_dtype)) + changed = True + + if not changed: + return None - # As this is just to allow merging more case, if - # the upcast don't work, we can just skip it. - return + rval = node.op(*new_inputs) + if not old_out.type.is_super(rval.type): + # This can happen for example when floatX=float32 + # and we do the true division between and int64 + # and a constant that will get typed as int8. + # As this is just to allow merging more case, if + # the upcast don't work, we can just skip it. + return None - # Copy over output stacktrace from before upcasting - copy_stack_trace(node.outputs[0], rval) - return rval + # Copy over output stacktrace from before upcasting + copy_stack_trace(old_out, rval) + return [rval] @node_rewriter([add, mul]) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2a1a71ae40..5ac219f928 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -23,10 +23,9 @@ diag, diagonal, ) -from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod +from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod from pytensor.tensor.nlinalg import ( SVD, KroneckerProduct, @@ -103,12 +102,12 @@ def transinv_to_invtrans(fgraph, node): @register_stabilize -@node_rewriter([Dot, Dot22]) +@node_rewriter([Dot]) def inv_as_solve(fgraph, node): """ This utilizes a boolean `symmetric` tag on the matrices. """ - if isinstance(node.op, Dot | Dot22): + if isinstance(node.op, Dot): l, r = node.inputs if ( l.owner @@ -278,14 +277,7 @@ def cholesky_ldotlt(fgraph, node): A = node.inputs[0] if not ( A.owner is not None - and ( - ( - isinstance(A.owner.op, Dot | Dot22) - # This rewrite only applies to matrix Dot - and A.owner.inputs[0].type.ndim == 2 - ) - or (A.owner.op == _matrix_matrix_matmul) - ) + and ((isinstance(A.owner.op, Dot)) or (A.owner.op == _matmul)) ): return diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index d126502bde..6871309d8c 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -19,7 +19,6 @@ node_rewriter, ) from pytensor.graph.rewriting.utils import get_clients_at_depth -from pytensor.raise_op import assert_op from pytensor.tensor.basic import ( Alloc, Join, @@ -28,14 +27,15 @@ as_tensor_variable, cast, constant, + expand_dims, get_underlying_scalar_constant_value, moveaxis, ones_like, register_infer_shape, switch, + zeros, zeros_like, ) -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays @@ -44,14 +44,10 @@ Prod, Sum, _conj, - _inner_prod, - _matrix_matrix_matmul, - _matrix_vec_prod, - _vec_matrix_prod, + _matmul, add, digamma, dot, - eq, erf, erfc, exp, @@ -98,6 +94,7 @@ register_useless, ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift +from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -131,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): return consts, origconsts, nonconsts -@register_canonicalize -@register_stabilize +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") @node_rewriter([Dot]) def local_0_dot_x(fgraph, node): - if not isinstance(node.op, Dot): - return False - - x = node.inputs[0] - y = node.inputs[1] - replace = ( + x, y = node.inputs + if ( get_underlying_scalar_constant_value( x, only_process_constants=True, raise_not_constant=False ) @@ -149,157 +142,203 @@ def local_0_dot_x(fgraph, node): y, only_process_constants=True, raise_not_constant=False ) == 0 - ) - - if replace: - constant_zero = constant(0, dtype=node.outputs[0].type.dtype) - if x.ndim == 2 and y.ndim == 2: - constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) - return [alloc(constant_zero, x.shape[0], y.shape[1])] - elif x.ndim == 1 and y.ndim == 2: - constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) - return [alloc(constant_zero, y.shape[1])] - elif x.ndim == 2 and y.ndim == 1: - constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) - return [alloc(constant_zero, x.shape[0])] - elif x.ndim == 1 and y.ndim == 1: - constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) - return [constant_zero] + ): + return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] @register_canonicalize -@node_rewriter([DimShuffle]) +@node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``. These rewrites "lift" (propagate towards the inputs) `DimShuffle` through dot product. It allows to put the graph in a more standard shape, and to later merge consecutive `DimShuffle`\s. + """ + + clients = fgraph.clients[node.out] + if len(clients) != 1: + # If the dot is used in more than one place, we don't want to duplicate it + return None - The transformation should be apply whether or not the transpose is - inplace. The newly-introduced transpositions are not inplace, this will - be taken care of in a later rewrite phase. + [(client, _)] = clients - """ - if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)): - return False - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): - return False - x, y = node.inputs[0].owner.inputs + if not (isinstance(client.op, DimShuffle) and is_matrix_transpose(client.out)): + return None - if x.ndim == y.ndim == 2: - # Output is dot product of transposed inputs in reverse order - ret = [dot(y.T, x.T)] + x, y = node.inputs + # Output is dot product of transposed inputs in reverse order + ret = node.op(y.mT, x.mT) - # Copy over stack trace to output from result of dot-product - copy_stack_trace(node.inputs[0], ret) - return ret + # Copy over stack trace to output from result of dot-product + copy_stack_trace(node.out, ret) + return {client.out: ret} -@register_stabilize -@register_specialize -@node_rewriter(tracks=[Blockwise]) -def local_batched_matmul_to_core_matmul(fgraph, node): - """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul. - Example, if x has batch dimensions, but y not: +def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool): + """Move batch dimensions of matmul operands to core matmul + + Example, if x has batch dimensions that don't overlap with batch dimensions of y x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1]) - It also works when y has batch dimensions, but x not. - """ + It also works for batch dimensions of y that don't overlap with batch dimensions of x - # Check whether we have a matmul operation in this node - if not ( - isinstance(node.op.core_op, Dot) - and len(node.op.inputs_sig[0]) == 2 - and len(node.op.inputs_sig[1]) == 2 - ): - return None + The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False` + """ x, y = node.inputs batch_ndim = node.op.batch_ndim(node) - # Check if x has batch dimensions, but y not (or only broadcastable dimensions) - if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all( - y.type.broadcastable[:-2] - ): - x_stacked = x.reshape((-1, x.shape[-1])) - out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim))) - out = out_stacked.reshape((*x.shape[:-1], y.shape[-1])) - return [out] - - # Otherwise, check if y has batch dimension, but x not - elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all( - x.type.broadcastable[:-2] - ): - # For the y batch case we need to first move the batch axes and then reshape - # y.shape == (*b, k, n) - y_tr = moveaxis(y, -2, 0) # (k, *b, n) - y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n) - out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n) - out_stacked_tr = out_stacked.reshape( - (x.shape[-2], *y.shape[:-2], y.shape[-1]) - ) # (m, *b, n) - out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n) - return [out] - - # Both x and y have batch dimensions, nothing to do here - return None + x_axis_to_merge = [ + i + for i, (bcast_x, bcast_y) in enumerate( + zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2]) + ) + if bcast_y and not bcast_x + ] + + y_axis_to_merge = [ + i + for i, (bcast_x, bcast_y) in enumerate( + zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2]) + ) + if bcast_x and not bcast_y + ] + + if not (x_axis_to_merge or y_axis_to_merge): + return None + + x_shape = tuple(x.shape) + y_shape = tuple(y.shape) + x_is_row = x.type.broadcastable[-2] + y_is_col = y.type.broadcastable[-1] + n_x_axis_to_merge = len(x_axis_to_merge) + n_y_axis_to_merge = len(y_axis_to_merge) + n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge + + x_stacked, y_stacked = x, y + dims_were_merged = False + + if n_x_axis_to_merge: + # ravel batch dimensions of x on the core (m) axis + x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2)) + x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination) + if x_is_row: + # x was a row matrix, squeeze it to clean up the graph + x_stacked = x_stacked.squeeze(-2) + if n_x_axis_to_merge > 1 or not x_is_row: + if not allow_reshape: + # TODO: We could allow the y rewrite to go on + # Or just move one axis (the largest) if x is row + return None + + # Ravel moved batch dims together with (m) if needed + x_stacked_shape = tuple(x_stacked.shape) + x_stacked = x_stacked.reshape( + (*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1]) + ) + dims_were_merged = True + + if n_y_axis_to_merge: + # ravel batch dimensions of y on the core (n) axis + y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1)) + y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination) + if y_is_col: + # y was a column matrix, squeeze it to clean up the graph + y_stacked = y_stacked.squeeze(-1) + if n_y_axis_to_merge > 1 or not y_is_col: + if not allow_reshape: + # TODO: We could allow the x rewrite to go on + # Or just move one axis (the largest) if y is col + return False + # Ravel moved batch dims together with (n) if needed + y_stacked_shape = tuple(y_stacked.shape) + y_stacked = y_stacked.reshape( + (*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1) + ) + dims_were_merged = True + + # Squeeze x_dims corresponding to merged dimensions of y + x_axis_to_squeeze = np.array(y_axis_to_merge) + for i in reversed(x_axis_to_merge): + # The corresponding dimensions of y may have shifted when we merged dimensions of x + x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1 + x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze)) + + # Same for y + y_axis_to_squeeze = np.array(x_axis_to_merge) + for i in reversed(y_axis_to_merge): + y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1 + y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze)) + + out_stacked = x_stacked @ y_stacked + + # Split back any merged dimensions + if dims_were_merged: + x_merged_shapes = [x_shape[i] for i in x_axis_to_merge] + if not x_is_row: + # Otherwise we handle that later with expand_dims, which is cleaner + x_merged_shapes.append(x_shape[-2]) + y_merged_shapes = [y_shape[i] for i in y_axis_to_merge] + if not y_is_col: + # Otherwise we handle that later with expand_dims, which is cleaner + y_merged_shapes.append(y_shape[-1]) + out_stacked_shape = tuple(out_stacked.shape) + out_unstacked = out_stacked.reshape( + ( + *out_stacked_shape[: batch_ndim - n_axis_to_merge], + *x_merged_shapes, + *y_merged_shapes, + ) + ) + else: + out_unstacked = out_stacked + + # Add back dummy row, col axis + # We do this separately to avoid the reshape as much as we can + if y_is_col and (n_y_axis_to_merge or dims_were_merged): + out_unstacked = expand_dims(out_unstacked, -1) + if x_is_row and (n_x_axis_to_merge or dims_were_merged): + out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2) + + # Move batch axis back to their original location + source = range(-n_axis_to_merge - 2, 0) + destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1) + out = moveaxis(out_unstacked, source, destination) + return [out] @register_canonicalize +@node_rewriter(tracks=[_matmul]) +def local_batched_matmul_to_core_matmul(fgraph, node): + # Allow passing batch dimensions of matmul to core vector / column matrices + return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=False) + + @register_specialize -@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]) -def local_blockwise_dot_to_mul(fgraph, node): - """Rewrite blockwise dots that correspond to multiplication without summation. +@node_rewriter(tracks=[_matmul]) +def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node): + # Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation + # We only apply this in specialize, because grahs with reshape are hard to work with + return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=True) - We don't touch the regular dot, to not interfere with the BLAS optimizations. - """ + +@register_canonicalize +@register_specialize +@node_rewriter([_matmul, Dot]) +def local_dot_to_mul(fgraph, node): + """Rewrite blockwise dots that correspond to multiplication without summation.""" a, b = node.inputs a_static_shape = a.type.shape b_static_shape = b.type.shape - core_a_ndim = len(node.op.inputs_sig[0]) - core_b_ndim = len(node.op.inputs_sig[1]) - if core_a_ndim > 2 or core_b_ndim > 2: - # Shouldn't happen, but here just in case + # Check if we have (..., m, 1) * (..., 1, n) -> (..., m, n) + if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1): return None - if core_b_ndim == 1: - if a_static_shape[-1] == 1 or b_static_shape[-1] == 1: - if core_a_ndim == 1: - # inner product: (..., 1) * (..., 1) -> (...) - # just squeeze the last dimensions of a and b - new_a = a.squeeze(-1) - new_b = b.squeeze(-1) - else: - # matrix vector product: (..., m, 1) * (..., 1) -> (..., m) - # the last dimension of b is already aligned for the elemwise multiplication - # after we squeeze the last dimension of a - new_a = a.squeeze(-1) - new_b = b - else: - return None - - else: - if a_static_shape[-1] == 1 or b_static_shape[-2] == 1: - if core_a_ndim == 1: - # vector_matrix product: (..., 1) * (..., 1, n) -> (..., n) - # the last dimension of a is already aligned for the elemwise multiplication - # after we squeeze the one to last dimension of b - new_a = a - new_b = b.squeeze(-2) - else: - # matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n) - # the dimensions of a and b are already aligned for the elemwise multiplication - new_a = a - new_b = b - else: - return None - - new_a = copy_stack_trace(a, new_a) - new_b = copy_stack_trace(b, new_b) - new_out = copy_stack_trace(node.out, mul(new_a, new_b)) + new_out = mul(a, b) + copy_stack_trace(node.out, new_out) return [new_out] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 0ca6e0b452..31b8bfd2bd 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -351,7 +351,8 @@ def local_useless_slice(fgraph, node): new_idxs[dim] = slice(start, stop, step) if change_flag or ((last_useful_idx + 1) < len(idxs)): - out = x[tuple(new_idxs[: last_useful_idx + 1])] + new_idxs = tuple(new_idxs[: last_useful_idx + 1]) + out = x[new_idxs] if new_idxs else x # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) return [out] @@ -369,74 +370,73 @@ def local_subtensor_merge(fgraph, node): """ from pytensor.scan.op import Scan - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if u.owner and isinstance(u.owner.op, Subtensor): - # We can merge :) - # x actual tensor on which we are picking slices - x = u.owner.inputs[0] - # slices of the first applied subtensor - slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) - slices2 = get_idx_list(node.inputs, node.op.idx_list) - - # Don't try to do the optimization on do-while scan outputs, - # as it will create a dependency on the shape of the outputs - if ( - x.owner is not None - and isinstance(x.owner.op, Scan) - and x.owner.op.info.as_while - ): - return None + u = node.inputs[0] + if not (u.owner is not None and isinstance(u.owner.op, Subtensor)): + return None - # Get the shapes of the vectors ! - try: - # try not to introduce new shape into the graph - xshape = fgraph.shape_feature.shape_of[x] - ushape = fgraph.shape_feature.shape_of[u] - except AttributeError: - # Following the suggested use of shape_feature which should - # consider the case when the compilation mode doesn't - # include the ShapeFeature - xshape = x.shape - ushape = u.shape - - merged_slices = [] - pos_2 = 0 - pos_1 = 0 - while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): - slice1 = slices1[pos_1] - if isinstance(slice1, slice): - merged_slices.append( - merge_two_slices( - fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] - ) - ) - pos_2 += 1 - else: - merged_slices.append(slice1) - pos_1 += 1 - - if pos_2 < len(slices2): - merged_slices += slices2[pos_2:] - else: - merged_slices += slices1[pos_1:] + # We can merge :) + # x actual tensor on which we are picking slices + x = u.owner.inputs[0] + # slices of the first applied subtensor + slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) + slices2 = get_idx_list(node.inputs, node.op.idx_list) - merged_slices = tuple(as_index_constant(s) for s in merged_slices) - subtens = Subtensor(merged_slices) + # Don't try to do the optimization on do-while scan outputs, + # as it will create a dependency on the shape of the outputs + if ( + x.owner is not None + and isinstance(x.owner.op, Scan) + and x.owner.op.info.as_while + ): + return None - sl_ins = get_slice_elements( - merged_slices, lambda x: isinstance(x, Variable) + # Get the shapes of the vectors ! + try: + # try not to introduce new shape into the graph + xshape = fgraph.shape_feature.shape_of[x] + ushape = fgraph.shape_feature.shape_of[u] + except AttributeError: + # Following the suggested use of shape_feature which should + # consider the case when the compilation mode doesn't + # include the ShapeFeature + xshape = x.shape + ushape = u.shape + + merged_slices = [] + pos_2 = 0 + pos_1 = 0 + while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): + slice1 = slices1[pos_1] + if isinstance(slice1, slice): + merged_slices.append( + merge_two_slices( + fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] + ) ) - # Do not call make_node for test_value - out = subtens(x, *sl_ins) + pos_2 += 1 + else: + merged_slices.append(slice1) + pos_1 += 1 - # Copy over previous output stacktrace - # and stacktrace from previous slicing operation. - # Why? Because, the merged slicing operation could have failed - # because of either of the two original slicing operations - orig_out = node.outputs[0] - copy_stack_trace([orig_out, node.inputs[0]], out) - return [out] + if pos_2 < len(slices2): + merged_slices += slices2[pos_2:] + else: + merged_slices += slices1[pos_1:] + + merged_slices = tuple(as_index_constant(s) for s in merged_slices) + subtens = Subtensor(merged_slices) + + sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable)) + # Do not call make_node for test_value + out = subtens(x, *sl_ins) + + # Copy over previous output stacktrace + # and stacktrace from previous slicing operation. + # Why? Because, the merged slicing operation could have failed + # because of either of the two original slicing operations + orig_out = node.outputs[0] + copy_stack_trace([orig_out, node.inputs[0]], out) + return [out] @register_specialize @@ -825,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): if not isinstance(slice1, slice): raise ValueError("slice1 should be of type `slice`") + # Simple case where one of the slices is useless + if is_full_slice(slice1): + return slice2 + elif is_full_slice(slice2): + return slice1 + sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl2, reverse2 = get_canonical_form_slice(slice2, len2) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 5a367a302a..bfb78a98e5 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -5,7 +5,7 @@ from pytensor import Variable from pytensor.compile import optdb -from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar import basic as ps @@ -20,6 +20,7 @@ join, register_infer_shape, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import squeeze @@ -118,49 +119,57 @@ def local_subtensor_of_dot(fgraph, node): the remaining entries of ``idxs`` (if any), modified to skip the second-to-last dimension of ``B`` (because dot sums over this dimension). """ - if not isinstance(node.op, Subtensor): - return - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + x, *idx_vars = node.inputs + if not ( + x.owner is not None + and ( + isinstance(x.owner.op, Dot) + or ( + isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, Dot) + ) + ) + ): return # If there is other node that use the outputs of the dot # We don't want to compute twice the sub part. - if len(fgraph.clients[node.inputs[0]]) > 1: + if len(fgraph.clients[x]) > 1: return - a = node.inputs[0].owner.inputs[0] - b = node.inputs[0].owner.inputs[1] - - idx_list = get_idx_list(node.inputs, node.op.idx_list) - - num_a_indices = min(a.ndim - 1, len(idx_list)) - a_indices = idx_list[:num_a_indices] - b_indices = idx_list[num_a_indices:] - - # This is necessary because np.dot sums the last index of a with the second to last of b - # so we want to skip the second-to-last index into b. - # This wasn't necessary for a, because we just omitted the last index. - # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] - # (dot also handles b.ndim < 2 as a special case) - if b.ndim > 1 and len(b_indices) >= b.ndim - 1: - b_indices = ( - b_indices[: b.ndim - 2] - + (slice(None, None, None),) - + b_indices[b.ndim - 2 :] - ) + a = x.owner.inputs[0] + b = x.owner.inputs[1] + idx_list = indices_from_subtensor(idx_vars, node.op.idx_list) - a_sub = a.__getitem__(tuple(a_indices)) - b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b + if not idx_list: + # Nothing to do, `local_useless_slice` will handle this case + return None + + batch_ndim = ( + x.owner.op.batch_ndim(x.owner) if isinstance(x.owner.op, Blockwise) else 0 + ) - # Copy over previous output stacktrace to a_sub and b_sub, - # because an error in the subtensor operation (e.g. an index error) - # on either a or b must correspond to an error in the - # subtensor operation on their dot product. - copy_stack_trace(node.outputs[0], [a_sub, b_sub]) + if batch_ndim: + batch_idx_list, idx_list = idx_list[:batch_ndim], idx_list[batch_ndim:] + if not idx_list: + # Indexing only over batch dimensions of Blockwise, nothing to do here + # This will be handled by `local_subtensor_of_batch_dims` + return None + # We perform the rest of the rewrite on dummy a, b that correspond to the core case + a = a.type.clone(shape=a.type.shape[batch_ndim:])() + b = b.type.clone(shape=b.type.shape[batch_ndim:])() - # Copy over previous output stacktrace and previous dot product stacktrace, - # because an error here may correspond to an either in either the original - # dot product, or in the dot product after the subtensor operation. + a_indices = idx_list[:1] + b_indices = (slice(None), *idx_list[1:]) + + a_sub = a[tuple(a_indices)] + b_sub = b[tuple(b_indices)] r = dot(a_sub, b_sub) + + if batch_ndim: + # Replace dummy inputs by the original batch ones + r = vectorize_graph(r, replace={a: x.owner.inputs[0], b: x.owner.inputs[1]}) + r = r[tuple(batch_idx_list)] + copy_stack_trace([node.outputs[0], node.inputs[0]], r) return [r] @@ -169,8 +178,8 @@ def local_subtensor_of_dot(fgraph, node): @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Subtensor]) -def local_subtensor_of_elemwise(fgraph, node): - """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior. +def local_subtensor_of_batch_dims(fgraph, node): + """Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior. exp(x)[:, 0] -> exp(x[:, 0]) add(x, y)[0] -> add(x[0], y[0]) @@ -178,7 +187,7 @@ def local_subtensor_of_elemwise(fgraph, node): """ elem, *idx = node.inputs - if not (elem.owner and isinstance(elem.owner.op, Elemwise)): + if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)): return None if len(fgraph.clients[elem]) > 1: @@ -188,9 +197,34 @@ def local_subtensor_of_elemwise(fgraph, node): idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + batch_ndim = ( + elem.owner.op.batch_ndim(elem.owner) + if isinstance(elem.owner.op, Blockwise) + else elem.ndim + ) + + if len(idx_tuple) > batch_ndim: + # Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only + batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:] + if all(is_full_slice(idx) for idx in batch_indices): + # No batch indices, nothing to do + return None + elem_with_batch_indices = elem[batch_indices] + [elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform( + fgraph, elem_with_batch_indices.owner + ) + # Reapply the core_indices + core_ndim = elem.type.ndim - batch_ndim + # Number of batch dims may have changed with the lifting of indices, so we recompute + new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim + new_indices = (*(slice(None),) * new_batch_ndim, *core_indices) + new_elem = elem_with_batch_indices_lifted[new_indices] + copy_stack_trace(node.outputs[0], new_elem) + return [new_elem] + elem_inputs = elem.owner.inputs - elem_bcast = elem.type.broadcastable - if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs): + elem_bcast = elem.type.broadcastable[:batch_ndim] + if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs): # No need to worry about implicit broadcasting. indexed_inputs = [inp[idx_tuple] for inp in elem_inputs] @@ -201,7 +235,7 @@ def local_subtensor_of_elemwise(fgraph, node): zip( idx_tuple, elem_bcast, - *(inp.type.broadcastable for inp in elem_inputs), + *(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs), # Indices can be shorter than input ndims strict=False, ) @@ -435,6 +469,41 @@ def local_subtensor_of_expand_dims(fgraph, node): return [out] +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_squeeze(fgraph, node): + """Lift subtensor through a squeeze operation""" + x, *idxs_vars = node.inputs + if not ( + x.owner is not None + and isinstance(x.owner.op, DimShuffle) + and x.owner.op.is_squeeze + ): + return None + + [x_before_squeeze] = x.owner.inputs + idxs = indices_from_subtensor(idxs_vars, node.op.idx_list) + dropped_dims = x.owner.op.drop + + # Apply indices directly on x + # Add empty slices on the axis that squeeze would have removed + new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None)) + x_indexed = x_before_squeeze[tuple(new_idxs)] + + # Reapply squeeze + # Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims + new_dropped_dims = np.array(dropped_dims) + for i, new_idx in reversed(tuple(enumerate(new_idxs))): + if not isinstance(new_idx, slice): + # If it's not a slice, it's an integer which drops the dimension + new_dropped_dims[new_dropped_dims > i] -= 1 + new_x = x_indexed.squeeze(tuple(new_dropped_dims)) + + copy_stack_trace(x, new_x) + return [new_x] + + @register_canonicalize @register_specialize @node_rewriter([Subtensor]) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 3b880616df..95cf6ec557 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -631,7 +631,7 @@ def test_Dot(x, y): x, x_test_value = x y, y_test_value = y - g = ptm.Dot()(x, y) + g = ptm.dot(x, y) compare_numba_and_py( [x, y], diff --git a/tests/tensor/rewriting/test_blas.py b/tests/tensor/rewriting/test_blas.py index d939ceedce..10e040367c 100644 --- a/tests/tensor/rewriting/test_blas.py +++ b/tests/tensor/rewriting/test_blas.py @@ -1,10 +1,10 @@ import numpy as np import pytest -from pytensor import function +from pytensor import config, function from pytensor import tensor as pt from pytensor.compile import get_default_mode -from pytensor.graph import FunctionGraph +from pytensor.graph import FunctionGraph, ancestors from pytensor.tensor import ( col, dscalar, @@ -21,7 +21,6 @@ vectorize, ) from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.rewriting.blas import ( _as_scalar, @@ -37,8 +36,11 @@ def XYZab(): return matrix(), matrix(), matrix(), scalar(), scalar() -@pytest.mark.parametrize("valid_case", (True, False)) -def test_specialize_matmul_to_batched_dot(valid_case): +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", reason="Test requires specialization rewrites" +) +@pytest.mark.parametrize("aligned", (True, False)) +def test_specialize_matmul_to_batched_dot(aligned): signature = BatchedDot.gufunc_signature rewrite = specialize_matmul_to_batched_dot.__name__ @@ -49,23 +51,36 @@ def core_np(x, y): return np.matmul(x, y) x = tensor(shape=(7, 5, 3, 3)) - if valid_case: + if aligned: y = tensor(shape=(7, 5, 3, 3)) else: y = tensor(shape=(5, 3, 3)) + out = vectorize(core_pt, signature=signature)(x, y) + + assert ( + sum( + isinstance(var.owner.op, BatchedDot) + for var in ancestors([out]) + if var.owner + ) + == 0 + ) + vectorize_pt = function( [x, y], - vectorize(core_pt, signature=signature)(x, y), + out, mode=get_default_mode().including(rewrite), ) - blocwkise_node = any( - isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + + assert ( + sum( + isinstance(var.owner.op, BatchedDot) + for var in ancestors(vectorize_pt.maker.fgraph.outputs) + if var.owner + ) + == 1 ) - if valid_case: - assert not blocwkise_node - else: - assert blocwkise_node x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3699a3fcff..3addead03f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -42,6 +42,7 @@ Prod, Sum, _conj, + _matmul, add, arccosh, arcsinh, @@ -1203,52 +1204,6 @@ def test_local_log_add_exp(): # TODO: test that the rewrite works in the presence of broadcasting. -def test_local_subtensor_of_dot(): - m1 = matrix() - m2 = matrix() - d1 = np.arange(6).reshape((3, 2)).astype(config.floatX) - d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10 - mode = get_default_mode().including("local_subtensor_of_dot") - - def test_equality(a, b): - return a.shape == b.shape and np.allclose(a, b) - - # [cst] - f = function([m1, m2], pytensor.tensor.dot(m1, m2)[1], mode=mode) - topo = f.maker.fgraph.toposort() - assert test_equality(f(d1, d2), np.dot(d1, d2)[1]) - # DimShuffle happen in FAST_COMPILE - assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle) - - # slice - f = function([m1, m2], pytensor.tensor.dot(m1, m2)[1:2], mode=mode) - topo = f.maker.fgraph.toposort() - assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2]) - assert isinstance(topo[-1].op, Dot22) - - m1 = tensor3() - m2 = tensor3() - idx = iscalar() - d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX) - d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100 - - f = function( - [m1, m2, idx], pytensor.tensor.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode - ) - assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:]) - # if we return the gradients. We need to use same mode as before. - assert check_stack_trace(f, ops_to_check="last") - - f = function( - [m1, m2, idx], pytensor.tensor.dot(m1, m2)[1:4, :, idx:, idx], mode=mode - ) - assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1]) - - # Now test that the stack trace is copied over properly, - # if we return the gradients. We need to use same mode as before. - assert check_stack_trace(f, ops_to_check="last") - - def test_local_elemwise_sub_zeros(): scal = scalar() vect = vector() @@ -4612,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul(): np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) +@pytest.mark.parametrize( + "mat_shape, vec_shape", + [ + [(1, 2, 2), (5, 2)], + [(5, 2, 2), (1, 2)], + [(1, 1, 2, 2), (7, 5, 2)], + [(7, 5, 2, 2), (1, 1, 5, 2)], + [(1, 5, 1, 2, 2), (7, 5, 7, 2)], + [(7, 5, 7, 2, 2), (1, 5, 1, 2)], + [(5, 1, 3, 1, 2, 2), (1, 7, 3, 7, 2)], + [(1, 7, 3, 7, 2, 2), (5, 1, 3, 1, 2)], + ], + ids=str, +) +@pytest.mark.parametrize("func", ("matvec", "vecmat", "vecdot")) +def test_batch_matvec_to_matmul(func, mat_shape, vec_shape): + def count_matvec_nodes(graph): + # Counts how many matmul nodes actually correspond to matvec or vecmat + return len( + [ + var + for var in ancestors([graph]) + if ( + var.owner is not None + and var.owner.op == _matmul + and ( + (var.owner.inputs[0].type.shape[-2] == 1) + or (var.owner.inputs[1].type.shape[-1] == 1) + ) + ) + ] + ) + + mat = pt.tensor("mat", shape=mat_shape, dtype="float64") + vec = pt.tensor("vec", shape=vec_shape, dtype="float64") + + if func == "matvec": + out = pt.matvec(mat, vec) + elif func == "vecmat": + out = pt.vecmat(vec, mat) + elif func == "vecdot": + out = pt.vecdot(mat[..., 0], vec) + else: + raise NotImplementedError(func) + + assert count_matvec_nodes(out) == 1 + + rewritten_out = rewrite_graph( + out, + include=( + "canonicalize", + "specialize", + ), + exclude=( + "local_eager_useless_unbatched_blockwise", + "specialize_matmul_to_batched_dot", + ), + ) + # No `matvec` in the rewritten out if one of the vector can be treated as a matrix + expected = not any( + mat_dim == 1 and vec_dim != 1 + for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2]) + ) + if not expected and func == "vecdot": + # In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix + expected = not any( + mat_dim != 1 and vec_dim == 1 + for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2]) + ) + + assert count_matvec_nodes(rewritten_out) == expected + + rng = np.random.default_rng(mat_shape + vec_shape) + eval_dict = {mat: rng.random(mat.type.shape), vec: rng.random(vec.type.shape)} + # Evaluate results are correct without further rewrites + no_optimization = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + rewritten_out.eval(eval_dict, mode=no_optimization), + out.eval(eval_dict, mode=no_optimization), + ) + + def test_log_kv_stabilization(): x = pt.scalar("x") out = log(kv(4.5, x)) @@ -4662,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out = dot(a, b) if batched: - batch_a = tensor("batch_a", shape=(1, 5, *a_shape)) - batch_b = tensor("batch_b", shape=(7, 1, *b_shape)) + batch_a = tensor("batch_a", shape=(2, 1, 5, *a_shape)) + batch_b = tensor("batch_b", shape=(2, 7, 1, *b_shape)) out = vectorize_graph(out, {a: batch_a, b: batch_b}) a = batch_a b = batch_b @@ -4677,14 +4714,16 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): == 1 ) - # For now rewrite only applies to Batched Dots rewritten_out = rewrite_graph(out) assert rewritten_out.type.shape == out.type.shape - assert sum( - isinstance(var.owner.op, (Blockwise | Dot)) - for var in ancestors([rewritten_out]) - if var.owner - ) == (0 if batched else 1) + assert ( + sum( + isinstance(var.owner.op, (Blockwise | Dot)) + for var in ancestors([rewritten_out]) + if var.owner + ) + == 0 + ) a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype) b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index ccfa033859..6f87f305a6 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -14,6 +14,7 @@ from pytensor.graph import ( Constant, FunctionGraph, + Op, RewriteDatabaseQuery, Type, rewrite_graph, @@ -23,6 +24,7 @@ from pytensor.printing import debugprint from pytensor.tensor import ( add, + dvector, exp, iscalar, iscalars, @@ -37,11 +39,14 @@ vector, ) from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector +from pytensor.tensor.blas import Dot22, Gemv +from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.subtensor_lift import ( local_subtensor_make_vector, - local_subtensor_of_elemwise, + local_subtensor_of_batch_dims, local_subtensor_shape_constant, ) from pytensor.tensor.shape import SpecifyShape, _shape @@ -58,7 +63,7 @@ NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) -class TestLocalSubtensorOfElemwise: +class TestLocalSubtensorOfBatchDims: def test_unary_multiple_clients(self): # as test0, but we reuse the output of the elemwise # So we should not lift the subtensor @@ -144,7 +149,7 @@ def test_multinary_multiple_clients(self): ), ], ) - def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): + def test_elemwise(self, original_fn, expected_fn): rng = np.random.default_rng(257) x = pt.matrix("x", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3)) @@ -163,7 +168,7 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): out.eval({x: x_test, y: y_test}, **eval_kwargs), ) - def test_local_subtensor_of_elemwise_multiple_clients(self): + def test_elemwise_multiple_clients(self): x = pt.matrix("x", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3)) out1 = add(x, y) @@ -171,11 +176,90 @@ def test_local_subtensor_of_elemwise_multiple_clients(self): # Rewrite should fail when another node uses out1 directly (in this case it's an extra output) fgraph = FunctionGraph([x, y], [out1, out2], clone=False) - assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None + assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None # Otherwise it should work fgraph.remove_output(0) - assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None + assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None + + def test_blockwise(self): + class CoreTestOp(Op): + itypes = [dvector, dvector] + otypes = [dvector] + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.convolve(*inputs, mode="valid") + + core_test_op = CoreTestOp() + block_test_op = Blockwise(core_test_op, signature="(a),(b)->(c)") + + x = tensor3("x", shape=(7, 5, 11), dtype="float64") + y = tensor("y", shape=(7, 33), dtype="float64") + out = block_test_op(x, y[:, None, :]) + assert isinstance(out.owner.op, Blockwise) + + out_sliced = out[2:][:, 3:] + rewritten_out_sliced = rewrite_graph(out_sliced) + expected_out_sliced = block_test_op(x[2:, 3:], y[2:][:, None, :]) + assert equal_computations([rewritten_out_sliced], [expected_out_sliced]) + + rng = np.random.default_rng(191) + x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) + y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) + np.testing.assert_allclose( + rewritten_out_sliced.eval( + {x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE + ), + out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE), + ) + + # Check slice on core dims + out_sliced = out[2:][:, 0][:, 4:] + rewritten_out_sliced = rewrite_graph(out_sliced) + expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:] + assert equal_computations([rewritten_out_sliced], [expected_out_sliced]) + + +def test_local_subtensor_of_dot(): + m1 = matrix() + m2 = matrix() + d1 = np.arange(6).reshape((3, 2)).astype(config.floatX) + d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10 + mode = get_default_mode().including("local_subtensor_of_dot") + + def test_equality(a, b): + return a.shape == b.shape and np.allclose(a, b) + + # [cst] + f = function([m1, m2], pt.dot(m1, m2)[1], mode=mode) + topo = f.maker.fgraph.toposort() + assert test_equality(f(d1, d2), np.dot(d1, d2)[1]) + # DimShuffle happen in FAST_COMPILE + assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle) + + # slice + f = function([m1, m2], pt.dot(m1, m2)[1:2], mode=mode) + topo = f.maker.fgraph.toposort() + assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2]) + assert isinstance(topo[-1].op, Dot22) + + m1 = tensor3() + m2 = tensor3() + idx = iscalar() + d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX) + d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100 + + f = function([m1, m2, idx], pt.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode) + assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:]) + # if we return the gradients. We need to use same mode as before. + assert check_stack_trace(f, ops_to_check="last") + + f = function([m1, m2, idx], pt.dot(m1, m2)[1:4, :, idx:, idx], mode=mode) + assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1]) + + # Now test that the stack trace is copied over properly, + # if we return the gradients. We need to use same mode as before. + assert check_stack_trace(f, ops_to_check="last") @pytest.mark.parametrize( diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index f3fcf72cc5..1332266e3d 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -1903,17 +1903,9 @@ def test_f32_1_2(self): def test_f64_4_5(self): return self.given_dtype("float64", 4, 5, destructive=False) - @pytest.mark.xfail( - condition=config.floatX == "float32", - reason="GER from complex64 is not introduced in float32 mode", - ) def test_c64_7_1(self): return self.given_dtype("complex64", 7, 1) - @pytest.mark.xfail( - raises=AssertionError, - reason="Unclear how this test was supposed to work with complex128", - ) def test_c128_1_9(self): return self.given_dtype("complex128", 1, 9) diff --git a/tests/tensor/test_blas_scipy.py b/tests/tensor/test_blas_scipy.py deleted file mode 100644 index 716eab7bbe..0000000000 --- a/tests/tensor/test_blas_scipy.py +++ /dev/null @@ -1,75 +0,0 @@ -import pickle - -import numpy as np - -import pytensor -from pytensor import tensor as pt -from pytensor.tensor.blas_scipy import ScipyGer -from pytensor.tensor.math import outer -from pytensor.tensor.type import tensor -from tests.tensor.test_blas import TestBlasStrides, gemm_no_inplace -from tests.unittest_tools import OptimizationTestMixin - - -class TestScipyGer(OptimizationTestMixin): - def setup_method(self): - self.mode = pytensor.compile.get_default_mode() - self.mode = self.mode.including("fast_run") - self.mode = self.mode.excluding("c_blas") # c_blas trumps scipy Ops - dtype = self.dtype = "float64" # optimization isn't dtype-dependent - self.A = tensor(dtype=dtype, shape=(None, None)) - self.a = tensor(dtype=dtype, shape=()) - self.x = tensor(dtype=dtype, shape=(None,)) - self.y = tensor(dtype=dtype, shape=(None,)) - self.Aval = np.ones((2, 3), dtype=dtype) - self.xval = np.asarray([1, 2], dtype=dtype) - self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype) - - def function(self, inputs, outputs): - return pytensor.function(inputs, outputs, self.mode) - - def run_f(self, f): - f(self.Aval, self.xval, self.yval) - f(self.Aval[::-1, ::-1], self.xval[::-1], self.yval[::-1]) - - def b(self, bval): - return pt.as_tensor_variable(np.asarray(bval, dtype=self.dtype)) - - def test_outer(self): - f = self.function([self.x, self.y], outer(self.x, self.y)) - self.assertFunctionContains(f, ScipyGer(destructive=True)) - - def test_A_plus_outer(self): - f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y)) - self.assertFunctionContains(f, ScipyGer(destructive=False)) - self.run_f(f) # DebugMode tests correctness - - def test_A_plus_scaled_outer(self): - f = self.function( - [self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y) - ) - self.assertFunctionContains(f, ScipyGer(destructive=False)) - self.run_f(f) # DebugMode tests correctness - - def test_scaled_A_plus_scaled_outer(self): - f = self.function( - [self.A, self.x, self.y], 0.2 * self.A + 0.1 * outer(self.x, self.y) - ) - self.assertFunctionContains(f, gemm_no_inplace) - self.run_f(f) # DebugMode tests correctness - - def test_pickle(self): - out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y) - f = pytensor.function([self.A, self.a, self.x, self.y], out) - new_f = pickle.loads(pickle.dumps(f)) - - assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer) - assert np.allclose( - f(self.Aval, 1.0, self.xval, self.yval), - new_f(self.Aval, 1.0, self.xval, self.yval), - ) - - -class TestBlasStridesScipy(TestBlasStrides): - mode = pytensor.compile.get_default_mode() - mode = mode.including("fast_run").excluding("gpu", "c_blas") diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 9b4b8ebbb9..950b53850a 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -2092,9 +2092,9 @@ def is_super_shape(var1, var2): def test_matrix_vector_ops(): """Test vecdot, matvec, and vecmat helper functions.""" - rng = np.random.default_rng(seed=utt.fetch_seed()) + rng = np.random.default_rng(2089) - # Create test data with batch dimension (2) + atol = 1e-7 if config.floatX == "float32" else 1e-15 batch_size = 2 dim_k = 4 # Common dimension dim_m = 3 # Matrix rows @@ -2109,7 +2109,6 @@ def test_matrix_vector_ops(): mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX) vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX) - # Create tensor variables with matching dtype mat_mk = tensor( name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX ) @@ -2130,7 +2129,7 @@ def test_matrix_vector_ops(): expected_vecdot = np.zeros((batch_size,), dtype=np.int32) for i in range(batch_size): expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i]) - np.testing.assert_allclose(result, expected_vecdot) + np.testing.assert_allclose(result, expected_vecdot, atol=atol) # Test 2: matvec - matrix-vector product matvec_out = matvec(mat_mk, vec_k) @@ -2141,7 +2140,7 @@ def test_matrix_vector_ops(): expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX) for i in range(batch_size): expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i]) - np.testing.assert_allclose(result_matvec, expected_matvec) + np.testing.assert_allclose(result_matvec, expected_matvec, atol=atol) # Test 3: vecmat - vector-matrix product vecmat_out = vecmat(vec_k, mat_kn) @@ -2152,7 +2151,7 @@ def test_matrix_vector_ops(): expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX) for i in range(batch_size): expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i]) - np.testing.assert_allclose(result_vecmat, expected_vecmat) + np.testing.assert_allclose(result_vecmat, expected_vecmat, atol=atol) class TestTensordot: @@ -2797,7 +2796,7 @@ def test_Dot(self): bdvec_val = random(4, rng=rng) self._compile_and_check( [advec, bdvec], - [Dot()(advec, bdvec)], + [dot(advec, bdvec)], [advec_val, bdvec_val], (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), ) @@ -2809,7 +2808,7 @@ def test_Dot(self): bdmat_val = random(5, 3, rng=rng) self._compile_and_check( [admat, bdmat], - [Dot()(admat, bdmat)], + [dot(admat, bdmat)], [admat_val, bdmat_val], (Dot, blas.Dot22), ) @@ -2818,7 +2817,7 @@ def test_Dot(self): bdmat_val = random(4, 5, rng=rng) self._compile_and_check( [advec, bdmat], - [Dot()(advec, bdmat)], + [dot(advec, bdmat)], [advec_val, bdmat_val], (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), ) @@ -2827,7 +2826,7 @@ def test_Dot(self): admat_val = random(5, 4, rng=rng) self._compile_and_check( [admat, bdvec], - [Dot()(admat, bdvec)], + [dot(admat, bdvec)], [admat_val, bdvec_val], (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), ) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 89712c19dd..8de9c24b18 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -32,7 +32,7 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.scan.op import Scan -from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, tanh +from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.random import RandomStream from pytensor.tensor.type import ( @@ -1143,6 +1143,24 @@ def test_benchmark(self, vectorize, benchmark): fn = function([x], jac_y, trust_input=True) benchmark(fn, np.array([0, 1, 2], dtype=x.type.dtype)) + def test_benchmark_partial_jacobian(self, vectorize, benchmark): + # Example from https://github.com/jax-ml/jax/discussions/5904#discussioncomment-422956 + N = 1000 + rng = np.random.default_rng(2025) + x_test = rng.random((N,)) + + f_mat = rng.random((N, N)) + x = vector("x", dtype="float64") + + def f(x): + return sqrt(f_mat @ x / N) + + full_jacobian = jacobian(f(x), x, vectorize=vectorize) + partial_jacobian = full_jacobian[:5, :5] + + f = pytensor.function([x], partial_jacobian, trust_input=True) + benchmark(f, x_test) + def test_hessian(): x = vector()