Skip to content

Commit 155461f

Browse files
committed
Support new webllm syntax
Signed-off-by: Jay Wang <[email protected]>
1 parent 7713346 commit 155461f

File tree

5 files changed

+89
-101
lines changed

5 files changed

+89
-101
lines changed

examples/rag-playground/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
},
1414
"devDependencies": {
1515
"@floating-ui/dom": "^1.6.1",
16-
"@mlc-ai/web-llm": "^0.2.18",
16+
"@mlc-ai/web-llm": "0.2.35",
1717
"@types/d3-array": "^3.2.1",
1818
"@types/d3-format": "^3.0.4",
1919
"@types/d3-random": "^3.0.3",

examples/rag-playground/src/components/panel-setting/panel-setting.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ const apiKeyDescriptionMap: Record<ModelFamily, TemplateResult> = {
5757
const localModelSizeMap: Record<SupportedLocalModel, string> = {
5858
[SupportedLocalModel['tinyllama-1.1b']]: '630 MB',
5959
[SupportedLocalModel['llama-2-7b']]: '3.6 GB',
60-
[SupportedLocalModel['phi-2']]: '1.5 GB'
61-
// [SupportedLocalModel['gpt-2']]: '311 MB'
60+
[SupportedLocalModel['phi-2']]: '1.5 GB',
61+
[SupportedLocalModel['gemma-2b']]: '1.3 GB'
6262
// [SupportedLocalModel['mistral-7b-v0.2']]: '3.5 GB'
6363
};
6464

examples/rag-playground/src/components/playground/playground.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ export class MememoPlayground extends LitElement {
586586
}
587587

588588
// case SupportedLocalModel['mistral-7b-v0.2']:
589-
// case SupportedLocalModel['gpt-2']:
589+
case SupportedLocalModel['gemma-2b']:
590590
case SupportedLocalModel['phi-2']:
591591
case SupportedLocalModel['llama-2-7b']:
592592
case SupportedLocalModel['tinyllama-1.1b']: {

examples/rag-playground/src/components/playground/user-config.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ import { get, set, del, clear } from 'idb-keyval';
33
const PREFIX = 'user-config';
44

55
export enum SupportedLocalModel {
6+
'gemma-2b' = 'Gemma (2B)',
67
'llama-2-7b' = 'Llama 2 (7B)',
7-
// 'mistral-7b-v0.2' = 'Mistral (7B)',
88
'phi-2' = 'Phi 2 (2.7B)',
99
'tinyllama-1.1b' = 'TinyLlama (1.1B)'
10-
// 'gpt-2' = 'GPT 2 (124M)'
1110
}
1211

1312
export enum SupportedRemoteModel {
@@ -27,9 +26,8 @@ export const supportedModelReverseLookup: Record<
2726
[SupportedRemoteModel['gemini-pro']]: 'gemini-pro',
2827
[SupportedLocalModel['tinyllama-1.1b']]: 'tinyllama-1.1b',
2928
[SupportedLocalModel['llama-2-7b']]: 'llama-2-7b',
30-
[SupportedLocalModel['phi-2']]: 'phi-2'
31-
// [SupportedLocalModel['gpt-2']]: 'gpt-2'
32-
// [SupportedLocalModel['mistral-7b-v0.2']]: 'mistral-7b-v0.2'
29+
[SupportedLocalModel['phi-2']]: 'phi-2',
30+
[SupportedLocalModel['gemma-2b']]: 'gemma-2b'
3331
};
3432

3533
export enum ModelFamily {
@@ -48,8 +46,7 @@ export const modelFamilyMap: Record<
4846
[SupportedRemoteModel['gemini-pro']]: ModelFamily.google,
4947
[SupportedLocalModel['tinyllama-1.1b']]: ModelFamily.local,
5048
[SupportedLocalModel['llama-2-7b']]: ModelFamily.local,
51-
// [SupportedLocalModel['gpt-2']]: ModelFamily.local
52-
// [SupportedLocalModel['mistral-7b-v0.2']]: ModelFamily.local
49+
[SupportedLocalModel['gemma-2b']]: ModelFamily.local,
5350
[SupportedLocalModel['phi-2']]: ModelFamily.local
5451
};
5552

examples/rag-playground/src/llms/web-llm.ts

Lines changed: 81 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -30,101 +30,78 @@ export type TextGenLocalWorkerMessage =
3030
//==========================================================================||
3131
// Worker Initialization ||
3232
//==========================================================================||
33-
const APP_CONFIGS: webllm.AppConfig = {
34-
model_list: [
35-
{
36-
model_url:
37-
'https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC/resolve/main/',
38-
local_id: 'TinyLlama-1.1B-Chat-v0.4-q4f16_1',
39-
model_lib_url:
40-
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/TinyLlama-1.1B-Chat-v0.4/TinyLlama-1.1B-Chat-v0.4-q4f16_1-ctx1k-webgpu.wasm'
41-
},
42-
{
43-
model_url:
44-
'https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/',
45-
local_id: 'Llama-2-7b-chat-hf-q4f16_1',
46-
model_lib_url:
47-
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-ctx1k-webgpu.wasm'
48-
},
49-
{
50-
model_url: 'https://huggingface.co/mlc-ai/gpt2-q0f16-MLC/resolve/main/',
51-
local_id: 'gpt2-q0f16',
52-
model_lib_url:
53-
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/gpt2/gpt2-q0f16-ctx1k-webgpu.wasm'
54-
},
55-
{
56-
model_url:
57-
'https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC/resolve/main/',
58-
local_id: 'Mistral-7B-Instruct-v0.2-q3f16_1',
59-
model_lib_url:
60-
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm'
61-
},
62-
{
63-
model_url:
64-
'https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/',
65-
local_id: 'Phi2-q4f16_1',
66-
model_lib_url:
67-
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/phi-2/phi-2-q4f16_1-ctx2k-webgpu.wasm',
68-
vram_required_MB: 3053.97,
69-
low_resource_required: false,
70-
required_features: ['shader-f16']
71-
}
72-
]
73-
};
33+
enum Role {
34+
user = 'user',
35+
assistant = 'assistant'
36+
}
7437

7538
const CONV_TEMPLATES: Record<
7639
SupportedLocalModel,
7740
Partial<ConvTemplateConfig>
7841
> = {
7942
[SupportedLocalModel['tinyllama-1.1b']]: {
80-
system: '<|im_start|><|im_end|> ',
81-
roles: ['<|im_start|>user', '<|im_start|>assistant'],
43+
system_template: '<|im_start|><|im_end|> ',
44+
roles: {
45+
[Role.user]: '<|im_start|>user',
46+
[Role.assistant]: '<|im_start|>assistant'
47+
},
8248
offset: 0,
8349
seps: ['', ''],
84-
separator_style: 'Two',
85-
stop_str: '<|im_end|>',
86-
add_bos: false,
87-
stop_tokens: [2]
50+
stop_str: ['<|im_end|>'],
51+
stop_token_ids: [2]
8852
},
8953
[SupportedLocalModel['llama-2-7b']]: {
90-
system: '[INST] <<SYS>><</SYS>>\n\n ',
91-
roles: ['[INST]', '[/INST]'],
54+
system_template: '[INST] <<SYS>><</SYS>>\n\n ',
55+
roles: {
56+
[Role.user]: '[INST]',
57+
[Role.assistant]: '[/INST]'
58+
},
9259
offset: 0,
9360
seps: [' ', ' '],
94-
separator_style: 'Two',
95-
stop_str: '[INST]',
96-
add_bos: true,
97-
stop_tokens: [2]
61+
role_content_sep: ' ',
62+
role_empty_sep: ' ',
63+
stop_str: ['[INST]'],
64+
system_prefix_token_ids: [1],
65+
stop_token_ids: [2],
66+
add_role_after_system_message: false
9867
},
9968
[SupportedLocalModel['phi-2']]: {
100-
system: '',
101-
roles: ['Instruct', 'Output'],
69+
system_template: '',
70+
system_message: '',
71+
roles: {
72+
[Role.user]: 'Instruct',
73+
[Role.assistant]: 'Output'
74+
},
10275
offset: 0,
10376
seps: ['\n'],
104-
separator_style: 'Two',
105-
stop_str: '<|endoftext|>',
106-
add_bos: false,
107-
stop_tokens: [50256]
77+
stop_str: ['<|endoftext|>'],
78+
stop_token_ids: [50256]
79+
},
80+
[SupportedLocalModel['gemma-2b']]: {
81+
system_template: '',
82+
system_message: '',
83+
roles: {
84+
[Role.user]: '<start_of_turn>user',
85+
[Role.assistant]: '<start_of_turn>model'
86+
},
87+
offset: 0,
88+
seps: ['<end_of_turn>\n', '<end_of_turn>\n'],
89+
role_content_sep: '\n',
90+
role_empty_sep: '\n',
91+
stop_str: ['<end_of_turn>'],
92+
system_prefix_token_ids: [2],
93+
stop_token_ids: [1, 107]
10894
}
10995
};
11096

11197
const modelMap: Record<SupportedLocalModel, string> = {
11298
[SupportedLocalModel['tinyllama-1.1b']]: 'TinyLlama-1.1B-Chat-v0.4-q4f16_1',
11399
[SupportedLocalModel['llama-2-7b']]: 'Llama-2-7b-chat-hf-q4f16_1',
114-
[SupportedLocalModel['phi-2']]: 'Phi2-q4f16_1'
115-
// [SupportedLocalModel['gpt-2']]: 'gpt2-q0f16'
116-
// [SupportedLocalModel['mistral-7b-v0.2']]: 'Mistral-7B-Instruct-v0.2-q3f16_1'
100+
[SupportedLocalModel['phi-2']]: 'Phi2-q4f16_1',
101+
[SupportedLocalModel['gemma-2b']]: 'gemma-2b-it-q4f16_1'
117102
};
118103

119-
const chat = new webllm.ChatModule();
120-
121-
// To reset temperature, WebLLM requires to reload the model. Therefore, we just
122-
// fix the temperature for now.
123-
let _temperature = 0.2;
124-
125-
let _modelLoadingComplete: Promise<void> | null = null;
126-
127-
chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
104+
const initProgressCallback = (report: webllm.InitProgressReport) => {
128105
// Update the main thread about the progress
129106
console.log(report.text);
130107
const message: TextGenLocalWorkerMessage = {
@@ -135,7 +112,9 @@ chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
135112
}
136113
};
137114
postMessage(message);
138-
});
115+
};
116+
117+
let engine: Promise<webllm.EngineInterface> | null = null;
139118

140119
//==========================================================================||
141120
// Worker Event Handlers ||
@@ -179,15 +158,25 @@ const startLoadModel = async (
179158
model: SupportedLocalModel,
180159
temperature: number
181160
) => {
182-
_temperature = temperature;
183161
const curModel = modelMap[model];
184-
const chatOption: webllm.ChatOptions = {
185-
temperature: temperature,
186-
conv_config: CONV_TEMPLATES[model],
187-
conv_template: 'custom'
188-
};
189-
_modelLoadingComplete = chat.reload(curModel, chatOption, APP_CONFIGS);
190-
await _modelLoadingComplete;
162+
163+
// Only use custom conv template for Llama to override the pre-included system
164+
// prompt from WebLLM
165+
let chatOption: webllm.ChatOptions | undefined = undefined;
166+
167+
if (model === SupportedLocalModel['llama-2-7b']) {
168+
chatOption = {
169+
conv_config: CONV_TEMPLATES[model],
170+
conv_template: 'custom'
171+
};
172+
}
173+
174+
engine = webllm.CreateEngine(curModel, {
175+
initProgressCallback: initProgressCallback,
176+
chatOpts: chatOption
177+
});
178+
179+
await engine;
191180

192181
try {
193182
// Send back the data to the main thread
@@ -220,24 +209,26 @@ const startLoadModel = async (
220209
*/
221210
const startTextGen = async (prompt: string, temperature: number) => {
222211
try {
223-
if (_modelLoadingComplete) {
224-
await _modelLoadingComplete;
225-
}
226-
227-
const truncated = prompt.slice(0, 2000);
228-
229-
const response = await chat.generate(truncated);
212+
const curEngine = await engine!;
213+
const response = await curEngine.chat.completions.create({
214+
messages: [{ role: 'user', content: prompt }],
215+
n: 1,
216+
max_gen_len: 2048,
217+
// Override temperature to 0 because local models are very unstable
218+
temperature: 0
219+
// logprobs: false
220+
});
230221

231222
// Reset the chat cache to avoid memorizing previous messages
232-
await chat.resetChat();
223+
await curEngine.resetChat();
233224

234225
// Send back the data to the main thread
235226
const message: TextGenLocalWorkerMessage = {
236227
command: 'finishTextGen',
237228
payload: {
238229
requestID: 'web-llm',
239230
apiKey: '',
240-
result: response,
231+
result: response.choices[0].message.content || '',
241232
prompt: prompt,
242233
detail: ''
243234
}
@@ -263,7 +254,7 @@ const startTextGen = async (prompt: string, temperature: number) => {
263254

264255
export const hasLocalModelInCache = async (model: SupportedLocalModel) => {
265256
const curModel = modelMap[model];
266-
const inCache = await webllm.hasModelInCache(curModel, APP_CONFIGS);
257+
const inCache = await webllm.hasModelInCache(curModel);
267258
return inCache;
268259
};
269260

0 commit comments

Comments
 (0)