Skip to content

Commit b3889ab

Browse files
authored
Add copy to a decostream (#1930)
* Fix warnings: remove a print and remove some deprecation warnings (#1924) * make sur the warning is just a warning * update * nits * fix tests * add copy test and prefill copy
1 parent b874abe commit b3889ab

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

bindings/python/src/decoders.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,13 @@ impl PyDecodeStream {
685685
))
686686
.into()
687687
}
688+
fn __copy__(&self) -> Self {
689+
self.clone()
690+
}
691+
692+
fn __deepcopy__(&self, _memo: &Bound<'_, PyDict>) -> Self {
693+
self.clone()
694+
}
688695
}
689696

690697
#[cfg(test)]

bindings/python/tests/bindings/test_tokenizer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pickle
2+
import copy
23
import concurrent.futures
34
import pytest
45
import numpy as np
@@ -374,6 +375,27 @@ def test_decode(self):
374375
stream = DecodeStream(ids=[0, 1, 2])
375376
assert stream.step(tokenizer, 3) == " john"
376377

378+
def test_decode_stream_copy_and_prefix_ids(self):
379+
tokenizer = Tokenizer(BPE())
380+
tokenizer.add_tokens(["my", "name", "is", "john"])
381+
token_ids = [0, 1, 2, 3]
382+
383+
stream = DecodeStream(skip_special_tokens=False)
384+
assert stream.step(tokenizer, token_ids[0]) == "my"
385+
assert stream.step(tokenizer, token_ids[1]) == " name"
386+
stream_copy = copy.copy(stream)
387+
assert stream.step(tokenizer, token_ids[2]) == " is"
388+
assert stream_copy.step(tokenizer, token_ids[2]) == " is"
389+
assert stream.step(tokenizer, token_ids[3]) == " john"
390+
assert stream_copy.step(tokenizer, token_ids[3]) == " john"
391+
392+
stream_steps = DecodeStream([])
393+
last_chunk = None
394+
for tid in token_ids:
395+
last_chunk = stream_steps.step(tokenizer, tid)
396+
stream_prefill = DecodeStream(token_ids[:-1])
397+
assert stream_prefill.step(tokenizer, token_ids[-1]) == last_chunk
398+
377399
def test_decode_stream_fallback(self):
378400
tokenizer = Tokenizer.from_pretrained("gpt2")
379401
# tokenizer.decode([255]) fails because its a fallback

0 commit comments

Comments
 (0)