From 90903545fdf63d6b2fc312f86f7791ae2bbac4d4 Mon Sep 17 00:00:00 2001
From: mikeevmm <miguelmurca@gmail.com>
Date: Sun, 4 Apr 2021 19:21:53 +0100
Subject: [PATCH 1/4] feat: implementation of bounded LBFGS algorithm

---
 .../python/optimizer/__init__.py              |    2 +
 .../python/optimizer/lbfgsb.py                | 1625 +++++++++++++++++
 2 files changed, 1627 insertions(+)
 create mode 100644 tensorflow_probability/python/optimizer/lbfgsb.py

diff --git a/tensorflow_probability/python/optimizer/__init__.py b/tensorflow_probability/python/optimizer/__init__.py
index 2187b56593..7ad87a8b93 100644
--- a/tensorflow_probability/python/optimizer/__init__.py
+++ b/tensorflow_probability/python/optimizer/__init__.py
@@ -27,6 +27,7 @@
 from tensorflow_probability.python.optimizer.differential_evolution import minimize as differential_evolution_minimize
 from tensorflow_probability.python.optimizer.differential_evolution import one_step as differential_evolution_one_step
 from tensorflow_probability.python.optimizer.lbfgs import minimize as lbfgs_minimize
+from tensorflow_probability.python.optimizer.lbfgsb import minimize as lbfgsb_minimize
 from tensorflow_probability.python.optimizer.nelder_mead import minimize as nelder_mead_minimize
 from tensorflow_probability.python.optimizer.nelder_mead import nelder_mead_one_step
 from tensorflow_probability.python.optimizer.proximal_hessian_sparse import minimize as proximal_hessian_sparse_minimize
@@ -42,6 +43,7 @@
     'differential_evolution_minimize',
     'differential_evolution_one_step',
     'lbfgs_minimize',
+    'lbfgsb_minimize',
     'nelder_mead_minimize',
     'nelder_mead_one_step',
     'proximal_hessian_sparse_minimize',
diff --git a/tensorflow_probability/python/optimizer/lbfgsb.py b/tensorflow_probability/python/optimizer/lbfgsb.py
new file mode 100644
index 0000000000..4e7c9e6d8e
--- /dev/null
+++ b/tensorflow_probability/python/optimizer/lbfgsb.py
@@ -0,0 +1,1625 @@
+# Copyright 2018 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A constrained version of the Limited-Memory BFGS minimization algorithm.
+
+Limited-memory quasi-Newton methods are useful for solving large problems
+whose Hessian matrices cannot be computed at a reasonable cost or are not
+sparse. Instead of storing fully dense n x n approximations of Hessian
+matrices, they only save a few vectors of length n that represent the
+approximations implicitly.
+
+This module implements the algorithm known as L-BFGS-B, which, as its name
+suggests, is a limited-memory version of the BFGS algorithm, with bounds.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from numpy.core.fromnumeric import argmin, clip
+
+from tensorflow.python.ops.gen_array_ops import gather, lower_bound, where
+from tensorflow_probability.python.internal.backend.numpy import dtype, numpy_math
+
+# Dependency imports
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.internal import dtype_util
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow_probability.python.optimizer import bfgs_utils
+from tensorflow_probability.python.optimizer import lbfgs_minimize
+
+
+LBfgsBOptimizerResults = collections.namedtuple(
+    'LBfgsBOptimizerResults', [
+        'converged',  # Scalar boolean tensor indicating whether the minimum
+                      # was found within tolerance.
+        'failed',  # Scalar boolean tensor indicating whether a line search
+                   # step failed to find a suitable step size satisfying Wolfe
+                   # conditions. In the absence of any constraints on the
+                   # number of objective evaluations permitted, this value will
+                   # be the complement of `converged`. However, if there is
+                   # a constraint and the search stopped due to available
+                   # evaluations being exhausted, both `failed` and `converged`
+                   # will be simultaneously False.
+        'num_iterations',  # The number of iterations of the BFGS update.
+        'num_objective_evaluations',  # The total number of objective
+                                      # evaluations performed.
+        'position',  # A tensor containing the last argument value found
+                     # during the search. If the search converged, then
+                     # this value is the argmin of the objective function.
+        'lower_bounds', # A tensor containing the lower bounds to the constrained
+                        # optimization, cast to the shape of `position`.
+        'upper_bounds', # A tensor containing the upper bounds to the constrained
+                       # optimization, cast to the shape of `position`.
+        'objective_value',  # A tensor containing the value of the objective
+                            # function at the `position`. If the search
+                            # converged, then this is the (local) minimum of
+                            # the objective function.
+        'objective_gradient',  # A tensor containing the gradient of the
+                               # objective function at the
+                               # `final_position`. If the search converged
+                               # the max-norm of this tensor should be
+                               # below the tolerance.
+        'position_deltas',  # A tensor encoding information about the latest
+                            # changes in `position` during the algorithm
+                            # execution. Its shape is of the form
+                            # `(num_correction_pairs,) + position.shape` where
+                            # `num_correction_pairs` is given as an argument to
+                            # the minimize function.
+        'gradient_deltas',  # A tensor encoding information about the latest
+                            # changes in `objective_gradient` during the
+                            # algorithm execution. Has the same shape as
+                            # position_deltas.
+        'history', # How many gradient/position deltas should be considered.
+    ])
+
+_ConstrainedCauchyState = collections.namedtuple(
+    '_ConstrainedCauchyResult', [
+        'theta', # `\theta` in [2]; n the Cauchy search, relates to the implicit Hessian
+                 # `B = \theta*I - WMW'` (`I` the identity, see [1,2] for details)
+        'm', # `M_k` matrix in [2]; part of the implicit representation of the Hessian,
+             # see the comment above
+        'breakpoints', # `t_i` in [Byrd et al.][2];
+                       # the breakpoints in the branch definition of the
+                       # projection of the gradients, batched
+        'steepest', # `d` in [2]; steepest descent clamped to bounds
+        'free_vars_idx', # `\mathcal{F}` of [2]; the indices of (currently) free variables.
+                          # Indices that are no longer free are marked with a negative value.
+                          # This is used instead of a ragged tensor because the size of the
+                          #  state object must remain constant between iterations of the
+                          #  while loop.
+        'free_mask', # Boolean mask of free variables
+        'p', # as in [2]
+        'c', # as in [2]
+        'df', # `f'` in [2]; (corrected) gradient 2-norm
+        'ddf', # `f''` in [2]; (corrected) laplacian 2-norm (?)
+        'dt_min', # `\Delta t_min` in [2]; the minimizing parameter
+                  # along the search direction
+        'breakpoint_min', # `t` in [2]
+        'breakpoint_min_idx', # `b` in [2]
+        'dt', # `\Delta t` in [2]
+        'breakpoint_min_old', # t_old in [2]
+        'cauchy_point', # `x^cp` in [2]; the actual cauchy point (we're looking for)
+        'active', # What batches are in active optimization
+    ])
+
+def minimize(value_and_gradients_function,
+             initial_position,
+             bounds=None,
+             previous_optimizer_results=None,
+             num_correction_pairs=10,
+             tolerance=1e-8,
+             x_tolerance=0,
+             f_relative_tolerance=0,
+             initial_inverse_hessian_estimate=None,
+             max_iterations=50,
+             parallel_iterations=1,
+             stopping_condition=None,
+             max_line_search_iterations=50,
+             name=None):
+  """Applies the L-BFGS-B algorithm to minimize a differentiable function.
+
+  Performs optionally constrained minimization of a differentiable function using the
+  L-BFGS-B scheme. See [Nocedal and Wright(2006)][1] for details on the unconstrained
+  version, and [Byrd et al.][2] for details on the constrained algorithm.
+
+  ### Usage:
+
+  The following example demonstrates the L-BFGS-B optimizer attempting to find the
+  constrained minimum for a simple high-dimensional quadratic objective function.
+
+  ```python
+    TODO
+  ```
+
+  ### References:
+
+  [1] Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series
+      in Operations Research. pp 176-180. 2006
+
+  http://pages.mtu.edu/~struther/Courses/OLD/Sp2013/5630/Jorge_Nocedal_Numerical_optimization_267490.pdf
+
+  [2] Richard H. Byrd, Peihuang Lu, Jorge Nocedal, & Ciyou Zhu (1995).
+      A Limited Memory Algorithm for Bound Constrained Optimization
+      SIAM Journal on Scientific Computing, 16(5), 1190–1208.
+
+  https://doi.org/10.1137/0916069
+
+  Args:
+    value_and_gradients_function:  A Python callable that accepts a point as a
+      real `Tensor` and returns a tuple of `Tensor`s of real dtype containing
+      the value of the function and its gradient at that point. The function
+      to be minimized. The input is of shape `[..., n]`, where `n` is the size
+      of the domain of input points, and all others are batching dimensions.
+      The first component of the return value is a real `Tensor` of matching
+      shape `[...]`. The second component (the gradient) is also of shape
+      `[..., n]` like the input value to the function.
+    initial_position: Real `Tensor` of shape `[..., n]`. The starting point, or
+      points when using batching dimensions, of the search procedure. At these
+      points the function value and the gradient norm should be finite.
+      Exactly one of `initial_position` and `previous_optimizer_results` can be
+      non-None.
+    bounds: Tuple of two real `Tensor`s of shape `[..., n]`. The first element
+      indicates the lower bounds in the constrained optimization, and the second
+      element of the tuple indicates the upper bounds of the optimization. If
+      `bounds` is `None`, the optimization is deferred to the unconstrained
+      version (see also `lbfgs_minimize`). If one of the elements of the tuple
+      is `None`, the optimization is assumed to be unconstrained (from above/below,
+      respectively). 
+    previous_optimizer_results: An `LBfgsBOptimizerResults` namedtuple to
+      intialize the optimizer state from, instead of an `initial_position`.
+      This can be passed in from a previous return value to resume optimization
+      with a different `stopping_condition`. Exactly one of `initial_position`
+      and `previous_optimizer_results` can be non-None.
+    num_correction_pairs: Positive integer. Specifies the maximum number of
+      (position_delta, gradient_delta) correction pairs to keep as implicit
+      approximation of the Hessian matri
+    A real `Tensor` of the same shape as the `state.position`, of dtype `bool`,
+    denoting a mask over the free variables.x.
+    tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance
+      for the procedure. If the supremum norm of the gradient vector is below
+      this number, the algorithm is stopped.
+    x_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the
+      position between one iteration and the next is smaller than this number,
+      the algorithm is stopped.
+    f_relative_tolerance: Scalar `Tensor` of real dtype. If the relative change
+      in the objective value between one iteration and the next is smaller
+      than this value, the algorithm is stopped.
+    initial_inverse_hessian_estimate: None. Option currently not supported.
+    max_iterations: Scalar positive int32 `Tensor`. The maximum number of
+      iterations for L-BFGS updates.
+    parallel_iterations: Positive integer. The number of iterations allowed to
+      run in parallel.
+    stopping_condition: (Optional) A Python function that takes as input two
+      Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor.
+      The input tensors are `converged` and `failed`, indicating the current
+      status of each respective batch member; the return value states whether
+      the algorithm should stop. The default is tfp.optimizer.converged_all
+      which only stops when all batch members have either converged or failed.
+      An alternative is tfp.optimizer.converged_any which stops as soon as one
+      batch member has converged, or when all have failed.
+    max_line_search_iterations: Python int. The maximum number of iterations
+      for the `hager_zhang` line search algorithm.
+    name: (Optional) Python str. The name prefixed to the ops created by this
+      function. If not supplied, the default name 'minimize' is used.
+
+  Returns:
+    optimizer_results: A namedtuple containing the following items:
+      converged: Scalar boolean tensor indicating whether the minimum was
+        found within tolerance.
+      failed:  Scalar boolean tensor indicating whether a line search
+        step failed to find a suitable step size satisfying Wolfe
+        conditions. In the absence of any constraints on the
+        number of objective evaluations permitted, this value will
+        be the complement of `converged`. However, if there is
+        a constraint and the search stopped due to available
+        evaluations being exhausted, both `failed` and `converged`
+        will be simultaneously False.
+      num_objective_evaluations: The total number of objective
+        evaluations performed.
+      position: A tensor containing the last argument value found
+        during the search. If the search converged, then
+        this value is the argmin of the objective function.
+      objective_value: A tensor containing the value of the objective
+        function at the `position`. If the search converged, then this is
+        the (local) minimum of the objective function.
+      objective_gradient: A tensor containing the gradient of the objective
+        function at the `position`. If the search converged the
+        max-norm of this tensor should be below the tolerance.
+      position_deltas: A tensor encoding information about the latest
+        changes in `position` during the algorithm execution.
+      gradient_deltas: A tensor encoding information about the latest
+        changes in `objective_gradient` during the algorithm execution.
+  """
+
+  def _lbfgs_defer():
+      return lbfgs_minimize(value_and_gradients_function,
+             initial_position,
+             previous_optimizer_results,
+             num_correction_pairs,
+             tolerance,
+             x_tolerance,
+             f_relative_tolerance,
+             initial_inverse_hessian_estimate,
+             max_iterations,
+             parallel_iterations,
+             stopping_condition,
+             max_line_search_iterations,
+             name)
+
+  if bounds is None:
+    return _lbfgs_defer()
+    
+  if len(bounds) != 2:
+    raise ValueError(
+      '`bounds` parameter has unexpected number of elements '
+      '(expected 2).')
+
+  lower_bounds, upper_bounds = bounds
+  
+  if lower_bounds is None and upper_bounds is None:
+    return _lbfgs_defer()
+  # Defer further conversion of the bounds to appropriate tensors
+  # until the shape of the input is known
+
+  if initial_inverse_hessian_estimate is not None:
+    raise NotImplementedError(
+        'Support of initial_inverse_hessian_estimate arg not yet implemented')
+
+  if stopping_condition is None:
+    stopping_condition = bfgs_utils.converged_all
+
+  with tf.name_scope(name or 'minimize'):
+    if (initial_position is None) == (previous_optimizer_results is None):
+      raise ValueError(
+          'Exactly one of `initial_position` or '
+          '`previous_optimizer_results` may be specified.')
+
+    if initial_position is not None:
+      initial_position = tf.convert_to_tensor(
+          initial_position, name='initial_position')
+      # Force at least one batching dimension
+      if len(ps.shape(initial_position)) == 1:
+        initial_position = initial_position[tf.newaxis, :]
+      position_shape = ps.shape(initial_position)
+      dtype = dtype_util.base_dtype(initial_position.dtype)
+
+    if previous_optimizer_results is not None:
+      position_shape = ps.shape(previous_optimizer_results.position)
+      dtype = dtype_util.base_dtype(previous_optimizer_results.position.dtype)
+
+    # TODO: This isn't agnostic to the number of batch dimensions, it only
+    #  supports one batch dimension, but I've found RaggedTensors to be far
+    #  too finicky/undocumented to handle multiple batch dimensions in any
+    #  sane way. (Even the way it's working so far is less than ideal.) 
+    if len(position_shape) > 2:
+      raise NotImplementedError("More than a batch dimension is not implemented. "
+                                "Consider flattening and then reshaping the results.") 
+    # NOTE: Broadcasting the batched dimensions breaks when there are no
+    #  batched dimensions. Although this isn't handled like this in
+    #  `lbfgs.py`, I'd rather force a batch dimension with a single
+    #  element than do conditional checks later.
+    if len(position_shape) == 1:
+      position_shape = tf.concat([[1], position_shape], axis=0)
+      initial_position = tf.broadcast_to(initial_position, position_shape)
+
+    # NOTE: Could maybe use bfgs_utils._broadcast here, but would have to check
+    #  that the non-batching dimensions also match; using `tf.broadcast_to` has
+    #  the advantage that passing a (1,)-shaped tensor as bounds will correctly
+    #  bound every variable at the single value.
+    if lower_bounds is None:
+      lower_bounds = tf.constant(
+        [-float('inf')], shape=position_shape, dtype=dtype, name='lower_bounds')
+    else:
+      lower_bounds = tf.cast(tf.convert_to_tensor(lower_bounds), dtype=dtype)
+      try:
+        lower_bounds = tf.broadcast_to(
+          lower_bounds, position_shape, name='lower_bounds')
+      except tf.errors.InvalidArgumentError:
+        raise ValueError(
+          'Failed to broadcast lower bounds tensor to the shape of starting position. '
+          'Are the lower bounds well formed?')
+    if upper_bounds is None:
+      upper_bounds = tf.constant(
+        [float('inf')], shape=position_shape, dtype=dtype, name='upper_bounds')
+    else:
+      upper_bounds = tf.cast(tf.convert_to_tensor(upper_bounds), dtype=dtype)
+      try:
+        upper_bounds = tf.broadcast_to(
+          upper_bounds, position_shape, name='upper_bounds')
+      except tf.errors.InvalidArgumentError:
+        raise ValueError(
+          'Failed to broadcast upper bounds tensor to the shape of starting position. '
+          'Are the lower bounds well formed?')  
+
+    # Clamp the starting position to the bounds, because the algorithm expects the
+    # variables to be in range for the Hessian inverse estimation, but also because
+    # that fast-tracks the first iteration of the Cauchy optimization.
+    initial_position = tf.clip_by_value(initial_position, lower_bounds, upper_bounds)
+
+    tolerance = tf.convert_to_tensor(
+        tolerance, dtype=dtype, name='grad_tolerance')
+    f_relative_tolerance = tf.convert_to_tensor(
+        f_relative_tolerance, dtype=dtype, name='f_relative_tolerance')
+    x_tolerance = tf.convert_to_tensor(
+        x_tolerance, dtype=dtype, name='x_tolerance')
+    max_iterations = tf.convert_to_tensor(max_iterations, name='max_iterations')
+
+    # The `state` here is a `LBfgsBOptimizerResults` tuple with values for the
+    # current state of the algorithm computation.
+    def _cond(state):
+      """Continue if iterations remain and stopping condition is not met."""
+      return ((state.num_iterations < max_iterations) &
+              tf.logical_not(stopping_condition(state.converged, state.failed)))
+
+    def _body(current_state):
+      """Main optimization loop."""
+      current_state = bfgs_utils.terminate_if_not_finite(current_state)
+  
+      cauchy_point, free_mask = \
+        _cauchy_minimization(current_state, num_correction_pairs, parallel_iterations)
+
+      search_direction = _get_search_direction(current_state)
+
+      # TODO(b/120134934): Check if the derivative at the start point is not
+      # negative, if so then reset position/gradient deltas and recompute
+      # search direction.
+      # NOTE: Erasing is currently handled in `_bounded_line_search_step`
+      search_direction = tf.where(
+                          free_mask,
+                          search_direction,
+                          0.)
+      bad_direction = \
+        (tf.reduce_sum(search_direction * current_state.objective_gradient, axis=-1) > 0)
+
+      cauchy_search = _cauchy_line_search_step(current_state,
+          value_and_gradients_function, search_direction,
+          tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
+          max_line_search_iterations, free_mask, cauchy_point)
+      
+      search_direction = cauchy_search.position - current_state.position
+      next_state = _bounded_line_search_step(current_state,
+          value_and_gradients_function, search_direction,
+          tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
+          max_line_search_iterations, bad_direction)
+
+      # If not failed or converged, update the Hessian estimate.
+      # Only do this if the new pairs obey the s.y > 0
+      position_delta = next_state.position - current_state.position
+      gradient_delta = next_state.objective_gradient - current_state.objective_gradient
+      positive_prod = (tf.math.reduce_sum(position_delta * gradient_delta, axis=-1) > \
+                        1E-8*tf.reduce_sum(gradient_delta**2, axis=-1))
+      should_push = ~(next_state.converged | next_state.failed) & positive_prod & ~bad_direction
+      new_position_deltas = _queue_push(
+              next_state.position_deltas, should_push, position_delta)
+      new_gradient_deltas = _queue_push(
+              next_state.gradient_deltas, should_push, gradient_delta)
+      new_history = tf.where(
+              should_push,
+              tf.math.minimum(next_state.history + 1, num_correction_pairs),
+              next_state.history)
+      
+      if not tf.executing_eagerly():
+        # Hint the compiler that the shape of the properties has not changed
+        new_position_deltas = tf.ensure_shape(
+          new_position_deltas, next_state.position_deltas.shape)
+        new_gradient_deltas = tf.ensure_shape(
+          new_gradient_deltas, next_state.gradient_deltas.shape)
+        new_history = tf.ensure_shape(
+          new_history, next_state.history.shape)
+
+      state_after_inv_hessian_update = bfgs_utils.update_fields(
+          next_state,
+          position_deltas=new_position_deltas,
+          gradient_deltas=new_gradient_deltas,
+          history=new_history)
+
+      return [state_after_inv_hessian_update]
+
+    if previous_optimizer_results is None:
+      assert initial_position is not None
+      initial_state = _get_initial_state(value_and_gradients_function,
+                                          initial_position,
+                                          lower_bounds,
+                                          upper_bounds,
+                                          num_correction_pairs,
+                                          tolerance)
+    else:
+      initial_state = previous_optimizer_results
+
+    return tf.while_loop(
+        cond=_cond,
+        body=_body,
+        loop_vars=[initial_state],
+        parallel_iterations=parallel_iterations)[0]
+
+
+def _cauchy_minimization(bfgs_state, num_correction_pairs, parallel_iterations):
+  """Calculates the Cauchy point (minimizes the quadratic approximation to the
+  objective function at the current position, in the direction of steepest
+  descent), but bounding the gradient by the corresponding bounds.
+
+  See algorithm CP and associated discussion of [Byrd,Lu,Nocedal,Zhu][2]
+  for details.
+
+  Args:
+    bfgs_state: A `_ConstrainedCauchyState` initialized to the starting point of the
+    constrained minimization.
+  Returns:
+    A potentially modified `state`, the obtained `cauchy_point` and boolean
+    `free_mask` indicating which variables are free (`True`) and which variables
+    are under active constrain (`False`)
+  """
+  cauchy_state = _get_initial_cauchy_state(bfgs_state, num_correction_pairs)
+  # NOTE: See lbfgsb.f (l. 1649)
+  ddf_org = -cauchy_state.theta * cauchy_state.df
+
+  def _cond(state):
+    """Test convergence to Cauchy point at current branch"""
+    return tf.math.reduce_any(state.active)
+
+  def _body(state):
+    """Cauchy point iterative loop
+    
+    (While loop of CP algorithm [2])"""
+    # Remove b from the free indices
+    free_vars_idx, free_mask = _cauchy_remove_breakpoint_min(
+                                state.free_vars_idx,
+                                state.breakpoint_min_idx,
+                                state.free_mask,
+                                state.active)
+
+    # Shape: [b]
+    d_b = tf.where(
+            state.active,
+            tf.gather(
+              state.steepest,
+              state.breakpoint_min_idx,
+              batch_dims=1),
+            0.)
+    # Shape: [b]
+    x_b = tf.where(
+            state.active,
+            tf.gather(
+              bfgs_state.position,
+              state.breakpoint_min_idx,
+              batch_dims=1),
+            0.)
+
+    # Shape: [b]
+    x_cp_b = tf.where(
+              state.active,
+              tf.where(
+                d_b > 0.,
+                tf.gather(
+                  bfgs_state.upper_bounds,
+                  state.breakpoint_min_idx,
+                  batch_dims=1),
+                tf.where(
+                  d_b < 0.,
+                  tf.gather(
+                    bfgs_state.lower_bounds,
+                    state.breakpoint_min_idx,
+                    batch_dims=1),
+                  x_b)),
+              tf.gather(
+                state.cauchy_point,
+                state.breakpoint_min_idx,
+                batch_dims=1))
+
+    keep_idx = (tf.range(ps.shape(state.cauchy_point)[-1]) != \
+                  state.breakpoint_min_idx[..., tf.newaxis])
+    cauchy_point = tf.where(
+                    state.active[..., tf.newaxis],
+                    tf.where(
+                      keep_idx,
+                      state.cauchy_point,
+                      x_cp_b[..., tf.newaxis]),
+                    state.cauchy_point)
+
+    z_b = tf.where(
+            state.active,
+            x_cp_b - x_b,
+            0.)
+
+    c = tf.where(
+        state.active[..., tf.newaxis],
+        state.c + state.dt[...,tf.newaxis] * state.p,
+        state.c)
+    
+    # The matrix M has shape
+    #
+    #  [[ 0  0   ]
+    #   [ 0  M_h ]]
+    # 
+    # where M_h is the M matrix considering the current history `h`.
+    # Therefore, for W, we should consider that the last `h` columns
+    #  are
+    #     Y[k-h,...,k-1] theta*S[k-h,...k-1]
+    #         (so that the first `2*(m-h)` columns are 0.
+
+    # 1. Create the "full" W matrix row
+    # TODO: Transpose seems inevitable, because of batch dims?
+    w_b = tf.concat(
+              [
+                tf.gather(
+                  tf.transpose(
+                    bfgs_state.gradient_deltas,
+                    perm=[1,0,2]),
+                  state.breakpoint_min_idx,
+                  axis=-1,
+                  batch_dims=1),
+                state.theta[..., tf.newaxis] * \
+                  tf.gather(
+                    tf.transpose(
+                      bfgs_state.position_deltas,
+                      perm=[1,0,2]),
+                    state.breakpoint_min_idx,
+                    axis=-1,
+                    batch_dims=1)
+              ],
+              axis=-1)
+    # 2. "Permute" the relevant items to the right
+    idx = tf.concat(
+            [
+              tf.ragged.range(
+                num_correction_pairs - bfgs_state.history),
+              tf.ragged.range(
+                num_correction_pairs,
+                2*num_correction_pairs - bfgs_state.history),
+              tf.ragged.range(
+                num_correction_pairs - bfgs_state.history,
+                num_correction_pairs),
+              tf.ragged.range(
+                2*num_correction_pairs - bfgs_state.history,
+                2*num_correction_pairs)
+            ],
+            axis=-1).to_tensor()
+    w_b = tf.gather(
+              w_b,
+              idx,
+              batch_dims=1)
+
+    # NOTE Use of d_b = -g_b
+    df = tf.where(
+          state.active,
+          state.df + state.dt * state.ddf + \
+            d_b**2 - \
+            state.theta * d_b * z_b + \
+            d_b * tf.einsum(
+                    '...j,...jk,...k->...',
+                    w_b,
+                    state.m,
+                    c),
+          state.df)
+          
+    # NOTE use of d_b = -g_b
+    ddf = tf.where(
+            state.active,
+            state.ddf - state.theta * d_b**2 + \
+              2. * d_b * tf.einsum(
+                          "...i,...ij,...j->...",
+                          w_b,
+                          state.m,
+                          state.p) - \
+              d_b**2 * tf.einsum(
+                        "...i,...ij,...j->...",
+                        w_b,
+                        state.m,
+                        w_b),
+            state.ddf)
+    # NOTE: See lbfgsb.f (l. 1649)
+    # TODO: How to get machine epsilon?
+    ddf = tf.math.maximum(ddf, 1E-8*ddf_org)
+
+    # NOTE use of d_b = -g_b
+    p = tf.where(
+          state.active[..., tf.newaxis],
+          state.p - d_b[..., tf.newaxis] * w_b,
+          state.p)
+
+    steepest_idx = tf.range(
+        ps.shape(state.steepest)[-1],
+        dtype=state.breakpoint_min_idx.dtype)[tf.newaxis, ...]
+    steepest = tf.where(
+      state.active[..., tf.newaxis],
+      tf.where(
+        steepest_idx == state.breakpoint_min_idx[..., tf.newaxis],
+        0.,
+        state.steepest),
+      state.steepest)
+    
+    dt_min = tf.where(
+              state.active,
+              -tf.math.divide_no_nan(df, ddf),
+              state.dt_min)
+
+    breakpoint_min_old = tf.where(
+                          state.active,
+                          state.breakpoint_min,
+                          state.breakpoint_min_old)
+    
+    # Find b
+    breakpoint_min_idx, breakpoint_min = \
+      _cauchy_get_breakpoint_min(
+        state.breakpoints,
+        free_vars_idx)
+    breakpoint_min_idx = tf.where(
+                          state.active,
+                          breakpoint_min_idx,
+                          state.breakpoint_min_idx)
+    breakpoint_min = tf.where(
+                      state.active,
+                      breakpoint_min,
+                      state.breakpoint_min)
+
+    dt = tf.where(
+          state.active,
+          breakpoint_min - state.breakpoint_min,
+          state.dt)
+          
+    active = tf.where(
+              state.active,
+              _cauchy_update_active(free_vars_idx, dt_min, dt),
+              state.active)
+
+    # We have to hint the "compiler" that the shapes of the new
+    # values are the same as the old values.
+    if not tf.executing_eagerly():
+      steepest = tf.ensure_shape(steepest, state.steepest.shape)
+      free_vars_idx = tf.ensure_shape(free_vars_idx, state.free_vars_idx.shape)
+      free_mask = tf.ensure_shape(free_mask, state.free_mask.shape)
+      p = tf.ensure_shape(p, state.p.shape)
+      c = tf.ensure_shape(c, state.c.shape)
+      df = tf.ensure_shape(df, state.df.shape)
+      ddf = tf.ensure_shape(ddf, state.ddf.shape)
+      dt_min = tf.ensure_shape(dt_min, state.dt_min.shape)
+      breakpoint_min = tf.ensure_shape(breakpoint_min, state.breakpoint_min.shape)
+      breakpoint_min_idx = tf.ensure_shape(breakpoint_min_idx, state.breakpoint_min_idx.shape)
+      dt = tf.ensure_shape(dt, state.dt.shape)
+      breakpoint_min_old = tf.ensure_shape(breakpoint_min_old, state.breakpoint_min_old.shape)
+      cauchy_point = tf.ensure_shape(cauchy_point, state.cauchy_point.shape)
+      active = tf.ensure_shape(active, state.active.shape)
+
+    new_state = bfgs_utils.update_fields(
+                  state, steepest=steepest, free_vars_idx=free_vars_idx,
+                  free_mask=free_mask, p=p, c=c, df=df, ddf=ddf, dt_min=dt_min,
+                  breakpoint_min=breakpoint_min, breakpoint_min_idx=breakpoint_min_idx,
+                  dt=dt, breakpoint_min_old=breakpoint_min_old,
+                  cauchy_point=cauchy_point, active=active)
+    
+    return [new_state]
+
+  cauchy_loop = tf.while_loop(
+        cond=_cond,
+        body=_body,
+        loop_vars=[cauchy_state],
+        parallel_iterations=parallel_iterations)[0]
+
+  # The loop broke, so the last identified `b` index never got
+  # removed
+  _free_vars_idx, free_mask = _cauchy_remove_breakpoint_min(
+                              cauchy_loop.free_vars_idx,
+                              cauchy_loop.breakpoint_min_idx,
+                              cauchy_loop.free_mask,
+                              cauchy_loop.active)
+
+  dt_min = tf.math.maximum(cauchy_loop.dt_min, 0)
+  t_old = cauchy_loop.breakpoint_min_old + dt_min
+  
+  # A breakpoint of -1 means that we ran out of free variables
+  flagged_breakpoint_min = tf.where(
+                              cauchy_loop.breakpoint_min < 0,
+                              float('inf'),
+                              cauchy_loop.breakpoint_min)
+  cauchy_point = tf.where(
+      ~(bfgs_state.converged | bfgs_state.failed)[..., tf.newaxis],
+      tf.where(
+        cauchy_loop.breakpoints >= flagged_breakpoint_min[..., tf.newaxis],
+        bfgs_state.position + t_old[..., tf.newaxis] * cauchy_loop.steepest,
+        cauchy_loop.cauchy_point),
+      bfgs_state.position)
+
+  # NOTE: We only return the cauchy point and the free mask, so there is no
+  #  need to update the actual state, even though we could at this point update
+  #  `free_vars_idx`, `free_mask`, and `cauchy_point`
+  free_mask = free_mask & ~(cauchy_loop.breakpoints != cauchy_loop.breakpoint_min)
+
+  return cauchy_point, free_mask
+
+
+def _cauchy_update_active(free_vars_idx, dt_min, dt):
+  return tf.where(
+            tf.reduce_any(free_vars_idx >= 0, axis=-1) & (dt_min >= dt),
+            True,
+            False)
+
+
+def _hz_line_search(state, value_and_gradients_function,
+      search_direction, max_iterations, inactive):
+  line_search_value_grad_func = bfgs_utils._restrict_along_direction(
+      value_and_gradients_function, state.position, search_direction)
+  derivative_at_start_pt = tf.reduce_sum(
+      state.objective_gradient * search_direction, axis=-1)
+  val_0 = bfgs_utils.ValueAndGradient(x=bfgs_utils._broadcast(0, state.position),
+                           f=state.objective_value,
+                           df=derivative_at_start_pt,
+                           full_gradient=state.objective_gradient)
+  return bfgs_utils.linesearch.hager_zhang(
+      line_search_value_grad_func,
+      initial_step_size=bfgs_utils._broadcast(1, state.position),
+      value_at_zero=val_0,
+      converged=inactive,
+      max_iterations=max_iterations)  # No search needed for these.
+
+
+def _cauchy_line_search_step(state, value_and_gradients_function, search_direction,
+                     grad_tolerance, f_relative_tolerance, x_tolerance,
+                     stopping_condition, max_iterations, free_mask, cauchy_point):
+  """Performs the line search in given direction, backtracking in direction to the cauchy point,
+  and clamping actively contrained variables to the cauchy point."""
+  inactive = state.failed | state.converged
+  ls_result = _hz_line_search(state, value_and_gradients_function,
+                search_direction, max_iterations, inactive)
+  
+  state_after_ls = bfgs_utils.update_fields(
+      state,
+      failed=state.failed | (~state.converged & ~ls_result.converged & tf.reduce_any(free_mask, axis=-1)),
+      num_iterations=state.num_iterations + 1,
+      num_objective_evaluations=(
+          state.num_objective_evaluations + ls_result.func_evals + 1))
+
+  def _do_update_position():
+    # For inactive batch members `left.x` is zero. However, their
+    # `search_direction` might also be undefined, so we can't rely on
+    # multiplication by zero to produce a `position_delta` of zero.
+    alpha = ls_result.left.x[..., tf.newaxis]
+    ideal_position = tf.where(
+        inactive[..., tf.newaxis],
+        state.position,
+        tf.where(
+          free_mask,
+          state.position + search_direction * alpha,
+          cauchy_point))
+
+    # Backtrack from the ideal position in direction to the Cauchy point
+    cauchy_to_ideal = ideal_position - cauchy_point
+    clip_lower = tf.math.divide_no_nan(
+                  state.lower_bounds - cauchy_point,
+                  cauchy_to_ideal)
+    clip_upper = tf.math.divide_no_nan(
+                  state.upper_bounds - cauchy_point,
+                  cauchy_to_ideal)
+    clip = tf.math.reduce_min(
+            tf.where(
+              cauchy_to_ideal > 0,
+              clip_upper,
+              tf.where(
+                cauchy_to_ideal < 0,
+                clip_lower,
+                float('inf'))),
+            axis=-1)
+    alpha = tf.minimum(1.0, clip)[..., tf.newaxis]
+    
+    next_position = tf.where(
+        inactive[..., tf.newaxis],
+        state.position,
+        tf.where(
+          free_mask,
+          cauchy_point + alpha * cauchy_to_ideal,
+          cauchy_point))
+    
+    # NOTE: one extra call to the function
+    next_objective, next_gradient = \
+      value_and_gradients_function(next_position)
+
+    return _update_position(
+        state_after_ls,
+        next_position,
+        next_objective,
+        next_gradient,
+        grad_tolerance,
+        f_relative_tolerance,
+        x_tolerance,
+        tf.constant(False))
+
+  return ps.cond(
+      stopping_condition(state.converged, state.failed),
+      true_fn=lambda: state_after_ls,
+      false_fn=_do_update_position)
+
+
+def _bounded_line_search_step(state, value_and_gradients_function, search_direction,
+                     grad_tolerance, f_relative_tolerance, x_tolerance,
+                     stopping_condition, max_iterations, bad_direction):
+  """Performs a line search in given direction, clamping to the bounds, and fixing the actively
+  constrained values to the given values."""
+  inactive = state.failed | state.converged | bad_direction
+  ls_result = _hz_line_search(state, value_and_gradients_function,
+                search_direction, max_iterations, inactive)
+
+  new_failed = state.failed | (~state.converged & ~ls_result.converged \
+                              & tf.reduce_any(search_direction != 0, axis=-1)) \
+                                & ~bad_direction
+  new_num_iterations = state.num_iterations + 1
+  new_num_objective_evaluations = (
+          state.num_objective_evaluations + ls_result.func_evals + 1)
+
+  if not tf.executing_eagerly():
+    # Hint the compiler that the properties' shape will not change
+    new_failed = tf.ensure_shape(
+      new_failed, state.failed.shape)
+    new_num_iterations = tf.ensure_shape(
+      new_num_iterations, state.num_iterations.shape)
+    new_num_objective_evaluations = tf.ensure_shape(
+      new_num_objective_evaluations, state.num_objective_evaluations.shape)
+
+  state_after_ls = bfgs_utils.update_fields(
+      state,
+      failed=new_failed,
+      num_iterations=new_num_iterations,
+      num_objective_evaluations=new_num_objective_evaluations)
+
+  def _do_update_position():
+    lower_term = tf.math.divide_no_nan(
+                  state.lower_bounds - state.position,
+                  search_direction)
+    upper_term = tf.math.divide_no_nan(
+                  state.upper_bounds - state.position,
+                  search_direction)
+    
+    under_clip = tf.math.reduce_max(
+                  tf.where(
+                    (search_direction > 0),
+                    lower_term,
+                    tf.where(
+                      (search_direction < 0),
+                      upper_term,
+                      -float('inf'))),
+                  axis=-1)
+    over_clip = tf.math.reduce_min(
+                  tf.where(
+                    (search_direction > 0),
+                    upper_term,
+                    tf.where(
+                      (search_direction < 0),
+                      lower_term,
+                      float('inf'))),
+                  axis=-1)
+
+    alpha_clip = tf.clip_by_value(
+                  ls_result.left.x,
+                  under_clip,
+                  over_clip)[..., tf.newaxis]
+
+    # For inactive batch members `left.x` is zero. However, their
+    # `search_direction` might also be undefined, so we can't rely on
+    # multiplication by zero to produce a `position_delta` of zero.
+    next_position = tf.where(
+        inactive[..., tf.newaxis],
+        state.position,
+        state.position + search_direction * alpha_clip)
+          
+    # one extra call to the function, counted above
+    next_objective, next_gradient = \
+      value_and_gradients_function(next_position)
+
+    return _update_position(
+        state_after_ls,
+        next_position,
+        next_objective,
+        next_gradient,
+        grad_tolerance,
+        f_relative_tolerance,
+        x_tolerance,
+        bad_direction)
+
+  return ps.cond(
+      stopping_condition(state.converged, state.failed),
+      true_fn=lambda: state_after_ls,
+      false_fn=_do_update_position)
+
+
+def _update_position(state,
+                     next_position,
+                     next_objective,
+                     next_gradient,
+                     grad_tolerance,
+                     f_relative_tolerance,
+                     x_tolerance,
+                     erase_memory):
+  """Updates the state advancing its position by a given position_delta.
+  Also erases the LBFGS memory if indicated."""
+  state = bfgs_utils.terminate_if_not_finite(state, next_objective, next_gradient)
+
+  converged = ~state.failed & \
+                      _check_convergence_bounded(state.position,
+                                                 next_position,
+                                                 state.objective_value,
+                                                 next_objective,
+                                                 next_gradient,
+                                                 grad_tolerance,
+                                                 f_relative_tolerance,
+                                                 x_tolerance,
+                                                 state.lower_bounds,
+                                                 state.upper_bounds)
+  new_position_deltas = tf.where(
+                      erase_memory[..., tf.newaxis],
+                      tf.zeros_like(state.position_deltas),
+                      state.position_deltas)
+  new_gradient_deltas = tf.where(
+                      erase_memory[..., tf.newaxis],
+                      tf.zeros_like(state.gradient_deltas),
+                      state.gradient_deltas)
+  new_history = tf.where(
+              erase_memory,
+              tf.zeros_like(state.history),
+              state.history)
+  new_converged = (state.converged | converged)
+
+  if not tf.executing_eagerly():
+    # Hint the compiler that the properties have not changed shape
+    new_converged = tf.ensure_shape(new_converged, state.converged.shape)
+    next_position = tf.ensure_shape(next_position, state.position.shape)
+    next_objective = tf.ensure_shape(next_objective, state.objective_value.shape)
+    next_gradient = tf.ensure_shape(next_gradient, state.objective_gradient.shape)
+    new_position_deltas = tf.ensure_shape(new_position_deltas, state.position_deltas.shape)
+    new_gradient_deltas = tf.ensure_shape(new_gradient_deltas, state.gradient_deltas.shape)
+    new_history = tf.ensure_shape(new_history, state.history.shape)
+
+  return bfgs_utils.update_fields(
+      state,
+      converged=new_converged,
+      position=next_position,
+      objective_value=next_objective,
+      objective_gradient=next_gradient,
+      position_deltas=new_position_deltas,
+      gradient_deltas=new_gradient_deltas,
+      history=new_history)
+
+
+def _check_convergence_bounded(current_position,
+                       next_position,
+                       current_objective,
+                       next_objective,
+                       next_gradient,
+                       grad_tolerance,
+                       f_relative_tolerance,
+                       x_tolerance,
+                       lower_bounds,
+                       upper_bounds):
+  """Checks if the algorithm satisfies the convergence criteria."""
+  proj_grad_converged = bfgs_utils.norm(
+                          tf.clip_by_value(
+                            next_position - next_gradient,
+                            lower_bounds,
+                            upper_bounds) - next_position, dims=1) <= grad_tolerance
+  x_converged = bfgs_utils.norm(next_position - current_position, dims=1) <= x_tolerance
+  f_converged = bfgs_utils.norm(next_objective - current_objective, dims=0) <= \
+                  f_relative_tolerance * current_objective
+  return proj_grad_converged | x_converged | f_converged
+
+
+def _get_initial_state(value_and_gradients_function,
+                       initial_position,
+                       lower_bounds,
+                       upper_bounds,
+                       num_correction_pairs,
+                       tolerance):
+  """Create LBfgsBOptimizerResults with initial state of search procedure."""
+  init_args = bfgs_utils.get_initial_state_args(
+      value_and_gradients_function,
+      initial_position,
+      tolerance)
+  init_args.update(lower_bounds=lower_bounds, upper_bounds=upper_bounds)
+  empty_queue = _make_empty_queue_for(num_correction_pairs, initial_position)
+  init_args.update(
+    position_deltas=empty_queue,
+    gradient_deltas=empty_queue,
+    history=tf.zeros(ps.shape(initial_position)[:-1], dtype=tf.int32))
+  return LBfgsBOptimizerResults(**init_args)
+
+
+def _get_initial_cauchy_state(state, num_correction_pairs):
+  """Create _ConstrainedCauchyState with initial parameters"""
+  
+  theta = tf.math.divide_no_nan(
+              tf.reduce_sum(state.gradient_deltas[-1, ...]**2, axis=-1),
+              tf.reduce_sum(state.gradient_deltas[-1,...] * state.position_deltas[-1, ...], axis=-1))
+  theta = tf.where(
+            theta != 0,
+            theta,
+            1.0)
+
+  m, refresh = _cauchy_init_m(
+                  state,
+                  ps.shape(state.position_deltas),
+                  theta,
+                  num_correction_pairs)
+  # Erase the history where M isn't invertible
+  state = \
+    bfgs_utils.update_fields(
+      state,
+      gradient_deltas=tf.where(
+                        refresh[..., tf.newaxis],
+                        tf.zeros_like(state.gradient_deltas),
+                        state.gradient_deltas),
+      position_deltas=tf.where(
+                        refresh[..., tf.newaxis],
+                        tf.zeros_like(state.position_deltas),
+                        state.position_deltas),
+      history=tf.where(refresh, 0, state.history))
+  theta = tf.where(refresh, 1.0, theta)
+
+  breakpoints = _cauchy_init_breakpoints(state)
+
+  steepest = tf.where(
+              breakpoints != 0.,
+              -state.objective_gradient,
+              0.)
+
+  free_mask = (breakpoints > 0)
+  free_vars_idx = tf.where(
+                    free_mask,
+                    tf.broadcast_to(
+                      tf.range(ps.shape(state.position)[-1], dtype=tf.int32),
+                      ps.shape(state.position)),
+                    -1)
+
+  # We need to account for the varying histories:
+  # we assume that the first `2*(m-h)` rows of W'^T
+  # are 0 (where `m` is the number of correction pairs
+  # and `h` is the history), in concordance with the first
+  # `2*(m-h)` rows of M being 0.
+  # 1. Calculate all elements
+  p = tf.concat(
+        [
+          tf.einsum(
+                  "m...i,...i->...m",
+                  state.gradient_deltas,
+                  steepest),
+          theta[..., tf.newaxis] * \
+                tf.einsum(
+                  "m...i,...i->...m",
+                  state.position_deltas,
+                  steepest)
+        ],
+        axis=-1)
+  # 2. Assemble the rows in the correct order
+  idx = tf.concat(
+          [
+            tf.ragged.range(
+              num_correction_pairs - state.history),
+            tf.ragged.range(
+              num_correction_pairs,
+              2*num_correction_pairs - state.history),
+            tf.ragged.range(
+              num_correction_pairs - state.history,
+              num_correction_pairs),
+            tf.ragged.range(
+              2*num_correction_pairs - state.history,
+              2*num_correction_pairs)
+          ],
+          axis=-1).to_tensor()
+  p = tf.gather(
+        p,
+        idx,
+        batch_dims=1)
+
+  c = tf.zeros_like(p)
+
+  df = -tf.reduce_sum(steepest**2, axis=-1)
+  ddf = -theta*df - tf.einsum("...i,...ij,...j->...", p, m, p)
+  dt_min = -tf.math.divide_no_nan(df, ddf)
+
+  breakpoint_min_idx, breakpoint_min = \
+    _cauchy_get_breakpoint_min(breakpoints, free_vars_idx)
+
+  dt = breakpoint_min
+
+  breakpoint_min_old = tf.zeros_like(breakpoint_min)
+
+  cauchy_point = state.position
+
+  active = ~(state.converged | state.failed) & \
+              _cauchy_update_active(free_vars_idx, dt_min, dt)
+
+  return _ConstrainedCauchyState(
+    theta, m, breakpoints, steepest, free_vars_idx, free_mask,
+    p, c, df, ddf, dt_min, breakpoint_min, breakpoint_min_idx,
+    dt, breakpoint_min_old, cauchy_point, active)
+
+
+def _cauchy_init_m(state, deltas_shape, theta, num_correction_pairs):
+  def build_m():
+    # All of the below block matrices have dimensions [..., m, m]
+    #  where `...` denotes the batch dimensions, and `m` the number
+    #  of correction pairs (compare to `deltas_shape`, which is [m,...,n]).
+    # New elements are pushed in "from the back", so we want to index
+    #  position_deltas and gradient_deltas with negative indices.
+    # Index 0 of `position_deltas` and `gradient_deltas` is oldest, and index -1
+    #  is most recent, so the below respects the indexing of the article.
+
+    # 1. calculate inner product (s_i.y_j) in shape [..., m, m]
+    l = tf.einsum(
+          "m...i,u...i->...mu",
+          state.position_deltas,
+          state.gradient_deltas)
+    # 2. Zero out diagonal and upper triangular
+    l_shape = ps.shape(l)
+    l = tf.linalg.set_diag(
+          tf.linalg.band_part(l, -1, 0),
+          tf.zeros([l_shape[0], l_shape[-1]]))
+    l_transpose = tf.linalg.matrix_transpose(l)
+    s_t_s = tf.einsum(
+              'm...i,n...i->...mn',
+              state.position_deltas,
+              state.position_deltas)
+    d = tf.linalg.diag(
+          tf.einsum(
+          'm...i,m...i->...m',
+          state.position_deltas,
+          state.gradient_deltas))
+
+    # Assemble into full matrix
+    # TODO: Is there no better way to create a block matrix?
+    block_d = tf.concat([-d, tf.zeros_like(d)], axis=-1)
+    block_d = tf.concat([block_d, tf.zeros_like(block_d)], axis=-2)
+    block_l_transpose = tf.concat([tf.zeros_like(l_transpose), l_transpose], axis=-1)
+    block_l_transpose = tf.concat([block_l_transpose, tf.zeros_like(block_l_transpose)], axis=-2)
+    block_l = tf.concat([l, tf.zeros_like(l)], axis=-1)
+    block_l = tf.concat([tf.zeros_like(block_l), block_l], axis=-2)
+    block_s_t_s = tf.concat([tf.zeros_like(s_t_s), s_t_s], axis=-1)
+    block_s_t_s = tf.concat([tf.zeros_like(block_s_t_s), block_s_t_s], axis=-2)
+
+    # shape [b, 2m, 2m]
+    m_inv = block_d + block_l_transpose + block_l + \
+              theta[..., tf.newaxis, tf.newaxis] * block_s_t_s
+    
+    # Adjust for varying history:
+    # Push columns indexed h,...,2m-h to the left (but to the right of 0...m-h)
+    #  and same index rows to the bottom
+    idx = tf.concat(
+            [tf.ragged.range(num_correction_pairs-state.history),
+              tf.ragged.range(num_correction_pairs, 2*num_correction_pairs-state.history),
+              tf.ragged.range(num_correction_pairs-state.history, num_correction_pairs),
+              tf.ragged.range(2*num_correction_pairs-state.history, 2*num_correction_pairs)],
+            axis=-1).to_tensor()
+    m_inv = tf.gather(
+              m_inv,
+              idx,
+              axis=-1,
+              batch_dims=1)
+    m_inv = tf.gather(
+              m_inv,
+              idx,
+              axis=-2,
+              batch_dims=1)
+
+    # Insert an identity in the empty block
+    identity_mask = \
+      (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] < \
+        2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis]
+    
+    m_inv = tf.where(
+              identity_mask,
+              tf.eye(deltas_shape[0]*2, batch_shape=[deltas_shape[1]]),
+              m_inv)
+
+    # If M is not invertible, refresh the memory
+    refresh = (tf.linalg.det(m_inv) == 0)
+
+    # Invert where invertible; 0s otherwise
+    m = tf.where(
+          refresh[..., tf.newaxis, tf.newaxis],
+          tf.zeros_like(m_inv),
+          tf.linalg.inv(
+            tf.where(
+              refresh[..., tf.newaxis, tf.newaxis],
+              tf.eye(deltas_shape[0]*2, batch_shape=[deltas_shape[1]]),
+              m_inv)))
+
+    # Re-zero the introduced identity blocks
+    m = tf.where(
+          identity_mask,
+          tf.zeros_like(m),
+          m)
+
+    return m, refresh
+  
+  # M is 0 for the first iterations
+  return tf.cond(
+          state.num_iterations < 1,
+          lambda: (tf.zeros([deltas_shape[1], 2*deltas_shape[0], 2*deltas_shape[0]]),
+                    tf.broadcast_to(False, ps.shape(state.history))),
+          build_m)
+
+
+def _cauchy_init_breakpoints(state):
+  breakpoints = \
+    tf.where(
+      state.objective_gradient < 0,
+      tf.math.divide_no_nan(
+        state.position - state.upper_bounds,
+        state.objective_gradient),
+      tf.where(
+        state.objective_gradient > 0,
+        tf.math.divide_no_nan(
+          state.position - state.lower_bounds,
+          state.objective_gradient),
+        float('inf')))
+
+  return breakpoints
+
+
+def _cauchy_remove_breakpoint_min(free_vars_idx,
+                                  breakpoint_min_idx,
+                                  free_mask,
+                                  active):
+  """Update the free variable indices to remove the minimum breakpoint index.
+
+  Returns:
+    Updated `free_vars_idx`, `free_mask`
+  """
+
+  # NOTE: In situations where none of the indices are free, breakpoint_min_idx
+  #  will falsely report 0. However, this is fine, because in this situation,
+  #  every element of free_vars_idx is -1, and so there is no match.
+  matching = (free_vars_idx == breakpoint_min_idx[..., tf.newaxis])
+  free_vars_idx = tf.where(
+                    matching,
+                    -1,
+                    free_vars_idx)
+  free_mask = tf.where(
+                active[..., tf.newaxis],
+                free_vars_idx >= 0,
+                free_mask)
+  
+  return free_vars_idx, free_mask
+
+
+def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx):
+  """Find the smallest breakpoint of free indices, returning the minimum breakpoint
+  and the corresponding index.
+
+  Returns:
+    Tuple of `breakpoint_min_idx`, `breakpoint_min`
+    where
+      `breakpoint_min_idx` is the index that has min. breakpoint
+      `breakpoint_min` is the corresponding breakpoint
+  """
+  # A tensor of shape [batch, dims] that has +infinity where free_vars_idx < 0,
+  #  and has breakpoints[free_vars_idx] otherwise.
+  flagged_breakpoints = tf.where(
+                          free_vars_idx < 0,
+                          float('inf'),
+                          tf.gather(
+                            breakpoints,
+                            tf.where(
+                              free_vars_idx < 0,
+                              0,
+                              free_vars_idx),
+                            batch_dims=1))
+
+  argmin_idx = tf.math.argmin(
+                flagged_breakpoints,
+                axis=-1,
+                output_type=tf.int32)
+  
+  # NOTE: For situations where there are no more free indices
+  #  (and therefore argmin_idx indexes into -1), we set
+  #  breakpoint_min_idx to 0 and flag that there are no free
+  #  indices by setting the breakpoint to -1 (this is an impossible
+  #  value, as breakpoints are g.e. to 0).
+  #  This is because in branching situations, indexing with
+  #  breakpoint_min_idx can occur, and later be discarded, but all
+  #  elements in breakpoint_min_idx must be a priori valid indices.
+  no_free = tf.gather(
+              free_vars_idx,
+              argmin_idx,
+              batch_dims=1) < 0
+  breakpoint_min_idx = tf.where(
+                        no_free,
+                        0,
+                        tf.gather(
+                          free_vars_idx,
+                          argmin_idx,
+                          batch_dims=1))
+  breakpoint_min = tf.where(
+                    no_free,
+                    -1.,
+                    tf.gather(
+                      breakpoints,
+                      argmin_idx,
+                      batch_dims=1))
+
+  return breakpoint_min_idx, breakpoint_min
+
+
+def _get_search_direction(state):
+  """Computes the search direction to follow at the current state.
+
+  On the `k`-th iteration of the main L-BFGS algorithm, the state has collected
+  the most recent `m` correction pairs in position_deltas and gradient_deltas,
+  where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`.
+
+  Assuming these, the code below is an implementation of the L-BFGS two-loop
+  recursion algorithm given by [Nocedal and Wright(2006)][1]:
+
+  ```None
+    q_direction = objective_gradient
+    for i in reversed(range(m)):  # First loop.
+      inv_rho[i] = gradient_deltas[i]^T * position_deltas[i]
+      alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i]
+      q_direction = q_direction - alpha[i] * gradient_deltas[i]
+
+    kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] /
+                              gradient_deltas[-1]^T * gradient_deltas[-1])
+    r_direction = kth_inv_hessian_factor * I * q_direction
+
+    for i in range(m):  # Second loop.
+      beta = gradient_deltas[i]^T * r_direction / inv_rho[i]
+      r_direction = r_direction + position_deltas[i] * (alpha[i] - beta)
+
+    return -r_direction  # Approximates - H_k * objective_gradient.
+  ```
+
+  Args:
+    state: A `LBfgsBOptimizerResults` tuple with the current state of the
+      search procedure.
+
+  Returns:
+    A real `Tensor` of the same shape as the `state.position`. The direction
+    along which to perform line search.
+  """
+  # The number of correction pairs that have been collected so far.
+  #num_elements = ps.minimum(
+  #    state.num_iterations,  # TODO(b/162733947): Change loop state -> closure.
+  #    ps.shape(state.position_deltas)[0])
+
+  def _two_loop_algorithm():
+    """L-BFGS two-loop algorithm."""
+    # Correction pairs are always appended to the end, so only the latest
+    # `num_elements` vectors have valid position/gradient deltas. Vectors
+    # that haven't been computed yet are zero.
+    position_deltas = state.position_deltas
+    gradient_deltas = state.gradient_deltas
+    num_correction_pairs, num_batches, _point_dims = \
+      ps.shape(gradient_deltas, out_type=tf.int32)
+
+    # Pre-compute all `inv_rho[i]`s.
+    inv_rhos = tf.reduce_sum(
+        gradient_deltas * position_deltas, axis=-1)
+
+    def first_loop(acc, args):
+      _, q_direction, num_iter = acc
+      position_delta, gradient_delta, inv_rho = args
+      active = (num_iter < state.history)
+      alpha = tf.math.divide_no_nan(
+                tf.reduce_sum(
+                  position_delta * q_direction,
+                  axis=-1),
+                inv_rho)
+      direction_delta = alpha[..., tf.newaxis] * gradient_delta
+      new_q_direction = tf.where(
+                          active[..., tf.newaxis],
+                          q_direction - direction_delta,
+                          q_direction)
+
+      return (alpha, new_q_direction, num_iter + 1)
+
+    # Run first loop body computing and collecting `alpha[i]`s, while also
+    # computing the updated `q_direction` at each step.
+    zero = tf.zeros_like(inv_rhos[0])
+    alphas, q_directions, _num_iters = tf.scan(
+        first_loop, [position_deltas, gradient_deltas, inv_rhos],
+        initializer=(zero, state.objective_gradient, 0), reverse=True)
+
+    # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse
+    # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`.
+    idx = tf.transpose(
+            tf.stack(
+              [tf.where(
+                state.history > 0,
+                num_correction_pairs - state.history,
+                0),
+              tf.range(num_batches)]))
+    gamma_k = tf.math.divide_no_nan(
+                tf.gather_nd(inv_rhos, idx),
+                tf.reduce_sum(
+                  tf.gather_nd(gradient_deltas, idx)**2,
+                  axis=-1))
+    gamma_k = tf.where(
+                (state.history > 0),
+                gamma_k,
+                1.0)
+    r_direction = gamma_k[..., tf.newaxis] * tf.gather_nd(q_directions, idx)
+
+    def second_loop(acc, args):
+      r_direction, iter_idx = acc
+      alpha, position_delta, gradient_delta, inv_rho = args
+      active = (iter_idx >= num_correction_pairs - state.history)
+      beta = tf.math.divide_no_nan(
+              tf.reduce_sum(
+                gradient_delta * r_direction,
+                axis=-1),
+              inv_rho)
+      direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta
+      new_r_direction = tf.where(
+                          active[..., tf.newaxis],
+                          r_direction + direction_delta,
+                          r_direction)
+      return (new_r_direction, iter_idx + 1)
+
+    # Finally, run second loop body computing the updated `r_direction` at each
+    # step.
+    r_directions, _num_iters = tf.scan(
+        second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos],
+        initializer=(r_direction, 0))
+
+    return -r_directions[-1]
+
+  return ps.cond(tf.reduce_any(state.history != 0),
+                 _two_loop_algorithm,
+                 lambda: -state.objective_gradient)
+
+
+def _get_ragged_sizes(tensor, dtype=tf.int32):
+  """Creates a tensor indicating the size of each component of
+  a ragged dimension.
+
+  For example:
+
+  ```python
+  element = tf.ragged.constant([[1,2], [3,4,5], [], [0]])
+  _get_ragged_sizes(element)
+  # => <tf.Tensor: shape=(4, 1), dtype=int32, numpy=
+  #      array([[2],
+  #             [3],
+  #             [0],
+  #             [1]], dtype=int32)>
+  ```
+  """
+  return tf.reduce_sum(
+            tf.ones_like(
+              tensor,
+              dtype=dtype),
+            axis=-1)[..., tf.newaxis]
+
+
+def _get_range_like_ragged(tensor, dtype=tf.int32):
+  """Creates a batched range for the elements of the batched tensor.
+
+  For example:
+
+  ```python
+  element = tf.ragged.constant([[1,2], [3,4,5], [], [0]])
+  _get_range_like_ragged(element)
+  # => <tf.RaggedTensor [[0, 1], [0, 1, 2], [], [0]]>
+
+  Args:
+    tensor: a RaggedTensor of shape `[n, None]`.
+
+  Returns:
+    A ragged tensor of shape `[n, None]` where the ragged dimensions
+    match the ragged dimensions of `tensor`, and are a range from `0` to
+    the size of the ragged dimension.
+  ```
+  """
+  sizes = _get_ragged_sizes(tensor)
+  flat_ranges = tf.ragged.range(
+                  tf.reshape(
+                    sizes,
+                    [tf.reduce_prod(sizes.shape)]),
+                  dtype=dtype)
+  return tf.RaggedTensor.from_row_lengths(flat_ranges, sizes.shape[:-1])[0]
+
+
+def _make_empty_queue_for(k, element):
+  """Creates a `tf.Tensor` suitable to hold `k` element-shaped tensors.
+
+  For example:
+
+  ```python
+    element = tf.constant([[0., 1., 2., 3., 4.],
+                           [5., 6., 7., 8., 9.]])
+
+    # A queue capable of holding 3 elements.
+    _make_empty_queue_for(3, element)
+    # => [[[ 0.,  0.,  0.,  0.,  0.],
+    #      [ 0.,  0.,  0.,  0.,  0.]],
+    #
+    #     [[ 0.,  0.,  0.,  0.,  0.],
+    #      [ 0.,  0.,  0.,  0.,  0.]],
+    #
+    #     [[ 0.,  0.,  0.,  0.,  0.],
+    #      [ 0.,  0.,  0.,  0.,  0.]]]
+  ```
+
+  Args:
+    k: A positive scalar integer, number of elements that each queue will hold.
+    element: A `tf.Tensor`, only its shape and dtype information are relevant.
+
+  Returns:
+    A zero-filed `tf.Tensor` of shape `(k,) + tf.shape(element)` and same dtype
+    as `element`.
+  """
+  queue_shape = ps.concat([[k], ps.shape(element)], axis=0)
+  return tf.zeros(queue_shape, dtype=dtype_util.base_dtype(element.dtype))
+
+
+def _queue_push(queue, should_update, new_vecs):
+  """Conditionally push new vectors into a batch of first-in-first-out queues.
+
+  The `queue` of shape `[k, ..., n]` can be thought of as a batch of queues,
+  each holding `k` n-D vectors; while `new_vecs` of shape `[..., n]` is a
+  fresh new batch of n-D vectors. The `should_update` batch of Boolean scalars,
+  i.e. shape `[...]`, indicates batch members whose corresponding n-D vector in
+  `new_vecs` should be added at the back of its queue, pushing out the
+  corresponding n-D vector from the front. Batch members in `new_vecs` for
+  which `should_update` is False are ignored.
+
+  Note: the choice of placing `k` at the dimension 0 of the queue is
+  constrained by the L-BFGS two-loop algorithm above. The algorithm uses
+  tf.scan to iterate over the `k` correction pairs simulatneously across all
+  batches, and tf.scan itself can only iterate over dimension 0.
+
+  For example:
+
+  ```python
+    k, b, n = (3, 2, 5)
+    queue = tf.reshape(tf.range(30), (k, b, n))
+    # => [[[ 0,  1,  2,  3,  4],
+    #      [ 5,  6,  7,  8,  9]],
+    #
+    #     [[10, 11, 12, 13, 14],
+    #      [15, 16, 17, 18, 19]],
+    #
+    #     [[20, 21, 22, 23, 24],
+    #      [25, 26, 27, 28, 29]]]
+
+    element = tf.reshape(tf.range(30, 40), (b, n))
+    # => [[30, 31, 32, 33, 34],
+          [35, 36, 37, 38, 39]]
+
+    should_update = tf.constant([True, False])  # Shape: (b,)
+
+    _queue_add(should_update, queue, element)
+    # => [[[10, 11, 12, 13, 14],
+    #      [ 5,  6,  7,  8,  9]],
+    #
+    #     [[20, 21, 22, 23, 24],
+    #      [15, 16, 17, 18, 19]],
+    #
+    #     [[30, 31, 32, 33, 34],
+    #      [25, 26, 27, 28, 29]]]
+  ```
+
+  Args:
+    queue: A `tf.Tensor` of shape `[k, ..., n]`; a batch of queues each with
+      `k` n-D vectors.
+    should_update: A Boolean `tf.Tensor` of shape `[...]` indicating batch
+      members where new vectors should be added to their queues.
+    new_vecs: A `tf.Tensor` of shape `[..., n]`; a batch of n-D vectors to add
+      at the end of their respective queues, pushing out the first element from
+      each.
+
+  Returns:
+    A new `tf.Tensor` of shape `[k, ..., n]`.
+  """
+  new_queue = tf.concat([queue[1:], [new_vecs]], axis=0)
+  return tf.where(
+      should_update[tf.newaxis, ..., tf.newaxis], new_queue, queue)

From ea9557c4fcc21a23eb6980084a6f41bd1dfb9d1e Mon Sep 17 00:00:00 2001
From: mikeevmm <miguelmurca@gmail.com>
Date: Thu, 29 Apr 2021 12:10:43 +0100
Subject: [PATCH 2/4] feat: correct subspace minimization (less fn. evals.)

---
 .../python/optimizer/lbfgsb.py                | 1197 ++++++++++-------
 1 file changed, 679 insertions(+), 518 deletions(-)

diff --git a/tensorflow_probability/python/optimizer/lbfgsb.py b/tensorflow_probability/python/optimizer/lbfgsb.py
index 4e7c9e6d8e..1ec36a61f1 100644
--- a/tensorflow_probability/python/optimizer/lbfgsb.py
+++ b/tensorflow_probability/python/optimizer/lbfgsb.py
@@ -28,10 +28,6 @@
 from __future__ import print_function
 
 import collections
-from numpy.core.fromnumeric import argmin, clip
-
-from tensorflow.python.ops.gen_array_ops import gather, lower_bound, where
-from tensorflow_probability.python.internal.backend.numpy import dtype, numpy_math
 
 # Dependency imports
 import tensorflow.compat.v2 as tf
@@ -87,7 +83,7 @@
     ])
 
 _ConstrainedCauchyState = collections.namedtuple(
-    '_ConstrainedCauchyResult', [
+    '_CauchyMinimizationResult', [
         'theta', # `\theta` in [2]; n the Cauchy search, relates to the implicit Hessian
                  # `B = \theta*I - WMW'` (`I` the identity, see [1,2] for details)
         'm', # `M_k` matrix in [2]; part of the implicit representation of the Hessian,
@@ -103,7 +99,7 @@
                           #  while loop.
         'free_mask', # Boolean mask of free variables
         'p', # as in [2]
-        'c', # as in [2]
+        'c', # as in [2]; eventually made to equal `W'(cauchy_point - position)`
         'df', # `f'` in [2]; (corrected) gradient 2-norm
         'ddf', # `f''` in [2]; (corrected) laplacian 2-norm (?)
         'dt_min', # `\Delta t_min` in [2]; the minimizing parameter
@@ -142,7 +138,29 @@ def minimize(value_and_gradients_function,
   constrained minimum for a simple high-dimensional quadratic objective function.
 
   ```python
-    TODO
+  ndims = 60
+  minimum = tf.convert_to_tensor(
+      np.ones([ndims]), dtype=tf.float32)
+  lower_bounds = tf.convert_to_tensor(
+      np.arange(ndims), dtype=tf.float32)
+  upper_bounds = tf.convert_to_tensor(
+      np.arange(100, 100-ndims, -1), dtype=tf.float32)
+  scales = tf.convert_to_tensor(
+      (np.random.rand(ndims) + 1.)*5. + 1., dtype=tf.float32)
+  start = tf.constant(np.random.rand(2, ndims)*100, dtype=tf.float32)
+
+  # The objective function and the gradient.
+  def quadratic_loss_and_gradient(x):
+      return tfp.math.value_and_gradient(
+          lambda x: tf.reduce_sum(
+              scales * tf.math.squared_difference(x, minimum), axis=-1),
+          x)
+  opt_results = tfp.optimizer.lbfgsb_minimize(
+                  quadratic_loss_and_gradient,
+                  initial_position=start,
+                  num_correction_pairs=10,
+                  tolerance=1e-10,
+                  bounds=[lower_bounds, upper_bounds])
   ```
 
   ### References:
@@ -158,6 +176,13 @@ def minimize(value_and_gradients_function,
 
   https://doi.org/10.1137/0916069
 
+  [3] Jose Luis Morales, Jorge Nocedal (2011).
+      "Remark On Algorithm 788: L-BFGS-B: Fortran Subroutines for Large-Scale
+        Bound Constrained Optimization"
+      ACM Trans. Math. Softw. 38, 1, Article 7.
+
+  https://dl.acm.org/doi/abs/10.1145/2049662.2049669
+
   Args:
     value_and_gradients_function:  A Python callable that accepts a point as a
       real `Tensor` and returns a tuple of `Tensor`s of real dtype containing
@@ -369,40 +394,28 @@ def _body(current_state):
       """Main optimization loop."""
       current_state = bfgs_utils.terminate_if_not_finite(current_state)
   
-      cauchy_point, free_mask = \
-        _cauchy_minimization(current_state, num_correction_pairs, parallel_iterations)
-
-      search_direction = _get_search_direction(current_state)
-
-      # TODO(b/120134934): Check if the derivative at the start point is not
-      # negative, if so then reset position/gradient deltas and recompute
-      # search direction.
-      # NOTE: Erasing is currently handled in `_bounded_line_search_step`
-      search_direction = tf.where(
-                          free_mask,
-                          search_direction,
-                          0.)
-      bad_direction = \
-        (tf.reduce_sum(search_direction * current_state.objective_gradient, axis=-1) > 0)
-
-      cauchy_search = _cauchy_line_search_step(current_state,
-          value_and_gradients_function, search_direction,
-          tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
-          max_line_search_iterations, free_mask, cauchy_point)
+      cauchy_state, current_state = (
+        _cauchy_minimization(current_state, num_correction_pairs, parallel_iterations))
+
+      search_direction, current_state, refresh = (
+        _find_search_direction(current_state, cauchy_state, num_correction_pairs))
       
-      search_direction = cauchy_search.position - current_state.position
-      next_state = _bounded_line_search_step(current_state,
-          value_and_gradients_function, search_direction,
+      next_state = _constrained_line_search_step(
+          current_state, value_and_gradients_function, search_direction,
           tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
-          max_line_search_iterations, bad_direction)
+          max_line_search_iterations, refresh)
 
       # If not failed or converged, update the Hessian estimate.
-      # Only do this if the new pairs obey the s.y > 0
-      position_delta = next_state.position - current_state.position
-      gradient_delta = next_state.objective_gradient - current_state.objective_gradient
-      positive_prod = (tf.math.reduce_sum(position_delta * gradient_delta, axis=-1) > \
-                        1E-8*tf.reduce_sum(gradient_delta**2, axis=-1))
-      should_push = ~(next_state.converged | next_state.failed) & positive_prod & ~bad_direction
+      # Only do this if the new pairs obey the s.y > eps.||g||
+      position_delta = (next_state.position - current_state.position)
+      gradient_delta = (next_state.objective_gradient - current_state.objective_gradient)
+      # Article is ambiguous; see lbfgs.f:863
+      positive_prod = (
+        tf.reduce_sum(position_delta * gradient_delta, axis=-1) >
+          dtype_util.eps(current_state.position.dtype) *
+            tf.reduce_sum(current_state.objective_gradient**2, axis=-1)
+      )
+      should_push = ~(next_state.converged | next_state.failed) & positive_prod & ~refresh
       new_position_deltas = _queue_push(
               next_state.position_deltas, should_push, position_delta)
       new_gradient_deltas = _queue_push(
@@ -421,13 +434,13 @@ def _body(current_state):
         new_history = tf.ensure_shape(
           new_history, next_state.history.shape)
 
-      state_after_inv_hessian_update = bfgs_utils.update_fields(
+      next_state = bfgs_utils.update_fields(
           next_state,
           position_deltas=new_position_deltas,
           gradient_deltas=new_gradient_deltas,
           history=new_history)
 
-      return [state_after_inv_hessian_update]
+      return [next_state]
 
     if previous_optimizer_results is None:
       assert initial_position is not None
@@ -455,26 +468,28 @@ def _cauchy_minimization(bfgs_state, num_correction_pairs, parallel_iterations):
   See algorithm CP and associated discussion of [Byrd,Lu,Nocedal,Zhu][2]
   for details.
 
+  This function may modify the given `bfgs_state`, in that it refreshes the memory
+  for batches that are found to be in an invalid state.
+
   Args:
-    bfgs_state: A `_ConstrainedCauchyState` initialized to the starting point of the
-    constrained minimization.
+    bfgs_state: current `LBfgsBOptimizerResults` state
+    num_correction_pairs: typically `m`; the (maximum) number of past steps to keep as
+      history for the LBFGS algorithm
+    parallel_iterations: argument of `tf.while` loops
   Returns:
-    A potentially modified `state`, the obtained `cauchy_point` and boolean
-    `free_mask` indicating which variables are free (`True`) and which variables
-    are under active constrain (`False`)
+    A `_CauchyMinimizationResult` containing the results of the Cauchy point computation.
+    Updated `bfgs_state`
   """
-  cauchy_state = _get_initial_cauchy_state(bfgs_state, num_correction_pairs)
-  # NOTE: See lbfgsb.f (l. 1649)
+  cauchy_state, bfgs_state = _get_initial_cauchy_state(bfgs_state, num_correction_pairs)
+  # NOTE: See lbfgsb.f (l. 1524)
   ddf_org = -cauchy_state.theta * cauchy_state.df
 
   def _cond(state):
     """Test convergence to Cauchy point at current branch"""
-    return tf.math.reduce_any(state.active)
+    return tf.reduce_any(state.active)
 
   def _body(state):
-    """Cauchy point iterative loop
-    
-    (While loop of CP algorithm [2])"""
+    """Cauchy point iterative loop (While loop of CP algorithm [2])"""
     # Remove b from the free indices
     free_vars_idx, free_mask = _cauchy_remove_breakpoint_min(
                                 state.free_vars_idx,
@@ -520,7 +535,8 @@ def _body(state):
                 state.breakpoint_min_idx,
                 batch_dims=1))
 
-    keep_idx = (tf.range(ps.shape(state.cauchy_point)[-1]) != \
+    # Set the `b`th component of the `cauchy_point` to `x_cp_b`
+    keep_idx = (tf.range(ps.shape(state.cauchy_point)[-1])[tf.newaxis, ...] != 
                   state.breakpoint_min_idx[..., tf.newaxis])
     cauchy_point = tf.where(
                     state.active[..., tf.newaxis],
@@ -562,14 +578,14 @@ def _body(state):
                   state.breakpoint_min_idx,
                   axis=-1,
                   batch_dims=1),
-                state.theta[..., tf.newaxis] * \
+                (state.theta[..., tf.newaxis] *
                   tf.gather(
                     tf.transpose(
                       bfgs_state.position_deltas,
                       perm=[1,0,2]),
                     state.breakpoint_min_idx,
                     axis=-1,
-                    batch_dims=1)
+                    batch_dims=1))
               ],
               axis=-1)
     # 2. "Permute" the relevant items to the right
@@ -596,34 +612,33 @@ def _body(state):
     # NOTE Use of d_b = -g_b
     df = tf.where(
           state.active,
-          state.df + state.dt * state.ddf + \
-            d_b**2 - \
-            state.theta * d_b * z_b + \
+          (state.df + state.dt * state.ddf +
+            d_b**2 -
+            state.theta * d_b * z_b +
             d_b * tf.einsum(
                     '...j,...jk,...k->...',
                     w_b,
                     state.m,
-                    c),
+                    c)),
           state.df)
           
     # NOTE use of d_b = -g_b
     ddf = tf.where(
             state.active,
-            state.ddf - state.theta * d_b**2 + \
+            (state.ddf - state.theta * d_b**2 +
               2. * d_b * tf.einsum(
                           "...i,...ij,...j->...",
                           w_b,
                           state.m,
-                          state.p) - \
+                          state.p) -
               d_b**2 * tf.einsum(
                         "...i,...ij,...j->...",
                         w_b,
                         state.m,
-                        w_b),
+                        w_b)),
             state.ddf)
     # NOTE: See lbfgsb.f (l. 1649)
-    # TODO: How to get machine epsilon?
-    ddf = tf.math.maximum(ddf, 1E-8*ddf_org)
+    ddf = tf.math.maximum(ddf, dtype_util.eps(ddf.dtype)*ddf_org)
 
     # NOTE use of d_b = -g_b
     p = tf.where(
@@ -653,10 +668,10 @@ def _body(state):
                           state.breakpoint_min_old)
     
     # Find b
-    breakpoint_min_idx, breakpoint_min = \
+    breakpoint_min_idx, breakpoint_min = (
       _cauchy_get_breakpoint_min(
         state.breakpoints,
-        free_vars_idx)
+        free_vars_idx))
     breakpoint_min_idx = tf.where(
                           state.active,
                           breakpoint_min_idx,
@@ -670,11 +685,9 @@ def _body(state):
           state.active,
           breakpoint_min - state.breakpoint_min,
           state.dt)
-          
-    active = tf.where(
-              state.active,
-              _cauchy_update_active(free_vars_idx, dt_min, dt),
-              state.active)
+    
+    active = (state.active & 
+      _cauchy_update_active(free_vars_idx, state.breakpoints, dt_min, dt))
 
     # We have to hint the "compiler" that the shapes of the new
     # values are the same as the old values.
@@ -710,212 +723,444 @@ def _body(state):
         parallel_iterations=parallel_iterations)[0]
 
   # The loop broke, so the last identified `b` index never got
-  # removed
-  _free_vars_idx, free_mask = _cauchy_remove_breakpoint_min(
-                              cauchy_loop.free_vars_idx,
-                              cauchy_loop.breakpoint_min_idx,
-                              cauchy_loop.free_mask,
-                              cauchy_loop.active)
+  # removed; we do not require knowledge of the free mask to
+  # terminate the algorithm and recalculate the free mask below
+  # with a different method, so we do not correct for this
+  #free_vars_idx, free_mask = _cauchy_remove_breakpoint_min(
+  #                            cauchy_loop.free_vars_idx,
+  #                            cauchy_loop.breakpoint_min_idx,
+  #                            cauchy_loop.free_mask,
+  #                            cauchy_loop.active)
 
   dt_min = tf.math.maximum(cauchy_loop.dt_min, 0)
   t_old = cauchy_loop.breakpoint_min_old + dt_min
   
   # A breakpoint of -1 means that we ran out of free variables
-  flagged_breakpoint_min = tf.where(
-                              cauchy_loop.breakpoint_min < 0,
-                              float('inf'),
-                              cauchy_loop.breakpoint_min)
+  change_cauchy = (
+    (cauchy_loop.breakpoint_min >= 0)[..., tf.newaxis] &
+    (cauchy_loop.breakpoints >= cauchy_loop.breakpoint_min[..., tf.newaxis])
+  )
   cauchy_point = tf.where(
       ~(bfgs_state.converged | bfgs_state.failed)[..., tf.newaxis],
       tf.where(
-        cauchy_loop.breakpoints >= flagged_breakpoint_min[..., tf.newaxis],
+        change_cauchy,
         bfgs_state.position + t_old[..., tf.newaxis] * cauchy_loop.steepest,
         cauchy_loop.cauchy_point),
       bfgs_state.position)
 
-  # NOTE: We only return the cauchy point and the free mask, so there is no
-  #  need to update the actual state, even though we could at this point update
-  #  `free_vars_idx`, `free_mask`, and `cauchy_point`
-  free_mask = free_mask & ~(cauchy_loop.breakpoints != cauchy_loop.breakpoint_min)
+  c = cauchy_loop.c + dt_min[..., tf.newaxis]*cauchy_loop.p
+  # NOTE: `c` is already permuted to match the subspace of `M`, because `w_b`
+  #  was already permuted.
+  # You can explicitly check this by comparing its value with W'.(x^c - x)
+  #  at this point.
+
+  # Update the free mask;
+  # Instead of updating the mask as suggested in [1, CP Algorithm], we recalculate
+  # whether each variable is free by looking at whether the Cauchy point is near
+  # the bound. This matches other implementations, and avoids weirdness where
+  # the first considered variable is always marked as constrained.
+  # NOTE: the 10 epsilon margin is fairly arbitrary
+  free_mask =(
+    tf.math.minimum(
+      tf.math.abs(cauchy_point - bfgs_state.upper_bounds),
+      tf.math.abs(cauchy_point - bfgs_state.lower_bounds),
+    ) > 10. * dtype_util.eps(cauchy_point.dtype)
+  )
+  free_vars_idx = (
+    tf.where(
+      free_mask,
+      tf.range(ps.shape(free_mask)[-1])[tf.newaxis, ...],
+      -1))
 
-  return cauchy_point, free_mask
+  # Update the final cauchy_state
+  # Hint the compiler that shape of things will not change
+  if not tf.executing_eagerly():
+    free_vars_idx = (
+      tf.ensure_shape(
+        free_vars_idx,
+        cauchy_loop.free_vars_idx.shape))
+    free_mask = (
+      tf.ensure_shape(
+        free_mask,
+        cauchy_loop.free_mask.shape))
+    cauchy_point = (
+      tf.ensure_shape(
+        cauchy_point,
+        cauchy_loop.cauchy_point.shape))
+    c = (
+      tf.ensure_shape(
+        c,
+        cauchy_loop.c.shape))
+  # Do the actual updating
+  final_cauchy_state = bfgs_utils.update_fields(
+    cauchy_loop,
+    free_vars_idx=free_vars_idx,
+    free_mask=free_mask,
+    cauchy_point=cauchy_point,
+    c=c)
+
+  return final_cauchy_state, bfgs_state
+
+
+def _cauchy_update_active(free_vars_idx, breakpoints, dt_min, dt):
+  """Determines whether each batch of a `_CauchyMinimizationResult` is active.
+
+  The conditions for a batch being active (i.e. for the loop of "Algorithm CP"
+  of [2] to proceed for that batch are):
+
+  1. That `dt_min >= dt` (as made explicit in the paper),
+  2. That there are free variables, and
+  3. That of those free variables, at least one of the corresponding breakpoints is
+      finite.
 
+  Args:
+    free_vars_idx: tensor of shape [batch, dims] where each element corresponds to
+      the index of the variable if the variable is free, and `-1` if the variable is
+      actively constrained
+    breakpoints: the breakpoints (`t` in [2]) of the `_CauchyMinimizationResult`
+    dt_min: the current `dt_min` property of the `_CauchyMinimizationResult`
+    dt: the current `dt` property of the `_CauchyMinimizationResult`
+  """
+  free_vars = (free_vars_idx >= 0)
+  return (
+    (dt_min >= dt) &
+    tf.reduce_any(free_vars, axis=-1) &
+    tf.reduce_any(free_vars & (breakpoints != float('inf')), axis=-1))
 
-def _cauchy_update_active(free_vars_idx, dt_min, dt):
-  return tf.where(
-            tf.reduce_any(free_vars_idx >= 0, axis=-1) & (dt_min >= dt),
-            True,
-            False)
+
+def _find_search_direction(bfgs_state, cauchy_state, num_correction_pairs):
+  """Finds the search direction based on the direct primal method.
+
+  This function corresponds to points 1-6 of the Direct Primal Method presented
+  in [2, p. 1199], with the first modification suggested in [3].
+
+  If an invalid condition is reached for a given batch, its history is reset. Therefore,
+  this function also returns an updated `bfgs_state`. 
+
+  Args:
+    bfgs_state: the `LBfgsBOptimizerResults` object representing the current iteration.
+    cauchy_state: the `_CauchyMinimizationResult` results of a cauchy search computation.
+                  Typically the output of `_cauchy_minimization`.
+    num_correction_pairs: The (maximum) number of correction pairs stored in memory (`m`)
+  Returns:
+    Tensor of batched search directions,
+    Updated `bfgs_state`
+    Boolean mask of batches that have been refreshed
+  """
+  # Let the reduced gradient be [2, eq. 5.4]
+  #
+  #     ρ = Z'r
+  #     r = g + Θ(x^c - x) - W.M.c
+  #
+  # and the search direction [2, eq. 5.7]
+  #
+  #     d = -B⁻¹ρ
+  #
+  # and [2, eq. 5.10]
+  # 
+  #     B⁻¹ = 1/Θ [ I + 1/Θ Z'.W.N⁻¹.M.W'.Z ]
+  #     N   = I - 1/Θ M.W'.Z.Z'.W
+  #
+  # Therefore, (see NOTE below regarding minus sign)
+  #
+  #     d = Z' . (-1/Θ) . [ r + 1/Θ W.N⁻¹.M.W'.Z.Z'.r ]
+  #
+  # Letting
+  # 
+  #     K = M.W'.Z.Z'
+  #
+  # this is rewritten as
+  #
+  #     d = -Z' . (1/Θ) . [ r + 1/Θ W.N⁻¹.K.r ]
+  #     N = I - (1/Θ) K.W
+
+  idx = (
+    tf.concat([
+      tf.ragged.range(
+        num_correction_pairs - bfgs_state.history),
+      tf.ragged.range(
+        num_correction_pairs,
+        2*num_correction_pairs - bfgs_state.history),
+      tf.ragged.range(
+        num_correction_pairs - bfgs_state.history,
+        num_correction_pairs),
+      tf.ragged.range(
+        2*num_correction_pairs - bfgs_state.history,
+        2*num_correction_pairs)
+    ],
+    axis=-1).to_tensor())
+
+  w_transpose = (
+    tf.gather(
+      tf.transpose(
+        tf.concat(
+          [bfgs_state.gradient_deltas,
+          cauchy_state.theta[tf.newaxis, ..., tf.newaxis] * bfgs_state.position_deltas],
+          axis=0),
+        perm=[1, 0, 2]
+      ),
+      idx,
+    batch_dims=1)
+  )
+
+  k = (
+    tf.einsum(
+      '...ij,...jk->...ik',
+      cauchy_state.m,
+      tf.where(
+        cauchy_state.free_mask[..., tf.newaxis, :],
+        w_transpose,
+        0.
+      )
+    )
+  )
+
+  n = (
+    tf.eye(2*num_correction_pairs, batch_shape=ps.shape(bfgs_state.position)[:-1]) -
+    tf.einsum(
+      '...ij,...kj->...ik',
+      k,
+      w_transpose
+    ) / cauchy_state.theta[..., tf.newaxis, tf.newaxis]
+  )
+
+  n_mask = (
+    tf.range(2*num_correction_pairs)[tf.newaxis, ...] <
+      (2*num_correction_pairs - 2*bfgs_state.history)[..., tf.newaxis]
+  )[..., tf.newaxis]
+
+  n = (
+    tf.where(
+      n_mask,
+      tf.eye(2*num_correction_pairs, batch_shape=ps.shape(bfgs_state.position)[:-1]),
+      n
+    )
+  )
+
+  # NOTE: For no history, N is at the moment the identity (so never triggers a refresh),
+  #  and is correctly completely zeroed once the inversion is complete
+  refresh = (tf.linalg.det(n) == 0)
+
+  n = (
+    tf.where(
+      (refresh[..., tf.newaxis, tf.newaxis] | n_mask),
+      0.,
+      tf.linalg.inv(
+        tf.where(
+          refresh[..., tf.newaxis, tf.newaxis],
+          tf.eye(2*num_correction_pairs, batch_shape=ps.shape(bfgs_state.position)[:-1]),
+          n
+        )
+      )
+    )
+  )
+
+  r = (
+    bfgs_state.objective_gradient +
+    (cauchy_state.cauchy_point - bfgs_state.position) * cauchy_state.theta[..., tf.newaxis] -
+    tf.einsum(
+      '...ji,...jk,...k->...i',
+      w_transpose,
+      cauchy_state.m,
+      cauchy_state.c
+    )
+  )
+
+  # TODO: According to the comment at the top of this function's definition
+  #      there's a leading minus here, but both the article and the Fortran
+  #      implementation do not use it. I cannot understand why, but the negative
+  #      sign seems to produce the correct results.
+  #      (See lbfgsb.f:3021)
+  d = (
+    -(r +
+      tf.einsum(
+        '...ji,...jk,...kl,...l->...i',
+        w_transpose,
+        n,
+        k,
+        r
+      ) / cauchy_state.theta[..., tf.newaxis]
+    ) / cauchy_state.theta[..., tf.newaxis]
+  )
+
+  d = (
+    tf.where(
+      cauchy_state.free_mask,
+      d,
+      0.
+    )
+  )
+
+  # Per [3]:
+  # Clip the `(cauchy point) + d` into the bounds
+  lower_term = tf.math.divide_no_nan(
+                bfgs_state.lower_bounds - cauchy_state.cauchy_point,
+                d)
+  upper_term = tf.math.divide_no_nan(
+                bfgs_state.upper_bounds - cauchy_state.cauchy_point,
+                d)
+  clip_per_var = (
+                tf.where(
+                  (d > 0),
+                  upper_term,
+                  tf.where(
+                    (d < 0),
+                    lower_term,
+                    float('inf')))
+  )
+
+  movement_clip = (
+    tf.math.minimum(
+      tf.math.reduce_min(clip_per_var, axis=-1),
+      1.)
+  )
+  # NOTE: `d` is zeroed for constrained variables, and `movement_clip` is at most 1.
+  minimizer = (
+    cauchy_state.cauchy_point + movement_clip[..., tf.newaxis]*d
+  )
+  
+  # Per [3]: If the search direction obtained with this minimizer is not a direction
+  # of strong descent, do not clip `d` to obtain the minimizer (i.e. fall back to the
+  # original algorithm)
+  fallback = (
+    tf.reduce_sum(
+      (minimizer - bfgs_state.position) * bfgs_state.objective_gradient, axis=-1) > 0
+  )
+
+  minimizer = (
+    tf.where(
+      fallback[..., tf.newaxis],
+      cauchy_state.cauchy_point + d,
+      minimizer
+    )
+  )
+
+  search_direction = (minimizer - bfgs_state.position)
+
+  # Reset if the search direction still isn't a direction of strong descent
+  refresh |=  (
+    tf.reduce_sum(search_direction * bfgs_state.objective_gradient, axis=-1) > 0)
+
+  # Apply refresh
+  bfgs_state = _erase_history(bfgs_state, refresh)
+
+  return search_direction, bfgs_state, refresh
 
 
-def _hz_line_search(state, value_and_gradients_function,
-      search_direction, max_iterations, inactive):
+def _hz_line_search(starting_position, starting_value, starting_gradient,
+      value_and_gradients_function, search_direction, max_iterations, inactive):
+  """Performs Hager Zhang line search via `bfgs_utils.linesearch.hager_zhang`."""
   line_search_value_grad_func = bfgs_utils._restrict_along_direction(
-      value_and_gradients_function, state.position, search_direction)
+      value_and_gradients_function, starting_position, search_direction)
   derivative_at_start_pt = tf.reduce_sum(
-      state.objective_gradient * search_direction, axis=-1)
-  val_0 = bfgs_utils.ValueAndGradient(x=bfgs_utils._broadcast(0, state.position),
-                           f=state.objective_value,
+      starting_gradient * search_direction, axis=-1)
+  val_0 = bfgs_utils.ValueAndGradient(x=bfgs_utils._broadcast(0, starting_position),
+                           f=starting_value,
                            df=derivative_at_start_pt,
-                           full_gradient=state.objective_gradient)
+                           full_gradient=starting_gradient)
   return bfgs_utils.linesearch.hager_zhang(
       line_search_value_grad_func,
-      initial_step_size=bfgs_utils._broadcast(1, state.position),
+      initial_step_size=bfgs_utils._broadcast(1, starting_position),
       value_at_zero=val_0,
       converged=inactive,
       max_iterations=max_iterations)  # No search needed for these.
 
 
-def _cauchy_line_search_step(state, value_and_gradients_function, search_direction,
+def _constrained_line_search_step(bfgs_state, value_and_gradients_function, search_direction,
                      grad_tolerance, f_relative_tolerance, x_tolerance,
-                     stopping_condition, max_iterations, free_mask, cauchy_point):
-  """Performs the line search in given direction, backtracking in direction to the cauchy point,
-  and clamping actively contrained variables to the cauchy point."""
-  inactive = state.failed | state.converged
-  ls_result = _hz_line_search(state, value_and_gradients_function,
-                search_direction, max_iterations, inactive)
-  
-  state_after_ls = bfgs_utils.update_fields(
-      state,
-      failed=state.failed | (~state.converged & ~ls_result.converged & tf.reduce_any(free_mask, axis=-1)),
-      num_iterations=state.num_iterations + 1,
-      num_objective_evaluations=(
-          state.num_objective_evaluations + ls_result.func_evals + 1))
-
-  def _do_update_position():
-    # For inactive batch members `left.x` is zero. However, their
-    # `search_direction` might also be undefined, so we can't rely on
-    # multiplication by zero to produce a `position_delta` of zero.
-    alpha = ls_result.left.x[..., tf.newaxis]
-    ideal_position = tf.where(
-        inactive[..., tf.newaxis],
-        state.position,
-        tf.where(
-          free_mask,
-          state.position + search_direction * alpha,
-          cauchy_point))
-
-    # Backtrack from the ideal position in direction to the Cauchy point
-    cauchy_to_ideal = ideal_position - cauchy_point
-    clip_lower = tf.math.divide_no_nan(
-                  state.lower_bounds - cauchy_point,
-                  cauchy_to_ideal)
-    clip_upper = tf.math.divide_no_nan(
-                  state.upper_bounds - cauchy_point,
-                  cauchy_to_ideal)
-    clip = tf.math.reduce_min(
-            tf.where(
-              cauchy_to_ideal > 0,
-              clip_upper,
-              tf.where(
-                cauchy_to_ideal < 0,
-                clip_lower,
-                float('inf'))),
-            axis=-1)
-    alpha = tf.minimum(1.0, clip)[..., tf.newaxis]
-    
-    next_position = tf.where(
-        inactive[..., tf.newaxis],
-        state.position,
-        tf.where(
-          free_mask,
-          cauchy_point + alpha * cauchy_to_ideal,
-          cauchy_point))
-    
-    # NOTE: one extra call to the function
-    next_objective, next_gradient = \
-      value_and_gradients_function(next_position)
+                     stopping_condition, max_iterations, refresh):
+  """Performs a constrained line search clamped to bounds in given direction."""
+  inactive = (bfgs_state.failed | bfgs_state.converged) | refresh
 
-    return _update_position(
-        state_after_ls,
-        next_position,
-        next_objective,
-        next_gradient,
-        grad_tolerance,
-        f_relative_tolerance,
-        x_tolerance,
-        tf.constant(False))
-
-  return ps.cond(
-      stopping_condition(state.converged, state.failed),
-      true_fn=lambda: state_after_ls,
-      false_fn=_do_update_position)
-
-
-def _bounded_line_search_step(state, value_and_gradients_function, search_direction,
-                     grad_tolerance, f_relative_tolerance, x_tolerance,
-                     stopping_condition, max_iterations, bad_direction):
-  """Performs a line search in given direction, clamping to the bounds, and fixing the actively
-  constrained values to the given values."""
-  inactive = state.failed | state.converged | bad_direction
-  ls_result = _hz_line_search(state, value_and_gradients_function,
-                search_direction, max_iterations, inactive)
-
-  new_failed = state.failed | (~state.converged & ~ls_result.converged \
-                              & tf.reduce_any(search_direction != 0, axis=-1)) \
-                                & ~bad_direction
-  new_num_iterations = state.num_iterations + 1
-  new_num_objective_evaluations = (
-          state.num_objective_evaluations + ls_result.func_evals + 1)
-
-  if not tf.executing_eagerly():
-    # Hint the compiler that the properties' shape will not change
-    new_failed = tf.ensure_shape(
-      new_failed, state.failed.shape)
-    new_num_iterations = tf.ensure_shape(
-      new_num_iterations, state.num_iterations.shape)
-    new_num_objective_evaluations = tf.ensure_shape(
-      new_num_objective_evaluations, state.num_objective_evaluations.shape)
-
-  state_after_ls = bfgs_utils.update_fields(
-      state,
-      failed=new_failed,
-      num_iterations=new_num_iterations,
-      num_objective_evaluations=new_num_objective_evaluations)
-
-  def _do_update_position():
+  def _do_line_search_step():
+    """Do unconstrained line search."""
+    # Truncation bounds
     lower_term = tf.math.divide_no_nan(
-                  state.lower_bounds - state.position,
-                  search_direction)
+                    bfgs_state.lower_bounds - bfgs_state.position,
+                    search_direction)
     upper_term = tf.math.divide_no_nan(
-                  state.upper_bounds - state.position,
+                  bfgs_state.upper_bounds - bfgs_state.position,
                   search_direction)
+
+    # Truncate the search direction to bounds before search
+    bounds_clip = (
+      tf.reduce_min(
+          tf.where(
+            (search_direction > 0),
+            upper_term,
+            tf.where(
+              (search_direction < 0),
+              lower_term,
+              float('inf'))),
+          axis=-1)
+    )
+    pre_clip = (
+      tf.math.minimum(
+        bounds_clip,
+        1.)
+    )
+
+    clipped_search_direction = search_direction * pre_clip[..., tf.newaxis]
     
-    under_clip = tf.math.reduce_max(
-                  tf.where(
-                    (search_direction > 0),
-                    lower_term,
-                    tf.where(
-                      (search_direction < 0),
-                      upper_term,
-                      -float('inf'))),
-                  axis=-1)
-    over_clip = tf.math.reduce_min(
-                  tf.where(
-                    (search_direction > 0),
-                    upper_term,
-                    tf.where(
-                      (search_direction < 0),
-                      lower_term,
-                      float('inf'))),
-                  axis=-1)
+    ls_result = _hz_line_search(bfgs_state.position, bfgs_state.objective_value, bfgs_state.objective_gradient,
+                                value_and_gradients_function, clipped_search_direction,
+                                max_iterations, inactive)
+
+    new_failed = ((bfgs_state.failed | (~bfgs_state.converged & ~ls_result.converged)) & ~inactive)
+    new_num_iterations = bfgs_state.num_iterations + 1
+    new_num_objective_evaluations = (
+            bfgs_state.num_objective_evaluations + ls_result.func_evals + 1)
+
+    # Also truncate to bounds after search
+    step = (
+      tf.math.minimum(
+        bounds_clip,
+        ls_result.left.x
+      )
+    )
 
-    alpha_clip = tf.clip_by_value(
-                  ls_result.left.x,
-                  under_clip,
-                  over_clip)[..., tf.newaxis]
+    # Hint the compiler that the properties' shape will not change
+    if not tf.executing_eagerly():
+      new_failed = tf.ensure_shape(
+        new_failed, bfgs_state.failed.shape)
+      new_num_iterations = tf.ensure_shape(
+        new_num_iterations, bfgs_state.num_iterations.shape)
+      new_num_objective_evaluations = tf.ensure_shape(
+        new_num_objective_evaluations, bfgs_state.num_objective_evaluations.shape)
+
+    state_after_ls = bfgs_utils.update_fields(
+        state=bfgs_state,
+        failed=new_failed,
+        num_iterations=new_num_iterations,
+        num_objective_evaluations=new_num_objective_evaluations)
+    
+    return step, state_after_ls
+  
+  # NOTE: It's important that the default (false `pred`) step matches
+  # the shape of true `pred` shape for graph purposes
+  step, state_after_ls = (
+    tf.cond(
+      pred=tf.math.logical_not(tf.reduce_all(inactive)),
+      true_fn=_do_line_search_step,
+      false_fn=lambda: (tf.zeros_like(inactive, dtype=search_direction.dtype), bfgs_state)
+    ))
 
+  def _do_update_position():
+    """Update the position"""
     # For inactive batch members `left.x` is zero. However, their
     # `search_direction` might also be undefined, so we can't rely on
     # multiplication by zero to produce a `position_delta` of zero.
+    # Also, the search direction has already been clipped to make sure
+    # it does not go out of bounds.
     next_position = tf.where(
         inactive[..., tf.newaxis],
-        state.position,
-        state.position + search_direction * alpha_clip)
+        bfgs_state.position,
+        bfgs_state.position + step[..., tf.newaxis] * search_direction)
           
     # one extra call to the function, counted above
-    next_objective, next_gradient = \
+    next_objective, next_gradient = (
       value_and_gradients_function(next_position)
+    )
 
     return _update_position(
         state_after_ls,
@@ -925,10 +1170,11 @@ def _do_update_position():
         grad_tolerance,
         f_relative_tolerance,
         x_tolerance,
-        bad_direction)
+        inactive)
 
   return ps.cond(
-      stopping_condition(state.converged, state.failed),
+      (stopping_condition(bfgs_state.converged, bfgs_state.failed) &
+        tf.math.logical_not(tf.reduce_all(inactive))),
       true_fn=lambda: state_after_ls,
       false_fn=_do_update_position)
 
@@ -940,34 +1186,22 @@ def _update_position(state,
                      grad_tolerance,
                      f_relative_tolerance,
                      x_tolerance,
-                     erase_memory):
-  """Updates the state advancing its position by a given position_delta.
-  Also erases the LBFGS memory if indicated."""
+                     inactive):
+  """Updates the state advancing its position by a given position_delta."""
   state = bfgs_utils.terminate_if_not_finite(state, next_objective, next_gradient)
 
-  converged = ~state.failed & \
-                      _check_convergence_bounded(state.position,
-                                                 next_position,
-                                                 state.objective_value,
-                                                 next_objective,
-                                                 next_gradient,
-                                                 grad_tolerance,
-                                                 f_relative_tolerance,
-                                                 x_tolerance,
-                                                 state.lower_bounds,
-                                                 state.upper_bounds)
-  new_position_deltas = tf.where(
-                      erase_memory[..., tf.newaxis],
-                      tf.zeros_like(state.position_deltas),
-                      state.position_deltas)
-  new_gradient_deltas = tf.where(
-                      erase_memory[..., tf.newaxis],
-                      tf.zeros_like(state.gradient_deltas),
-                      state.gradient_deltas)
-  new_history = tf.where(
-              erase_memory,
-              tf.zeros_like(state.history),
-              state.history)
+  converged = (~inactive &
+                ~state.failed &
+                  _check_convergence_bounded(state.position,
+                                            next_position,
+                                            state.objective_value,
+                                            next_objective,
+                                            next_gradient,
+                                            grad_tolerance,
+                                            f_relative_tolerance,
+                                            x_tolerance,
+                                            state.lower_bounds,
+                                            state.upper_bounds))
   new_converged = (state.converged | converged)
 
   if not tf.executing_eagerly():
@@ -976,19 +1210,57 @@ def _update_position(state,
     next_position = tf.ensure_shape(next_position, state.position.shape)
     next_objective = tf.ensure_shape(next_objective, state.objective_value.shape)
     next_gradient = tf.ensure_shape(next_gradient, state.objective_gradient.shape)
-    new_position_deltas = tf.ensure_shape(new_position_deltas, state.position_deltas.shape)
-    new_gradient_deltas = tf.ensure_shape(new_gradient_deltas, state.gradient_deltas.shape)
-    new_history = tf.ensure_shape(new_history, state.history.shape)
 
   return bfgs_utils.update_fields(
       state,
       converged=new_converged,
       position=next_position,
       objective_value=next_objective,
-      objective_gradient=next_gradient,
-      position_deltas=new_position_deltas,
-      gradient_deltas=new_gradient_deltas,
-      history=new_history)
+      objective_gradient=next_gradient)
+
+
+def _erase_history(bfgs_state, where_erase):
+  """Erases the BFGS correction pairs for the specified batches.
+
+  This function will zero `gradient_deltas`, `position_deltas`, and `history`.
+
+  Args:
+    `bfgs_state`: a `LBfgsBOptimizerResults` to modify
+    `where_erase`: a Boolean tensor with shape matching the batch dimensions
+                  with `True` for the batches to erase the history of.
+  Returns:
+    Modified `bfgs_state`.
+  """
+  # Calculate new values
+  new_gradient_deltas = (tf.where(
+                        where_erase[..., tf.newaxis],
+                        0.,
+                        bfgs_state.gradient_deltas))
+  new_position_deltas = (tf.where(
+                        where_erase[..., tf.newaxis],
+                        0.,
+                        bfgs_state.position_deltas))
+  new_history = tf.where(where_erase, 0, bfgs_state.history)
+  # Assure the compiler that the shape of things has not changed
+  if not tf.executing_eagerly():
+    new_gradient_deltas = (
+      tf.ensure_shape(
+        new_gradient_deltas,
+        bfgs_state.gradient_deltas.shape))
+    new_position_deltas = (
+      tf.ensure_shape(
+        new_position_deltas,
+        bfgs_state.position_deltas.shape))
+    new_history = (
+      tf.ensure_shape(
+        new_history,
+        bfgs_state.history.shape))
+  # Update and return
+  return bfgs_utils.update_fields(
+          bfgs_state,
+          gradient_deltas=new_gradient_deltas,
+          position_deltas=new_position_deltas,
+          history=new_history)
 
 
 def _check_convergence_bounded(current_position,
@@ -1002,17 +1274,22 @@ def _check_convergence_bounded(current_position,
                        lower_bounds,
                        upper_bounds):
   """Checks if the algorithm satisfies the convergence criteria."""
+  # NOTE: The original algorithm (as described in [2]) only considers halting on
+  # the projected gradient condition. However, `x_converged` and `f_converged` do
+  # not seem to pose a problem when refreshing is correctly accounted for (so that
+  # the optimization does not halt upon a refresh), and the default values of `0`
+  # for `f_relative_tolerance` and `x_tolerance` further strengthen these conditions.
   proj_grad_converged = bfgs_utils.norm(
                           tf.clip_by_value(
                             next_position - next_gradient,
                             lower_bounds,
                             upper_bounds) - next_position, dims=1) <= grad_tolerance
   x_converged = bfgs_utils.norm(next_position - current_position, dims=1) <= x_tolerance
-  f_converged = bfgs_utils.norm(next_objective - current_objective, dims=0) <= \
-                  f_relative_tolerance * current_objective
+  f_converged = (
+    bfgs_utils.norm(next_objective - current_objective, dims=0) <= 
+      f_relative_tolerance * current_objective)
   return proj_grad_converged | x_converged | f_converged
 
-
 def _get_initial_state(value_and_gradients_function,
                        initial_position,
                        lower_bounds,
@@ -1033,50 +1310,56 @@ def _get_initial_state(value_and_gradients_function,
   return LBfgsBOptimizerResults(**init_args)
 
 
-def _get_initial_cauchy_state(state, num_correction_pairs):
-  """Create _ConstrainedCauchyState with initial parameters"""
+def _get_initial_cauchy_state(bfgs_state, num_correction_pairs):
+  """Create `_ConstrainedCauchyState` with initial parameters.
+  
+  This will calculate the elements of `_ConstrainedCauchyState` based on the given
+  `LBfgsBOptimizerResults` state object. Some of these properties may be incalculable,
+  for which batches the state will be reset.
+
+  Args:
+    bfgs_state: `LBfgsBOptimizerResults` object representing the current state of the
+      LBFGSB optimization
+    num_correction_pairs: typically `m`; the (maximum) number of past steps to keep as
+      history for the LBFGS algorithm
+  
+  Returns:
+    Initialized `_ConstrainedCauchyState`
+    Updated `bfgs_state`
+  """
   
   theta = tf.math.divide_no_nan(
-              tf.reduce_sum(state.gradient_deltas[-1, ...]**2, axis=-1),
-              tf.reduce_sum(state.gradient_deltas[-1,...] * state.position_deltas[-1, ...], axis=-1))
+              tf.reduce_sum(bfgs_state.gradient_deltas[-1, ...]**2, axis=-1),
+              (tf.reduce_sum(bfgs_state.gradient_deltas[-1,...] *
+                bfgs_state.position_deltas[-1, ...], axis=-1)))
   theta = tf.where(
             theta != 0,
             theta,
-            1.0)
+            1.)
 
   m, refresh = _cauchy_init_m(
-                  state,
-                  ps.shape(state.position_deltas),
+                  bfgs_state,
+                  ps.shape(bfgs_state.position_deltas),
                   theta,
                   num_correction_pairs)
+  
   # Erase the history where M isn't invertible
-  state = \
-    bfgs_utils.update_fields(
-      state,
-      gradient_deltas=tf.where(
-                        refresh[..., tf.newaxis],
-                        tf.zeros_like(state.gradient_deltas),
-                        state.gradient_deltas),
-      position_deltas=tf.where(
-                        refresh[..., tf.newaxis],
-                        tf.zeros_like(state.position_deltas),
-                        state.position_deltas),
-      history=tf.where(refresh, 0, state.history))
-  theta = tf.where(refresh, 1.0, theta)
-
-  breakpoints = _cauchy_init_breakpoints(state)
+  bfgs_state = _erase_history(bfgs_state, refresh)
+  theta = tf.where(refresh, 1., theta)
+
+  breakpoints = _cauchy_init_breakpoints(bfgs_state)
 
   steepest = tf.where(
               breakpoints != 0.,
-              -state.objective_gradient,
+              -bfgs_state.objective_gradient,
               0.)
 
   free_mask = (breakpoints > 0)
   free_vars_idx = tf.where(
                     free_mask,
                     tf.broadcast_to(
-                      tf.range(ps.shape(state.position)[-1], dtype=tf.int32),
-                      ps.shape(state.position)),
+                      tf.range(ps.shape(bfgs_state.position)[-1], dtype=tf.int32),
+                      ps.shape(bfgs_state.position)),
                     -1)
 
   # We need to account for the varying histories:
@@ -1089,28 +1372,28 @@ def _get_initial_cauchy_state(state, num_correction_pairs):
         [
           tf.einsum(
                   "m...i,...i->...m",
-                  state.gradient_deltas,
+                  bfgs_state.gradient_deltas,
                   steepest),
-          theta[..., tf.newaxis] * \
+          (theta[..., tf.newaxis] *
                 tf.einsum(
                   "m...i,...i->...m",
-                  state.position_deltas,
-                  steepest)
+                  bfgs_state.position_deltas,
+                  steepest))
         ],
         axis=-1)
   # 2. Assemble the rows in the correct order
   idx = tf.concat(
           [
             tf.ragged.range(
-              num_correction_pairs - state.history),
+              num_correction_pairs - bfgs_state.history),
             tf.ragged.range(
               num_correction_pairs,
-              2*num_correction_pairs - state.history),
+              2*num_correction_pairs - bfgs_state.history),
             tf.ragged.range(
-              num_correction_pairs - state.history,
+              num_correction_pairs - bfgs_state.history,
               num_correction_pairs),
             tf.ragged.range(
-              2*num_correction_pairs - state.history,
+              2*num_correction_pairs - bfgs_state.history,
               2*num_correction_pairs)
           ],
           axis=-1).to_tensor()
@@ -1125,26 +1408,31 @@ def _get_initial_cauchy_state(state, num_correction_pairs):
   ddf = -theta*df - tf.einsum("...i,...ij,...j->...", p, m, p)
   dt_min = -tf.math.divide_no_nan(df, ddf)
 
-  breakpoint_min_idx, breakpoint_min = \
-    _cauchy_get_breakpoint_min(breakpoints, free_vars_idx)
+  breakpoint_min_idx, breakpoint_min = (
+    _cauchy_get_breakpoint_min(breakpoints, free_vars_idx))
 
   dt = breakpoint_min
 
   breakpoint_min_old = tf.zeros_like(breakpoint_min)
 
-  cauchy_point = state.position
+  cauchy_point = bfgs_state.position
 
-  active = ~(state.converged | state.failed) & \
-              _cauchy_update_active(free_vars_idx, dt_min, dt)
+  active = (~(bfgs_state.converged | bfgs_state.failed) &
+              _cauchy_update_active(free_vars_idx, breakpoints, dt_min, dt))
 
-  return _ConstrainedCauchyState(
-    theta, m, breakpoints, steepest, free_vars_idx, free_mask,
-    p, c, df, ddf, dt_min, breakpoint_min, breakpoint_min_idx,
-    dt, breakpoint_min_old, cauchy_point, active)
+  cauchy_state = (
+    _ConstrainedCauchyState(
+      theta, m, breakpoints, steepest, free_vars_idx, free_mask,
+      p, c, df, ddf, dt_min, breakpoint_min, breakpoint_min_idx,
+      dt, breakpoint_min_old, cauchy_point, active))
+  
+  return cauchy_state, bfgs_state
 
 
 def _cauchy_init_m(state, deltas_shape, theta, num_correction_pairs):
+  """Initialize the M matrix for a `_CauchyMinimizationResult` state."""
   def build_m():
+    """Construct and invert the M block matrix."""
     # All of the below block matrices have dimensions [..., m, m]
     #  where `...` denotes the batch dimensions, and `m` the number
     #  of correction pairs (compare to `deltas_shape`, which is [m,...,n]).
@@ -1176,18 +1464,23 @@ def build_m():
 
     # Assemble into full matrix
     # TODO: Is there no better way to create a block matrix?
-    block_d = tf.concat([-d, tf.zeros_like(d)], axis=-1)
-    block_d = tf.concat([block_d, tf.zeros_like(block_d)], axis=-2)
-    block_l_transpose = tf.concat([tf.zeros_like(l_transpose), l_transpose], axis=-1)
-    block_l_transpose = tf.concat([block_l_transpose, tf.zeros_like(block_l_transpose)], axis=-2)
-    block_l = tf.concat([l, tf.zeros_like(l)], axis=-1)
-    block_l = tf.concat([tf.zeros_like(block_l), block_l], axis=-2)
-    block_s_t_s = tf.concat([tf.zeros_like(s_t_s), s_t_s], axis=-1)
-    block_s_t_s = tf.concat([tf.zeros_like(block_s_t_s), block_s_t_s], axis=-2)
+    m_inv = tf.concat(
+              [
+                tf.concat([-d, l_transpose], axis=-1),
+                tf.concat([l, theta[..., tf.newaxis, tf.newaxis] * s_t_s], axis=-1)
+              ], axis=-2)
+    #block_d = tf.concat([-d, tf.zeros_like(d)], axis=-1)
+    #block_d = tf.concat([block_d, tf.zeros_like(block_d)], axis=-2)
+    #block_l_transpose = tf.concat([tf.zeros_like(l_transpose), l_transpose], axis=-1)
+    #block_l_transpose = tf.concat([block_l_transpose, tf.zeros_like(block_l_transpose)], axis=-2)
+    #block_l = tf.concat([l, tf.zeros_like(l)], axis=-1)
+    #block_l = tf.concat([tf.zeros_like(block_l), block_l], axis=-2)
+    #block_s_t_s = tf.concat([tf.zeros_like(s_t_s), s_t_s], axis=-1)
+    #block_s_t_s = tf.concat([tf.zeros_like(block_s_t_s), block_s_t_s], axis=-2)
 
     # shape [b, 2m, 2m]
-    m_inv = block_d + block_l_transpose + block_l + \
-              theta[..., tf.newaxis, tf.newaxis] * block_s_t_s
+    #m_inv = (block_d + block_l_transpose + block_l +
+    #          theta[..., tf.newaxis, tf.newaxis] * block_s_t_s)
     
     # Adjust for varying history:
     # Push columns indexed h,...,2m-h to the left (but to the right of 0...m-h)
@@ -1210,9 +1503,9 @@ def build_m():
               batch_dims=1)
 
     # Insert an identity in the empty block
-    identity_mask = \
-      (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] < \
-        2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis]
+    identity_mask = (
+      (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] <
+        2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis])
     
     m_inv = tf.where(
               identity_mask,
@@ -1249,7 +1542,8 @@ def build_m():
 
 
 def _cauchy_init_breakpoints(state):
-  breakpoints = \
+  """Calculate the breakpoints for a `_CauchyMinimizationResult` state."""
+  breakpoints = (
     tf.where(
       state.objective_gradient < 0,
       tf.math.divide_no_nan(
@@ -1261,6 +1555,7 @@ def _cauchy_init_breakpoints(state):
           state.position - state.lower_bounds,
           state.objective_gradient),
         float('inf')))
+  )
 
   return breakpoints
 
@@ -1271,6 +1566,18 @@ def _cauchy_remove_breakpoint_min(free_vars_idx,
                                   active):
   """Update the free variable indices to remove the minimum breakpoint index.
 
+  This will set the `breakpoint_min_idx`th entry of `free_mask` to `False`,
+  and of `free_vars_idx` to `-1`.
+
+  Args:
+    free_vars_idx: tensor of shape [batch, dims] where each entry is the index of the
+      entry for the batch if the corresponding variable is free, and -1 otherwise
+    breakpoint_min_idx: tensor of shape [batch] denoting the indices to mark as
+      constrained for each batch
+    free_mask: tensor of shape [batch, dims] where `True` denotes a free variable, and
+      `False` an actively constrained variable
+    active: tensor of shape [batch] denoting whether each batch should be updated
+
   Returns:
     Updated `free_vars_idx`, `free_mask`
   """
@@ -1280,28 +1587,35 @@ def _cauchy_remove_breakpoint_min(free_vars_idx,
   #  every element of free_vars_idx is -1, and so there is no match.
   matching = (free_vars_idx == breakpoint_min_idx[..., tf.newaxis])
   free_vars_idx = tf.where(
-                    matching,
+                    active[..., tf.newaxis] & matching,
                     -1,
                     free_vars_idx)
   free_mask = tf.where(
                 active[..., tf.newaxis],
                 free_vars_idx >= 0,
                 free_mask)
-  
   return free_vars_idx, free_mask
 
 
 def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx):
-  """Find the smallest breakpoint of free indices, returning the minimum breakpoint
-  and the corresponding index.
+  """Find the smallest breakpoint of free indices.
+
+  If every breakpoint is equal, this function will return the first found variable
+  that is not actively constrained.
+
+  Args:
+    breakpoints: tensor of breakpoints as initialized in a `_CauchyMinimizationResult`
+      state
+    free_vars_idx: tensor denoting free and constrained variables, as initialized in
+      a `_CauchyMinimizationResult` state
 
   Returns:
-    Tuple of `breakpoint_min_idx`, `breakpoint_min`
-    where
-      `breakpoint_min_idx` is the index that has min. breakpoint
-      `breakpoint_min` is the corresponding breakpoint
+    Index that has min. breakpoint
+    Corresponding breakpoint
   """
-  # A tensor of shape [batch, dims] that has +infinity where free_vars_idx < 0,
+  no_free = (~tf.reduce_any(free_vars_idx >= 0, axis=-1))
+  
+  # A tensor of shape [batch, dims] that has inf where free_vars_idx < 0,
   #  and has breakpoints[free_vars_idx] otherwise.
   flagged_breakpoints = tf.where(
                           free_vars_idx < 0,
@@ -1319,6 +1633,36 @@ def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx):
                 axis=-1,
                 output_type=tf.int32)
   
+  # Sometimes free variables have 'inf' breakpoints, and then there
+  # is no guarantee that argmin will not have picked a constrained variable
+  # In this case, grab the first free variable by iterating along the variables
+  # until one is free
+
+  def _check_gathered(active, _):
+    """Whether we are still looking for a free variable."""
+    return tf.reduce_any(active)
+  
+  def _get_first(active, new_idx):
+    """Check if next variable is free."""
+    new_idx = tf.where(active, new_idx+1, new_idx)
+    active =  (~no_free & 
+                (tf.gather(
+                  free_vars_idx,
+                  new_idx,
+                  batch_dims=1) < 0))
+    return [active, new_idx]
+
+  active = (~no_free & 
+              (tf.gather(
+                free_vars_idx,
+                argmin_idx,
+                batch_dims=1) < 0))
+  _, argmin_idx = (
+    tf.while_loop(
+      cond=_check_gathered,
+      body=_get_first,
+      loop_vars=[active, argmin_idx]))
+  
   # NOTE: For situations where there are no more free indices
   #  (and therefore argmin_idx indexes into -1), we set
   #  breakpoint_min_idx to 0 and flag that there are no free
@@ -1327,10 +1671,6 @@ def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx):
   #  This is because in branching situations, indexing with
   #  breakpoint_min_idx can occur, and later be discarded, but all
   #  elements in breakpoint_min_idx must be a priori valid indices.
-  no_free = tf.gather(
-              free_vars_idx,
-              argmin_idx,
-              batch_dims=1) < 0
   breakpoint_min_idx = tf.where(
                         no_free,
                         0,
@@ -1349,185 +1689,6 @@ def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx):
   return breakpoint_min_idx, breakpoint_min
 
 
-def _get_search_direction(state):
-  """Computes the search direction to follow at the current state.
-
-  On the `k`-th iteration of the main L-BFGS algorithm, the state has collected
-  the most recent `m` correction pairs in position_deltas and gradient_deltas,
-  where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`.
-
-  Assuming these, the code below is an implementation of the L-BFGS two-loop
-  recursion algorithm given by [Nocedal and Wright(2006)][1]:
-
-  ```None
-    q_direction = objective_gradient
-    for i in reversed(range(m)):  # First loop.
-      inv_rho[i] = gradient_deltas[i]^T * position_deltas[i]
-      alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i]
-      q_direction = q_direction - alpha[i] * gradient_deltas[i]
-
-    kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] /
-                              gradient_deltas[-1]^T * gradient_deltas[-1])
-    r_direction = kth_inv_hessian_factor * I * q_direction
-
-    for i in range(m):  # Second loop.
-      beta = gradient_deltas[i]^T * r_direction / inv_rho[i]
-      r_direction = r_direction + position_deltas[i] * (alpha[i] - beta)
-
-    return -r_direction  # Approximates - H_k * objective_gradient.
-  ```
-
-  Args:
-    state: A `LBfgsBOptimizerResults` tuple with the current state of the
-      search procedure.
-
-  Returns:
-    A real `Tensor` of the same shape as the `state.position`. The direction
-    along which to perform line search.
-  """
-  # The number of correction pairs that have been collected so far.
-  #num_elements = ps.minimum(
-  #    state.num_iterations,  # TODO(b/162733947): Change loop state -> closure.
-  #    ps.shape(state.position_deltas)[0])
-
-  def _two_loop_algorithm():
-    """L-BFGS two-loop algorithm."""
-    # Correction pairs are always appended to the end, so only the latest
-    # `num_elements` vectors have valid position/gradient deltas. Vectors
-    # that haven't been computed yet are zero.
-    position_deltas = state.position_deltas
-    gradient_deltas = state.gradient_deltas
-    num_correction_pairs, num_batches, _point_dims = \
-      ps.shape(gradient_deltas, out_type=tf.int32)
-
-    # Pre-compute all `inv_rho[i]`s.
-    inv_rhos = tf.reduce_sum(
-        gradient_deltas * position_deltas, axis=-1)
-
-    def first_loop(acc, args):
-      _, q_direction, num_iter = acc
-      position_delta, gradient_delta, inv_rho = args
-      active = (num_iter < state.history)
-      alpha = tf.math.divide_no_nan(
-                tf.reduce_sum(
-                  position_delta * q_direction,
-                  axis=-1),
-                inv_rho)
-      direction_delta = alpha[..., tf.newaxis] * gradient_delta
-      new_q_direction = tf.where(
-                          active[..., tf.newaxis],
-                          q_direction - direction_delta,
-                          q_direction)
-
-      return (alpha, new_q_direction, num_iter + 1)
-
-    # Run first loop body computing and collecting `alpha[i]`s, while also
-    # computing the updated `q_direction` at each step.
-    zero = tf.zeros_like(inv_rhos[0])
-    alphas, q_directions, _num_iters = tf.scan(
-        first_loop, [position_deltas, gradient_deltas, inv_rhos],
-        initializer=(zero, state.objective_gradient, 0), reverse=True)
-
-    # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse
-    # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`.
-    idx = tf.transpose(
-            tf.stack(
-              [tf.where(
-                state.history > 0,
-                num_correction_pairs - state.history,
-                0),
-              tf.range(num_batches)]))
-    gamma_k = tf.math.divide_no_nan(
-                tf.gather_nd(inv_rhos, idx),
-                tf.reduce_sum(
-                  tf.gather_nd(gradient_deltas, idx)**2,
-                  axis=-1))
-    gamma_k = tf.where(
-                (state.history > 0),
-                gamma_k,
-                1.0)
-    r_direction = gamma_k[..., tf.newaxis] * tf.gather_nd(q_directions, idx)
-
-    def second_loop(acc, args):
-      r_direction, iter_idx = acc
-      alpha, position_delta, gradient_delta, inv_rho = args
-      active = (iter_idx >= num_correction_pairs - state.history)
-      beta = tf.math.divide_no_nan(
-              tf.reduce_sum(
-                gradient_delta * r_direction,
-                axis=-1),
-              inv_rho)
-      direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta
-      new_r_direction = tf.where(
-                          active[..., tf.newaxis],
-                          r_direction + direction_delta,
-                          r_direction)
-      return (new_r_direction, iter_idx + 1)
-
-    # Finally, run second loop body computing the updated `r_direction` at each
-    # step.
-    r_directions, _num_iters = tf.scan(
-        second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos],
-        initializer=(r_direction, 0))
-
-    return -r_directions[-1]
-
-  return ps.cond(tf.reduce_any(state.history != 0),
-                 _two_loop_algorithm,
-                 lambda: -state.objective_gradient)
-
-
-def _get_ragged_sizes(tensor, dtype=tf.int32):
-  """Creates a tensor indicating the size of each component of
-  a ragged dimension.
-
-  For example:
-
-  ```python
-  element = tf.ragged.constant([[1,2], [3,4,5], [], [0]])
-  _get_ragged_sizes(element)
-  # => <tf.Tensor: shape=(4, 1), dtype=int32, numpy=
-  #      array([[2],
-  #             [3],
-  #             [0],
-  #             [1]], dtype=int32)>
-  ```
-  """
-  return tf.reduce_sum(
-            tf.ones_like(
-              tensor,
-              dtype=dtype),
-            axis=-1)[..., tf.newaxis]
-
-
-def _get_range_like_ragged(tensor, dtype=tf.int32):
-  """Creates a batched range for the elements of the batched tensor.
-
-  For example:
-
-  ```python
-  element = tf.ragged.constant([[1,2], [3,4,5], [], [0]])
-  _get_range_like_ragged(element)
-  # => <tf.RaggedTensor [[0, 1], [0, 1, 2], [], [0]]>
-
-  Args:
-    tensor: a RaggedTensor of shape `[n, None]`.
-
-  Returns:
-    A ragged tensor of shape `[n, None]` where the ragged dimensions
-    match the ragged dimensions of `tensor`, and are a range from `0` to
-    the size of the ragged dimension.
-  ```
-  """
-  sizes = _get_ragged_sizes(tensor)
-  flat_ranges = tf.ragged.range(
-                  tf.reshape(
-                    sizes,
-                    [tf.reduce_prod(sizes.shape)]),
-                  dtype=dtype)
-  return tf.RaggedTensor.from_row_lengths(flat_ranges, sizes.shape[:-1])[0]
-
-
 def _make_empty_queue_for(k, element):
   """Creates a `tf.Tensor` suitable to hold `k` element-shaped tensors.
 

From 41b22679c6df83d8e9ea77146b1d010b7054ac40 Mon Sep 17 00:00:00 2001
From: mikeevmm <miguelmurca@gmail.com>
Date: Fri, 25 Jun 2021 14:49:44 +0100
Subject: [PATCH 3/4] refact: minor optimizations and documentation
 refactoring.

---
 .../python/optimizer/lbfgs.py                 | 1752 ++++++++++++++---
 1 file changed, 1434 insertions(+), 318 deletions(-)

diff --git a/tensorflow_probability/python/optimizer/lbfgs.py b/tensorflow_probability/python/optimizer/lbfgs.py
index 886a3084ef..e4382b5bf6 100644
--- a/tensorflow_probability/python/optimizer/lbfgs.py
+++ b/tensorflow_probability/python/optimizer/lbfgs.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ============================================================================
-"""The Limited-Memory BFGS minimization algorithm.
+"""A constrained version of the Limited-Memory BFGS minimization algorithm.
 
 Limited-memory quasi-Newton methods are useful for solving large problems
 whose Hessian matrices cannot be computed at a reasonable cost or are not
@@ -20,8 +20,8 @@
 matrices, they only save a few vectors of length n that represent the
 approximations implicitly.
 
-This module implements the algorithm known as L-BFGS, which, as its name
-suggests, is a limited-memory version of the BFGS algorithm.
+This module implements the algorithm known as L-BFGS-B, which, as its name
+suggests, is a limited-memory version of the BFGS algorithm, with bounds.
 """
 from __future__ import absolute_import
 from __future__ import division
@@ -35,183 +35,261 @@
 from tensorflow_probability.python.internal import dtype_util
 from tensorflow_probability.python.internal import prefer_static as ps
 from tensorflow_probability.python.optimizer import bfgs_utils
-
-
-LBfgsOptimizerResults = collections.namedtuple(
-    'LBfgsOptimizerResults', [
-        'converged',  # Scalar boolean tensor indicating whether the minimum
-                      # was found within tolerance.
-        'failed',  # Scalar boolean tensor indicating whether a line search
-                   # step failed to find a suitable step size satisfying Wolfe
-                   # conditions. In the absence of any constraints on the
-                   # number of objective evaluations permitted, this value will
-                   # be the complement of `converged`. However, if there is
-                   # a constraint and the search stopped due to available
-                   # evaluations being exhausted, both `failed` and `converged`
-                   # will be simultaneously False.
-        'num_iterations',  # The number of iterations of the BFGS update.
-        'num_objective_evaluations',  # The total number of objective
-                                      # evaluations performed.
-        'position',  # A tensor containing the last argument value found
-                     # during the search. If the search converged, then
-                     # this value is the argmin of the objective function.
-        'objective_value',  # A tensor containing the value of the objective
-                            # function at the `position`. If the search
-                            # converged, then this is the (local) minimum of
-                            # the objective function.
-        'objective_gradient',  # A tensor containing the gradient of the
-                               # objective function at the
-                               # `final_position`. If the search converged
-                               # the max-norm of this tensor should be
-                               # below the tolerance.
-        'position_deltas',  # A tensor encoding information about the latest
-                            # changes in `position` during the algorithm
-                            # execution. Its shape is of the form
-                            # `(num_correction_pairs,) + position.shape` where
-                            # `num_correction_pairs` is given as an argument to
-                            # the minimize function.
-        'gradient_deltas',  # A tensor encoding information about the latest
-                            # changes in `objective_gradient` during the
-                            # algorithm execution. Has the same shape as
-                            # position_deltas.
-    ])
+from tensorflow_probability.python.optimizer import lbfgs_minimize
+
+
+LBfgsBOptimizerResults = collections.namedtuple(
+  'LBfgsBOptimizerResults', [
+    'converged',  # Scalar boolean tensor indicating whether the minimum
+            # was found within tolerance.
+    'failed',  # Scalar boolean tensor indicating whether a line search
+           # step failed to find a suitable step size satisfying Wolfe
+           # conditions. In the absence of any constraints on the
+           # number of objective evaluations permitted, this value will
+           # be the complement of `converged`. However, if there is
+           # a constraint and the search stopped due to available
+           # evaluations being exhausted, both `failed` and `converged`
+           # will be simultaneously False.
+    'num_iterations',  # The number of iterations of the BFGS update.
+    'num_objective_evaluations',  # The total number of objective
+                    # evaluations performed.
+    'position',  # A tensor containing the last argument value found
+           # during the search. If the search converged, then
+           # this value is the argmin of the objective function.
+    'lower_bounds',  # A tensor containing the lower bounds to the constrained
+    # optimization, cast to the shape of `position`.
+    'upper_bounds',  # A tensor containing the upper bounds to the constrained
+    # optimization, cast to the shape of `position`.
+    'objective_value',  # A tensor containing the value of the objective
+              # function at the `position`. If the search
+              # converged, then this is the (local) minimum of
+              # the objective function.
+    'objective_gradient',  # A tensor containing the gradient of the
+                 # objective function at the
+                 # `final_position`. If the search converged
+                 # the max-norm of this tensor should be
+                 # below the tolerance.
+    'position_deltas',  # A tensor encoding information about the latest
+              # changes in `position` during the algorithm
+              # execution. Its shape is of the form
+              # `(num_correction_pairs,) + position.shape` where
+              # `num_correction_pairs` is given as an argument to
+              # the minimize function.
+    'gradient_deltas',  # A tensor encoding information about the latest
+              # changes in `objective_gradient` during the
+              # algorithm execution. Has the same shape as
+              # position_deltas.
+    'history',  # How many gradient/position deltas should be considered.
+  ])
+
+_ConstrainedCauchyState = collections.namedtuple(
+  '_CauchyMinimizationResult', [
+    # `\theta` in [2]; n the Cauchy search, relates to the implicit Hessian
+    'theta',
+    # `B = \theta*I - WMW'` (`I` the identity, see [1,2] for details)
+    # `M_k` matrix in [2]; part of the implicit representation of the Hessian,
+    'm',
+    # see the comment above
+    'breakpoints',  # `t_i` in [Byrd et al.][2];
+    # the breakpoints in the branch definition of the
+    # projection of the gradients, batched
+    'breakpoints_argsort',  # Range from 0...n-1 sorted by increasing breakpoints
+    # Tensor of shape [batch]; the index into `breakpoints_argsort`
+    'next_free_idx',
+    # for the breakpoint in effect
+    'steepest',  # `d` in [2]; steepest descent clamped to bounds
+    'p',  # as in [2]
+    # as in [2]; eventually made to equal `W'(cauchy_point - position)`
+    'c',
+    'df',  # `f'` in [2]
+    'ddf',  # `f''` in [2]
+    'dt',  # `\Delta t` in [2]
+    'dt_min',  # `\Delta t_min` in [2]
+    'tsum',  # Sum of all the considered breakpoints so far
+    'breakpoint_min_old',  # t_old in [2]
+    # `x^cp` in [2]; the actual cauchy point (we're looking for)
+    'cauchy_point',
+    'active',  # What batches are in active optimization
+    'free_mask',  # Boolean tensor of what variables are actively constrained
+  ])
 
 
 def minimize(value_and_gradients_function,
-             initial_position,
-             previous_optimizer_results=None,
-             num_correction_pairs=10,
-             tolerance=1e-8,
-             x_tolerance=0,
-             f_relative_tolerance=0,
-             initial_inverse_hessian_estimate=None,
-             max_iterations=50,
-             parallel_iterations=1,
-             stopping_condition=None,
-             max_line_search_iterations=50,
-             name=None):
-  """Applies the L-BFGS algorithm to minimize a differentiable function.
-
-  Performs unconstrained minimization of a differentiable function using the
-  L-BFGS scheme. See [Nocedal and Wright(2006)][1] for details of the algorithm.
+       initial_position,
+       bounds=None,
+       previous_optimizer_results=None,
+       num_correction_pairs=10,
+       tolerance=1e-5,
+       x_tolerance=0,
+       f_relative_tolerance=0,
+       initial_inverse_hessian_estimate=None,
+       max_iterations=50,
+       parallel_iterations=1,
+       stopping_condition=None,
+       max_line_search_iterations=50,
+       name=None):
+  """Applies the L-BFGS-B algorithm to minimize a differentiable function.
+
+  Performs optionally constrained minimization of a differentiable function using the
+  L-BFGS-B scheme. See [Nocedal and Wright(2006)][1] for details on the unconstrained
+  version, and [Byrd et al.][2] for details on the constrained algorithm.
 
   ### Usage:
 
-  The following example demonstrates the L-BFGS optimizer attempting to find the
-  minimum for a simple high-dimensional quadratic objective function.
+  The following example demonstrates the L-BFGS-B optimizer attempting to find the
+  constrained minimum for a simple high-dimensional quadratic objective function.
 
   ```python
-    # A high-dimensional quadratic bowl.
-    ndims = 60
-    minimum = np.ones([ndims], dtype='float64')
-    scales = np.arange(ndims, dtype='float64') + 1.0
-
-    # The objective function and the gradient.
-    def quadratic_loss_and_gradient(x):
-      return tfp.math.value_and_gradient(
-          lambda x: tf.reduce_sum(
-              scales * tf.math.squared_difference(x, minimum), axis=-1),
-          x)
-    start = np.arange(ndims, 0, -1, dtype='float64')
-    optim_results = tfp.optimizer.lbfgs_minimize(
-        quadratic_loss_and_gradient,
-        initial_position=start,
-        num_correction_pairs=10,
-        tolerance=1e-8)
-
-    # Check that the search converged
-    assert(optim_results.converged)
-    # Check that the argmin is close to the actual value.
-    np.testing.assert_allclose(optim_results.position, minimum)
+  ndims = 60
+  minimum = tf.convert_to_tensor(
+    np.ones([ndims]), dtype=tf.float32)
+  lower_bounds = tf.convert_to_tensor(
+    np.arange(ndims), dtype=tf.float32)
+  upper_bounds = tf.convert_to_tensor(
+    np.arange(100, 100-ndims, -1), dtype=tf.float32)
+  scales = tf.convert_to_tensor(
+    (np.random.rand(ndims) + 1.)*5. + 1., dtype=tf.float32)
+  start = tf.constant(np.random.rand(2, ndims)*100, dtype=tf.float32)
+
+  # The objective function and the gradient.
+  def quadratic_loss_and_gradient(x):
+    return tfp.math.value_and_gradient(
+      lambda x: tf.reduce_sum(
+        scales * tf.math.squared_difference(x, minimum), axis=-1),
+      x)
+  opt_results = tfp.optimizer.lbfgsb_minimize(
+          quadratic_loss_and_gradient,
+          initial_position=start,
+          num_correction_pairs=10,
+          tolerance=1e-10,
+          bounds=[lower_bounds, upper_bounds])
   ```
 
   ### References:
 
   [1] Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series
-      in Operations Research. pp 176-180. 2006
+    in Operations Research. pp 176-180. 2006
 
   http://pages.mtu.edu/~struther/Courses/OLD/Sp2013/5630/Jorge_Nocedal_Numerical_optimization_267490.pdf
 
+  [2] Richard H. Byrd, Peihuang Lu, Jorge Nocedal, & Ciyou Zhu (1995).
+    A Limited Memory Algorithm for Bound Constrained Optimization
+    SIAM Journal on Scientific Computing, 16(5), 1190–1208.
+
+  https://doi.org/10.1137/0916069
+
+  [3] Jose Luis Morales, Jorge Nocedal (2011).
+    "Remark On Algorithm 788: L-BFGS-B: Fortran Subroutines for Large-Scale
+      Bound Constrained Optimization"
+    ACM Trans. Math. Softw. 38, 1, Article 7.
+
+  https://dl.acm.org/doi/abs/10.1145/2049662.2049669
+
   Args:
     value_and_gradients_function:  A Python callable that accepts a point as a
-      real `Tensor` and returns a tuple of `Tensor`s of real dtype containing
-      the value of the function and its gradient at that point. The function
-      to be minimized. The input is of shape `[..., n]`, where `n` is the size
-      of the domain of input points, and all others are batching dimensions.
-      The first component of the return value is a real `Tensor` of matching
-      shape `[...]`. The second component (the gradient) is also of shape
-      `[..., n]` like the input value to the function.
+    real `Tensor` and reporting arguments, and returns a tuple of `Tensor`s of
+    real dtype containing the value of the function and its gradient at that
+    point. The function to be minimized. The input is of shape `[..., n]`,
+    where `n` is the size of the domain of input points, and all others are
+    batching dimensions. The first component of the return value is a real
+    `Tensor` of matching shape `[...]`. The second component (the gradient) is
+    also of shape `[..., n]` like the input value to the function.
+    The reporting arguments consist of a Boolean `Tensor` of shape `[...]`
+    denoting which batches have terminated, and two real `Tensor` of shape
+    `[..., n]`, denoting the last evaluated objective values and gradients
+    (respectively). 
     initial_position: Real `Tensor` of shape `[..., n]`. The starting point, or
-      points when using batching dimensions, of the search procedure. At these
-      points the function value and the gradient norm should be finite.
-      Exactly one of `initial_position` and `previous_optimizer_results` can be
-      non-None.
-    previous_optimizer_results: An `LBfgsOptimizerResults` namedtuple to
-      intialize the optimizer state from, instead of an `initial_position`.
-      This can be passed in from a previous return value to resume optimization
-      with a different `stopping_condition`. Exactly one of `initial_position`
-      and `previous_optimizer_results` can be non-None.
+    points when using batching dimensions, of the search procedure. At these
+    points the function value and the gradient norm should be finite.
+    Exactly one of `initial_position` and `previous_optimizer_results` can be
+    non-None.
+    bounds: Tuple of two real `Tensor`s of shape `[..., n]`. The first element
+    indicates the lower bounds in the constrained optimization, and the second
+    element of the tuple indicates the upper bounds of the optimization. If
+    `bounds` is `None`, the optimization is deferred to the unconstrained
+    version (see also `lbfgs_minimize`). If one of the elements of the tuple
+    is `None`, the optimization is assumed to be unconstrained (from above/below,
+    respectively). 
+    previous_optimizer_results: An `LBfgsBOptimizerResults` namedtuple to
+    intialize the optimizer state from, instead of an `initial_position`.
+    This can be passed in from a previous return value to resume optimization
+    with a different `stopping_condition`. Exactly one of `initial_position`
+    and `previous_optimizer_results` can be non-None.
     num_correction_pairs: Positive integer. Specifies the maximum number of
-      (position_delta, gradient_delta) correction pairs to keep as implicit
-      approximation of the Hessian matrix.
+    (position_delta, gradient_delta) correction pairs to keep as implicit
+    approximation of the Hessian matrix
+    A real `Tensor` of the same shape as the `state.position`, of dtype `bool`,
+    denoting a mask over the free variables.x.
     tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance
-      for the procedure. If the supremum norm of the gradient vector is below
-      this number, the algorithm is stopped.
+    for the procedure. If the supremum norm of the gradient vector is below
+    this number, the algorithm is stopped.
     x_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the
-      position between one iteration and the next is smaller than this number,
-      the algorithm is stopped.
+    position between one iteration and the next is smaller than this number,
+    the algorithm is stopped.
     f_relative_tolerance: Scalar `Tensor` of real dtype. If the relative change
-      in the objective value between one iteration and the next is smaller
-      than this value, the algorithm is stopped.
+    in the objective value between one iteration and the next is smaller
+    than this value, referenced to the current objective value, the previous
+    objective value, or `1`, whichever is greatest, the algorithm is stopped.
     initial_inverse_hessian_estimate: None. Option currently not supported.
     max_iterations: Scalar positive int32 `Tensor`. The maximum number of
-      iterations for L-BFGS updates.
+    iterations for L-BFGS updates.
     parallel_iterations: Positive integer. The number of iterations allowed to
-      run in parallel.
+    run in parallel.
     stopping_condition: (Optional) A Python function that takes as input two
-      Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor.
-      The input tensors are `converged` and `failed`, indicating the current
-      status of each respective batch member; the return value states whether
-      the algorithm should stop. The default is tfp.optimizer.converged_all
-      which only stops when all batch members have either converged or failed.
-      An alternative is tfp.optimizer.converged_any which stops as soon as one
-      batch member has converged, or when all have failed.
+    Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor.
+    The input tensors are `converged` and `failed`, indicating the current
+    status of each respective batch member; the return value states whether
+    the algorithm should stop. The default is tfp.optimizer.converged_all
+    which only stops when all batch members have either converged or failed.
+    An alternative is tfp.optimizer.converged_any which stops as soon as one
+    batch member has converged, or when all have failed.
     max_line_search_iterations: Python int. The maximum number of iterations
-      for the `hager_zhang` line search algorithm.
+    for the `hager_zhang` line search algorithm.
     name: (Optional) Python str. The name prefixed to the ops created by this
-      function. If not supplied, the default name 'minimize' is used.
+    function. If not supplied, the default name 'minimize' is used.
 
   Returns:
     optimizer_results: A namedtuple containing the following items:
-      converged: Scalar boolean tensor indicating whether the minimum was
-        found within tolerance.
-      failed:  Scalar boolean tensor indicating whether a line search
-        step failed to find a suitable step size satisfying Wolfe
-        conditions. In the absence of any constraints on the
-        number of objective evaluations permitted, this value will
-        be the complement of `converged`. However, if there is
-        a constraint and the search stopped due to available
-        evaluations being exhausted, both `failed` and `converged`
-        will be simultaneously False.
-      num_objective_evaluations: The total number of objective
-        evaluations performed.
-      position: A tensor containing the last argument value found
-        during the search. If the search converged, then
-        this value is the argmin of the objective function.
-      objective_value: A tensor containing the value of the objective
-        function at the `position`. If the search converged, then this is
-        the (local) minimum of the objective function.
-      objective_gradient: A tensor containing the gradient of the objective
-        function at the `position`. If the search converged the
-        max-norm of this tensor should be below the tolerance.
-      position_deltas: A tensor encoding information about the latest
-        changes in `position` during the algorithm execution.
-      gradient_deltas: A tensor encoding information about the latest
-        changes in `objective_gradient` during the algorithm execution.
+    converged: Scalar boolean tensor indicating whether the minimum was
+      found within tolerance.
+    failed:  Scalar boolean tensor indicating whether a line search
+      step failed to find a suitable step size satisfying Wolfe
+      conditions. In the absence of any constraints on the
+      number of objective evaluations permitted, this value will
+      be the complement of `converged`. However, if there is
+      a constraint and the search stopped due to available
+      evaluations being exhausted, both `failed` and `converged`
+      will be simultaneously False.
+    num_objective_evaluations: The total number of objective
+      evaluations performed.
+    position: A tensor containing the last argument value found
+      during the search. If the search converged, then
+      this value is the argmin of the objective function.
+    objective_value: A tensor containing the value of the objective
+      function at the `position`. If the search converged, then this is
+      the (local) minimum of the objective function.
+    objective_gradient: A tensor containing the gradient of the objective
+      function at the `position`. If the search converged the
+      max-norm of this tensor should be below the tolerance.
+    position_deltas: A tensor encoding information about the latest
+      changes in `position` during the algorithm execution.
+    gradient_deltas: A tensor encoding information about the latest
+      changes in `objective_gradient` during the algorithm execution.
   """
+
+  if len(bounds) != 2:
+    raise ValueError(
+      '`bounds` parameter has unexpected number of elements '
+      '(expected 2).')
+
+  lower_bounds, upper_bounds = bounds
+
+  # Defer further conversion of the bounds to appropriate tensors
+  # until the shape of the input is known
+
   if initial_inverse_hessian_estimate is not None:
     raise NotImplementedError(
-        'Support of initial_inverse_hessian_estimate arg not yet implemented')
+      'Support of initial_inverse_hessian_estimate arg not yet implemented')
 
   if stopping_condition is None:
     stopping_condition = bfgs_utils.converged_all
@@ -219,180 +297,1223 @@ def quadratic_loss_and_gradient(x):
   with tf.name_scope(name or 'minimize'):
     if (initial_position is None) == (previous_optimizer_results is None):
       raise ValueError(
-          'Exactly one of `initial_position` or '
-          '`previous_optimizer_results` may be specified.')
+        'Exactly one of `initial_position` or '
+        '`previous_optimizer_results` may be specified.')
 
     if initial_position is not None:
       initial_position = tf.convert_to_tensor(
-          initial_position, name='initial_position')
+        initial_position, name='initial_position')
+      # Force at least one batching dimension
+      if len(ps.shape(initial_position)) == 1:
+        initial_position = initial_position[tf.newaxis, :]
+      position_shape = ps.shape(initial_position)
       dtype = dtype_util.base_dtype(initial_position.dtype)
 
     if previous_optimizer_results is not None:
-      dtype = dtype_util.base_dtype(previous_optimizer_results.position.dtype)
+      position_shape = ps.shape(previous_optimizer_results.position)
+      dtype = dtype_util.base_dtype(
+        previous_optimizer_results.position.dtype)
+
+    # TODO: This isn't agnostic to the number of batch dimensions, it only
+    #  supports one batch dimension, but I've found RaggedTensors to be far
+    #  too finicky/undocumented to handle multiple batch dimensions in any
+    #  sane way. (Even the way it's working so far is less than ideal.)
+    if len(position_shape) > 2:
+      raise NotImplementedError(
+        "More than a batch dimension is not implemented. "
+        "Consider flattening and then reshaping the results.")
+    # NOTE: Broadcasting the batched dimensions breaks when there are no
+    #  batched dimensions. Although this isn't handled like this in
+    #  `lbfgs.py`, I'd rather force a batch dimension with a single
+    #  element than do conditional checks later.
+    if len(position_shape) == 1:
+      position_shape = tf.concat([[1], position_shape], axis=0)
+      initial_position = tf.broadcast_to(
+        initial_position, position_shape)
+
+    # NOTE: Could maybe use bfgs_utils._broadcast here, but would have to check
+    #  that the non-batching dimensions also match; using `tf.broadcast_to` has
+    #  the advantage that passing a (1,)-shaped tensor as bounds will correctly
+    #  bound every variable at the single value.
+    if lower_bounds is None:
+      lower_bounds = tf.constant(
+        [-float('inf')], shape=position_shape, dtype=dtype, name='lower_bounds')
+    else:
+      lower_bounds = tf.cast(
+        tf.convert_to_tensor(lower_bounds), dtype=dtype)
+      try:
+        lower_bounds = tf.broadcast_to(
+          lower_bounds, position_shape, name='lower_bounds')
+      except tf.errors.InvalidArgumentError:
+        raise ValueError(
+          'Failed to broadcast lower bounds tensor to the shape of starting '
+          'position. Are the lower bounds well formed?')
+    if upper_bounds is None:
+      upper_bounds = tf.constant(
+        [float('inf')], shape=position_shape, dtype=dtype, name='upper_bounds')
+    else:
+      upper_bounds = tf.cast(
+        tf.convert_to_tensor(upper_bounds), dtype=dtype)
+      try:
+        upper_bounds = tf.broadcast_to(
+          upper_bounds, position_shape, name='upper_bounds')
+      except tf.errors.InvalidArgumentError:
+        raise ValueError(
+          'Failed to broadcast upper bounds tensor to the shape of starting '
+          'position. Are the lower bounds well formed?')
+
+    # Clamp the starting position to the bounds, because the algorithm expects
+    # the variables to be in range for the Hessian inverse estimation, but also
+    # because that fast-tracks the first iteration of the Cauchy optimization.
+    initial_position = tf.clip_by_value(
+      initial_position, lower_bounds, upper_bounds)
 
     tolerance = tf.convert_to_tensor(
-        tolerance, dtype=dtype, name='grad_tolerance')
+      tolerance, dtype=dtype, name='grad_tolerance')
     f_relative_tolerance = tf.convert_to_tensor(
-        f_relative_tolerance, dtype=dtype, name='f_relative_tolerance')
+      f_relative_tolerance, dtype=dtype, name='f_relative_tolerance')
     x_tolerance = tf.convert_to_tensor(
-        x_tolerance, dtype=dtype, name='x_tolerance')
-    max_iterations = tf.convert_to_tensor(max_iterations, name='max_iterations')
+      x_tolerance, dtype=dtype, name='x_tolerance')
+    max_iterations = tf.convert_to_tensor(
+      max_iterations, name='max_iterations')
 
-    # The `state` here is a `LBfgsOptimizerResults` tuple with values for the
+    # The `state` here is a `LBfgsBOptimizerResults` tuple with values for the
     # current state of the algorithm computation.
     def _cond(state):
       """Continue if iterations remain and stopping condition is not met."""
       return ((state.num_iterations < max_iterations) &
-              tf.logical_not(stopping_condition(state.converged, state.failed)))
+          tf.logical_not(stopping_condition(state.converged, state.failed)))
 
     def _body(current_state):
       """Main optimization loop."""
       current_state = bfgs_utils.terminate_if_not_finite(current_state)
-      search_direction = _get_search_direction(current_state)
+      cauchy_state, current_state = _cauchy_minimization(
+        current_state, num_correction_pairs, parallel_iterations)
 
-      # TODO(b/120134934): Check if the derivative at the start point is not
-      # negative, if so then reset position/gradient deltas and recompute
-      # search direction.
+      search_direction, current_state, clip_before, refreshed = (
+        _find_search_direction(
+          current_state, cauchy_state, num_correction_pairs))
 
-      next_state = bfgs_utils.line_search_step(
-          current_state,
-          value_and_gradients_function, search_direction,
-          tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
-          max_line_search_iterations)
+      # If any batch needs a refresh, restart the whole thing, to reduce number
+      # of function evaluations
 
-      # If not failed or converged, update the Hessian estimate.
-      should_update = ~(next_state.converged | next_state.failed)
-      state_after_inv_hessian_update = bfgs_utils.update_fields(
+      def _continue_minimization():
+        """Proceeds with minimization iteration."""
+        next_state = _constrained_line_search_step(
+          current_state, value_and_gradients_function, search_direction,
+          tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
+          max_line_search_iterations, clip_before)
+
+        # If not failed or converged, update the Hessian estimate.
+        # Only do this if the new pairs obey the s.y > eps.||y||
+        position_delta = (next_state.position - current_state.position)
+        gradient_delta = (next_state.objective_gradient -
+                  current_state.objective_gradient)
+        # Article is ambiguous; see lbfgs.f:863
+        curvature_cond = (
+          tf.reduce_sum(position_delta * gradient_delta, axis=-1) >=
+          bfgs_utils.norm(current_state.objective_gradient, dims=1) *
+          dtype_util.eps(position_delta.dtype))
+        should_push = (~(next_state.converged | next_state.failed) &
+                 curvature_cond & ~refreshed)
+        # TODO: Track number of skipped pairs
+        new_position_deltas = _queue_push(
+          next_state.position_deltas, should_push, position_delta)
+        new_gradient_deltas = _queue_push(
+          next_state.gradient_deltas, should_push, gradient_delta)
+        new_history = tf.where(
+          should_push,
+          tf.math.minimum(next_state.history + 1,
+                  num_correction_pairs),
+          next_state.history)
+
+        if not tf.executing_eagerly():
+          # Hint the compiler that the shape of the properties has not changed
+          new_position_deltas = tf.ensure_shape(
+            new_position_deltas, next_state.position_deltas.shape)
+          new_gradient_deltas = tf.ensure_shape(
+            new_gradient_deltas, next_state.gradient_deltas.shape)
+          new_history = tf.ensure_shape(
+            new_history, next_state.history.shape)
+
+        next_state = bfgs_utils.update_fields(
           next_state,
-          position_deltas=_queue_push(
-              current_state.position_deltas, should_update,
-              next_state.position - current_state.position),
-          gradient_deltas=_queue_push(
-              current_state.gradient_deltas, should_update,
-              next_state.objective_gradient - current_state.objective_gradient))
-      return [state_after_inv_hessian_update]
+          position_deltas=new_position_deltas,
+          gradient_deltas=new_gradient_deltas,
+          history=new_history)
+
+        return [next_state]
+
+      return tf.cond(
+        pred=tf.reduce_any(refreshed),
+        true_fn=lambda: [current_state],
+        false_fn=_continue_minimization)
 
     if previous_optimizer_results is None:
       assert initial_position is not None
       initial_state = _get_initial_state(value_and_gradients_function,
-                                         initial_position,
-                                         num_correction_pairs,
-                                         tolerance)
+                         initial_position,
+                         lower_bounds,
+                         upper_bounds,
+                         num_correction_pairs,
+                         tolerance)
     else:
       initial_state = previous_optimizer_results
 
     return tf.while_loop(
-        cond=_cond,
-        body=_body,
-        loop_vars=[initial_state],
-        parallel_iterations=parallel_iterations)[0]
-
-
-def _get_initial_state(value_and_gradients_function,
-                       initial_position,
-                       num_correction_pairs,
-                       tolerance):
-  """Create LBfgsOptimizerResults with initial state of search procedure."""
-  init_args = bfgs_utils.get_initial_state_args(
-      value_and_gradients_function,
-      initial_position,
-      tolerance)
-  empty_queue = _make_empty_queue_for(num_correction_pairs, initial_position)
-  init_args.update(position_deltas=empty_queue, gradient_deltas=empty_queue)
-  return LBfgsOptimizerResults(**init_args)
+      cond=_cond,
+      body=_body,
+      loop_vars=[initial_state],
+      parallel_iterations=parallel_iterations)[0]
 
 
-def _get_search_direction(state):
-  """Computes the search direction to follow at the current state.
+def _cauchy_minimization(bfgs_state, num_correction_pairs, parallel_iterations):
+  """Calculates the Cauchy point, bounding the gradient by the bounds.
 
-  On the `k`-th iteration of the main L-BFGS algorithm, the state has collected
-  the most recent `m` correction pairs in position_deltas and gradient_deltas,
-  where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`.
+  This function minimizes the quadratic approximation to the objective
+  function at the current position, in the direction of steepest descent,
+  but bounding the gradient by the corresponding bounds.
 
-  Assuming these, the code below is an implementation of the L-BFGS two-loop
-  recursion algorithm given by [Nocedal and Wright(2006)][1]:
+  See algorithm CP and associated discussion of [Byrd,Lu,Nocedal,Zhu][2]
+  for details.
 
-  ```None
-    q_direction = objective_gradient
-    for i in reversed(range(m)):  # First loop.
-      inv_rho[i] = gradient_deltas[i]^T * position_deltas[i]
-      alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i]
-      q_direction = q_direction - alpha[i] * gradient_deltas[i]
+  This function may modify the given `bfgs_state`, in that it refreshes the
+  memory for batches that are found to be in an invalid state.
 
-    kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] /
-                              gradient_deltas[-1]^T * gradient_deltas[-1])
-    r_direction = kth_inv_hessian_factor * I * q_direction
+  Args:
+    bfgs_state: current `LBfgsBOptimizerResults` state
+    num_correction_pairs: the (maximum) number of past steps to keep as
+    history for the LBFGS algorithm
+    parallel_iterations: argument of `tf.while` loops
+  Returns:
+    A `_CauchyMinimizationResult` containing the results of the Cauchy point
+    computation.
+    Updated `bfgs_state`
+  """
+  cauchy_state, bfgs_state = _get_initial_cauchy_state(
+    bfgs_state, num_correction_pairs)
+  n = ps.shape(bfgs_state.position)[-1]
+  idx_range = tf.range(ps.shape(bfgs_state.position)[-1])[tf.newaxis, ...]
+  # NOTE: See lbfgsb.f (l. 1524)
+  ddf_org = -cauchy_state.theta * cauchy_state.df
+
+  def _cond(state):
+    """Test convergence to Cauchy point at current branch"""
+    return tf.reduce_any(state.active)
+
+  def _body(state):
+    """Cauchy point iterative loop (While loop of CP algorithm [2])"""
+    # Because of `where` statements, the indices for gathering must always
+    # be valid, even if the result is not used afterwards. For batches that
+    # are no longer active, the `next_free_idx` (which points to the index
+    # of the current minimum breakpoint via `breakpoints_argsort`) may
+    # exceed the size of `breakpoints_argsort` (if the batch isn't active
+    # because there are no free variables left). So, instead, we take 0 as a
+    # dummy value, which will later be discarded by the `where` statements.
+    next_free_idx = tf.where(
+      state.active,
+      state.next_free_idx,
+      0)
+    breakpoint_min_idx = tf.where(
+      state.active,
+      tf.gather(
+        state.breakpoints_argsort,
+        next_free_idx,
+        batch_dims=1),
+      0)
+    breakpoint_min = tf.where(
+      state.active,
+      tf.gather(
+        state.breakpoints,
+        breakpoint_min_idx,
+        batch_dims=1),
+      state.breakpoint_min_old)
+
+    dt = (breakpoint_min - state.breakpoint_min_old)
+
+    # NOTE: We immediately update active to simulate an early return
+    # This value should be used below (instead of `state.active`)
+    active = (state.active & (state.dt_min >= dt))
+
+    # Set the considered variable as fixed
+    tsum = tf.where(active, state.tsum + dt, state.tsum)
+    breakpoint_min_idx_mask = (
+      idx_range == breakpoint_min_idx[..., tf.newaxis])
+    steepest = tf.where(
+      active[..., tf.newaxis],
+      tf.where(
+        breakpoint_min_idx_mask,
+        0.,
+        state.steepest),
+      state.steepest)
+    free_mask = tf.where(
+      active[..., tf.newaxis],
+      (state.free_mask & ~breakpoint_min_idx_mask),
+      state.free_mask)
+    d_b = tf.gather(
+      state.steepest,
+      breakpoint_min_idx,
+      batch_dims=1)
+    x_cp_b = tf.gather(
+      tf.where(
+        (d_b > 0.)[..., tf.newaxis],
+        bfgs_state.upper_bounds,
+        tf.where(
+            (d_b < 0.)[..., tf.newaxis],
+          bfgs_state.lower_bounds,
+          state.cauchy_point
+        )),
+      breakpoint_min_idx,
+      batch_dims=1)
+    cauchy_point = tf.where(
+      active[..., tf.newaxis],
+      tf.where(
+        breakpoint_min_idx_mask,
+        x_cp_b[..., tf.newaxis],
+        state.cauchy_point),
+      state.cauchy_point)
+
+    # If we're out of free variables, set dt_min to dt and "return"
+    next_free_idx = tf.where(active, next_free_idx + 1, next_free_idx)
+    no_more_free = (next_free_idx >= n)
+    dt_min = tf.where(no_more_free, dt, state.dt_min)
+    active &= ~no_more_free
+
+    # Update remaining properties
+    # - Update `c`
+    c = tf.where(
+      active[..., tf.newaxis],
+      state.c + dt[..., tf.newaxis] * state.p,
+      state.c)
+    # - Get the `b`th row of W (needed for f', f'')
+    # The matrix M has shape
+    #
+    #  [[ 0  0   ]
+    #   [ 0  M_h ]]
+    #
+    # where M_h is the M matrix considering the current history `h`.
+    # Therefore, for W, we should consider that the last `h` columns
+    #  are
+    #     Y[k-h,...,k-1] theta*S[k-h,...k-1]
+    #         (so that the first `2*(m-h)` columns are 0.
+    # 1. Create the "full" W matrix row
+    w_b = tf.concat(
+      [tf.gather(
+        bfgs_state.gradient_deltas,
+        breakpoint_min_idx,
+        axis=-1,
+        batch_dims=1),
+       (state.theta[..., tf.newaxis] *
+        tf.gather(
+         bfgs_state.position_deltas,
+         breakpoint_min_idx,
+         axis=-1,
+         batch_dims=1))
+       ],
+      axis=-1)
+    # 2. "Permute" the relevant items to the right
+    idx = tf.concat(
+      [
+        tf.ragged.range(
+          num_correction_pairs - bfgs_state.history),
+        tf.ragged.range(
+          num_correction_pairs,
+          2*num_correction_pairs - bfgs_state.history),
+        tf.ragged.range(
+          num_correction_pairs - bfgs_state.history,
+          num_correction_pairs),
+        tf.ragged.range(
+          2*num_correction_pairs - bfgs_state.history,
+          2*num_correction_pairs)
+      ],
+      axis=-1).to_tensor()
+    w_b = tf.gather(
+      w_b,
+      idx,
+      batch_dims=1)
+
+    # - Update f'
+    x_b = tf.gather(
+      bfgs_state.position,
+      breakpoint_min_idx,
+      batch_dims=1)
+    # NOTE Use of d_b = -g_b
+    df = tf.where(
+      active,
+      (state.df + dt * state.ddf +
+       d_b**2 -
+       state.theta * d_b * (x_cp_b - x_b) +
+       d_b * tf.einsum(
+         '...j,...jk,...k->...',
+         w_b,
+         state.m,
+         c)),
+      state.df)
+
+    # - Update f''
+    # NOTE use of d_b = -g_b
+    ddf = tf.where(
+      active,
+      (state.ddf - state.theta * d_b**2 +
+       2. * d_b * tf.einsum(
+         "...i,...ij,...j->...",
+         w_b,
+         state.m,
+         state.p) -
+       d_b**2 * tf.einsum(
+         "...i,...ij,...j->...",
+         w_b,
+         state.m,
+         w_b)),
+      state.ddf)
+    # NOTE: See lbfgsb.f (l. 1649)
+    ddf = tf.where(
+      active,
+      tf.math.maximum(ddf, dtype_util.eps(ddf.dtype)*ddf_org),
+      state.ddf)
+
+    # - Update p
+    # NOTE use of d_b = -g_b
+    p = tf.where(
+      active[..., tf.newaxis],
+      state.p - d_b[..., tf.newaxis] * w_b,
+      state.p)
+
+    # - Update dt_min
+    dt_min = tf.where(
+      active, -tf.math.divide_no_nan(df, ddf), state.dt_min)
+
+    # Create the updated state
+
+    # We need to hint the compiler that nothing changed shapes
+    if not tf.executing_eagerly():
+      steepest = tf.ensure_shape(steepest, state.steepest.shape)
+      p = tf.ensure_shape(p, state.p.shape)
+      c = tf.ensure_shape(c, state.c.shape)
+      df = tf.ensure_shape(df, state.df.shape)
+      ddf = tf.ensure_shape(ddf, state.ddf.shape)
+      dt = tf.ensure_shape(dt, state.dt.shape)
+      dt_min = tf.ensure_shape(dt_min, state.dt_min.shape)
+      tsum = tf.ensure_shape(tsum, state.tsum.shape)
+      breakpoint_min = tf.ensure_shape(
+        breakpoint_min, state.breakpoint_min_old.shape)
+      next_free_idx = tf.ensure_shape(
+        next_free_idx, state.next_free_idx.shape)
+      cauchy_point = tf.ensure_shape(
+        cauchy_point, state.cauchy_point.shape)
+      free_mask = tf.ensure_shape(free_mask, state.free_mask.shape)
+      active = tf.ensure_shape(active, state.active.shape)
+
+    new_state = bfgs_utils.update_fields(
+      state, steepest=steepest, p=p, c=c, df=df, ddf=ddf, dt=dt,
+      dt_min=dt_min, tsum=tsum, breakpoint_min_old=breakpoint_min,
+      next_free_idx=next_free_idx, cauchy_point=cauchy_point,
+      free_mask=free_mask, active=active)
+
+    return [new_state]
+
+  cauchy_loop = tf.while_loop(
+    cond=_cond,
+    body=_body,
+    loop_vars=[cauchy_state],
+    parallel_iterations=parallel_iterations)[0]
+
+  # NOTE: See lbfgs.f lines 1584, 1606, 1667, 1682
+  free_remaining = (cauchy_loop.next_free_idx < n)
+  dt_min = tf.where(
+    free_remaining,
+    tf.math.maximum(cauchy_loop.dt_min, 0),
+    cauchy_loop.dt_min)
+  tsum = tf.where(
+    free_remaining,
+    cauchy_loop.tsum + dt_min,
+    cauchy_loop.tsum)
+
+  cauchy_point = tf.where(
+    (bfgs_state.converged | bfgs_state.failed)[..., tf.newaxis],
+    bfgs_state.position,
+    tf.where(
+      free_remaining[..., tf.newaxis],
+      cauchy_loop.cauchy_point +
+      tsum[..., tf.newaxis] * cauchy_loop.steepest,
+      cauchy_loop.cauchy_point))
+
+  c = cauchy_loop.c + dt_min[..., tf.newaxis]*cauchy_loop.p
+  # NOTE: `c` is already permuted to match the subspace of `M`, because `w_b`
+  #  was already permuted.
+  # You can explicitly check this by comparing its value with W'.(x^c - x)
+  #  at this point.
+
+  # Set points where gradient is 0 as fixed
+  # TODO: Does this cause problems with sadle points?
+  free_mask = (cauchy_loop.free_mask & (bfgs_state.objective_gradient != 0))
+
+  # Hint the compiler that shape of things will not change
+  if not tf.executing_eagerly():
+    dt_min = tf.ensure_shape(dt_min, cauchy_loop.dt_min.shape)
+    tsum = tf.ensure_shape(tsum, cauchy_loop.tsum.shape)
+    cauchy_point = tf.ensure_shape(
+      cauchy_point, cauchy_loop.cauchy_point.shape)
+    c = tf.ensure_shape(c, cauchy_loop.c.shape)
+    free_mask = tf.ensure_shape(free_mask, cauchy_loop.free_mask.shape)
+  # Do the actual updating
+  final_cauchy_state = bfgs_utils.update_fields(
+    cauchy_loop, dt_min=dt_min, tsum=tsum, cauchy_point=cauchy_point, c=c,
+    free_mask=free_mask)
+
+  return final_cauchy_state, bfgs_state
+
+
+def _get_initial_cauchy_state(bfgs_state, num_correction_pairs):
+  """Create `_ConstrainedCauchyState` with initial parameters.
+
+  This will calculate the elements of `_ConstrainedCauchyState` based on the
+  given `LBfgsBOptimizerResults` state object. Some of these properties may be
+  incalculable, for which batches the state will be reset.
 
-    for i in range(m):  # Second loop.
-      beta = gradient_deltas[i]^T * r_direction / inv_rho[i]
-      r_direction = r_direction + position_deltas[i] * (alpha[i] - beta)
+  Args:
+    bfgs_state: `LBfgsBOptimizerResults` object representing the current state
+    of the LBFGSB optimization
+    num_correction_pairs: typically `m`; the (maximum) number of past steps to
+    keep as history for the LBFGS algorithm
 
-    return -r_direction  # Approximates - H_k * objective_gradient.
-  ```
+  Returns:
+    Initialized `_ConstrainedCauchyState`
+    Updated `bfgs_state`
+  """
+  cauchy_point = bfgs_state.position
+
+  theta = tf.math.divide_no_nan(
+    tf.reduce_sum(bfgs_state.gradient_deltas[..., -1, :]**2, axis=-1),
+    (tf.reduce_sum(bfgs_state.gradient_deltas[..., -1, :] *
+             bfgs_state.position_deltas[..., -1, :], axis=-1)))
+  theta = tf.where(bfgs_state.history == 0, 1., theta)
+
+  m, refresh = _cauchy_init_m(bfgs_state, theta, num_correction_pairs)
+
+  # Erase the history where M isn't invertible
+  bfgs_state = _erase_history(bfgs_state, refresh)
+  theta = tf.where(refresh, 1., theta)
+
+  breakpoints = _cauchy_init_breakpoints(bfgs_state)
+  breakpoints_argsort = tf.argsort(breakpoints)
+
+  steepest = tf.where((breakpoints > 0.), -bfgs_state.objective_gradient, 0.)
+
+  # We need to account for the varying histories:
+  # we assume that the first `2*(m-h)` rows of W'^T
+  # are 0 (where `m` is the number of correction pairs
+  # and `h` is the history), in concordance with the first
+  # `2*(m-h)` rows of M being 0.
+  # 1. Calculate all elements
+  p = tf.concat(
+    [
+      tf.einsum(
+        "...mi,...i->...m",
+        bfgs_state.gradient_deltas,
+        steepest),
+      (theta[..., tf.newaxis] *
+       tf.einsum(
+        "...mi,...i->...m",
+        bfgs_state.position_deltas,
+        steepest))
+    ],
+    axis=-1)
+  # 2. Assemble the rows in the correct order
+  idx = tf.concat(
+    [
+      tf.ragged.range(
+        num_correction_pairs - bfgs_state.history),
+      tf.ragged.range(
+        num_correction_pairs,
+        2*num_correction_pairs - bfgs_state.history),
+      tf.ragged.range(
+        num_correction_pairs - bfgs_state.history,
+        num_correction_pairs),
+      tf.ragged.range(
+        2*num_correction_pairs - bfgs_state.history,
+        2*num_correction_pairs)
+    ],
+    axis=-1).to_tensor()
+  p = tf.gather(
+    p,
+    idx,
+    batch_dims=1)
+
+  c = tf.zeros_like(p)
+  df = -tf.reduce_sum(steepest**2, axis=-1)
+  ddf = -theta*df - tf.einsum("...i,...ij,...j->...", p, m, p)
+  dt_min = -tf.math.divide_no_nan(df, ddf)
+  tsum = tf.zeros_like(dt_min)
+
+  # NOTE: These are placeholder values.
+  # All of these have shape [batch], which matches dt_min
+  dt = tf.zeros_like(dt_min)
+  breakpoint_min_old = tf.zeros_like(dt_min)
+
+  next_free_idx = tf.reduce_sum(tf.where(breakpoints <= 0., 1, 0), axis=-1)
+  free_mask = (breakpoints > 0.)
+
+  # NOTE: _cauchy_update_active should NOT be accounted for here; the first
+  # iteration should always run (if the batch is overall active)
+  active = ~(bfgs_state.converged | bfgs_state.failed)
+
+  cauchy_state = _ConstrainedCauchyState(
+    theta=theta, m=m, breakpoints=breakpoints,
+    breakpoints_argsort=breakpoints_argsort, next_free_idx=next_free_idx,
+    steepest=steepest, p=p, c=c, df=df, ddf=ddf, dt=dt, dt_min=dt_min,
+    tsum=tsum, breakpoint_min_old=breakpoint_min_old, cauchy_point=cauchy_point,
+    active=active, free_mask=free_mask)
+
+  return cauchy_state, bfgs_state
+
+
+def _cauchy_init_breakpoints(state):
+  """Calculate the breakpoints for a `_CauchyMinimizationResult` state."""
+  breakpoints = (
+    tf.where(
+      state.objective_gradient < 0,
+      tf.math.divide_no_nan(
+        state.position - state.upper_bounds,
+        state.objective_gradient),
+      tf.where(
+        state.objective_gradient > 0,
+        tf.math.divide_no_nan(
+          state.position - state.lower_bounds,
+          state.objective_gradient),
+        float('inf')))
+  )
+
+  return breakpoints
+
+
+def _find_search_direction(bfgs_state, cauchy_state, num_correction_pairs):
+  """Finds the search direction based on the direct primal method.
+
+  This function corresponds to points 1-6 of the Direct Primal Method presented
+  in [2, p. 1199] for subspace minimization, with the first modification
+  suggested in [3].
+
+  If an invalid condition is reached for a given batch, its history is reset.
+  Therefore, this function also returns an updated `bfgs_state`. 
 
   Args:
-    state: A `LBfgsOptimizerResults` tuple with the current state of the
-      search procedure.
+    bfgs_state: the `LBfgsBOptimizerResults` object representing the current
+      iteration.
+    cauchy_state: the `_CauchyMinimizationResult` results of a cauchy search
+      computation. Typically the output of `_cauchy_minimization`.
+    num_correction_pairs: The (maximum) number of correction pairs stored in
+      memory (`m`)
+  Returns:
+    Tensor of batched search directions,
+    Updated `bfgs_state`,
+    Tensor of Boolean dtype indicating whether the search direction should be
+      clamped to bounds before the search is performed,
+    Tensor of Boolean dtype indicating what batches have been refreshed.
+  """
+  def _find_constrained_minimizer():
+    """Performs free subspace minimization based on the Direct Method."""
+    # Let the reduced gradient be [2, eq. 5.4]
+    #
+    #     ρ = Z'r
+    #     r = g + Θ(x^c - x) + (1/Θ).W.M.c
+    #
+    # and the search direction [2, eq. 5.7]
+    #
+    #     d = -B⁻¹ρ
+    #
+    # and [2, eq. 5.10]
+    #
+    #     B⁻¹ = 1/Θ [ I + 1/Θ Z'.W.N⁻¹.M.W'.Z ]
+    #     N   = I - 1/Θ M.W'.Z.Z'.W
+    #
+    # Therefore,
+    #
+    #     d = Z' . (-1/Θ) . [ r + 1/Θ W.N⁻¹.M.W'.Z.Z'.r ]
+    #
+    # NOTE that the leading sign does not match that of [2, eq. 5.11]. This is
+    # because the article conflates the definition of r in [2, eq. 5.4] and the
+    # definition employed in the Fortran implementation, where
+    #
+    #    r = -Z'B(x^c - x) - Z'g
+    #
+    # From which follows
+    #
+    #    d = Z'  (1/Θ) . [ r + 1/Θ W.N⁻¹.M.W'.Z.Z'.r ]
+    idx = (
+      tf.concat([
+        tf.ragged.range(
+          num_correction_pairs - bfgs_state.history),
+        tf.ragged.range(
+          num_correction_pairs,
+          2*num_correction_pairs - bfgs_state.history),
+        tf.ragged.range(
+          num_correction_pairs - bfgs_state.history,
+          num_correction_pairs),
+        tf.ragged.range(
+          2*num_correction_pairs - bfgs_state.history,
+          2*num_correction_pairs)
+      ],
+        axis=-1).to_tensor())
+
+    w_transpose = (
+      tf.gather(
+        tf.concat(
+          [bfgs_state.gradient_deltas,
+           cauchy_state.theta[..., tf.newaxis, tf.newaxis] *
+           bfgs_state.position_deltas],
+          axis=-2),
+        idx,
+        batch_dims=1)
+    )
+
+    r = (
+      cauchy_state.theta[..., tf.newaxis] *
+      (bfgs_state.position - cauchy_state.cauchy_point) +
+      tf.einsum(
+        '...ji,...jk,...k->...i',
+        w_transpose,
+        cauchy_state.m,
+        cauchy_state.c) -
+      bfgs_state.objective_gradient)
+
+    n = (
+      tf.eye(
+        num_rows=num_correction_pairs*2,
+        batch_shape=ps.shape(bfgs_state.position)[:-1]) -
+      (tf.einsum(
+        '...ij,...jk,...lk->...il',
+        cauchy_state.m,
+        w_transpose,
+        tf.where(
+          cauchy_state.free_mask[..., tf.newaxis, :],
+          w_transpose,
+          0.)
+      ) / cauchy_state.theta[..., tf.newaxis, tf.newaxis]))
+
+    # NOTE: No need to "mask" the no-history subspace of N: because of I - (...)
+    # we correctly get a block form. The extraneous identity block is then
+    # zeroed when the product with M is taken
+    refresh = (tf.linalg.det(n) == 0.)
+
+    n = tf.linalg.inv(
+      tf.where(
+        refresh[..., tf.newaxis, tf.newaxis],
+        tf.eye(
+          num_rows=num_correction_pairs*2,
+          batch_shape=ps.shape(bfgs_state.position)[:-1]),
+        n))
+
+    n = tf.where(
+      refresh[..., tf.newaxis, tf.newaxis],
+      tf.zeros_like(n),
+      n)
+
+    # d is composed in three parts
+    d = tf.einsum('...ji,...jk,...kl,...lm,...m->...i',
+            w_transpose,
+            n,
+            cauchy_state.m,
+            tf.where(
+              cauchy_state.free_mask[..., tf.newaxis, :],
+              w_transpose,
+              0.),
+            r)
+
+    d = r + d/cauchy_state.theta[..., tf.newaxis]
+    d = d/cauchy_state.theta[..., tf.newaxis]
+
+    d = tf.where(cauchy_state.free_mask, d, 0.)
+
+    # Per [3]:
+    # Project `(cauchy point) + d` into the bounds
+    # NOTE: `d` is zeroed for constrained variables, and `movement_clip` is
+    # at most 1.
+    minimizer = tf.clip_by_value(
+      cauchy_state.cauchy_point + d,
+      bfgs_state.lower_bounds,
+      bfgs_state.upper_bounds)
+
+    # Per [3]: If the search direction obtained with this minimizer is not a
+    # direction of strong descent, allow the minimizer to be oob, and clip the
+    # direction (i.e. fall back to the original algorithm). The clipping is
+    # handled outside this fn.
+    fallback = (tf.reduce_sum((minimizer - bfgs_state.position) *
+                  bfgs_state.objective_gradient, axis=-1) > 0)
+
+    minimizer = tf.where(
+      fallback[..., tf.newaxis],
+      cauchy_state.cauchy_point + d,
+      minimizer)
+
+    active = (tf.reduce_any(cauchy_state.free_mask, axis=-1) &
+          (bfgs_state.history > 0))
+    minimizer = tf.where(
+      active[..., tf.newaxis], minimizer, cauchy_state.cauchy_point)
+
+    return minimizer, refresh, fallback
+
+  # NOTE: we're abusing `bfgs_state.history.shape` again to get the batch
+  # dimensions Also: the Cauchy point is a minimization along the (projected)
+  # minus gradient direction; this is why we can skip subspace minimization if
+  # there's no history (because the search direction would indeed have been the
+  # minus gradient), but should run it otherwise (to make use of the BFGS
+  # information).
+  skip_subspace = (
+    (~tf.reduce_any(cauchy_state.free_mask)) |
+    tf.reduce_all(bfgs_state.history == 0))
+  minimizer, refresh, clip_before = (
+    tf.cond(
+      pred=skip_subspace,
+      true_fn=lambda: (cauchy_state.cauchy_point,
+               tf.broadcast_to(
+                 False, ps.shape(bfgs_state.history)),
+               tf.broadcast_to(True, ps.shape(bfgs_state.history))),
+      false_fn=_find_constrained_minimizer))
+
+  search_direction = (minimizer - bfgs_state.position)
+
+  # Reset if the search direction still isn't a direction of strong descent
+  refresh |= (
+    tf.reduce_sum(
+      search_direction * bfgs_state.objective_gradient, axis=-1) > 0)
+
+  # Refresh conditions only make sense if a batch had not already converged
+  refresh &= ~ (bfgs_state.converged | bfgs_state.failed)
+
+  # Apply refresh
+  bfgs_state = _erase_history(bfgs_state, refresh)
+
+  return search_direction, bfgs_state, clip_before, refresh
+
+
+def _constrained_line_search_step(bfgs_state, value_and_gradients_function,
+                  search_direction, grad_tolerance, f_relative_tolerance,
+                  x_tolerance, stopping_condition, max_iterations, clip_before):
+  """Performs a constrained line search clamped to bounds in given direction."""
+  inactive = (bfgs_state.failed | bfgs_state.converged)
+
+  def _do_line_search_step():
+    """Do unconstrained line search."""
+    nonlocal search_direction
+    # Truncation bounds
+    lower_term = tf.math.divide_no_nan(
+      bfgs_state.lower_bounds - bfgs_state.position,
+      search_direction)
+    upper_term = tf.math.divide_no_nan(
+      bfgs_state.upper_bounds - bfgs_state.position,
+      search_direction)
+    bounds_clip = (
+      tf.reduce_min(
+        tf.where(
+          (search_direction > 0),
+          upper_term,
+          tf.where(
+            (search_direction < 0),
+            lower_term,
+            float('inf'))),
+        axis=-1)
+    )
+
+    search_direction *= tf.where(
+      clip_before,
+      tf.math.minimum(1., bounds_clip),
+      1.)[..., tf.newaxis]
+
+    def _fn_with_report(x):
+      return value_and_gradients_function(
+        x, inactive, bfgs_state.objective_value, bfgs_state.objective_gradient)
+
+    ls_result = _hz_line_search(
+      bfgs_state.position, bfgs_state.objective_value,
+      bfgs_state.objective_gradient,
+      _fn_with_report, search_direction,
+      max_iterations, inactive)
+
+    # Truncate to bounds after search
+    step = (
+      tf.math.minimum(
+        bounds_clip,
+        ls_result.left.x
+      )
+    )
+
+    # For inactive batch members `left.x` is zero. However, their
+    # `search_direction` might also be undefined, so we can't rely on
+    # multiplication by zero to produce a `position_delta` of zero.
+    next_position = tf.where(
+      inactive[..., tf.newaxis],
+      bfgs_state.position,
+      bfgs_state.position + step[..., tf.newaxis] * search_direction)
+
+    # If the movement isn't clipped, we can use the final results of the
+    # line search.
+    reevaluated = (tf.reduce_any(ls_result.left.x > bounds_clip))
+    next_objective, next_gradient = (
+      tf.cond(
+        pred=reevaluated,
+        true_fn=lambda: value_and_gradients_function(
+          next_position, inactive, bfgs_state.objective_value,
+          bfgs_state.objective_gradient),
+        false_fn=lambda: (ls_result.left.f,
+                  ls_result.left.full_gradient)
+      )
+    )
+
+    new_failed = (bfgs_state.failed | (
+      ~inactive & ~bfgs_state.converged & ~ls_result.converged))
+    new_num_iterations = bfgs_state.num_iterations + 1
+    new_num_objective_evaluations = tf.cond(
+      pred=reevaluated,
+      true_fn=lambda: (
+        bfgs_state.num_objective_evaluations + ls_result.func_evals + 1),
+      false_fn=lambda: (
+        bfgs_state.num_objective_evaluations + ls_result.func_evals))
+
+    # Hint the compiler that the properties' shape will not change
+    if not tf.executing_eagerly():
+      new_failed = tf.ensure_shape(new_failed, bfgs_state.failed.shape)
+      new_num_iterations = tf.ensure_shape(
+        new_num_iterations, bfgs_state.num_iterations.shape)
+      new_num_objective_evaluations = tf.ensure_shape(
+        new_num_objective_evaluations, bfgs_state.num_objective_evaluations.shape)
+
+    state_after_ls = bfgs_utils.update_fields(
+      state=bfgs_state,
+      failed=new_failed,
+      num_iterations=new_num_iterations,
+      num_objective_evaluations=new_num_objective_evaluations)
+
+    return state_after_ls, next_position, next_objective, next_gradient
+
+  # NOTE: It's important that the default (false `pred`) step matches
+  # the shape of true `pred` shape for graph purposes
+  state_after_ls, next_position, next_objective, next_gradient = (
+    tf.cond(
+      pred=tf.math.logical_not(tf.reduce_all(inactive)),
+      true_fn=_do_line_search_step,
+      false_fn=lambda: (bfgs_state,
+                bfgs_state.position,
+                bfgs_state.objective_value,
+                bfgs_state.objective_gradient)
+    ))
+
+  def _do_update_position():
+    """Update the position"""
+    return _update_position(
+      state_after_ls,
+      next_position,
+      next_objective,
+      next_gradient,
+      grad_tolerance,
+      f_relative_tolerance,
+      x_tolerance,
+      inactive)
+
+  return ps.cond(
+    (stopping_condition(bfgs_state.converged, bfgs_state.failed) |
+     tf.reduce_all(inactive)),
+    true_fn=lambda: state_after_ls,
+    false_fn=_do_update_position)
+
+
+def _hz_line_search(starting_position, starting_value, starting_gradient,
+          value_and_gradients_function, search_direction, max_iterations,
+          inactive):
+  """Performs Hager Zhang line search via `bfgs_utils.linesearch.hager_zhang`."""
+  line_search_value_grad_func = bfgs_utils._restrict_along_direction(
+    value_and_gradients_function, starting_position, search_direction)
+  derivative_at_start_pt = tf.reduce_sum(
+    starting_gradient * search_direction, axis=-1)
+  val_0 = bfgs_utils.ValueAndGradient(
+    x=bfgs_utils._broadcast(0, starting_position),
+    f=starting_value,
+    df=derivative_at_start_pt,
+    full_gradient=starting_gradient)
+  return bfgs_utils.linesearch.hager_zhang(
+    line_search_value_grad_func,
+    initial_step_size=bfgs_utils._broadcast(1, starting_position),
+    value_at_zero=val_0,
+    converged=inactive,
+    max_iterations=max_iterations)
+
+
+def _update_position(state,
+           next_position,
+           next_objective,
+           next_gradient,
+           grad_tolerance,
+           f_relative_tolerance,
+           x_tolerance,
+           inactive):
+  """Updates the state advancing its position by a given position_delta."""
+  state = bfgs_utils.terminate_if_not_finite(
+    state, next_objective, next_gradient)
+
+  converged = (~inactive & ~state.failed &
+         _check_convergence_bounded(state.position,
+                      next_position,
+                      state.objective_value,
+                      next_objective,
+                      next_gradient,
+                      grad_tolerance,
+                      f_relative_tolerance,
+                      x_tolerance,
+                      state.lower_bounds,
+                      state.upper_bounds))
+  new_converged = (state.converged | converged)
+
+  if not tf.executing_eagerly():
+    # Hint the compiler that the properties have not changed shape
+    new_converged = tf.ensure_shape(new_converged, state.converged.shape)
+    next_position = tf.ensure_shape(next_position, state.position.shape)
+    next_objective = tf.ensure_shape(
+      next_objective, state.objective_value.shape)
+    next_gradient = tf.ensure_shape(
+      next_gradient, state.objective_gradient.shape)
+
+  return bfgs_utils.update_fields(
+    state,
+    converged=new_converged,
+    position=next_position,
+    objective_value=next_objective,
+    objective_gradient=next_gradient)
+
+
+def _erase_history(bfgs_state, where_erase):
+  """Erases the BFGS correction pairs for the specified batches.
+
+  This function will zero `gradient_deltas`, `position_deltas`, and `history`.
 
+  Args:
+    `bfgs_state`: a `LBfgsBOptimizerResults` to modify
+    `where_erase`: a Boolean tensor with shape matching the batch dimensions
+          with `True` for the batches to erase the history of.
   Returns:
-    A real `Tensor` of the same shape as the `state.position`. The direction
-    along which to perform line search.
+    Modified `bfgs_state`.
   """
-  # The number of correction pairs that have been collected so far.
-  num_elements = ps.minimum(
-      state.num_iterations,  # TODO(b/162733947): Change loop state -> closure.
-      ps.shape(state.position_deltas)[0])
-
-  def _two_loop_algorithm():
-    """L-BFGS two-loop algorithm."""
-    # Correction pairs are always appended to the end, so only the latest
-    # `num_elements` vectors have valid position/gradient deltas. Vectors
-    # that haven't been computed yet are zero.
-    position_deltas = state.position_deltas
-    gradient_deltas = state.gradient_deltas
-
-    # Pre-compute all `inv_rho[i]`s.
-    inv_rhos = tf.reduce_sum(
-        gradient_deltas * position_deltas, axis=-1)
-
-    def first_loop(acc, args):
-      _, q_direction = acc
-      position_delta, gradient_delta, inv_rho = args
-      alpha = tf.math.divide_no_nan(
-          tf.reduce_sum(position_delta * q_direction, axis=-1), inv_rho)
-      direction_delta = alpha[..., tf.newaxis] * gradient_delta
-      return (alpha, q_direction - direction_delta)
-
-    # Run first loop body computing and collecting `alpha[i]`s, while also
-    # computing the updated `q_direction` at each step.
-    zero = tf.zeros_like(inv_rhos[-num_elements])
-    alphas, q_directions = tf.scan(
-        first_loop, [position_deltas, gradient_deltas, inv_rhos],
-        initializer=(zero, state.objective_gradient), reverse=True)
-
-    # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse
-    # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`.
-    gamma_k = inv_rhos[-1] / tf.reduce_sum(
-        gradient_deltas[-1] * gradient_deltas[-1], axis=-1)
-    r_direction = gamma_k[..., tf.newaxis] * q_directions[-num_elements]
-
-    def second_loop(r_direction, args):
-      alpha, position_delta, gradient_delta, inv_rho = args
-      beta = tf.math.divide_no_nan(
-          tf.reduce_sum(gradient_delta * r_direction, axis=-1), inv_rho)
-      direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta
-      return r_direction + direction_delta
-
-    # Finally, run second loop body computing the updated `r_direction` at each
-    # step.
-    r_directions = tf.scan(
-        second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos],
-        initializer=r_direction)
-    return -r_directions[-1]
-
-  return ps.cond(ps.equal(num_elements, 0),
-                 lambda: -state.objective_gradient,
-                 _two_loop_algorithm)
+  # Calculate new values
+  new_gradient_deltas = (tf.where(
+    where_erase[..., tf.newaxis, tf.newaxis],
+    0.,
+    bfgs_state.gradient_deltas))
+  new_position_deltas = (tf.where(
+    where_erase[..., tf.newaxis, tf.newaxis],
+    0.,
+    bfgs_state.position_deltas))
+  new_history = tf.where(where_erase, 0, bfgs_state.history)
+  # Assure the compiler that the shape of things has not changed
+  if not tf.executing_eagerly():
+    new_gradient_deltas = (
+      tf.ensure_shape(
+        new_gradient_deltas,
+        bfgs_state.gradient_deltas.shape))
+    new_position_deltas = (
+      tf.ensure_shape(
+        new_position_deltas,
+        bfgs_state.position_deltas.shape))
+    new_history = (
+      tf.ensure_shape(
+        new_history,
+        bfgs_state.history.shape))
+  # Update and return
+  return bfgs_utils.update_fields(
+    bfgs_state,
+    gradient_deltas=new_gradient_deltas,
+    position_deltas=new_position_deltas,
+    history=new_history)
+
+
+def _check_convergence_bounded(current_position,
+                 next_position,
+                 current_objective,
+                 next_objective,
+                 next_gradient,
+                 grad_tolerance,
+                 f_relative_tolerance,
+                 x_tolerance,
+                 lower_bounds,
+                 upper_bounds):
+  """Checks if the algorithm satisfies the convergence criteria."""
+  # NOTE: The original algorithm (as described in [2]) only considers halting on
+  # the projected gradient condition. However, `x_converged` and `f_converged`
+  # do not seem to pose a problem when refreshing is correctly accounted for
+  # (so that the optimization does not halt upon a refresh), and the default
+  # values of `0` for `f_relative_tolerance` and `x_tolerance` further
+  # strengthen these conditions.
+  proj_grad_converged = bfgs_utils.norm(
+    tf.clip_by_value(
+      next_position - next_gradient,
+      lower_bounds,
+      upper_bounds) - next_position, dims=1) <= grad_tolerance
+  x_converged = bfgs_utils.norm(
+    next_position - current_position, dims=1) <= x_tolerance
+  f_ref = tf.math.maximum(1., tf.math.maximum(
+    tf.math.abs(next_objective),
+    tf.math.abs(current_objective)))
+  f_converged = (tf.math.abs(next_objective - current_objective)
+           <= f_ref*f_relative_tolerance)
+  return proj_grad_converged | x_converged | f_converged
+
+
+def _get_initial_state(value_and_gradients_function,
+             initial_position,
+             lower_bounds,
+             upper_bounds,
+             num_correction_pairs,
+             tolerance):
+  """Create LBfgsBOptimizerResults with initial state of search procedure."""
+  init_args = get_initial_state_args(value_and_gradients_function,
+                     initial_position,
+                     tolerance)
+  empty_queue = _make_empty_queue_for(num_correction_pairs, initial_position)
+  zero_history = tf.zeros(ps.shape(initial_position)[:-1], dtype=tf.int32)
+  init_args.update(
+    lower_bounds=lower_bounds,
+    upper_bounds=upper_bounds,
+    position_deltas=empty_queue,
+    gradient_deltas=empty_queue,
+    history=zero_history)
+  return LBfgsBOptimizerResults(**init_args)
+
+
+def get_initial_state_args(value_and_gradients_function,
+               initial_position,
+               grad_tolerance,
+               control_inputs=None):
+  none_finished = tf.broadcast_to(False, ps.shape(initial_position)[:-1])
+  zero_values = bfgs_utils._broadcast(0., initial_position)
+  zero_gradients = tf.zeros_like(initial_position)
+  if control_inputs:
+    with tf.control_dependencies(control_inputs):
+      f0, df0 = value_and_gradients_function(
+        initial_position, none_finished, zero_values, zero_gradients)
+  else:
+    f0, df0 = value_and_gradients_function(
+      initial_position, none_finished, zero_values, zero_gradients)
+  # This is a gradient-based convergence check.  We only do it for finite
+  # objective values because we assume the gradient reported at a position with
+  # a non-finite objective value is untrustworthy.  The main loop handles
+  # non-finite objective values itself (see `terminate_if_not_finite`).
+  init_converged = (tf.math.is_finite(f0) &
+            (bfgs_utils.norm(df0, dims=1) < grad_tolerance))
+  return dict(
+    converged=init_converged,
+    failed=tf.zeros_like(init_converged),  # i.e. False.
+    num_iterations=tf.convert_to_tensor(0),
+    num_objective_evaluations=tf.convert_to_tensor(1),
+    position=initial_position,
+    objective_value=f0,
+    objective_gradient=df0)
+
+
+def _cauchy_init_m(state, theta, num_correction_pairs):
+  """Initialize the M matrix for a `_CauchyMinimizationResult` state."""
+  def build_m():
+    """Construct and invert the M block matrix."""
+    # All of the below block matrices have dimensions [..., 2m, 2m]
+    #  where `...` denotes the batch dimensions, and `m` the number
+    #  of correction pairs.
+    # New elements are pushed in "from the back", so we want to index
+    #  position_deltas and gradient_deltas with negative indices.
+    # Index 0 of `position_deltas` and `gradient_deltas` is oldest, and index -1
+    #  is most recent, so the below respects the indexing of the article.
+
+    # 1. calculate inner product (s_i.y_j) in shape [..., m, m]
+    l = tf.einsum(
+      "...mi,...ui->...mu",
+      state.position_deltas,
+      state.gradient_deltas)
+    # 2. Zero out diagonal and upper triangular
+    l_shape = ps.shape(l)
+    l = tf.linalg.set_diag(
+      tf.linalg.band_part(l, -1, 0),
+      tf.zeros([l_shape[0], l_shape[-1]]))
+    l_transpose = tf.linalg.matrix_transpose(l)
+    s_t_s = tf.einsum(
+      '...mi,...ni->...mn',
+      state.position_deltas,
+      state.position_deltas)
+    d = tf.linalg.diag(
+      tf.einsum(
+        '...mi,...mi->...m',
+        state.position_deltas,
+        state.gradient_deltas))
+
+    # Assemble into full matrix
+    # shape [b, 2m, 2m]
+    m_inv = tf.concat(
+      [
+        tf.concat([-d, l_transpose], axis=-1),
+        tf.concat(
+          [l, theta[..., tf.newaxis, tf.newaxis] * s_t_s], axis=-1)
+      ], axis=-2)
+
+    # Adjust for varying history:
+    # Push columns indexed h,...,2m-h to the left (but to the right of 0...m-h)
+    #  and same index rows to the bottom
+    idx = tf.concat(
+      [tf.ragged.range(num_correction_pairs-state.history),
+       tf.ragged.range(num_correction_pairs, 2 *
+              num_correction_pairs-state.history),
+       tf.ragged.range(num_correction_pairs -
+              state.history, num_correction_pairs),
+       tf.ragged.range(
+              2*num_correction_pairs-state.history, 2*num_correction_pairs)],
+      axis=-1).to_tensor()
+    m_inv = tf.gather(
+      m_inv,
+      idx,
+      axis=-1,
+      batch_dims=1)
+    m_inv = tf.gather(
+      m_inv,
+      idx,
+      axis=-2,
+      batch_dims=1)
+
+    # Insert an identity in the empty block
+    identity_mask = (
+      (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] <
+       2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis])
+
+    m_inv = tf.where(
+      identity_mask,
+      tf.eye(ps.shape(m_inv)[-1], batch_shape=ps.shape(m_inv)[:-2]),
+      m_inv)
+
+    # If M is not invertible, refresh the memory
+    # TODO: Checking the determinant is likely overkill?
+    refresh = (tf.linalg.det(m_inv) == 0)
+
+    # Invert where invertible; 0s otherwise
+    m = tf.where(
+      refresh[..., tf.newaxis, tf.newaxis],
+      tf.zeros_like(m_inv),
+      tf.linalg.inv(
+        tf.where(
+          refresh[..., tf.newaxis, tf.newaxis],
+          tf.eye(ps.shape(m_inv)[-1],
+               batch_shape=ps.shape(m_inv)[:-2]),
+          m_inv)))
+
+    # Re-zero the introduced identity blocks
+    m = tf.where(
+      identity_mask,
+      tf.zeros_like(m),
+      m)
+
+    return m, refresh
+
+  # M is 0 for the first iterations
+  # We abuse `state.history` to extract the batch shape
+  m_shape = ps.concat([ps.shape(state.history),
+             [num_correction_pairs*2, num_correction_pairs*2]], axis=0)
+  return tf.cond(
+    state.num_iterations < 1,
+    lambda: (tf.zeros(m_shape),
+         tf.broadcast_to(False, ps.shape(state.history))),
+    build_m)
 
 
 def _make_empty_queue_for(k, element):
@@ -402,18 +1523,16 @@ def _make_empty_queue_for(k, element):
 
   ```python
     element = tf.constant([[0., 1., 2., 3., 4.],
-                           [5., 6., 7., 8., 9.]])
+               [5., 6., 7., 8., 9.]])
 
     # A queue capable of holding 3 elements.
     _make_empty_queue_for(3, element)
-    # => [[[ 0.,  0.,  0.,  0.,  0.],
-    #      [ 0.,  0.,  0.,  0.,  0.]],
-    #
-    #     [[ 0.,  0.,  0.,  0.,  0.],
-    #      [ 0.,  0.,  0.,  0.,  0.]],
-    #
-    #     [[ 0.,  0.,  0.,  0.,  0.],
-    #      [ 0.,  0.,  0.,  0.,  0.]]]
+    # => [[[0., 0., 0., 0., 0.],
+    #      [0., 0., 0., 0., 0.],
+    #      [0., 0., 0., 0., 0.]],
+    #     [[0., 0., 0., 0., 0.],
+    #      [0., 0., 0., 0., 0.],
+    #      [0., 0., 0., 0., 0.]]]
   ```
 
   Args:
@@ -421,17 +1540,18 @@ def _make_empty_queue_for(k, element):
     element: A `tf.Tensor`, only its shape and dtype information are relevant.
 
   Returns:
-    A zero-filed `tf.Tensor` of shape `(k,) + tf.shape(element)` and same dtype
-    as `element`.
+    A zero-filed `tf.Tensor` of shape `(s[:-1], k, s[-1])`, where
+    `s = tf.shape(element)`, and same dtype as `element`.
   """
-  queue_shape = ps.concat([[k], ps.shape(element)], axis=0)
+  queue_shape = ps.concat(
+    [ps.shape(element)[:-1], [k], ps.shape(element)[-1:]], axis=0)
   return tf.zeros(queue_shape, dtype=dtype_util.base_dtype(element.dtype))
 
 
 def _queue_push(queue, should_update, new_vecs):
   """Conditionally push new vectors into a batch of first-in-first-out queues.
 
-  The `queue` of shape `[k, ..., n]` can be thought of as a batch of queues,
+  The `queue` of shape `[..., k, n]` can be thought of as a batch of queues,
   each holding `k` n-D vectors; while `new_vecs` of shape `[..., n]` is a
   fresh new batch of n-D vectors. The `should_update` batch of Boolean scalars,
   i.e. shape `[...]`, indicates batch members whose corresponding n-D vector in
@@ -439,54 +1559,50 @@ def _queue_push(queue, should_update, new_vecs):
   corresponding n-D vector from the front. Batch members in `new_vecs` for
   which `should_update` is False are ignored.
 
-  Note: the choice of placing `k` at the dimension 0 of the queue is
-  constrained by the L-BFGS two-loop algorithm above. The algorithm uses
-  tf.scan to iterate over the `k` correction pairs simulatneously across all
-  batches, and tf.scan itself can only iterate over dimension 0.
+  Note: whereas `lbfgs.py` places the `k` at dimension 0 due to constraints
+  of `tf.scan`, those do not apply here, and in fact it is more advantageous
+  to have the batch dimensions before `k`.
 
   For example:
 
   ```python
-    k, b, n = (3, 2, 5)
-    queue = tf.reshape(tf.range(30), (k, b, n))
+    b, k, n = (2, 3, 5)
+    queue = tf.reshape(tf.range(30), (b, k, n))
     # => [[[ 0,  1,  2,  3,  4],
-    #      [ 5,  6,  7,  8,  9]],
-    #
-    #     [[10, 11, 12, 13, 14],
-    #      [15, 16, 17, 18, 19]],
-    #
-    #     [[20, 21, 22, 23, 24],
-    #      [25, 26, 27, 28, 29]]]
+    #      [ 5,  6,  7,  8,  9],
+    #      [10, 11, 12, 13, 14]],
+    #    [[15, 16, 17, 18, 19],
+    #     [20, 21, 22, 23, 24],
+    #     [25, 26, 27, 28, 29]]]
 
     element = tf.reshape(tf.range(30, 40), (b, n))
     # => [[30, 31, 32, 33, 34],
-          [35, 36, 37, 38, 39]]
+      [35, 36, 37, 38, 39]]
 
     should_update = tf.constant([True, False])  # Shape: (b,)
 
-    _queue_add(should_update, queue, element)
-    # => [[[10, 11, 12, 13, 14],
-    #      [ 5,  6,  7,  8,  9]],
-    #
-    #     [[20, 21, 22, 23, 24],
-    #      [15, 16, 17, 18, 19]],
-    #
-    #     [[30, 31, 32, 33, 34],
+    _queue_push(queue, should_update, element)
+    # => [[[ 5,  6,  7,  8,  9],
+    #      [10, 11, 12, 13, 14],
+    #      [30, 31, 32, 33, 34]],
+    #     [[15, 16, 17, 18, 19],
+    #      [20, 21, 22, 23, 24],
     #      [25, 26, 27, 28, 29]]]
   ```
 
   Args:
-    queue: A `tf.Tensor` of shape `[k, ..., n]`; a batch of queues each with
-      `k` n-D vectors.
+    queue: A `tf.Tensor` of shape `[..., k, n]`; a batch of queues each with
+    `k` n-D vectors.
     should_update: A Boolean `tf.Tensor` of shape `[...]` indicating batch
-      members where new vectors should be added to their queues.
+    members where new vectors should be added to their queues.
     new_vecs: A `tf.Tensor` of shape `[..., n]`; a batch of n-D vectors to add
-      at the end of their respective queues, pushing out the first element from
-      each.
+    at the end of their respective queues, pushing out the first element from
+    each.
 
   Returns:
-    A new `tf.Tensor` of shape `[k, ..., n]`.
+    A new `tf.Tensor` of shape `[..., k, n]`.
   """
-  new_queue = tf.concat([queue[1:], [new_vecs]], axis=0)
+  new_queue = tf.concat(
+    [queue[..., 1:, :], new_vecs[..., tf.newaxis, :]], axis=-2)
   return tf.where(
-      should_update[tf.newaxis, ..., tf.newaxis], new_queue, queue)
+    should_update[..., tf.newaxis, tf.newaxis], new_queue, queue)

From 64d30a428ea3d81b08c5a23faad2795372411084 Mon Sep 17 00:00:00 2001
From: mikeevmm <miguelmurca@gmail.com>
Date: Fri, 25 Jun 2021 14:55:57 +0100
Subject: [PATCH 4/4] wip: test file base

---
 .../python/optimizer/lbfgsb_test.py           | 573 ++++++++++++++++++
 1 file changed, 573 insertions(+)
 create mode 100644 tensorflow_probability/python/optimizer/lbfgsb_test.py

diff --git a/tensorflow_probability/python/optimizer/lbfgsb_test.py b/tensorflow_probability/python/optimizer/lbfgsb_test.py
new file mode 100644
index 0000000000..03dd8fdf9c
--- /dev/null
+++ b/tensorflow_probability/python/optimizer/lbfgsb_test.py
@@ -0,0 +1,573 @@
+# Copyright 2018 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the constrained L-BFGS-B optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from absl.testing import parameterized
+import numpy as np
+from scipy.stats import special_ortho_group
+
+import tensorflow.compat.v1 as tf1
+import tensorflow.compat.v2 as tf
+import tensorflow_probability as tfp
+
+from tensorflow_probability.python.internal import test_util
+
+
+def _make_val_and_grad_fn(value_fn):
+  @functools.wraps(value_fn)
+  def val_and_grad(x):
+    return tfp.math.value_and_gradient(value_fn, x)
+  return val_and_grad
+
+
+def _norm(x):
+  return np.linalg.norm(x, np.inf)
+
+
+@test_util.test_all_tf_execution_regimes
+class LBfgsTest(test_util.TestCase):
+  """Tests for LBFGSB optimization algorithm."""
+
+  def test_quadratic_bowl_2d(self):
+    """Can minimize a two dimensional quadratic function when constrained."""
+    minimum = np.array([1.0, 1.0])
+    scales = np.array([2.0, 3.0])
+    lower_bounds = np.array([0., 2.])
+    upper_bounds = np.array([2., 5.])
+    expected = np.array([1.0, 2.0])
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum))
+
+    start = tf.constant([0.6, 0.8])
+    results = self.evaluate(tfp.optimizer.lbfgsb_minimize(
+        quadratic, initial_position=start, tolerance=1e-8,
+        lower_bounds=lower_bounds, upper_bounds=upper_bounds))
+    self.assertTrue(results.converged)
+    self.assertLessEqual(_norm(results.objective_gradient), 1e-8)
+    self.assertArrayNear(results.position, expected, 1e-5)
+
+  # TODO:
+  def test_high_dims_quadratic_bowl_trivial(self):
+    """Can minimize a high-dimensional trivial bowl (sphere)."""
+    ndims = 100
+    minimum = np.ones([ndims], dtype='float64')
+    scales = np.ones([ndims], dtype='float64')
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum))
+
+    start = np.zeros([ndims], dtype='float64')
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        quadratic, initial_position=start, tolerance=1e-8))
+    self.assertTrue(results.converged)
+    self.assertEqual(results.num_iterations, 1)  # Solved by first line search.
+    self.assertLessEqual(_norm(results.objective_gradient), 1e-8)
+    self.assertArrayNear(results.position, minimum, 1e-5)
+
+  # TODO:
+  def test_quadratic_bowl_40d(self):
+    """Can minimize a high-dimensional quadratic function."""
+    dim = 40
+    np.random.seed(14159)
+    minimum = np.random.randn(dim)
+    scales = np.exp(np.random.randn(dim))
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum))
+
+    start = tf.ones_like(minimum)
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        quadratic, initial_position=start, tolerance=1e-8))
+    self.assertTrue(results.converged)
+    self.assertLessEqual(_norm(results.objective_gradient), 1e-8)
+    self.assertArrayNear(results.position, minimum, 1e-5)
+
+  # TODO:
+  def test_quadratic_with_skew(self):
+    """Can minimize a general quadratic function."""
+    dim = 50
+    np.random.seed(26535)
+    minimum = np.random.randn(dim)
+    principal_values = np.diag(np.exp(np.random.randn(dim)))
+    rotation = special_ortho_group.rvs(dim)
+    hessian = np.dot(np.transpose(rotation), np.dot(principal_values, rotation))
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      y = x - minimum
+      yp = tf.tensordot(hessian, y, axes=[1, 0])
+      return tf.reduce_sum(y * yp) / 2
+
+    start = tf.ones_like(minimum)
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        quadratic, initial_position=start, tolerance=1e-8))
+    self.assertTrue(results.converged)
+    self.assertLessEqual(_norm(results.objective_gradient), 1e-8)
+    self.assertArrayNear(results.position, minimum, 1e-5)
+
+  # TODO:
+  def test_quadratic_with_strong_skew(self):
+    """Can minimize a strongly skewed quadratic function."""
+    np.random.seed(89793)
+    minimum = np.random.randn(3)
+    principal_values = np.diag(np.array([0.1, 2.0, 50.0]))
+    rotation = special_ortho_group.rvs(3)
+    hessian = np.dot(np.transpose(rotation), np.dot(principal_values, rotation))
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      y = x - minimum
+      yp = tf.tensordot(hessian, y, axes=[1, 0])
+      return tf.reduce_sum(y * yp) / 2
+
+    start = tf.ones_like(minimum)
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        quadratic, initial_position=start, tolerance=1e-8))
+    self.assertTrue(results.converged)
+    self.assertLessEqual(_norm(results.objective_gradient), 1e-8)
+    self.assertArrayNear(results.position, minimum, 1e-5)
+
+  # TODO:
+  def test_rosenbrock_2d(self):
+    """Tests L-BFGS on the Rosenbrock function.
+
+    The Rosenbrock function is a standard optimization test case. In two
+    dimensions, the function is (a, b > 0):
+      f(x, y) = (a - x)^2 + b (y - x^2)^2
+    The function has a global minimum at (a, a^2). This minimum lies inside
+    a parabolic valley (y = x^2).
+    """
+    def rosenbrock(coord):
+      """The Rosenbrock function in two dimensions with a=1, b=100.
+
+      Args:
+        coord: A Tensor of shape [2]. The coordinate of the point to evaluate
+          the function at.
+
+      Returns:
+        fv: A scalar tensor containing the value of the Rosenbrock function at
+          the supplied point.
+        dfx: Scalar tensor. The derivative of the function with respect to x.
+        dfy: Scalar tensor. The derivative of the function with respect to y.
+      """
+      x, y = coord[0], coord[1]
+      fv = (1 - x)**2 + 100 * (y - x**2)**2
+      dfx = 2 * (x - 1) + 400 * x * (x**2 - y)
+      dfy = 200 * (y - x**2)
+      return fv, tf.stack([dfx, dfy])
+
+    start = tf.constant([-1.2, 1.0])
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        rosenbrock, initial_position=start, tolerance=1e-5))
+    self.assertTrue(results.converged)
+    self.assertLessEqual(_norm(results.objective_gradient), 1e-5)
+    self.assertArrayNear(results.position, np.array([1.0, 1.0]), 1e-5)
+
+  # TODO:
+  def test_himmelblau(self):
+    """Tests minimization on the Himmelblau's function.
+
+    Himmelblau's function is a standard optimization test case. The function is
+    given by:
+
+      f(x, y) = (x^2 + y - 11)^2 + (x + y^2 - 7)^2
+
+    The function has four minima located at (3, 2), (-2.805118, 3.131312),
+    (-3.779310, -3.283186), (3.584428, -1.848126).
+
+    All these minima may be reached from appropriate starting points. To keep
+    the runtime of this test small, here we only find the first two minima.
+    However, all four can be easily found in `test_himmelblau_batch_all` below
+    with the help of batching.
+    """
+    @_make_val_and_grad_fn
+    def himmelblau(coord):
+      x, y = coord[0], coord[1]
+      return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2
+
+    starts_and_targets = [
+        # Start Point, Target Minimum, Num evaluations expected.
+        [(1, 1), (3, 2), 31],
+        [(-2, 2), (-2.805118, 3.131312), 17],
+    ]
+    dtype = 'float64'
+    for start, expected_minima, expected_evals in starts_and_targets:
+      start = tf.constant(start, dtype=dtype)
+      results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+          himmelblau, initial_position=start, tolerance=1e-8))
+      self.assertTrue(results.converged)
+      self.assertArrayNear(results.position,
+                           np.array(expected_minima, dtype=dtype),
+                           1e-5)
+      self.assertEqual(results.num_objective_evaluations, expected_evals)
+
+  # TODO:
+  def test_himmelblau_batch_all(self):
+    @_make_val_and_grad_fn
+    def himmelblau(coord):
+      x, y = coord[..., 0], coord[..., 1]
+      return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2
+
+    dtype = 'float64'
+    starts = tf.constant([[1, 1],
+                          [-2, 2],
+                          [-1, -1],
+                          [1, -2]], dtype=dtype)
+    expected_minima = np.array([[3, 2],
+                                [-2.805118, 3.131312],
+                                [-3.779310, -3.283186],
+                                [3.584428, -1.848126]], dtype=dtype)
+    batch_results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        himmelblau, initial_position=starts,
+        stopping_condition=tfp.optimizer.converged_all, tolerance=1e-8))
+
+    self.assertFalse(np.any(batch_results.failed))  # None have failed.
+    self.assertTrue(np.all(batch_results.converged))  # All converged.
+
+    # All converged points are near expected minima.
+    for actual, expected in zip(batch_results.position, expected_minima):
+      self.assertArrayNear(actual, expected, 1e-5)
+    self.assertEqual(batch_results.num_objective_evaluations, 36)
+
+  # TODO:
+  def test_himmelblau_batch_any(self):
+    @_make_val_and_grad_fn
+    def himmelblau(coord):
+      x, y = coord[..., 0], coord[..., 1]
+      return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2
+
+    dtype = 'float64'
+    starts = tf.constant([[1, 1],
+                          [-2, 2],
+                          [-1, -1],
+                          [1, -2]], dtype=dtype)
+    expected_minima = np.array([[3, 2],
+                                [-2.805118, 3.131312],
+                                [-3.779310, -3.283186],
+                                [3.584428, -1.848126]], dtype=dtype)
+
+    # Run with `converged_any` stopping condition, to stop as soon as any of
+    # the batch members have converged.
+    batch_results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        himmelblau, initial_position=starts,
+        stopping_condition=tfp.optimizer.converged_any, tolerance=1e-8))
+
+    self.assertFalse(np.any(batch_results.failed))  # None have failed.
+    self.assertTrue(np.any(batch_results.converged))  # At least one converged.
+    self.assertFalse(np.all(batch_results.converged))  # But not all did.
+
+    # Converged points are near expected minima.
+    for actual, expected in zip(batch_results.position[batch_results.converged],
+                                expected_minima[batch_results.converged]):
+      self.assertArrayNear(actual, expected, 1e-5)
+    self.assertEqual(batch_results.num_objective_evaluations, 28)
+
+  # TODO:
+  def test_himmelblau_batch_any_resume_then_all(self):
+    @_make_val_and_grad_fn
+    def himmelblau(coord):
+      x, y = coord[..., 0], coord[..., 1]
+      return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2
+
+    dtype = 'float64'
+    starts = tf.constant([[1, 1],
+                          [-2, 2],
+                          [-1, -1],
+                          [1, -2]], dtype=dtype)
+    expected_minima = np.array([[3, 2],
+                                [-2.805118, 3.131312],
+                                [-3.779310, -3.283186],
+                                [3.584428, -1.848126]], dtype=dtype)
+
+    # Run with `converged_any` stopping condition, to stop as soon as any of
+    # the batch members have converged.
+    raw_batch_results = tfp.optimizer.lbfgs_minimize(
+        himmelblau, initial_position=starts,
+        stopping_condition=tfp.optimizer.converged_any, tolerance=1e-8)
+    batch_results = self.evaluate(raw_batch_results)
+
+    self.assertFalse(np.any(batch_results.failed))  # None have failed.
+    self.assertTrue(np.any(batch_results.converged))  # At least one converged.
+    self.assertFalse(np.all(batch_results.converged))  # But not all did.
+
+    # Converged points are near expected minima.
+    for actual, expected in zip(batch_results.position[batch_results.converged],
+                                expected_minima[batch_results.converged]):
+      self.assertArrayNear(actual, expected, 1e-5)
+    self.assertEqual(batch_results.num_objective_evaluations, 28)
+
+    # Run with `converged_all`, starting from previous state.
+    batch_results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        himmelblau, initial_position=None,
+        previous_optimizer_results=raw_batch_results,
+        stopping_condition=tfp.optimizer.converged_all, tolerance=1e-8))
+
+    # All converged points are near expected minima and the nunmber of
+    # evaluaitons is as if we never stopped.
+    for actual, expected in zip(batch_results.position, expected_minima):
+      self.assertArrayNear(actual, expected, 1e-5)
+    self.assertEqual(batch_results.num_objective_evaluations, 36)
+
+  # TODO:
+  def test_initial_position_and_previous_optimizer_results_are_exclusive(self):
+    minimum = np.array([1.0, 1.0])
+    scales = np.array([2.0, 3.0])
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum))
+
+    start = tf.constant([0.6, 0.8])
+
+    def run(position, state):
+      raw_results = tfp.optimizer.lbfgs_minimize(
+          quadratic, initial_position=position,
+          previous_optimizer_results=state, tolerance=1e-8)
+      self.evaluate(raw_results)
+      return raw_results
+
+    self.assertRaises(ValueError, run, None, None)
+    results = run(start, None)
+    self.assertRaises(ValueError, run, start, results)
+
+  # TODO:
+  def test_data_fitting(self):
+    """Tests MLE estimation for a simple geometric GLM."""
+    n, dim = 100, 30
+    dtype = tf.float64
+    np.random.seed(234095)
+    x = np.random.choice([0, 1], size=[dim, n])
+    s = 0.01 * np.sum(x, 0)
+    p = 1. / (1 + np.exp(-s))
+    y = np.random.geometric(p)
+    x_data = tf.convert_to_tensor(x, dtype=dtype)
+    y_data = tf.convert_to_tensor(y, dtype=dtype)[..., tf.newaxis]
+
+    @_make_val_and_grad_fn
+    def neg_log_likelihood(state):
+      state_ext = tf.expand_dims(state, 0)
+      linear_part = tf.matmul(state_ext, x_data)
+      linear_part_ex = tf.stack([tf.zeros_like(linear_part),
+                                 linear_part], axis=0)
+      term1 = tf.squeeze(
+          tf.matmul(
+              tf.reduce_logsumexp(linear_part_ex, axis=0), y_data),
+          -1)
+      term2 = (
+          0.5 * tf.reduce_sum(state_ext * state_ext, axis=-1) -
+          tf.reduce_sum(linear_part, axis=-1))
+      return  tf.squeeze(term1 + term2)
+
+    start = tf.ones(shape=[dim], dtype=dtype)
+
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        neg_log_likelihood, initial_position=start, tolerance=1e-6))
+    self.assertTrue(results.converged)
+
+  # TODO:
+  def test_determinism(self):
+    """Tests that the results are determinsitic."""
+    dim = 25
+
+    @_make_val_and_grad_fn
+    def rastrigin(x):
+      """The value and gradient of the Rastrigin function.
+
+      The Rastrigin function is a standard optimization test case. It is a
+      multimodal non-convex function. While it has a large number of local
+      minima, the global minimum is located at the origin and where the function
+      value is zero. The standard search domain for optimization problems is the
+      hypercube [-5.12, 5.12]**d in d-dimensions.
+
+      Args:
+        x: Real `Tensor` of shape [2]. The position at which to evaluate the
+          function.
+
+      Returns:
+        value_and_gradient: A tuple of two `Tensor`s containing
+          value: A scalar `Tensor` of the function value at the supplied point.
+          gradient: A `Tensor` of shape [2] containing the gradient of the
+            function along the two axes.
+      """
+      return tf.reduce_sum(x**2 - 10.0 * tf.cos(2 * np.pi * x)) + 10.0 * dim
+
+    start_position = np.random.rand(dim) * 2.0 * 5.12 - 5.12
+
+    def get_results():
+      start = tf.constant(start_position)
+      return self.evaluate(tfp.optimizer.lbfgs_minimize(
+          rastrigin, initial_position=start, tolerance=1e-5))
+
+    res1, res2 = get_results(), get_results()
+
+    self.assertTrue(res1.converged)
+    self.assertEqual(res1.converged, res2.converged)
+    self.assertEqual(res1.failed, res2.failed)
+    self.assertEqual(res1.num_objective_evaluations,
+                     res2.num_objective_evaluations)
+    self.assertArrayNear(res1.position, res2.position, 1e-5)
+    self.assertAlmostEqual(res1.objective_value, res2.objective_value)
+    self.assertArrayNear(res1.objective_gradient, res2.objective_gradient, 1e-5)
+    self.assertArrayNear(res1.position_deltas.reshape([-1]),
+                         res2.position_deltas.reshape([-1]), 1e-5)
+    self.assertArrayNear(res1.gradient_deltas.reshape([-1]),
+                         res2.gradient_deltas.reshape([-1]), 1e-5)
+
+  # TODO:
+  def test_compile(self):
+    """Tests that the computation can be XLA-compiled."""
+
+    self.skip_if_no_xla()
+
+    dim = 25
+
+    @_make_val_and_grad_fn
+    def rastrigin(x):
+      """The value and gradient of the Rastrigin function.
+
+      The Rastrigin function is a standard optimization test case. It is a
+      multimodal non-convex function. While it has a large number of local
+      minima, the global minimum is located at the origin and where the function
+      value is zero. The standard search domain for optimization problems is the
+      hypercube [-5.12, 5.12]**d in d-dimensions.
+
+      Args:
+        x: Real `Tensor` of shape [2]. The position at which to evaluate the
+          function.
+
+      Returns:
+        value_and_gradient: A tuple of two `Tensor`s containing
+          value: A scalar `Tensor` of the function value at the supplied point.
+          gradient: A `Tensor` of shape [2] containing the gradient of the
+            function along the two axes.
+      """
+      return tf.reduce_sum(x**2 - 10.0 * tf.cos(2 * np.pi * x)) + 10.0 * dim
+
+    start_position = np.random.rand(dim) * 2.0 * 5.12 - 5.12
+
+    res = tf.function(tfp.optimizer.lbfgs_minimize, jit_compile=True)(
+        rastrigin,
+        initial_position=tf.constant(start_position),
+        tolerance=1e-5)
+
+    # We simply verify execution & convergence.
+    self.assertTrue(self.evaluate(res.converged))
+
+  # TODO:
+  def test_dynamic_shapes(self):
+    """Can build an lbfgs_op with dynamic shapes in graph mode."""
+    if tf.executing_eagerly(): return
+    ndims = 60
+    minimum = np.ones([ndims], dtype='float64')
+    scales = np.arange(ndims, dtype='float64') + minimum
+
+    @_make_val_and_grad_fn
+    def quadratic(x):
+      return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum))
+
+    # Test with a vector of unknown dimension, and a fully unknown shape.
+    for shape in ([None], None):
+      start = tf1.placeholder(tf.float32, shape=shape)
+      lbfgs_op = tfp.optimizer.lbfgs_minimize(
+          quadratic, initial_position=start, tolerance=1e-8)
+      self.assertFalse(lbfgs_op.position.shape.is_fully_defined())
+
+      start_value = np.arange(ndims, 0, -1, dtype='float64')
+      with self.cached_session() as session:
+        results = session.run(lbfgs_op, feed_dict={start: start_value})
+      self.assertTrue(results.converged)
+      self.assertLessEqual(_norm(results.objective_gradient), 1e-8)
+      self.assertArrayNear(results.position, minimum, 1e-5)
+
+  # TODO:
+  @parameterized.named_parameters(
+      [{'testcase_name': '_from_start', 'start': np.array([0.8, 0.8])},
+       {'testcase_name': '_during_opt', 'start': np.array([0.0, 0.0])},
+       {'testcase_name': '_mixed', 'start': np.array([[0.8, 0.8], [0.0, 0.0]])},
+      ])
+  def test_stop_at_negative_infinity(self, start):
+    """Stops gently when encountering a -inf objective."""
+    minimum = np.array([1.0, 1.0])
+    scales = np.array([2.0, 3.0])
+
+    @_make_val_and_grad_fn
+    def quadratic_with_hole(x):
+      quadratic = tf.reduce_sum(
+          scales * tf.math.squared_difference(x, minimum), axis=-1)
+      square_hole = tf.reduce_all(tf.logical_and((x > 0.7), (x < 1.3)), axis=-1)
+      minus_infty = tf.constant(float('-inf'), dtype=quadratic.dtype)
+      answer = tf.where(square_hole, minus_infty, quadratic)
+      return answer
+
+    start = tf.constant(start)
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        quadratic_with_hole, initial_position=start, tolerance=1e-8))
+    self.assertAllTrue(results.converged)
+    self.assertAllFalse(results.failed)
+    self.assertAllNegativeInf(results.objective_value)
+    self.assertAllFinite(results.position)
+    self.assertAllNegativeInf(quadratic_with_hole(results.position)[0])
+
+  # TODO:
+  @parameterized.named_parameters(
+      [{'testcase_name': '_from_start', 'start': np.array([0.8, 0.8])},
+       {'testcase_name': '_during_opt', 'start': np.array([0.0, 0.0])},
+       {'testcase_name': '_mixed', 'start': np.array([[0.8, 0.8], [0.0, 0.0]])},
+      ])
+  def test_fail_at_non_finite(self, start):
+    """Fails promptly when encountering a non-finite but not -inf objective."""
+    # Meaning, +inf (tested here) and nan (not tested separately due to nearly
+    # identical code paths) objective values cause a "stop with failure".
+    # Actually, there is a further nitpick: +inf is currently treated a little
+    # inconsistently.  To wit, if the outer loop hits a +inf, it gives up and
+    # reports failure, because it assumes the gradient from a +inf value is
+    # garbage and no further progress is possible.  However, if the line search
+    # encounters an intermediate +inf, it in some cases knows to just treat it
+    # as a large finite value and avoid it.  So in principle, minimizing this
+    # test function starting outside the +inf region could stop at the actual
+    # minimum at the edge of said +inf region.  However, currently it happens to
+    # fail.
+    minimum = np.array([1.0, 1.0])
+    scales = np.array([2.0, 3.0])
+
+    @_make_val_and_grad_fn
+    def quadratic_with_spike(x):
+      quadratic = tf.reduce_sum(
+          scales * tf.math.squared_difference(x, minimum), axis=-1)
+      square_hole = tf.reduce_all(tf.logical_and((x > 0.7), (x < 1.3)), axis=-1)
+      infty = tf.constant(float('+inf'), dtype=quadratic.dtype)
+      answer = tf.where(square_hole, infty, quadratic)
+      return answer
+
+    start = tf.constant(start)
+    results = self.evaluate(tfp.optimizer.lbfgs_minimize(
+        quadratic_with_spike, initial_position=start, tolerance=1e-8))
+    self.assertAllFalse(results.converged)
+    self.assertAllTrue(results.failed)
+    self.assertAllFinite(results.position)
+
+
+if __name__ == '__main__':
+  tf.test.main()