Spaces:
Runtime error
Runtime error
Commit
·
1620ce5
1
Parent(s):
55db78d
Upload 3 files
Browse files- app.py +19 -18
- presets.py +11 -2
- utils.py +65 -49
app.py
CHANGED
|
@@ -42,14 +42,6 @@ else:
|
|
| 42 |
gr.Chatbot.postprocess = postprocess
|
| 43 |
|
| 44 |
with gr.Blocks(css=customCSS) as demo:
|
| 45 |
-
gr.HTML(title)
|
| 46 |
-
gr.HTML('''<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="复制 Space"></a>强烈建议点击上面的按钮复制一份这个Space,在你自己的Space里运行,响应更迅速、也更安全👆</center>''')
|
| 47 |
-
with gr.Row():
|
| 48 |
-
with gr.Column(scale=4):
|
| 49 |
-
keyTxt = gr.Textbox(show_label=False, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY).style(container=True)
|
| 50 |
-
with gr.Column(scale=1):
|
| 51 |
-
use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
|
| 52 |
-
chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
|
| 53 |
history = gr.State([])
|
| 54 |
token_count = gr.State([])
|
| 55 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
|
@@ -57,6 +49,15 @@ with gr.Blocks(css=customCSS) as demo:
|
|
| 57 |
FALSECONSTANT = gr.State(False)
|
| 58 |
topic = gr.State("未命名对话历史记录")
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
with gr.Row():
|
| 61 |
with gr.Column(scale=12):
|
| 62 |
user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
|
|
@@ -69,8 +70,9 @@ with gr.Blocks(css=customCSS) as demo:
|
|
| 69 |
delLastBtn = gr.Button("🗑️ 删除最近一条对话")
|
| 70 |
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
| 71 |
status_display = gr.Markdown("status: ready")
|
| 72 |
-
|
| 73 |
-
|
|
|
|
| 74 |
with gr.Accordion(label="加载Prompt模板", open=False):
|
| 75 |
with gr.Column():
|
| 76 |
with gr.Row():
|
|
@@ -101,28 +103,27 @@ with gr.Blocks(css=customCSS) as demo:
|
|
| 101 |
#inputs, top_p, temperature, top_k, repetition_penalty
|
| 102 |
with gr.Accordion("参数", open=False):
|
| 103 |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
|
| 104 |
-
|
| 105 |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
|
| 106 |
step=0.1, interactive=True, label="Temperature",)
|
| 107 |
-
|
| 108 |
-
#repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
|
| 109 |
gr.Markdown(description)
|
| 110 |
|
| 111 |
|
| 112 |
-
user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
| 113 |
user_input.submit(reset_textbox, [], [user_input])
|
| 114 |
|
| 115 |
-
submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
| 116 |
submitBtn.click(reset_textbox, [], [user_input])
|
| 117 |
|
| 118 |
emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
|
| 119 |
|
| 120 |
-
retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
| 121 |
|
| 122 |
-
delLastBtn.click(delete_last_conversation, [chatbot, history, token_count
|
| 123 |
chatbot, history, token_count, status_display], show_progress=True)
|
| 124 |
|
| 125 |
-
reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
|
| 126 |
|
| 127 |
saveHistoryBtn.click(save_chat_history, [
|
| 128 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
|
|
|
| 42 |
gr.Chatbot.postprocess = postprocess
|
| 43 |
|
| 44 |
with gr.Blocks(css=customCSS) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
history = gr.State([])
|
| 46 |
token_count = gr.State([])
|
| 47 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
|
|
|
| 49 |
FALSECONSTANT = gr.State(False)
|
| 50 |
topic = gr.State("未命名对话历史记录")
|
| 51 |
|
| 52 |
+
gr.HTML(title)
|
| 53 |
+
with gr.Row():
|
| 54 |
+
with gr.Column():
|
| 55 |
+
keyTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY, label="API-Key")
|
| 56 |
+
with gr.Column():
|
| 57 |
+
with gr.Row():
|
| 58 |
+
model_select_dropdown = gr.Dropdown(label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0])
|
| 59 |
+
use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
|
| 60 |
+
chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
|
| 61 |
with gr.Row():
|
| 62 |
with gr.Column(scale=12):
|
| 63 |
user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
|
|
|
|
| 70 |
delLastBtn = gr.Button("🗑️ 删除最近一条对话")
|
| 71 |
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
| 72 |
status_display = gr.Markdown("status: ready")
|
| 73 |
+
|
| 74 |
+
systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...", label="System prompt", value=initial_prompt).style(container=True)
|
| 75 |
+
|
| 76 |
with gr.Accordion(label="加载Prompt模板", open=False):
|
| 77 |
with gr.Column():
|
| 78 |
with gr.Row():
|
|
|
|
| 103 |
#inputs, top_p, temperature, top_k, repetition_penalty
|
| 104 |
with gr.Accordion("参数", open=False):
|
| 105 |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
|
| 106 |
+
interactive=True, label="Top-p (nucleus sampling)",)
|
| 107 |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
|
| 108 |
step=0.1, interactive=True, label="Temperature",)
|
| 109 |
+
|
|
|
|
| 110 |
gr.Markdown(description)
|
| 111 |
|
| 112 |
|
| 113 |
+
user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
| 114 |
user_input.submit(reset_textbox, [], [user_input])
|
| 115 |
|
| 116 |
+
submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
| 117 |
submitBtn.click(reset_textbox, [], [user_input])
|
| 118 |
|
| 119 |
emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
|
| 120 |
|
| 121 |
+
retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
| 122 |
|
| 123 |
+
delLastBtn.click(delete_last_conversation, [chatbot, history, token_count], [
|
| 124 |
chatbot, history, token_count, status_display], show_progress=True)
|
| 125 |
|
| 126 |
+
reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
|
| 127 |
|
| 128 |
saveHistoryBtn.click(save_chat_history, [
|
| 129 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
presets.py
CHANGED
|
@@ -31,9 +31,18 @@ pre code {
|
|
| 31 |
}
|
| 32 |
"""
|
| 33 |
|
| 34 |
-
standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
|
| 35 |
-
error_retrieve_prompt = "连接超时,无法获取对话。请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
|
| 36 |
summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
| 38 |
timeout_streaming = 15 # 流式对话时的超时时间
|
| 39 |
max_token_all = 3500 # 非流式对话时的最大 token 数
|
|
|
|
| 31 |
}
|
| 32 |
"""
|
| 33 |
|
|
|
|
|
|
|
| 34 |
summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
|
| 35 |
+
MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"] # 可选的模型
|
| 36 |
+
|
| 37 |
+
# 错误信息
|
| 38 |
+
standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
|
| 39 |
+
error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
|
| 40 |
+
connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
|
| 41 |
+
read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
|
| 42 |
+
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
|
| 43 |
+
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
| 44 |
+
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
| 45 |
+
|
| 46 |
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
| 47 |
timeout_streaming = 15 # 流式对话时的超时时间
|
| 48 |
max_token_all = 3500 # 非流式对话时的最大 token 数
|
utils.py
CHANGED
|
@@ -99,7 +99,7 @@ def construct_assistant(text):
|
|
| 99 |
def construct_token_message(token, stream=False):
|
| 100 |
return f"Token 计数: {token}"
|
| 101 |
|
| 102 |
-
def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
|
| 103 |
headers = {
|
| 104 |
"Content-Type": "application/json",
|
| 105 |
"Authorization": f"Bearer {openai_api_key}"
|
|
@@ -108,7 +108,7 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
|
|
| 108 |
history = [construct_system(system_prompt), *history]
|
| 109 |
|
| 110 |
payload = {
|
| 111 |
-
"model":
|
| 112 |
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
| 113 |
"temperature": temperature, # 1.0,
|
| 114 |
"top_p": top_p, # 1.0,
|
|
@@ -124,40 +124,40 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
|
|
| 124 |
response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
|
| 125 |
return response
|
| 126 |
|
| 127 |
-
def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot,
|
| 128 |
def get_return_value():
|
| 129 |
-
return chatbot, history, status_text,
|
| 130 |
|
| 131 |
print("实时回答模式")
|
| 132 |
-
token_counter = 0
|
| 133 |
partial_words = ""
|
| 134 |
counter = 0
|
| 135 |
status_text = "开始实时传输回答……"
|
| 136 |
history.append(construct_user(inputs))
|
|
|
|
|
|
|
| 137 |
user_token_count = 0
|
| 138 |
-
if len(
|
| 139 |
system_prompt_token_count = count_token(system_prompt)
|
| 140 |
user_token_count = count_token(inputs) + system_prompt_token_count
|
| 141 |
else:
|
| 142 |
user_token_count = count_token(inputs)
|
|
|
|
| 143 |
print(f"输入token计数: {user_token_count}")
|
|
|
|
| 144 |
try:
|
| 145 |
-
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
|
| 146 |
except requests.exceptions.ConnectTimeout:
|
| 147 |
-
|
| 148 |
-
status_text = standard_error_msg + "连接超时,无法获取对话。" + error_retrieve_prompt
|
| 149 |
yield get_return_value()
|
| 150 |
return
|
| 151 |
except requests.exceptions.ReadTimeout:
|
| 152 |
-
|
| 153 |
-
status_text = standard_error_msg + "读取超时,无法获取对话。" + error_retrieve_prompt
|
| 154 |
yield get_return_value()
|
| 155 |
return
|
| 156 |
|
| 157 |
-
chatbot.append((parse_text(inputs), ""))
|
| 158 |
yield get_return_value()
|
| 159 |
|
| 160 |
-
for chunk in response.iter_lines():
|
| 161 |
if counter == 0:
|
| 162 |
counter += 1
|
| 163 |
continue
|
|
@@ -169,77 +169,93 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
| 169 |
try:
|
| 170 |
chunk = json.loads(chunk[6:])
|
| 171 |
except json.JSONDecodeError:
|
|
|
|
| 172 |
status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
|
| 173 |
yield get_return_value()
|
| 174 |
-
|
| 175 |
# decode each line as response data is in bytes
|
| 176 |
if chunklength > 6 and "delta" in chunk['choices'][0]:
|
| 177 |
finish_reason = chunk['choices'][0]['finish_reason']
|
| 178 |
-
status_text = construct_token_message(sum(
|
| 179 |
if finish_reason == "stop":
|
| 180 |
-
print("生成完毕")
|
| 181 |
yield get_return_value()
|
| 182 |
break
|
| 183 |
try:
|
| 184 |
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
| 185 |
except KeyError:
|
| 186 |
-
status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(
|
| 187 |
yield get_return_value()
|
| 188 |
break
|
| 189 |
-
|
| 190 |
-
history.append(construct_assistant(" " + partial_words))
|
| 191 |
-
else:
|
| 192 |
-
history[-1] = construct_assistant(partial_words)
|
| 193 |
chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
|
| 194 |
-
|
| 195 |
yield get_return_value()
|
| 196 |
|
| 197 |
|
| 198 |
-
def predict_all(openai_api_key, system_prompt, history, inputs, chatbot,
|
| 199 |
print("一次性回答模式")
|
| 200 |
history.append(construct_user(inputs))
|
|
|
|
|
|
|
|
|
|
| 201 |
try:
|
| 202 |
-
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
|
| 203 |
except requests.exceptions.ConnectTimeout:
|
| 204 |
-
status_text = standard_error_msg + error_retrieve_prompt
|
| 205 |
-
return chatbot, history, status_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
response = json.loads(response.text)
|
| 207 |
content = response["choices"][0]["message"]["content"]
|
| 208 |
-
history
|
| 209 |
chatbot.append((parse_text(inputs), parse_text(content)))
|
| 210 |
total_token_count = response["usage"]["total_tokens"]
|
| 211 |
-
|
| 212 |
status_text = construct_token_message(total_token_count)
|
| 213 |
-
|
| 214 |
-
return chatbot, history, status_text, previous_token_count
|
| 215 |
|
| 216 |
|
| 217 |
-
def predict(openai_api_key, system_prompt, history, inputs, chatbot,
|
| 218 |
print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
if stream:
|
| 220 |
print("使用流式传输")
|
| 221 |
-
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot,
|
| 222 |
-
for chatbot, history, status_text,
|
| 223 |
-
yield chatbot, history, status_text,
|
| 224 |
else:
|
| 225 |
print("不使用流式传输")
|
| 226 |
-
chatbot, history, status_text,
|
| 227 |
-
yield chatbot, history, status_text,
|
| 228 |
-
print(f"传输完毕。当前token计数为{
|
| 229 |
-
|
|
|
|
| 230 |
if stream:
|
| 231 |
max_token = max_token_streaming
|
| 232 |
else:
|
| 233 |
max_token = max_token_all
|
| 234 |
-
if sum(
|
| 235 |
-
print(f"精简token中{
|
| 236 |
-
iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot,
|
| 237 |
-
for chatbot, history, status_text,
|
| 238 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
| 239 |
-
yield chatbot, history, status_text,
|
| 240 |
|
| 241 |
|
| 242 |
-
def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):
|
| 243 |
print("重试中……")
|
| 244 |
if len(history) == 0:
|
| 245 |
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
|
@@ -247,15 +263,15 @@ def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, t
|
|
| 247 |
history.pop()
|
| 248 |
inputs = history.pop()["content"]
|
| 249 |
token_count.pop()
|
| 250 |
-
iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream)
|
| 251 |
print("重试完毕")
|
| 252 |
for x in iter:
|
| 253 |
yield x
|
| 254 |
|
| 255 |
|
| 256 |
-
def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False):
|
| 257 |
print("开始减少token数量……")
|
| 258 |
-
iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, should_check_token_count=False)
|
| 259 |
for chatbot, history, status_text, previous_token_count in iter:
|
| 260 |
history = history[-2:]
|
| 261 |
token_count = previous_token_count[-1:]
|
|
@@ -265,7 +281,7 @@ def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_cou
|
|
| 265 |
print("减少token数量完毕")
|
| 266 |
|
| 267 |
|
| 268 |
-
def delete_last_conversation(chatbot, history, previous_token_count
|
| 269 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
| 270 |
print("由于包含报错信息,只删除chatbot记录")
|
| 271 |
chatbot.pop()
|
|
@@ -280,7 +296,7 @@ def delete_last_conversation(chatbot, history, previous_token_count, streaming):
|
|
| 280 |
if len(previous_token_count) > 0:
|
| 281 |
print("删除了一组对话的token计数记录")
|
| 282 |
previous_token_count.pop()
|
| 283 |
-
return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count)
|
| 284 |
|
| 285 |
|
| 286 |
def save_chat_history(filename, system, history, chatbot):
|
|
|
|
| 99 |
def construct_token_message(token, stream=False):
|
| 100 |
return f"Token 计数: {token}"
|
| 101 |
|
| 102 |
+
def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model):
|
| 103 |
headers = {
|
| 104 |
"Content-Type": "application/json",
|
| 105 |
"Authorization": f"Bearer {openai_api_key}"
|
|
|
|
| 108 |
history = [construct_system(system_prompt), *history]
|
| 109 |
|
| 110 |
payload = {
|
| 111 |
+
"model": selected_model,
|
| 112 |
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
| 113 |
"temperature": temperature, # 1.0,
|
| 114 |
"top_p": top_p, # 1.0,
|
|
|
|
| 124 |
response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
|
| 125 |
return response
|
| 126 |
|
| 127 |
+
def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
|
| 128 |
def get_return_value():
|
| 129 |
+
return chatbot, history, status_text, all_token_counts
|
| 130 |
|
| 131 |
print("实时回答模式")
|
|
|
|
| 132 |
partial_words = ""
|
| 133 |
counter = 0
|
| 134 |
status_text = "开始实时传输回答……"
|
| 135 |
history.append(construct_user(inputs))
|
| 136 |
+
history.append(construct_assistant(""))
|
| 137 |
+
chatbot.append((parse_text(inputs), ""))
|
| 138 |
user_token_count = 0
|
| 139 |
+
if len(all_token_counts) == 0:
|
| 140 |
system_prompt_token_count = count_token(system_prompt)
|
| 141 |
user_token_count = count_token(inputs) + system_prompt_token_count
|
| 142 |
else:
|
| 143 |
user_token_count = count_token(inputs)
|
| 144 |
+
all_token_counts.append(user_token_count)
|
| 145 |
print(f"输入token计数: {user_token_count}")
|
| 146 |
+
yield get_return_value()
|
| 147 |
try:
|
| 148 |
+
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True, selected_model)
|
| 149 |
except requests.exceptions.ConnectTimeout:
|
| 150 |
+
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
|
|
|
| 151 |
yield get_return_value()
|
| 152 |
return
|
| 153 |
except requests.exceptions.ReadTimeout:
|
| 154 |
+
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
|
|
|
| 155 |
yield get_return_value()
|
| 156 |
return
|
| 157 |
|
|
|
|
| 158 |
yield get_return_value()
|
| 159 |
|
| 160 |
+
for chunk in tqdm(response.iter_lines()):
|
| 161 |
if counter == 0:
|
| 162 |
counter += 1
|
| 163 |
continue
|
|
|
|
| 169 |
try:
|
| 170 |
chunk = json.loads(chunk[6:])
|
| 171 |
except json.JSONDecodeError:
|
| 172 |
+
print(chunk)
|
| 173 |
status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
|
| 174 |
yield get_return_value()
|
| 175 |
+
continue
|
| 176 |
# decode each line as response data is in bytes
|
| 177 |
if chunklength > 6 and "delta" in chunk['choices'][0]:
|
| 178 |
finish_reason = chunk['choices'][0]['finish_reason']
|
| 179 |
+
status_text = construct_token_message(sum(all_token_counts), stream=True)
|
| 180 |
if finish_reason == "stop":
|
|
|
|
| 181 |
yield get_return_value()
|
| 182 |
break
|
| 183 |
try:
|
| 184 |
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
| 185 |
except KeyError:
|
| 186 |
+
status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(all_token_counts))
|
| 187 |
yield get_return_value()
|
| 188 |
break
|
| 189 |
+
history[-1] = construct_assistant(partial_words)
|
|
|
|
|
|
|
|
|
|
| 190 |
chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
|
| 191 |
+
all_token_counts[-1] += 1
|
| 192 |
yield get_return_value()
|
| 193 |
|
| 194 |
|
| 195 |
+
def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
|
| 196 |
print("一次性回答模式")
|
| 197 |
history.append(construct_user(inputs))
|
| 198 |
+
history.append(construct_assistant(""))
|
| 199 |
+
chatbot.append((parse_text(inputs), ""))
|
| 200 |
+
all_token_counts.append(count_token(inputs))
|
| 201 |
try:
|
| 202 |
+
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False, selected_model)
|
| 203 |
except requests.exceptions.ConnectTimeout:
|
| 204 |
+
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
| 205 |
+
return chatbot, history, status_text, all_token_counts
|
| 206 |
+
except requests.exceptions.ProxyError:
|
| 207 |
+
status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
|
| 208 |
+
return chatbot, history, status_text, all_token_counts
|
| 209 |
+
except requests.exceptions.SSLError:
|
| 210 |
+
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
|
| 211 |
+
return chatbot, history, status_text, all_token_counts
|
| 212 |
response = json.loads(response.text)
|
| 213 |
content = response["choices"][0]["message"]["content"]
|
| 214 |
+
history[-1] = construct_assistant(content)
|
| 215 |
chatbot.append((parse_text(inputs), parse_text(content)))
|
| 216 |
total_token_count = response["usage"]["total_tokens"]
|
| 217 |
+
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
| 218 |
status_text = construct_token_message(total_token_count)
|
| 219 |
+
return chatbot, history, status_text, all_token_counts
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
+
def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], should_check_token_count = True): # repetition_penalty, top_k
|
| 223 |
print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
| 224 |
+
if len(openai_api_key) != 51:
|
| 225 |
+
status_text = standard_error_msg + no_apikey_msg
|
| 226 |
+
print(status_text)
|
| 227 |
+
history.append(construct_user(inputs))
|
| 228 |
+
history.append("")
|
| 229 |
+
chatbot.append((parse_text(inputs), ""))
|
| 230 |
+
all_token_counts.append(0)
|
| 231 |
+
yield chatbot, history, status_text, all_token_counts
|
| 232 |
+
return
|
| 233 |
+
yield chatbot, history, "开始生成回答……", all_token_counts
|
| 234 |
if stream:
|
| 235 |
print("使用流式传输")
|
| 236 |
+
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
|
| 237 |
+
for chatbot, history, status_text, all_token_counts in iter:
|
| 238 |
+
yield chatbot, history, status_text, all_token_counts
|
| 239 |
else:
|
| 240 |
print("不使用流式传输")
|
| 241 |
+
chatbot, history, status_text, all_token_counts = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
|
| 242 |
+
yield chatbot, history, status_text, all_token_counts
|
| 243 |
+
print(f"传输完毕。当前token计数为{all_token_counts}")
|
| 244 |
+
if len(history) > 1 and history[-1]['content'] != inputs:
|
| 245 |
+
print("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
|
| 246 |
if stream:
|
| 247 |
max_token = max_token_streaming
|
| 248 |
else:
|
| 249 |
max_token = max_token_all
|
| 250 |
+
if sum(all_token_counts) > max_token and should_check_token_count:
|
| 251 |
+
print(f"精简token中{all_token_counts}/{max_token}")
|
| 252 |
+
iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, all_token_counts, top_p, temperature, stream=False, hidden=True)
|
| 253 |
+
for chatbot, history, status_text, all_token_counts in iter:
|
| 254 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
| 255 |
+
yield chatbot, history, status_text, all_token_counts
|
| 256 |
|
| 257 |
|
| 258 |
+
def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0]):
|
| 259 |
print("重试中……")
|
| 260 |
if len(history) == 0:
|
| 261 |
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
|
|
|
| 263 |
history.pop()
|
| 264 |
inputs = history.pop()["content"]
|
| 265 |
token_count.pop()
|
| 266 |
+
iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream, selected_model=selected_model)
|
| 267 |
print("重试完毕")
|
| 268 |
for x in iter:
|
| 269 |
yield x
|
| 270 |
|
| 271 |
|
| 272 |
+
def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False, selected_model = MODELS[0]):
|
| 273 |
print("开始减少token数量……")
|
| 274 |
+
iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, selected_model = selected_model, should_check_token_count=False)
|
| 275 |
for chatbot, history, status_text, previous_token_count in iter:
|
| 276 |
history = history[-2:]
|
| 277 |
token_count = previous_token_count[-1:]
|
|
|
|
| 281 |
print("减少token数量完毕")
|
| 282 |
|
| 283 |
|
| 284 |
+
def delete_last_conversation(chatbot, history, previous_token_count):
|
| 285 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
| 286 |
print("由于包含报错信息,只删除chatbot记录")
|
| 287 |
chatbot.pop()
|
|
|
|
| 296 |
if len(previous_token_count) > 0:
|
| 297 |
print("删除了一组对话的token计数记录")
|
| 298 |
previous_token_count.pop()
|
| 299 |
+
return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count))
|
| 300 |
|
| 301 |
|
| 302 |
def save_chat_history(filename, system, history, chatbot):
|