1
1
import cvxpy as cp
2
2
import numpy as np
3
+ from plotter import SNMFPlotter
3
4
from scipy .optimize import minimize
4
5
from scipy .sparse import coo_matrix , diags
5
6
@@ -73,6 +74,7 @@ def __init__(
73
74
tol = 5e-7 ,
74
75
n_components = None ,
75
76
random_state = None ,
77
+ show_plots = False ,
76
78
):
77
79
"""Initialize an instance of SNMF and run the optimization.
78
80
@@ -112,6 +114,8 @@ def __init__(
112
114
random_state : int Optional Default = None
113
115
The seed for the initial guesses at the matrices (A, X, and Y) created by
114
116
the decomposition.
117
+ show_plots : boolean Optional Default = False
118
+ Enables plotting at each step of the decomposition.
115
119
"""
116
120
117
121
self .source_matrix = source_matrix
@@ -123,6 +127,7 @@ def __init__(
123
127
self .signal_length , self .n_signals = source_matrix .shape
124
128
self .num_updates = 0
125
129
self ._rng = np .random .default_rng (random_state )
130
+ self .plotter = SNMFPlotter () if show_plots else None
126
131
127
132
# Enforce exclusive specification of n_components or init_weights
128
133
if (n_components is None and init_weights is None ) or (
@@ -236,6 +241,13 @@ def normalize_results(self):
236
241
print (f"Objective function after normalize_components: { self .objective_function :.5e} " )
237
242
self ._objective_history .append (self .objective_function )
238
243
self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
244
+ if self .plotter is not None :
245
+ self .plotter .update (
246
+ components = self .components ,
247
+ weights = self .weights ,
248
+ stretch = self .stretch ,
249
+ update_tag = "normalize components" ,
250
+ )
239
251
if self .objective_difference < self .objective_function * self .tol and outiter >= 7 :
240
252
break
241
253
@@ -252,6 +264,10 @@ def outer_loop(self):
252
264
if self .objective_function < self .best_objective :
253
265
self .best_objective = self .objective_function
254
266
self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
267
+ if self .plotter is not None :
268
+ self .plotter .update (
269
+ components = self .components , weights = self .weights , stretch = self .stretch , update_tag = "components"
270
+ )
255
271
256
272
self .update_weights ()
257
273
self .residuals = self .get_residual_matrix ()
@@ -262,6 +278,10 @@ def outer_loop(self):
262
278
if self .objective_function < self .best_objective :
263
279
self .best_objective = self .objective_function
264
280
self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
281
+ if self .plotter is not None :
282
+ self .plotter .update (
283
+ components = self .components , weights = self .weights , stretch = self .stretch , update_tag = "weights"
284
+ )
265
285
266
286
self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
267
287
if self ._objective_history [- 3 ] - self .objective_function < self .objective_difference * 1e-3 :
@@ -276,6 +296,10 @@ def outer_loop(self):
276
296
if self .objective_function < self .best_objective :
277
297
self .best_objective = self .objective_function
278
298
self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
299
+ if self .plotter is not None :
300
+ self .plotter .update (
301
+ components = self .components , weights = self .weights , stretch = self .stretch , update_tag = "stretch"
302
+ )
279
303
280
304
def get_residual_matrix (self , components = None , weights = None , stretch = None ):
281
305
# Initialize residual matrix as negative of source_matrix
0 commit comments