@@ -19,78 +19,37 @@ fn _byte_pair_merge(
19
19
ranks : & HashMap < Vec < u8 > , Rank > ,
20
20
piece : & [ u8 ] ,
21
21
) -> Vec < ( usize , Rank ) > {
22
- // This is a vector of (start, rank).
23
- // The rank is of the byte pair starting at position start.
24
- // The rank of the last item in the vector is not a valid value.
25
- let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
26
-
27
- let get_rank = {
28
- #[ inline( always) ]
29
- |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
30
- if ( start_idx + skip + 2 ) < parts. len ( ) {
31
- ranks
32
- . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
33
- . copied ( )
34
- } else {
35
- None
36
- }
37
- }
22
+ let get_rank = |parts : & Vec < ( usize , _ ) > , start_idx : usize , end_idx : usize | {
23
+ * parts. get ( end_idx)
24
+ . map ( |e| parts. get ( start_idx) . unwrap ( ) . 0 ..e. 0 )
25
+ . and_then ( |r| piece. get ( r) )
26
+ . filter ( |p| p. len ( ) < piece. len ( ) )
27
+ . and_then ( |p| ranks. get ( p) )
28
+ . unwrap_or ( & Rank :: MAX )
38
29
} ;
39
30
40
- // We look up the ranks once in the beginning and iteratively update
41
- // them during each merge, which reduces the number of rank lookups.
42
- for i in 0 ..parts. len ( ) - 2 {
43
- match get_rank ( & parts, i, 0 ) {
44
- Some ( rank) => {
45
- // Rank::MAX is a sentinel value and cannot be a valid rank
46
- debug_assert ! ( rank != Rank :: MAX ) ;
47
- parts[ i] . 1 = rank;
48
- }
49
- None => {
50
- continue ;
51
- }
52
- } ;
53
- }
54
-
55
- // If you have n parts and m merges, this does O(mn) work.
56
- // We could do something with a heap and do O(m log n) work.
57
- // It is important to consider that n is often small (<100), and as such
58
- // the cache-locality benefits outweigh the algorithmic complexity downsides
59
- // of the `parts` vector data structure above.
60
-
61
- // Note that we hash bytes, not token pairs. As long as we train BPE the way we
62
- // currently do, this is equivalent. An easy way to break this would be to decouple
63
- // merge priority from token index or to prevent specific token merges.
64
- loop {
65
- if parts. len ( ) == 1 {
66
- break ;
31
+ let ( mut min_rank_index, mut min_rank) = ( 0 , Rank :: MAX ) ;
32
+ let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
33
+ for i in 0 ..piece. len ( ) + 1 {
34
+ let part = ( i, * piece. get ( i..i + 2 ) . and_then ( |p| ranks. get ( p) ) . unwrap_or ( & Rank :: MAX ) ) ;
35
+ if part. 1 < min_rank {
36
+ ( min_rank_index, min_rank) = part;
67
37
}
38
+ parts. push ( part) ;
39
+ }
68
40
69
- // Rank::MAX is a sentinel rank value allowing us to
70
- // take the min more quickly
71
- let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
72
- for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
73
- if rank < min_rank. 0 {
74
- min_rank = ( rank, i) ;
75
- }
41
+ while parts. len ( ) > 3 && min_rank != Rank :: MAX {
42
+ if min_rank_index > 0 {
43
+ parts[ min_rank_index - 1 ] . 1 = get_rank ( & parts, min_rank_index - 1 , min_rank_index + 2 ) ;
76
44
}
45
+ parts[ min_rank_index] . 1 = get_rank ( & parts, min_rank_index, min_rank_index + 3 ) ;
46
+ parts. remove ( min_rank_index + 1 ) ;
77
47
78
- if min_rank. 0 != Rank :: MAX {
79
- let i = min_rank. 1 ;
80
-
81
- // NOTE: We are about to remove parts[i + 1]. We do not do it
82
- // yet because there are cache-locality benefits to updating
83
- // parts[i] and parts[i-1] before removing, which could thrash
84
- // the cache. Thus, we update the rank calculation by skipping over
85
- // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
86
- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
87
- if i > 0 {
88
- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
48
+ ( min_rank_index, min_rank) = ( 0 , parts[ 0 ] . 1 ) ;
49
+ for i in 1 ..parts. len ( ) - 2 {
50
+ if parts[ i] . 1 < min_rank {
51
+ ( min_rank_index, min_rank) = ( i, parts[ i] . 1 ) ;
89
52
}
90
-
91
- parts. remove ( i + 1 ) ;
92
- } else {
93
- break ;
94
53
}
95
54
}
96
55
0 commit comments