1
+ import torch
1
2
from torch import nn
2
3
import torch .nn .functional as F
3
4
import numpy as np
4
-
5
+ import os . path
5
6
6
7
def new_size_conv (size , kernel , stride = 1 , padding = 0 ):
7
8
return np .floor ((size + 2 * padding - (kernel - 1 )- 1 )/ stride + 1 )
@@ -290,7 +291,85 @@ def forward(self, x):
290
291
291
292
return out
292
293
293
-
294
+
295
+ class audio_cnn_block (nn .Module ):
296
+ '''
297
+ 1D convolution block used to build audio cnn classifiers
298
+ Args:
299
+ input: input channels
300
+ output: output channels
301
+ kernel_size: convolution kernel size
302
+ '''
303
+ def __init__ (self , n_input , n_out , kernel_size ):
304
+ super (audio_cnn_block , self ).__init__ ()
305
+ self .cnn_block = nn .Sequential (
306
+ nn .Conv1d (n_input , n_out , kernel_size , padding = 1 ),
307
+ nn .BatchNorm1d (n_out ),
308
+ nn .ReLU (),
309
+ nn .MaxPool1d (kernel_size = 4 , stride = 4 )
310
+ )
311
+
312
+ def forward (self , x ):
313
+ return self .cnn_block (x )
314
+
315
+
316
+ class audio_tiny_cnn (nn .Module ):
317
+ '''
318
+ Template for convolutional audio classifiers.
319
+ '''
320
+ def __init__ (self , cnn_sizes , n_hidden , kernel_size , n_classes ):
321
+ '''
322
+ Init
323
+ Args:
324
+ cnn_sizes: List of sizes for the convolution blocks
325
+ n_hidden: number of hidden units in the first fully connected layer
326
+ kernel_size: convolution kernel size
327
+ n_classes: number of speakers to classify
328
+ '''
329
+ super (audio_tiny_cnn , self ).__init__ ()
330
+ self .down_path = nn .ModuleList ()
331
+ self .down_path .append (audio_cnn_block (cnn_sizes [0 ], cnn_sizes [1 ],
332
+ kernel_size ,))
333
+ self .down_path .append (audio_cnn_block (cnn_sizes [1 ], cnn_sizes [2 ],
334
+ kernel_size ,))
335
+ self .down_path .append (audio_cnn_block (cnn_sizes [2 ], cnn_sizes [3 ],
336
+ kernel_size ,))
337
+ self .fc = nn .Sequential (
338
+ nn .Linear (cnn_sizes [4 ], n_hidden ),
339
+ nn .ReLU ()
340
+ )
341
+ self .out = nn .Linear (n_hidden , n_classes )
342
+
343
+ def forward (self , x ):
344
+ for down in self .down_path :
345
+ x = down (x )
346
+ x = x .view (x .size (0 ), - 1 )
347
+ x = self .fc (x )
348
+ return self .out (x )
349
+
350
+
351
+ def MFCC_cnn_classifier (n_classes ):
352
+ '''
353
+ Builds speaker classifier that ingests MFCC's
354
+ '''
355
+ in_size = 20
356
+ n_hidden = 512
357
+ sizes_list = [in_size , 2 * in_size , 4 * in_size , 8 * in_size , 8 * in_size ]
358
+ return audio_tiny_cnn (cnn_sizes = sizes_list , n_hidden = n_hidden ,
359
+ kernel_size = 3 , n_classes = 125 )
360
+
361
+
362
+ def ft_cnn_classifer (n_classes ):
363
+ '''
364
+ Builds speaker classifier that ingests the abs value of fourier transforms
365
+ '''
366
+ in_size = 94
367
+ n_hidden = 512
368
+ sizes_list = [in_size , in_size , 2 * in_size , 4 * in_size , 14 * 4 * in_size ]
369
+ return audio_tiny_cnn (cnn_sizes = sizes_list , n_hidden = n_hidden ,
370
+ kernel_size = 7 , n_classes = 125 )
371
+
372
+
294
373
def weights_init (m ):
295
374
if isinstance (m , nn .Conv2d ):
296
375
nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
@@ -302,4 +381,30 @@ def weights_init(m):
302
381
elif isinstance (m , nn .Linear ):
303
382
nn .init .xavier_normal_ (m .weight .data )
304
383
nn .init .constant_ (m .bias , 0 )
305
-
384
+
385
+ def save_checkpoint (model = None , optimizer = None , epoch = None ,
386
+ data_descriptor = None , loss = None , accuracy = None , path = './' ,
387
+ filename = 'checkpoint' , ext = '.pth.tar' ):
388
+ state = {
389
+ 'epoch' : epoch ,
390
+ 'arch' : str (model .type ),
391
+ 'state_dict' : model .state_dict (),
392
+ 'optimizer' : optimizer .state_dict (),
393
+ 'loss' : loss ,
394
+ 'accuracy' : accuracy ,
395
+ 'dataset' : data_descriptor
396
+ }
397
+ torch .save (state , path + filename + ext )
398
+
399
+
400
+ def load_checkpoint (model = None , optimizer = None , checkpoint = None ):
401
+ assert os .path .isfile (checkpoint ), 'Checkpoint not found, aborting load'
402
+ chpt = torch .load (checkpoint )
403
+ assert str (model .type ) == chpt ['arch' ], 'Model arquitecture mismatch,\
404
+ aborting load'
405
+ model .load_state_dict (chpt ['state_dict' ])
406
+ if optimizer is not None :
407
+ optimizer .load_state_dict ['optimizer' ]
408
+ print ('Succesfully loaded checkpoint \n Dataset: %s \n Epoch: %s \n Loss: %s\
409
+ \n Accuracy: %s' % (chpt ['dataset' ], chpt ['epoch' ], chpt ['loss' ],
410
+ chpt ['accuracy' ]))
0 commit comments