Skip to content

Commit 66ddc55

Browse files
committed
Adds script to generate speech samples
1 parent 44cfacd commit 66ddc55

File tree

4 files changed

+193
-40
lines changed

4 files changed

+193
-40
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ of the paper
66

77
> R. Scheibler and N. Ono, [*"Fast Independent Vector Extraction by Iterative SINR Maximization,"*](http://arxiv.org/abs/1910.10654) 2019.
88
9+
Speech samples are available [here](http://robinscheibler.org/icassp2020).
10+
911
Abstract
1012
--------
1113

@@ -47,6 +49,11 @@ An `environment.yml` file is provided to install the required dependencies.
4749
# switch to new environment
4850
conda activate 2019_scheibler_five
4951

52+
Listen
53+
------
54+
55+
Samples are available [here
56+
5057
Test FIVE
5158
---------
5259

@@ -226,6 +233,7 @@ Summary of the Files in this Repo
226233
paper_sim_config.json # simulation configuration file
227234
paper_plot_figures.py # plots the figures from the paper
228235
paper_plot_everything.py # plots all the output of paper_simulation.py
236+
make_separation_samples.py # create sample separated signals
229237

230238
data # directory containing simulation results
231239
rrtools # tools for parallel simulation

example.py

+66-38
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,32 @@
2222
2323
This script requires the `mir_eval` to run, and `tkinter` and `sounddevice` packages for the GUI option.
2424
"""
25+
import os
2526
import sys
2627
import time
28+
from pathlib import Path
2729

2830
import matplotlib
2931
import numpy as np
32+
from mir_eval.separation import bss_eval_sources
33+
from scipy.io import wavfile
34+
3035
from auxiva_pca import auxiva_pca
3136
from five import five
32-
3337
# Get the data if needed
3438
from get_data import get_data, samples_dir
3539
from ive import ogive
36-
from mir_eval.separation import bss_eval_sources
3740
from overiva import overiva
3841
from pyroomacoustics.bss import projection_back
39-
from routines import PlaySoundGUI, grid_layout, random_layout, semi_circle_layout
40-
from scipy.io import wavfile
41-
42-
get_data()
42+
from routines import (PlaySoundGUI, grid_layout, random_layout,
43+
semi_circle_layout)
4344
from samples.generate_samples import sampling, wav_read_center
4445

4546
# Once we are sure the data is there, import some methods
4647
# to select and read samples
4748
sys.path.append(samples_dir)
4849

50+
SEP_SAMPLES_DIR = Path("separation_samples")
4951

5052
# We concatenate a few samples to make them long enough
5153
if __name__ == "__main__":
@@ -96,6 +98,15 @@
9698
action="store_true",
9799
help="Creates a small GUI for easy playback of the sound samples",
98100
)
101+
parser.add_argument(
102+
"--no_plot", action="store_true", help="Do not plot anything",
103+
)
104+
parser.add_argument(
105+
"--sinr", default=5, type=int, help="Signal-to-Interference-and-Noise Ratio",
106+
)
107+
parser.add_argument(
108+
"--seed", default=7284023459, type=int, help="Seed for the simulation",
109+
)
99110
parser.add_argument(
100111
"--save",
101112
action="store_true",
@@ -127,13 +138,13 @@
127138
use_real_R = False
128139

129140
# fix the randomness for repeatability
130-
np.random.seed(30)
141+
np.random.seed(args.seed)
131142

132143
# set the source powers, the first one is half
133144
source_std = np.ones(n_sources_target)
134145
source_std[0] /= np.sqrt(2.0)
135146

136-
SINR = 5 # signal-to-interference-and-noise ratio
147+
SINR = args.sinr # signal-to-interference-and-noise ratio
137148
SINR_diffuse_ratio = 0.9999 # ratio of uncorrelated to diffuse noise
138149

139150
# STFT parameters
@@ -161,13 +172,17 @@
161172
)
162173
# interferer_locs = grid_layout([3., 5.5], n_sources - n_sources_target, offset=[6.5, 1., 1.7])
163174
interferer_locs = random_layout(
164-
[3.0, 5.5, 1.5], n_sources - n_sources_target, offset=[6.5, 1.0, 0.5], seed=1234
175+
[3.0, 5.5, 1.5], n_sources - n_sources_target, offset=[6.5, 1.0, 0.5],
165176
)
166177
source_locs = np.concatenate((target_locs, interferer_locs), axis=1)
167178

168179
# Prepare the signals
169180
wav_files = sampling(
170-
1, n_sources, f"{samples_dir}/metadata.json", gender_balanced=True, seed=2222
181+
1,
182+
n_sources,
183+
f"{samples_dir}/metadata.json",
184+
gender_balanced=True,
185+
seed=args.seed,
171186
)[0]
172187
signals = wav_read_center(wav_files, seed=123)
173188

@@ -237,7 +252,7 @@ def callback_mix(
237252

238253
# reference is taken at microphone 0
239254
ref = np.vstack(
240-
[separate_recordings[0, :1], np.sum(separate_recordings[1:, :1], axis=0)]
255+
[separate_recordings[0, :1], np.sum(separate_recordings[1:, :1], axis=0)]
241256
)
242257

243258
SDR, SIR, eval_time = [], [], []
@@ -272,7 +287,9 @@ def convergence_callback(Y, **kwargs):
272287
eval_time.append(t_exit - t_enter)
273288

274289
if args.algo.startswith("ogive"):
275-
callback_checkpoints = list(range(1, ogive_iter + ogive_iter // n_iter, ogive_iter // n_iter))
290+
callback_checkpoints = list(
291+
range(1, ogive_iter + ogive_iter // n_iter, ogive_iter // n_iter)
292+
)
276293
else:
277294
callback_checkpoints = list(range(1, n_iter + 1))
278295
if args.no_cb:
@@ -386,44 +403,55 @@ def convergence_callback(Y, **kwargs):
386403
print(f"SDR: In: {SDR[0, 0]:6.2f} dB -> Out: {SDR[-1, 0]:6.2f} dB")
387404
print(f"SIR: In: {SIR[0, 0]:6.2f} dB -> Out: {SIR[-1, 0]:6.2f} dB")
388405

389-
import matplotlib.pyplot as plt
406+
if not args.no_plot:
407+
import matplotlib.pyplot as plt
390408

391-
plt.figure()
409+
plt.figure()
392410

393-
plt.subplot(2, 1, 1)
394-
plt.specgram(mics_signals[0], NFFT=1024, Fs=room.fs)
395-
plt.title("Microphone 0 input")
411+
plt.subplot(2, 1, 1)
412+
plt.specgram(mics_signals[0], NFFT=1024, Fs=room.fs)
413+
plt.title("Microphone 0 input")
396414

397-
plt.subplot(2, 1, 2)
398-
plt.specgram(y_hat[:, 0], NFFT=1024, Fs=room.fs)
399-
plt.title("Extracted source")
415+
plt.subplot(2, 1, 2)
416+
plt.specgram(y_hat[:, 0], NFFT=1024, Fs=room.fs)
417+
plt.title("Extracted source")
400418

401-
plt.tight_layout(pad=0.5)
419+
plt.tight_layout(pad=0.5)
402420

403-
plt.figure()
404-
plt.plot([0] + callback_checkpoints, SDR[:, 0], label="SDR", marker="*")
405-
plt.plot([0] + callback_checkpoints, SIR[:, 0], label="SIR", marker="o")
406-
plt.legend()
407-
plt.tight_layout(pad=0.5)
421+
if not args.no_cb:
422+
plt.figure()
423+
plt.plot([0] + callback_checkpoints, SDR[:, 0], label="SDR", marker="*")
424+
plt.plot([0] + callback_checkpoints, SIR[:, 0], label="SIR", marker="o")
425+
plt.legend()
426+
plt.tight_layout(pad=0.5)
408427

409-
if not args.gui:
410-
plt.show()
411-
else:
412-
plt.show(block=False)
428+
if not args.gui:
429+
plt.show()
430+
else:
431+
plt.show(block=False)
413432

414433
if args.save:
415-
wavfile.write(
416-
"bss_iva_mix.wav",
417-
room.fs,
418-
pra.normalize(mics_signals[0, :], bits=16).astype(np.int16),
434+
435+
scale = (0.95 * (2 ** 15)) / np.max(
436+
[np.abs(s).max() for s in [mics_signals[0, :], ref, y_hat]]
419437
)
420-
for i, sig in enumerate(y_hat):
438+
439+
if not SEP_SAMPLES_DIR.exists():
440+
os.mkdir(SEP_SAMPLES_DIR)
441+
442+
def wavsave(type_, fs, audio):
421443
wavfile.write(
422-
"bss_iva_source{}.wav".format(i + 1),
423-
room.fs,
424-
pra.normalize(sig, bits=16).astype(np.int16),
444+
SEP_SAMPLES_DIR
445+
/ f"sample_{SINR}_{args.seed}_{args.algo}_{args.dist}_{n_mics}_{type_}.wav",
446+
fs,
447+
(scale * audio).astype(np.int16),
425448
)
426449

450+
wavsave("mix", room.fs, mics_signals[0])
451+
wavsave("ref", room.fs, ref[0])
452+
for i, sig in enumerate(y_hat.T):
453+
wavsave(f"source{i}", room.fs, sig)
454+
427455
if args.gui:
428456

429457
from tkinter import Tk

get_data.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
This script can be used to download the data used in the experiments.
2222
"""
2323
import os
24+
2425
from pyroomacoustics.datasets.utils import download_uncompress
2526

2627
url_data = "https://zenodo.org/record/3066489/files/cmu_arctic_concat15.tar.gz"
@@ -30,7 +31,11 @@
3031

3132
def get_data():
3233
if os.path.exists(samples_dir):
33-
print("The samples directory " f"{samples_dir}" " seems to exist already. Delete if re-download is needed.")
34+
print(
35+
"The samples directory "
36+
f"{samples_dir}"
37+
" seems to exist already. Delete if re-download is needed."
38+
)
3439
else:
3540
print("Downloading the samples... ", end="")
3641
download_uncompress(url_data, temp_dir)
@@ -41,5 +46,7 @@ def get_data():
4146
print("done.")
4247

4348

49+
get_data()
50+
4451
if __name__ == "__main__":
45-
get_data()
52+
pass

make_separation_samples.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import os
2+
import subprocess
3+
from pathlib import Path
4+
5+
SEP_SAMPLES_DIR = Path("separation_samples")
6+
7+
N_ITER = {
8+
"five": 3,
9+
"overiva": 10,
10+
"auxiva": 50,
11+
}
12+
13+
ALGO_NAMES = {
14+
"five": "FIVE",
15+
"overiva": "OverIVA",
16+
"auxiva": "AuxIVA",
17+
}
18+
19+
if __name__ == "__main__":
20+
21+
if not SEP_SAMPLES_DIR.exists():
22+
os.mkdir(SEP_SAMPLES_DIR)
23+
24+
f = open(SEP_SAMPLES_DIR / "table.html", "w")
25+
26+
print(
27+
"""<table>
28+
<tr>
29+
<td># mics</td>
30+
<td>sample #</td>
31+
<td>algo.</td>
32+
<td>clean</td>
33+
<td>mix</td>
34+
<td>output</td>
35+
<td>SDR</td>
36+
<td>SIR</td>
37+
<td>iter.</td>
38+
<td>runtime</td>
39+
</tr>""",
40+
file=f,
41+
)
42+
43+
for sinr in [5]:
44+
for dist in ["gauss"]:
45+
for n_mics in [2, 3, 5, 8]:
46+
for i_seed, seed in enumerate(["2785643", "398745627", "58984517"]):
47+
for algo in ["five", "overiva", "auxiva"]:
48+
49+
print(f"sinr={sinr} mics={n_mics} seed={seed} algo={algo}")
50+
51+
command = [
52+
"python",
53+
"./example.py",
54+
"-m",
55+
str(n_mics),
56+
"-a",
57+
algo,
58+
"-d",
59+
"gauss",
60+
"-n",
61+
str(N_ITER[algo]),
62+
"--seed",
63+
str(seed),
64+
"--save",
65+
"--no_cb",
66+
"--no_plot",
67+
]
68+
69+
out = subprocess.run(command, capture_output=True)
70+
71+
if out.returncode != 0:
72+
print("Failed!!")
73+
print("stderr:")
74+
print(out.stderr)
75+
print("stdout:")
76+
print(out.stdout)
77+
78+
else:
79+
lines = out.stdout.decode().split("\n")
80+
81+
for l in lines:
82+
e = l.split()
83+
if len(e) == 0:
84+
continue
85+
elif l.startswith("Processing"):
86+
runtime = e[2]
87+
elif l.startswith("SDR"):
88+
sdr = e[6]
89+
elif l.startswith("SIR"):
90+
sir = e[6]
91+
92+
print(
93+
f""" <tr>
94+
<td>{n_mics}</td>
95+
<td>{i_seed + 1}</td>
96+
<td>{ALGO_NAMES[algo]}</td>
97+
<td><audio controls="controls" type="audio/wav" src="<SEPDIR>/sample_{sinr}_{seed}_{algo}_{dist}_{n_mics}_ref.wav"><a>play</a></audio></td>
98+
<td><audio controls="controls" type="audio/wav" src="<SEPDIR>/sample_{sinr}_{seed}_{algo}_{dist}_{n_mics}_mix.wav"><a>play</a></audio></td>
99+
<td><audio controls="controls" type="audio/wav" src="<SEPDIR>/sample_{sinr}_{seed}_{algo}_{dist}_{n_mics}_source0.wav"><a>play</a></audio></td>
100+
<td>{sdr} dB</td>
101+
<td>{sir} dB</td>
102+
<td>{N_ITER[algo]}</td>
103+
<td>{runtime} s</td>
104+
</tr>""",
105+
file=f,
106+
)
107+
108+
print("</table>", file=f)
109+
110+
f.close()

0 commit comments

Comments
 (0)