HandsomeSB commited on
Commit
5d1ef4b
·
1 Parent(s): 9aa1873

changing to app.py

Browse files
Files changed (1) hide show
  1. main.py → app.py +40 -38
main.py → app.py RENAMED
@@ -107,29 +107,37 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
107
 
108
  steps = []
109
 
110
- # Track the actual output tokens (for streaming display)
111
- output_tokens = []
112
- # Track metadata for each token: 'accepted', 'rejected', or 'resampled'
113
- token_metadata = []
 
114
 
115
  def build_html():
116
  html = "<div style='font-family: monospace;'>"
117
 
118
- # Final output box - shows the streaming tokens with color coding
119
  html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
120
- html += f"<b>Final Output:</b><br/>"
121
- if output_tokens:
122
- for i, token_id in enumerate(output_tokens):
 
 
 
 
 
 
 
 
123
  token_text = tokenizer.decode([token_id])
124
  token_display = token_text.replace("<", "&lt;").replace(">", "&gt;")
125
 
126
- # Apply color based on metadata
127
- if i < len(token_metadata):
128
- if token_metadata[i] == 'accepted':
129
  html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
130
- elif token_metadata[i] == 'resampled':
131
  html += f"<span style='background: #5AADCC; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
132
- elif token_metadata[i] == 'rejected':
133
  html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 1px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
134
  else:
135
  html += token_display
@@ -137,7 +145,7 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
137
 
138
  # Acceptance rate
139
  if total_drafted > 0:
140
- html += f"<div style='margin-bottom: 20px; padding: 10px; background: #e0e0e0; border-radius: 5px;'>"
141
  html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
142
  html += "</div>"
143
 
@@ -163,17 +171,15 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
163
  return html
164
 
165
  while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
166
- # Draft phase
167
  drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
168
  drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
169
  drafted_tokens = [tokenizer.decode([t]) for t in drafted_token_ids]
170
 
171
- # Immediately show drafted tokens in output (optimistically)
172
- output_tokens.extend(drafted_token_ids)
173
- # Mark all as accepted initially (will be corrected after verification)
174
- token_metadata.extend(['accepted'] * len(drafted_token_ids))
175
 
176
- # Create a temporary step showing all drafted tokens as accepted
177
  temp_step = {
178
  "drafted": drafted_tokens,
179
  "accepted": len(drafted_tokens),
@@ -182,36 +188,33 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
182
  steps.append(temp_step)
183
  total_drafted += len(drafted_probs)
184
 
185
- # Yield the state with drafted tokens showing
186
  yield build_html()
187
 
188
- # Verify phase
189
  accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
190
  total_accepted += num_accepted
191
 
192
- # Now update the step with actual acceptance information
193
- # Remove the optimistically added tokens and metadata
194
- output_tokens = output_tokens[:-len(drafted_token_ids)]
195
- token_metadata = token_metadata[:-len(drafted_token_ids)]
196
 
197
- # Add back the actually accepted tokens with correct metadata
198
- for i, token_id in enumerate(accepted_tokens):
199
- output_tokens.append(token_id)
200
  if i < num_accepted:
201
- # This token was accepted from the draft
202
- token_metadata.append('accepted')
203
  else:
204
- # This is the resampled token
205
- token_metadata.append('resampled')
 
 
 
 
 
206
 
207
- # Update the step with real acceptance info
208
  steps[-1] = {
209
  "drafted": drafted_tokens,
210
  "accepted": num_accepted,
211
  "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
212
  }
213
 
214
- # Yield the corrected state
215
  yield build_html()
216
 
217
  valid_len = result.shape[-1] + num_accepted
@@ -225,7 +228,6 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
225
  if eos_token in accepted_tokens or im_end_token in accepted_tokens:
226
  break
227
 
228
- # Final yield with complete output
229
  yield build_html()
230
 
231
  demo = gr.Interface(
@@ -252,8 +254,8 @@ demo = gr.Interface(
252
  **Watch the tokens stream in real-time!** Draft tokens appear immediately, then get accepted or rejected by the verify model.
253
  """,
254
  examples=[
255
- ["What is a deal flow in a VC fund?", 80, 15, 0.5],
256
- ["def fibonacci(n):", 50, 15, 0.5],
257
  ["Explain the concept of attention in transformers", 60, 10, 0.6]
258
  ]
259
  )
 
107
 
108
  steps = []
109
 
110
+ # Track the clean output tokens (only accepted/resampled)
111
+ clean_output_tokens = []
112
+ all_tokens = []
113
+ # Metadata for ALL tokens: 'accepted', 'rejected', or 'resampled'
114
+ all_token_metadata = []
115
 
116
  def build_html():
117
  html = "<div style='font-family: monospace;'>"
118
 
119
+ # Clean final output box
120
  html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
121
+ html += f"<b>Final Output (Clean):</b><br/>"
122
+ if clean_output_tokens:
123
+ clean_text = tokenizer.decode(clean_output_tokens)
124
+ html += clean_text
125
+ html += "</div>"
126
+
127
+ # Detailed output box
128
+ html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
129
+ html += f"<b>Detailed Output (All Tokens):</b><br/>"
130
+ if all_tokens:
131
+ for i, token_id in enumerate(all_tokens):
132
  token_text = tokenizer.decode([token_id])
133
  token_display = token_text.replace("<", "&lt;").replace(">", "&gt;")
134
 
135
+ if i < len(all_token_metadata):
136
+ if all_token_metadata[i] == 'accepted':
 
137
  html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
138
+ elif all_token_metadata[i] == 'resampled':
139
  html += f"<span style='background: #5AADCC; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
140
+ elif all_token_metadata[i] == 'rejected':
141
  html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 1px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
142
  else:
143
  html += token_display
 
145
 
146
  # Acceptance rate
147
  if total_drafted > 0:
148
+ html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
149
  html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
150
  html += "</div>"
151
 
 
171
  return html
172
 
173
  while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
174
+ # Draft
175
  drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
176
  drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
177
  drafted_tokens = [tokenizer.decode([t]) for t in drafted_token_ids]
178
 
179
+ clean_output_tokens.extend(drafted_token_ids)
180
+ all_tokens.extend(drafted_token_ids)
181
+ all_token_metadata.extend(['accepted'] * len(drafted_token_ids))
 
182
 
 
183
  temp_step = {
184
  "drafted": drafted_tokens,
185
  "accepted": len(drafted_tokens),
 
188
  steps.append(temp_step)
189
  total_drafted += len(drafted_probs)
190
 
 
191
  yield build_html()
192
 
193
+ # Verify
194
  accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
195
  total_accepted += num_accepted
196
 
197
+ clean_output_tokens = clean_output_tokens[:-len(drafted_token_ids)]
198
+ all_token_metadata = all_token_metadata[:-len(drafted_token_ids)]
 
 
199
 
200
+ for i, token_id in enumerate(drafted_token_ids):
 
 
201
  if i < num_accepted:
202
+ all_token_metadata.append('accepted')
 
203
  else:
204
+ all_token_metadata.append('rejected')
205
+
206
+ clean_output_tokens.extend(accepted_tokens)
207
+
208
+ if num_accepted < len(accepted_tokens):
209
+ all_tokens.append(accepted_tokens[-1])
210
+ all_token_metadata.append('resampled')
211
 
 
212
  steps[-1] = {
213
  "drafted": drafted_tokens,
214
  "accepted": num_accepted,
215
  "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
216
  }
217
 
 
218
  yield build_html()
219
 
220
  valid_len = result.shape[-1] + num_accepted
 
228
  if eos_token in accepted_tokens or im_end_token in accepted_tokens:
229
  break
230
 
 
231
  yield build_html()
232
 
233
  demo = gr.Interface(
 
254
  **Watch the tokens stream in real-time!** Draft tokens appear immediately, then get accepted or rejected by the verify model.
255
  """,
256
  examples=[
257
+ ["What is the capital of France?", 80, 15, 0.5],
258
+ ["Complete the python function \n def fibonacci(n):", 50, 15, 0.5],
259
  ["Explain the concept of attention in transformers", 60, 10, 0.6]
260
  ]
261
  )