Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
913c55b
updated loading in llama 2 demo to use transformer bridge
degenfabian Aug 18, 2025
c9bc71e
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Aug 20, 2025
3b42901
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Aug 22, 2025
52f7b8a
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Aug 26, 2025
20cc89c
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 4, 2025
8319cc2
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 5, 2025
e26da3d
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 6, 2025
95aa144
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 7, 2025
d5e1327
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 10, 2025
ea0f147
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 10, 2025
bd85cb6
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 12, 2025
c419c4f
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 12, 2025
a030fdf
Merge remote-tracking branch 'origin/dev-3.x' into llama_2_demo_trans…
bryce13950 Sep 12, 2025
5dd19d1
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 10, 2025
23efa9a
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 13, 2025
65d501a
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 14, 2025
ef8951b
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 14, 2025
57fb25e
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 15, 2025
d6dc710
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 15, 2025
dd571c1
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 15, 2025
e595bab
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 16, 2025
0066317
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 16, 2025
b2af109
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 16, 2025
5bd42bf
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 16, 2025
5b588af
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 16, 2025
cbbdd0b
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 17, 2025
d7bc861
Merge remote-tracking branch 'origin/dev-3.x-folding' into llama_2_de…
bryce13950 Oct 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ jobs:
# - "Head_Detector_Demo"
# - "Interactive_Neuroscope"
# - "LLaMA"
# - "LLaMA2_GPU_Quantized"
- "LLaMA2_GPU_Quantized"
- "Main_Demo"
# - "No_Position_Experiment"
- "Othello_GPT"
Expand Down
14 changes: 8 additions & 6 deletions demos/LLaMA2_GPU_Quantized.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"id": "P8zS3MPkCUsR"
},
Expand All @@ -232,7 +232,7 @@
"from transformer_lens.hook_points import (\n",
" HookPoint,\n",
") # Hooking utilities\n",
"from transformer_lens import HookedTransformer\n",
"from transformer_lens.model_bridge import TransformerBridge\n",
"\n",
"torch.set_grad_enabled(False)\n",
"\n",
Expand Down Expand Up @@ -291,7 +291,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"id": "RdJ0AuW_CUsS"
},
Expand All @@ -303,7 +303,8 @@
" tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)\n",
" hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)\n",
"\n",
" model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
" model = TransformerBridge.boot_transformers(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
" model.enable_compatibility_mode()\n",
"\n",
" model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)"
Expand Down Expand Up @@ -406,7 +407,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -730,14 +731,15 @@
"\n",
"tokenizer = AutoTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)\n",
"\n",
"model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH,\n",
"model = TransformerBridge.boot_transformers(LLAMA_2_7B_CHAT_PATH,\n",
" hf_model=hf_model,\n",
" dtype=inference_dtype,\n",
" fold_ln=False,\n",
" fold_value_biases=False,\n",
" center_writing_weights=False,\n",
" center_unembed=False,\n",
" tokenizer=tokenizer)\n",
"model.enable_compatibility_mode()\n",
"\n",
"model.generate(\"The capital of Germany is\", max_new_tokens=2, temperature=0)\n",
"\n"
Expand Down
Loading