# app.py """ Gradio word-level attention visualizer with: - Paragraph-style wrapping and semi-transparent backgrounds per word - Proper detokenization to words (regex) - Ability to pick from many causal LMs - Trailing EOS/PAD special tokens removed (no <|endoftext|> shown) - FIX: safely reset Radio with value=None to avoid Gradio choices error """ import re from typing import List, Tuple import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import numpy as np # ========================= # Config # ========================= ALLOWED_MODELS = [ # ---- GPT-2 family "gpt2", "distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", # ---- EleutherAI (Neo/J/NeoX/Pythia) "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neox-20b", "EleutherAI/pythia-70m", "EleutherAI/pythia-160m", "EleutherAI/pythia-410m", "EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b", "EleutherAI/pythia-6.9b", "EleutherAI/pythia-12b", # ---- Meta OPT "facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-13b", "facebook/opt-30b", # ---- Mistral "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.3", "mistralai/Mistral-7B-Instruct-v0.2", # ---- TinyLlama / OpenLLaMA "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", "openlm-research/open_llama_3b", "openlm-research/open_llama_7b", # ---- Microsoft Phi "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2", # ---- Qwen "Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B", "Qwen/Qwen1.5-7B", "Qwen/Qwen2-1.5B", "Qwen/Qwen2-7B", # ---- MPT "mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct", # ---- Falcon "tiiuae/falcon-7b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b", # ---- Cerebras GPT "cerebras/Cerebras-GPT-111M", "cerebras/Cerebras-GPT-256M", "cerebras/Cerebras-GPT-590M", "cerebras/Cerebras-GPT-1.3B", "cerebras/Cerebras-GPT-2.7B", ] device = "cuda" if torch.cuda.is_available() else "cpu" model = None tokenizer = None # Word regex (words + punctuation) WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]") # ========================= # Model loading # ========================= def _safe_set_attn_impl(m): try: m.config._attn_implementation = "eager" except Exception: pass def load_model(model_name: str): """Load tokenizer+model globally.""" global model, tokenizer try: del model torch.cuda.empty_cache() except Exception: pass tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # Ensure pad token id if tokenizer.pad_token_id is None: if tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id else: tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) model = AutoModelForCausalLM.from_pretrained(model_name) _safe_set_attn_impl(model) if hasattr(model, "resize_token_embeddings") and tokenizer.pad_token_id >= model.get_input_embeddings().num_embeddings: model.resize_token_embeddings(len(tokenizer)) model.eval() model.to(device) def model_heads_layers(): try: L = int(getattr(model.config, "num_hidden_layers", 12)) except Exception: L = 12 try: H = int(getattr(model.config, "num_attention_heads", 12)) except Exception: H = 12 return max(1, L), max(1, H) # ========================= # Attention utils # ========================= def get_attention_for_token_layer( attentions, token_index, layer_index, batch_index=0, head_index=0, mean_across_layers=True, mean_across_heads=True, ): """ attentions: tuple length = #generated tokens attentions[t] -> tuple of len = num_layers, each: (batch, heads, q, k) """ token_attention = attentions[token_index] if mean_across_layers: layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k) else: layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k) batch_attention = layer_attention[int(batch_index)] # (heads, q, k) if mean_across_heads: head_attention = batch_attention.mean(dim=0) # (q, k) else: head_attention = batch_attention[int(head_index)] # (q, k) return head_attention.squeeze(0) # q==1 -> (k,) # ========================= # Tokens -> words mapping # ========================= def _words_and_map_from_tokens(gen_token_ids: List[int]) -> Tuple[List[str], List[int]]: """ From *generated* token ids, return: - words: detokenized words (regex-split) - word2tok: list where word2tok[i] = index (relative to generated) of the LAST token that composes that word. """ if not gen_token_ids: return [], [] gen_tokens_str = tokenizer.convert_ids_to_tokens(gen_token_ids) detok_text = tokenizer.convert_tokens_to_string(gen_tokens_str) words = WORD_RE.findall(detok_text) enc = tokenizer(detok_text, return_offsets_mapping=True, add_special_tokens=False) tok_offsets = enc["offset_mapping"] n = min(len(tok_offsets), len(gen_token_ids)) spans = [m.span() for m in re.finditer(WORD_RE, detok_text)] word2tok: List[int] = [] t = 0 for (ws, we) in spans: last_t = None while t < n: ts, te = tok_offsets[t] if not (te <= ws or ts >= we): last_t = t t += 1 else: if te <= ws: t += 1 else: break if last_t is None: last_t = max(0, min(n - 1, t - 1)) word2tok.append(int(last_t)) return words, word2tok # ========================= # Helpers # ========================= def _strip_trailing_special(ids: List[int]) -> List[int]: """Remove trailing EOS/PAD/other special tokens from the generated ids.""" specials = set(getattr(tokenizer, "all_special_ids", []) or []) j = len(ids) while j > 0 and ids[j - 1] in specials: j -= 1 return ids[:j] def clamp01(x: float) -> float: x = float(x) return 0.0 if x < 0 else 1.0 if x > 1 else x # ========================= # Visualization (WORD-LEVEL) # ========================= def generate_word_visualization(words: List[str], abs_word_ends: List[int], attention_values: np.ndarray, selected_token_abs_idx: int) -> str: """ Paragraph-style visualization over words. For each word, aggregate attention over its composing tokens (sum), normalize across words, and render opacity as a semi-transparent background. """ if not words or attention_values is None or len(attention_values) == 0: return ( "
" "
" "
No attention values.
" "
" "
" ) # Start..end spans from ends starts = [] for i, end in enumerate(abs_word_ends): if i == 0: starts.append(0) else: starts.append(min(abs_word_ends[i - 1] + 1, end)) # Sum attention per word word_scores = [] for i, end in enumerate(abs_word_ends): start = starts[i] if start > end: start = end s = max(0, min(start, len(attention_values) - 1)) e = max(0, min(end, len(attention_values) - 1)) if e < s: s, e = e, s word_scores.append(float(attention_values[s:e + 1].sum())) max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0) # Which word holds the selected token? selected_word_idx = None for i, end in enumerate(abs_word_ends): if selected_token_abs_idx <= end: selected_word_idx = i break if selected_word_idx is None and abs_word_ends: selected_word_idx = len(abs_word_ends) - 1 spans = [] for i, w in enumerate(words): alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0 bg = f"rgba(66,133,244,{alpha:.3f})" border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent" spans.append( f"" f"{w}" ) return ( "
" "
" "
" f" {''.join(spans)}" "
" "
" "
" ) # ========================= # Core functions # ========================= def run_generation(prompt, max_new_tokens, temperature, top_p): """Generate and prepare word-level selector + initial visualization.""" inputs = tokenizer(prompt or "", return_tensors="pt").to(device) prompt_len = inputs["input_ids"].shape[1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=True, pad_token_id=tokenizer.pad_token_id, output_attentions=True, return_dict_in_generate=True, ) all_token_ids = outputs.sequences[0].tolist() generated_token_ids = _strip_trailing_special(all_token_ids[prompt_len:]) # Words and map (word -> last generated token index) words, word2tok = _words_and_map_from_tokens(generated_token_ids) display_choices = [(w, i) for i, w in enumerate(words)] if not display_choices: return { state_attentions: None, state_all_token_ids: None, state_prompt_len: 0, state_words: None, state_word2tok: None, # SAFE RADIO RESET radio_word_selector: gr.update(choices=[], value=None), html_visualization: "
No new tokens generated.
", } first_word_idx = 0 html_init = update_visualization( first_word_idx, outputs.attentions, all_token_ids, prompt_len, 0, 0, True, True, words, word2tok, ) return { state_attentions: outputs.attentions, state_all_token_ids: all_token_ids, state_prompt_len: prompt_len, state_words: words, state_word2tok: word2tok, radio_word_selector: gr.update(choices=display_choices, value=first_word_idx), html_visualization: html_init, } def update_visualization( selected_word_index, attentions, all_token_ids, prompt_len, layer, head, mean_layers, mean_heads, words, word2tok, ): """Recompute visualization for the chosen word (maps to its last token).""" if selected_word_index is None or attentions is None or word2tok is None: return "
Generate text first.
" widx = int(selected_word_index) if not (0 <= widx < len(word2tok)): return "
Invalid selection.
" token_index_relative = int(word2tok[widx]) token_index_absolute = int(prompt_len) + token_index_relative token_attn = get_attention_for_token_layer( attentions, token_index=token_index_relative, layer_index=int(layer), head_index=int(head), mean_across_layers=bool(mean_layers), mean_across_heads=bool(mean_heads), ) attn_vals = token_attn.detach().cpu().numpy() # Pad attention to full (prompt + generated) length total_tokens = len(all_token_ids) padded = np.zeros(total_tokens, dtype=float) if attn_vals.ndim == 2: attn_vals = attn_vals[-1] padded[: len(attn_vals)] = attn_vals # Absolute word ends (prompt offset + relative token index) abs_word_ends = [int(prompt_len) + int(t) for t in (word2tok or [])] return generate_word_visualization(words, abs_word_ends, padded, token_index_absolute) def toggle_slider(is_mean): return gr.update(interactive=not bool(is_mean)) # ========================= # Gradio UI # ========================= with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🤖 Word-Level Attention Visualizer — choose a model & explore") gr.Markdown( "Pick a model, generate text, then select a **generated word** to see where it attends. " "Words wrap in a paragraph; opacity is the summed attention over the word’s tokens. " "EOS tokens are stripped so `<|endoftext|>` doesn’t appear." ) # States state_attentions = gr.State(None) state_all_token_ids = gr.State(None) state_prompt_len = gr.State(None) state_words = gr.State(None) state_word2tok = gr.State(None) state_model_name = gr.State(None) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 0) Model") dd_model = gr.Dropdown( ALLOWED_MODELS, value=ALLOWED_MODELS[0], label="Causal LM", info="Models that work with AutoModelForCausalLM + attentions" ) btn_load = gr.Button("Load / Switch Model", variant="secondary") gr.Markdown("### 1) Generation") txt_prompt = gr.Textbox("In a distant future, humanity", label="Prompt") btn_generate = gr.Button("Generate", variant="primary") slider_max_tokens = gr.Slider(10, 200, value=50, step=10, label="Max New Tokens") slider_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature") slider_top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P") gr.Markdown("### 2) Attention") check_mean_layers = gr.Checkbox(True, label="Mean Across Layers") check_mean_heads = gr.Checkbox(True, label="Mean Across Heads") slider_layer = gr.Slider(0, 11, value=0, step=1, label="Layer", interactive=False) slider_head = gr.Slider(0, 11, value=0, step=1, label="Head", interactive=False) with gr.Column(scale=3): radio_word_selector = gr.Radio( [], label="Select Generated Word to Visualize", info="Click Generate to populate" ) html_visualization = gr.HTML( "
" "Attention visualization will appear here.
" ) # Load/switch model def on_load_model(selected_name, mean_layers, mean_heads): load_model(selected_name) L, H = model_heads_layers() return ( selected_name, # state_model_name gr.update(minimum=0, maximum=L - 1, value=0, interactive=not bool(mean_layers)), gr.update(minimum=0, maximum=H - 1, value=0, interactive=not bool(mean_heads)), # SAFE RADIO RESET (avoid Value: [] not in choices) gr.update(choices=[], value=None), "
Model loaded. Generate to visualize.
", ) btn_load.click( fn=on_load_model, inputs=[dd_model, check_mean_layers, check_mean_heads], outputs=[state_model_name, slider_layer, slider_head, radio_word_selector, html_visualization], ) # Load default model at app start def _init_model(_): load_model(ALLOWED_MODELS[0]) L, H = model_heads_layers() return ( ALLOWED_MODELS[0], gr.update(minimum=0, maximum=L - 1, value=0, interactive=False if check_mean_layers.value else True), gr.update(minimum=0, maximum=H - 1, value=0, interactive=False if check_mean_heads.value else True), # Also ensure radio is clean at start gr.update(choices=[], value=None), ) demo.load(_init_model, inputs=[gr.State(None)], outputs=[state_model_name, slider_layer, slider_head, radio_word_selector]) # Generate btn_generate.click( fn=run_generation, inputs=[txt_prompt, slider_max_tokens, slider_temp, slider_top_p], outputs=[ state_attentions, state_all_token_ids, state_prompt_len, state_words, state_word2tok, radio_word_selector, html_visualization, ], ) # Update viz on any control for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]: control.change( fn=update_visualization, inputs=[ radio_word_selector, state_attentions, state_all_token_ids, state_prompt_len, slider_layer, slider_head, check_mean_layers, check_mean_heads, state_words, state_word2tok, ], outputs=html_visualization, ) # Toggle slider interactivity check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer) check_mean_heads.change(toggle_slider, check_mean_heads, slider_head) if __name__ == "__main__": print(f"Device: {device}") # Ensure a default model is ready load_model(ALLOWED_MODELS[0]) demo.launch(debug=True)