|
2 | 2 |
|
3 | 3 | import functools
|
4 | 4 | from concurrent.futures import ThreadPoolExecutor
|
5 |
| -from typing import AbstractSet, Collection, Literal, NoReturn, Sequence |
| 5 | +from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence |
6 | 6 |
|
7 | 7 | import regex
|
8 | 8 |
|
9 | 9 | from tiktoken import _tiktoken
|
10 | 10 |
|
| 11 | +if TYPE_CHECKING: |
| 12 | + import numpy as np |
| 13 | + import numpy.typing as npt |
| 14 | + |
11 | 15 |
|
12 | 16 | class Encoding:
|
13 | 17 | def __init__(
|
@@ -128,6 +132,32 @@ def encode(
|
128 | 132 | text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
|
129 | 133 | return self._core_bpe.encode(text, allowed_special)
|
130 | 134 |
|
| 135 | + def encode_to_numpy( |
| 136 | + self, |
| 137 | + text: str, |
| 138 | + *, |
| 139 | + allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 |
| 140 | + disallowed_special: Literal["all"] | Collection[str] = "all", |
| 141 | + ) -> npt.NDArray[np.uint32]: |
| 142 | + """Encodes a string into tokens, returning a numpy array. |
| 143 | +
|
| 144 | + Avoids the overhead of copying the token buffer into a Python list. |
| 145 | + """ |
| 146 | + if allowed_special == "all": |
| 147 | + allowed_special = self.special_tokens_set |
| 148 | + if disallowed_special == "all": |
| 149 | + disallowed_special = self.special_tokens_set - allowed_special |
| 150 | + if disallowed_special: |
| 151 | + if not isinstance(disallowed_special, frozenset): |
| 152 | + disallowed_special = frozenset(disallowed_special) |
| 153 | + if match := _special_token_regex(disallowed_special).search(text): |
| 154 | + raise_disallowed_special_token(match.group()) |
| 155 | + |
| 156 | + import numpy as np |
| 157 | + |
| 158 | + buffer = self._core_bpe.encode_to_tiktoken_buffer(text, self.special_tokens_set) |
| 159 | + return np.frombuffer(buffer, dtype=np.uint32) |
| 160 | + |
131 | 161 | def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
|
132 | 162 | """Encodes a list of strings into tokens, in parallel, ignoring special tokens.
|
133 | 163 |
|
@@ -332,6 +362,10 @@ def eot_token(self) -> int:
|
332 | 362 | def special_tokens_set(self) -> set[str]:
|
333 | 363 | return set(self._special_tokens.keys())
|
334 | 364 |
|
| 365 | + def is_special_token(self, token: int) -> bool: |
| 366 | + assert isinstance(token, int) |
| 367 | + return token in self._special_token_values |
| 368 | + |
335 | 369 | @property
|
336 | 370 | def n_vocab(self) -> int:
|
337 | 371 | """For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
|
|
0 commit comments