diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index 52ac7c12..84e9cd18 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# CircuitsVis Demonstration" + "# CircuitsVis Demonstration\n" ] }, { @@ -15,7 +15,7 @@ "source": [ "## Setup/Imports\n", "\n", - "__Note:__ To run Jupyter directly within this repo, you may need to run `poetry run pip install jupyter`." + "**Note:** To run Jupyter directly within this repo, you may need to run `poetry run pip install jupyter`.\n" ] }, { @@ -27,8 +27,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "In dev mode: True\n" ] } ], @@ -39,12 +38,14 @@ "\n", "# Imports\n", "import numpy as np\n", - "from circuitsvis.attention import attention_patterns, attention_pattern\n", + "from circuitsvis.attention import attention_patterns, attention_pattern, attention_heads\n", "from circuitsvis.activations import text_neuron_activations\n", - "from circuitsvis.examples import hello\n", "from circuitsvis.tokens import colored_tokens\n", "from circuitsvis.topk_tokens import topk_tokens\n", - "from circuitsvis.topk_samples import topk_samples" + "from circuitsvis.topk_samples import topk_samples\n", + "\n", + "from circuitsvis.utils.render import is_in_dev_mode\n", + "print(\"In dev mode:\", is_in_dev_mode())" ] }, { @@ -53,82 +54,9906 @@ "tags": [] }, "source": [ - "## Built In Visualizations" + "## Built In Visualizations\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Activations\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Text Neuron Activations (single sample)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", + "n_layers = 3\n", + "n_neurons_per_layer = 4\n", + "activations = np.random.normal(size=(len(tokens), n_layers, n_neurons_per_layer))\n", + "activations = np.exp(activations) / np.exp(activations).sum(axis=0, keepdims=True)\n", + "text_neuron_activations(tokens=tokens, activations=activations)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Text Neuron Activations (multiple samples)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n", + "n_layers = 3\n", + "n_neurons_per_layer = 4\n", + "activations = []\n", + "for sample in tokens:\n", + " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n", + " activations.append(sample_activations)\n", + "text_neuron_activations(tokens=tokens, activations=activations)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Activations" + "### Attention\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "#### Text Neuron Activations (single sample)" + "#### Attention Heads\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", - "n_layers = 3\n", - "n_neurons_per_layer = 4\n", - "activations = np.random.normal(size=(len(tokens), n_layers, n_neurons_per_layer))\n", - "activations = np.exp(activations) / np.exp(activations).sum(axis=0, keepdims=True) \n", - "text_neuron_activations(tokens=tokens, activations=activations)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Text Neuron Activations (multiple samples)" + "n_tokens = len(tokens)\n", + "single_head_attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(n_tokens, n_tokens)))\n", + "attention_heads(tokens=tokens, attention=single_head_attention, show_tokens=True)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n", "n_layers = 3\n", - "n_neurons_per_layer = 4\n", - "activations = []\n", - "for sample in tokens:\n", - " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n", - " activations.append(sample_activations)\n", - "text_neuron_activations(tokens=tokens, activations=activations)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Attention" + "n_heads = 4\n", + "\n", + "multi_head_attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(n_layers * n_heads, n_tokens, n_tokens)))\n", + "head_names = [f\"L{layer}H{head}\" for layer in range(n_layers) for head in range(n_heads)]\n", + "attention_heads(tokens=tokens, attention=multi_head_attention, attention_head_names=head_names, show_tokens=True, match_color=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Pattern (single head)" + "#### Attention Pattern (single head)\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -14797,67 +24602,67 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Patterns" + "#### Attention Patterns [deprecated]\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -19701,74 +29506,74 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Tokens" + "### Tokens\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Colored Tokens" + "#### Colored Tokens\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -24613,67 +34418,67 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Topk Tokens Table" + "### Topk Tokens Table\n" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -29527,67 +39332,67 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Topk Samples" + "### Topk Samples\n" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -34437,25 +44242,24 @@ "activations = []\n", "for neuron in range(len(tokens)):\n", " neuron_acts = []\n", - " \n", + "\n", " for k in range(len(tokens[0])):\n", " acts = (np.random.normal(size=(len(tokens[neuron][k]))) * 5).tolist()\n", " neuron_acts.append(acts)\n", " activations.append(neuron_acts)\n", - " \n", + "\n", "# Assume we have an arbitrary selection of neurons\n", "neuron_labels = [2, 7, 9]\n", "# Wrap tokens and activations in an outer list to represent the single layer\n", - "topk_samples(tokens=[tokens], activations=[activations], zeroth_dimension_name=\"Layer\", first_dimension_name=\"Neuron\", first_dimension_labels=neuron_labels)\n", - "\n" + "topk_samples(tokens=[tokens], activations=[activations], zeroth_dimension_name=\"Layer\", first_dimension_name=\"Neuron\", first_dimension_labels=neuron_labels)" ] } ], "metadata": { "kernelspec": { - "display_name": "circuitsvis-env", + "display_name": ".venv", "language": "python", - "name": "circuitsvis-env" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -34467,12 +44271,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "vscode": { - "interpreter": { - "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee" - } + "version": "3.12.3" } }, "nbformat": 4, diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py index 3da78d4e..317902b9 100644 --- a/python/circuitsvis/attention.py +++ b/python/circuitsvis/attention.py @@ -3,6 +3,7 @@ import numpy as np import torch + from circuitsvis.utils.render import RenderedHTML, render @@ -15,6 +16,8 @@ def attention_heads( negative_color: Optional[str] = None, positive_color: Optional[str] = None, mask_upper_tri: Optional[bool] = None, + show_tokens: Optional[bool] = None, + match_color: Optional[bool] = None, ) -> RenderedHTML: """Attention Heads @@ -24,8 +27,8 @@ def attention_heads( is then shown in full size. Args: - attention: Attention head activations of the shape [dest_tokens x - src_tokens] + attention: Attention head activations of the shape [heads x dest_tokens x src_tokens] + or [dest_tokens x src_tokens] (will be expanded to single head) tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length as the list of values. max_value: Maximum value. Used to determine how dark the token color is @@ -41,12 +44,43 @@ def attention_heads( mask_upper_tri: Whether or not to mask the upper triangular portion of the attention patterns. Should be true for causal attention, false for bidirectional attention. + show_tokens: Whether to show interactive token visualization where + hovering over tokens shows attention strength to other tokens. + match_color: Whether to match colors between attention patterns, token + visualization, and head headers for visual consistency. Returns: Html: Attention pattern visualization """ + + # Convert attention to numpy array if it's not already + attention_np: np.ndarray + if isinstance(attention, torch.Tensor): + attention_np = attention.detach().cpu().numpy() + elif isinstance(attention, np.ndarray): + attention_np = attention + else: + attention_np = np.array(attention) + + # Ensure attention is 3D (num_heads, dest_len, src_len) + if attention_np.ndim == 2: + attention_np = attention_np[np.newaxis, :, :] + elif attention_np.ndim != 3: + raise ValueError( + f"Attention tensor must be 2D or 3D, got {attention_np.ndim}D tensor." + ) + + num_heads, dest_len, src_len = attention_np.shape + + # Validate token count matches attention dimensions + if len(tokens) != dest_len or len(tokens) != src_len: + raise ValueError( + f"Token count ({len(tokens)}) doesn't match attention dimensions " + f"(dest: {dest_len}, src: {src_len}). For causal attention, these should all be equal." + ) + kwargs = { - "attention": attention, + "attention": attention_np, "attentionHeadNames": attention_head_names, "maxValue": max_value, "minValue": min_value, @@ -54,6 +88,8 @@ def attention_heads( "positiveColor": positive_color, "tokens": tokens, "maskUpperTri": mask_upper_tri, + "showTokens": show_tokens, + "matchColor": match_color, } return render( diff --git a/react/src/attention/AttentionHeads.tsx b/react/src/attention/AttentionHeads.tsx index 67251047..abc43893 100644 --- a/react/src/attention/AttentionHeads.tsx +++ b/react/src/attention/AttentionHeads.tsx @@ -1,6 +1,8 @@ -import React from "react"; +import React, { useMemo, useState } from "react"; import { Col, Container, Row } from "react-grid-system"; import { AttentionPattern } from "./AttentionPattern"; +import { colorAttentionTensors } from "./AttentionPatterns"; +import { Tokens, TokensView } from "./components/AttentionTokens"; import { useHoverLock, UseHoverLockState } from "./components/useHoverLock"; /** @@ -35,6 +37,7 @@ export function AttentionHeadsSelector({ onMouseLeave, positiveColor, maskUpperTri, + matchColor, tokens }: AttentionHeadsProps & { attentionHeadNames: string[]; @@ -88,8 +91,12 @@ export function AttentionHeadsSelector({ showAxisLabels={false} maxValue={maxValue} minValue={minValue} - negativeColor={negativeColor} - positiveColor={positiveColor} + negativeColor={matchColor ? undefined : negativeColor} + positiveColor={ + matchColor + ? attentionHeadColor(idx, attention.length) + : positiveColor + } maskUpperTri={maskUpperTri} />
@@ -115,14 +122,41 @@ export function AttentionHeads({ negativeColor, positiveColor, maskUpperTri = true, + showTokens = true, + matchColor = false, tokens }: AttentionHeadsProps) { // Attention head focussed state const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0); + // State for the token view type + const [tokensView, setTokensView] = useState( + TokensView.DESTINATION_TO_SOURCE + ); + + // State for which token is focussed + const { + focused: focussedToken, + onClick: onClickToken, + onMouseEnter: onMouseEnterToken, + onMouseLeave: onMouseLeaveToken + } = useHoverLock(); + const headNames = attentionHeadNames || attention.map((_, idx) => `Head ${idx}`); + // Color the attention values (by head) for interactive tokens + const coloredAttention = useMemo(() => { + if (!showTokens || !attention || attention.length === 0) return null; + const numHeads = attention.length; + const numDestTokens = attention[0]?.length || 0; + const numSrcTokens = attention[0]?.[0]?.length || 0; + + if (numDestTokens === 0 || numSrcTokens === 0 || numHeads === 0) + return null; + return colorAttentionTensors(attention); + }, [attention, showTokens]); + return (

@@ -141,6 +175,7 @@ export function AttentionHeads({ onMouseLeave={onMouseLeave} positiveColor={positiveColor} maskUpperTri={maskUpperTri} + matchColor={matchColor} tokens={tokens} /> @@ -166,8 +201,12 @@ export function AttentionHeads({ attention={attention[focused]} maxValue={maxValue} minValue={minValue} - negativeColor={negativeColor} - positiveColor={positiveColor} + negativeColor={matchColor ? undefined : negativeColor} + positiveColor={ + matchColor + ? attentionHeadColor(focused, attention.length) + : positiveColor + } zoomed={true} maskUpperTri={maskUpperTri} tokens={tokens} @@ -176,6 +215,42 @@ export function AttentionHeads({ + {showTokens && coloredAttention && ( + + +
+

+ Tokens + (click to focus) +

+ +
+ +
+
+ +
+ )} + ); @@ -262,6 +337,24 @@ export interface AttentionHeadsProps { */ showAxisLabels?: boolean; + /** + * Show interactive tokens + * + * Whether to show interactive token visualization where hovering over tokens shows attention strength to other tokens. + * + * @default true + */ + showTokens?: boolean; + + /** + * Match colors + * + * Whether to match colors between attention patterns, token visualization, and head headers for visual consistency. + * + * @default true + */ + matchColor?: boolean; + /** * List of tokens *