rzline commited on
Commit
12cfedd
·
verified ·
1 Parent(s): b2afc29

Create utils.py

Browse files
Files changed (1) hide show
  1. app/utils.py +240 -0
app/utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from fastapi import HTTPException, Request
3
+ import time
4
+ import re
5
+ from datetime import datetime, timedelta
6
+ from apscheduler.schedulers.background import BackgroundScheduler
7
+ import os
8
+ import requests
9
+ import httpx
10
+ from threading import Lock
11
+ import logging
12
+ import sys
13
+
14
+ DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
15
+ LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
16
+ LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
17
+
18
+ # 配置 logger
19
+ logger = logging.getLogger("my_logger")
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ handler = logging.StreamHandler()
23
+ # formatter = logging.Formatter('%(message)s')
24
+ # handler.setFormatter(formatter)
25
+ logger.addHandler(handler)
26
+
27
+ def format_log_message(level, message, extra=None):
28
+ extra = extra or {}
29
+ log_values = {
30
+ 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
31
+ 'levelname': level,
32
+ 'key': extra.get('key', 'N/A'),
33
+ 'request_type': extra.get('request_type', 'N/A'),
34
+ 'model': extra.get('model', 'N/A'),
35
+ 'status_code': extra.get('status_code', 'N/A'),
36
+ 'error_message': extra.get('error_message', ''),
37
+ 'message': message
38
+ }
39
+ log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
40
+ return log_format % log_values
41
+
42
+
43
+ class APIKeyManager:
44
+ def __init__(self):
45
+ self.api_keys = re.findall(
46
+ r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
47
+ self.key_stack = [] # 初始化密钥栈
48
+ self._reset_key_stack() # 初始化时创建随机密钥栈
49
+ # self.api_key_blacklist = set()
50
+ # self.api_key_blacklist_duration = 60
51
+ self.scheduler = BackgroundScheduler()
52
+ self.scheduler.start()
53
+ self.tried_keys_for_request = set() # 用于跟踪当前请求尝试中已试过的 key
54
+
55
+ def _reset_key_stack(self):
56
+ """创建并随机化密钥栈"""
57
+ shuffled_keys = self.api_keys[:] # 创建 api_keys 的副本以避免直接修改原列表
58
+ random.shuffle(shuffled_keys)
59
+ self.key_stack = shuffled_keys
60
+
61
+
62
+ def get_available_key(self):
63
+ """从栈顶获取密钥,栈空时重新生成 (修改后)"""
64
+ while self.key_stack:
65
+ key = self.key_stack.pop()
66
+ # if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
67
+ if key not in self.tried_keys_for_request:
68
+ self.tried_keys_for_request.add(key)
69
+ return key
70
+
71
+ if not self.api_keys:
72
+ log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
73
+ logger.error(log_msg)
74
+ return None
75
+
76
+ self._reset_key_stack() # 重新生成密钥栈
77
+
78
+ # 再次尝试从新栈中获取密钥 (迭代一次)
79
+ while self.key_stack:
80
+ key = self.key_stack.pop()
81
+ # if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
82
+ if key not in self.tried_keys_for_request:
83
+ self.tried_keys_for_request.add(key)
84
+ return key
85
+
86
+ return None
87
+
88
+
89
+ def show_all_keys(self):
90
+ log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
91
+ logger.info(log_msg)
92
+ for i, api_key in enumerate(self.api_keys):
93
+ log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
94
+ logger.info(log_msg)
95
+
96
+ # def blacklist_key(self, key):
97
+ # log_msg = format_log_message('WARNING', f"{key[:8]} → 暂时禁用 {self.api_key_blacklist_duration} 秒")
98
+ # logger.warning(log_msg)
99
+ # self.api_key_blacklist.add(key)
100
+ # self.scheduler.add_job(lambda: self.api_key_blacklist.discard(key), 'date',
101
+ # run_date=datetime.now() + timedelta(seconds=self.api_key_blacklist_duration))
102
+
103
+ def reset_tried_keys_for_request(self):
104
+ """在新的请求尝试时重置已尝试的 key 集合"""
105
+ self.tried_keys_for_request = set()
106
+
107
+
108
+ def handle_gemini_error(error, current_api_key, key_manager) -> str:
109
+ if isinstance(error, requests.exceptions.HTTPError):
110
+ status_code = error.response.status_code
111
+ if status_code == 400:
112
+ try:
113
+ error_data = error.response.json()
114
+ if 'error' in error_data:
115
+ if error_data['error'].get('code') == "invalid_argument":
116
+ error_message = "无效的 API 密钥"
117
+ extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
118
+ log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
119
+ logger.error(log_msg)
120
+ # key_manager.blacklist_key(current_api_key)
121
+
122
+ return error_message
123
+ error_message = error_data['error'].get(
124
+ 'message', 'Bad Request')
125
+ extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
126
+ log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
127
+ logger.warning(log_msg)
128
+ return f"400 错误请求: {error_message}"
129
+ except ValueError:
130
+ error_message = "400 错误请求:响应不是有效的JSON格式"
131
+ extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
132
+ log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
133
+ logger.warning(log_msg)
134
+ return error_message
135
+
136
+ elif status_code == 429:
137
+ error_message = "API 密钥配额已用尽或其他原因"
138
+ extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
139
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429)
140
+ logger.warning(log_msg)
141
+ # key_manager.blacklist_key(current_api_key)
142
+
143
+ return error_message
144
+
145
+ elif status_code == 403:
146
+ error_message = "权限被拒绝"
147
+ extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
148
+ log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
149
+ logger.error(log_msg)
150
+ # key_manager.blacklist_key(current_api_key)
151
+
152
+ return error_message
153
+ elif status_code == 500:
154
+ error_message = "服务器内部错误"
155
+ extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
156
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
157
+ logger.warning(log_msg)
158
+
159
+ return "Gemini API 内部错误"
160
+
161
+ elif status_code == 503:
162
+ error_message = "服务不可用"
163
+ extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
164
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
165
+ logger.warning(log_msg)
166
+
167
+ return "Gemini API 服务不可用"
168
+ else:
169
+ error_message = f"未知错误: {status_code}"
170
+ extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
171
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other)
172
+ logger.warning(log_msg)
173
+
174
+ return f"未知错误/模型不可用: {status_code}"
175
+
176
+ elif isinstance(error, requests.exceptions.ConnectionError):
177
+ error_message = "连接错误"
178
+ log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
179
+ logger.warning(log_msg)
180
+ return error_message
181
+
182
+ elif isinstance(error, requests.exceptions.Timeout):
183
+ error_message = "请求超时"
184
+ log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
185
+ logger.warning(log_msg)
186
+ return error_message
187
+ else:
188
+ error_message = f"发生未知错误: {error}"
189
+ log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
190
+ logger.error(log_msg)
191
+ return error_message
192
+
193
+
194
+ async def test_api_key(api_key: str) -> bool:
195
+ """
196
+ 测试 API 密钥是否有效。
197
+ """
198
+ try:
199
+ url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
200
+ async with httpx.AsyncClient() as client:
201
+ response = await client.get(url)
202
+ response.raise_for_status()
203
+ return True
204
+ except Exception:
205
+ return False
206
+
207
+
208
+ rate_limit_data = {}
209
+ rate_limit_lock = Lock()
210
+
211
+
212
+ def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
213
+ now = int(time.time())
214
+ minute = now // 60
215
+ day = now // (60 * 60 * 24)
216
+
217
+ minute_key = f"{request.url.path}:{minute}"
218
+ day_key = f"{request.client.host}:{day}"
219
+
220
+ with rate_limit_lock:
221
+ minute_count, minute_timestamp = rate_limit_data.get(
222
+ minute_key, (0, now))
223
+ if now - minute_timestamp >= 60:
224
+ minute_count = 0
225
+ minute_timestamp = now
226
+ minute_count += 1
227
+ rate_limit_data[minute_key] = (minute_count, minute_timestamp)
228
+
229
+ day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
230
+ if now - day_timestamp >= 86400:
231
+ day_count = 0
232
+ day_timestamp = now
233
+ day_count += 1
234
+ rate_limit_data[day_key] = (day_count, day_timestamp)
235
+
236
+ if minute_count > max_requests_per_minute:
237
+ raise HTTPException(status_code=429, detail={
238
+ "message": "Too many requests per minute", "limit": max_requests_per_minute})
239
+ if day_count > max_requests_per_day_per_ip:
240
+ raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip})