@@ -42,15 +42,15 @@ def test_embedding_layer_with_token_type(self):
4242 output = layer (input_seq , token_type )
4343 output_shape = output .shape .as_list ()
4444 expected_shape = [1 , 4 , 16 ]
45- self .assertListEqual (output_shape , expected_shape , msg = None )
45+ self .assertListEqual (output_shape , expected_shape )
4646
4747 def test_embedding_layer_without_token_type (self ):
4848 layer = mobile_bert_layers .MobileBertEmbedding (10 , 8 , 2 , 16 )
4949 input_seq = tf .Variable ([[2 , 3 , 4 , 5 ]])
5050 output = layer (input_seq )
5151 output_shape = output .shape .as_list ()
5252 expected_shape = [1 , 4 , 16 ]
53- self .assertListEqual (output_shape , expected_shape , msg = None )
53+ self .assertListEqual (output_shape , expected_shape )
5454
5555 def test_embedding_layer_get_config (self ):
5656 layer = mobile_bert_layers .MobileBertEmbedding (
@@ -72,7 +72,7 @@ def test_no_norm(self):
7272 output = layer (feature )
7373 output_shape = output .shape .as_list ()
7474 expected_shape = [2 , 3 , 4 ]
75- self .assertListEqual (output_shape , expected_shape , msg = None )
75+ self .assertListEqual (output_shape , expected_shape )
7676
7777 @parameterized .named_parameters (('with_kq_shared_bottleneck' , False ),
7878 ('without_kq_shared_bottleneck' , True ))
@@ -83,7 +83,17 @@ def test_transfomer_kq_shared_bottleneck(self, is_kq_shared):
8383 output = layer (feature )
8484 output_shape = output .shape .as_list ()
8585 expected_shape = [2 , 3 , 512 ]
86- self .assertListEqual (output_shape , expected_shape , msg = None )
86+ self .assertListEqual (output_shape , expected_shape )
87+
88+ def test_transformer_with_squared_relu (self ):
89+ feature = tf .random .uniform ([2 , 3 , 512 ])
90+ layer = mobile_bert_layers .MobileBertTransformer (
91+ intermediate_act_fn = 'squared_relu'
92+ )
93+ output = layer (feature )
94+ output_shape = output .shape .as_list ()
95+ expected_shape = [2 , 3 , 512 ]
96+ self .assertListEqual (output_shape , expected_shape )
8797
8898 def test_transfomer_with_mask (self ):
8999 feature = tf .random .uniform ([2 , 3 , 512 ])
@@ -94,7 +104,7 @@ def test_transfomer_with_mask(self):
94104 output = layer (feature , input_mask )
95105 output_shape = output .shape .as_list ()
96106 expected_shape = [2 , 3 , 512 ]
97- self .assertListEqual (output_shape , expected_shape , msg = None )
107+ self .assertListEqual (output_shape , expected_shape )
98108
99109 def test_transfomer_return_attention_score (self ):
100110 sequence_length = 5
@@ -104,8 +114,7 @@ def test_transfomer_return_attention_score(self):
104114 num_attention_heads = num_attention_heads )
105115 _ , attention_score = layer (feature , return_attention_scores = True )
106116 expected_shape = [2 , num_attention_heads , sequence_length , sequence_length ]
107- self .assertListEqual (
108- attention_score .shape .as_list (), expected_shape , msg = None )
117+ self .assertListEqual (attention_score .shape .as_list (), expected_shape )
109118
110119 def test_transformer_get_config (self ):
111120 layer = mobile_bert_layers .MobileBertTransformer (
0 commit comments