1414
1515from __future__ import annotations
1616
17+ import inspect
1718import numbers
18- from typing import TYPE_CHECKING
19+ from typing import TYPE_CHECKING , Any
20+
21+ from typing_extensions import overload
1922
2023import paddle
2124from paddle import _C_ops , in_dynamic_mode
@@ -681,6 +684,7 @@ def local_response_norm(
681684 return res
682685
683686
687+ @overload
684688def group_norm (
685689 x : Tensor ,
686690 num_groups : int ,
@@ -689,16 +693,40 @@ def group_norm(
689693 bias : Tensor | None = None ,
690694 data_format : DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW' ,
691695 name : str | None = None ,
692- ) -> Tensor :
696+ ) -> Tensor : ...
697+
698+
699+ @overload
700+ def group_norm (
701+ input : Tensor ,
702+ num_groups : int ,
703+ weight : Tensor | None = None ,
704+ bias : Tensor | None = None ,
705+ eps : float = 1e-05 ,
706+ ) -> Tensor : ...
707+
708+
709+ def group_norm (* args : Any , ** kwargs : Any ) -> Tensor :
693710 """
694711 nn.GroupNorm is recommended.
695712 For more information, please refer to :ref:`api_paddle_nn_GroupNorm` .
696713
714+ This function has two functionalities, depending on the parameters passed:
715+
716+ 1. ``group_norm(Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)``:
717+ PyTorch compatible group_norm.
718+
719+ 2. ``group_norm(Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None,
720+ DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)``:
721+ The original paddle.nn.functional.group_norm, see the following docs.
722+
697723 Parameters:
698724 x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`.
725+ alias: ``input``.
699726 num_groups(int): The number of groups that divided from channels.
700727 epsilon(float, optional): The small value added to the variance to prevent
701728 division by zero. Default: 1e-05.
729+ alias: ``eps``.
702730 weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`.
703731 Default: None.
704732 bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`.
@@ -744,6 +772,44 @@ def group_norm(
744772 [[-1.34163547, -0.44721183],
745773 [ 0.44721183, 1.34163547]]]])
746774 """
775+
776+ len_args = len (args )
777+ if len_args + len (kwargs ) < 2 :
778+ raise TypeError (
779+ f"Too few arguments in the function call: { len_args } , { len (kwargs )} . Expect one of: \n "
780+ " - (Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)\n "
781+ " - (Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, "
782+ "DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)"
783+ )
784+
785+ def safe_set_param (key : str , value : Any ):
786+ if key in kwargs :
787+ raise TypeError (f"got multiple values for argument '{ key } '" )
788+ kwargs [key ] = value
789+
790+ if 'input' in kwargs :
791+ safe_set_param ('x' , kwargs .pop ('input' ))
792+
793+ if 'eps' in kwargs :
794+ safe_set_param ('epsilon' , kwargs .pop ('eps' ))
795+
796+ if len_args >= 3 and not isinstance (args [2 ], float ):
797+ param_keys = ["weight" , "bias" , "epsilon" ]
798+ for idx in range (min (len_args - 2 , len (param_keys ))):
799+ safe_set_param (param_keys [idx ], args [idx + 2 ])
800+ args = args [:2 ]
801+ return _group_norm_wrapper (* args , ** kwargs )
802+
803+
804+ def _group_norm_wrapper (
805+ x : Tensor ,
806+ num_groups : int ,
807+ epsilon : float = 1e-05 ,
808+ weight : Tensor | None = None ,
809+ bias : Tensor | None = None ,
810+ data_format : DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW' ,
811+ name : str | None = None ,
812+ ) -> Tensor :
747813 if data_format not in ['NCL' , 'NCHW' , 'NCDHW' , 'NLC' , 'NHWC' , 'NDHWC' ]:
748814 raise ValueError ("unsupported data layout:" + data_format )
749815
@@ -794,3 +860,6 @@ def group_norm(
794860 )
795861
796862 return helper .append_activation (group_norm_out )
863+
864+
865+ group_norm .__signature__ = inspect .signature (_group_norm_wrapper )
0 commit comments