Skip to content

Commit e5dbe4d

Browse files
author
neoncloud
committed
Format code
1 parent ad2aed9 commit e5dbe4d

File tree

4 files changed

+45
-36
lines changed

4 files changed

+45
-36
lines changed

models/networks.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,17 @@ def get_target_tensor(self, input, target_is_real):
116116
create_label = ((self.real_label_var is None) or
117117
(self.real_label_var.shape != input.shape))
118118
if create_label:
119-
real_tensor = torch.Tensor(input.size()).fill_(self.real_label).to(self.device)
119+
real_tensor = torch.Tensor(input.size()).fill_(
120+
self.real_label).to(self.device)
120121
self.real_label_var = Variable(
121122
real_tensor, requires_grad=False)
122123
target_tensor = self.real_label_var
123124
else:
124125
create_label = ((self.fake_label_var is None) or
125126
(self.fake_label_var.shape != input.shape))
126127
if create_label:
127-
fake_tensor = torch.Tensor(input.size()).fill_(self.fake_label).to(self.device)
128+
fake_tensor = torch.Tensor(input.size()).fill_(
129+
self.fake_label).to(self.device)
128130
self.fake_label_var = Variable(
129131
fake_tensor, requires_grad=False)
130132
target_tensor = self.fake_label_var
@@ -213,41 +215,41 @@ def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_
213215
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
214216
norm_layer(ngf_global), nn.ReLU(True),
215217
downsample_layer(ngf_global, ngf_global * 2,
216-
kernel_size=3, stride=2, padding=1),
218+
kernel_size=3, stride=2, padding=1),
217219
norm_layer(ngf_global * 2), nn.ReLU(True)]
218220
# residual blocks
219221
model_upsample = []
220222
for i in range(n_blocks_local):
221223
model_upsample += [ResnetBlock(ngf_global * 2,
222-
padding_type=padding_type, norm_layer=norm_layer)]
224+
padding_type=padding_type, norm_layer=norm_layer)]
223225
# attention bottleneck
224226
if n_attn_l > 0:
225227
middle = n_blocks_local//2
226228
# 8x downsample
227229
down = [downsample_layer(ngf_global * 2, ngf_global,
228-
kernel_size=3, stride=2, padding=1),
230+
kernel_size=3, stride=2, padding=1),
229231
norm_layer(ngf_global), nn.ReLU(True)]
230232
down += [downsample_layer(ngf_global, ngf_global,
231-
kernel_size=3, stride=2, padding=1),
232-
norm_layer(ngf_global), nn.ReLU(True)]*2
233+
kernel_size=3, stride=2, padding=1),
234+
norm_layer(ngf_global), nn.ReLU(True)]*2
233235
down = nn.Sequential(*down)
234236
model_upsample.insert(middle, down)
235237

236238
middle += 1
237239
input_size = tuple(map(lambda x: x//16, input_size))
238240
from bottleneck_transformer_pytorch import BottleStack
239241
attn_block = BottleStack(dim=ngf_global, fmap_size=input_size, dim_out=ngf_global*2, num_layers=n_attn_l, proj_factor=proj_factor_l,
240-
downsample=False, heads=heads_l, dim_head=dim_head_l, activation=nn.ReLU(True), rel_pos_emb=False)
242+
downsample=False, heads=heads_l, dim_head=dim_head_l, activation=nn.ReLU(True), rel_pos_emb=False)
241243
model_upsample.insert(middle, attn_block)
242244
model_upsample += [upsample_layer(in_channels=ngf_global*2, out_channels=ngf_global*2, kernel_size=3, stride=2, padding=1, output_padding=1),
243-
norm_layer(ngf_global), nn.ReLU(True)]*3
245+
norm_layer(ngf_global), nn.ReLU(True)]*3
244246

245247
model_upsample += [upsample_layer(in_channels=ngf_global*2, out_channels=ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
246-
norm_layer(ngf_global), nn.ReLU(True)]
248+
norm_layer(ngf_global), nn.ReLU(True)]
247249

248250
# final convolution
249251
model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(
250-
ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
252+
ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
251253

252254
self.model1_1 = nn.Sequential(*model_downsample)
253255
self.model1_2 = nn.Sequential(*model_upsample)

models/pix2pixHD_model.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
#from dct.dct_native import DCT_2N_native, IDCT_2N_native
1111
import torchaudio.functional as aF
1212

13+
1314
class Audio2Spectro(torch.nn.Module):
1415
def __init__(self, opt) -> None:
15-
super(Audio2Spectro,self).__init__()
16+
super(Audio2Spectro, self).__init__()
1617
opt_dict = vars(opt)
1718
for k, v in opt_dict.items():
1819
setattr(self, k, v)
@@ -27,7 +28,7 @@ def __init__(self, opt) -> None:
2728
self._imdct = IMDCT4(n_fft=self.n_fft, hop_length=self.hop_length,
2829
win_length=self.win_length, window=self.window, device=self.device)
2930

30-
def to_spectro(self, audio:torch.Tensor, mask:bool=False, mask_size:int=-1):
31+
def to_spectro(self, audio: torch.Tensor, mask: bool = False, mask_size: int = -1):
3132
# Forward Transformation (MDCT)
3233
spectro, frames = self._mdct(audio.to(self.device), True)
3334
spectro = spectro.unsqueeze(1)
@@ -59,12 +60,14 @@ def to_spectro(self, audio:torch.Tensor, mask:bool=False, mask_size:int=-1):
5960
mask_size = int(size[3]*(1-1/self.up_ratio))
6061

6162
# fill the blank mask with noise
62-
_noise = torch.randn(size[0], size[1], size[2], mask_size, device=self.device)
63+
_noise = torch.randn(
64+
size[0], size[1], size[2], mask_size, device=self.device)
6365
_noise_min = _noise.min()
6466
_noise_max = _noise.max()
6567

6668
if self.fit_residual:
67-
_noise = torch.zeros(size[0], size[1], size[2], mask_size, device=self.device)
69+
_noise = torch.zeros(
70+
size[0], size[1], size[2], mask_size, device=self.device)
6871
else:
6972
# fill empty with randn noise, single peak, centered at 0
7073
_noise = _noise/(_noise_max - _noise_min)
@@ -108,16 +111,18 @@ def normalize(self, spectro):
108111
audio_min = log_spectro.flatten(-2).min(dim=-
109112
1).values[:, :, None, None].float()
110113
else:
111-
audio_min = torch.tensor([self.src_range[0]])[None,None,None,:].to(self.device)
112-
audio_max = torch.tensor([self.src_range[1]])[None,None,None,:].to(self.device)
114+
audio_min = torch.tensor([self.src_range[0]])[
115+
None, None, None, :].to(self.device)
116+
audio_max = torch.tensor([self.src_range[1]])[
117+
None, None, None, :].to(self.device)
113118
log_spectro = (log_spectro-audio_min)/(audio_max-audio_min)
114119
log_spectro = log_spectro * \
115120
(self.norm_range[1]-self.norm_range[0]
116121
)+self.norm_range[0]
117122

118123
return log_spectro, audio_max, audio_min, mean, std
119124

120-
def denormalize(self, log_spectro:torch.Tensor, min:torch.Tensor, max:torch.Tensor):
125+
def denormalize(self, log_spectro: torch.Tensor, min: torch.Tensor, max: torch.Tensor):
121126
log_spectro = (
122127
log_spectro.to(torch.float64)-self.norm_range[0])/(self.norm_range[1]-self.norm_range[0])
123128
log_spectro = log_spectro*(max-min)+min
@@ -127,8 +132,9 @@ def denormalize(self, log_spectro:torch.Tensor, min:torch.Tensor, max:torch.Tens
127132
else:
128133
return aF.DB_to_amplitude(log_spectro.to(self.device), 10.0, 0.5)-self.min_value
129134

130-
def to_audio(self, log_spectro:torch.Tensor, norm_param:Dict[str,torch.Tensor], pha:torch.Tensor):
131-
spectro = self.denormalize(log_spectro, norm_param['min'], norm_param['max'])
135+
def to_audio(self, log_spectro: torch.Tensor, norm_param: Dict[str, torch.Tensor], pha: torch.Tensor):
136+
spectro = self.denormalize(
137+
log_spectro, norm_param['min'], norm_param['max'])
132138
if self.explicit_encoding:
133139
spectro = (spectro[..., 0, :, :] -
134140
spectro[..., 1, :, :])/(2*self.alpha-1)
@@ -151,7 +157,8 @@ def to_audio(self, log_spectro:torch.Tensor, norm_param:Dict[str,torch.Tensor],
151157
return audio
152158

153159
def to_frames(self, log_spectro, norm_param):
154-
spectro = self.denormalize(log_spectro, norm_param['min'],norm_param['max'])
160+
spectro = self.denormalize(
161+
log_spectro, norm_param['min'], norm_param['max'])
155162
if self.explicit_encoding:
156163
spectro = (spectro[..., 0, :, :] -
157164
spectro[..., 1, :, :])/(2*self.alpha-1)
@@ -165,21 +172,22 @@ def norm_frames(self, frames):
165172
frames = frames / frames.max()
166173
return frames * (self.norm_range[1]-self.norm_range[0]) + self.norm_range[0]
167174

168-
def forward(self, lr_audio:torch.Tensor):
175+
def forward(self, lr_audio: torch.Tensor):
169176
# low-res audio for training
170177
with torch.no_grad():
171178
lr_spectro, lr_pha, lr_norm_param = self.to_spectro(
172179
lr_audio, mask=self.mask)
173180
return lr_spectro, lr_pha, lr_norm_param
174181

175-
def hr_forward(self, hr_audio:torch.Tensor):
182+
def hr_forward(self, hr_audio: torch.Tensor):
176183
# high-res audio for training
177184
with torch.no_grad():
178185
hr_spectro, hr_pha, hr_norm_param = self.to_spectro(hr_audio, mask=self.mask_hr, mask_size=int(
179186
self.n_fft*(1-self.sr_sampling_rate/self.hr_sampling_rate)//2))
180187

181188
return hr_spectro, hr_pha, hr_norm_param
182189

190+
183191
class Pix2PixHDModel(BaseModel):
184192
def name(self):
185193
return 'Pix2PixHDModel'
@@ -376,7 +384,8 @@ def discriminate_hifi(self, input, norm_param=None, pha=None, is_spectro=True):
376384
def forward(self, lr_audio, hr_audio):
377385
# Encode Inputs
378386
lr_spectro, lr_pha, lr_norm_param = self.preprocess.forward(lr_audio)
379-
hr_spectro, hr_pha, hr_norm_param = self.preprocess.hr_forward(hr_audio)
387+
hr_spectro, hr_pha, hr_norm_param = self.preprocess.hr_forward(
388+
hr_audio)
380389
#### G Forward ####
381390
if self.abs_spectro and self.arcsinh_transform:
382391
lr_input = lr_spectro.abs()*2+self.norm_range[0]
@@ -395,11 +404,14 @@ def forward(self, lr_audio, hr_audio):
395404
return sr_spectro, sr_pha, hr_spectro, hr_pha, hr_norm_param, lr_spectro, lr_pha, lr_norm_param
396405

397406
def _forward(self, lr_audio, hr_audio, infer=False):
398-
sr_spectro, sr_pha, hr_spectro, hr_pha, hr_norm_param, lr_spectro, lr_pha, lr_norm_param = self.forward(lr_audio, hr_audio)
407+
sr_spectro, sr_pha, hr_spectro, hr_pha, hr_norm_param, lr_spectro, lr_pha, lr_norm_param = self.forward(
408+
lr_audio, hr_audio)
399409
# Fake Detection and Loss
400410
if self.abs_spectro and self.arcsinh_transform:
401-
sr_input = torch.cat((sr_spectro, sr_spectro.abs()*2+self.norm_range[0]), dim=1)
402-
hr_input = torch.cat((hr_spectro, hr_spectro.abs()*2+self.norm_range[0]), dim=1)
411+
sr_input = torch.cat(
412+
(sr_spectro, sr_spectro.abs()*2+self.norm_range[0]), dim=1)
413+
hr_input = torch.cat(
414+
(hr_spectro, hr_spectro.abs()*2+self.norm_range[0]), dim=1)
403415
else:
404416
sr_input = sr_spectro
405417
hr_input = hr_spectro
@@ -584,7 +596,7 @@ def inference(self, lr_audio):
584596
# Encode Inputs
585597
with torch.no_grad():
586598
lr_spectro, lr_pha, lr_norm_param = self.preprocess.forward(
587-
lr_audio)
599+
lr_audio)
588600

589601
if self.abs_spectro and self.arcsinh_transform:
590602
lr_input = lr_spectro.abs()*2+self.norm_range[0]
@@ -673,4 +685,4 @@ def get_current_visuals(self):
673685

674686
class InferenceModel(Pix2PixHDModel):
675687
def forward(self, lr_audio):
676-
return self.inference(lr_audio)
688+
return self.inference(lr_audio)

run_script.py

-6
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44

55
from options.train_options import TrainOptions
66
from data.data_loader import CreateDataLoader
7-
from util.visualizer import Visualizer
8-
from util.spectro_img import compute_visuals
97
from util.util import compute_matrics
108

119
# Initilize the setup
1210
opt = TrainOptions().parse()
13-
visualizer = Visualizer(opt)
1411
data_loader = CreateDataLoader(opt)
1512
dataset = data_loader.load_data()
1613
dataset_size = len(data_loader)
@@ -47,9 +44,6 @@
4744
print('MSE: %.4f' % _mse)
4845
print('SNR_SR: %.4f' % _snr_sr)
4946
print('SNR_LR: %.4f' % _snr_lr)
50-
#print('SSNR_SR: %.4f' % _ssnr_sr)
51-
#print('SSNR_LR: %.4f' % _ssnr_lr)
52-
#print('PESQ: %.4f' % _pesq)
5347
print('LSD: %.4f' % _lsd)
5448
with open(os.path.join(opt.checkpoints_dir, opt.name, 'metric.txt'), 'w') as f:
5549
f.write('MSE,SNR_SR,LSD\n')

save_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
opt.isTrain = False
1010
model = create_model(opt)
1111
model_scripted = torch.jit.script(model)
12-
torch.jit.save(model_scripted,os.path.join(opt.checkpoints_dir, opt.name, 'model_scripted.pt'))
12+
torch.jit.save(model_scripted,os.path.join(opt.checkpoints_dir, opt.name, 'model_scripted.pt'))
13+
torch.save(opt,os.path.join(opt.checkpoints_dir, opt.name, 'opt.pt'))

0 commit comments

Comments
 (0)