Skip to content

Commit 8f5dd7d

Browse files
author
Lőrinc
committed
Simplify and optimize _byte_pair_merge
We're already calculating the min during construction of the first pairs. Because of this minimum calculation is moved to the end of the loop. Since we've filtered out single tokens, we can safely exit when the parts length is already small enough
1 parent 7398253 commit 8f5dd7d

File tree

1 file changed

+24
-65
lines changed

1 file changed

+24
-65
lines changed

src/lib.rs

+24-65
Original file line numberDiff line numberDiff line change
@@ -19,78 +19,37 @@ fn _byte_pair_merge(
1919
ranks: &HashMap<Vec<u8>, Rank>,
2020
piece: &[u8],
2121
) -> Vec<(usize, Rank)> {
22-
// This is a vector of (start, rank).
23-
// The rank is of the byte pair starting at position start.
24-
// The rank of the last item in the vector is not a valid value.
25-
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
26-
27-
let get_rank = {
28-
#[inline(always)]
29-
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
30-
if (start_idx + skip + 2) < parts.len() {
31-
ranks
32-
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
33-
.copied()
34-
} else {
35-
None
36-
}
37-
}
22+
let get_rank = |parts: &Vec<(usize, _)>, start_idx: usize, end_idx: usize| {
23+
*parts.get(end_idx)
24+
.map(|e| parts.get(start_idx).unwrap().0..e.0)
25+
.and_then(|r| piece.get(r))
26+
.filter(|p| p.len() < piece.len())
27+
.and_then(|p| ranks.get(p))
28+
.unwrap_or(&Rank::MAX)
3829
};
3930

40-
// We look up the ranks once in the beginning and iteratively update
41-
// them during each merge, which reduces the number of rank lookups.
42-
for i in 0..parts.len() - 2 {
43-
match get_rank(&parts, i, 0) {
44-
Some(rank) => {
45-
// Rank::MAX is a sentinel value and cannot be a valid rank
46-
debug_assert!(rank != Rank::MAX);
47-
parts[i].1 = rank;
48-
}
49-
None => {
50-
continue;
51-
}
52-
};
53-
}
54-
55-
// If you have n parts and m merges, this does O(mn) work.
56-
// We could do something with a heap and do O(m log n) work.
57-
// It is important to consider that n is often small (<100), and as such
58-
// the cache-locality benefits outweigh the algorithmic complexity downsides
59-
// of the `parts` vector data structure above.
60-
61-
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
62-
// currently do, this is equivalent. An easy way to break this would be to decouple
63-
// merge priority from token index or to prevent specific token merges.
64-
loop {
65-
if parts.len() == 1 {
66-
break;
31+
let (mut min_rank_index, mut min_rank) = (0, Rank::MAX);
32+
let mut parts = Vec::with_capacity(piece.len() + 1);
33+
for i in 0..piece.len() + 1 {
34+
let part = (i, *piece.get(i..i + 2).and_then(|p| ranks.get(p)).unwrap_or(&Rank::MAX));
35+
if part.1 < min_rank {
36+
(min_rank_index, min_rank) = part;
6737
}
38+
parts.push(part);
39+
}
6840

69-
// Rank::MAX is a sentinel rank value allowing us to
70-
// take the min more quickly
71-
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
72-
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
73-
if rank < min_rank.0 {
74-
min_rank = (rank, i);
75-
}
41+
while parts.len() > 3 && min_rank != Rank::MAX {
42+
if min_rank_index > 0 {
43+
parts[min_rank_index - 1].1 = get_rank(&parts, min_rank_index - 1, min_rank_index + 2);
7644
}
45+
parts[min_rank_index].1 = get_rank(&parts, min_rank_index, min_rank_index + 3);
46+
parts.remove(min_rank_index + 1);
7747

78-
if min_rank.0 != Rank::MAX {
79-
let i = min_rank.1;
80-
81-
// NOTE: We are about to remove parts[i + 1]. We do not do it
82-
// yet because there are cache-locality benefits to updating
83-
// parts[i] and parts[i-1] before removing, which could thrash
84-
// the cache. Thus, we update the rank calculation by skipping over
85-
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
86-
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
87-
if i > 0 {
88-
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
48+
(min_rank_index, min_rank) = (0, parts[0].1);
49+
for i in 1..parts.len() - 2 {
50+
if parts[i].1 < min_rank {
51+
(min_rank_index, min_rank) = (i, parts[i].1);
8952
}
90-
91-
parts.remove(i + 1);
92-
} else {
93-
break;
9453
}
9554
}
9655

0 commit comments

Comments
 (0)