Skip to content

Commit 11564da

Browse files
committed
fix Assign docstring sample shapes
1 parent 9310eb3 commit 11564da

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

python/paddle/nn/initializer/assign.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,61 +209,64 @@ class Assign(NumpyArrayInitializer):
209209
>>> data_1 = paddle.ones(shape=[1, 2], dtype='float32')
210210
>>> weight_attr_1 = paddle.ParamAttr(
211211
... name="linear_weight_1",
212-
... initializer=paddle.nn.initializer.Assign(np.array([2, 2])),
212+
... initializer=paddle.nn.initializer.Assign(np.array([[2, 2], [2, 2]])),
213213
... )
214214
>>> bias_attr_1 = paddle.ParamAttr(
215215
... name="linear_bias_1",
216-
... initializer=paddle.nn.initializer.Assign(np.array([2])),
216+
... initializer=paddle.nn.initializer.Assign(np.array([2, 2])),
217217
... )
218218
>>> linear_1 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_1, bias_attr=bias_attr_1)
219219
>>> print(linear_1.weight.numpy())
220-
[2. 2.]
220+
[[2. 2.]
221+
[2. 2.]]
221222
>>> print(linear_1.bias.numpy())
222-
[2.]
223+
[2. 2.]
223224
224225
>>> res_1 = linear_1(data_1)
225226
>>> print(res_1.numpy())
226-
[6.]
227+
[[6. 6.]]
227228
228229
>>> # python list
229230
>>> data_2 = paddle.ones(shape=[1, 2], dtype='float32')
230231
>>> weight_attr_2 = paddle.ParamAttr(
231232
... name="linear_weight_2",
232-
... initializer=paddle.nn.initializer.Assign([2, 2]),
233+
... initializer=paddle.nn.initializer.Assign([[2, 2], [2, 2]]),
233234
... )
234235
>>> bias_attr_2 = paddle.ParamAttr(
235236
... name="linear_bias_2",
236-
... initializer=paddle.nn.initializer.Assign([2]),
237+
... initializer=paddle.nn.initializer.Assign([2, 2]),
237238
... )
238239
>>> linear_2 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_2, bias_attr=bias_attr_2)
239240
>>> print(linear_2.weight.numpy())
240-
[2. 2.]
241+
[[2. 2.]
242+
[2. 2.]]
241243
>>> print(linear_2.bias.numpy())
242-
[2.]
244+
[2. 2.]
243245
244246
>>> res_2 = linear_2(data_2)
245247
>>> print(res_2.numpy())
246-
[6.]
248+
[[6. 6.]]
247249
248250
>>> # tensor
249251
>>> data_3 = paddle.ones(shape=[1, 2], dtype='float32')
250252
>>> weight_attr_3 = paddle.ParamAttr(
251253
... name="linear_weight_3",
252-
... initializer=paddle.nn.initializer.Assign(paddle.full([2], 2)),
254+
... initializer=paddle.nn.initializer.Assign(paddle.full([2, 2], 2)),
253255
... )
254256
>>> bias_attr_3 = paddle.ParamAttr(
255257
... name="linear_bias_3",
256-
... initializer=paddle.nn.initializer.Assign(paddle.full([1], 2)),
258+
... initializer=paddle.nn.initializer.Assign(paddle.full([2], 2)),
257259
... )
258260
>>> linear_3 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_3, bias_attr=bias_attr_3)
259261
>>> print(linear_3.weight.numpy())
260-
[2. 2.]
262+
[[2. 2.]
263+
[2. 2.]]
261264
>>> print(linear_3.bias.numpy())
262-
[2.]
265+
[2. 2.]
263266
264267
>>> res_3 = linear_3(data_3)
265268
>>> print(res_3.numpy())
266-
[6.]
269+
[[6. 6.]]
267270
"""
268271

269272
def __init__(

0 commit comments

Comments
 (0)