manu02 commited on
Commit
8b72e45
·
verified ·
1 Parent(s): 3701608

Initial commit

Browse files
Files changed (5) hide show
  1. .gitignore +207 -0
  2. LICENSE +21 -0
  3. README.md +122 -14
  4. app.py +502 -0
  5. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Emmanuel David Muñiz
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,122 @@
1
- ---
2
- title: Token Attention Viewer
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Interactive visualization of attention weights in LLMs word-
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Token-Attention-Viewer
2
+ Token Attention Viewer is an interactive Gradio app that visualizes the self-attention weights inside transformer language models for every generated token. It helps researchers, students, and developers explore how models like GPT-2 or LLaMA focus on different parts of the input as they generate text.
3
+
4
+ # Word-Level Attention Visualizer (Gradio)
5
+
6
+ An interactive Gradio app to **generate text with a causal language model** and **visualize attention word-by-word**.
7
+ Each word in the generated continuation is shown like a paragraph; the **background opacity** behind a word reflects the **sum of attention weights** that the selected (query) word assigns to the context. You can also switch between many popular Hugging Face models.
8
+
9
+ ---
10
+
11
+ ## What the app does
12
+
13
+ * **Generate** a continuation from your prompt using a selected causal LM (GPT-2, OPT, Mistral, etc.).
14
+ * **Select a generated word** to inspect.
15
+ * **Visualize attention** as a semi-transparent background behind words (no plots/libraries like matplotlib).
16
+ * **Mean across layers/heads** or inspect a specific layer/head.
17
+ * **Proper detokenization** to real words (regex-based) and **EOS tokens are stripped** (no `<|endoftext|>` clutter).
18
+ * **Paragraph wrapping**: words wrap to new lines automatically inside the box.
19
+
20
+ ---
21
+
22
+ ## 🚀 Quickstart
23
+
24
+ ### 1) Clone
25
+
26
+ ```bash
27
+ git clone https://github.com/devMuniz02/Token-Attention-Viewer
28
+ cd Token-Attention-Viewer
29
+ ```
30
+
31
+ ### 2) (Optional) Create a virtual environment
32
+
33
+ **Windows (PowerShell):**
34
+
35
+ ```powershell
36
+ python -m venv venv
37
+ .\venv\Scripts\Activate.ps1
38
+ ```
39
+
40
+ **macOS / Linux (bash/zsh):**
41
+
42
+ ```bash
43
+ python3 -m venv venv
44
+ source venv/bin/activate
45
+ ```
46
+
47
+ ### 3) Install requirements
48
+
49
+ Install:
50
+
51
+ ```bash
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+
56
+ ### 4) Run the app
57
+
58
+ ```bash
59
+ python app.py
60
+ ```
61
+
62
+ You should see Gradio report a local URL similar to:
63
+
64
+ ```
65
+ Running on local URL: http://127.0.0.1:7860
66
+ ```
67
+
68
+ ### 5) Open in your browser
69
+
70
+ Open the printed URL (default `http://127.0.0.1:7860`) in your browser.
71
+
72
+ ---
73
+
74
+ ## 🧭 How to use
75
+
76
+ 1. **Model**: pick a model from the dropdown and click **Load / Switch Model**.
77
+
78
+ * Small models (e.g., `distilgpt2`, `gpt2`) run on CPU.
79
+ * Larger models (e.g., `mistralai/Mistral-7B-v0.1`) generally need a GPU with enough VRAM.
80
+ 2. **Prompt**: enter your starting text.
81
+ 3. **Generate**: click **Generate** to produce a continuation.
82
+ 4. **Inspect**: select any **generated word** (radio buttons).
83
+
84
+ * The paragraph box highlights where that word attends.
85
+ * Toggle **Mean Across Layers/Heads** or choose a specific **layer/head**.
86
+ 5. Repeat with different models or prompts.
87
+
88
+ ---
89
+
90
+ ## 🧩 Files
91
+
92
+ * `app.py` — Gradio application (UI + model loading + attention visualization).
93
+ * `requirements.txt` — Python dependencies (see above).
94
+ * `README.md` — this file.
95
+
96
+ ---
97
+
98
+ ## 🛠️ Troubleshooting
99
+
100
+ * **Radio/choices error**: If you switch models and see a Gradio “value not in choices” error, ensure the app resets the radio with `value=None` (the included code already does this).
101
+ * **`<|endoftext|>` shows up**: The app strips **trailing** special tokens from the generated segment, so EOS shouldn’t appear. If you still see it in the middle, your model truly generated it as a token.
102
+ * **OOM / model too large**:
103
+
104
+ * Try a smaller model (`distilgpt2`, `gpt2`, `facebook/opt-125m`).
105
+ * Reduce `Max New Tokens`.
106
+ * Use CPU for smaller models or a GPU with more VRAM for bigger ones.
107
+ * **Slow generation**: Smaller models or CPU mode will be slower; consider using GPU and the `accelerate` package.
108
+ * **Missing tokenizer pad token**: The app sets `pad_token_id = eos_token_id` automatically when needed.
109
+
110
+ ---
111
+
112
+ ## 🔒 Access-gated models
113
+
114
+ Some families (e.g., **LLaMA**, **Gemma**) require you to accept licenses or request access on Hugging Face. Make sure your Hugging Face account has access before trying to load those models.
115
+
116
+ ---
117
+
118
+
119
+ ## 📣 Acknowledgments
120
+
121
+ * Built with [Gradio](https://www.gradio.app/) and [Hugging Face Transformers](https://huggingface.co/docs/transformers).
122
+ * Attention visualization inspired by standard causal LM attention tensors available from `generate(output_attentions=True)`.
app.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Gradio word-level attention visualizer with:
4
+ - Paragraph-style wrapping and semi-transparent backgrounds per word
5
+ - Proper detokenization to words (regex)
6
+ - Ability to pick from many causal LMs
7
+ - Trailing EOS/PAD special tokens removed (no <|endoftext|> shown)
8
+ - FIX: safely reset Radio with value=None to avoid Gradio choices error
9
+ """
10
+
11
+ import re
12
+ from typing import List, Tuple
13
+
14
+ import gradio as gr
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ import torch
17
+ import numpy as np
18
+
19
+ # =========================
20
+ # Config
21
+ # =========================
22
+ ALLOWED_MODELS = [
23
+ # ---- GPT-2 family
24
+ "gpt2", "distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl",
25
+ # ---- EleutherAI (Neo/J/NeoX/Pythia)
26
+ "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B",
27
+ "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neox-20b",
28
+ "EleutherAI/pythia-70m", "EleutherAI/pythia-160m", "EleutherAI/pythia-410m",
29
+ "EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
30
+ "EleutherAI/pythia-6.9b", "EleutherAI/pythia-12b",
31
+ # ---- Meta OPT
32
+ "facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b",
33
+ "facebook/opt-6.7b", "facebook/opt-13b", "facebook/opt-30b",
34
+ # ---- Mistral
35
+ "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.3", "mistralai/Mistral-7B-Instruct-v0.2",
36
+ # ---- TinyLlama / OpenLLaMA
37
+ "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
38
+ "openlm-research/open_llama_3b", "openlm-research/open_llama_7b",
39
+ # ---- Microsoft Phi
40
+ "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2",
41
+ # ---- Qwen
42
+ "Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B", "Qwen/Qwen1.5-7B",
43
+ "Qwen/Qwen2-1.5B", "Qwen/Qwen2-7B",
44
+ # ---- MPT
45
+ "mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct",
46
+ # ---- Falcon
47
+ "tiiuae/falcon-7b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b",
48
+ # ---- Cerebras GPT
49
+ "cerebras/Cerebras-GPT-111M", "cerebras/Cerebras-GPT-256M",
50
+ "cerebras/Cerebras-GPT-590M", "cerebras/Cerebras-GPT-1.3B", "cerebras/Cerebras-GPT-2.7B",
51
+ ]
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ model = None
55
+ tokenizer = None
56
+
57
+ # Word regex (words + punctuation)
58
+ WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
59
+
60
+ # =========================
61
+ # Model loading
62
+ # =========================
63
+ def _safe_set_attn_impl(m):
64
+ try:
65
+ m.config._attn_implementation = "eager"
66
+ except Exception:
67
+ pass
68
+
69
+ def load_model(model_name: str):
70
+ """Load tokenizer+model globally."""
71
+ global model, tokenizer
72
+ try:
73
+ del model
74
+ torch.cuda.empty_cache()
75
+ except Exception:
76
+ pass
77
+
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
79
+ # Ensure pad token id
80
+ if tokenizer.pad_token_id is None:
81
+ if tokenizer.eos_token_id is not None:
82
+ tokenizer.pad_token_id = tokenizer.eos_token_id
83
+ else:
84
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(model_name)
87
+ _safe_set_attn_impl(model)
88
+ if hasattr(model, "resize_token_embeddings") and tokenizer.pad_token_id >= model.get_input_embeddings().num_embeddings:
89
+ model.resize_token_embeddings(len(tokenizer))
90
+ model.eval()
91
+ model.to(device)
92
+
93
+ def model_heads_layers():
94
+ try:
95
+ L = int(getattr(model.config, "num_hidden_layers", 12))
96
+ except Exception:
97
+ L = 12
98
+ try:
99
+ H = int(getattr(model.config, "num_attention_heads", 12))
100
+ except Exception:
101
+ H = 12
102
+ return max(1, L), max(1, H)
103
+
104
+ # =========================
105
+ # Attention utils
106
+ # =========================
107
+ def get_attention_for_token_layer(
108
+ attentions,
109
+ token_index,
110
+ layer_index,
111
+ batch_index=0,
112
+ head_index=0,
113
+ mean_across_layers=True,
114
+ mean_across_heads=True,
115
+ ):
116
+ """
117
+ attentions: tuple length = #generated tokens
118
+ attentions[t] -> tuple of len = num_layers, each: (batch, heads, q, k)
119
+ """
120
+ token_attention = attentions[token_index]
121
+
122
+ if mean_across_layers:
123
+ layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
124
+ else:
125
+ layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
126
+
127
+ batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
128
+
129
+ if mean_across_heads:
130
+ head_attention = batch_attention.mean(dim=0) # (q, k)
131
+ else:
132
+ head_attention = batch_attention[int(head_index)] # (q, k)
133
+
134
+ return head_attention.squeeze(0) # q==1 -> (k,)
135
+
136
+ # =========================
137
+ # Tokens -> words mapping
138
+ # =========================
139
+ def _words_and_map_from_tokens(gen_token_ids: List[int]) -> Tuple[List[str], List[int]]:
140
+ """
141
+ From *generated* token ids, return:
142
+ - words: detokenized words (regex-split)
143
+ - word2tok: list where word2tok[i] = index (relative to generated) of the
144
+ LAST token that composes that word.
145
+ """
146
+ if not gen_token_ids:
147
+ return [], []
148
+
149
+ gen_tokens_str = tokenizer.convert_ids_to_tokens(gen_token_ids)
150
+ detok_text = tokenizer.convert_tokens_to_string(gen_tokens_str)
151
+
152
+ words = WORD_RE.findall(detok_text)
153
+
154
+ enc = tokenizer(detok_text, return_offsets_mapping=True, add_special_tokens=False)
155
+ tok_offsets = enc["offset_mapping"]
156
+ n = min(len(tok_offsets), len(gen_token_ids))
157
+
158
+ spans = [m.span() for m in re.finditer(WORD_RE, detok_text)]
159
+
160
+ word2tok: List[int] = []
161
+ t = 0
162
+ for (ws, we) in spans:
163
+ last_t = None
164
+ while t < n:
165
+ ts, te = tok_offsets[t]
166
+ if not (te <= ws or ts >= we):
167
+ last_t = t
168
+ t += 1
169
+ else:
170
+ if te <= ws:
171
+ t += 1
172
+ else:
173
+ break
174
+ if last_t is None:
175
+ last_t = max(0, min(n - 1, t - 1))
176
+ word2tok.append(int(last_t))
177
+
178
+ return words, word2tok
179
+
180
+ # =========================
181
+ # Helpers
182
+ # =========================
183
+ def _strip_trailing_special(ids: List[int]) -> List[int]:
184
+ """Remove trailing EOS/PAD/other special tokens from the generated ids."""
185
+ specials = set(getattr(tokenizer, "all_special_ids", []) or [])
186
+ j = len(ids)
187
+ while j > 0 and ids[j - 1] in specials:
188
+ j -= 1
189
+ return ids[:j]
190
+
191
+ def clamp01(x: float) -> float:
192
+ x = float(x)
193
+ return 0.0 if x < 0 else 1.0 if x > 1 else x
194
+
195
+ # =========================
196
+ # Visualization (WORD-LEVEL)
197
+ # =========================
198
+ def generate_word_visualization(words: List[str],
199
+ abs_word_ends: List[int],
200
+ attention_values: np.ndarray,
201
+ selected_token_abs_idx: int) -> str:
202
+ """
203
+ Paragraph-style visualization over words.
204
+ For each word, aggregate attention over its composing tokens (sum),
205
+ normalize across words, and render opacity as a semi-transparent background.
206
+ """
207
+ if not words or attention_values is None or len(attention_values) == 0:
208
+ return (
209
+ "<div style='width:100%;'>"
210
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
211
+ " <div style='color:#ddd;'>No attention values.</div>"
212
+ " </div>"
213
+ "</div>"
214
+ )
215
+
216
+ # Start..end spans from ends
217
+ starts = []
218
+ for i, end in enumerate(abs_word_ends):
219
+ if i == 0:
220
+ starts.append(0)
221
+ else:
222
+ starts.append(min(abs_word_ends[i - 1] + 1, end))
223
+
224
+ # Sum attention per word
225
+ word_scores = []
226
+ for i, end in enumerate(abs_word_ends):
227
+ start = starts[i]
228
+ if start > end:
229
+ start = end
230
+ s = max(0, min(start, len(attention_values) - 1))
231
+ e = max(0, min(end, len(attention_values) - 1))
232
+ if e < s:
233
+ s, e = e, s
234
+ word_scores.append(float(attention_values[s:e + 1].sum()))
235
+
236
+ max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
237
+
238
+ # Which word holds the selected token?
239
+ selected_word_idx = None
240
+ for i, end in enumerate(abs_word_ends):
241
+ if selected_token_abs_idx <= end:
242
+ selected_word_idx = i
243
+ break
244
+ if selected_word_idx is None and abs_word_ends:
245
+ selected_word_idx = len(abs_word_ends) - 1
246
+
247
+ spans = []
248
+ for i, w in enumerate(words):
249
+ alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
250
+ bg = f"rgba(66,133,244,{alpha:.3f})"
251
+ border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
252
+ spans.append(
253
+ f"<span style='display:inline-block;background:{bg};border:{border};"
254
+ f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
255
+ f"{w}</span>"
256
+ )
257
+
258
+ return (
259
+ "<div style='width:100%;'>"
260
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
261
+ " <div style='white-space:normal;line-height:1.8;'>"
262
+ f" {''.join(spans)}"
263
+ " </div>"
264
+ " </div>"
265
+ "</div>"
266
+ )
267
+
268
+ # =========================
269
+ # Core functions
270
+ # =========================
271
+ def run_generation(prompt, max_new_tokens, temperature, top_p):
272
+ """Generate and prepare word-level selector + initial visualization."""
273
+ inputs = tokenizer(prompt or "", return_tensors="pt").to(device)
274
+ prompt_len = inputs["input_ids"].shape[1]
275
+
276
+ with torch.no_grad():
277
+ outputs = model.generate(
278
+ **inputs,
279
+ max_new_tokens=int(max_new_tokens),
280
+ temperature=float(temperature),
281
+ top_p=float(top_p),
282
+ do_sample=True,
283
+ pad_token_id=tokenizer.pad_token_id,
284
+ output_attentions=True,
285
+ return_dict_in_generate=True,
286
+ )
287
+
288
+ all_token_ids = outputs.sequences[0].tolist()
289
+ generated_token_ids = _strip_trailing_special(all_token_ids[prompt_len:])
290
+
291
+ # Words and map (word -> last generated token index)
292
+ words, word2tok = _words_and_map_from_tokens(generated_token_ids)
293
+
294
+ display_choices = [(w, i) for i, w in enumerate(words)]
295
+ if not display_choices:
296
+ return {
297
+ state_attentions: None,
298
+ state_all_token_ids: None,
299
+ state_prompt_len: 0,
300
+ state_words: None,
301
+ state_word2tok: None,
302
+ # SAFE RADIO RESET
303
+ radio_word_selector: gr.update(choices=[], value=None),
304
+ html_visualization: "<div style='text-align:center;padding:20px;'>No new tokens generated.</div>",
305
+ }
306
+
307
+ first_word_idx = 0
308
+ html_init = update_visualization(
309
+ first_word_idx,
310
+ outputs.attentions,
311
+ all_token_ids,
312
+ prompt_len,
313
+ 0, 0, True, True,
314
+ words,
315
+ word2tok,
316
+ )
317
+
318
+ return {
319
+ state_attentions: outputs.attentions,
320
+ state_all_token_ids: all_token_ids,
321
+ state_prompt_len: prompt_len,
322
+ state_words: words,
323
+ state_word2tok: word2tok,
324
+ radio_word_selector: gr.update(choices=display_choices, value=first_word_idx),
325
+ html_visualization: html_init,
326
+ }
327
+
328
+ def update_visualization(
329
+ selected_word_index,
330
+ attentions,
331
+ all_token_ids,
332
+ prompt_len,
333
+ layer,
334
+ head,
335
+ mean_layers,
336
+ mean_heads,
337
+ words,
338
+ word2tok,
339
+ ):
340
+ """Recompute visualization for the chosen word (maps to its last token)."""
341
+ if selected_word_index is None or attentions is None or word2tok is None:
342
+ return "<div style='text-align:center;padding:20px;'>Generate text first.</div>"
343
+
344
+ widx = int(selected_word_index)
345
+ if not (0 <= widx < len(word2tok)):
346
+ return "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
347
+
348
+ token_index_relative = int(word2tok[widx])
349
+ token_index_absolute = int(prompt_len) + token_index_relative
350
+
351
+ token_attn = get_attention_for_token_layer(
352
+ attentions,
353
+ token_index=token_index_relative,
354
+ layer_index=int(layer),
355
+ head_index=int(head),
356
+ mean_across_layers=bool(mean_layers),
357
+ mean_across_heads=bool(mean_heads),
358
+ )
359
+
360
+ attn_vals = token_attn.detach().cpu().numpy()
361
+
362
+ # Pad attention to full (prompt + generated) length
363
+ total_tokens = len(all_token_ids)
364
+ padded = np.zeros(total_tokens, dtype=float)
365
+ if attn_vals.ndim == 2:
366
+ attn_vals = attn_vals[-1]
367
+ padded[: len(attn_vals)] = attn_vals
368
+
369
+ # Absolute word ends (prompt offset + relative token index)
370
+ abs_word_ends = [int(prompt_len) + int(t) for t in (word2tok or [])]
371
+
372
+ return generate_word_visualization(words, abs_word_ends, padded, token_index_absolute)
373
+
374
+ def toggle_slider(is_mean):
375
+ return gr.update(interactive=not bool(is_mean))
376
+
377
+ # =========================
378
+ # Gradio UI
379
+ # =========================
380
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
381
+ gr.Markdown("# 🤖 Word-Level Attention Visualizer — choose a model & explore")
382
+ gr.Markdown(
383
+ "Pick a model, generate text, then select a **generated word** to see where it attends. "
384
+ "Words wrap in a paragraph; opacity is the summed attention over the word’s tokens. "
385
+ "EOS tokens are stripped so `<|endoftext|>` doesn’t appear."
386
+ )
387
+
388
+ # States
389
+ state_attentions = gr.State(None)
390
+ state_all_token_ids = gr.State(None)
391
+ state_prompt_len = gr.State(None)
392
+ state_words = gr.State(None)
393
+ state_word2tok = gr.State(None)
394
+ state_model_name = gr.State(None)
395
+
396
+ with gr.Row():
397
+ with gr.Column(scale=1):
398
+ gr.Markdown("### 0) Model")
399
+ dd_model = gr.Dropdown(
400
+ ALLOWED_MODELS, value=ALLOWED_MODELS[0], label="Causal LM",
401
+ info="Models that work with AutoModelForCausalLM + attentions"
402
+ )
403
+ btn_load = gr.Button("Load / Switch Model", variant="secondary")
404
+
405
+ gr.Markdown("### 1) Generation")
406
+ txt_prompt = gr.Textbox("In a distant future, humanity", label="Prompt")
407
+ btn_generate = gr.Button("Generate", variant="primary")
408
+ slider_max_tokens = gr.Slider(10, 200, value=50, step=10, label="Max New Tokens")
409
+ slider_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature")
410
+ slider_top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P")
411
+
412
+ gr.Markdown("### 2) Attention")
413
+ check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
414
+ check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
415
+ slider_layer = gr.Slider(0, 11, value=0, step=1, label="Layer", interactive=False)
416
+ slider_head = gr.Slider(0, 11, value=0, step=1, label="Head", interactive=False)
417
+
418
+ with gr.Column(scale=3):
419
+ radio_word_selector = gr.Radio(
420
+ [], label="Select Generated Word to Visualize",
421
+ info="Click Generate to populate"
422
+ )
423
+ html_visualization = gr.HTML(
424
+ "<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
425
+ "Attention visualization will appear here.</div>"
426
+ )
427
+
428
+ # Load/switch model
429
+ def on_load_model(selected_name, mean_layers, mean_heads):
430
+ load_model(selected_name)
431
+ L, H = model_heads_layers()
432
+ return (
433
+ selected_name, # state_model_name
434
+ gr.update(minimum=0, maximum=L - 1, value=0, interactive=not bool(mean_layers)),
435
+ gr.update(minimum=0, maximum=H - 1, value=0, interactive=not bool(mean_heads)),
436
+ # SAFE RADIO RESET (avoid Value: [] not in choices)
437
+ gr.update(choices=[], value=None),
438
+ "<div style='text-align:center;padding:20px;'>Model loaded. Generate to visualize.</div>",
439
+ )
440
+
441
+ btn_load.click(
442
+ fn=on_load_model,
443
+ inputs=[dd_model, check_mean_layers, check_mean_heads],
444
+ outputs=[state_model_name, slider_layer, slider_head, radio_word_selector, html_visualization],
445
+ )
446
+
447
+ # Load default model at app start
448
+ def _init_model(_):
449
+ load_model(ALLOWED_MODELS[0])
450
+ L, H = model_heads_layers()
451
+ return (
452
+ ALLOWED_MODELS[0],
453
+ gr.update(minimum=0, maximum=L - 1, value=0, interactive=False if check_mean_layers.value else True),
454
+ gr.update(minimum=0, maximum=H - 1, value=0, interactive=False if check_mean_heads.value else True),
455
+ # Also ensure radio is clean at start
456
+ gr.update(choices=[], value=None),
457
+ )
458
+ demo.load(_init_model, inputs=[gr.State(None)], outputs=[state_model_name, slider_layer, slider_head, radio_word_selector])
459
+
460
+ # Generate
461
+ btn_generate.click(
462
+ fn=run_generation,
463
+ inputs=[txt_prompt, slider_max_tokens, slider_temp, slider_top_p],
464
+ outputs=[
465
+ state_attentions,
466
+ state_all_token_ids,
467
+ state_prompt_len,
468
+ state_words,
469
+ state_word2tok,
470
+ radio_word_selector,
471
+ html_visualization,
472
+ ],
473
+ )
474
+
475
+ # Update viz on any control
476
+ for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
477
+ control.change(
478
+ fn=update_visualization,
479
+ inputs=[
480
+ radio_word_selector,
481
+ state_attentions,
482
+ state_all_token_ids,
483
+ state_prompt_len,
484
+ slider_layer,
485
+ slider_head,
486
+ check_mean_layers,
487
+ check_mean_heads,
488
+ state_words,
489
+ state_word2tok,
490
+ ],
491
+ outputs=html_visualization,
492
+ )
493
+
494
+ # Toggle slider interactivity
495
+ check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
496
+ check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
497
+
498
+ if __name__ == "__main__":
499
+ print(f"Device: {device}")
500
+ # Ensure a default model is ready
501
+ load_model(ALLOWED_MODELS[0])
502
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ gradio
3
+ torch