Skip to content

Replace 1D apply_interpolation with np.interp #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 66 additions & 76 deletions src/diffpy/snmf/snmf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,48 @@ def outer_loop(self):
)

def get_residual_matrix(self, components=None, weights=None, stretch=None):
# Initialize residual matrix as negative of source_matrix
"""
Return the residuals (difference) between the source matrix and its reconstruction
from the given components, weights, and stretch factors.

Each component profile is stretched, interpolated to fractional positions,
weighted per signal, and summed to form the reconstruction. The residuals
are the source matrix minus this reconstruction.

Parameters
----------
components : (signal_len, n_components) array, optional
weights : (n_components, n_signals) array, optional
stretch : (n_components, n_signals) array, optional

Returns
-------
residuals : (signal_len, n_signals) array
"""

if components is None:
components = self.components
if weights is None:
weights = self.weights
if stretch is None:
stretch = self.stretch

residuals = -self.source_matrix.copy()
# Compute transformed components for all (k, m) pairs
for k in range(weights.shape[0]): # K
stretched_components, _, _ = apply_interpolation(stretch[k, :], components[:, k]) # Only use Ax
residuals += weights[k, :] * stretched_components # Element-wise scaling and sum
sample_indices = np.arange(components.shape[0]) # (signal_len,)

for comp in range(components.shape[1]): # loop over components
residuals += (
np.interp(
sample_indices[:, None]
/ stretch[comp][None, :], # fractional positions (signal_len, n_signals)
sample_indices, # (signal_len,)
components[:, comp], # component profile (signal_len,)
left=components[0, comp],
right=components[-1, comp],
)
* weights[comp][None, :] # broadcast (n_signals,) over rows
)

return residuals

def get_objective_function(self, residuals=None, stretch=None):
Expand Down Expand Up @@ -579,42 +609,47 @@ def update_components(self):

def update_weights(self):
"""
Updates weights using matrix operations, solving a quadratic program to do so.
Updates weights by building the stretched component matrix `stretched_comps` with np.interp
and solving a quadratic program for each signal.
"""

signal_length = self.signal_length
n_signals = self.n_signals

for m in range(n_signals):
t = np.zeros((signal_length, self.n_components))

# Populate t using apply_interpolation
for k in range(self.n_components):
t[:, k] = apply_interpolation(self.stretch[k, m], self.components[:, k])[0].squeeze()

# Solve quadratic problem for y
y = self.solve_quadratic_program(t=t, m=m)
sample_indices = np.arange(self.signal_length)
for signal in range(self.n_signals):
# Stretch factors for this signal across components:
this_stretch = self.stretch[:, signal]
# Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
stretched_comps = np.empty((self.signal_length, self.n_components), dtype=self.components.dtype)
for comp in range(self.n_components):
pos = sample_indices / this_stretch[comp]
stretched_comps[:, comp] = np.interp(
pos,
sample_indices,
self.components[:, comp],
left=self.components[0, comp],
right=self.components[-1, comp],
)

# Update Y
self.weights[:, m] = y
# Solve quadratic problem for a given signal and update its weight
new_weight = self.solve_quadratic_program(t=stretched_comps, m=signal)
self.weights[:, signal] = new_weight

def regularize_function(self, stretch=None):
if stretch is None:
stretch = self.stretch

K = self.n_components
M = self.n_signals
N = self.signal_length

stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch)
intermediate = stretched_components.flatten(order="F").reshape((N * M, K), order="F")
residuals = intermediate.sum(axis=1).reshape((N, M), order="F") - self.source_matrix
intermediate = stretched_components.flatten(order="F").reshape(
(self.signal_length * self.n_signals, self.n_components), order="F"
)
residuals = (
intermediate.sum(axis=1).reshape((self.signal_length, self.n_signals), order="F") - self.source_matrix
)

fun = self.get_objective_function(residuals, stretch)

tiled_res = np.tile(residuals, (1, K))
tiled_res = np.tile(residuals, (1, self.n_components))
grad_flat = np.sum(d_stretch_comps * tiled_res, axis=0)
gra = grad_flat.reshape((M, K), order="F").T
gra = grad_flat.reshape((self.n_signals, self.n_components), order="F").T
gra += self.rho * stretch @ (self._spline_smooth_operator.T @ self._spline_smooth_operator)

# Hessian would go here
Expand All @@ -623,10 +658,10 @@ def regularize_function(self, stretch=None):

def update_stretch(self):
"""
Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB).
Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).
"""

# Flatten A for compatibility with the optimizer (since SciPy expects 1D inputs)
# Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input)
stretch_flat_initial = self.stretch.flatten()

# Define the optimization function
Expand All @@ -648,7 +683,7 @@ def objective(stretch_vec):
bounds=bounds,
)

# Update A with the optimized values
# Update stretch with the optimized values
self.stretch = result.x.reshape(self.stretch.shape)


Expand Down Expand Up @@ -683,48 +718,3 @@ def cubic_largest_real_root(p, q):
y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0

return y


def apply_interpolation(a, x):
"""
Applies an interpolation-based transformation to `x` based on scaling `a`.
Also computes first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
"""
x_len = len(x)

# Ensure `a` is an array and reshape for broadcasting
a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D

# Compute fractional indices, broadcasting over `a`
fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M)

integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M))
valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds

# Apply valid_mask to keep correct indices
idx_int = np.where(valid_mask, integer_indices, x_len - 2) # Prevent out-of-bounds indexing (previously "I")
idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i")

# Ensure x is a 1D array
x = np.asarray(x).ravel()

# Compute interpolated_x (linear interpolation)
interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * (
idx_frac - idx_int
)

# Fill the tail with the last valid value
intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :])
interpolated_x = np.vstack([interpolated_x, intr_x_tail])

# Compute first derivative (d_intr_x)
di = -idx_frac / a
d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di
d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))])

# Compute second derivative (dd_intr_x)
ddi = -di / a + idx_frac * a**-2
dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi
dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))])

return interpolated_x, d_intr_x, dd_intr_x