修改:
1.在源文件261行前加入:
while pad_size > in_f :
x = torch.cat([x, x[..., :]], dim=-1)
pad_size-=in_f
这是由于没考虑pad_size > in_f得情况导致。
2.在源文件293行代码替换为:
if out_x.numel() == 0:
out_x = out_x.view(*x.shape[:-1], out_f)
else :
out_x = out_x.view(*x.shape[:-1], -1)[..., :out_f]
这是由于未考虑 out_x 中element=0得情况。