@@ -30,101 +30,78 @@ export type TextGenLocalWorkerMessage =
30
30
//==========================================================================||
31
31
// Worker Initialization ||
32
32
//==========================================================================||
33
- const APP_CONFIGS : webllm . AppConfig = {
34
- model_list : [
35
- {
36
- model_url :
37
- 'https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC/resolve/main/' ,
38
- local_id : 'TinyLlama-1.1B-Chat-v0.4-q4f16_1' ,
39
- model_lib_url :
40
- 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/TinyLlama-1.1B-Chat-v0.4/TinyLlama-1.1B-Chat-v0.4-q4f16_1-ctx1k-webgpu.wasm'
41
- } ,
42
- {
43
- model_url :
44
- 'https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/' ,
45
- local_id : 'Llama-2-7b-chat-hf-q4f16_1' ,
46
- model_lib_url :
47
- 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-ctx1k-webgpu.wasm'
48
- } ,
49
- {
50
- model_url : 'https://huggingface.co/mlc-ai/gpt2-q0f16-MLC/resolve/main/' ,
51
- local_id : 'gpt2-q0f16' ,
52
- model_lib_url :
53
- 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/gpt2/gpt2-q0f16-ctx1k-webgpu.wasm'
54
- } ,
55
- {
56
- model_url :
57
- 'https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC/resolve/main/' ,
58
- local_id : 'Mistral-7B-Instruct-v0.2-q3f16_1' ,
59
- model_lib_url :
60
- 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm'
61
- } ,
62
- {
63
- model_url :
64
- 'https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/' ,
65
- local_id : 'Phi2-q4f16_1' ,
66
- model_lib_url :
67
- 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/phi-2/phi-2-q4f16_1-ctx2k-webgpu.wasm' ,
68
- vram_required_MB : 3053.97 ,
69
- low_resource_required : false ,
70
- required_features : [ 'shader-f16' ]
71
- }
72
- ]
73
- } ;
33
+ enum Role {
34
+ user = 'user' ,
35
+ assistant = 'assistant'
36
+ }
74
37
75
38
const CONV_TEMPLATES : Record <
76
39
SupportedLocalModel ,
77
40
Partial < ConvTemplateConfig >
78
41
> = {
79
42
[ SupportedLocalModel [ 'tinyllama-1.1b' ] ] : {
80
- system : '<|im_start|><|im_end|> ' ,
81
- roles : [ '<|im_start|>user' , '<|im_start|>assistant' ] ,
43
+ system_template : '<|im_start|><|im_end|> ' ,
44
+ roles : {
45
+ [ Role . user ] : '<|im_start|>user' ,
46
+ [ Role . assistant ] : '<|im_start|>assistant'
47
+ } ,
82
48
offset : 0 ,
83
49
seps : [ '' , '' ] ,
84
- separator_style : 'Two' ,
85
- stop_str : '<|im_end|>' ,
86
- add_bos : false ,
87
- stop_tokens : [ 2 ]
50
+ stop_str : [ '<|im_end|>' ] ,
51
+ stop_token_ids : [ 2 ]
88
52
} ,
89
53
[ SupportedLocalModel [ 'llama-2-7b' ] ] : {
90
- system : '[INST] <<SYS>><</SYS>>\n\n ' ,
91
- roles : [ '[INST]' , '[/INST]' ] ,
54
+ system_template : '[INST] <<SYS>><</SYS>>\n\n ' ,
55
+ roles : {
56
+ [ Role . user ] : '[INST]' ,
57
+ [ Role . assistant ] : '[/INST]'
58
+ } ,
92
59
offset : 0 ,
93
60
seps : [ ' ' , ' ' ] ,
94
- separator_style : 'Two' ,
95
- stop_str : '[INST]' ,
96
- add_bos : true ,
97
- stop_tokens : [ 2 ]
61
+ role_content_sep : ' ' ,
62
+ role_empty_sep : ' ' ,
63
+ stop_str : [ '[INST]' ] ,
64
+ system_prefix_token_ids : [ 1 ] ,
65
+ stop_token_ids : [ 2 ] ,
66
+ add_role_after_system_message : false
98
67
} ,
99
68
[ SupportedLocalModel [ 'phi-2' ] ] : {
100
- system : '' ,
101
- roles : [ 'Instruct' , 'Output' ] ,
69
+ system_template : '' ,
70
+ system_message : '' ,
71
+ roles : {
72
+ [ Role . user ] : 'Instruct' ,
73
+ [ Role . assistant ] : 'Output'
74
+ } ,
102
75
offset : 0 ,
103
76
seps : [ '\n' ] ,
104
- separator_style : 'Two' ,
105
- stop_str : '<|endoftext|>' ,
106
- add_bos : false ,
107
- stop_tokens : [ 50256 ]
77
+ stop_str : [ '<|endoftext|>' ] ,
78
+ stop_token_ids : [ 50256 ]
79
+ } ,
80
+ [ SupportedLocalModel [ 'gemma-2b' ] ] : {
81
+ system_template : '' ,
82
+ system_message : '' ,
83
+ roles : {
84
+ [ Role . user ] : '<start_of_turn>user' ,
85
+ [ Role . assistant ] : '<start_of_turn>model'
86
+ } ,
87
+ offset : 0 ,
88
+ seps : [ '<end_of_turn>\n' , '<end_of_turn>\n' ] ,
89
+ role_content_sep : '\n' ,
90
+ role_empty_sep : '\n' ,
91
+ stop_str : [ '<end_of_turn>' ] ,
92
+ system_prefix_token_ids : [ 2 ] ,
93
+ stop_token_ids : [ 1 , 107 ]
108
94
}
109
95
} ;
110
96
111
97
const modelMap : Record < SupportedLocalModel , string > = {
112
98
[ SupportedLocalModel [ 'tinyllama-1.1b' ] ] : 'TinyLlama-1.1B-Chat-v0.4-q4f16_1' ,
113
99
[ SupportedLocalModel [ 'llama-2-7b' ] ] : 'Llama-2-7b-chat-hf-q4f16_1' ,
114
- [ SupportedLocalModel [ 'phi-2' ] ] : 'Phi2-q4f16_1'
115
- // [SupportedLocalModel['gpt-2']]: 'gpt2-q0f16'
116
- // [SupportedLocalModel['mistral-7b-v0.2']]: 'Mistral-7B-Instruct-v0.2-q3f16_1'
100
+ [ SupportedLocalModel [ 'phi-2' ] ] : 'Phi2-q4f16_1' ,
101
+ [ SupportedLocalModel [ 'gemma-2b' ] ] : 'gemma-2b-it-q4f16_1'
117
102
} ;
118
103
119
- const chat = new webllm . ChatModule ( ) ;
120
-
121
- // To reset temperature, WebLLM requires to reload the model. Therefore, we just
122
- // fix the temperature for now.
123
- let _temperature = 0.2 ;
124
-
125
- let _modelLoadingComplete : Promise < void > | null = null ;
126
-
127
- chat . setInitProgressCallback ( ( report : webllm . InitProgressReport ) => {
104
+ const initProgressCallback = ( report : webllm . InitProgressReport ) => {
128
105
// Update the main thread about the progress
129
106
console . log ( report . text ) ;
130
107
const message : TextGenLocalWorkerMessage = {
@@ -135,7 +112,9 @@ chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
135
112
}
136
113
} ;
137
114
postMessage ( message ) ;
138
- } ) ;
115
+ } ;
116
+
117
+ let engine : Promise < webllm . EngineInterface > | null = null ;
139
118
140
119
//==========================================================================||
141
120
// Worker Event Handlers ||
@@ -179,15 +158,25 @@ const startLoadModel = async (
179
158
model : SupportedLocalModel ,
180
159
temperature : number
181
160
) => {
182
- _temperature = temperature ;
183
161
const curModel = modelMap [ model ] ;
184
- const chatOption : webllm . ChatOptions = {
185
- temperature : temperature ,
186
- conv_config : CONV_TEMPLATES [ model ] ,
187
- conv_template : 'custom'
188
- } ;
189
- _modelLoadingComplete = chat . reload ( curModel , chatOption , APP_CONFIGS ) ;
190
- await _modelLoadingComplete ;
162
+
163
+ // Only use custom conv template for Llama to override the pre-included system
164
+ // prompt from WebLLM
165
+ let chatOption : webllm . ChatOptions | undefined = undefined ;
166
+
167
+ if ( model === SupportedLocalModel [ 'llama-2-7b' ] ) {
168
+ chatOption = {
169
+ conv_config : CONV_TEMPLATES [ model ] ,
170
+ conv_template : 'custom'
171
+ } ;
172
+ }
173
+
174
+ engine = webllm . CreateEngine ( curModel , {
175
+ initProgressCallback : initProgressCallback ,
176
+ chatOpts : chatOption
177
+ } ) ;
178
+
179
+ await engine ;
191
180
192
181
try {
193
182
// Send back the data to the main thread
@@ -220,24 +209,26 @@ const startLoadModel = async (
220
209
*/
221
210
const startTextGen = async ( prompt : string , temperature : number ) => {
222
211
try {
223
- if ( _modelLoadingComplete ) {
224
- await _modelLoadingComplete ;
225
- }
226
-
227
- const truncated = prompt . slice ( 0 , 2000 ) ;
228
-
229
- const response = await chat . generate ( truncated ) ;
212
+ const curEngine = await engine ! ;
213
+ const response = await curEngine . chat . completions . create ( {
214
+ messages : [ { role : 'user' , content : prompt } ] ,
215
+ n : 1 ,
216
+ max_gen_len : 2048 ,
217
+ // Override temperature to 0 because local models are very unstable
218
+ temperature : 0
219
+ // logprobs: false
220
+ } ) ;
230
221
231
222
// Reset the chat cache to avoid memorizing previous messages
232
- await chat . resetChat ( ) ;
223
+ await curEngine . resetChat ( ) ;
233
224
234
225
// Send back the data to the main thread
235
226
const message : TextGenLocalWorkerMessage = {
236
227
command : 'finishTextGen' ,
237
228
payload : {
238
229
requestID : 'web-llm' ,
239
230
apiKey : '' ,
240
- result : response ,
231
+ result : response . choices [ 0 ] . message . content || '' ,
241
232
prompt : prompt ,
242
233
detail : ''
243
234
}
@@ -263,7 +254,7 @@ const startTextGen = async (prompt: string, temperature: number) => {
263
254
264
255
export const hasLocalModelInCache = async ( model : SupportedLocalModel ) => {
265
256
const curModel = modelMap [ model ] ;
266
- const inCache = await webllm . hasModelInCache ( curModel , APP_CONFIGS ) ;
257
+ const inCache = await webllm . hasModelInCache ( curModel ) ;
267
258
return inCache ;
268
259
} ;
269
260
0 commit comments