10
10
#from dct.dct_native import DCT_2N_native, IDCT_2N_native
11
11
import torchaudio .functional as aF
12
12
13
+
13
14
class Audio2Spectro (torch .nn .Module ):
14
15
def __init__ (self , opt ) -> None :
15
- super (Audio2Spectro ,self ).__init__ ()
16
+ super (Audio2Spectro , self ).__init__ ()
16
17
opt_dict = vars (opt )
17
18
for k , v in opt_dict .items ():
18
19
setattr (self , k , v )
@@ -27,7 +28,7 @@ def __init__(self, opt) -> None:
27
28
self ._imdct = IMDCT4 (n_fft = self .n_fft , hop_length = self .hop_length ,
28
29
win_length = self .win_length , window = self .window , device = self .device )
29
30
30
- def to_spectro (self , audio :torch .Tensor , mask :bool = False , mask_size :int = - 1 ):
31
+ def to_spectro (self , audio : torch .Tensor , mask : bool = False , mask_size : int = - 1 ):
31
32
# Forward Transformation (MDCT)
32
33
spectro , frames = self ._mdct (audio .to (self .device ), True )
33
34
spectro = spectro .unsqueeze (1 )
@@ -59,12 +60,14 @@ def to_spectro(self, audio:torch.Tensor, mask:bool=False, mask_size:int=-1):
59
60
mask_size = int (size [3 ]* (1 - 1 / self .up_ratio ))
60
61
61
62
# fill the blank mask with noise
62
- _noise = torch .randn (size [0 ], size [1 ], size [2 ], mask_size , device = self .device )
63
+ _noise = torch .randn (
64
+ size [0 ], size [1 ], size [2 ], mask_size , device = self .device )
63
65
_noise_min = _noise .min ()
64
66
_noise_max = _noise .max ()
65
67
66
68
if self .fit_residual :
67
- _noise = torch .zeros (size [0 ], size [1 ], size [2 ], mask_size , device = self .device )
69
+ _noise = torch .zeros (
70
+ size [0 ], size [1 ], size [2 ], mask_size , device = self .device )
68
71
else :
69
72
# fill empty with randn noise, single peak, centered at 0
70
73
_noise = _noise / (_noise_max - _noise_min )
@@ -108,16 +111,18 @@ def normalize(self, spectro):
108
111
audio_min = log_spectro .flatten (- 2 ).min (dim = -
109
112
1 ).values [:, :, None , None ].float ()
110
113
else :
111
- audio_min = torch .tensor ([self .src_range [0 ]])[None ,None ,None ,:].to (self .device )
112
- audio_max = torch .tensor ([self .src_range [1 ]])[None ,None ,None ,:].to (self .device )
114
+ audio_min = torch .tensor ([self .src_range [0 ]])[
115
+ None , None , None , :].to (self .device )
116
+ audio_max = torch .tensor ([self .src_range [1 ]])[
117
+ None , None , None , :].to (self .device )
113
118
log_spectro = (log_spectro - audio_min )/ (audio_max - audio_min )
114
119
log_spectro = log_spectro * \
115
120
(self .norm_range [1 ]- self .norm_range [0 ]
116
121
)+ self .norm_range [0 ]
117
122
118
123
return log_spectro , audio_max , audio_min , mean , std
119
124
120
- def denormalize (self , log_spectro :torch .Tensor , min :torch .Tensor , max :torch .Tensor ):
125
+ def denormalize (self , log_spectro : torch .Tensor , min : torch .Tensor , max : torch .Tensor ):
121
126
log_spectro = (
122
127
log_spectro .to (torch .float64 )- self .norm_range [0 ])/ (self .norm_range [1 ]- self .norm_range [0 ])
123
128
log_spectro = log_spectro * (max - min )+ min
@@ -127,8 +132,9 @@ def denormalize(self, log_spectro:torch.Tensor, min:torch.Tensor, max:torch.Tens
127
132
else :
128
133
return aF .DB_to_amplitude (log_spectro .to (self .device ), 10.0 , 0.5 )- self .min_value
129
134
130
- def to_audio (self , log_spectro :torch .Tensor , norm_param :Dict [str ,torch .Tensor ], pha :torch .Tensor ):
131
- spectro = self .denormalize (log_spectro , norm_param ['min' ], norm_param ['max' ])
135
+ def to_audio (self , log_spectro : torch .Tensor , norm_param : Dict [str , torch .Tensor ], pha : torch .Tensor ):
136
+ spectro = self .denormalize (
137
+ log_spectro , norm_param ['min' ], norm_param ['max' ])
132
138
if self .explicit_encoding :
133
139
spectro = (spectro [..., 0 , :, :] -
134
140
spectro [..., 1 , :, :])/ (2 * self .alpha - 1 )
@@ -151,7 +157,8 @@ def to_audio(self, log_spectro:torch.Tensor, norm_param:Dict[str,torch.Tensor],
151
157
return audio
152
158
153
159
def to_frames (self , log_spectro , norm_param ):
154
- spectro = self .denormalize (log_spectro , norm_param ['min' ],norm_param ['max' ])
160
+ spectro = self .denormalize (
161
+ log_spectro , norm_param ['min' ], norm_param ['max' ])
155
162
if self .explicit_encoding :
156
163
spectro = (spectro [..., 0 , :, :] -
157
164
spectro [..., 1 , :, :])/ (2 * self .alpha - 1 )
@@ -165,21 +172,22 @@ def norm_frames(self, frames):
165
172
frames = frames / frames .max ()
166
173
return frames * (self .norm_range [1 ]- self .norm_range [0 ]) + self .norm_range [0 ]
167
174
168
- def forward (self , lr_audio :torch .Tensor ):
175
+ def forward (self , lr_audio : torch .Tensor ):
169
176
# low-res audio for training
170
177
with torch .no_grad ():
171
178
lr_spectro , lr_pha , lr_norm_param = self .to_spectro (
172
179
lr_audio , mask = self .mask )
173
180
return lr_spectro , lr_pha , lr_norm_param
174
181
175
- def hr_forward (self , hr_audio :torch .Tensor ):
182
+ def hr_forward (self , hr_audio : torch .Tensor ):
176
183
# high-res audio for training
177
184
with torch .no_grad ():
178
185
hr_spectro , hr_pha , hr_norm_param = self .to_spectro (hr_audio , mask = self .mask_hr , mask_size = int (
179
186
self .n_fft * (1 - self .sr_sampling_rate / self .hr_sampling_rate )// 2 ))
180
187
181
188
return hr_spectro , hr_pha , hr_norm_param
182
189
190
+
183
191
class Pix2PixHDModel (BaseModel ):
184
192
def name (self ):
185
193
return 'Pix2PixHDModel'
@@ -376,7 +384,8 @@ def discriminate_hifi(self, input, norm_param=None, pha=None, is_spectro=True):
376
384
def forward (self , lr_audio , hr_audio ):
377
385
# Encode Inputs
378
386
lr_spectro , lr_pha , lr_norm_param = self .preprocess .forward (lr_audio )
379
- hr_spectro , hr_pha , hr_norm_param = self .preprocess .hr_forward (hr_audio )
387
+ hr_spectro , hr_pha , hr_norm_param = self .preprocess .hr_forward (
388
+ hr_audio )
380
389
#### G Forward ####
381
390
if self .abs_spectro and self .arcsinh_transform :
382
391
lr_input = lr_spectro .abs ()* 2 + self .norm_range [0 ]
@@ -395,11 +404,14 @@ def forward(self, lr_audio, hr_audio):
395
404
return sr_spectro , sr_pha , hr_spectro , hr_pha , hr_norm_param , lr_spectro , lr_pha , lr_norm_param
396
405
397
406
def _forward (self , lr_audio , hr_audio , infer = False ):
398
- sr_spectro , sr_pha , hr_spectro , hr_pha , hr_norm_param , lr_spectro , lr_pha , lr_norm_param = self .forward (lr_audio , hr_audio )
407
+ sr_spectro , sr_pha , hr_spectro , hr_pha , hr_norm_param , lr_spectro , lr_pha , lr_norm_param = self .forward (
408
+ lr_audio , hr_audio )
399
409
# Fake Detection and Loss
400
410
if self .abs_spectro and self .arcsinh_transform :
401
- sr_input = torch .cat ((sr_spectro , sr_spectro .abs ()* 2 + self .norm_range [0 ]), dim = 1 )
402
- hr_input = torch .cat ((hr_spectro , hr_spectro .abs ()* 2 + self .norm_range [0 ]), dim = 1 )
411
+ sr_input = torch .cat (
412
+ (sr_spectro , sr_spectro .abs ()* 2 + self .norm_range [0 ]), dim = 1 )
413
+ hr_input = torch .cat (
414
+ (hr_spectro , hr_spectro .abs ()* 2 + self .norm_range [0 ]), dim = 1 )
403
415
else :
404
416
sr_input = sr_spectro
405
417
hr_input = hr_spectro
@@ -584,7 +596,7 @@ def inference(self, lr_audio):
584
596
# Encode Inputs
585
597
with torch .no_grad ():
586
598
lr_spectro , lr_pha , lr_norm_param = self .preprocess .forward (
587
- lr_audio )
599
+ lr_audio )
588
600
589
601
if self .abs_spectro and self .arcsinh_transform :
590
602
lr_input = lr_spectro .abs ()* 2 + self .norm_range [0 ]
@@ -673,4 +685,4 @@ def get_current_visuals(self):
673
685
674
686
class InferenceModel (Pix2PixHDModel ):
675
687
def forward (self , lr_audio ):
676
- return self .inference (lr_audio )
688
+ return self .inference (lr_audio )
0 commit comments