Skip to content

Commit 849f949

Browse files
committed
Filter.filter: generalize analysis and synthesis
Standardize signal shapes as N_NODES x N_SIGNALS x N_FEATURES
1 parent c82bf38 commit 849f949

File tree

7 files changed

+271
-0
lines changed

7 files changed

+271
-0
lines changed

pygsp/filters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
.. autosummary::
1717
1818
Filter.evaluate
19+
Filter.filter
1920
Filter.analysis
2021
Filter.synthesis
2122
Filter.compute_frame

pygsp/filters/filter.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,177 @@ def evaluate(self, x):
203203
y[i] = g(x)
204204
return y
205205

206+
def filter(self, s, method='chebyshev', order=30):
207+
r"""
208+
Filter signals with the filter bank (analysis or synthesis).
209+
210+
A signal is defined as a rank-3 tensor of shape ``(N_NODES, N_SIGNALS,
211+
N_FEATURES)``, where ``N_NODES`` is the number of nodes in the graph,
212+
``N_SIGNALS`` is the number of independent signals, and ``N_FEATURES``
213+
is the number of features which compose a graph signal, or the
214+
dimensionality of a graph signal. For example if you filter a signal
215+
with a filter bank of 8 filters, you're extracting 8 features and
216+
decomposing your signal into 8 parts. That is called analysis. Your are
217+
thus transforming your signal tensor from ``(G.N, 1, 1)`` to ``(G.N, 1,
218+
8)``. Now you may want to combine back the features to form an unique
219+
signal. For this you apply again 8 filters, one filter per feature, and
220+
sum the result up. As such you're transforming your ``(G.N, 1, 8)``
221+
tensor signal back to ``(G.N, 1, 1)``. That is known as synthesis. More
222+
generally, you may want to map a set of features to another, though
223+
that is not implemented yet.
224+
225+
The method computes the transform coefficients of a signal :math:`s`,
226+
where the atoms of the transform dictionary are generalized
227+
translations of each graph spectral filter to each vertex on the graph:
228+
229+
.. math:: c = D^* s,
230+
231+
where the columns of :math:`D` are :math:`g_{i,m} = T_i g_m` and
232+
:math:`T_i` is a generalized translation operator applied to each
233+
filter :math:`\hat{g}_m(\cdot)`. Each column of :math:`c` is the
234+
response of the signal to one filter.
235+
236+
In other words, this function is applying the analysis operator
237+
:math:`D^*`, respectively the synthesis operator :math:`D`, associated
238+
with the frame defined by the filter bank to the signals.
239+
240+
Parameters
241+
----------
242+
s : ndarray
243+
Graph signals, a tensor of shape ``(N_NODES, N_SIGNALS,
244+
N_FEATURES)``, where ``N_NODES`` is the number of nodes in the
245+
graph, ``N_SIGNALS`` the number of independent signals you want to
246+
filter, and ``N_FEATURES`` is either 1 (analysis) or the number of
247+
filters in the filter bank (synthesis).
248+
method : {'exact', 'chebyshev'}
249+
Whether to use the exact method (via the graph Fourier transform)
250+
or the Chebyshev polynomial approximation. A Lanczos
251+
approximation is coming.
252+
order : int
253+
Degree of the Chebyshev polynomials.
254+
255+
Returns
256+
-------
257+
s : ndarray
258+
Graph signals, a tensor of shape ``(N_NODES, N_SIGNALS,
259+
N_FEATURES)``, where ``N_NODES`` and ``N_SIGNALS`` are the number
260+
of nodes and signals of the signal tensor that pas passed in, and
261+
``N_FEATURES`` is either 1 (synthesis) or the number of filters in
262+
the filter bank (analysis).
263+
264+
References
265+
----------
266+
See :cite:`hammond2011wavelets` for details on filtering graph signals.
267+
268+
Examples
269+
--------
270+
271+
Create a bunch of smooth signals by low-pass filtering white noise:
272+
273+
>>> import matplotlib.pyplot as plt
274+
>>> G = graphs.Ring(N=60)
275+
>>> G.estimate_lmax()
276+
>>> s = np.random.RandomState(42).uniform(size=(G.N, 10))
277+
>>> taus = [1, 10, 100]
278+
>>> s = filters.Heat(G, taus).filter(s)
279+
>>> s.shape
280+
(60, 10, 3)
281+
282+
Plot the 3 smoothed versions of the 10th signal:
283+
284+
>>> fig, ax = plt.subplots()
285+
>>> G.set_coordinates('line1D') # To visualize multiple signals in 1D.
286+
>>> G.plot_signal(s[:, 9, :], ax=ax)
287+
>>> legend = [r'$\tau={}$'.format(t) for t in taus]
288+
>>> ax.legend(legend) # doctest: +ELLIPSIS
289+
<matplotlib.legend.Legend object at ...>
290+
291+
Low-pass filter a delta to create a localized smooth signal:
292+
293+
>>> G = graphs.Sensor(30, seed=42)
294+
>>> G.compute_fourier_basis() # Reproducible computation of lmax.
295+
>>> s1 = np.zeros(G.N)
296+
>>> s1[13] = 1
297+
>>> s1 = filters.Heat(G, 3).filter(s1)
298+
>>> s1.shape
299+
(30, 1, 1)
300+
301+
Filter and reconstruct our signal:
302+
303+
>>> g = filters.MexicanHat(G, Nf=4)
304+
>>> s2 = g.filter(s1)
305+
>>> s2.shape
306+
(30, 1, 4)
307+
>>> s2 = g.filter(s2)
308+
>>> s2.shape
309+
(30, 1, 1)
310+
311+
Look how well we were able to reconstruct:
312+
313+
>>> fig, axes = plt.subplots(1, 2)
314+
>>> G.plot_signal(s1, ax=axes[0])
315+
>>> G.plot_signal(s2, ax=axes[1])
316+
>>> print('{:.5f}'.format(np.linalg.norm(s1 - s2)))
317+
0.29620
318+
319+
Perfect reconstruction with Itersine, a tight frame:
320+
321+
>>> g = filters.Itersine(G)
322+
>>> s2 = g.filter(s1, method='exact')
323+
>>> s2 = g.filter(s2, method='exact')
324+
>>> np.linalg.norm(s1 - s2) < 1e-10
325+
True
326+
327+
"""
328+
s = self.G.sanitize_signal(s)
329+
N_NODES, N_SIGNALS, N_FEATURES_IN = s.shape
330+
331+
# TODO: generalize to 2D (m --> n) filter banks.
332+
# Only 1 --> Nf (analysis) and Nf --> 1 (synthesis) for now.
333+
if N_FEATURES_IN not in [1, self.Nf]:
334+
raise ValueError('Last dimension (N_FEATURES) should either be '
335+
'1 or the number of filters (Nf), '
336+
'not {}.'.format(s.shape))
337+
N_FEATURES_OUT = self.Nf if N_FEATURES_IN == 1 else 1
338+
339+
if method == 'exact':
340+
341+
axis = 1 if N_FEATURES_IN == 1 else 2
342+
f = self.evaluate(self.G.e)
343+
f = np.expand_dims(f.T, axis)
344+
assert f.shape == (N_NODES, N_FEATURES_IN, N_FEATURES_OUT)
345+
346+
s = self.G.gft2(s)
347+
s = np.matmul(s, f)
348+
s = self.G.igft2(s)
349+
350+
elif method == 'chebyshev':
351+
352+
# TODO: update Chebyshev implementation (after 2D filter banks).
353+
c = approximations.compute_cheby_coeff(self, m=order)
354+
355+
if N_FEATURES_IN == 1: # Analysis.
356+
s = s.squeeze(axis=2)
357+
s = approximations.cheby_op(self.G, c, s)
358+
s = s.reshape((N_NODES, N_FEATURES_OUT, N_SIGNALS), order='F')
359+
s = s.swapaxes(1, 2)
360+
361+
elif N_FEATURES_IN == self.Nf: # Synthesis.
362+
s = s.swapaxes(1, 2)
363+
s_in = s.reshape((N_NODES*N_FEATURES_IN, N_SIGNALS), order='F')
364+
s = np.zeros((N_NODES, N_SIGNALS))
365+
tmpN = np.arange(N_NODES, dtype=int)
366+
for i in range(N_FEATURES_IN):
367+
s += approximations.cheby_op(self.G,
368+
c[i],
369+
s_in[i * N_NODES + tmpN])
370+
s = np.expand_dims(s, 2)
371+
372+
else:
373+
raise ValueError('Unknown method {}.'.format(method))
374+
375+
return s
376+
206377
def inverse(self, c):
207378
r"""
208379
Not implemented yet.

pygsp/graphs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
Graph.set_coordinates
8787
Graph.subgraph
8888
Graph.extract_components
89+
Graph.sanitize_signal
8990
9091
Graph models
9192
============

pygsp/graphs/fourier.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def gft(self, s):
139139
"""
140140
return np.dot(np.conjugate(self.U.T), s) # True Hermitian here.
141141

142+
def gft2(self, s):
143+
s = self.sanitize_signal(s)
144+
U = np.conjugate(self.U) # True Hermitian. (Although U is often real.)
145+
return np.tensordot(U, s, ([0], [0]))
146+
142147
def igft(self, s_hat):
143148
r"""Compute the inverse graph Fourier transform.
144149
@@ -171,6 +176,10 @@ def igft(self, s_hat):
171176
"""
172177
return np.dot(self.U, s_hat)
173178

179+
def igft2(self, s_hat):
180+
s_hat = self.sanitize_signal(s_hat)
181+
return np.tensordot(self.U, s_hat, ([1], [0]))
182+
174183
def translate(self, f, i):
175184
r"""Translate the signal *f* to the node *i*.
176185

pygsp/graphs/graph.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,57 @@ def get_edge_list(self):
725725

726726
return v_in, v_out, weights
727727

728+
def sanitize_signal(self, s):
729+
r"""Standardize signal shape.
730+
731+
Add singleton dimensions at the end and check the resulting shape.
732+
733+
Parameters
734+
----------
735+
s : ndarray
736+
Signal tensor of shape ``(N_NODES)``, ``(N_NODES, N_SIGNALS)``, or
737+
``(N_NODES, N_SIGNALS, N_FEATURES)``.
738+
739+
Returns
740+
-------
741+
s : ndarray
742+
Signal tensor of shape ``(N_NODES, N_SIGNALS, N_FEATURES)``.
743+
744+
Raises
745+
------
746+
ValueError
747+
If the passed signal tensor is more than 3 dimensions or if the
748+
first dimension's size is not the number of nodes.
749+
750+
Examples
751+
--------
752+
>>> G = graphs.Logo()
753+
>>> s = np.ones(G.N) # One signal, one feature.
754+
>>> G.sanitize_signal(s).shape
755+
(1130, 1, 1)
756+
>>> s = np.ones((G.N, 10)) # Ten signals of one feature.
757+
>>> G.sanitize_signal(s).shape
758+
(1130, 10, 1)
759+
>>> s = np.ones((G.N, 10, 5)) # Ten signals of 5 features.
760+
>>> G.sanitize_signal(s).shape
761+
(1130, 10, 5)
762+
763+
"""
764+
if s.ndim == 1:
765+
# Single signal, single feature.
766+
s = np.expand_dims(s, axis=1)
767+
768+
if s.ndim == 2:
769+
# Multiple signals, single feature.
770+
s = np.expand_dims(s, axis=2)
771+
772+
if s.ndim != 3 or s.shape[0] != self.N:
773+
raise ValueError('Signal must have shape N_NODES x N_SIGNALS x '
774+
'N_FEATURES, not {}. Last singleton dimensions '
775+
'may be omitted.'.format(s.shape))
776+
777+
return s
778+
728779
def modulate(self, f, k):
729780
r"""Modulate the signal *f* to the frequency *k*.
730781

pygsp/tests/test_filters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,30 @@ def _test_synthesis(self, f):
4040
self.assertRaises(NotImplementedError, f.synthesis, S,
4141
method='lanczos')
4242

43+
def _test_filter(self, f, tight):
44+
# Analysis.
45+
s2 = f.filter(self._signal, method='exact')
46+
s3 = f.filter(self._signal, method='chebyshev', order=100)
47+
48+
# Synthesis.
49+
s4 = f.filter(s2, method='exact')
50+
s5 = f.filter(s3, method='chebyshev', order=100)
51+
52+
if f.Nf < 100:
53+
# TODO: does not pass for Gabor.
54+
np.testing.assert_allclose(s2, s3, rtol=0.1, atol=0.01)
55+
np.testing.assert_allclose(s4, s5, rtol=0.1, atol=0.01)
56+
57+
if tight:
58+
A, _ = f.estimate_frame_bounds(use_eigenvalues=True)
59+
np.testing.assert_allclose(s4.squeeze(), A * self._signal)
60+
assert np.linalg.norm(s5.squeeze() - A * self._signal) < 0.1
61+
4362
def _test_methods(self, f, tight):
4463
self.assertIs(f.G, self._G)
4564

65+
self._test_filter(f, tight)
66+
4667
c_exact = f.analysis(self._signal, method='exact')
4768
c_cheby = f.analysis(self._signal, method='chebyshev')
4869
self.assertEqual(c_exact.shape, c_cheby.shape)

pygsp/tests/test_graphs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,23 @@ def test_set_coordinates(self):
119119
G.set_coordinates('community2D')
120120
self.assertRaises(ValueError, G.set_coordinates, 'invalid')
121121

122+
def test_sanitize_signal(self):
123+
s1 = np.arange(self._G.N)
124+
s2 = np.reshape(s1, (self._G.N, 1))
125+
s3 = np.reshape(s1, (self._G.N, 1, 1))
126+
s4 = np.arange(self._G.N*10).reshape((self._G.N, 10))
127+
s5 = np.reshape(s4, (self._G.N, 10, 1))
128+
s1 = self._G.sanitize_signal(s1)
129+
s2 = self._G.sanitize_signal(s2)
130+
s3 = self._G.sanitize_signal(s3)
131+
s4 = self._G.sanitize_signal(s4)
132+
s5 = self._G.sanitize_signal(s5)
133+
np.testing.assert_equal(s2, s1)
134+
np.testing.assert_equal(s3, s1)
135+
np.testing.assert_equal(s5, s4)
136+
self.assertRaises(ValueError, self._G.sanitize_signal,
137+
np.ones((2, 2, 2, 2)))
138+
122139
def test_nngraph(self):
123140
Xin = np.arange(90).reshape(30, 3)
124141
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']

0 commit comments

Comments
 (0)