diff --git a/src/diffpy/snmf/snmf_class.py b/src/diffpy/snmf/snmf_class.py index 83fdb04..2983e57 100644 --- a/src/diffpy/snmf/snmf_class.py +++ b/src/diffpy/snmf/snmf_class.py @@ -303,18 +303,48 @@ def outer_loop(self): ) def get_residual_matrix(self, components=None, weights=None, stretch=None): - # Initialize residual matrix as negative of source_matrix + """ + Return the residuals (difference) between the source matrix and its reconstruction + from the given components, weights, and stretch factors. + + Each component profile is stretched, interpolated to fractional positions, + weighted per signal, and summed to form the reconstruction. The residuals + are the source matrix minus this reconstruction. + + Parameters + ---------- + components : (signal_len, n_components) array, optional + weights : (n_components, n_signals) array, optional + stretch : (n_components, n_signals) array, optional + + Returns + ------- + residuals : (signal_len, n_signals) array + """ + if components is None: components = self.components if weights is None: weights = self.weights if stretch is None: stretch = self.stretch + residuals = -self.source_matrix.copy() - # Compute transformed components for all (k, m) pairs - for k in range(weights.shape[0]): # K - stretched_components, _, _ = apply_interpolation(stretch[k, :], components[:, k]) # Only use Ax - residuals += weights[k, :] * stretched_components # Element-wise scaling and sum + sample_indices = np.arange(components.shape[0]) # (signal_len,) + + for comp in range(components.shape[1]): # loop over components + residuals += ( + np.interp( + sample_indices[:, None] + / stretch[comp][None, :], # fractional positions (signal_len, n_signals) + sample_indices, # (signal_len,) + components[:, comp], # component profile (signal_len,) + left=components[0, comp], + right=components[-1, comp], + ) + * weights[comp][None, :] # broadcast (n_signals,) over rows + ) + return residuals def get_objective_function(self, residuals=None, stretch=None): @@ -579,42 +609,47 @@ def update_components(self): def update_weights(self): """ - Updates weights using matrix operations, solving a quadratic program to do so. + Updates weights by building the stretched component matrix `stretched_comps` with np.interp + and solving a quadratic program for each signal. """ - signal_length = self.signal_length - n_signals = self.n_signals - - for m in range(n_signals): - t = np.zeros((signal_length, self.n_components)) - - # Populate t using apply_interpolation - for k in range(self.n_components): - t[:, k] = apply_interpolation(self.stretch[k, m], self.components[:, k])[0].squeeze() - - # Solve quadratic problem for y - y = self.solve_quadratic_program(t=t, m=m) + sample_indices = np.arange(self.signal_length) + for signal in range(self.n_signals): + # Stretch factors for this signal across components: + this_stretch = self.stretch[:, signal] + # Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp] + stretched_comps = np.empty((self.signal_length, self.n_components), dtype=self.components.dtype) + for comp in range(self.n_components): + pos = sample_indices / this_stretch[comp] + stretched_comps[:, comp] = np.interp( + pos, + sample_indices, + self.components[:, comp], + left=self.components[0, comp], + right=self.components[-1, comp], + ) - # Update Y - self.weights[:, m] = y + # Solve quadratic problem for a given signal and update its weight + new_weight = self.solve_quadratic_program(t=stretched_comps, m=signal) + self.weights[:, signal] = new_weight def regularize_function(self, stretch=None): if stretch is None: stretch = self.stretch - K = self.n_components - M = self.n_signals - N = self.signal_length - stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch) - intermediate = stretched_components.flatten(order="F").reshape((N * M, K), order="F") - residuals = intermediate.sum(axis=1).reshape((N, M), order="F") - self.source_matrix + intermediate = stretched_components.flatten(order="F").reshape( + (self.signal_length * self.n_signals, self.n_components), order="F" + ) + residuals = ( + intermediate.sum(axis=1).reshape((self.signal_length, self.n_signals), order="F") - self.source_matrix + ) fun = self.get_objective_function(residuals, stretch) - tiled_res = np.tile(residuals, (1, K)) + tiled_res = np.tile(residuals, (1, self.n_components)) grad_flat = np.sum(d_stretch_comps * tiled_res, axis=0) - gra = grad_flat.reshape((M, K), order="F").T + gra = grad_flat.reshape((self.n_signals, self.n_components), order="F").T gra += self.rho * stretch @ (self._spline_smooth_operator.T @ self._spline_smooth_operator) # Hessian would go here @@ -623,10 +658,10 @@ def regularize_function(self, stretch=None): def update_stretch(self): """ - Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB). + Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB). """ - # Flatten A for compatibility with the optimizer (since SciPy expects 1D inputs) + # Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input) stretch_flat_initial = self.stretch.flatten() # Define the optimization function @@ -648,7 +683,7 @@ def objective(stretch_vec): bounds=bounds, ) - # Update A with the optimized values + # Update stretch with the optimized values self.stretch = result.x.reshape(self.stretch.shape) @@ -683,48 +718,3 @@ def cubic_largest_real_root(p, q): y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0 return y - - -def apply_interpolation(a, x): - """ - Applies an interpolation-based transformation to `x` based on scaling `a`. - Also computes first (`d_intr_x`) and second (`dd_intr_x`) derivatives. - """ - x_len = len(x) - - # Ensure `a` is an array and reshape for broadcasting - a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D - - # Compute fractional indices, broadcasting over `a` - fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M) - - integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M)) - valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds - - # Apply valid_mask to keep correct indices - idx_int = np.where(valid_mask, integer_indices, x_len - 2) # Prevent out-of-bounds indexing (previously "I") - idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i") - - # Ensure x is a 1D array - x = np.asarray(x).ravel() - - # Compute interpolated_x (linear interpolation) - interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * ( - idx_frac - idx_int - ) - - # Fill the tail with the last valid value - intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :]) - interpolated_x = np.vstack([interpolated_x, intr_x_tail]) - - # Compute first derivative (d_intr_x) - di = -idx_frac / a - d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di - d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))]) - - # Compute second derivative (dd_intr_x) - ddi = -di / a + idx_frac * a**-2 - dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi - dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))]) - - return interpolated_x, d_intr_x, dd_intr_x