diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index a4a3d1840a..e8a6eb86b2 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -6,8 +6,11 @@ from functools import partial from typing import Union, cast -from pytensor.compile.function import function -from pytensor.compile.function.pfunc import rebuild_collect_shared +from pytensor.compile import get_default_mode, insert_deepcopy +from pytensor.compile.function.pfunc import pfunc, rebuild_collect_shared +from pytensor.compile.function.types import add_supervisor_to_fgraph +from pytensor.compile.io import In, Out +from pytensor.compile.mode import Mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType, Rop, grad @@ -21,7 +24,7 @@ ) from pytensor.graph.fg import FunctionGraph from pytensor.graph.null_type import NullType -from pytensor.graph.op import HasInnerGraph, Op +from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType from pytensor.graph.replace import clone_replace from pytensor.graph.utils import MissingInputError @@ -433,6 +436,9 @@ def __init__( assert isinstance(name, str), "name must be None or string object" self.name = name self.destroy_map = destroy_map if destroy_map is not None else {} + self._rewritten_fgraph = {} + self._wrapped_inputs = {} + self._wrapped_outputs = {} def __eq__(self, other): # TODO: recognize a copy @@ -847,14 +853,58 @@ def infer_shape(self, fgraph, node, shapes): return ret + def _rewrite_fgraph(self, impl): + if self._rewritten_fgraph.get(impl, None) is None: + mode = get_default_mode() + if impl == "py": + mode = mode.excluding("cxx") + rewriter = mode.optimizer + + # We are cloning fgraph too many times, but one of the existing tests checks for this + # TestOpFromGraph.test_outputs_consistency + fgraph = self.fgraph.clone() + self._wrapped_inputs[impl] = temp_wrapped_inputs = [ + In(inp, borrow=False, mutable=False) for inp in fgraph.inputs + ] + # These are just temporary because the graph rewirite may change them + temp_wrapped_outputs = [ + Out(out, borrow=True) for out in self.fgraph.outputs + ] + add_supervisor_to_fgraph( + fgraph, + temp_wrapped_inputs, + accept_inplace=False, + ) + with config.change_flags(compute_test_value="off"): + rewriter(fgraph) + insert_deepcopy(fgraph, temp_wrapped_inputs, temp_wrapped_outputs) + self._wrapped_outputs[impl] = [ + Out(out, borrow=True) for out in fgraph.outputs + ] + self._rewritten_fgraph[impl] = fgraph + + return ( + self._rewritten_fgraph[impl], + self._wrapped_inputs[impl], + self._wrapped_outputs[impl], + ) + @property def fn(self): - """Lazily compile the inner function graph.""" if getattr(self, "_fn", None) is not None: return self._fn - self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs) - self._fn.trust_input = True + fgraph, wrapped_inputs, wrapped_outputs = self._rewrite_fgraph(impl=None) + + self._fn = pfunc( + wrapped_inputs, + wrapped_outputs, + mode=Mode(linker=get_default_mode().linker, optimizer=None), + accept_inplace=True, + on_unused_input="ignore", + fgraph=fgraph, + trust_input=True, + ) return self._fn @@ -871,6 +921,58 @@ def clone(self): res.fgraph = res.fgraph.clone() return res + def prepare_node( + self, + node: Apply, + storage_map: StorageMapType | None, + compute_map: ComputeMapType | None, + impl: str | None, + ) -> None: + self._rewrite_fgraph(impl) + self.fn + + def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): + from pytensor.link.vm import VMLinker + + self.prepare_node(node, storage_map, compute_map, impl) + fg, _, _ = self._rewrite_fgraph(impl) + fg_no_recycling = [ + new_o + for (new_o, old_o) in zip(fg.outputs, node.outputs, strict=True) + if old_o in no_recycling + ] + + node_input_storage = [storage_map[r] for r in node.inputs] + node_output_storage = [storage_map[r] for r in node.outputs] + node_compute_map = [compute_map[r] for r in node.outputs] + + def create_thunk(linker): + linker.accept(fg, no_recycling=fg_no_recycling) + thunk, _, _ = linker.make_thunk( + input_storage=node_input_storage, + output_storage=node_output_storage, + ) + return thunk + + def thunk_wrapper(thunk=thunk, node_compute_map=node_compute_map): + thunk() + for cm in node_compute_map: + cm[0] = True + + return thunk_wrapper + + if impl != "py": + # try: + # # We default to CLinker because it generates code for the whole graph that the compiler can reason about. + # # Whereas the VMLinker will compile each node separately and call them in a pre-defined VM. + # # It also has less overhead + # return create_thunk(linker=CLinker()) + # except NotImplementedError: + # # Some Op doesn't have a C implementation, VM it is + return create_thunk(VMLinker(use_cloop=True, c_thunks=True)) + else: + return create_thunk(VMLinker(use_cloop=False, c_thunks=False)) + def perform(self, node, inputs, outputs): variables = self.fn(*inputs) assert len(variables) == len(outputs) diff --git a/pytensor/link/c/c_code/lazylinker_c.c b/pytensor/link/c/c_code/lazylinker_c.c index 08f3e4d0fb..b1b0f9ee37 100644 --- a/pytensor/link/c/c_code/lazylinker_c.c +++ b/pytensor/link/c/c_code/lazylinker_c.c @@ -676,20 +676,7 @@ static int lazy_rec_eval(CLazyLinker *self, Py_ssize_t var_idx, PyObject *one, // rval is new ref if (rval) // pycall returned normally (no exception) { - if (rval == Py_None) { - Py_DECREF(rval); // ignore a return of None - } else if (PyList_Check(rval)) { - PyErr_SetString(PyExc_TypeError, - "non-lazy thunk should return None, not list"); - err = 1; - goto pyfail; - } else // don't know what it returned, but it wasn't right. - { - PyErr_SetObject(PyExc_TypeError, rval); - err = 1; - // We don't release rval since we put it in the error above - goto fail; - } + Py_DECREF(rval); // ignore whatever was returned } else // pycall returned NULL (internal error) { err = 1; @@ -981,7 +968,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { }; static PyObject *get_version(PyObject *dummy, PyObject *args) { - PyObject *result = PyFloat_FromDouble(0.3); + PyObject *result = PyFloat_FromDouble(0.4); return result; } diff --git a/pytensor/link/c/lazylinker_c.py b/pytensor/link/c/lazylinker_c.py index ce67190342..f03f1cc008 100644 --- a/pytensor/link/c/lazylinker_c.py +++ b/pytensor/link/c/lazylinker_c.py @@ -14,7 +14,7 @@ _logger = logging.getLogger(__file__) force_compile = False -version = 0.3 # must match constant returned in function get_version() +version = 0.4 # must match constant returned in function get_version() lazylinker_ext: ModuleType | None = None diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 59148fae3b..6d559bff7d 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -1120,15 +1120,11 @@ def unconditional_constant_folding(fgraph, node): compute_map[o] = [False] thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) - required = thunk() - - # A node whose inputs are all provided should always return successfully - assert not required + thunk() rval = [] for output in node.outputs: data = storage_map[output][0] - assert compute_map[output][0], (output, data) # TODO: `Type` itself should provide an interface for constructing # instances appropriate for a given constant. diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index ba0257cdda..6017bc7a60 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -4,6 +4,7 @@ import pytest import pytensor.tensor as pt +from pytensor import scan from pytensor.compile import shared from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -15,9 +16,10 @@ grad, verify_grad, ) -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.null_type import NullType, null_type +from pytensor.graph.op import Op from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint @@ -622,14 +624,15 @@ def test_outputs_consistency(self): """Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`.""" x = scalar("x") - op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN") + op = OpFromGraph([x], [x**2 / x]) # Confirm that the inner-graph is as expected assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) # These outputs of the compiled `op.fgraph` should differ from the # original, uncompiled `op.fgraph` outputs - fn = op.fn + with config.change_flags(mode="FAST_RUN"): + fn = op.fn new_inputs = fn.maker.fgraph.inputs new_outputs = fn.maker.fgraph.outputs assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x]) @@ -740,3 +743,58 @@ def test_debugprint(): for truth, out in zip(exp_res.split("\n"), lines, strict=True): assert truth.strip() == out.strip() + + +@pytest.mark.parametrize("kind", ("ofg", "inlined", "scan")) +@pytest.mark.parametrize("c_op", (True, False), ids=lambda x: f"c_op={x}") +def test_benchmark(c_op, kind, benchmark): + class ExpWithoutC(Op): + def make_node(self, x): + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.exp(inputs[0]) + + exp_without_c = ExpWithoutC() + + n = 25 + + def _f(x): + if isinstance(x, np.ndarray): + y = np.exp(x) + else: + if c_op: + y = pt.exp(x) + else: + y = exp_without_c(x) + y /= y.sum() + return y + + x = pt.vector("x") + + if kind == "ofg": + f = OpFromGraph([x], [_f(x)]) + else: + f = _f + + if kind == "scan": + # Scan is included for a reference of how bad the overhead can be + outs, _ = scan(fn=f, outputs_info=[x], n_steps=n) + out = outs[-1] + else: + out = x + for i in range(n): + out = f(out) + + compiled_fn = function([x], out, trust_input=True, mode="FAST_RUN") + compiled_fn.vm.allow_gc = False + + rng = np.random.default_rng(1) + x_test = rng.normal(size=(10,)) + + res = benchmark(compiled_fn, x_test) + + expected_res = x_test + for i in range(n): + expected_res = _f(expected_res) + np.testing.assert_allclose(res, expected_res)