Skip to content

Commit 1b5b6ef

Browse files
committed
27/10/21_debug_spec_aug
1 parent 6bbf8c5 commit 1b5b6ef

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,20 @@ def forward(self, input: torch.tensor) -> torch.tensor:
9595

9696
class FbankAug(nn.Module):
9797

98-
def __init__(self, time_mask_width = (0, 8), freq_mask_width = (0, 10)):
98+
def __init__(self, freq_mask_width = (0, 8), time_mask_width = (0, 10)):
9999
self.time_mask_width = time_mask_width
100100
self.freq_mask_width = freq_mask_width
101101
super().__init__()
102102

103103
def mask_along_axis(self, x, dim):
104104
original_size = x.shape
105-
batch, time, fea = x.shape
105+
batch, fea, time = x.shape
106106
if dim == 1:
107-
D = time
108-
width_range = self.time_mask_width
109-
else:
110107
D = fea
111108
width_range = self.freq_mask_width
109+
else:
110+
D = time
111+
width_range = self.time_mask_width
112112

113113
mask_len = torch.randint(width_range[0], width_range[1], (batch, 1), device=x.device).unsqueeze(2)
114114
mask_pos = torch.randint(0, max(1, D - mask_len.max()), (batch, 1), device=x.device).unsqueeze(2)

0 commit comments

Comments
 (0)