diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 404a3663..01e727ad 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -41,6 +41,8 @@ class PETScVecInterface: def __init__(self, x, *, comm=None): if PETSc is None: raise RuntimeError("PETSc not available") + if petsctools is None: + raise RuntimeError("petsctools not available") x = Enlist(x) if comm is None: @@ -345,8 +347,9 @@ def mult(self, A, x, y): to_petsc(ub_vec, ubs) tao.setVariableBounds(lb_vec, ub_vec) - self.options = OptionsManager(parameters, None) - self.options.set_from_options(tao) + petsctools.set_from_options( + tao, parameters=parameters, + options_prefix=None) if tao.getType() in {PETSc.TAO.Type.LMVM, PETSc.TAO.Type.BLMVM}: class InitialHessian: @@ -391,7 +394,7 @@ def apply(self, pc, x, y): x = vec_interface.new_petsc() tao.setSolution(x) - with self.options.inserted_options(): + with petsctools.inserted_options(tao): tao.setUp() super().__init__(problem, parameters) @@ -432,7 +435,7 @@ def solve(self): controls = self.tao_objective.reduced_functional.controls m = tuple(control.tape_value()._ad_copy() for control in controls) self._vec_interface.to_petsc(self.x, m) - with self.options.inserted_options(): + with petsctools.inserted_options(self.tao): self.tao.solve() self._vec_interface.from_petsc(self.x, m) if self.tao.getConvergedReason() <= 0: diff --git a/pyproject.toml b/pyproject.toml index a0eb58e3..996dfaa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,13 +38,38 @@ all = [ "sphinxcontrib-bibtex", "tensorflow", ] -doc = ["sphinx", "sphinx-autobuild", "sphinxcontrib-bibtex"] -meshing = ["pygmsh", "meshio"] -moola = ["moola>=0.1.6"] -test = ["pytest>=3.10", "flake8", "coverage"] -visualisation = ["tensorflow", "protobuf", "networkx", "pygraphviz"] -tao = ["petsc4py", "petsctools"] +doc = [ + "sphinx", + "sphinx-autobuild", + "sphinxcontrib-bibtex" +] +meshing = [ + "pygmsh", + "meshio" +] +moola = [ + "moola>=0.1.6" +] +test = [ + "pytest>=3.10", + "flake8", + "coverage" +] +visualisation = [ + "tensorflow", + "protobuf", + "networkx", + "pygraphviz" +] +tao = [ + "petsc4py", + "petsctools>2025.0" +] [tool.setuptools] -packages = ["numpy_adjoint", "pyadjoint", "pyadjoint.optimization"] +packages = [ + "numpy_adjoint", + "pyadjoint", + "pyadjoint.optimization" +]