From 66a57bae4017d2b41a3e47654db8ad44c56ee66f Mon Sep 17 00:00:00 2001
From: Shantanu <shantanu@openai.com>
Date: Fri, 9 Feb 2024 14:27:44 -0800
Subject: [PATCH 1/2] Simplify byte_pair_merge

---
 src/lib.rs | 96 ++++++++++++++++++++----------------------------------
 1 file changed, 36 insertions(+), 60 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 2b9e15ff..9d072383 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap;
 
 type Rank = u32;
 
-fn _byte_pair_merge(
-    ranks: &HashMap<Vec<u8>, Rank>,
-    piece: &[u8],
-) -> Vec<(usize, Rank)> {
+fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
     // This is a vector of (start, rank).
-    // The rank is of the byte pair starting at position start.
-    // The rank of the last item in the vector is not a valid value.
-    let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
+    // The rank is of the pair starting at position start.
+    let mut parts = Vec::with_capacity(piece.len() + 1);
+
+    // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
+    // the way we currently do, this is equivalent. An easy way to break this would be to decouple
+    // merge priority from token index or to prevent specific token merges.
+    let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
+    for i in 0..piece.len() - 1 {
+        let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
+        if rank < min_rank.0 {
+            min_rank = (rank, i);
+        }
+        parts.push((i, rank));
+    }
+    parts.push((piece.len() - 1, Rank::MAX));
+    parts.push((piece.len(), Rank::MAX));
 
     let get_rank = {
         #[inline(always)]
-        |parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
-            if (start_idx + skip + 2) < parts.len() {
-                ranks
-                    .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
-                    .copied()
+        |parts: &Vec<(usize, Rank)>, i: usize| {
+            if (i + 3) < parts.len() {
+                // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
+                // parts[i + 1], see comment in the main loop.
+                *ranks
+                    .get(&piece[parts[i].0..parts[i + 3].0])
+                    .unwrap_or(&Rank::MAX)
             } else {
-                None
+                Rank::MAX
             }
         }
     };
 
-    // We look up the ranks once in the beginning and iteratively update
-    // them during each merge, which reduces the number of rank lookups.
-    for i in 0..parts.len() - 2 {
-        match get_rank(&parts, i, 0) {
-            Some(rank) => {
-                // Rank::MAX is a sentinel value and cannot be a valid rank
-                debug_assert!(rank != Rank::MAX);
-                parts[i].1 = rank;
-            }
-            None => {
-                continue;
-            }
-        };
-    }
-
     // If you have n parts and m merges, this does O(mn) work.
     // We could do something with a heap and do O(m log n) work.
-    // It is important to consider that n is often small (<100), and as such
-    // the cache-locality benefits outweigh the algorithmic complexity downsides
-    // of the `parts` vector data structure above.
-
-    // Note that we hash bytes, not token pairs. As long as we train BPE the way we
-    // currently do, this is equivalent. An easy way to break this would be to decouple
-    // merge priority from token index or to prevent specific token merges.
-    loop {
-        if parts.len() == 1 {
-            break;
+    // n is often very small so considerations like cache-locality outweigh the algorithmic
+    // complexity downsides of the `parts` vector.
+    while min_rank.0 != Rank::MAX {
+        let i = min_rank.1;
+        // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
+        // `parts.remove(i + 1)` will thrash the cache.
+        parts[i].1 = get_rank(&parts, i);
+        if i > 0 {
+            parts[i - 1].1 = get_rank(&parts, i - 1);
         }
+        parts.remove(i + 1);
 
-        // Rank::MAX is a sentinel rank value allowing us to
-        // take the min more quickly
-        let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
+        min_rank = (Rank::MAX, 0);
         for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
             if rank < min_rank.0 {
                 min_rank = (rank, i);
             }
         }
-
-        if min_rank.0 != Rank::MAX {
-            let i = min_rank.1;
-
-            // NOTE: We are about to remove parts[i + 1]. We do not do it
-            // yet because there are cache-locality benefits to updating
-            // parts[i] and parts[i-1] before removing, which could thrash
-            // the cache. Thus, we update the rank calculation by skipping over
-            // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
-            parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
-            if i > 0 {
-                parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
-            }
-
-            parts.remove(i + 1);
-        } else {
-            break;
-        }
     }
-
     parts
 }
 

From 2cc09e0776964c30e51f5a6475d9cd6e1572c828 Mon Sep 17 00:00:00 2001
From: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
Date: Sun, 11 Feb 2024 00:15:37 -0800
Subject: [PATCH 2/2] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Lőrinc Pap <1841944+paplorinc@users.noreply.github.com>
---
 src/lib.rs | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 9d072383..b466edd1 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -23,7 +23,7 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
     // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
     // the way we currently do, this is equivalent. An easy way to break this would be to decouple
     // merge priority from token index or to prevent specific token merges.
-    let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
+    let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
     for i in 0..piece.len() - 1 {
         let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
         if rank < min_rank.0 {
@@ -57,13 +57,13 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
         let i = min_rank.1;
         // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
         // `parts.remove(i + 1)` will thrash the cache.
-        parts[i].1 = get_rank(&parts, i);
         if i > 0 {
             parts[i - 1].1 = get_rank(&parts, i - 1);
         }
+        parts[i].1 = get_rank(&parts, i);
         parts.remove(i + 1);
 
-        min_rank = (Rank::MAX, 0);
+        min_rank = (Rank::MAX, usize::MAX);
         for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
             if rank < min_rank.0 {
                 min_rank = (rank, i);