|
1 | 1 | """Denoising source separation.""" |
| 2 | +# Authors: Nicolas Barascud <nicolas.barascud@gmail.com> |
| 3 | +# Maciej Szul <maciej.szul@isc.cnrs.fr> |
| 4 | + |
2 | 5 | import numpy as np |
3 | 6 | from scipy import linalg |
| 7 | +from scipy.signal import welch |
4 | 8 |
|
5 | 9 | from .tspca import tsr |
6 | 10 | from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth, |
@@ -230,3 +234,130 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False): |
230 | 234 | p = wpwr(X - y)[0] / wpwr(X)[0] |
231 | 235 | print('Power of components removed by DSS: {:.2f}'.format(p)) |
232 | 236 | return y, artifact |
| 237 | + |
| 238 | + |
| 239 | +def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5, |
| 240 | + nfft=512, show=False, prefix="dss_iter", n_iter_max=100): |
| 241 | + """Remove power line artifact iteratively. |
| 242 | +
|
| 243 | + This method applies dss_line() until the artifact has been smoothed out |
| 244 | + from the spectrum. |
| 245 | +
|
| 246 | + Parameters |
| 247 | + ---------- |
| 248 | + data : data, shape=(n_samples, n_chans, n_trials) |
| 249 | + Input data. |
| 250 | + fline : float |
| 251 | + Line frequency. |
| 252 | + sfreq : float |
| 253 | + Sampling frequency. |
| 254 | + win_sz : float |
| 255 | + Half of the width of the window around the target frequency used to fit |
| 256 | + the polynomial (default=10). |
| 257 | + spot_sz : float |
| 258 | + Half of the width of the window around the target frequency used to |
| 259 | + remove the peak and interpolate (default=2.5). |
| 260 | + nfft : int |
| 261 | + FFT size for the internal PSD calculation (default=512). |
| 262 | + show: bool |
| 263 | + Produce a visual output of each iteration (default=False). |
| 264 | + prefix : str |
| 265 | + Path and first part of the visualisation output file |
| 266 | + "{prefix}_{iteration number}.png" (default="dss_iter"). |
| 267 | + n_iter_max : int |
| 268 | + Maximum number of iterations (default=100). |
| 269 | +
|
| 270 | + Returns |
| 271 | + ------- |
| 272 | + data : array, shape=(n_samples, n_chans, n_trials) |
| 273 | + Denoised data. |
| 274 | + iterations : int |
| 275 | + Number of iterations. |
| 276 | + """ |
| 277 | + |
| 278 | + def nan_basic_interp(array): |
| 279 | + """Nan interpolation.""" |
| 280 | + nans, ix = np.isnan(array), lambda x: x.nonzero()[0] |
| 281 | + array[nans] = np.interp(ix(nans), ix(~nans), array[~nans]) |
| 282 | + return array |
| 283 | + |
| 284 | + freq_rn = [fline - win_sz, fline + win_sz] |
| 285 | + freq_sp = [fline - spot_sz, fline + spot_sz] |
| 286 | + freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) |
| 287 | + |
| 288 | + freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1]) |
| 289 | + freq_used = freq[freq_rn_ix] |
| 290 | + freq_sp_ix = np.logical_and(freq_used >= freq_sp[0], |
| 291 | + freq_used <= freq_sp[1]) |
| 292 | + |
| 293 | + if psd.ndim == 3: |
| 294 | + mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix] |
| 295 | + elif psd.ndim == 2: |
| 296 | + mean_psd = np.mean(psd, axis=(1))[freq_rn_ix] |
| 297 | + |
| 298 | + mean_psd_wospot = mean_psd.copy() |
| 299 | + mean_psd_wospot[freq_sp_ix] = np.nan |
| 300 | + mean_psd_tf = nan_basic_interp(mean_psd_wospot) |
| 301 | + pf = np.polyfit(freq_used, mean_psd_tf, 3) |
| 302 | + p = np.poly1d(pf) |
| 303 | + clean_fit_line = p(freq_used) |
| 304 | + |
| 305 | + aggr_resid = [] |
| 306 | + iterations = 0 |
| 307 | + while iterations < n_iter_max: |
| 308 | + data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1) |
| 309 | + freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) |
| 310 | + if psd.ndim == 3: |
| 311 | + mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix] |
| 312 | + elif psd.ndim == 2: |
| 313 | + mean_psd = np.mean(psd, axis=(1))[freq_rn_ix] |
| 314 | + |
| 315 | + residuals = mean_psd - clean_fit_line |
| 316 | + mean_score = np.mean(residuals[freq_sp_ix]) |
| 317 | + aggr_resid.append(mean_score) |
| 318 | + |
| 319 | + print("Iteration {} score: {}".format(iterations, mean_score)) |
| 320 | + |
| 321 | + if show: |
| 322 | + import matplotlib.pyplot as plt |
| 323 | + f, ax = plt.subplots(2, 2, figsize=(12, 6), facecolor="white") |
| 324 | + |
| 325 | + if psd.ndim == 3: |
| 326 | + mean_sens = np.mean(psd, axis=2) |
| 327 | + elif psd.ndim == 2: |
| 328 | + mean_sens = psd |
| 329 | + |
| 330 | + y = mean_sens[freq_rn_ix] |
| 331 | + ax.flat[0].plot(freq_used, y) |
| 332 | + ax.flat[0].set_title("Mean PSD across trials") |
| 333 | + |
| 334 | + ax.flat[1].plot(freq_used, mean_psd_tf, c="gray") |
| 335 | + ax.flat[1].plot(freq_used, mean_psd, c="blue") |
| 336 | + ax.flat[1].plot(freq_used, clean_fit_line, c="red") |
| 337 | + ax.flat[1].set_title("Mean PSD across trials and sensors") |
| 338 | + |
| 339 | + tf_ix = np.where(freq_used <= fline)[0][-1] |
| 340 | + ax.flat[2].plot(residuals, freq_used) |
| 341 | + color = "green" |
| 342 | + if mean_score <= 0: |
| 343 | + color = "red" |
| 344 | + ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color) |
| 345 | + ax.flat[2].set_title("Residuals") |
| 346 | + |
| 347 | + ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker='o') |
| 348 | + ax.flat[3].set_title("Iterations") |
| 349 | + |
| 350 | + f.set_tight_layout(True) |
| 351 | + plt.savefig(f"{prefix}_{iterations:03}.png") |
| 352 | + plt.close("all") |
| 353 | + |
| 354 | + if mean_score <= 0: |
| 355 | + break |
| 356 | + |
| 357 | + iterations += 1 |
| 358 | + |
| 359 | + if iterations == n_iter_max: |
| 360 | + raise RuntimeError('Could not converge. Consider increasing the ' |
| 361 | + 'maximum number of iterations') |
| 362 | + |
| 363 | + return data, iterations |
0 commit comments