Skip to content

Commit 5af8058

Browse files
author
Lőrinc
committed
Add a _byte_pair_merge_large for worst-case scenarios
We're storing the ranks in a sorted tree of sorted (or linked) trees. Getting the minimum rank is logarithmic and each subsequent occurrence is constant time. To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree. We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually.
1 parent aeca532 commit 5af8058

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

src/lib.rs

+84-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// This check is new and seems buggy (possibly with PyO3 interaction)
22
#![allow(clippy::borrow_deref_ref)]
33

4-
use std::collections::HashSet;
4+
use std::collections::{BTreeMap, BTreeSet, HashSet};
5+
use std::iter::successors;
56
use std::num::NonZeroU64;
67
use std::thread;
78

@@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap;
1516

1617
type Rank = u32;
1718

19+
const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500;
20+
1821
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
22+
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT {
23+
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight
24+
} else {
25+
_byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy
26+
}
27+
}
28+
29+
fn _byte_pair_merge_small(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
1930
// This is a vector of (start, rank).
2031
// The rank is of the pair starting at position start.
2132
let mut parts = Vec::with_capacity(piece.len() + 1);
@@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
7384
parts
7485
}
7586

87+
fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
88+
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new();
89+
let mut index_rank = vec![Rank::MAX; piece.len() + 1];
90+
let mut index_prev = vec![usize::MAX; piece.len() + 1];
91+
let mut index_next = vec![usize::MAX; piece.len() + 1];
92+
93+
let get_rank = |start_idx: usize, end_idx: usize| -> Rank {
94+
*piece.get(start_idx..end_idx)
95+
.and_then(|p| ranks.get(p))
96+
.unwrap_or(&Rank::MAX)
97+
};
98+
99+
let mut prev_node = None;
100+
for i in 0..=piece.len() {
101+
let rank = get_rank(i, i + 2);
102+
index_rank[i] = rank;
103+
if let Some(prev) = prev_node {
104+
index_prev[i] = prev;
105+
index_next[prev] = i;
106+
}
107+
prev_node = Some(i);
108+
109+
rank_indexes.entry(rank).or_default().insert(i);
110+
}
111+
112+
while rank_indexes.len() > 1 {
113+
let mut skip_next = false;
114+
if let Some((_, nodes)) = rank_indexes.pop_first() {
115+
for &min_node in &nodes {
116+
if skip_next {
117+
skip_next = false;
118+
continue;
119+
}
120+
121+
let min_rank = index_rank[min_node];
122+
123+
let prev_node = index_prev[min_node];
124+
let next_node = index_next[min_node];
125+
let next_next_node = index_next[next_node];
126+
let next_next_next_node = index_next[next_next_node];
127+
128+
if prev_node != usize::MAX {
129+
let new_rank = get_rank(prev_node, next_next_node);
130+
if index_rank[prev_node] != new_rank {
131+
rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node);
132+
index_rank[prev_node] = new_rank;
133+
rank_indexes.entry(new_rank).or_default().insert(prev_node);
134+
}
135+
}
136+
137+
let new_rank = get_rank(min_node, next_next_next_node);
138+
index_rank[min_node] = new_rank;
139+
rank_indexes.entry(new_rank).or_default().insert(min_node);
140+
141+
index_next[min_node] = next_next_node;
142+
index_prev[next_next_node] = min_node;
143+
144+
let next_node_rank = index_rank[next_node];
145+
if next_node_rank == min_rank {
146+
skip_next = true;
147+
} else if next_node_rank != Rank::MAX {
148+
rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node);
149+
}
150+
}
151+
}
152+
}
153+
154+
successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied())
155+
.map(|n| (n, Rank::MAX))
156+
.collect()
157+
}
158+
76159
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
77160
assert!(piece.len() > 1);
78161
_byte_pair_merge(&ranks, &piece)

0 commit comments

Comments
 (0)