Skip to content

Commit a26435f

Browse files
authored
[API compatibility] update paddle group_norm api (#76149)
* [API compatibility] update paddle group_norm api * update: compatible with old parameter order * update: remove decorator
1 parent 41c476b commit a26435f

File tree

2 files changed

+196
-2
lines changed

2 files changed

+196
-2
lines changed

python/paddle/nn/functional/norm.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515
from __future__ import annotations
1616

17+
import inspect
1718
import numbers
18-
from typing import TYPE_CHECKING
19+
from typing import TYPE_CHECKING, Any
20+
21+
from typing_extensions import overload
1922

2023
import paddle
2124
from paddle import _C_ops, in_dynamic_mode
@@ -681,6 +684,7 @@ def local_response_norm(
681684
return res
682685

683686

687+
@overload
684688
def group_norm(
685689
x: Tensor,
686690
num_groups: int,
@@ -689,16 +693,40 @@ def group_norm(
689693
bias: Tensor | None = None,
690694
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
691695
name: str | None = None,
692-
) -> Tensor:
696+
) -> Tensor: ...
697+
698+
699+
@overload
700+
def group_norm(
701+
input: Tensor,
702+
num_groups: int,
703+
weight: Tensor | None = None,
704+
bias: Tensor | None = None,
705+
eps: float = 1e-05,
706+
) -> Tensor: ...
707+
708+
709+
def group_norm(*args: Any, **kwargs: Any) -> Tensor:
693710
"""
694711
nn.GroupNorm is recommended.
695712
For more information, please refer to :ref:`api_paddle_nn_GroupNorm` .
696713
714+
This function has two functionalities, depending on the parameters passed:
715+
716+
1. ``group_norm(Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)``:
717+
PyTorch compatible group_norm.
718+
719+
2. ``group_norm(Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None,
720+
DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)``:
721+
The original paddle.nn.functional.group_norm, see the following docs.
722+
697723
Parameters:
698724
x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`.
725+
alias: ``input``.
699726
num_groups(int): The number of groups that divided from channels.
700727
epsilon(float, optional): The small value added to the variance to prevent
701728
division by zero. Default: 1e-05.
729+
alias: ``eps``.
702730
weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`.
703731
Default: None.
704732
bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`.
@@ -744,6 +772,44 @@ def group_norm(
744772
[[-1.34163547, -0.44721183],
745773
[ 0.44721183, 1.34163547]]]])
746774
"""
775+
776+
len_args = len(args)
777+
if len_args + len(kwargs) < 2:
778+
raise TypeError(
779+
f"Too few arguments in the function call: {len_args}, {len(kwargs)}. Expect one of: \n"
780+
" - (Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)\n"
781+
" - (Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, "
782+
"DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)"
783+
)
784+
785+
def safe_set_param(key: str, value: Any):
786+
if key in kwargs:
787+
raise TypeError(f"got multiple values for argument '{key}'")
788+
kwargs[key] = value
789+
790+
if 'input' in kwargs:
791+
safe_set_param('x', kwargs.pop('input'))
792+
793+
if 'eps' in kwargs:
794+
safe_set_param('epsilon', kwargs.pop('eps'))
795+
796+
if len_args >= 3 and not isinstance(args[2], float):
797+
param_keys = ["weight", "bias", "epsilon"]
798+
for idx in range(min(len_args - 2, len(param_keys))):
799+
safe_set_param(param_keys[idx], args[idx + 2])
800+
args = args[:2]
801+
return _group_norm_wrapper(*args, **kwargs)
802+
803+
804+
def _group_norm_wrapper(
805+
x: Tensor,
806+
num_groups: int,
807+
epsilon: float = 1e-05,
808+
weight: Tensor | None = None,
809+
bias: Tensor | None = None,
810+
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
811+
name: str | None = None,
812+
) -> Tensor:
747813
if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']:
748814
raise ValueError("unsupported data layout:" + data_format)
749815

@@ -794,3 +860,6 @@ def group_norm(
794860
)
795861

796862
return helper.append_activation(group_norm_out)
863+
864+
865+
group_norm.__signature__ = inspect.signature(_group_norm_wrapper)

test/legacy_test/test_group_norm_op_v2.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,5 +618,130 @@ def test_group_norm_cpu_with_optional_grad_nhwc(self):
618618
np.testing.assert_equal(dx.numpy(), dx_ref.numpy())
619619

620620

621+
class TestGroupNormParam(unittest.TestCase):
622+
def setUp(self):
623+
self.x_tensor = paddle.randn([2, 6, 4, 4], dtype='float32')
624+
self.weight_tensor = paddle.randn([6], dtype='float32')
625+
self.bias_tensor = paddle.randn([6], dtype='float32')
626+
627+
def test_alias_input_for_x(self):
628+
"""test parameter alias input/x"""
629+
out_with_input = paddle.nn.functional.group_norm(
630+
input=self.x_tensor,
631+
num_groups=3,
632+
weight=self.weight_tensor,
633+
bias=self.bias_tensor,
634+
eps=1e-5,
635+
)
636+
out_with_x = paddle.nn.functional.group_norm(
637+
x=self.x_tensor,
638+
num_groups=3,
639+
weight=self.weight_tensor,
640+
bias=self.bias_tensor,
641+
eps=1e-5,
642+
)
643+
644+
np.testing.assert_array_equal(
645+
out_with_input.numpy(), out_with_x.numpy()
646+
)
647+
648+
def test_params_consistency(self):
649+
"""test both paddle and torch formats works."""
650+
out_old = paddle.nn.functional.group_norm(
651+
self.x_tensor,
652+
3,
653+
1e-5,
654+
weight=self.weight_tensor,
655+
bias=self.bias_tensor,
656+
)
657+
658+
out_new = paddle.nn.functional.group_norm(
659+
x=self.x_tensor,
660+
num_groups=3,
661+
weight=self.weight_tensor,
662+
bias=self.bias_tensor,
663+
eps=1e-5,
664+
)
665+
666+
np.testing.assert_array_equal(out_old.numpy(), out_new.numpy())
667+
668+
def test_params_1(self):
669+
"""test all args with torch format"""
670+
try:
671+
out = paddle.nn.functional.group_norm(
672+
self.x_tensor,
673+
3,
674+
self.weight_tensor,
675+
self.bias_tensor,
676+
1e-5,
677+
)
678+
self.assertTrue(True, "Function call succeeded without error")
679+
except Exception as e:
680+
self.fail(f"Function raised an unexpected exception: {e}")
681+
682+
def test_params_2(self):
683+
"""test all kwargs with torch format"""
684+
try:
685+
out = paddle.nn.functional.group_norm(
686+
input=self.x_tensor,
687+
num_groups=3,
688+
weight=self.weight_tensor,
689+
bias=self.bias_tensor,
690+
epsilon=1e-5,
691+
)
692+
self.assertTrue(True, "Function call succeeded without error")
693+
except Exception as e:
694+
self.fail(f"Function raised an unexpected exception: {e}")
695+
696+
def test_params_3(self):
697+
"""test of passing both args and kwargs parameters"""
698+
try:
699+
out1 = paddle.nn.functional.group_norm(
700+
self.x_tensor,
701+
3,
702+
weight=self.weight_tensor,
703+
bias=self.bias_tensor,
704+
epsilon=1e-5,
705+
)
706+
out2 = paddle.nn.functional.group_norm(
707+
self.x_tensor,
708+
3,
709+
1e-5,
710+
weight=self.weight_tensor,
711+
bias=self.bias_tensor,
712+
)
713+
self.assertTrue(True, "Function call succeeded without error")
714+
except Exception as e:
715+
self.fail(f"Function raised an unexpected exception: {e}")
716+
717+
def test_params_4(self):
718+
"""test default parameters"""
719+
try:
720+
out1 = paddle.nn.functional.group_norm(
721+
self.x_tensor,
722+
3,
723+
self.weight_tensor,
724+
)
725+
out2 = paddle.nn.functional.group_norm(self.x_tensor, 3, 1e-5)
726+
self.assertTrue(True, "Function call succeeded without error")
727+
except Exception as e:
728+
self.fail(f"Function raised an unexpected exception: {e}")
729+
730+
def test_params_5(self):
731+
"""test duplicate parameters"""
732+
with self.assertRaises(TypeError):
733+
out_1 = paddle.nn.functional.group_norm(
734+
x=self.x_tensor,
735+
input=self.x_tensor,
736+
num_groups=3,
737+
)
738+
with self.assertRaises(TypeError):
739+
out_2 = paddle.nn.functional.group_norm(
740+
self.x_tensor,
741+
input=self.x_tensor,
742+
num_groups=3,
743+
)
744+
745+
621746
if __name__ == '__main__':
622747
unittest.main()

0 commit comments

Comments
 (0)