moevis commited on
Commit
e8b5c5e
·
verified ·
1 Parent(s): 6b94939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -23
app.py CHANGED
@@ -7,6 +7,7 @@ import base64
7
  import json
8
  import os
9
  import io
 
10
  from pydub import AudioSegment
11
  import re
12
 
@@ -17,6 +18,46 @@ API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
17
  MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
18
  SECRET = os.getenv("API_SECRET", "")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def escape_html(text):
21
  """Escape HTML special characters to prevent XSS"""
22
  if not isinstance(text, str):
@@ -200,16 +241,35 @@ def format_messages(system, history, user_text, audio_data_list=None):
200
 
201
  return messages
202
 
203
- def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, show_thinking=True, model_name=None):
204
  """Chat function"""
205
  # If model is not specified, use global configuration
206
  if model_name is None:
207
  model_name = MODEL_NAME
208
 
 
 
 
 
209
  if not user_text and not audio_file:
210
- yield history or []
211
  return
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # Ensure history is a list and formatted correctly
214
  history = history or []
215
  clean_history = []
@@ -228,7 +288,7 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
228
 
229
  messages = format_messages(system_prompt, history, user_text, audio_data_list)
230
  if not messages:
231
- yield history or []
232
  return
233
 
234
  # Debug: Print message format
@@ -242,7 +302,14 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
242
  if "input_audio" in item_copy:
243
  audio_info = item_copy["input_audio"].copy()
244
  if "data" in audio_info:
245
- audio_info["data"] = f"[BASE64_AUDIO_DATA_LEN_{len(audio_info['data'])}]"
 
 
 
 
 
 
 
246
  item_copy["input_audio"] = audio_info
247
  new_content.append(item_copy)
248
  else:
@@ -268,6 +335,9 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
268
  # Audio only
269
  history.append({"role": "user", "content": gr.Audio(audio_file)})
270
 
 
 
 
271
  # Add thinking placeholder
272
  if show_thinking:
273
  history.append({
@@ -279,16 +349,19 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
279
  '</div>'
280
  )
281
  })
282
- yield history
283
  else:
284
  history.append({
285
  "role": "assistant",
286
  "content": "⏳ Generating response..."
287
  })
288
- yield history
289
 
290
  try:
291
  # 禁用代理以访问内网 API
 
 
 
292
  with httpx.Client(base_url=API_BASE_URL, timeout=120) as client:
293
  response = client.post("/chat/completions", json={
294
  "model": model_name,
@@ -304,6 +377,8 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
304
  })
305
 
306
  if response.status_code != 200:
 
 
307
  error_msg = f"❌ API Error {response.status_code}"
308
  if response.status_code == 404:
309
  error_msg += " - vLLM service not ready"
@@ -313,7 +388,7 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
313
  error_msg += f" - Model error ({response.text})"
314
  # Update the last message with error
315
  history[-1]["content"] = error_msg
316
- yield history
317
  return
318
 
319
  # Process streaming response
@@ -340,17 +415,17 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
340
  if 'content' in delta:
341
  content = delta['content']
342
  buffer += content
343
-
344
  if is_thinking:
345
  if "</think>" in buffer:
346
  is_thinking = False
347
  parts = buffer.split("</think>", 1)
348
  think_content = parts[0]
349
  response_content = parts[1]
350
-
351
  if think_content.startswith("<think>"):
352
  think_content = think_content[len("<think>"):].strip()
353
-
354
  if show_thinking:
355
  # Format thinking with custom styled block (escape HTML for safety)
356
  escaped_think = escape_html(think_content)
@@ -384,10 +459,10 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
384
  parts = buffer.split("</think>", 1)
385
  think_content = parts[0]
386
  response_content = parts[1]
387
-
388
  if think_content.startswith("<think>"):
389
  think_content = think_content[len("<think>"):].strip()
390
-
391
  if show_thinking:
392
  # Update with formatted thinking + response
393
  escaped_think = escape_html(think_content)
@@ -402,18 +477,26 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
402
  else:
403
  # Only show response
404
  history[-1]["content"] = response_content
405
-
406
- yield history
407
-
408
  except json.JSONDecodeError:
409
  continue
410
 
 
 
 
 
411
  except httpx.ConnectError:
 
 
412
  history[-1]["content"] = "❌ Cannot connect to vLLM API"
413
- yield history
414
  except Exception as e:
 
 
415
  history[-1]["content"] = f"❌ Error: {str(e)}"
416
- yield history
417
 
418
  # Custom CSS for better UI
419
  custom_css = """
@@ -571,7 +654,6 @@ h3 {
571
  margin-top: 1rem;
572
  gap: 0.5rem;
573
  }
574
-
575
  /* Dark Mode Support */
576
  .dark .message.bot {
577
  background: #1f2937 !important;
@@ -604,7 +686,6 @@ h3 {
604
  .dark h3 {
605
  color: #e5e7eb;
606
  }
607
-
608
  /* 滚动条美化 */
609
  ::-webkit-scrollbar {
610
  width: 8px;
@@ -692,6 +773,12 @@ with gr.Blocks(title="Step Audio R1", css=custom_css, theme=gr.themes.Soft()) as
692
  show_label=True
693
  )
694
 
 
 
 
 
 
 
695
  # Buttons
696
  with gr.Row():
697
  clear_btn = gr.Button("🗑️ Clear", scale=1, size="lg")
@@ -729,15 +816,37 @@ with gr.Blocks(title="Step Audio R1", css=custom_css, theme=gr.themes.Soft()) as
729
  bubble_full_width=False
730
  )
731
 
 
 
 
732
  submit_btn.click(
733
  fn=chat,
734
- inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p, show_thinking],
735
- outputs=[chatbot]
736
  )
737
 
738
  clear_btn.click(
739
- fn=lambda: ([], "", None),
740
- outputs=[chatbot, user_text, audio_file]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  )
742
 
743
  if __name__ == "__main__":
 
7
  import json
8
  import os
9
  import io
10
+ import time
11
  from pydub import AudioSegment
12
  import re
13
 
 
18
  MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
19
  SECRET = os.getenv("API_SECRET", "")
20
 
21
+ # 音频大小限制 (10MB)
22
+ MAX_AUDIO_SIZE_MB = 10
23
+ MAX_AUDIO_SIZE_BYTES = MAX_AUDIO_SIZE_MB * 1024 * 1024
24
+
25
+ def get_wav_size(audio_path):
26
+ """Calculate the size of audio after converting to wav (in bytes)"""
27
+ if not audio_path or not os.path.exists(audio_path):
28
+ return 0
29
+ try:
30
+ audio = AudioSegment.from_file(audio_path)
31
+ buffer = io.BytesIO()
32
+ audio.export(buffer, format="wav")
33
+ return len(buffer.getvalue())
34
+ except Exception as e:
35
+ print(f"[ERROR] Failed to calculate wav size: {e}")
36
+ return 0
37
+
38
+ def get_audio_size_info(used_size_bytes, current_audio_path=None):
39
+ """Get audio size usage info message"""
40
+ current_size = 0
41
+ if current_audio_path and os.path.exists(current_audio_path):
42
+ current_size = get_wav_size(current_audio_path)
43
+
44
+ remaining = MAX_AUDIO_SIZE_BYTES - used_size_bytes
45
+
46
+ used_mb = used_size_bytes / (1024 * 1024)
47
+ remaining_mb = remaining / (1024 * 1024)
48
+ current_mb = current_size / (1024 * 1024)
49
+
50
+ if used_size_bytes == 0 and current_size == 0:
51
+ return f"📊 Audio limit: {MAX_AUDIO_SIZE_MB}MB total available"
52
+ elif current_size > 0:
53
+ new_remaining = remaining - current_size
54
+ new_remaining_mb = new_remaining / (1024 * 1024)
55
+ if new_remaining < 0:
56
+ return f"📊 ⚠️ Current audio ({current_mb:.2f}MB) exceeds remaining limit ({remaining_mb:.2f}MB)"
57
+ return f"📊 Audio: {used_mb:.2f}MB used + {current_mb:.2f}MB pending = {new_remaining_mb:.2f}MB remaining"
58
+ else:
59
+ return f"📊 Audio limit: {used_mb:.2f}MB used, {remaining_mb:.2f}MB remaining (max {MAX_AUDIO_SIZE_MB}MB)"
60
+
61
  def escape_html(text):
62
  """Escape HTML special characters to prevent XSS"""
63
  if not isinstance(text, str):
 
241
 
242
  return messages
243
 
244
+ def chat(system_prompt, user_text, audio_file, history, used_audio_size, max_tokens, temperature, top_p, show_thinking=True, model_name=None):
245
  """Chat function"""
246
  # If model is not specified, use global configuration
247
  if model_name is None:
248
  model_name = MODEL_NAME
249
 
250
+ # 初始化已使用音频大小
251
+ if used_audio_size is None:
252
+ used_audio_size = 0
253
+
254
  if not user_text and not audio_file:
255
+ yield history or [], used_audio_size, get_audio_size_info(used_audio_size, None)
256
  return
257
 
258
+ # 检查音频大小限制
259
+ current_audio_size = 0
260
+ if audio_file:
261
+ current_audio_size = get_wav_size(audio_file)
262
+ total_size = used_audio_size + current_audio_size
263
+
264
+ if total_size > MAX_AUDIO_SIZE_BYTES:
265
+ history = history or []
266
+ remaining_mb = (MAX_AUDIO_SIZE_BYTES - used_audio_size) / (1024 * 1024)
267
+ current_mb = current_audio_size / (1024 * 1024)
268
+ error_msg = f"❌ Audio size limit exceeded! Current audio is {current_mb:.2f}MB, but only {max(0, remaining_mb):.2f}MB remaining (max {MAX_AUDIO_SIZE_MB}MB)"
269
+ history.append({"role": "assistant", "content": error_msg})
270
+ yield history, used_audio_size, get_audio_size_info(used_audio_size, None)
271
+ return
272
+
273
  # Ensure history is a list and formatted correctly
274
  history = history or []
275
  clean_history = []
 
288
 
289
  messages = format_messages(system_prompt, history, user_text, audio_data_list)
290
  if not messages:
291
+ yield history or [], used_audio_size, get_audio_size_info(used_audio_size, None)
292
  return
293
 
294
  # Debug: Print message format
 
302
  if "input_audio" in item_copy:
303
  audio_info = item_copy["input_audio"].copy()
304
  if "data" in audio_info:
305
+ data_len = len(audio_info['data'])
306
+ if data_len >= 1024 * 1024:
307
+ human_size = f"{data_len / (1024 * 1024):.2f} MB"
308
+ elif data_len >= 1024:
309
+ human_size = f"{data_len / 1024:.2f} KB"
310
+ else:
311
+ human_size = f"{data_len} B"
312
+ audio_info["data"] = f"[BASE64_AUDIO_DATA: {human_size} ({data_len} bytes)]"
313
  item_copy["input_audio"] = audio_info
314
  new_content.append(item_copy)
315
  else:
 
335
  # Audio only
336
  history.append({"role": "user", "content": gr.Audio(audio_file)})
337
 
338
+ # 更新已使用的音频大小
339
+ new_used_audio_size = used_audio_size + current_audio_size
340
+
341
  # Add thinking placeholder
342
  if show_thinking:
343
  history.append({
 
349
  '</div>'
350
  )
351
  })
352
+ yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None)
353
  else:
354
  history.append({
355
  "role": "assistant",
356
  "content": "⏳ Generating response..."
357
  })
358
+ yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None)
359
 
360
  try:
361
  # 禁用代理以访问内网 API
362
+ start_time = time.time()
363
+ print(f"[API] Starting request to {API_BASE_URL}/chat/completions ...")
364
+
365
  with httpx.Client(base_url=API_BASE_URL, timeout=120) as client:
366
  response = client.post("/chat/completions", json={
367
  "model": model_name,
 
377
  })
378
 
379
  if response.status_code != 200:
380
+ elapsed_time = time.time() - start_time
381
+ print(f"[API] ❌ FAILED - Status: {response.status_code}, Time: {elapsed_time:.2f}s")
382
  error_msg = f"❌ API Error {response.status_code}"
383
  if response.status_code == 404:
384
  error_msg += " - vLLM service not ready"
 
388
  error_msg += f" - Model error ({response.text})"
389
  # Update the last message with error
390
  history[-1]["content"] = error_msg
391
+ yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None)
392
  return
393
 
394
  # Process streaming response
 
415
  if 'content' in delta:
416
  content = delta['content']
417
  buffer += content
418
+
419
  if is_thinking:
420
  if "</think>" in buffer:
421
  is_thinking = False
422
  parts = buffer.split("</think>", 1)
423
  think_content = parts[0]
424
  response_content = parts[1]
425
+
426
  if think_content.startswith("<think>"):
427
  think_content = think_content[len("<think>"):].strip()
428
+
429
  if show_thinking:
430
  # Format thinking with custom styled block (escape HTML for safety)
431
  escaped_think = escape_html(think_content)
 
459
  parts = buffer.split("</think>", 1)
460
  think_content = parts[0]
461
  response_content = parts[1]
462
+
463
  if think_content.startswith("<think>"):
464
  think_content = think_content[len("<think>"):].strip()
465
+
466
  if show_thinking:
467
  # Update with formatted thinking + response
468
  escaped_think = escape_html(think_content)
 
477
  else:
478
  # Only show response
479
  history[-1]["content"] = response_content
480
+
481
+ yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None)
482
+
483
  except json.JSONDecodeError:
484
  continue
485
 
486
+ # 请求成功完成
487
+ elapsed_time = time.time() - start_time
488
+ print(f"[API] ✅ SUCCESS - Time: {elapsed_time:.2f}s")
489
+
490
  except httpx.ConnectError:
491
+ elapsed_time = time.time() - start_time
492
+ print(f"[API] ❌ FAILED - Connection error, Time: {elapsed_time:.2f}s")
493
  history[-1]["content"] = "❌ Cannot connect to vLLM API"
494
+ yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None)
495
  except Exception as e:
496
+ elapsed_time = time.time() - start_time
497
+ print(f"[API] ❌ FAILED - Error: {str(e)}, Time: {elapsed_time:.2f}s")
498
  history[-1]["content"] = f"❌ Error: {str(e)}"
499
+ yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None)
500
 
501
  # Custom CSS for better UI
502
  custom_css = """
 
654
  margin-top: 1rem;
655
  gap: 0.5rem;
656
  }
 
657
  /* Dark Mode Support */
658
  .dark .message.bot {
659
  background: #1f2937 !important;
 
686
  .dark h3 {
687
  color: #e5e7eb;
688
  }
 
689
  /* 滚动条美化 */
690
  ::-webkit-scrollbar {
691
  width: 8px;
 
773
  show_label=True
774
  )
775
 
776
+ # Audio size limit info
777
+ audio_size_info = gr.Markdown(
778
+ value=f"📊 Audio limit: {MAX_AUDIO_SIZE_MB}MB total available",
779
+ elem_classes=["audio-size-info"]
780
+ )
781
+
782
  # Buttons
783
  with gr.Row():
784
  clear_btn = gr.Button("🗑️ Clear", scale=1, size="lg")
 
816
  bubble_full_width=False
817
  )
818
 
819
+ # State to track used audio size (in bytes)
820
+ used_audio_size = gr.State(value=0)
821
+
822
  submit_btn.click(
823
  fn=chat,
824
+ inputs=[system_prompt, user_text, audio_file, chatbot, used_audio_size, max_tokens, temperature, top_p, show_thinking],
825
+ outputs=[chatbot, used_audio_size, audio_size_info]
826
  )
827
 
828
  clear_btn.click(
829
+ fn=lambda: ([], 0, "", None, f"📊 Audio limit: {MAX_AUDIO_SIZE_MB}MB total available"),
830
+ outputs=[chatbot, used_audio_size, user_text, audio_file, audio_size_info]
831
+ )
832
+
833
+ # Update audio size info when audio file changes
834
+ audio_file.change(
835
+ fn=lambda audio, used_size: get_audio_size_info(used_size, audio),
836
+ inputs=[audio_file, used_audio_size],
837
+ outputs=[audio_size_info]
838
+ )
839
+
840
+ # Also listen to upload and stop_recording events
841
+ audio_file.upload(
842
+ fn=lambda audio, used_size: get_audio_size_info(used_size, audio),
843
+ inputs=[audio_file, used_audio_size],
844
+ outputs=[audio_size_info]
845
+ )
846
+ audio_file.stop_recording(
847
+ fn=lambda audio, used_size: get_audio_size_info(used_size, audio),
848
+ inputs=[audio_file, used_audio_size],
849
+ outputs=[audio_size_info]
850
  )
851
 
852
  if __name__ == "__main__":