Skip to content

Layer Norm doesn't work. #30

@omar-abdelgawad

Description

@omar-abdelgawad

Describe the bug
The ConvBlock class in img2img/nn/blocks.py has a norm_layer attribute that can be batchnorm, instancenorm, or layernorm. Unfortunately, The layer norm default class in pytorch that we use needs an argument called the normalized_shape which is different from the argument passed to batchnorma and instancenorm.

    def _normalization_selector(
        self, normalization_type: NormalizationType
    ) -> nn.Module:
        if normalization_type == NormalizationType.BATCH:
            return nn.BatchNorm2d(self.norm_dim)
        elif normalization_type == NormalizationType.INSTANCE:
            return nn.InstanceNorm2d(self.norm_dim)
        elif normalization_type == NormalizationType.LAYER:
            return nn.LayerNorm(self.norm_dim)  # This doesn't work.
        elif normalization_type == NormalizationType.NONE:
            return nn.Identity()
        else:
            raise NotImplementedError(
                f"Normalization type {normalization_type} is not implemented."
            )

Most other papers we have seen avoided the builtin LayerNorm by providing their own LayerNorm class such as:

class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        # print(x.size())
        if x.size(0) == 1:
            # These two lines run much faster in pytorch 0.4 than the two lines listed below.
            mean = x.view(-1).mean().view(*shape)
            std = x.view(-1).std().view(*shape)
        else:
            mean = x.view(x.size(0), -1).mean(1).view(*shape)
            std = x.view(x.size(0), -1).std(1).view(*shape)

        x = (x - mean) / (std + self.eps)

        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)
            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x

but this approach is unwanted as using builtins in pytorch seems like a better option (also it works in a slightly different way and contains more weights). Whatever the solution is, LayerNorm should be a viable option in the cfg and should work.

To Reproduce
This bug was discovered during UNIT training as the original implementation used LayerNorm in the decoder/generator.

Additional context
In Vision The input usually contains a batch of 3-channeled images that has shape in the form of (B, C, H, W). Batchnorm and Instance norm need the number of channels C as argument to determine the number of parameters needed. Builtin LayerNorm on the other hand needs to know the normalized shape which is basically (C,H,W) which is not always feasible to provide in implementation as Conv2d layer doesn't know the input or output (H,W) and having them calculated on the first forward pass is hard to achieve (and probably not the best idea).

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions