diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index c739291b32..336d16615f 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -31,7 +31,7 @@ class PatchEmbed(nn.Module): def __init__( self, - img_size: Optional[int] = 224, + img_size: Union[int, Tuple[int, int]] = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768,