Skip to content

Commit ed28670

Browse files
maciekszulnbara
andauthored
[NEW] Add dss_line_iter() (#52)
* ADD: added dss_line_iter * add test embryo * ADD: max_iteration argument added to dss_line_iter * ADD: dss_line_iter test data * add example_dss_line skeleton - bump version - fixed doc - add `pytest --noplots` option * finish example Co-authored-by: nbara <10333715+nbara@users.noreply.github.com>
1 parent 3105ad6 commit ed28670

File tree

10 files changed

+390
-75
lines changed

10 files changed

+390
-75
lines changed

doc/modules/meegkit.dss.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
dss0
1717
dss1
1818
dss_line
19+
dss_line_iter
1920

2021

2122

examples/example_dss_line.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Remove line noise with ZapLine
3+
==============================
4+
5+
Find a spatial filter to get rid of line noise [1]_.
6+
7+
Uses meegkit.dss_line().
8+
9+
References
10+
----------
11+
.. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to
12+
remove power line artifacts [Preprint]. https://doi.org/10.1101/782029
13+
14+
"""
15+
# Authors: Maciej Szul <maciej.szul@isc.cnrs.fr>
16+
# Nicolas Barascud <nicolas.barascud@gmail.com>
17+
import os
18+
19+
import matplotlib.pyplot as plt
20+
import numpy as np
21+
from meegkit import dss
22+
from meegkit.utils import create_line_data, unfold
23+
from scipy import signal
24+
25+
###############################################################################
26+
# Line noise removal
27+
# =============================================================================
28+
29+
###############################################################################
30+
# Remove line noise with dss_line()
31+
# -----------------------------------------------------------------------------
32+
# We first generate some noisy data to work with
33+
sfreq = 250
34+
fline = 50
35+
nsamples = 10000
36+
nchans = 10
37+
data = create_line_data(n_samples=3 * nsamples, n_chans=nchans,
38+
n_trials=1, fline=fline / sfreq, SNR=2)[0]
39+
data = data[..., 0] # only take first trial
40+
41+
# Apply dss_line (ZapLine)
42+
out, _ = dss.dss_line(data, fline, sfreq, nkeep=1)
43+
44+
###############################################################################
45+
# Plot before/after
46+
f, ax = plt.subplots(1, 2, sharey=True)
47+
f, Pxx = signal.welch(data, sfreq, nperseg=500, axis=0, return_onesided=True)
48+
ax[0].semilogy(f, Pxx)
49+
f, Pxx = signal.welch(out, sfreq, nperseg=500, axis=0, return_onesided=True)
50+
ax[1].semilogy(f, Pxx)
51+
ax[0].set_xlabel('frequency [Hz]')
52+
ax[1].set_xlabel('frequency [Hz]')
53+
ax[0].set_ylabel('PSD [V**2/Hz]')
54+
ax[0].set_title('before')
55+
ax[1].set_title('after')
56+
plt.show()
57+
58+
59+
###############################################################################
60+
# Remove line noise with dss_line_iter()
61+
# -----------------------------------------------------------------------------
62+
# We first load some noisy data to work with
63+
data = np.load(os.path.join('..', 'tests', 'data', 'dss_line_data.npy'))
64+
fline = 50
65+
sfreq = 200
66+
print(data.shape) # n_samples, n_chans, n_trials
67+
68+
# Apply dss_line(), removing only one component
69+
out1, _ = dss.dss_line(data, fline, sfreq, nremove=1, nfft=400)
70+
71+
###############################################################################
72+
# Now try dss_line_iter(). This applies dss_line() repeatedly until the
73+
# artifact is gone
74+
out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400)
75+
print(f'Removed {iterations} components')
76+
77+
###############################################################################
78+
# Plot results with dss_line() vs. dss_line_iter()
79+
f, ax = plt.subplots(1, 2, sharey=True)
80+
f, Pxx = signal.welch(unfold(out1), sfreq, nperseg=200, axis=0,
81+
return_onesided=True)
82+
ax[0].semilogy(f, Pxx, lw=.5)
83+
f, Pxx = signal.welch(unfold(out2), sfreq, nperseg=200, axis=0,
84+
return_onesided=True)
85+
ax[1].semilogy(f, Pxx, lw=.5)
86+
ax[0].set_xlabel('frequency [Hz]')
87+
ax[1].set_xlabel('frequency [Hz]')
88+
ax[0].set_ylabel('PSD [V**2/Hz]')
89+
ax[0].set_title('dss_line')
90+
ax[1].set_title('dss_line_iter')
91+
plt.tight_layout()
92+
plt.show()

meegkit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""M/EEG denoising utilities in python."""
2-
__version__ = '0.1.1'
2+
__version__ = '0.1.2'
33

44
from . import asr, cca, detrend, dss, sns, star, ress, trca, tspca, utils
55

meegkit/dss.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Denoising source separation."""
2+
# Authors: Nicolas Barascud <nicolas.barascud@gmail.com>
3+
# Maciej Szul <maciej.szul@isc.cnrs.fr>
4+
25
import numpy as np
36
from scipy import linalg
7+
from scipy.signal import welch
48

59
from .tspca import tsr
610
from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth,
@@ -230,3 +234,130 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
230234
p = wpwr(X - y)[0] / wpwr(X)[0]
231235
print('Power of components removed by DSS: {:.2f}'.format(p))
232236
return y, artifact
237+
238+
239+
def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
240+
nfft=512, show=False, prefix="dss_iter", n_iter_max=100):
241+
"""Remove power line artifact iteratively.
242+
243+
This method applies dss_line() until the artifact has been smoothed out
244+
from the spectrum.
245+
246+
Parameters
247+
----------
248+
data : data, shape=(n_samples, n_chans, n_trials)
249+
Input data.
250+
fline : float
251+
Line frequency.
252+
sfreq : float
253+
Sampling frequency.
254+
win_sz : float
255+
Half of the width of the window around the target frequency used to fit
256+
the polynomial (default=10).
257+
spot_sz : float
258+
Half of the width of the window around the target frequency used to
259+
remove the peak and interpolate (default=2.5).
260+
nfft : int
261+
FFT size for the internal PSD calculation (default=512).
262+
show: bool
263+
Produce a visual output of each iteration (default=False).
264+
prefix : str
265+
Path and first part of the visualisation output file
266+
"{prefix}_{iteration number}.png" (default="dss_iter").
267+
n_iter_max : int
268+
Maximum number of iterations (default=100).
269+
270+
Returns
271+
-------
272+
data : array, shape=(n_samples, n_chans, n_trials)
273+
Denoised data.
274+
iterations : int
275+
Number of iterations.
276+
"""
277+
278+
def nan_basic_interp(array):
279+
"""Nan interpolation."""
280+
nans, ix = np.isnan(array), lambda x: x.nonzero()[0]
281+
array[nans] = np.interp(ix(nans), ix(~nans), array[~nans])
282+
return array
283+
284+
freq_rn = [fline - win_sz, fline + win_sz]
285+
freq_sp = [fline - spot_sz, fline + spot_sz]
286+
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
287+
288+
freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1])
289+
freq_used = freq[freq_rn_ix]
290+
freq_sp_ix = np.logical_and(freq_used >= freq_sp[0],
291+
freq_used <= freq_sp[1])
292+
293+
if psd.ndim == 3:
294+
mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix]
295+
elif psd.ndim == 2:
296+
mean_psd = np.mean(psd, axis=(1))[freq_rn_ix]
297+
298+
mean_psd_wospot = mean_psd.copy()
299+
mean_psd_wospot[freq_sp_ix] = np.nan
300+
mean_psd_tf = nan_basic_interp(mean_psd_wospot)
301+
pf = np.polyfit(freq_used, mean_psd_tf, 3)
302+
p = np.poly1d(pf)
303+
clean_fit_line = p(freq_used)
304+
305+
aggr_resid = []
306+
iterations = 0
307+
while iterations < n_iter_max:
308+
data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1)
309+
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
310+
if psd.ndim == 3:
311+
mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix]
312+
elif psd.ndim == 2:
313+
mean_psd = np.mean(psd, axis=(1))[freq_rn_ix]
314+
315+
residuals = mean_psd - clean_fit_line
316+
mean_score = np.mean(residuals[freq_sp_ix])
317+
aggr_resid.append(mean_score)
318+
319+
print("Iteration {} score: {}".format(iterations, mean_score))
320+
321+
if show:
322+
import matplotlib.pyplot as plt
323+
f, ax = plt.subplots(2, 2, figsize=(12, 6), facecolor="white")
324+
325+
if psd.ndim == 3:
326+
mean_sens = np.mean(psd, axis=2)
327+
elif psd.ndim == 2:
328+
mean_sens = psd
329+
330+
y = mean_sens[freq_rn_ix]
331+
ax.flat[0].plot(freq_used, y)
332+
ax.flat[0].set_title("Mean PSD across trials")
333+
334+
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray")
335+
ax.flat[1].plot(freq_used, mean_psd, c="blue")
336+
ax.flat[1].plot(freq_used, clean_fit_line, c="red")
337+
ax.flat[1].set_title("Mean PSD across trials and sensors")
338+
339+
tf_ix = np.where(freq_used <= fline)[0][-1]
340+
ax.flat[2].plot(residuals, freq_used)
341+
color = "green"
342+
if mean_score <= 0:
343+
color = "red"
344+
ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color)
345+
ax.flat[2].set_title("Residuals")
346+
347+
ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker='o')
348+
ax.flat[3].set_title("Iterations")
349+
350+
f.set_tight_layout(True)
351+
plt.savefig(f"{prefix}_{iterations:03}.png")
352+
plt.close("all")
353+
354+
if mean_score <= 0:
355+
break
356+
357+
iterations += 1
358+
359+
if iterations == n_iter_max:
360+
raise RuntimeError('Could not converge. Consider increasing the '
361+
'maximum number of iterations')
362+
363+
return data, iterations

meegkit/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
spectral_envelope, teager_kaiser)
1414
from .stats import (bootstrap_ci, bootstrap_snr, cronbach, rms, robust_mean,
1515
rolling_corr, snr_spectrum)
16+
from .testing import create_line_data

meegkit/utils/testing.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Synthetic test data."""
2+
import numpy as np
3+
from meegkit.utils import fold, rms, unfold
4+
5+
import matplotlib.pyplot as plt
6+
7+
8+
def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20,
9+
n_bad_chans=1, SNR=.1, fline=1, t0=None, show=False):
10+
"""Create synthetic data.
11+
12+
Parameters
13+
----------
14+
n_samples : int
15+
Number of samples (default=100*3).
16+
n_chans : int
17+
Number of channels (default=30).
18+
n_trials : int
19+
Number of trials (default=100).
20+
noise_dim : int
21+
Dimensionality of noise (default=20).
22+
n_bad_chans : int
23+
Number of bad channels (default=1).
24+
t0 : int
25+
Onset sample of artifact.
26+
fline : float
27+
Normalized frequency of artifact (freq/samplerate), (default=1).
28+
29+
Returns
30+
-------
31+
data : ndarray, shape=(n_samples, n_chans, n_trials)
32+
source : ndarray, shape=(n_samples,)
33+
"""
34+
rng = np.random.RandomState(2022)
35+
36+
if t0 is None:
37+
t0 = n_samples // 3
38+
t1 = n_samples - 2 * t0 # artifact duration
39+
40+
# create source signal
41+
source = np.hstack((
42+
np.zeros(t0),
43+
np.sin(2 * np.pi * fline * np.arange(t1)),
44+
np.zeros(t0))) # noise -> artifact -> noise
45+
source = source[:, None]
46+
47+
# mix source in channels
48+
s = source * rng.randn(1, n_chans)
49+
s = s[:, :, np.newaxis]
50+
s = np.tile(s, (1, 1, n_trials)) # create trials
51+
52+
# set first `n_bad_chans` to zero
53+
s[:, :n_bad_chans] = 0.
54+
55+
# noise
56+
noise = np.dot(
57+
unfold(rng.randn(n_samples, noise_dim, n_trials)),
58+
rng.randn(noise_dim, n_chans))
59+
noise = fold(noise, n_samples)
60+
61+
# mix signal and noise
62+
data = noise / rms(noise.flatten()) + SNR * s / rms(s.flatten())
63+
64+
if show:
65+
f, ax = plt.subplots(3)
66+
ax[0].plot(source.mean(-1), label='source')
67+
ax[1].plot(noise[:, 1].mean(-1), label='noise (avg over trials)')
68+
ax[2].plot(data[:, 1].mean(-1), label='mixture (avg over trials)')
69+
ax[0].legend()
70+
ax[1].legend()
71+
ax[2].legend()
72+
plt.show()
73+
74+
return data, source

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
author='N Barascud',
1010
author_email='nicolas.barascud@gmail.com',
1111
license='UNLICENSED',
12-
version='0.1.1',
12+
version='0.1.2',
1313
packages=find_packages(exclude=['doc', 'tests']),
1414
zip_safe=False)

tests/conftest.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
11
import pytest
2-
import numpy as np
3-
import random as rand
42

3+
import matplotlib.pyplot as plt
54

6-
@pytest.fixture
7-
def random():
8-
rand.seed(9)
9-
np.random.seed(9)
5+
6+
def pytest_addoption(parser):
7+
"""Add command line option to pytest."""
8+
parser.addoption(
9+
"--runslow",
10+
action="store_true",
11+
default=False,
12+
help="run slow tests"
13+
)
14+
parser.addoption(
15+
"--noplots",
16+
action="store_true",
17+
default=False,
18+
help="halt on plots"
19+
)
20+
21+
22+
def pytest_collection_modifyitems(config, items):
23+
"""Do not skip slow test if option provided."""
24+
if config.getoption("--noplots"):
25+
plt.switch_backend('agg')
26+
27+
if config.getoption("--runslow"):
28+
# --runslow given in cli: do not skip slow tests
29+
return
30+
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
31+
for item in items:
32+
if "slow" in item.keywords:
33+
item.add_marker(skip_slow)

tests/data/dss_line_data.npy

52.1 MB
Binary file not shown.

0 commit comments

Comments
 (0)