Skip to content

Commit 0bf62a8

Browse files
john-halloranJohn Halloran
andauthored
feat: add live plotting of updates (#166)
* feat: add live plotting of updates * style: make plotting vars lowercase --------- Co-authored-by: John Halloran <[email protected]>
1 parent 1b49701 commit 0bf62a8

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

src/diffpy/snmf/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
init_weights=init_weights_file,
1313
init_components=init_components_file,
1414
init_stretch=init_stretch_file,
15+
show_plots=True,
1516
)
1617

1718
print("Done")

src/diffpy/snmf/plotter.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
5+
class SNMFPlotter:
6+
def __init__(self, figsize=(12, 4)):
7+
plt.ion()
8+
self.fig, self.axes = plt.subplots(1, 3, figsize=figsize)
9+
titles = ["Components", "Weights (rows as series)", "Stretch (rows as series)"]
10+
for ax, t in zip(self.axes, titles):
11+
ax.set_title(t)
12+
self.lines = {"components": [], "weights": [], "stretch": []}
13+
self._layout_done = False
14+
plt.show()
15+
16+
def _ensure_lines(self, ax, key, n_series):
17+
cur = self.lines[key]
18+
if len(cur) != n_series:
19+
ax.cla()
20+
ax.set_title(ax.get_title())
21+
self.lines[key] = [ax.plot([], [])[0] for _ in range(n_series)]
22+
return self.lines[key]
23+
24+
def _update_series(self, ax, key, data_2d):
25+
# Expect rows = separate series for components
26+
data_2d = np.atleast_2d(data_2d)
27+
n_series, n_pts = data_2d.shape
28+
lines = self._ensure_lines(ax, key, n_series)
29+
x = np.arange(n_pts)
30+
for ln, y in zip(lines, data_2d):
31+
ln.set_data(x, y)
32+
ax.relim()
33+
ax.autoscale_view()
34+
35+
def update(self, components, weights, stretch, update_tag=None):
36+
# Components: transpose before plotting
37+
c = np.asarray(components).T
38+
self._update_series(self.axes[0], "components", c)
39+
40+
w = np.asarray(weights)
41+
self._update_series(self.axes[1], "weights", w)
42+
43+
s = np.asarray(stretch)
44+
self._update_series(self.axes[2], "stretch", s)
45+
46+
if update_tag is not None:
47+
self.fig.suptitle(f"Updated: {update_tag}", fontsize=14)
48+
49+
if not self._layout_done:
50+
self.fig.tight_layout()
51+
self._layout_done = True
52+
53+
self.fig.canvas.draw()
54+
self.fig.canvas.flush_events()
55+
plt.pause(0.001)

src/diffpy/snmf/snmf_class.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import cvxpy as cp
22
import numpy as np
3+
from plotter import SNMFPlotter
34
from scipy.optimize import minimize
45
from scipy.sparse import coo_matrix, diags
56

@@ -73,6 +74,7 @@ def __init__(
7374
tol=5e-7,
7475
n_components=None,
7576
random_state=None,
77+
show_plots=False,
7678
):
7779
"""Initialize an instance of SNMF and run the optimization.
7880
@@ -112,6 +114,8 @@ def __init__(
112114
random_state : int Optional Default = None
113115
The seed for the initial guesses at the matrices (A, X, and Y) created by
114116
the decomposition.
117+
show_plots : boolean Optional Default = False
118+
Enables plotting at each step of the decomposition.
115119
"""
116120

117121
self.source_matrix = source_matrix
@@ -123,6 +127,7 @@ def __init__(
123127
self.signal_length, self.n_signals = source_matrix.shape
124128
self.num_updates = 0
125129
self._rng = np.random.default_rng(random_state)
130+
self.plotter = SNMFPlotter() if show_plots else None
126131

127132
# Enforce exclusive specification of n_components or init_weights
128133
if (n_components is None and init_weights is None) or (
@@ -236,6 +241,13 @@ def normalize_results(self):
236241
print(f"Objective function after normalize_components: {self.objective_function:.5e}")
237242
self._objective_history.append(self.objective_function)
238243
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
244+
if self.plotter is not None:
245+
self.plotter.update(
246+
components=self.components,
247+
weights=self.weights,
248+
stretch=self.stretch,
249+
update_tag="normalize components",
250+
)
239251
if self.objective_difference < self.objective_function * self.tol and outiter >= 7:
240252
break
241253

@@ -252,6 +264,10 @@ def outer_loop(self):
252264
if self.objective_function < self.best_objective:
253265
self.best_objective = self.objective_function
254266
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
267+
if self.plotter is not None:
268+
self.plotter.update(
269+
components=self.components, weights=self.weights, stretch=self.stretch, update_tag="components"
270+
)
255271

256272
self.update_weights()
257273
self.residuals = self.get_residual_matrix()
@@ -262,6 +278,10 @@ def outer_loop(self):
262278
if self.objective_function < self.best_objective:
263279
self.best_objective = self.objective_function
264280
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
281+
if self.plotter is not None:
282+
self.plotter.update(
283+
components=self.components, weights=self.weights, stretch=self.stretch, update_tag="weights"
284+
)
265285

266286
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
267287
if self._objective_history[-3] - self.objective_function < self.objective_difference * 1e-3:
@@ -276,6 +296,10 @@ def outer_loop(self):
276296
if self.objective_function < self.best_objective:
277297
self.best_objective = self.objective_function
278298
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
299+
if self.plotter is not None:
300+
self.plotter.update(
301+
components=self.components, weights=self.weights, stretch=self.stretch, update_tag="stretch"
302+
)
279303

280304
def get_residual_matrix(self, components=None, weights=None, stretch=None):
281305
# Initialize residual matrix as negative of source_matrix

0 commit comments

Comments
 (0)