Skip to content

Commit d232c67

Browse files
author
John Halloran
committed
fix: using working quadratic solver for weights
1 parent 4ff1eb0 commit d232c67

File tree

1 file changed

+29
-37
lines changed

1 file changed

+29
-37
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# import cvxpy as cp
1+
import cvxpy as cp
22
import numpy as np
33
from scipy.optimize import minimize
4-
from scipy.sparse import coo_matrix, csc_matrix, diags
4+
from scipy.sparse import coo_matrix, diags
55

66

77
class SNMFOptimizer:
@@ -451,56 +451,48 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
451451

452452
return stretch_transformed
453453

454-
def solve_quadratic_program(self, t, m, alg="trust-constr"):
454+
def solve_quadratic_program(self, t, m):
455455
"""
456-
Solves the quadratic program for updating y in stretched NMF using scipy.optimize:
456+
Solves the quadratic program for updating y in stretched NMF:
457457
458-
min J(y) = 0.5 * y^T Q y + d^T y
458+
min J(y) = 0.5 * y^T q y + d^T y
459459
subject to: 0 ≤ y ≤ 1
460460
461-
Uses the 'trust-constr' solver with the analytical gradient and Hessian.
462-
Alternatively, can use scipy's L-BFGS-B algorithm, which supports bound
463-
constraints.
464-
465461
Parameters:
466-
- t: (N, K) ndarray
467-
Matrix computed from getAfun(A(k, m), X[:, k]).
468-
- m: int
469-
Index of the current column in source_matrix.
462+
- t: (N, k) ndarray
463+
- source_matrix_col: (N,) column of source_matrix for the corresponding m
470464
471465
Returns:
472-
- y: (k,) ndarray
473-
Optimal solution for y, clipped to ensure non-negativity.
466+
- y: (k,) optimal solution
474467
"""
468+
475469
source_matrix_col = self.source_matrix[:, m]
476-
q = t.T @ t
477-
d = -t.T @ source_matrix_col
478-
k = q.shape[0]
479-
reg_factor = 1e-8 * np.linalg.norm(q, ord="fro")
480-
q += np.eye(k) * reg_factor
481470

482-
def objective(y):
483-
return 0.5 * y @ q @ y + d @ y
471+
# Compute q and d
472+
q = t.T @ t # Gram matrix (k x k)
473+
d = -t.T @ source_matrix_col # Linear term (k,)
484474

485-
def grad(y):
486-
return q @ y + d
475+
k = q.shape[0] # Number of variables
487476

488-
if alg == "trust-constr":
477+
# Regularize q to ensure positive semi-definiteness
478+
reg_factor = 1e-8 * np.linalg.norm(q, ord="fro") # Adaptive regularization, original was fixed
479+
q += np.eye(k) * reg_factor
489480

490-
def hess(y):
491-
return csc_matrix(q) # sparse format for efficiency
481+
# Define optimization variable
482+
y = cp.Variable(k)
492483

493-
bounds = [(0, 1)] * k
494-
y0 = np.clip(-np.linalg.solve(q + np.eye(k) * 1e-5, d), 0, 1)
495-
result = minimize(
496-
objective, y0, method="trust-constr", jac=grad, hess=hess, bounds=bounds, options={"verbose": 0}
497-
)
498-
elif alg == "L-BFGS-B":
499-
bounds = [(0, 1) for _ in range(k)] # per-variable bounds
500-
y0 = np.clip(-np.linalg.solve(q + np.eye(k) * 1e-5, d), 0, 1) # Initial guess
501-
result = minimize(objective, y0, method="L-BFGS-B", jac=grad, bounds=bounds)
484+
# Define quadratic objective
485+
objective = cp.Minimize(0.5 * cp.quad_form(y, q) + d.T @ y)
486+
487+
# Define constraints (0 ≤ y ≤ 1)
488+
constraints = [y >= 0, y <= 1]
489+
490+
# Solve using a QP solver
491+
prob = cp.Problem(objective, constraints)
492+
prob.solve(solver=cp.OSQP, verbose=False)
502493

503-
return np.maximum(result.x, 0)
494+
# Get the solution
495+
return np.maximum(y.value, 0) # Ensure non-negative values in case of solver tolerance issues
504496

505497
def update_components(self):
506498
"""

0 commit comments

Comments
 (0)