1
1
// This check is new and seems buggy (possibly with PyO3 interaction)
2
2
#![ allow( clippy:: borrow_deref_ref) ]
3
3
4
- use std:: collections:: HashSet ;
4
+ use std:: collections:: { BTreeMap , BTreeSet , HashSet } ;
5
+ use std:: iter:: successors;
5
6
use std:: num:: NonZeroU64 ;
6
7
use std:: thread;
7
8
@@ -15,9 +16,22 @@ use rustc_hash::FxHashMap as HashMap;
15
16
16
17
type Rank = u32 ;
17
18
19
+ const LARGE_ENCODER_CHARACTER_LIMIT : usize = 500 ;
20
+
18
21
fn _byte_pair_merge (
19
22
ranks : & HashMap < Vec < u8 > , Rank > ,
20
23
piece : & [ u8 ] ,
24
+ ) -> Vec < ( usize , Rank ) > {
25
+ if piece. len ( ) < LARGE_ENCODER_CHARACTER_LIMIT {
26
+ _byte_pair_merge_small ( ranks, piece) // Quadratic, but lightweight
27
+ } else {
28
+ _byte_pair_merge_large ( ranks, piece) // Linearithmic, but heavy
29
+ }
30
+ }
31
+
32
+ fn _byte_pair_merge_small (
33
+ ranks : & HashMap < Vec < u8 > , Rank > ,
34
+ piece : & [ u8 ] ,
21
35
) -> Vec < ( usize , Rank ) > {
22
36
let get_rank = |parts : & Vec < ( usize , _ ) > , start_idx : usize , end_idx : usize | {
23
37
* parts. get ( end_idx)
@@ -56,6 +70,85 @@ fn _byte_pair_merge(
56
70
parts
57
71
}
58
72
73
+ fn _byte_pair_merge_large (
74
+ ranks : & HashMap < Vec < u8 > , Rank > ,
75
+ piece : & [ u8 ] ,
76
+ ) -> Vec < ( usize , Rank ) > {
77
+ let get_rank = |start_idx : usize , end_idx : usize | {
78
+ * piece. get ( start_idx..end_idx)
79
+ . filter ( |p| p. len ( ) < piece. len ( ) )
80
+ . and_then ( |p| ranks. get ( p) )
81
+ . unwrap_or ( & Rank :: MAX )
82
+ } ;
83
+
84
+ let mut rank_indexes = BTreeMap :: < Rank , BTreeSet < usize > > :: new ( ) ;
85
+ let mut index_rank = vec ! [ Rank :: MAX ; piece. len( ) + 1 ] ;
86
+ let mut index_prev = vec ! [ usize :: MAX ; piece. len( ) + 1 ] ;
87
+ let mut index_next = vec ! [ usize :: MAX ; piece. len( ) + 1 ] ;
88
+
89
+ let mut prev_node = None ;
90
+ for i in 0 ..=piece. len ( ) {
91
+ let rank = get_rank ( i, i + 2 ) ;
92
+ index_rank[ i] = rank;
93
+ if let Some ( prev) = prev_node {
94
+ index_prev[ i] = prev;
95
+ index_next[ prev] = i;
96
+ }
97
+ prev_node = Some ( i) ;
98
+
99
+ rank_indexes. entry ( rank) . or_default ( ) . insert ( i) ;
100
+ }
101
+
102
+ let mut token_count = piece. len ( ) ;
103
+ while token_count > 2 && rank_indexes. len ( ) > 1 {
104
+ let mut skip_next = false ;
105
+ if let Some ( ( _, nodes) ) = rank_indexes. pop_first ( ) {
106
+ for & min_node in & nodes {
107
+ if skip_next {
108
+ skip_next = false ;
109
+ continue ;
110
+ }
111
+
112
+ let min_rank = index_rank[ min_node] ;
113
+
114
+ let prev_node = index_prev[ min_node] ;
115
+ let next_node = index_next[ min_node] ;
116
+ let next_next_node = index_next[ next_node] ;
117
+ let next_next_next_node = index_next[ next_next_node] ;
118
+
119
+ if prev_node != usize:: MAX {
120
+ let new_rank = get_rank ( prev_node, next_next_node) ;
121
+ if index_rank[ prev_node] != new_rank {
122
+ rank_indexes. get_mut ( & index_rank[ prev_node] ) . unwrap ( ) . remove ( & prev_node) ;
123
+ index_rank[ prev_node] = new_rank;
124
+ rank_indexes. entry ( new_rank) . or_default ( ) . insert ( prev_node) ;
125
+ }
126
+ }
127
+
128
+ let new_rank = get_rank ( min_node, next_next_next_node) ;
129
+ index_rank[ min_node] = new_rank;
130
+ rank_indexes. entry ( new_rank) . or_default ( ) . insert ( min_node) ;
131
+
132
+ index_next[ min_node] = next_next_node;
133
+ index_prev[ next_next_node] = min_node;
134
+ let next_node_rank = index_rank[ next_node] ;
135
+ if next_node_rank != Rank :: MAX {
136
+ skip_next = next_node_rank == min_rank;
137
+ if !skip_next {
138
+ rank_indexes. get_mut ( & next_node_rank) . unwrap ( ) . remove ( & next_node) ;
139
+ }
140
+ }
141
+
142
+ token_count -= 1 ;
143
+ }
144
+ }
145
+ }
146
+
147
+ successors ( Some ( 0 ) , |& n| index_next. get ( n) . filter ( |& x| * x != usize:: MAX ) . copied ( ) )
148
+ . map ( |n| ( n, Rank :: MAX ) )
149
+ . collect ( )
150
+ }
151
+
59
152
pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
60
153
assert ! ( piece. len( ) > 1 ) ;
61
154
_byte_pair_merge ( & ranks, & piece)
0 commit comments