mereith commited on
Commit
a1207af
·
1 Parent(s): 8a9e633

feat: stop chat

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +3 -150
  3. api.py +12 -0
  4. app.py +64 -17
.gitignore CHANGED
@@ -1 +1,2 @@
1
  venv
 
 
1
  venv
2
+ local-dev.sh
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ODR Demo - Async SSE Chat Interface
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
@@ -8,7 +8,7 @@ sdk_version: 5.43.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: 一个演示项目,包含异步 SSE(服务器发送事件)接口请求功能,支持实时流式聊天体验
12
  tags:
13
  - chatbot
14
  - sse
@@ -19,151 +19,4 @@ tags:
19
  suggested_hardware: cpu-basic
20
  hf_oauth: false
21
  disable_embedding: false
22
- ---
23
-
24
- # ODR Demo
25
-
26
- 一个演示项目,包含异步 SSE(服务器发送事件)接口请求功能,支持实时流式聊天体验。
27
-
28
- ## 🌟 特性
29
-
30
- - 异步 SSE 流式请求
31
- - 支持自定义参数(深度思考模式、调试模式等)
32
- - 返回异步迭代器进行实时数据处理
33
- - 支持原始数据和解析 JSON 数据模式
34
- - 结构化事件解析,包含 `event` 和 `data` 字段
35
- - 正确的 SSE 事件类型检测和处理
36
- - 实时 Markdown 渲染和可折叠工具调用显示
37
-
38
- ## 🚀 快速开始
39
-
40
- 这个应用部署在 Hugging Face Spaces 上,你可以直接访问 Web 界面进行交互:
41
-
42
- 1. 在输入框中输入你的英文问题
43
- 2. 点击 "Run" 按钮开始处理
44
- 3. 观察实时的 Agent 和工具调用过程
45
- 4. 可以随时点击 "Stop" 按钮停止处理
46
-
47
- ## 📦 本地安装
48
-
49
- 如果你想在本地运行此项目:
50
-
51
- ```bash
52
- # 克隆项目
53
- git clone <repository-url>
54
- cd odr-demo
55
-
56
- # 创建虚拟环境
57
- python -m venv venv
58
- source venv/bin/activate # 在 Windows 上使用: venv\Scripts\activate
59
-
60
- # 安装依赖
61
- pip install gradio aiohttp
62
-
63
- # 设置环境变量
64
- export API_ENDPOINT="your-api-endpoint-here"
65
-
66
- # 运行应用
67
- python app.py
68
- ```
69
-
70
- ## 💻 基本用法
71
-
72
- ### 原始 SSE 事件
73
-
74
- ```python
75
- import asyncio
76
- from api import request_sse_stream
77
-
78
- async def main():
79
- query = "Hello, please introduce Python"
80
-
81
- async for event_data in request_sse_stream(query):
82
- event_type = event_data.get('event', 'unknown')
83
- data_content = event_data.get('data', '')
84
- print(f"Event: {event_type}")
85
- print(f"Data: {data_content}")
86
-
87
- asyncio.run(main())
88
- ```
89
-
90
- ### 解析后的数据
91
-
92
- ```python
93
- import asyncio
94
- from api import request_sse_stream_parsed
95
-
96
- async def main():
97
- query = "What is machine learning?"
98
-
99
- async for event_data in request_sse_stream_parsed(query):
100
- event_type = event_data.get('event', 'unknown')
101
- parsed_data = event_data.get('data', {})
102
- print(f"Event: {event_type}")
103
- print(f"Parsed Data: {parsed_data}")
104
-
105
- asyncio.run(main())
106
- ```
107
-
108
- ### 使用类方法(更多控制)
109
-
110
- ```python
111
- import asyncio
112
- from api import SSEClient
113
-
114
- async def main():
115
- client = SSEClient()
116
-
117
- async for event_data in client.stream_chat(
118
- query="Explain deep learning",
119
- deep_thinking_mode=True, # 启用深度思考
120
- debug=True, # 启用调试模式
121
- chat_id="my_custom_id" # 自定义聊天 ID
122
- ):
123
- event_type = event_data.get('event', 'unknown')
124
- data_content = event_data.get('data', '')
125
- print(f"Event: {event_type}")
126
- print(f"Data: {data_content}")
127
-
128
- asyncio.run(main())
129
- ```
130
-
131
- ## ⚙️ API 参数
132
-
133
- - `query`: 必需,用户查询内容
134
- - `deep_thinking_mode`: 可选,是否启用深度思考模式,默认 False
135
- - `search_before_planning`: 可选,是否在规划前搜索,默认 False
136
- - `debug`: 可选,是否启用调试模式,默认 False
137
- - `chat_id`: 可选,聊天 ID,如未提供将自动生成
138
-
139
- ## 📊 数据结构
140
-
141
- 所有函数返回异步迭代器,产生以下结构的字典:
142
-
143
- ```python
144
- {
145
- "event": "message", # SSE 事件类型(如 "message", "error", "data" 等)
146
- "data": "..." # 事件数据内容
147
- }
148
- ```
149
-
150
- - `request_sse_stream()`: 返回原始数据,`data` 字段包含原始字符串
151
- - `request_sse_stream_parsed()`: 返回解析后的数据,`data` 字段包含 JSON 对象(如果可能)
152
-
153
- ## 📁 文件说明
154
-
155
- - `api.py`: 主要的 SSE 客户端实现
156
- - `app.py`: Gradio Web 应用程序
157
- - `utils.py`: 工具函数
158
- - `README.md`: 项目文档
159
-
160
- ## 🎯 注意事项
161
-
162
- - 目前仅支持英文输入
163
- - 需要设置有效的 `API_ENDPOINT` 环境变量
164
- - 实时显示包括代理处理过程和工具调用详情
165
- - 支持错误处理和中断操作
166
-
167
- ## 📄 许可证
168
-
169
- MIT License
 
1
  ---
2
+ title: ODR Demo
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: a demo
12
  tags:
13
  - chatbot
14
  - sse
 
19
  suggested_hardware: cpu-basic
20
  hf_oauth: false
21
  disable_embedding: false
22
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api.py CHANGED
@@ -4,6 +4,9 @@ import os
4
  import uuid
5
  from typing import AsyncIterator, Dict, Any
6
  import aiohttp
 
 
 
7
 
8
 
9
  class SSEClient:
@@ -164,6 +167,15 @@ async def request_sse_stream_parsed(query: str, **kwargs) -> AsyncIterator[Dict[
164
  yield event_data
165
 
166
 
 
 
 
 
 
 
 
 
 
167
  # Example usage
168
  async def main():
169
  """Example usage method"""
 
4
  import uuid
5
  from typing import AsyncIterator, Dict, Any
6
  import aiohttp
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
 
12
  class SSEClient:
 
167
  yield event_data
168
 
169
 
170
+ async def stop_chat(chat_id: str):
171
+ url = f"{os.getenv('STOP_CHAT_API_ENDPOINT')}"
172
+ async with aiohttp.ClientSession() as session:
173
+ async with session.post(url, json={"chatId": chat_id}) as response:
174
+ if response.status != 200:
175
+ logger.error(f"Request failed with status code: {response.status}")
176
+ raise Exception(f"Request failed with status code: {response.status}")
177
+ return await response.json()
178
+
179
  # Example usage
180
  async def main():
181
  """Example usage method"""
app.py CHANGED
@@ -1,15 +1,18 @@
1
  import json
2
  import logging
3
  import uuid
 
4
  from typing import Optional
5
  import gradio as gr
6
 
7
- from api import request_sse_stream_parsed
8
 
9
  from utils import contains_chinese, replace_chinese_punctuation
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  from typing import Optional
15
 
@@ -89,10 +92,14 @@ def _render_markdown(state: dict) -> str:
89
  lines.append("\n---\n")
90
  return "\n".join(lines) if lines else "Waiting..."
91
 
92
- def _update_state_with_event(state: dict, message: dict):
93
  event = message.get("event")
94
  data = message.get("data", {})
95
- if event == "start_of_agent":
 
 
 
 
96
  agent_id = data.get("agent_id")
97
  agent_name = data.get("agent_name", "unknown")
98
  if agent_id and agent_id not in state["agents"]:
@@ -189,15 +196,17 @@ async def gradio_run(query: str, ui_state: Optional[dict]):
189
  "we only support English input for the time being.",
190
  gr.update(interactive=True),
191
  gr.update(interactive=False),
192
- ui_state or {"task_id": None}
193
  )
194
  return
195
- task_id = str(uuid.uuid4())
196
- if not ui_state:
197
- ui_state = {"task_id": task_id}
198
- else:
199
- ui_state = {**ui_state, "task_id": task_id}
200
  state = _init_render_state()
 
 
 
 
 
 
201
  # Initial: disable Run, enable Stop, and show spinner at bottom of text
202
  yield (
203
  _render_markdown(state) + _spinner_markup(True),
@@ -205,15 +214,36 @@ async def gradio_run(query: str, ui_state: Optional[dict]):
205
  gr.update(interactive=True),
206
  ui_state
207
  )
208
- async for message in request_sse_stream_parsed(query):
209
- state = _update_state_with_event(state, message)
210
- md = _render_markdown(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  yield (
212
- md + _spinner_markup(True),
213
- gr.update(interactive=False),
214
  gr.update(interactive=True),
 
215
  ui_state
216
  )
 
 
 
 
 
217
  # End: enable Run, disable Stop, remove spinner
218
  yield (
219
  _render_markdown(state),
@@ -222,8 +252,25 @@ async def gradio_run(query: str, ui_state: Optional[dict]):
222
  ui_state
223
  )
224
 
225
- def stop_current(ui_state: Optional[dict]):
226
- # Immediately switch button availability: enable Run, disable Stop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  return (
228
  gr.update(interactive=True),
229
  gr.update(interactive=False),
@@ -244,7 +291,7 @@ def build_demo():
244
  run_btn = gr.Button("Run")
245
  stop_btn = gr.Button("Stop", variant="stop", interactive=False)
246
  out_md = gr.Markdown("", elem_id="log-view")
247
- ui_state = gr.State({"task_id": None})
248
  # run: outputs -> markdown, run_btn(update), stop_btn(update), ui_state
249
  run_btn.click(fn=gradio_run, inputs=[inp, ui_state], outputs=[out_md, run_btn, stop_btn, ui_state])
250
  # stop: outputs -> run_btn(update), stop_btn(update)
 
1
  import json
2
  import logging
3
  import uuid
4
+ import asyncio
5
  from typing import Optional
6
  import gradio as gr
7
 
8
+ from api import request_sse_stream_parsed, stop_chat
9
 
10
  from utils import contains_chinese, replace_chinese_punctuation
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
+ running_tasks = {}
15
+
16
 
17
  from typing import Optional
18
 
 
92
  lines.append("\n---\n")
93
  return "\n".join(lines) if lines else "Waiting..."
94
 
95
+ def _update_state_with_event(ui_state: dict, state: dict, message: dict):
96
  event = message.get("event")
97
  data = message.get("data", {})
98
+ if event == "job_started":
99
+ chat_id = data.get("chat_id")
100
+ state["chat_id"] = chat_id
101
+ ui_state["chat_id"] = chat_id
102
+ elif event == "start_of_agent":
103
  agent_id = data.get("agent_id")
104
  agent_name = data.get("agent_name", "unknown")
105
  if agent_id and agent_id not in state["agents"]:
 
196
  "we only support English input for the time being.",
197
  gr.update(interactive=True),
198
  gr.update(interactive=False),
199
+ ui_state or {"chat_id": None}
200
  )
201
  return
202
+
 
 
 
 
203
  state = _init_render_state()
204
+
205
+ task_id = str(uuid.uuid4())
206
+ if ui_state is None:
207
+ ui_state = {"chat_id": None}
208
+ ui_state["task_id"] = task_id
209
+
210
  # Initial: disable Run, enable Stop, and show spinner at bottom of text
211
  yield (
212
  _render_markdown(state) + _spinner_markup(True),
 
214
  gr.update(interactive=True),
215
  ui_state
216
  )
217
+
218
+ try:
219
+ current_task = asyncio.current_task()
220
+ running_tasks[task_id] = current_task
221
+
222
+ async for message in request_sse_stream_parsed(query):
223
+ if current_task.cancelled():
224
+ break
225
+
226
+ state = _update_state_with_event(ui_state, state, message)
227
+ md = _render_markdown(state)
228
+ yield (
229
+ md + _spinner_markup(True),
230
+ gr.update(interactive=False),
231
+ gr.update(interactive=True),
232
+ ui_state
233
+ )
234
+ except asyncio.CancelledError:
235
+ state.setdefault("errors", []).append("Task has been cancelled")
236
  yield (
237
+ _render_markdown(state),
 
238
  gr.update(interactive=True),
239
+ gr.update(interactive=False),
240
  ui_state
241
  )
242
+ return
243
+ finally:
244
+ if task_id in running_tasks:
245
+ del running_tasks[task_id]
246
+
247
  # End: enable Run, disable Stop, remove spinner
248
  yield (
249
  _render_markdown(state),
 
252
  ui_state
253
  )
254
 
255
+ async def stop_current(ui_state: Optional[dict]):
256
+ if ui_state is None:
257
+ ui_state = {}
258
+
259
+ task_id = ui_state.get("task_id")
260
+ if task_id and task_id in running_tasks:
261
+ task = running_tasks[task_id]
262
+ if task and not task.done():
263
+ task.cancel()
264
+ logger.info(f"Task has been cancelled: {task_id}")
265
+
266
+ chat_id = ui_state.get("chat_id")
267
+ if chat_id:
268
+ try:
269
+ res = await stop_chat(chat_id)
270
+ logger.info(f"Chat has been stopped: {chat_id}, res: {res}")
271
+ except Exception as e:
272
+ logger.error(f"Stop chat API call failed: {e}")
273
+
274
  return (
275
  gr.update(interactive=True),
276
  gr.update(interactive=False),
 
291
  run_btn = gr.Button("Run")
292
  stop_btn = gr.Button("Stop", variant="stop", interactive=False)
293
  out_md = gr.Markdown("", elem_id="log-view")
294
+ ui_state = gr.State({"chat_id": None})
295
  # run: outputs -> markdown, run_btn(update), stop_btn(update), ui_state
296
  run_btn.click(fn=gradio_run, inputs=[inp, ui_state], outputs=[out_md, run_btn, stop_btn, ui_state])
297
  # stop: outputs -> run_btn(update), stop_btn(update)