Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
07fd8c9
updated loading in attribution patching demo to use transformer bridge
degenfabian Aug 18, 2025
6b0831d
updated loading in bert demo to use transformer bridge
degenfabian Aug 18, 2025
f06e259
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Aug 20, 2025
9372054
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Aug 22, 2025
35759e4
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Aug 26, 2025
0c1abde
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 4, 2025
09e2a76
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 5, 2025
75bdc2e
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 6, 2025
2a589af
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 7, 2025
97fab2f
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 10, 2025
c9b44f2
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 10, 2025
f124e03
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 12, 2025
7a5f0e3
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 12, 2025
f67c8d6
Merge remote-tracking branch 'origin/dev-3.x' into bert_demo_transfor…
bryce13950 Sep 12, 2025
22509d6
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 10, 2025
6730dc4
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 13, 2025
96bf976
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 14, 2025
15ff003
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 14, 2025
638c199
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 15, 2025
a6cc320
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 15, 2025
592d56d
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 15, 2025
5a60db3
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 16, 2025
5e292f2
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 16, 2025
e527b45
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 16, 2025
4d58274
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 16, 2025
5f02fe3
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 16, 2025
68f13f1
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 17, 2025
43f1f43
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Oct 23, 2025
3cf3c00
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Nov 12, 2025
49329a2
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Nov 12, 2025
d23b42e
Merge remote-tracking branch 'origin/dev-3.x-folding' into bert_demo_…
bryce13950 Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ jobs:
matrix:
notebook:
# - "Activation_Patching_in_TL_Demo"
# - "Attribution_Patching_Demo"
- "Attribution_Patching_Demo"
- "ARENA_Content"
- "BERT"
- "Exploratory_Analysis_Demo"
Expand Down
3,761 changes: 3,760 additions & 1 deletion demos/Attribution_Patching_Demo.ipynb

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions demos/BERT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -157,7 +157,7 @@
"\n",
"from transformers import AutoTokenizer\n",
"\n",
"from transformer_lens import HookedEncoder, BertNextSentencePrediction"
"from transformer_lens.model_bridge import TransformerBridge"
]
},
{
Expand Down Expand Up @@ -192,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -214,7 +214,8 @@
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n",
"bert = TransformerBridge.boot_transformers(\"bert-base-cased\")\n",
"bert.enable_compatibility_mode()\n",
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
]
},
Expand Down Expand Up @@ -287,14 +288,13 @@
"metadata": {},
"source": [
"## Next Sentence Prediction\n",
"To carry out Next Sentence Prediction, you have to use the class BertNextSentencePrediction, and pass a HookedEncoder in its constructor. \n",
"Then, create a list with the two sentences you want to perform NSP on as elements and use that as input to the forward function. \n",
"To carry out Next Sentence Prediction create a list with the two sentences you want to perform NSP on as elements and use that as input to the forward function. \n",
"The model will then predict the probability of the sentence at position 1 following (i.e. being the next sentence) to the sentence at position 0."
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -308,13 +308,12 @@
}
],
"source": [
"nsp = BertNextSentencePrediction(bert)\n",
"sentence_a = \"A man walked into a grocery store.\"\n",
"sentence_b = \"He bought an apple.\"\n",
"\n",
"input = [sentence_a, sentence_b]\n",
"\n",
"predictions = nsp(input, return_type=\"predictions\")\n",
"predictions = bert(input, return_type=\"predictions\")\n",
"\n",
"print(f\"Sentence A: {sentence_a}\")\n",
"print(f\"Sentence B: {sentence_b}\")\n",
Expand Down
Loading