Skip to content

Commit dc4d738

Browse files
charlesbeattiecopybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 941222035
1 parent ee6026f commit dc4d738

2 files changed

Lines changed: 174 additions & 7 deletions

File tree

python/google/protobuf/internal/message_test.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -531,16 +531,85 @@ def testAssignRepeatedField(self, message_module):
531531
self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)
532532

533533
def testRepeatedFieldSelfSliceAssignment(self, message_module):
534-
msg = message_module.NestedTestAllTypes()
535-
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
536-
msg.payload.repeated_int32[:] = msg.payload.repeated_int32
534+
msg = message_module.NestedTestAllTypes()
535+
for field_name in [
536+
'repeated_int32',
537+
'repeated_int64',
538+
'repeated_uint32',
539+
'repeated_uint64',
540+
'repeated_sint32',
541+
'repeated_sint64',
542+
'repeated_fixed32',
543+
'repeated_fixed64',
544+
'repeated_sfixed32',
545+
'repeated_sfixed64',
546+
]:
547+
field = getattr(msg.payload, field_name)
548+
field[:] = [1, 2, 3, 4]
549+
field[:] = field
550+
self.assertEqual([1, 2, 3, 4], field)
551+
field[:] = field[1:-1]
552+
self.assertEqual([2, 3], field)
553+
for field_name in [
554+
'repeated_float',
555+
'repeated_double',
556+
]:
557+
field = getattr(msg.payload, field_name)
558+
field[:] = [1.25, 2.25, 3.25, 4.25]
559+
field[:] = field
560+
self.assertEqual([1.25, 2.25, 3.25, 4.25], field)
561+
562+
def testRepeatedFieldExtendWithPartialSuccess(self, message_module):
563+
msg = message_module.NestedTestAllTypes()
564+
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
565+
with self.assertRaises(ValueError):
566+
msg.payload.repeated_int32.extend([4, 5, 6, 2**34])
567+
if api_implementation.Type() == 'cpp':
568+
self.assertEqual([1, 2, 3, 4, 4, 5, 6], msg.payload.repeated_int32)
569+
else:
537570
self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)
538571

572+
def testRepeatedFieldSubSliceAssignment(self, message_module):
573+
msg = message_module.NestedTestAllTypes()
574+
msg.payload.repeated_int32[:] = range(1, 6)
575+
msg.payload.repeated_int32[1:3] = msg.payload.repeated_int32[2:4]
576+
self.assertEqual([1, 3, 4, 4, 5], msg.payload.repeated_int32)
577+
msg.payload.repeated_int32.extend(msg.payload.repeated_int32[1:3])
578+
self.assertEqual([1, 3, 4, 4, 5, 3, 4], msg.payload.repeated_int32)
579+
580+
def testRepeatedFieldDifferentTypeSliceAssignment(self, message_module):
581+
msg1 = message_module.NestedTestAllTypes()
582+
msg2 = message_module.NestedTestAllTypes()
583+
# int64 -> int32
584+
msg2.payload.repeated_int64[:] = [1, 2, 3, 4]
585+
msg1.payload.repeated_int32[:] = msg2.payload.repeated_int64
586+
self.assertEqual([1, 2, 3, 4], msg1.payload.repeated_int32)
587+
# int32 -> int64
588+
msg2.payload.repeated_int32[:] = [1, 2, 3, 4]
589+
msg1.payload.repeated_int64[:] = msg2.payload.repeated_int32
590+
self.assertEqual([1, 2, 3, 4], msg1.payload.repeated_int64)
591+
# int64 overflow -> int32
592+
msg2.payload.repeated_int64[:] = [1, 2, 3, 2**35]
593+
with self.assertRaises((ValueError, OverflowError, TypeError)):
594+
msg1.payload.repeated_int32[:] = msg2.payload.repeated_int64
595+
# double -> float
596+
msg2.payload.repeated_double[:] = [1.5, 2.5, 3.5]
597+
msg1.payload.repeated_float[:] = msg2.payload.repeated_double
598+
self.assertEqual([1.5, 2.5, 3.5], msg1.payload.repeated_float)
599+
# float -> double
600+
msg2.payload.repeated_float[:] = [1.5, 2.5, 3.5]
601+
msg1.payload.repeated_double[:] = msg2.payload.repeated_float
602+
self.assertEqual([1.5, 2.5, 3.5], msg1.payload.repeated_double)
603+
604+
msg2.payload.repeated_double[:] = [1.5, 2.5, 1e300]
605+
msg1.payload.repeated_float[:] = msg2.payload.repeated_double
606+
self.assertEqual([1.5, 2.5, float('inf')], msg1.payload.repeated_float)
607+
539608
def testRepeatedFieldSelfExtend(self, message_module):
540-
msg = message_module.NestedTestAllTypes()
541-
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
542-
msg.payload.repeated_int32.extend(msg.payload.repeated_int32)
543-
self.assertEqual([1, 2, 3, 4] * 2, msg.payload.repeated_int32)
609+
msg = message_module.NestedTestAllTypes()
610+
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
611+
msg.payload.repeated_int32.extend(msg.payload.repeated_int32)
612+
self.assertEqual([1, 2, 3, 4] * 2, msg.payload.repeated_int32)
544613

545614
def testAssignOutOfRange(self, message_module):
546615
msg = message_module.NestedTestAllTypes()

python/google/protobuf/internal/numpy/numpy_test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,32 @@ def testNumpyFloatArrayToScalar_RaisesTypeError(self):
220220
with self.assertRaises(TypeError):
221221
message.optional_int64 = np_22_float_array
222222

223+
def testRepeatedFieldSelfSliceAssignment(self):
224+
msg = unittest_pb2.NestedTestAllTypes()
225+
msg.payload.repeated_int32[:] = np.arange(4, dtype=np.int32)
226+
msg.payload.repeated_int32[:] = np.asarray(msg.payload.repeated_int32)
227+
self.assertEqual([0, 1, 2, 3], msg.payload.repeated_int32)
228+
229+
def testNumpyArrayIsMutableCopy(self):
230+
msg = unittest_pb2.NestedTestAllTypes()
231+
msg.payload.repeated_int32[:] = np.arange(4, dtype=np.int32)
232+
arr = np.asarray(msg.payload.repeated_int32)
233+
arr[0] = 100
234+
self.assertEqual([0, 1, 2, 3], msg.payload.repeated_int32)
235+
np.testing.assert_equal([100, 1, 2, 3], arr)
236+
237+
def testNumpyDifferentIntTypeSliceAssignment(self):
238+
msg = unittest_pb2.NestedTestAllTypes()
239+
# int64 -> int32
240+
msg.payload.repeated_int32[:] = np.arange(4, dtype=np.int64)
241+
self.assertEqual([0, 1, 2, 3], msg.payload.repeated_int32)
242+
# int32 -> int64
243+
msg.payload.repeated_int64[:] = np.arange(4, dtype=np.int32)
244+
self.assertEqual([0, 1, 2, 3], msg.payload.repeated_int64)
245+
# int64 overflow -> int32
246+
with self.assertRaises((ValueError, OverflowError, TypeError)):
247+
msg.payload.repeated_int32[:] = np.array([0, 1, 2, 2**35], dtype=np.int64)
248+
223249

224250
@testing_refleaks.TestCase
225251
class NumpyFloatProtoTest(unittest.TestCase):
@@ -274,6 +300,20 @@ def testNumpyObjectArrayToScalar_RaisesTypeError(self):
274300
with self.assertRaises(TypeError):
275301
message.optional_float = np_22_object_array_float
276302

303+
def testNumpyDifferentFloatTypeSliceAssignment(self):
304+
msg = unittest_pb2.NestedTestAllTypes()
305+
# float64 -> float32
306+
msg.payload.repeated_float[:] = np.array([1.5, 2.5, 3.5], dtype=np.float64)
307+
self.assertEqual([1.5, 2.5, 3.5], msg.payload.repeated_float)
308+
# float32 -> float64
309+
msg.payload.repeated_double[:] = np.array([1.5, 2.5, 3.5], dtype=np.float32)
310+
self.assertEqual([1.5, 2.5, 3.5], msg.payload.repeated_double)
311+
# float64 overflow -> float32
312+
msg.payload.repeated_float[:] = np.array(
313+
[1.5, 2.5, 1e300], dtype=np.float64
314+
)
315+
self.assertEqual([1.5, 2.5, float('inf')], msg.payload.repeated_float)
316+
277317

278318
@testing_refleaks.TestCase
279319
class NumpyBoolProtoTest(unittest.TestCase):
@@ -749,6 +789,64 @@ def test_nparray_order(self, message_module):
749789
arr = np.array(m.repeated_int32, order='F')
750790
np.testing.assert_equal(arr, np.array([1, 2, 3]))
751791

792+
@parameterized.product(
793+
message_module=[unittest_pb2, unittest_proto3_arena_pb2],
794+
field_name=[
795+
'repeated_int32',
796+
'repeated_int64',
797+
'repeated_uint32',
798+
'repeated_uint64',
799+
'repeated_sint32',
800+
'repeated_sint64',
801+
'repeated_fixed32',
802+
'repeated_fixed64',
803+
'repeated_sfixed32',
804+
'repeated_sfixed64',
805+
],
806+
dtype=[
807+
np.int8,
808+
np.int16,
809+
np.int32,
810+
np.int64,
811+
np.uint8,
812+
np.uint16,
813+
np.uint32,
814+
np.uint64,
815+
],
816+
)
817+
def test_assign_integer_numpy_array_to_repeated(
818+
self, message_module, field_name, dtype
819+
):
820+
m = message_module.TestAllTypes()
821+
field = getattr(m, field_name)
822+
arr = np.array([0, 1, 2, 3], dtype=dtype)
823+
field[:] = arr
824+
self.assertEqual([0, 1, 2, 3], field)
825+
field[1:-1] = arr
826+
self.assertEqual([0, 0, 1, 2, 3, 3], field)
827+
828+
@parameterized.product(
829+
message_module=[unittest_pb2, unittest_proto3_arena_pb2],
830+
field_name=[
831+
'repeated_float',
832+
'repeated_double',
833+
],
834+
dtype=[
835+
np.float32,
836+
np.float64,
837+
],
838+
)
839+
def test_assign_float_numpy_array_to_repeated(
840+
self, message_module, field_name, dtype
841+
):
842+
m = message_module.TestAllTypes()
843+
field = getattr(m, field_name)
844+
arr = np.array([1.5, 2.5, 3.5], dtype=dtype)
845+
field[:] = arr
846+
self.assertEqual([1.5, 2.5, 3.5], field)
847+
field[1:-1] = arr
848+
self.assertEqual([1.5, 1.5, 2.5, 3.5, 3.5], field)
849+
752850

753851
if __name__ == '__main__':
754852
unittest.main()

0 commit comments

Comments
 (0)