@@ -203,6 +203,177 @@ def evaluate(self, x):
203203 y [i ] = g (x )
204204 return y
205205
206+ def filter (self , s , method = 'chebyshev' , order = 30 ):
207+ r"""
208+ Filter signals with the filter bank (analysis or synthesis).
209+
210+ A signal is defined as a rank-3 tensor of shape ``(N_NODES, N_SIGNALS,
211+ N_FEATURES)``, where ``N_NODES`` is the number of nodes in the graph,
212+ ``N_SIGNALS`` is the number of independent signals, and ``N_FEATURES``
213+ is the number of features which compose a graph signal, or the
214+ dimensionality of a graph signal. For example if you filter a signal
215+ with a filter bank of 8 filters, you're extracting 8 features and
216+ decomposing your signal into 8 parts. That is called analysis. Your are
217+ thus transforming your signal tensor from ``(G.N, 1, 1)`` to ``(G.N, 1,
218+ 8)``. Now you may want to combine back the features to form an unique
219+ signal. For this you apply again 8 filters, one filter per feature, and
220+ sum the result up. As such you're transforming your ``(G.N, 1, 8)``
221+ tensor signal back to ``(G.N, 1, 1)``. That is known as synthesis. More
222+ generally, you may want to map a set of features to another, though
223+ that is not implemented yet.
224+
225+ The method computes the transform coefficients of a signal :math:`s`,
226+ where the atoms of the transform dictionary are generalized
227+ translations of each graph spectral filter to each vertex on the graph:
228+
229+ .. math:: c = D^* s,
230+
231+ where the columns of :math:`D` are :math:`g_{i,m} = T_i g_m` and
232+ :math:`T_i` is a generalized translation operator applied to each
233+ filter :math:`\hat{g}_m(\cdot)`. Each column of :math:`c` is the
234+ response of the signal to one filter.
235+
236+ In other words, this function is applying the analysis operator
237+ :math:`D^*`, respectively the synthesis operator :math:`D`, associated
238+ with the frame defined by the filter bank to the signals.
239+
240+ Parameters
241+ ----------
242+ s : ndarray
243+ Graph signals, a tensor of shape ``(N_NODES, N_SIGNALS,
244+ N_FEATURES)``, where ``N_NODES`` is the number of nodes in the
245+ graph, ``N_SIGNALS`` the number of independent signals you want to
246+ filter, and ``N_FEATURES`` is either 1 (analysis) or the number of
247+ filters in the filter bank (synthesis).
248+ method : {'exact', 'chebyshev'}
249+ Whether to use the exact method (via the graph Fourier transform)
250+ or the Chebyshev polynomial approximation. A Lanczos
251+ approximation is coming.
252+ order : int
253+ Degree of the Chebyshev polynomials.
254+
255+ Returns
256+ -------
257+ s : ndarray
258+ Graph signals, a tensor of shape ``(N_NODES, N_SIGNALS,
259+ N_FEATURES)``, where ``N_NODES`` and ``N_SIGNALS`` are the number
260+ of nodes and signals of the signal tensor that pas passed in, and
261+ ``N_FEATURES`` is either 1 (synthesis) or the number of filters in
262+ the filter bank (analysis).
263+
264+ References
265+ ----------
266+ See :cite:`hammond2011wavelets` for details on filtering graph signals.
267+
268+ Examples
269+ --------
270+
271+ Create a bunch of smooth signals by low-pass filtering white noise:
272+
273+ >>> import matplotlib.pyplot as plt
274+ >>> G = graphs.Ring(N=60)
275+ >>> G.estimate_lmax()
276+ >>> s = np.random.RandomState(42).uniform(size=(G.N, 10))
277+ >>> taus = [1, 10, 100]
278+ >>> s = filters.Heat(G, taus).filter(s)
279+ >>> s.shape
280+ (60, 10, 3)
281+
282+ Plot the 3 smoothed versions of the 10th signal:
283+
284+ >>> fig, ax = plt.subplots()
285+ >>> G.set_coordinates('line1D') # To visualize multiple signals in 1D.
286+ >>> G.plot_signal(s[:, 9, :], ax=ax)
287+ >>> legend = [r'$\tau={}$'.format(t) for t in taus]
288+ >>> ax.legend(legend) # doctest: +ELLIPSIS
289+ <matplotlib.legend.Legend object at ...>
290+
291+ Low-pass filter a delta to create a localized smooth signal:
292+
293+ >>> G = graphs.Sensor(30, seed=42)
294+ >>> G.compute_fourier_basis() # Reproducible computation of lmax.
295+ >>> s1 = np.zeros(G.N)
296+ >>> s1[13] = 1
297+ >>> s1 = filters.Heat(G, 3).filter(s1)
298+ >>> s1.shape
299+ (30, 1, 1)
300+
301+ Filter and reconstruct our signal:
302+
303+ >>> g = filters.MexicanHat(G, Nf=4)
304+ >>> s2 = g.filter(s1)
305+ >>> s2.shape
306+ (30, 1, 4)
307+ >>> s2 = g.filter(s2)
308+ >>> s2.shape
309+ (30, 1, 1)
310+
311+ Look how well we were able to reconstruct:
312+
313+ >>> fig, axes = plt.subplots(1, 2)
314+ >>> G.plot_signal(s1, ax=axes[0])
315+ >>> G.plot_signal(s2, ax=axes[1])
316+ >>> print('{:.5f}'.format(np.linalg.norm(s1 - s2)))
317+ 0.29620
318+
319+ Perfect reconstruction with Itersine, a tight frame:
320+
321+ >>> g = filters.Itersine(G)
322+ >>> s2 = g.filter(s1, method='exact')
323+ >>> s2 = g.filter(s2, method='exact')
324+ >>> np.linalg.norm(s1 - s2) < 1e-10
325+ True
326+
327+ """
328+ s = self .G .sanitize_signal (s )
329+ N_NODES , N_SIGNALS , N_FEATURES_IN = s .shape
330+
331+ # TODO: generalize to 2D (m --> n) filter banks.
332+ # Only 1 --> Nf (analysis) and Nf --> 1 (synthesis) for now.
333+ if N_FEATURES_IN not in [1 , self .Nf ]:
334+ raise ValueError ('Last dimension (N_FEATURES) should either be '
335+ '1 or the number of filters (Nf), '
336+ 'not {}.' .format (s .shape ))
337+ N_FEATURES_OUT = self .Nf if N_FEATURES_IN == 1 else 1
338+
339+ if method == 'exact' :
340+
341+ axis = 1 if N_FEATURES_IN == 1 else 2
342+ f = self .evaluate (self .G .e )
343+ f = np .expand_dims (f .T , axis )
344+ assert f .shape == (N_NODES , N_FEATURES_IN , N_FEATURES_OUT )
345+
346+ s = self .G .gft2 (s )
347+ s = np .matmul (s , f )
348+ s = self .G .igft2 (s )
349+
350+ elif method == 'chebyshev' :
351+
352+ # TODO: update Chebyshev implementation (after 2D filter banks).
353+ c = approximations .compute_cheby_coeff (self , m = order )
354+
355+ if N_FEATURES_IN == 1 : # Analysis.
356+ s = s .squeeze (axis = 2 )
357+ s = approximations .cheby_op (self .G , c , s )
358+ s = s .reshape ((N_NODES , N_FEATURES_OUT , N_SIGNALS ), order = 'F' )
359+ s = s .swapaxes (1 , 2 )
360+
361+ elif N_FEATURES_IN == self .Nf : # Synthesis.
362+ s = s .swapaxes (1 , 2 )
363+ s_in = s .reshape ((N_NODES * N_FEATURES_IN , N_SIGNALS ), order = 'F' )
364+ s = np .zeros ((N_NODES , N_SIGNALS ))
365+ tmpN = np .arange (N_NODES , dtype = int )
366+ for i in range (N_FEATURES_IN ):
367+ s += approximations .cheby_op (self .G ,
368+ c [i ],
369+ s_in [i * N_NODES + tmpN ])
370+ s = np .expand_dims (s , 2 )
371+
372+ else :
373+ raise ValueError ('Unknown method {}.' .format (method ))
374+
375+ return s
376+
206377 def inverse (self , c ):
207378 r"""
208379 Not implemented yet.
0 commit comments