-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathesc50.py
37 lines (27 loc) · 1.34 KB
/
esc50.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import pandas as pd
from torch.utils.data import Dataset
from .transforms import preprocess_audio
class ESC50(Dataset):
def __init__(self, config, mode='train'):
assert mode in ['train', 'val']
assert config.val_fold >= 1 and config.val_fold <= 5, "`config.val_fold` must be between 1 and 5."
data_root = os.path.expanduser(config.data_root)
self.audio_dir = os.path.join(data_root, 'audio')
meta_dir = os.path.join(data_root, 'meta')
self.sample_rate = config.sample_rate
if not os.path.isdir(self.audio_dir):
raise RuntimeError(f'Audio directory: {self.audio_dir} does not exist.')
if not os.path.isdir(meta_dir):
raise RuntimeError(f'Meta directory: {meta_dir} does not exist.')
metadata = pd.read_csv(f'{meta_dir}/esc50.csv')
select_fold = metadata['fold'] != config.val_fold if mode == 'train' else metadata['fold'] == config.val_fold
self.metadata = metadata[select_fold].reset_index(drop=True)
def __len__(self):
return len(self.metadata)
def __getitem__(self, index):
row = self.metadata.iloc[index]
file_path = f"{self.audio_dir}/{row['filename']}"
label = row['target']
mel_spectrogram = preprocess_audio(file_path, sample_rate=self.sample_rate)
return mel_spectrogram, label