Skip to content

Add FxHash and ShortStringOptimization. #1733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions bindings/node/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::trainers::Trainer;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tokenizers as tk;
Expand Down Expand Up @@ -95,7 +95,7 @@ impl tk::Model for Model {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}

fn get_vocab(&self) -> HashMap<String, u32> {
fn get_vocab(&self) -> FxHashMap<String, u32> {
self
.model
.as_ref()
Expand Down
4 changes: 2 additions & 2 deletions bindings/node/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer;
use crate::processors::Processor;
use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask};
use crate::trainers::Trainer;
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use tokenizers::Model as ModelTrait;

use napi::bindgen_prelude::*;
Expand Down Expand Up @@ -433,7 +433,7 @@ impl Tokenizer {
}

#[napi]
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> HashMap<String, u32> {
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> FxHashMap<String, u32> {
let with_added_tokens = with_added_tokens.unwrap_or(true);
self.tokenizer.read().unwrap().get_vocab(with_added_tokens)
}
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] }
numpy = "0.23"
ndarray = "0.16"
itertools = "0.12"
rustc-hash = "2.1.1"
compact_str = { version = "0.8.1", features = ["serde"] }

[dependencies.tokenizers]
path = "../../tokenizers"
Expand Down
94 changes: 68 additions & 26 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock};
use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern;
use compact_str::ToCompactString;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
Expand Down Expand Up @@ -91,7 +92,10 @@ impl PyDecoder {
}

impl Decoder for PyDecoder {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
self.decoder.decode_chain(tokens)
}
}
Expand Down Expand Up @@ -139,7 +143,12 @@ impl PyDecoder {
/// :obj:`str`: The decoded string
#[pyo3(text_signature = "(self, tokens)")]
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
ToPyResult(
self.decoder
.decode(tokens)
.map(|t| t.to_compact_string().to_string()),
)
.into()
}

fn __repr__(&self) -> PyResult<String> {
Expand Down Expand Up @@ -235,12 +244,12 @@ pub struct PyWordPieceDec {}
impl PyWordPieceDec {
#[getter]
fn get_prefix(self_: PyRef<Self>) -> String {
getter!(self_, WordPiece, prefix.clone())
getter!(self_, WordPiece, prefix.clone().to_string())
}

#[setter]
fn set_prefix(self_: PyRef<Self>, prefix: String) {
setter!(self_, WordPiece, prefix, prefix);
setter!(self_, WordPiece, prefix, prefix.to_compact_string());
}

#[getter]
Expand All @@ -256,7 +265,10 @@ impl PyWordPieceDec {
#[new]
#[pyo3(signature = (prefix = String::from("##"), cleanup = true), text_signature = "(self, prefix=\"##\", cleanup=True)")]
fn new(prefix: String, cleanup: bool) -> (Self, PyDecoder) {
(PyWordPieceDec {}, WordPiece::new(prefix, cleanup).into())
(
PyWordPieceDec {},
WordPiece::new(prefix.to_compact_string(), cleanup).into(),
)
}
}

Expand Down Expand Up @@ -412,12 +424,12 @@ pub struct PyBPEDecoder {}
impl PyBPEDecoder {
#[getter]
fn get_suffix(self_: PyRef<Self>) -> String {
getter!(self_, BPE, suffix.clone())
getter!(self_, BPE, suffix.to_string())
}

#[setter]
fn set_suffix(self_: PyRef<Self>, suffix: String) {
setter!(self_, BPE, suffix, suffix);
setter!(self_, BPE, suffix, suffix.into());
}

#[new]
Expand All @@ -443,22 +455,27 @@ pub struct PyCTCDecoder {}
impl PyCTCDecoder {
#[getter]
fn get_pad_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, pad_token.clone())
getter!(self_, CTC, pad_token.to_string())
}

#[setter]
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
setter!(self_, CTC, pad_token, pad_token);
setter!(self_, CTC, pad_token, pad_token.into());
}

#[getter]
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, word_delimiter_token.clone())
getter!(self_, CTC, word_delimiter_token.clone()).to_string()
}

#[setter]
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
setter!(
self_,
CTC,
word_delimiter_token,
word_delimiter_token.into()
);
}

#[getter]
Expand Down Expand Up @@ -526,22 +543,33 @@ impl CustomDecoder {
}

impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode<T: ToCompactString>(&self, tokens: Vec<T>) -> tk::Result<impl ToCompactString> {
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| t.to_compact_string().to_string())
.collect();
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract(py)?;
.extract::<String>(py)?;
Ok(decoded)
})
}

fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| t.to_compact_string().to_string())
.collect();
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode_chain", (tokens,), None)?
.extract(py)?;
.extract::<Vec<String>>(py)?;
Ok(decoded)
})
}
Expand Down Expand Up @@ -595,10 +623,21 @@ where
}

impl Decoder for PyDecoderWrapper {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Wrapped(inner) => inner
.read()
.unwrap()
.decode_chain(tokens)
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
PyDecoderWrapper::Custom(inner) => inner
.read()
.unwrap()
.decode_chain(tokens)
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
}
}
}
Expand Down Expand Up @@ -663,14 +702,17 @@ impl PyDecodeStream {

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
))
ToPyResult(
tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix.to_compact_string(),
&mut self.prefix_index,
)
.map(|o| o.map(|s| s.to_string())),
)
.into()
}
}
Expand Down
6 changes: 5 additions & 1 deletion bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ impl PyEncoding {
/// :obj:`List[str]`: The list of tokens
#[getter]
fn get_tokens(&self) -> Vec<String> {
self.encoding.get_tokens().to_vec()
self.encoding
.get_tokens()
.iter()
.map(|x| x.to_string())
.collect()
}

/// The generated word indices.
Expand Down
Loading