|
1 | 1 | import pickle |
| 2 | +import copy |
2 | 3 | import concurrent.futures |
3 | 4 | import pytest |
4 | 5 | import numpy as np |
@@ -374,6 +375,27 @@ def test_decode(self): |
374 | 375 | stream = DecodeStream(ids=[0, 1, 2]) |
375 | 376 | assert stream.step(tokenizer, 3) == " john" |
376 | 377 |
|
| 378 | + def test_decode_stream_copy_and_prefix_ids(self): |
| 379 | + tokenizer = Tokenizer(BPE()) |
| 380 | + tokenizer.add_tokens(["my", "name", "is", "john"]) |
| 381 | + token_ids = [0, 1, 2, 3] |
| 382 | + |
| 383 | + stream = DecodeStream(skip_special_tokens=False) |
| 384 | + assert stream.step(tokenizer, token_ids[0]) == "my" |
| 385 | + assert stream.step(tokenizer, token_ids[1]) == " name" |
| 386 | + stream_copy = copy.copy(stream) |
| 387 | + assert stream.step(tokenizer, token_ids[2]) == " is" |
| 388 | + assert stream_copy.step(tokenizer, token_ids[2]) == " is" |
| 389 | + assert stream.step(tokenizer, token_ids[3]) == " john" |
| 390 | + assert stream_copy.step(tokenizer, token_ids[3]) == " john" |
| 391 | + |
| 392 | + stream_steps = DecodeStream([]) |
| 393 | + last_chunk = None |
| 394 | + for tid in token_ids: |
| 395 | + last_chunk = stream_steps.step(tokenizer, tid) |
| 396 | + stream_prefill = DecodeStream(token_ids[:-1]) |
| 397 | + assert stream_prefill.step(tokenizer, token_ids[-1]) == last_chunk |
| 398 | + |
377 | 399 | def test_decode_stream_fallback(self): |
378 | 400 | tokenizer = Tokenizer.from_pretrained("gpt2") |
379 | 401 | # tokenizer.decode([255]) fails because its a fallback |
|
0 commit comments