11"""This module contains all the process for the latent space of the audio autoencoder."""
2+ import random
3+
24import torch
35from torch import nn
6+ from torch .nn import functional as F
47
58from audioenhancer .model .audio_ae .expert import Expert
69from audioenhancer .model .audio_ae .mamba import MambaBlock
@@ -76,7 +79,7 @@ class LatentProcessor(nn.Module):
7679 This module processes the latent space of the audio autoencoder.
7780 """
7881
79- def __init__ (self , in_dim : int , out_dim : int , latent_dim , num_layer , num_expert = 1 ):
82+ def __init__ (self , in_dim : int , out_dim : int , latent_dim , num_layer , noise_grad = 1 ):
8083 super ().__init__ ()
8184 self .latent_dim = latent_dim
8285 self .num_layer = num_layer
@@ -85,9 +88,11 @@ def __init__(self, in_dim: int, out_dim: int, latent_dim, num_layer, num_expert=
8588 self .in_proj = nn .Linear (in_dim , latent_dim )
8689
8790 self .out_proj = nn .Linear (latent_dim , out_dim )
88- self .num_expert = num_expert
8991
9092 self .mambas = nn .ModuleList ([MambaBlock (config ) for _ in range (num_layer )])
93+ self .unknow_noise = nn .Parameter (torch .randn (latent_dim ))
94+ self .noise_embed = nn .Embedding (noise_grad , latent_dim )
95+ self .noise_head = nn .Linear (latent_dim , noise_grad )
9196 # self.pre_process = nn.Sequential(
9297 # MambaBlock(config),
9398 # MambaBlock(config),
@@ -104,14 +109,32 @@ def classify(self, x):
104109 x = self .pre_process (x )
105110 return self .classifier (x )
106111
107- def forward (self , x , classes ):
112+ def forward (self , x , noise , gen_noise = False , noise_label = None ):
113+ bzs = x .size (0 )
108114 h = self .in_proj (x )
115+ if noise is not None and not gen_noise :
116+ noise = self .noise_embed (noise ).reshape (bzs , 1 , - 1 )
117+ h = torch .cat ([h , noise ], dim = 1 )
118+ gen_noise = True
119+ else :
120+ noise = self .unknow_noise .reshape (1 , 1 , - 1 ).repeat (bzs , 1 , 1 )
121+ h = torch .cat ([h , noise ], dim = 1 )
122+
109123 # h = self.pre_process(h)
110124 for mamba in self .mambas :
111- h = mamba (h )
125+ h = mamba (h , gen_noise = gen_noise )
112126 # if classes is not None:
113127 # return x * classes[:, None, None, 0] + self.out_proj(h) * classes[:, None, None, 1]
114- return self .out_proj (h )
128+
129+ logits = self .noise_head (h [:, - 1 ])
130+ h = h [:, :- 1 ]
131+ if noise_label is not None :
132+ if not gen_noise :
133+ return self .out_proj (h ), 0
134+ loss = F .cross_entropy (logits .view (- 1 , logits .size (- 1 )), noise_label .view (- 1 ))
135+ return self .out_proj (h ), loss
136+
137+ return self .out_proj (h ), logits
115138
116139 # expert
117140 # def forward(self, x, expert_id=None):
0 commit comments