{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e0ef8d28", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/yiyunzhu/Library/Caches/pypoetry/virtualenvs/speculativedecoding-B0TTdUOs-py3.13/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "`torch_dtype` is deprecated! Use `dtype` instead!\n", "Fetching 2 files: 100%|██████████| 2/2 [00:58<00:00, 29.32s/it]\n", "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 27.91it/s]\n" ] } ], "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed\n", "\n", "set_seed(67)\n", "\n", "device = \"mps\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-Coder-0.5B-Instruct\") #HuggingFaceTB/SmolLM2-135M-Instruct\n", "draft_model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-Coder-0.5B-Instruct\", torch_dtype=torch.bfloat16).to(device)\n", "verify_model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-Coder-3B-Instruct\", torch_dtype=torch.bfloat16).to(device) #HuggingFaceTB/SmolLM2-1.7B-Instruct" ] }, { "cell_type": "code", "execution_count": 8, "id": "30d81505", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'The quick brown fox jumps over the lazy dog'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt = \"The quick brown fox\"\n", "inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", "input_ids = inputs[\"input_ids\"]\n", "\n", "generated = input_ids.clone() # [1, seq_len]\n", "draft_probs = []\n", "for _ in range(5): # gamma = 5\n", " with torch.no_grad():\n", " outputs = draft_model(generated)\n", " logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size\n", " \n", " probs = torch.softmax(logits, dim=-1) # [1, 50257]\n", " next_token = torch.multinomial(probs, num_samples=1)\n", "\n", " draft_probs.append(probs)\n", " generated = torch.cat([generated, next_token], dim=-1)\n", "\n", "tokenizer.decode(generated[0])" ] }, { "cell_type": "code", "execution_count": 9, "id": "2492ee58", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Token: ' jumps', q(x)=0.9727, p(x)=0.9023\n", " -> Accepted\n", "Token: ' over', q(x)=1.0000, p(x)=0.9961\n", " -> Accepted\n", "Token: ' the', q(x)=0.9023, p(x)=0.9102\n", " -> Accepted\n", "Token: ' lazy', q(x)=1.0000, p(x)=0.9922\n", " -> Accepted\n", "Token: ' dog', q(x)=0.9922, p(x)=0.9844\n", " -> Accepted\n", " jumps over the lazy dog\n", "[34208, 916, 279, 15678, 5562]\n" ] } ], "source": [ "with torch.no_grad():\n", " target_outputs = verify_model(generated)\n", " target_logits = target_outputs.logits[:, -6:-1, :]\n", "\n", "target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]\n", "accepted_tokens = []\n", "for i in range(5):\n", " # if q(x) <= p(x), keep\n", " # if q(x) > p(x), reject with 1-p(x)/q(x) chance\n", " # if rejected, we sample from norm(max(0, p(x) - q(x)))\n", " q = draft_probs[i] # [1, 50257]\n", " p = target_probs[:, i, :] # [1, 50257]\n", " token = generated[:, i - 5] # [1]\n", " # assume unbatched for now\n", " x = token[0].item()\n", " q_x = q[0, x].item()\n", " p_x = p[0, x].item()\n", "\n", " print(f\"Token: '{tokenizer.decode(x)}', q(x)={q_x:.4f}, p(x)={p_x:.4f}\")\n", " if q_x <= p_x:\n", " print(\" -> Accepted\")\n", " accepted_tokens.append(x)\n", " else:\n", " r = torch.rand(1).item()\n", " acceptance_rate = p_x / q_x\n", "\n", " if r < acceptance_rate:\n", " print(\" -> Accepted\")\n", " accepted_tokens.append(x)\n", " else:\n", " print(\" -> Rejected\")\n", " adjusted = torch.clamp(p-q, min=0)\n", " adjusted = adjusted / adjusted.sum()\n", " new_token = torch.multinomial(adjusted, num_samples=1)[0].item()\n", " accepted_tokens.append(new_token)\n", " break\n", "\n", "print(tokenizer.decode(accepted_tokens))\n", "print(accepted_tokens)" ] }, { "cell_type": "code", "execution_count": 20, "id": "2df0e2a9", "metadata": {}, "outputs": [], "source": [ "def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):\n", " generated = input_ids.clone() # [1, seq_len]\n", " draft_probs = []\n", " for _ in range(gamma): \n", " with torch.no_grad():\n", " outputs = draft_model(\n", " generated if past_kv is None else generated[:, -1:],\n", " past_key_values=past_kv,\n", " use_cache=True\n", " )\n", " logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size\n", " past_kv = outputs.past_key_values\n", " \n", " probs = torch.softmax(logits, dim=-1) # [1, 50257]\n", "\n", " confidence = probs.max().item() # dynamic speculative decoding\n", " if confidence < confidence_threshold and len(draft_probs) > 0:\n", " break \n", "\n", " next_token = torch.argmax(probs, dim=-1, keepdim=True)\n", "\n", " draft_probs.append(probs)\n", " generated = torch.cat([generated, next_token], dim=-1)\n", "\n", " if next_token.item() == eos_token:\n", " break;\n", "\n", " return generated, draft_probs, past_kv\n", "\n", "def verify(drafted, drafted_probs, eos_token, past_kv):\n", " draft_len = len(drafted_probs) # number of new drafted tokens\n", " with torch.no_grad():\n", " if past_kv is None:\n", " target_outputs = verify_model(drafted, use_cache=True)\n", " target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]\n", " else:\n", " target_outputs = verify_model(\n", " drafted[:, -(draft_len + 1):], # extra token \n", " past_key_values=past_kv,\n", " use_cache=True\n", " )\n", " target_logits = target_outputs.logits[:, :-1, :] # Drop last (predicts bonus token)\n", "\n", " past_kv = target_outputs.past_key_values\n", "\n", " target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]\n", " accepted_tokens = []\n", " num_accepted = 0 # number of tokens from drafted that is accepted\n", " for i in range(draft_len):\n", " # if q(x) <= p(x), keep\n", " # if q(x) > p(x), reject with 1-p(x)/q(x) chance\n", " # if rejected, we sample from norm(max(0, p(x) - q(x)))\n", " q = drafted_probs[i] # [1, 50257]\n", " p = target_probs[:, i, :] # [1, 50257]\n", " token = drafted[:, i - draft_len] # [1]\n", " # assume unbatched for now\n", " x = token[0].item()\n", " q_x = q[0, x].item()\n", " p_x = p[0, x].item()\n", "\n", " print(f\"Token: '{tokenizer.decode(x)}'\", end = \"\")\n", " if q_x <= p_x:\n", " print(\" -> Accepted\")\n", " accepted_tokens.append(x)\n", " num_accepted+=1\n", " else:\n", " r = torch.rand(1).item()\n", " acceptance_rate = p_x / q_x\n", "\n", " if r < acceptance_rate:\n", " print(\" -> Accepted\")\n", " accepted_tokens.append(x)\n", " num_accepted+=1\n", " else:\n", " print(\" -> Rejected\", end = \"\")\n", " adjusted = torch.clamp(p-q, min=0)\n", " adjusted = adjusted / adjusted.sum()\n", " new_token = torch.multinomial(adjusted, num_samples=1)[0].item()\n", " accepted_tokens.append(new_token)\n", " print(tokenizer.decode(new_token))\n", " break\n", " if accepted_tokens[-1] == eos_token:\n", " break\n", "\n", " return accepted_tokens, num_accepted, past_kv" ] }, { "cell_type": "code", "execution_count": 24, "id": "5378f5d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Token: 'A' -> Accepted\n", "Token: ' deal' -> Accepted\n", "Token: ' flow' -> Accepted\n", "Token: ' in' -> Accepted\n", "Token: ' a' -> Accepted\n", "Token: ' VC' -> Rejected venture\n", "Token: ' capital' -> Accepted\n", "Token: ' fund' -> Rejected (\n", "Token: 'VC' -> Accepted\n", "Token: ')' -> Accepted\n", "Token: ' fund' -> Accepted\n", "Token: ' is' -> Rejected refers\n", "Token: ' to' -> Accepted\n", "Token: ' the' -> Accepted\n", "Token: ' process' -> Accepted\n", "Token: ' of' -> Accepted\n", "Token: ' setting' -> Rejected screening\n", "Token: ',' -> Accepted\n", "Token: ' evaluating' -> Accepted\n", "Token: ',' -> Accepted\n", "Token: ' and' -> Accepted\n", "Token: ' selecting' -> Accepted\n", "Token: ' investors' -> Rejected potential\n", "Token: ' investment' -> Accepted\n", "Token: ' in' -> Rejected opportunities\n", "Token: ' for' -> Rejected.\n", "Token: ' It' -> Accepted\n", "Token: ' involves' -> Accepted\n", "Token: ' several' -> Rejected identifying\n", "Token: ' potential' -> Accepted\n", "Token: ' investors' -> Rejected companies\n", "Token: ' that' -> Accepted\n", "Token: ' could' -> Rejected the\n", "Token: ' VC' -> Accepted\n", "Token: ' fund' -> Accepted\n", "Token: 'ers' -> Rejected is\n", "Token: ' interested' -> Accepted\n", "Token: ' in' -> Accepted\n", "Token: ',' -> Accepted\n", "Token: ' assessing' -> Accepted\n", "Token: ' their' -> Accepted\n", "Token: ' financial' -> Rejected growth\n", "Token: ' potential' -> Accepted\n", "Token: ',' -> Accepted\n", "Token: ' and' -> Accepted\n", "Token: ' evaluating' -> Rejected business\n", "Token: ' model' -> Accepted\n", "Token: ',' -> Accepted\n", "Token: ' and' -> Accepted\n", "Token: ',' -> Rejected market\n", "Token: ' demand' -> Accepted\n", "Token: ',' -> Accepted\n", "Token: ' and' -> Accepted\n", "Token: ' then' -> Accepted\n", "Token: ' making' -> Accepted\n", "Token: ' a' -> Accepted\n", "Token: ' decision' -> Accepted\n", "Token: ' on' -> Accepted\n", "Token: ' whether' -> Accepted\n", "Token: ' to' -> Accepted\n", "Token: ' invest' -> Accepted\n", "Token: ' in' -> Accepted\n", "Token: ' those' -> Rejected the\n", "Token: ' VC' -> Rejected fund\n", "Token: ' in' -> Rejected’s\n", "Token: ' capital' -> Accepted\n", "Token: '.' -> Accepted\n", "Token: '<|im_end|>' -> Accepted\n" ] }, { "data": { "text/plain": [ "'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is a deal flow in a VC fund?<|im_end|>\\n<|im_start|>assistant\\nA deal flow in a venture capital (VC) fund refers to the process of screening, evaluating, and selecting potential investment opportunities. It involves identifying potential companies that the VC fund is interested in, assessing their growth potential, and business model, and market demand, and then making a decision on whether to invest in the fund’s capital.<|im_end|>'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "messages = [{\"role\": \"user\", \"content\": \"What is a deal flow in a VC fund?\"}]\n", "prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", "\n", "max_tokens = 80\n", "eos_token = tokenizer.eos_token_id\n", "im_end_token = tokenizer.convert_tokens_to_ids(\"<|im_end|>\")\n", "gamma = 15\n", "confidence_threshold = 0.5\n", "\n", "inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", "result = inputs[\"input_ids\"].clone()\n", "\n", "draft_kv = None\n", "verify_kv = None\n", "\n", "total_drafted = 0\n", "total_accepted = 0\n", "\n", "while result.shape[-1] - inputs[\"input_ids\"].shape[-1] < max_tokens:\n", " drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)\n", " accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)\n", "\n", " total_drafted += len(drafted_probs)\n", " total_accepted += num_accepted\n", " \n", " valid_len = result.shape[-1] + num_accepted\n", " result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)\n", "\n", " if draft_kv is not None:\n", " draft_kv.crop(max_length=valid_len)\n", " if verify_kv is not None:\n", " verify_kv.crop(max_length=valid_len)\n", "\n", " if eos_token in accepted_tokens or im_end_token in accepted_tokens:\n", " break\n", "\n", "tokenizer.decode(result[0])\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "40c92741", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6071428571428571" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total_accepted / total_drafted" ] }, { "cell_type": "code", "execution_count": 26, "id": "0c661940", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is a deal flow in a VC fund?<|im_end|>\\n<|im_start|>assistant\\nA deal flow in a VC fund refers to the collection and processing of new investment opportunities presented to the fund for screening, evaluation, and ultimately investment.\\n\\nHere’s a basic overview:\\n\\n* **Deal Flow Collection**: Fund managers typically collect deals through email, cold calls, calendars, meeting notes, investor referrals, and referrals by other investors. They utilize various channels including online searches, investor networks, and industry'" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", "result = inputs[\"input_ids\"].clone()\n", "\n", "past_kv = None\n", "\n", "while result.shape[-1] - inputs[\"input_ids\"].shape[-1] < max_tokens:\n", " with torch.no_grad():\n", " output = verify_model(\n", " result if past_kv is None else result[:, -1:],\n", " past_key_values=past_kv,\n", " use_cache=True\n", " )\n", " logits = output.logits[:, -1, :] # batch, vocab\n", " \n", " past_kv = output.past_key_values\n", " probs = torch.softmax(logits, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", "\n", " result = torch.cat([result, next_token], dim=-1)\n", " if eos_token in next_token or im_end_token in next_token:\n", " break\n", "\n", "tokenizer.decode(result[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "c7a26417", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "speculativedecoding-B0TTdUOs-py3.13", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }