diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb
index 52ac7c12..84e9cd18 100644
--- a/python/Demonstration.ipynb
+++ b/python/Demonstration.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# CircuitsVis Demonstration"
+ "# CircuitsVis Demonstration\n"
]
},
{
@@ -15,7 +15,7 @@
"source": [
"## Setup/Imports\n",
"\n",
- "__Note:__ To run Jupyter directly within this repo, you may need to run `poetry run pip install jupyter`."
+ "**Note:** To run Jupyter directly within this repo, you may need to run `poetry run pip install jupyter`.\n"
]
},
{
@@ -27,8 +27,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
+ "In dev mode: True\n"
]
}
],
@@ -39,12 +38,14 @@
"\n",
"# Imports\n",
"import numpy as np\n",
- "from circuitsvis.attention import attention_patterns, attention_pattern\n",
+ "from circuitsvis.attention import attention_patterns, attention_pattern, attention_heads\n",
"from circuitsvis.activations import text_neuron_activations\n",
- "from circuitsvis.examples import hello\n",
"from circuitsvis.tokens import colored_tokens\n",
"from circuitsvis.topk_tokens import topk_tokens\n",
- "from circuitsvis.topk_samples import topk_samples"
+ "from circuitsvis.topk_samples import topk_samples\n",
+ "\n",
+ "from circuitsvis.utils.render import is_in_dev_mode\n",
+ "print(\"In dev mode:\", is_in_dev_mode())"
]
},
{
@@ -53,82 +54,9906 @@
"tags": []
},
"source": [
- "## Built In Visualizations"
+ "## Built In Visualizations\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Activations\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Text Neuron Activations (single sample)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n",
+ "n_layers = 3\n",
+ "n_neurons_per_layer = 4\n",
+ "activations = np.random.normal(size=(len(tokens), n_layers, n_neurons_per_layer))\n",
+ "activations = np.exp(activations) / np.exp(activations).sum(axis=0, keepdims=True)\n",
+ "text_neuron_activations(tokens=tokens, activations=activations)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Text Neuron Activations (multiple samples)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n",
+ "n_layers = 3\n",
+ "n_neurons_per_layer = 4\n",
+ "activations = []\n",
+ "for sample in tokens:\n",
+ " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n",
+ " activations.append(sample_activations)\n",
+ "text_neuron_activations(tokens=tokens, activations=activations)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Activations"
+ "### Attention\n"
]
},
{
- "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Text Neuron Activations (single sample)"
+ "#### Attention Heads\n"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n",
- "n_layers = 3\n",
- "n_neurons_per_layer = 4\n",
- "activations = np.random.normal(size=(len(tokens), n_layers, n_neurons_per_layer))\n",
- "activations = np.exp(activations) / np.exp(activations).sum(axis=0, keepdims=True) \n",
- "text_neuron_activations(tokens=tokens, activations=activations)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Text Neuron Activations (multiple samples)"
+ "n_tokens = len(tokens)\n",
+ "single_head_attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(n_tokens, n_tokens)))\n",
+ "attention_heads(tokens=tokens, attention=single_head_attention, show_tokens=True)"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n",
"n_layers = 3\n",
- "n_neurons_per_layer = 4\n",
- "activations = []\n",
- "for sample in tokens:\n",
- " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n",
- " activations.append(sample_activations)\n",
- "text_neuron_activations(tokens=tokens, activations=activations)\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Attention"
+ "n_heads = 4\n",
+ "\n",
+ "multi_head_attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(n_layers * n_heads, n_tokens, n_tokens)))\n",
+ "head_names = [f\"L{layer}H{head}\" for layer in range(n_layers) for head in range(n_heads)]\n",
+ "attention_heads(tokens=tokens, attention=multi_head_attention, attention_head_names=head_names, show_tokens=True, match_color=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Attention Pattern (single head)"
+ "#### Attention Pattern (single head)\n"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -14797,67 +24602,67 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Attention Patterns"
+ "#### Attention Patterns [deprecated]\n"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -19701,74 +29506,74 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Tokens"
+ "### Tokens\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Colored Tokens"
+ "#### Colored Tokens\n"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -24613,67 +34418,67 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Topk Tokens Table"
+ "### Topk Tokens Table\n"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -29527,67 +39332,67 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Topk Samples"
+ "### Topk Samples\n"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 8,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -34437,25 +44242,24 @@
"activations = []\n",
"for neuron in range(len(tokens)):\n",
" neuron_acts = []\n",
- " \n",
+ "\n",
" for k in range(len(tokens[0])):\n",
" acts = (np.random.normal(size=(len(tokens[neuron][k]))) * 5).tolist()\n",
" neuron_acts.append(acts)\n",
" activations.append(neuron_acts)\n",
- " \n",
+ "\n",
"# Assume we have an arbitrary selection of neurons\n",
"neuron_labels = [2, 7, 9]\n",
"# Wrap tokens and activations in an outer list to represent the single layer\n",
- "topk_samples(tokens=[tokens], activations=[activations], zeroth_dimension_name=\"Layer\", first_dimension_name=\"Neuron\", first_dimension_labels=neuron_labels)\n",
- "\n"
+ "topk_samples(tokens=[tokens], activations=[activations], zeroth_dimension_name=\"Layer\", first_dimension_name=\"Neuron\", first_dimension_labels=neuron_labels)"
]
}
],
"metadata": {
"kernelspec": {
- "display_name": "circuitsvis-env",
+ "display_name": ".venv",
"language": "python",
- "name": "circuitsvis-env"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -34467,12 +44271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.8"
- },
- "vscode": {
- "interpreter": {
- "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee"
- }
+ "version": "3.12.3"
}
},
"nbformat": 4,
diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py
index 3da78d4e..317902b9 100644
--- a/python/circuitsvis/attention.py
+++ b/python/circuitsvis/attention.py
@@ -3,6 +3,7 @@
import numpy as np
import torch
+
from circuitsvis.utils.render import RenderedHTML, render
@@ -15,6 +16,8 @@ def attention_heads(
negative_color: Optional[str] = None,
positive_color: Optional[str] = None,
mask_upper_tri: Optional[bool] = None,
+ show_tokens: Optional[bool] = None,
+ match_color: Optional[bool] = None,
) -> RenderedHTML:
"""Attention Heads
@@ -24,8 +27,8 @@ def attention_heads(
is then shown in full size.
Args:
- attention: Attention head activations of the shape [dest_tokens x
- src_tokens]
+ attention: Attention head activations of the shape [heads x dest_tokens x src_tokens]
+ or [dest_tokens x src_tokens] (will be expanded to single head)
tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length
as the list of values.
max_value: Maximum value. Used to determine how dark the token color is
@@ -41,12 +44,43 @@ def attention_heads(
mask_upper_tri: Whether or not to mask the upper triangular portion of
the attention patterns. Should be true for causal attention, false for
bidirectional attention.
+ show_tokens: Whether to show interactive token visualization where
+ hovering over tokens shows attention strength to other tokens.
+ match_color: Whether to match colors between attention patterns, token
+ visualization, and head headers for visual consistency.
Returns:
Html: Attention pattern visualization
"""
+
+ # Convert attention to numpy array if it's not already
+ attention_np: np.ndarray
+ if isinstance(attention, torch.Tensor):
+ attention_np = attention.detach().cpu().numpy()
+ elif isinstance(attention, np.ndarray):
+ attention_np = attention
+ else:
+ attention_np = np.array(attention)
+
+ # Ensure attention is 3D (num_heads, dest_len, src_len)
+ if attention_np.ndim == 2:
+ attention_np = attention_np[np.newaxis, :, :]
+ elif attention_np.ndim != 3:
+ raise ValueError(
+ f"Attention tensor must be 2D or 3D, got {attention_np.ndim}D tensor."
+ )
+
+ num_heads, dest_len, src_len = attention_np.shape
+
+ # Validate token count matches attention dimensions
+ if len(tokens) != dest_len or len(tokens) != src_len:
+ raise ValueError(
+ f"Token count ({len(tokens)}) doesn't match attention dimensions "
+ f"(dest: {dest_len}, src: {src_len}). For causal attention, these should all be equal."
+ )
+
kwargs = {
- "attention": attention,
+ "attention": attention_np,
"attentionHeadNames": attention_head_names,
"maxValue": max_value,
"minValue": min_value,
@@ -54,6 +88,8 @@ def attention_heads(
"positiveColor": positive_color,
"tokens": tokens,
"maskUpperTri": mask_upper_tri,
+ "showTokens": show_tokens,
+ "matchColor": match_color,
}
return render(
diff --git a/react/src/attention/AttentionHeads.tsx b/react/src/attention/AttentionHeads.tsx
index 67251047..abc43893 100644
--- a/react/src/attention/AttentionHeads.tsx
+++ b/react/src/attention/AttentionHeads.tsx
@@ -1,6 +1,8 @@
-import React from "react";
+import React, { useMemo, useState } from "react";
import { Col, Container, Row } from "react-grid-system";
import { AttentionPattern } from "./AttentionPattern";
+import { colorAttentionTensors } from "./AttentionPatterns";
+import { Tokens, TokensView } from "./components/AttentionTokens";
import { useHoverLock, UseHoverLockState } from "./components/useHoverLock";
/**
@@ -35,6 +37,7 @@ export function AttentionHeadsSelector({
onMouseLeave,
positiveColor,
maskUpperTri,
+ matchColor,
tokens
}: AttentionHeadsProps & {
attentionHeadNames: string[];
@@ -88,8 +91,12 @@ export function AttentionHeadsSelector({
showAxisLabels={false}
maxValue={maxValue}
minValue={minValue}
- negativeColor={negativeColor}
- positiveColor={positiveColor}
+ negativeColor={matchColor ? undefined : negativeColor}
+ positiveColor={
+ matchColor
+ ? attentionHeadColor(idx, attention.length)
+ : positiveColor
+ }
maskUpperTri={maskUpperTri}
/>
@@ -115,14 +122,41 @@ export function AttentionHeads({
negativeColor,
positiveColor,
maskUpperTri = true,
+ showTokens = true,
+ matchColor = false,
tokens
}: AttentionHeadsProps) {
// Attention head focussed state
const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0);
+ // State for the token view type
+ const [tokensView, setTokensView] = useState(
+ TokensView.DESTINATION_TO_SOURCE
+ );
+
+ // State for which token is focussed
+ const {
+ focused: focussedToken,
+ onClick: onClickToken,
+ onMouseEnter: onMouseEnterToken,
+ onMouseLeave: onMouseLeaveToken
+ } = useHoverLock();
+
const headNames =
attentionHeadNames || attention.map((_, idx) => `Head ${idx}`);
+ // Color the attention values (by head) for interactive tokens
+ const coloredAttention = useMemo(() => {
+ if (!showTokens || !attention || attention.length === 0) return null;
+ const numHeads = attention.length;
+ const numDestTokens = attention[0]?.length || 0;
+ const numSrcTokens = attention[0]?.[0]?.length || 0;
+
+ if (numDestTokens === 0 || numSrcTokens === 0 || numHeads === 0)
+ return null;
+ return colorAttentionTensors(attention);
+ }, [attention, showTokens]);
+
return (
@@ -141,6 +175,7 @@ export function AttentionHeads({
onMouseLeave={onMouseLeave}
positiveColor={positiveColor}
maskUpperTri={maskUpperTri}
+ matchColor={matchColor}
tokens={tokens}
/>
@@ -166,8 +201,12 @@ export function AttentionHeads({
attention={attention[focused]}
maxValue={maxValue}
minValue={minValue}
- negativeColor={negativeColor}
- positiveColor={positiveColor}
+ negativeColor={matchColor ? undefined : negativeColor}
+ positiveColor={
+ matchColor
+ ? attentionHeadColor(focused, attention.length)
+ : positiveColor
+ }
zoomed={true}
maskUpperTri={maskUpperTri}
tokens={tokens}
@@ -176,6 +215,42 @@ export function AttentionHeads({
+ {showTokens && coloredAttention && (
+
+
+
+
+ Tokens
+ (click to focus)
+
+
+
+
+
+
+
+
+ )}
+
);
@@ -262,6 +337,24 @@ export interface AttentionHeadsProps {
*/
showAxisLabels?: boolean;
+ /**
+ * Show interactive tokens
+ *
+ * Whether to show interactive token visualization where hovering over tokens shows attention strength to other tokens.
+ *
+ * @default true
+ */
+ showTokens?: boolean;
+
+ /**
+ * Match colors
+ *
+ * Whether to match colors between attention patterns, token visualization, and head headers for visual consistency.
+ *
+ * @default true
+ */
+ matchColor?: boolean;
+
/**
* List of tokens
*