Skip to content

Commit 8a18bda

Browse files
committed
train to small denoising
1 parent 1f37f2b commit 8a18bda

File tree

10 files changed

+174
-44
lines changed

10 files changed

+174
-44
lines changed

audioenhancer/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
"""
44

55
SAMPLING_RATE = 44100
6-
MAX_AUDIO_LENGTH = 5
7-
BATCH_SIZE = 8
6+
MAX_AUDIO_LENGTH = 6
7+
BATCH_SIZE = 3
88
EPOCH = 1
99
LOGGING_STEPS = 10
10-
GRADIENT_ACCUMULATION_STEPS = 2
10+
GRADIENT_ACCUMULATION_STEPS = 3
1111
SAVE_STEPS = 100
1212
EVAL_STEPS = 100
1313

audioenhancer/dataset/loader.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(
6262
if os.path.isdir(f)
6363
]
6464

65-
self.codecs += [audio_dir,]
6665

6766
self._pad_length_input = 2 ** math.ceil(math.log2(max_duration * input_freq))
6867
self._pad_length_output = 2 ** math.ceil(math.log2(max_duration * output_freq))
@@ -174,22 +173,27 @@ def __getitem__(self, index: int) -> tuple:
174173
base_waveform
175174
)
176175

177-
noisy_waveform = torch.zeros_like(encoded_compressed_waveform)
178-
for i in range(encoded_compressed_waveform.shape[1]):
179-
if random.random() < 0.3:
180-
noisy_waveform[:, i] = encoded_compressed_waveform[:, i]
181-
else:
182-
noisy_waveform[:, i] = encoded_base_waveform[:, i]
176+
noise = encoded_base_waveform - encoded_compressed_waveform
177+
noise_levels = []
178+
noise_level = random.random()
179+
encoded_compressed_waveform[:] = encoded_base_waveform[:] + noise[:] * noise_level
180+
if noise_level <= 0.10:
181+
t_noise_level = 0
182+
else:
183+
t_noise_level = noise_level - 0.10
184+
encoded_base_waveform[:] = encoded_base_waveform[:] + noise[:] * t_noise_level
185+
noise_levels.append(t_noise_level)
186+
183187

184188
# class_id = [0]
185189
# if "dac" in codec or "encodec" in codec or "opus" in codec or use_transform:
186190
# class_id = [1]
187191

188-
# class_id = torch.tensor(class_id).cuda()
192+
noise_levels = torch.tensor(noise_levels)
189193

190194
return (
191-
noisy_waveform,
195+
encoded_compressed_waveform,
192196
encoded_base_waveform,
193197
base_waveform,
194-
# class_id,
198+
noise_levels,
195199
)

audioenhancer/inference.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,22 @@
1010
from einops import rearrange
1111

1212
from audioenhancer.model.audio_ae.model import mamba_model as model
13+
from audioenhancer.model.audio_ae.model import disc_model as disc_model
1314

1415

1516
class Inference:
1617
def __init__(self, model_path: str, sampling_rate: int):
1718
self.model = model
19+
self.disc_model = disc_model
1820
self.device = torch.device("cuda")
1921
self.model = self.model.to(self.device)
2022
self.model.load_state_dict(torch.load(model_path))
2123
self.model.eval()
2224

25+
# self.disc_model = self.disc_model.to(self.device)
26+
# self.disc_model.load_state_dict(torch.load("data/model/disc_model_1400.pt"))
27+
# self.disc_model.eval()
28+
2329
self._sampling_rate = sampling_rate
2430

2531
autoencoder_path = dac.utils.download(model_type="44khz")
@@ -84,7 +90,14 @@ def inference(self, audio_path: str, chunk_duration: int = 5):
8490
c, d = encoded.shape[1], encoded.shape[2]
8591
encoded = rearrange(encoded, "b c d t -> b (t c) d")
8692

87-
pred = self.model(encoded, None)
93+
noise_level = torch.tensor([10]).to(self.device)
94+
for i in range(1, 15):
95+
pred, _ = self.model(encoded, None, False, None)
96+
# _, logits = self.disc_model(pred, None, True, None)
97+
noise_level -= 1
98+
print(f"Noise level: {noise_level.item()}")
99+
if noise_level == 0:
100+
break
88101

89102
pred = rearrange(pred, "b (t c) d -> b c d t", c=c, d=d)
90103
pred = pred.squeeze(0)

audioenhancer/model/audio_ae/latent.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""This module contains all the process for the latent space of the audio autoencoder."""
2+
import random
3+
24
import torch
35
from torch import nn
6+
from torch.nn import functional as F
47

58
from audioenhancer.model.audio_ae.expert import Expert
69
from audioenhancer.model.audio_ae.mamba import MambaBlock
@@ -76,7 +79,7 @@ class LatentProcessor(nn.Module):
7679
This module processes the latent space of the audio autoencoder.
7780
"""
7881

79-
def __init__(self, in_dim: int, out_dim: int, latent_dim, num_layer, num_expert=1):
82+
def __init__(self, in_dim: int, out_dim: int, latent_dim, num_layer, noise_grad=1):
8083
super().__init__()
8184
self.latent_dim = latent_dim
8285
self.num_layer = num_layer
@@ -85,9 +88,11 @@ def __init__(self, in_dim: int, out_dim: int, latent_dim, num_layer, num_expert=
8588
self.in_proj = nn.Linear(in_dim, latent_dim)
8689

8790
self.out_proj = nn.Linear(latent_dim, out_dim)
88-
self.num_expert = num_expert
8991

9092
self.mambas = nn.ModuleList([MambaBlock(config) for _ in range(num_layer)])
93+
self.unknow_noise = nn.Parameter(torch.randn(latent_dim))
94+
self.noise_embed = nn.Embedding(noise_grad, latent_dim)
95+
self.noise_head = nn.Linear(latent_dim, noise_grad)
9196
# self.pre_process = nn.Sequential(
9297
# MambaBlock(config),
9398
# MambaBlock(config),
@@ -104,14 +109,32 @@ def classify(self, x):
104109
x = self.pre_process(x)
105110
return self.classifier(x)
106111

107-
def forward(self, x, classes):
112+
def forward(self, x, noise, gen_noise=False, noise_label=None):
113+
bzs = x.size(0)
108114
h = self.in_proj(x)
115+
if noise is not None and not gen_noise:
116+
noise = self.noise_embed(noise).reshape(bzs, 1, -1)
117+
h = torch.cat([h, noise], dim=1)
118+
gen_noise = True
119+
else:
120+
noise = self.unknow_noise.reshape(1, 1, -1).repeat(bzs, 1, 1)
121+
h = torch.cat([h, noise], dim=1)
122+
109123
# h = self.pre_process(h)
110124
for mamba in self.mambas:
111-
h = mamba(h)
125+
h = mamba(h, gen_noise=gen_noise)
112126
# if classes is not None:
113127
# return x * classes[:, None, None, 0] + self.out_proj(h) * classes[:, None, None, 1]
114-
return self.out_proj(h)
128+
129+
logits = self.noise_head(h[:, -1])
130+
h = h[:, :-1]
131+
if noise_label is not None:
132+
if not gen_noise:
133+
return self.out_proj(h), 0
134+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), noise_label.view(-1))
135+
return self.out_proj(h), loss
136+
137+
return self.out_proj(h), logits
115138

116139
# expert
117140
# def forward(self, x, expert_id=None):

audioenhancer/model/audio_ae/mamba.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def __init__(self, config):
212212
self.mixer_rev = MambaMixer(config)
213213
# self.mlp = MLP(config)
214214

215-
def forward(self, hidden_states):
215+
def forward(self, hidden_states, gen_noise):
216+
bzs, _ , h_dim = hidden_states.shape
216217
residual = hidden_states
217218
hidden_states = self.norm(hidden_states)
218219
if self.residual_in_fp32:
@@ -223,11 +224,17 @@ def forward(self, hidden_states):
223224
self.mixer_rev.to(torch.float32)
224225

225226
out = self.mixer(hidden_states)
226-
out_rev = self.mixer_rev(
227-
hidden_states.flip(dims=(1,))
228-
).flip(dims=(1,))
229-
hidden_states = out + out_rev
230-
227+
if gen_noise:
228+
out_rev = self.mixer_rev(
229+
hidden_states.flip(dims=(1,))[..., 1:, :]
230+
).flip(dims=(1,))
231+
hidden_states = out + torch.cat(
232+
[out_rev, torch.zeros([bzs, 1, h_dim]).to(device=out.device, dtype=out.dtype)],
233+
dim=1
234+
)
235+
else:
236+
out_rev = self.mixer_rev(hidden_states.flip(dims=(1,))).flip(dims=(1,))
237+
hidden_states = out + out_rev
231238
if self.residual_in_fp32:
232239
hidden_states = hidden_states.to(original_dtype)
233240
residual = residual.to(original_dtype)

audioenhancer/model/audio_ae/model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder
1717

1818
from audioenhancer.model.audio_ae.latent import LatentProcessor
19+
from audioenhancer.model.audio_ae.transformer import Transformer
1920
from audioenhancer.model.audio_ae.vdiffusion import CustomVDiffusion
2021

2122
model = DiffusionModel(
@@ -197,10 +198,25 @@
197198
),
198199
)
199200

201+
transformer = Transformer(
202+
in_dim=1024,
203+
out_dim=1024,
204+
latent_dim=1024,
205+
num_layer=6,
206+
)
207+
200208
mamba_model = LatentProcessor(
201209
in_dim=1024,
202210
out_dim=1024,
203211
latent_dim=2048,
204212
num_layer=6,
205-
num_expert=1,
213+
noise_grad=11,
214+
)
215+
216+
disc_model = LatentProcessor(
217+
in_dim=1024,
218+
out_dim=1024,
219+
latent_dim=2048,
220+
num_layer=6,
221+
noise_grad=11,
206222
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import warnings
2+
3+
import torch
4+
from torch import nn
5+
from x_transformers import XTransformer
6+
from xformers.factory import xFormerConfig, xFormer
7+
8+
9+
class Transformer(nn.Module):
10+
def __init__(self, in_dim: int, out_dim: int, latent_dim, num_layer):
11+
super().__init__()
12+
self.in_dim = in_dim
13+
self.out_dim = out_dim
14+
self.latent_dim = latent_dim
15+
self.num_layer = num_layer
16+
17+
transformer = XTransformer(
18+
dim=latent_dim,
19+
enc_depth=num_layer,
20+
enc_heads=16,
21+
enc_max_seq_len=0,
22+
enc_attn_flash=True,
23+
enc_num_tokens=256,
24+
enc_cross_attend=False,
25+
enc_ff_glu=True,
26+
enc_rotary_pos_emb=True,
27+
enc_use_scalenorm=True,
28+
enc_zero_init_branch_output=True,
29+
dec_num_tokens=256,
30+
dec_depth=num_layer,
31+
dec_heads=16,
32+
dec_ff_glu=True,
33+
dec_rotary_pos_emb=True,
34+
dec_use_scalenorm=True,
35+
dec_attn_flash=True,
36+
dec_max_seq_len=0,
37+
dec_zero_init_branch_output=True,
38+
)
39+
self.embed = nn.Parameter(torch.randn(latent_dim))
40+
self.encoders = transformer.encoder.attn_layers
41+
self.decoders = transformer.decoder.net.attn_layers
42+
self.in_proj = nn.Linear(in_dim, latent_dim)
43+
self.out_proj = nn.Linear(latent_dim, out_dim)
44+
45+
def forward(self, x: torch.Tensor) -> torch.Tensor:
46+
x = self.in_proj(x)
47+
48+
x = self.encoders(x)
49+
50+
h = self.embed.expand(x.shape)
51+
h = self.decoders(h, context=x)
52+
h = self.out_proj(h)
53+
return x - h

scripts/gradio_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"--model_path",
1212
type=str,
1313
required=False,
14-
default="data/model/model_1000.pt",
14+
default="data/model/model_3700.pt",
1515
help="The path to the model",
1616
)
1717

scripts/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
parser = argparse.ArgumentParser()
1111
parser.add_argument(
1212
"--audio",
13-
default="../media/works/dataset/opus/5700_part2.mp3", # ../media/works/dataset/dac/5700_part2.mp3
13+
default="../media/works/dataset/encodec/5700_part2.mp3", # ../media/works/dataset/dac/5700_part2.mp3
1414
type=str,
1515
required=False,
1616
help="The path to the audio file to enhance",
@@ -20,7 +20,7 @@
2020
"--model_path",
2121
type=str,
2222
required=False,
23-
default="data/model/model_300.pt",
23+
default="data/model/model_200.pt",
2424
help="The path to the model",
2525
)
2626

0 commit comments

Comments
 (0)