Skip to content

Commit c9be901

Browse files
committed
[FFI] Refactored WebLLM per new TVM ffi changes
1 parent d8b25fe commit c9be901

File tree

7 files changed

+286
-6
lines changed

7 files changed

+286
-6
lines changed

examples/model-tests/README.md

Whitespace-only changes.

examples/model-tests/package.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"name": "model-tests",
3+
"version": "0.1.0",
4+
"private": true,
5+
"scripts": {
6+
"start": "parcel src/model_tests.html --port 8889",
7+
"build": "parcel build src/model_tests.html --dist-dir lib"
8+
},
9+
"devDependencies": {
10+
"buffer": "^5.7.1",
11+
"parcel": "^2.8.3",
12+
"process": "^0.11.10",
13+
"tslib": "^2.3.1",
14+
"typescript": "^4.9.5",
15+
"url": "^0.11.3"
16+
},
17+
"dependencies": {
18+
"@mlc-ai/web-llm": "file:../../"
19+
}
20+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<!doctype html>
2+
<html>
3+
<script>
4+
webLLMGlobal = {};
5+
</script>
6+
<body>
7+
<h2>WebLLM Model Tester</h2>
8+
Open console to see output
9+
<br />
10+
<br />
11+
<label id="init-label"> </label>
12+
13+
<h3>Current Model</h3>
14+
<label id="current-model-label"> </label>
15+
16+
<h3>Progress</h3>
17+
<label id="progress-label"> </label>
18+
19+
<h3>Latest Response</h3>
20+
<label id="response-label"> </label>
21+
<br />
22+
<label id="stats-label"> </label>
23+
24+
<script type="module" src="./model_tests.ts"></script>
25+
</body>
26+
</html>
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import * as webllm from "@mlc-ai/web-llm";
2+
3+
function setLabel(id: string, text: string) {
4+
const label = document.getElementById(id);
5+
if (label == null) {
6+
throw Error("Cannot find label " + id);
7+
}
8+
label.innerText = text;
9+
}
10+
11+
// Models to test: uncomment the specific ones you want to test
12+
const TEST_MODELS = [
13+
// Llama 2 7B
14+
// "Llama-2-7b-chat-hf-q4f16_1-MLC",
15+
// "Llama-2-7b-chat-hf-q4f32_1-MLC",
16+
17+
// // Llama 3 8B
18+
// "Llama-3-8B-Instruct-q4f16_1-MLC",
19+
// "Llama-3-8B-Instruct-q4f32_1-MLC",
20+
21+
// // Llama 3.1 8B
22+
// "Llama-3.1-8B-Instruct-q4f16_1-MLC",
23+
// "Llama-3.1-8B-Instruct-q4f32_1-MLC",
24+
25+
// // Llama 3.2 1B, 3B
26+
// "Llama-3.2-1B-Instruct-q4f16_1-MLC",
27+
// "Llama-3.2-1B-Instruct-q4f32_1-MLC",
28+
// "Llama-3.2-3B-Instruct-q4f16_1-MLC",
29+
// "Llama-3.2-3B-Instruct-q4f32_1-MLC",
30+
31+
// // Mistral 7B v0.3
32+
// "Mistral-7B-Instruct-v0.3-q4f16_1-MLC",
33+
// "Mistral-7B-Instruct-v0.3-q4f32_1-MLC",
34+
35+
// // Phi models
36+
// "phi-1_5-q4f16_1-MLC",
37+
// "phi-1_5-q4f32_1-MLC",
38+
// "phi-2-q4f16_1-MLC",
39+
// "phi-2-q4f32_1-MLC",
40+
// "Phi-3-mini-4k-instruct-q4f16_1-MLC",
41+
// "Phi-3-mini-4k-instruct-q4f32_1-MLC",
42+
// "Phi-3.5-mini-instruct-q4f16_1-MLC",
43+
// "Phi-3.5-mini-instruct-q4f32_1-MLC",
44+
45+
// // Qwen2
46+
"Qwen2-0.5B-Instruct-q4f16_1-MLC",
47+
// "Qwen2-0.5B-Instruct-q4f32_1-MLC",
48+
// "Qwen2-1.5B-Instruct-q4f16_1-MLC",
49+
// "Qwen2-1.5B-Instruct-q4f32_1-MLC",
50+
51+
// // Qwen2.5
52+
// "Qwen2.5-3B-Instruct-q4f16_1-MLC",
53+
// "Qwen2.5-3B-Instruct-q4f32_1-MLC",
54+
55+
// // Qwen3 (including q0 for 0.6B)
56+
// "Qwen3-0.6B-q4f16_1-MLC",
57+
// "Qwen3-0.6B-q4f32_1-MLC",
58+
// "Qwen3-0.6B-q0f32-MLC",
59+
// "Qwen3-1.7B-q4f16_1-MLC",
60+
// "Qwen3-1.7B-q4f32_1-MLC",
61+
// "Qwen3-4B-q4f16_1-MLC",
62+
// "Qwen3-4B-q4f32_1-MLC",
63+
// "Qwen3-8B-q4f16_1-MLC",
64+
// "Qwen3-8B-q4f32_1-MLC",
65+
66+
// // RedPajama
67+
// "RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC",
68+
// "RedPajama-INCITE-Chat-3B-v1-q4f32_1-MLC",
69+
70+
// // SmolLM2 (including q0 for smaller ones)
71+
// "SmolLM2-135M-Instruct-q0f16-MLC",
72+
// "SmolLM2-135M-Instruct-q0f32-MLC",
73+
// "SmolLM2-360M-Instruct-q0f16-MLC",
74+
// "SmolLM2-360M-Instruct-q0f32-MLC",
75+
// "SmolLM2-1.7B-Instruct-q4f16_1-MLC",
76+
// "SmolLM2-1.7B-Instruct-q4f32_1-MLC",
77+
78+
// // TinyLlama v1.0
79+
// "TinyLlama-1.1B-Chat-v1.0-q4f16_1-MLC",
80+
// "TinyLlama-1.1B-Chat-v1.0-q4f32_1-MLC",
81+
82+
// // Gemma models
83+
// "gemma-2b-it-q4f16_1-MLC",
84+
// "gemma-2b-it-q4f32_1-MLC",
85+
// "gemma-2-2b-it-q4f16_1-MLC",
86+
// "gemma-2-2b-it-q4f32_1-MLC",
87+
// "gemma-2-9b-it-q4f16_1-MLC",
88+
// "gemma-2-9b-it-q4f32_1-MLC",
89+
90+
// // StableLM
91+
// "stablelm-2-zephyr-1_6b-q4f16_1-MLC",
92+
// "stablelm-2-zephyr-1_6b-q4f32_1-MLC",
93+
];
94+
95+
const TEST_PROMPT = "Tell me a joke.";
96+
97+
const initProgressCallback = (report: webllm.InitProgressReport) => {
98+
setLabel("init-label", report.text);
99+
};
100+
101+
async function testModel(
102+
modelId: string,
103+
modelIndex: number,
104+
totalModels: number,
105+
): Promise<boolean> {
106+
try {
107+
// print output into console
108+
console.log(
109+
`\n=== Testing Model ${modelIndex + 1}/${totalModels}: ${modelId} ===`,
110+
);
111+
setLabel(
112+
"current-model-label",
113+
`${modelId} (${modelIndex + 1}/${totalModels})`,
114+
);
115+
setLabel("progress-label", `Loading model...`);
116+
setLabel("response-label", "");
117+
118+
const startTime = Date.now();
119+
120+
const appConfig = webllm.prebuiltAppConfig;
121+
appConfig.useIndexedDBCache = true;
122+
123+
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
124+
modelId,
125+
{
126+
initProgressCallback: initProgressCallback,
127+
appConfig: appConfig,
128+
logLevel: "ERROR",
129+
},
130+
);
131+
132+
const loadTime = Date.now() - startTime;
133+
console.log(`Model loaded in ${(loadTime / 1000).toFixed(1)}s`);
134+
setLabel(
135+
"progress-label",
136+
`Model loaded in ${(loadTime / 1000).toFixed(1)}s. Generating...`,
137+
);
138+
139+
// Test chat completion
140+
const generateStart = Date.now();
141+
const reply = await engine.chat.completions.create({
142+
messages: [{ role: "user", content: TEST_PROMPT }],
143+
temperature: 0.1,
144+
max_tokens: 500,
145+
});
146+
147+
const generateTime = Date.now() - generateStart;
148+
const response = reply.choices[0]?.message?.content || "No response";
149+
150+
console.log(`Generated response in ${(generateTime / 1000).toFixed(1)}s`);
151+
console.log(`Response: "${response}"`);
152+
153+
setLabel(
154+
"response-label",
155+
response.substring(0, 200) + (response.length > 200 ? "..." : ""),
156+
);
157+
setLabel(
158+
"stats-label",
159+
`Load: ${(loadTime / 1000).toFixed(1)}s, Generate: ${(generateTime / 1000).toFixed(1)}s, Tokens: ${reply.usage?.completion_tokens || "?"}`,
160+
);
161+
162+
// Clear cache for this model
163+
setLabel("progress-label", `Clearing cache...`);
164+
await webllm.deleteModelAllInfoInCache(modelId, appConfig);
165+
console.log(`Cleared cache for ${modelId}`);
166+
167+
return true;
168+
} catch (error) {
169+
console.error(`Error testing ${modelId}:`, error);
170+
setLabel("response-label", `Error: ${error.message}`);
171+
setLabel("progress-label", `Error with ${modelId}`);
172+
173+
// Still try to clear cache even if test failed
174+
try {
175+
const appConfig = webllm.prebuiltAppConfig;
176+
appConfig.useIndexedDBCache = true;
177+
await webllm.deleteModelAllInfoInCache(modelId, appConfig);
178+
console.log(`Cleared cache for ${modelId} (after error)`);
179+
} catch (clearError) {
180+
console.error(`Failed to clear cache for ${modelId}:`, clearError);
181+
}
182+
183+
return false;
184+
}
185+
}
186+
187+
async function main() {
188+
console.log("Starting WebLLM Model Testing");
189+
console.log(`Testing ${TEST_MODELS.length} chat models`);
190+
191+
const results = {
192+
passed: 0,
193+
failed: 0,
194+
total: TEST_MODELS.length,
195+
};
196+
197+
setLabel("current-model-label", "Starting tests...");
198+
setLabel("progress-label", `0/${TEST_MODELS.length} models tested`);
199+
200+
for (let i = 0; i < TEST_MODELS.length; i++) {
201+
const modelId = TEST_MODELS[i];
202+
const success = await testModel(modelId, i, TEST_MODELS.length);
203+
204+
if (success) {
205+
results.passed++;
206+
} else {
207+
results.failed++;
208+
}
209+
210+
setLabel(
211+
"progress-label",
212+
`${i + 1}/${TEST_MODELS.length} models tested (${results.passed} passed, ${results.failed} failed)`,
213+
);
214+
215+
await new Promise((resolve) => setTimeout(resolve, 1000));
216+
}
217+
218+
console.log(`\nTesting completed!`);
219+
console.log(
220+
`Results: ${results.passed}/${results.total} models passed (${Math.round((results.passed / results.total) * 100)}%)`,
221+
);
222+
console.log(`Passed: ${results.passed}`);
223+
console.log(`Failed: ${results.failed}`);
224+
225+
setLabel("current-model-label", "All tests completed!");
226+
setLabel(
227+
"progress-label",
228+
`Final: ${results.passed}/${results.total} passed (${Math.round((results.passed / results.total) * 100)}%)`,
229+
);
230+
setLabel("response-label", "Check console for full results");
231+
setLabel("stats-label", `${results.passed} passed, ${results.failed} failed`);
232+
}
233+
234+
main();

src/config.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ export interface AppConfig {
287287
* @note The model version does not have to match the npm version, since not each npm update
288288
* requires an update of the model libraries.
289289
*/
290-
export const modelVersion = "v0_2_48";
290+
export const modelVersion = "v0_2_80";
291291
export const modelLibURLPrefix =
292292
"https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/web-llm-models/";
293293

@@ -1190,7 +1190,7 @@ export const prebuiltAppConfig: AppConfig = {
11901190
model_lib:
11911191
modelLibURLPrefix +
11921192
modelVersion +
1193-
"/Qwen2-0.5B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm",
1193+
"/Qwen2-0.5B-Instruct-testtokenizer-q4f16_1-ctx4k_cs1k-webgpu.wasm",
11941194
low_resource_required: true,
11951195
vram_required_MB: 944.62,
11961196
overrides: {
@@ -1322,7 +1322,7 @@ export const prebuiltAppConfig: AppConfig = {
13221322
model_lib:
13231323
modelLibURLPrefix +
13241324
modelVersion +
1325-
"/Qwen2-0.5B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm",
1325+
"/Qwen2-0.5B-Instruct-testtokenizer-q4f16_1-ctx4k_cs1k-webgpu.wasm",
13261326
low_resource_required: true,
13271327
vram_required_MB: 944.62,
13281328
overrides: {
@@ -1677,7 +1677,7 @@ export const prebuiltAppConfig: AppConfig = {
16771677
model_lib:
16781678
modelLibURLPrefix +
16791679
modelVersion +
1680-
"/Qwen2-0.5B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm",
1680+
"/Qwen2-0.5B-Instruct-testtokenizer-q4f16_1-ctx4k_cs1k-webgpu.wasm",
16811681
low_resource_required: true,
16821682
vram_required_MB: 944.62,
16831683
overrides: {

src/embedding.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ export class EmbeddingPipeline {
4949
// 2. Get json stored in the vm's metadata function
5050
const fgetMetadata = this.vm.getFunction("_metadata");
5151
const ret_value = fgetMetadata();
52-
const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
52+
const metadataStr = ret_value.toString();
5353
const metadata = JSON.parse(metadataStr);
5454

5555
// 3. Load parameters by name

src/llm_chat.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ export class LLMChatPipeline {
201201
// 2. Get json stored in the vm's metadata function
202202
const fgetMetadata = this.vm.getFunction("_metadata");
203203
const ret_value = fgetMetadata();
204-
const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
204+
const metadataStr = ret_value.toString();
205205
const metadata = JSON.parse(metadataStr);
206206

207207
// 3. Load parameters by name

0 commit comments

Comments
 (0)