Skip to content

Commit d24b67b

Browse files
author
Lőrinc
committed
Add a _byte_pair_merge_large for worst-case scenarios
We're storing the ranks in a sorted tree of sorted (or linked) trees. Getting the minimum rank is logarithmic and each subsequent occurrence is constant time. To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree. We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually.
1 parent 8f5dd7d commit d24b67b

File tree

1 file changed

+94
-1
lines changed

1 file changed

+94
-1
lines changed

src/lib.rs

+94-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// This check is new and seems buggy (possibly with PyO3 interaction)
22
#![allow(clippy::borrow_deref_ref)]
33

4-
use std::collections::HashSet;
4+
use std::collections::{BTreeMap, BTreeSet, HashSet};
5+
use std::iter::successors;
56
use std::num::NonZeroU64;
67
use std::thread;
78

@@ -15,9 +16,22 @@ use rustc_hash::FxHashMap as HashMap;
1516

1617
type Rank = u32;
1718

19+
const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500;
20+
1821
fn _byte_pair_merge(
1922
ranks: &HashMap<Vec<u8>, Rank>,
2023
piece: &[u8],
24+
) -> Vec<(usize, Rank)> {
25+
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT {
26+
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight
27+
} else {
28+
_byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy
29+
}
30+
}
31+
32+
fn _byte_pair_merge_small(
33+
ranks: &HashMap<Vec<u8>, Rank>,
34+
piece: &[u8],
2135
) -> Vec<(usize, Rank)> {
2236
let get_rank = |parts: &Vec<(usize, _)>, start_idx: usize, end_idx: usize| {
2337
*parts.get(end_idx)
@@ -56,6 +70,85 @@ fn _byte_pair_merge(
5670
parts
5771
}
5872

73+
fn _byte_pair_merge_large(
74+
ranks: &HashMap<Vec<u8>, Rank>,
75+
piece: &[u8],
76+
) -> Vec<(usize, Rank)> {
77+
let get_rank = |start_idx: usize, end_idx: usize| {
78+
*piece.get(start_idx..end_idx)
79+
.filter(|p| p.len() < piece.len())
80+
.and_then(|p| ranks.get(p))
81+
.unwrap_or(&Rank::MAX)
82+
};
83+
84+
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new();
85+
let mut index_rank = vec![Rank::MAX; piece.len() + 1];
86+
let mut index_prev = vec![usize::MAX; piece.len() + 1];
87+
let mut index_next = vec![usize::MAX; piece.len() + 1];
88+
89+
let mut prev_node = None;
90+
for i in 0..=piece.len() {
91+
let rank = get_rank(i, i + 2);
92+
index_rank[i] = rank;
93+
if let Some(prev) = prev_node {
94+
index_prev[i] = prev;
95+
index_next[prev] = i;
96+
}
97+
prev_node = Some(i);
98+
99+
rank_indexes.entry(rank).or_default().insert(i);
100+
}
101+
102+
let mut token_count = piece.len();
103+
while token_count > 2 && rank_indexes.len() > 1 {
104+
let mut skip_next = false;
105+
if let Some((_, nodes)) = rank_indexes.pop_first() {
106+
for &min_node in &nodes {
107+
if skip_next {
108+
skip_next = false;
109+
continue;
110+
}
111+
112+
let min_rank = index_rank[min_node];
113+
114+
let prev_node = index_prev[min_node];
115+
let next_node = index_next[min_node];
116+
let next_next_node = index_next[next_node];
117+
let next_next_next_node = index_next[next_next_node];
118+
119+
if prev_node != usize::MAX {
120+
let new_rank = get_rank(prev_node, next_next_node);
121+
if index_rank[prev_node] != new_rank {
122+
rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node);
123+
index_rank[prev_node] = new_rank;
124+
rank_indexes.entry(new_rank).or_default().insert(prev_node);
125+
}
126+
}
127+
128+
let new_rank = get_rank(min_node, next_next_next_node);
129+
index_rank[min_node] = new_rank;
130+
rank_indexes.entry(new_rank).or_default().insert(min_node);
131+
132+
index_next[min_node] = next_next_node;
133+
index_prev[next_next_node] = min_node;
134+
let next_node_rank = index_rank[next_node];
135+
if next_node_rank != Rank::MAX {
136+
skip_next = next_node_rank == min_rank;
137+
if !skip_next {
138+
rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node);
139+
}
140+
}
141+
142+
token_count -= 1;
143+
}
144+
}
145+
}
146+
147+
successors(Some(0), |&n| index_next.get(n).filter(|&x| *x != usize::MAX).copied())
148+
.map(|n| (n, Rank::MAX))
149+
.collect()
150+
}
151+
59152
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
60153
assert!(piece.len() > 1);
61154
_byte_pair_merge(&ranks, &piece)

0 commit comments

Comments
 (0)