Skip to content

Commit e41d434

Browse files
authored
[Bugfix] Fix entropy calculation bugs (#5941)
* fix entropy bugs
1 parent b9663e5 commit e41d434

File tree

3 files changed

+70
-16
lines changed

3 files changed

+70
-16
lines changed

fastdeploy/model_executor/entropy_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@
1919
from fastdeploy.utils import data_processor_logger
2020

2121

22+
def get_entropy(logits):
23+
# Check for -inf values in logits
24+
if paddle.any(paddle.isinf(logits) & (logits < 0)):
25+
data_processor_logger.debug("Detected -inf values in logits, clipping to minimum value")
26+
logits = paddle.clip(logits, min=1e-9)
27+
28+
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
29+
ea0 = paddle.exp(a0)
30+
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
31+
p0 = ea0 / z0
32+
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
33+
34+
2235
def calculate_logits_entropy(logits, share_inputs, temperature):
2336
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
2437
real_seq_lens = paddle.where(
@@ -27,13 +40,6 @@ def calculate_logits_entropy(logits, share_inputs, temperature):
2740
share_inputs["seq_lens_this_time"].squeeze(1),
2841
)
2942

30-
def get_entropy(logits):
31-
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
32-
ea0 = paddle.exp(a0)
33-
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
34-
p0 = ea0 / z0
35-
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
36-
3743
batch_indices = paddle.arange(real_bsz, dtype="int32")
3844
batch_id_per_token = paddle.repeat_interleave(batch_indices, real_seq_lens)
3945
for i in range(logits.shape[0]):
@@ -77,13 +83,6 @@ def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
7783
for i in range(total_accepted_num):
7884
accepted_logits[i] = logits[accepted_idx[i]]
7985

80-
def get_entropy(logits):
81-
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
82-
ea0 = paddle.exp(a0)
83-
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
84-
p0 = ea0 / z0
85-
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
86-
8786
batch_indices = paddle.arange(share_inputs["accept_num"].shape[0], dtype="int32")
8887
batch_id_per_token = paddle.repeat_interleave(batch_indices, share_inputs["accept_num"])
8988
for i in range(accepted_logits.shape[0]):

fastdeploy/worker/gpu_model_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,15 +1875,14 @@ def _dummy_sampler_run(
18751875
group=self.parallel_config.tp_group,
18761876
)
18771877
else:
1878-
self.sampler(
1878+
sampler_output = self.sampler(
18791879
logits,
18801880
self.sampling_metadata,
18811881
self.model_config.max_model_len,
18821882
self.share_inputs,
18831883
accept_all_drafts,
18841884
reject_all_drafts,
18851885
)
1886-
sampler_output = None
18871886
if self.parallel_config.tensor_parallel_size > 1:
18881887
paddle.distributed.broadcast(
18891888
self.share_inputs["accept_tokens"],

tests/model_executor/test_entropy_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,34 @@ def test_entropy_list_clear(self):
107107

108108
self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0003187173861078918, places=6)
109109

110+
def test_negative_inf_clip(self):
111+
share_inputs = {
112+
"seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"),
113+
"seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"),
114+
"seq_lens_decoder": paddle.to_tensor([[30], [0], [15]], dtype="int32"),
115+
"entropy_list": [[], [], []],
116+
"stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"),
117+
"req_ids": ["req_1", "req_2", "req_3"],
118+
}
119+
120+
logits = paddle.to_tensor(
121+
[
122+
[10.0, 1.0, -float("inf")],
123+
[1.0, 1.0, -float("inf")],
124+
],
125+
dtype="float32",
126+
)
127+
temperature = paddle.ones([3], dtype="float32")
128+
129+
calculate_logits_entropy(logits, share_inputs, temperature)
130+
131+
self.assertEqual(len(share_inputs["entropy_list"][0]), 1)
132+
self.assertEqual(len(share_inputs["entropy_list"][1]), 0)
133+
self.assertEqual(len(share_inputs["entropy_list"][2]), 1)
134+
135+
self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0017332095885649323, places=6)
136+
self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 1.017357349395752, places=6)
137+
110138

111139
class TestSpeculateCalculateLogitsEntropy(unittest.TestCase):
112140

@@ -207,6 +235,34 @@ def test_entropy_list_clear(self):
207235

208236
self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0024676250759512186, places=6)
209237

238+
def test_negative_inf_clip(self):
239+
share_inputs = {
240+
"seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"),
241+
"seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"),
242+
"seq_lens_decoder": paddle.to_tensor([[30], [0], [15]], dtype="int32"),
243+
"entropy_list": [[], [], []],
244+
"stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"),
245+
"req_ids": ["req_1", "req_2", "req_3"],
246+
}
247+
248+
logits = paddle.to_tensor(
249+
[
250+
[10.0, 1.0, -float("inf")],
251+
[1.0, 1.0, -float("inf")],
252+
],
253+
dtype="float32",
254+
)
255+
temperature = paddle.ones([3], dtype="float32")
256+
257+
calculate_logits_entropy(logits, share_inputs, temperature)
258+
259+
self.assertEqual(len(share_inputs["entropy_list"][0]), 1)
260+
self.assertEqual(len(share_inputs["entropy_list"][1]), 0)
261+
self.assertEqual(len(share_inputs["entropy_list"][2]), 1)
262+
263+
self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0017332095885649323, places=6)
264+
self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 1.017357349395752, places=6)
265+
210266

211267
if __name__ == "__main__":
212268
unittest.main()

0 commit comments

Comments
 (0)