|
22 | 22 |
|
23 | 23 | This script requires the `mir_eval` to run, and `tkinter` and `sounddevice` packages for the GUI option.
|
24 | 24 | """
|
| 25 | +import os |
25 | 26 | import sys
|
26 | 27 | import time
|
| 28 | +from pathlib import Path |
27 | 29 |
|
28 | 30 | import matplotlib
|
29 | 31 | import numpy as np
|
| 32 | +from mir_eval.separation import bss_eval_sources |
| 33 | +from scipy.io import wavfile |
| 34 | + |
30 | 35 | from auxiva_pca import auxiva_pca
|
31 | 36 | from five import five
|
32 |
| - |
33 | 37 | # Get the data if needed
|
34 | 38 | from get_data import get_data, samples_dir
|
35 | 39 | from ive import ogive
|
36 |
| -from mir_eval.separation import bss_eval_sources |
37 | 40 | from overiva import overiva
|
38 | 41 | from pyroomacoustics.bss import projection_back
|
39 |
| -from routines import PlaySoundGUI, grid_layout, random_layout, semi_circle_layout |
40 |
| -from scipy.io import wavfile |
41 |
| - |
42 |
| -get_data() |
| 42 | +from routines import (PlaySoundGUI, grid_layout, random_layout, |
| 43 | + semi_circle_layout) |
43 | 44 | from samples.generate_samples import sampling, wav_read_center
|
44 | 45 |
|
45 | 46 | # Once we are sure the data is there, import some methods
|
46 | 47 | # to select and read samples
|
47 | 48 | sys.path.append(samples_dir)
|
48 | 49 |
|
| 50 | +SEP_SAMPLES_DIR = Path("separation_samples") |
49 | 51 |
|
50 | 52 | # We concatenate a few samples to make them long enough
|
51 | 53 | if __name__ == "__main__":
|
|
96 | 98 | action="store_true",
|
97 | 99 | help="Creates a small GUI for easy playback of the sound samples",
|
98 | 100 | )
|
| 101 | + parser.add_argument( |
| 102 | + "--no_plot", action="store_true", help="Do not plot anything", |
| 103 | + ) |
| 104 | + parser.add_argument( |
| 105 | + "--sinr", default=5, type=int, help="Signal-to-Interference-and-Noise Ratio", |
| 106 | + ) |
| 107 | + parser.add_argument( |
| 108 | + "--seed", default=7284023459, type=int, help="Seed for the simulation", |
| 109 | + ) |
99 | 110 | parser.add_argument(
|
100 | 111 | "--save",
|
101 | 112 | action="store_true",
|
|
127 | 138 | use_real_R = False
|
128 | 139 |
|
129 | 140 | # fix the randomness for repeatability
|
130 |
| - np.random.seed(30) |
| 141 | + np.random.seed(args.seed) |
131 | 142 |
|
132 | 143 | # set the source powers, the first one is half
|
133 | 144 | source_std = np.ones(n_sources_target)
|
134 | 145 | source_std[0] /= np.sqrt(2.0)
|
135 | 146 |
|
136 |
| - SINR = 5 # signal-to-interference-and-noise ratio |
| 147 | + SINR = args.sinr # signal-to-interference-and-noise ratio |
137 | 148 | SINR_diffuse_ratio = 0.9999 # ratio of uncorrelated to diffuse noise
|
138 | 149 |
|
139 | 150 | # STFT parameters
|
|
161 | 172 | )
|
162 | 173 | # interferer_locs = grid_layout([3., 5.5], n_sources - n_sources_target, offset=[6.5, 1., 1.7])
|
163 | 174 | interferer_locs = random_layout(
|
164 |
| - [3.0, 5.5, 1.5], n_sources - n_sources_target, offset=[6.5, 1.0, 0.5], seed=1234 |
| 175 | + [3.0, 5.5, 1.5], n_sources - n_sources_target, offset=[6.5, 1.0, 0.5], |
165 | 176 | )
|
166 | 177 | source_locs = np.concatenate((target_locs, interferer_locs), axis=1)
|
167 | 178 |
|
168 | 179 | # Prepare the signals
|
169 | 180 | wav_files = sampling(
|
170 |
| - 1, n_sources, f"{samples_dir}/metadata.json", gender_balanced=True, seed=2222 |
| 181 | + 1, |
| 182 | + n_sources, |
| 183 | + f"{samples_dir}/metadata.json", |
| 184 | + gender_balanced=True, |
| 185 | + seed=args.seed, |
171 | 186 | )[0]
|
172 | 187 | signals = wav_read_center(wav_files, seed=123)
|
173 | 188 |
|
@@ -237,7 +252,7 @@ def callback_mix(
|
237 | 252 |
|
238 | 253 | # reference is taken at microphone 0
|
239 | 254 | ref = np.vstack(
|
240 |
| - [separate_recordings[0, :1], np.sum(separate_recordings[1:, :1], axis=0)] |
| 255 | + [separate_recordings[0, :1], np.sum(separate_recordings[1:, :1], axis=0)] |
241 | 256 | )
|
242 | 257 |
|
243 | 258 | SDR, SIR, eval_time = [], [], []
|
@@ -272,7 +287,9 @@ def convergence_callback(Y, **kwargs):
|
272 | 287 | eval_time.append(t_exit - t_enter)
|
273 | 288 |
|
274 | 289 | if args.algo.startswith("ogive"):
|
275 |
| - callback_checkpoints = list(range(1, ogive_iter + ogive_iter // n_iter, ogive_iter // n_iter)) |
| 290 | + callback_checkpoints = list( |
| 291 | + range(1, ogive_iter + ogive_iter // n_iter, ogive_iter // n_iter) |
| 292 | + ) |
276 | 293 | else:
|
277 | 294 | callback_checkpoints = list(range(1, n_iter + 1))
|
278 | 295 | if args.no_cb:
|
@@ -386,44 +403,55 @@ def convergence_callback(Y, **kwargs):
|
386 | 403 | print(f"SDR: In: {SDR[0, 0]:6.2f} dB -> Out: {SDR[-1, 0]:6.2f} dB")
|
387 | 404 | print(f"SIR: In: {SIR[0, 0]:6.2f} dB -> Out: {SIR[-1, 0]:6.2f} dB")
|
388 | 405 |
|
389 |
| - import matplotlib.pyplot as plt |
| 406 | + if not args.no_plot: |
| 407 | + import matplotlib.pyplot as plt |
390 | 408 |
|
391 |
| - plt.figure() |
| 409 | + plt.figure() |
392 | 410 |
|
393 |
| - plt.subplot(2, 1, 1) |
394 |
| - plt.specgram(mics_signals[0], NFFT=1024, Fs=room.fs) |
395 |
| - plt.title("Microphone 0 input") |
| 411 | + plt.subplot(2, 1, 1) |
| 412 | + plt.specgram(mics_signals[0], NFFT=1024, Fs=room.fs) |
| 413 | + plt.title("Microphone 0 input") |
396 | 414 |
|
397 |
| - plt.subplot(2, 1, 2) |
398 |
| - plt.specgram(y_hat[:, 0], NFFT=1024, Fs=room.fs) |
399 |
| - plt.title("Extracted source") |
| 415 | + plt.subplot(2, 1, 2) |
| 416 | + plt.specgram(y_hat[:, 0], NFFT=1024, Fs=room.fs) |
| 417 | + plt.title("Extracted source") |
400 | 418 |
|
401 |
| - plt.tight_layout(pad=0.5) |
| 419 | + plt.tight_layout(pad=0.5) |
402 | 420 |
|
403 |
| - plt.figure() |
404 |
| - plt.plot([0] + callback_checkpoints, SDR[:, 0], label="SDR", marker="*") |
405 |
| - plt.plot([0] + callback_checkpoints, SIR[:, 0], label="SIR", marker="o") |
406 |
| - plt.legend() |
407 |
| - plt.tight_layout(pad=0.5) |
| 421 | + if not args.no_cb: |
| 422 | + plt.figure() |
| 423 | + plt.plot([0] + callback_checkpoints, SDR[:, 0], label="SDR", marker="*") |
| 424 | + plt.plot([0] + callback_checkpoints, SIR[:, 0], label="SIR", marker="o") |
| 425 | + plt.legend() |
| 426 | + plt.tight_layout(pad=0.5) |
408 | 427 |
|
409 |
| - if not args.gui: |
410 |
| - plt.show() |
411 |
| - else: |
412 |
| - plt.show(block=False) |
| 428 | + if not args.gui: |
| 429 | + plt.show() |
| 430 | + else: |
| 431 | + plt.show(block=False) |
413 | 432 |
|
414 | 433 | if args.save:
|
415 |
| - wavfile.write( |
416 |
| - "bss_iva_mix.wav", |
417 |
| - room.fs, |
418 |
| - pra.normalize(mics_signals[0, :], bits=16).astype(np.int16), |
| 434 | + |
| 435 | + scale = (0.95 * (2 ** 15)) / np.max( |
| 436 | + [np.abs(s).max() for s in [mics_signals[0, :], ref, y_hat]] |
419 | 437 | )
|
420 |
| - for i, sig in enumerate(y_hat): |
| 438 | + |
| 439 | + if not SEP_SAMPLES_DIR.exists(): |
| 440 | + os.mkdir(SEP_SAMPLES_DIR) |
| 441 | + |
| 442 | + def wavsave(type_, fs, audio): |
421 | 443 | wavfile.write(
|
422 |
| - "bss_iva_source{}.wav".format(i + 1), |
423 |
| - room.fs, |
424 |
| - pra.normalize(sig, bits=16).astype(np.int16), |
| 444 | + SEP_SAMPLES_DIR |
| 445 | + / f"sample_{SINR}_{args.seed}_{args.algo}_{args.dist}_{n_mics}_{type_}.wav", |
| 446 | + fs, |
| 447 | + (scale * audio).astype(np.int16), |
425 | 448 | )
|
426 | 449 |
|
| 450 | + wavsave("mix", room.fs, mics_signals[0]) |
| 451 | + wavsave("ref", room.fs, ref[0]) |
| 452 | + for i, sig in enumerate(y_hat.T): |
| 453 | + wavsave(f"source{i}", room.fs, sig) |
| 454 | + |
427 | 455 | if args.gui:
|
428 | 456 |
|
429 | 457 | from tkinter import Tk
|
|
0 commit comments