diff --git a/pyadjoint/optimization/rol_solver.py b/pyadjoint/optimization/rol_solver.py index 06222478..290c080d 100644 --- a/pyadjoint/optimization/rol_solver.py +++ b/pyadjoint/optimization/rol_solver.py @@ -71,8 +71,8 @@ def scale(self, alpha): def riesz_map(self, derivs): dat = [] opts = {"riesz_representation": self.inner_product} - for deriv in Enlist(derivs): - dat.append(deriv._ad_convert_type(deriv, options=opts)) + for f, deriv in zip(self.dat, derivs): + dat.append(f._ad_convert_type(deriv, options=opts)) return dat def dot(self, yy): @@ -82,6 +82,15 @@ def dot(self, yy): res += x._ad_dot(y, options=opts) return res + def dual(self) -> "ROLVector": + """Create a new `ROLVector` in the dual space of the current `self`. + """ + dat = [] + opts = {"riesz_map": self.inner_product} + for x in self.dat: + dat.append(x._ad_riesz_representation(options=opts)) + return ROLVector(dat, inner_product=self.inner_product) + def norm(self): return self.dot(self) ** 0.5 @@ -123,12 +132,14 @@ def applyJacobian(self, jv, v, x, tol): self.con.jacobian_action(x.dat, v.dat[0], jv.dat) def applyAdjointJacobian(self, jv, v, x, tol): - self.con.jacobian_adjoint_action(x.dat, v.dat, jv.dat[0]) - jv.dat = jv.riesz_map(jv.dat) + tmp = jv.dual() + self.con.jacobian_adjoint_action(x.dat, v.dat, tmp.dat[0]) + jv.dat = jv.riesz_map(tmp.dat) def applyAdjointHessian(self, ahuv, u, v, x, tol): - self.con.hessian_action(x.dat, u.dat[0], v.dat, ahuv.dat[0]) - ahuv.dat = ahuv.riesz_map(ahuv.dat) + tmp = ahuv.dual() + self.con.hessian_action(x.dat, u.dat[0], v.dat, tmp.dat[0]) + ahuv.dat = ahuv.riesz_map(tmp.dat) class ROLSolver(OptimizationSolver): """ diff --git a/pyadjoint/overloaded_type.py b/pyadjoint/overloaded_type.py index 1d286aee..61b5f47e 100644 --- a/pyadjoint/overloaded_type.py +++ b/pyadjoint/overloaded_type.py @@ -116,6 +116,20 @@ def _ad_convert_type(self, value, options={}): """ raise NotImplementedError(f"OverloadedType._ad_convert_type not defined for class {type(self)}.") + def _ad_riesz_representation(self, options={}): + """This method must be overridden. + + Should implement a way to return the Riesz representation of the overloaded object. + + Args: + options (dict): A dictionary with options that may be supplied by the user. If the Riesz representation + functionality offers some options on how to compute it, this is the dictionary that should be used. + + Returns: + OverloadedType: The Riesz representation of the overloaded object. + """ + raise NotImplementedError(f"OverloadedType._riesz_representation not defined for class {type(self)}.") + def _ad_create_checkpoint(self): """This method must be overridden.