In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

set_seed(67)

device = "mps"

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-0.5B-Instruct") #HuggingFaceTB/SmolLM2-135M-Instruct
draft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-0.5B-Instruct", torch_dtype=torch.bfloat16).to(device)
verify_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-3B-Instruct", torch_dtype=torch.bfloat16).to(device) #HuggingFaceTB/SmolLM2-1.7B-Instruct

 from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 2 files: 100%|██████████| 2/2 [00:58<00:00, 29.32s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 27.91it/s]


In [8]:
prompt = "The quick brown fox"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

generated = input_ids.clone() # [1, seq_len]
draft_probs = []
for _ in range(5): # gamma = 5
 with torch.no_grad():
 outputs = draft_model(generated)
 logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size
 
 probs = torch.softmax(logits, dim=-1) # [1, 50257]
 next_token = torch.multinomial(probs, num_samples=1)

 draft_probs.append(probs)
 generated = torch.cat([generated, next_token], dim=-1)

tokenizer.decode(generated[0])

'The quick brown fox jumps over the lazy dog'

In [9]:
with torch.no_grad():
 target_outputs = verify_model(generated)
 target_logits = target_outputs.logits[:, -6:-1, :]

target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]
accepted_tokens = []
for i in range(5):
 # if q(x) <= p(x), keep
 # if q(x) > p(x), reject with 1-p(x)/q(x) chance
 # if rejected, we sample from norm(max(0, p(x) - q(x)))
 q = draft_probs[i] # [1, 50257]
 p = target_probs[:, i, :] # [1, 50257]
 token = generated[:, i - 5] # [1]
 # assume unbatched for now
 x = token[0].item()
 q_x = q[0, x].item()
 p_x = p[0, x].item()

 print(f"Token: '{tokenizer.decode(x)}', q(x)={q_x:.4f}, p(x)={p_x:.4f}")
 if q_x <= p_x:
 print(" -> Accepted")
 accepted_tokens.append(x)
 else:
 r = torch.rand(1).item()
 acceptance_rate = p_x / q_x

 if r < acceptance_rate:
 print(" -> Accepted")
 accepted_tokens.append(x)
 else:
 print(" -> Rejected")
 adjusted = torch.clamp(p-q, min=0)
 adjusted = adjusted / adjusted.sum()
 new_token = torch.multinomial(adjusted, num_samples=1)[0].item()
 accepted_tokens.append(new_token)
 break

print(tokenizer.decode(accepted_tokens))
print(accepted_tokens)

Token: ' jumps', q(x)=0.9727, p(x)=0.9023
 -> Accepted
Token: ' over', q(x)=1.0000, p(x)=0.9961
 -> Accepted
Token: ' the', q(x)=0.9023, p(x)=0.9102
 -> Accepted
Token: ' lazy', q(x)=1.0000, p(x)=0.9922
 -> Accepted
Token: ' dog', q(x)=0.9922, p(x)=0.9844
 -> Accepted
 jumps over the lazy dog
[34208, 916, 279, 15678, 5562]


In [20]:
def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):
 generated = input_ids.clone() # [1, seq_len]
 draft_probs = []
 for _ in range(gamma): 
 with torch.no_grad():
 outputs = draft_model(
 generated if past_kv is None else generated[:, -1:],
 past_key_values=past_kv,
 use_cache=True
 )
 logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size
 past_kv = outputs.past_key_values
 
 probs = torch.softmax(logits, dim=-1) # [1, 50257]

 confidence = probs.max().item() # dynamic speculative decoding
 if confidence < confidence_threshold and len(draft_probs) > 0:
 break 

 next_token = torch.argmax(probs, dim=-1, keepdim=True)

 draft_probs.append(probs)
 generated = torch.cat([generated, next_token], dim=-1)

 if next_token.item() == eos_token:
 break;

 return generated, draft_probs, past_kv

def verify(drafted, drafted_probs, eos_token, past_kv):
 draft_len = len(drafted_probs) # number of new drafted tokens
 with torch.no_grad():
 if past_kv is None:
 target_outputs = verify_model(drafted, use_cache=True)
 target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]
 else:
 target_outputs = verify_model(
 drafted[:, -(draft_len + 1):], # extra token 
 past_key_values=past_kv,
 use_cache=True
 )
 target_logits = target_outputs.logits[:, :-1, :] # Drop last (predicts bonus token)

 past_kv = target_outputs.past_key_values

 target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]
 accepted_tokens = []
 num_accepted = 0 # number of tokens from drafted that is accepted
 for i in range(draft_len):
 # if q(x) <= p(x), keep
 # if q(x) > p(x), reject with 1-p(x)/q(x) chance
 # if rejected, we sample from norm(max(0, p(x) - q(x)))
 q = drafted_probs[i] # [1, 50257]
 p = target_probs[:, i, :] # [1, 50257]
 token = drafted[:, i - draft_len] # [1]
 # assume unbatched for now
 x = token[0].item()
 q_x = q[0, x].item()
 p_x = p[0, x].item()

 print(f"Token: '{tokenizer.decode(x)}'", end = "")
 if q_x <= p_x:
 print(" -> Accepted")
 accepted_tokens.append(x)
 num_accepted+=1
 else:
 r = torch.rand(1).item()
 acceptance_rate = p_x / q_x

 if r < acceptance_rate:
 print(" -> Accepted")
 accepted_tokens.append(x)
 num_accepted+=1
 else:
 print(" -> Rejected", end = "")
 adjusted = torch.clamp(p-q, min=0)
 adjusted = adjusted / adjusted.sum()
 new_token = torch.multinomial(adjusted, num_samples=1)[0].item()
 accepted_tokens.append(new_token)
 print(tokenizer.decode(new_token))
 break
 if accepted_tokens[-1] == eos_token:
 break

 return accepted_tokens, num_accepted, past_kv

In [24]:
messages = [{"role": "user", "content": "What is a deal flow in a VC fund?"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

max_tokens = 80
eos_token = tokenizer.eos_token_id
im_end_token = tokenizer.convert_tokens_to_ids("<|im_end|>")
gamma = 15
confidence_threshold = 0.5

inputs = tokenizer(prompt, return_tensors="pt").to(device)
result = inputs["input_ids"].clone()

draft_kv = None
verify_kv = None

total_drafted = 0
total_accepted = 0

while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
 drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
 accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)

 total_drafted += len(drafted_probs)
 total_accepted += num_accepted
 
 valid_len = result.shape[-1] + num_accepted
 result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)

 if draft_kv is not None:
 draft_kv.crop(max_length=valid_len)
 if verify_kv is not None:
 verify_kv.crop(max_length=valid_len)

 if eos_token in accepted_tokens or im_end_token in accepted_tokens:
 break

tokenizer.decode(result[0])


Token: 'A' -> Accepted
Token: ' deal' -> Accepted
Token: ' flow' -> Accepted
Token: ' in' -> Accepted
Token: ' a' -> Accepted
Token: ' VC' -> Rejected venture
Token: ' capital' -> Accepted
Token: ' fund' -> Rejected (
Token: 'VC' -> Accepted
Token: ')' -> Accepted
Token: ' fund' -> Accepted
Token: ' is' -> Rejected refers
Token: ' to' -> Accepted
Token: ' the' -> Accepted
Token: ' process' -> Accepted
Token: ' of' -> Accepted
Token: ' setting' -> Rejected screening
Token: ',' -> Accepted
Token: ' evaluating' -> Accepted
Token: ',' -> Accepted
Token: ' and' -> Accepted
Token: ' selecting' -> Accepted
Token: ' investors' -> Rejected potential
Token: ' investment' -> Accepted
Token: ' in' -> Rejected opportunities
Token: ' for' -> Rejected.
Token: ' It' -> Accepted
Token: ' involves' -> Accepted
Token: ' several' -> Rejected identifying
Token: ' potential' -> Accepted
Token: ' investors' -> Rejected companies
Token: ' that' -> Accepted
Token: ' could' -> Rejected the
Token: ' VC' -> Accep

'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is a deal flow in a VC fund?<|im_end|>\n<|im_start|>assistant\nA deal flow in a venture capital (VC) fund refers to the process of screening, evaluating, and selecting potential investment opportunities. It involves identifying potential companies that the VC fund is interested in, assessing their growth potential, and business model, and market demand, and then making a decision on whether to invest in the fund’s capital.<|im_end|>'

In [25]:
total_accepted / total_drafted

0.6071428571428571

In [26]:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
result = inputs["input_ids"].clone()

past_kv = None

while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
 with torch.no_grad():
 output = verify_model(
 result if past_kv is None else result[:, -1:],
 past_key_values=past_kv,
 use_cache=True
 )
 logits = output.logits[:, -1, :] # batch, vocab
 
 past_kv = output.past_key_values
 probs = torch.softmax(logits, dim=-1)
 next_token = torch.multinomial(probs, num_samples=1)

 result = torch.cat([result, next_token], dim=-1)
 if eos_token in next_token or im_end_token in next_token:
 break

tokenizer.decode(result[0])

'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is a deal flow in a VC fund?<|im_end|>\n<|im_start|>assistant\nA deal flow in a VC fund refers to the collection and processing of new investment opportunities presented to the fund for screening, evaluation, and ultimately investment.\n\nHere’s a basic overview:\n\n* **Deal Flow Collection**: Fund managers typically collect deals through email, cold calls, calendars, meeting notes, investor referrals, and referrals by other investors. They utilize various channels including online searches, investor networks, and industry'