@@ -107,6 +107,34 @@ def test_entropy_list_clear(self):
107107
108108 self .assertAlmostEqual (share_inputs ["entropy_list" ][2 ][0 ], 0.0003187173861078918 , places = 6 )
109109
110+ def test_negative_inf_clip (self ):
111+ share_inputs = {
112+ "seq_lens_this_time" : paddle .to_tensor ([[1 ], [0 ], [15 ]], dtype = "int32" ),
113+ "seq_lens_encoder" : paddle .to_tensor ([[0 ], [0 ], [15 ]], dtype = "int32" ),
114+ "seq_lens_decoder" : paddle .to_tensor ([[30 ], [0 ], [15 ]], dtype = "int32" ),
115+ "entropy_list" : [[], [], []],
116+ "stop_flags" : paddle .to_tensor ([[False ], [True ], [False ]], dtype = "bool" ),
117+ "req_ids" : ["req_1" , "req_2" , "req_3" ],
118+ }
119+
120+ logits = paddle .to_tensor (
121+ [
122+ [10.0 , 1.0 , - float ("inf" )],
123+ [1.0 , 1.0 , - float ("inf" )],
124+ ],
125+ dtype = "float32" ,
126+ )
127+ temperature = paddle .ones ([3 ], dtype = "float32" )
128+
129+ calculate_logits_entropy (logits , share_inputs , temperature )
130+
131+ self .assertEqual (len (share_inputs ["entropy_list" ][0 ]), 1 )
132+ self .assertEqual (len (share_inputs ["entropy_list" ][1 ]), 0 )
133+ self .assertEqual (len (share_inputs ["entropy_list" ][2 ]), 1 )
134+
135+ self .assertAlmostEqual (share_inputs ["entropy_list" ][0 ][0 ], 0.0017332095885649323 , places = 6 )
136+ self .assertAlmostEqual (share_inputs ["entropy_list" ][2 ][0 ], 1.017357349395752 , places = 6 )
137+
110138
111139class TestSpeculateCalculateLogitsEntropy (unittest .TestCase ):
112140
@@ -207,6 +235,34 @@ def test_entropy_list_clear(self):
207235
208236 self .assertAlmostEqual (share_inputs ["entropy_list" ][1 ][0 ], 0.0024676250759512186 , places = 6 )
209237
238+ def test_negative_inf_clip (self ):
239+ share_inputs = {
240+ "seq_lens_this_time" : paddle .to_tensor ([[1 ], [0 ], [15 ]], dtype = "int32" ),
241+ "seq_lens_encoder" : paddle .to_tensor ([[0 ], [0 ], [15 ]], dtype = "int32" ),
242+ "seq_lens_decoder" : paddle .to_tensor ([[30 ], [0 ], [15 ]], dtype = "int32" ),
243+ "entropy_list" : [[], [], []],
244+ "stop_flags" : paddle .to_tensor ([[False ], [True ], [False ]], dtype = "bool" ),
245+ "req_ids" : ["req_1" , "req_2" , "req_3" ],
246+ }
247+
248+ logits = paddle .to_tensor (
249+ [
250+ [10.0 , 1.0 , - float ("inf" )],
251+ [1.0 , 1.0 , - float ("inf" )],
252+ ],
253+ dtype = "float32" ,
254+ )
255+ temperature = paddle .ones ([3 ], dtype = "float32" )
256+
257+ calculate_logits_entropy (logits , share_inputs , temperature )
258+
259+ self .assertEqual (len (share_inputs ["entropy_list" ][0 ]), 1 )
260+ self .assertEqual (len (share_inputs ["entropy_list" ][1 ]), 0 )
261+ self .assertEqual (len (share_inputs ["entropy_list" ][2 ]), 1 )
262+
263+ self .assertAlmostEqual (share_inputs ["entropy_list" ][0 ][0 ], 0.0017332095885649323 , places = 6 )
264+ self .assertAlmostEqual (share_inputs ["entropy_list" ][2 ][0 ], 1.017357349395752 , places = 6 )
265+
210266
211267if __name__ == "__main__" :
212268 unittest .main ()
0 commit comments