Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -6,448 +6,190 @@ from urllib.parse import urlparse
|
|
| 6 |
import numpy as np
|
| 7 |
import onnxruntime as ort
|
| 8 |
|
| 9 |
-
import config
|
| 10 |
from utils import preprocess_text
|
| 11 |
|
| 12 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
_binary_session = None
|
| 14 |
_category_session = None
|
| 15 |
_metadata = None
|
| 16 |
|
| 17 |
-
#
|
|
|
|
| 18 |
SPAM_HINT_PATTERN = re.compile(
|
| 19 |
r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
|
| 20 |
r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
|
| 21 |
re.IGNORECASE,
|
| 22 |
)
|
| 23 |
|
| 24 |
-
BENIGN_WIN_CONTEXT_PATTERN = re.compile(
|
| 25 |
-
r"\b(won|win|winner)\b.*\b(match|game|tournament|league|race|finals|team|football|cricket|basketball)\b",
|
| 26 |
-
re.IGNORECASE,
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
SCAM_ACTION_PATTERN = re.compile(
|
| 30 |
r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
|
| 31 |
re.IGNORECASE,
|
| 32 |
)
|
| 33 |
|
| 34 |
GIVEAWAY_OVERRIDE_PATTERN = re.compile(
|
| 35 |
-
r"(\b(won|winner|jackpot
|
| 36 |
-
r"(\b(claim|redeem)\b.*\b(prize|reward|gift|voucher)\b)",
|
| 37 |
re.IGNORECASE,
|
| 38 |
)
|
| 39 |
|
| 40 |
-
SENTENCE_SPLIT_PATTERN = re.compile(r"(?<=[.!?])\s+|\n+")
|
| 41 |
-
URL_TOKEN_PATTERN = re.compile(r"^(https?://\S+|www\.\S+)$", re.IGNORECASE)
|
| 42 |
URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
|
| 43 |
-
LINK_SPAM_CUE_PATTERN = re.compile(
|
| 44 |
-
r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|"
|
| 45 |
-
r"deposit|investment|gift card|limited time|act now|suspended|login|otp|kyc)",
|
| 46 |
-
re.IGNORECASE,
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
BENIGN_ADULT_CONTEXT_PATTERN = re.compile(
|
| 50 |
-
r"(older than 18|over 18|under 18|age requirement|adult supervision|age limit|"
|
| 51 |
-
r"content rating|parental guidance|legal age|years old)",
|
| 52 |
-
re.IGNORECASE,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
BENIGN_WORK_CONTEXT_PATTERN = re.compile(
|
| 56 |
-
r"(pull request|code review|deployment|sprint|bug fix|qa|release note|"
|
| 57 |
-
r"project update|meeting notes|standup|ticket|merge request|ci pipeline)",
|
| 58 |
-
re.IGNORECASE,
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
SHORT_BRAND_ALERT_PATTERN = re.compile(
|
| 62 |
-
r"^[A-Za-z0-9&'._+-]{2,32}\s*:\s*[^:]{2,90}[.!?]?$"
|
| 63 |
-
)
|
| 64 |
|
| 65 |
-
|
| 66 |
-
r"(
|
| 67 |
-
r"(earn\s+\$?\s?\d[\d,]*)|"
|
| 68 |
-
r"(get rich quick)|"
|
| 69 |
-
r"(no experience needed)",
|
| 70 |
re.IGNORECASE,
|
| 71 |
)
|
| 72 |
|
|
|
|
| 73 |
|
| 74 |
def _load_metadata():
|
| 75 |
-
if os.path.exists(
|
| 76 |
-
with open(
|
| 77 |
return json.load(f)
|
| 78 |
|
| 79 |
return {
|
| 80 |
-
"spam_threshold":
|
| 81 |
-
"short_text_word_count": config.SHORT_TEXT_WORD_COUNT,
|
| 82 |
-
"short_text_threshold": config.SHORT_TEXT_THRESHOLD,
|
| 83 |
-
"very_short_text_word_count": config.VERY_SHORT_TEXT_WORD_COUNT,
|
| 84 |
-
"very_short_text_threshold": config.VERY_SHORT_TEXT_THRESHOLD,
|
| 85 |
}
|
| 86 |
|
| 87 |
|
| 88 |
def load_models():
|
| 89 |
-
"""Loads the model sessions from disk only once."""
|
| 90 |
global _binary_session, _category_session, _metadata
|
| 91 |
|
| 92 |
if _binary_session is None:
|
| 93 |
-
|
| 94 |
-
_binary_session = ort.InferenceSession(
|
| 95 |
-
|
| 96 |
if _category_session is None:
|
| 97 |
-
|
| 98 |
-
_category_session = ort.InferenceSession(
|
| 99 |
-
|
| 100 |
if _metadata is None:
|
| 101 |
_metadata = _load_metadata()
|
| 102 |
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
words = [w for w in cleaned_text.split(" ") if w]
|
| 120 |
-
has_spam_hint = bool(SPAM_HINT_PATTERN.search(raw_text or ""))
|
| 121 |
-
|
| 122 |
-
if not has_spam_hint:
|
| 123 |
-
if len(words) <= very_short_word_count:
|
| 124 |
-
threshold = max(threshold, very_short_threshold)
|
| 125 |
-
elif len(words) <= short_word_count:
|
| 126 |
-
threshold = max(threshold, short_threshold)
|
| 127 |
|
| 128 |
return threshold
|
| 129 |
|
| 130 |
|
| 131 |
-
def
|
| 132 |
-
|
| 133 |
-
return False
|
| 134 |
-
return bool(BENIGN_WIN_CONTEXT_PATTERN.search(raw_text)) and not bool(
|
| 135 |
-
SCAM_ACTION_PATTERN.search(raw_text)
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def _is_benign_context(raw_text):
|
| 140 |
-
if not raw_text:
|
| 141 |
-
return False
|
| 142 |
-
if bool(SCAM_ACTION_PATTERN.search(raw_text)):
|
| 143 |
-
return False
|
| 144 |
-
return bool(BENIGN_ADULT_CONTEXT_PATTERN.search(raw_text)) or bool(
|
| 145 |
-
BENIGN_WORK_CONTEXT_PATTERN.search(raw_text)
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def _extract_url_domains(raw_text: str) -> list[str]:
|
| 150 |
-
if not raw_text:
|
| 151 |
-
return []
|
| 152 |
-
domains = []
|
| 153 |
-
for m in URL_ANY_PATTERN.finditer(raw_text):
|
| 154 |
-
url = m.group(0).strip()
|
| 155 |
-
if url.lower().startswith("www."):
|
| 156 |
-
url = "https://" + url
|
| 157 |
-
try:
|
| 158 |
-
parsed = urlparse(url)
|
| 159 |
-
host = (parsed.netloc or "").lower().strip()
|
| 160 |
-
except Exception:
|
| 161 |
-
continue
|
| 162 |
-
if not host:
|
| 163 |
-
continue
|
| 164 |
-
if host.startswith("m."):
|
| 165 |
-
host = host[2:]
|
| 166 |
-
domains.append(host)
|
| 167 |
-
return domains
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def _has_blocked_domain(raw_text: str) -> bool:
|
| 171 |
-
blocked = set(getattr(config, "BLOCKED_URL_DOMAINS", set()))
|
| 172 |
-
if not blocked:
|
| 173 |
-
return False
|
| 174 |
-
domains = _extract_url_domains(raw_text)
|
| 175 |
-
if not domains:
|
| 176 |
-
return False
|
| 177 |
-
for host in domains:
|
| 178 |
-
if host in blocked:
|
| 179 |
-
return True
|
| 180 |
-
for base in blocked:
|
| 181 |
-
if host.endswith('.' + base):
|
| 182 |
-
return True
|
| 183 |
-
return False
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def _contains_url(raw_text: str) -> bool:
|
| 187 |
-
return bool(URL_ANY_PATTERN.search(raw_text or ""))
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def _has_link_spam_cues(raw_text: str) -> bool:
|
| 191 |
-
return bool(LINK_SPAM_CUE_PATTERN.search(raw_text or ""))
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def _split_long_text(text: str) -> list[str]:
|
| 195 |
-
max_words = int(getattr(config, "CHUNK_MAX_WORDS", 40))
|
| 196 |
-
max_chunks = int(getattr(config, "MAX_CHUNKS", 24))
|
| 197 |
-
|
| 198 |
-
parts = [p.strip() for p in SENTENCE_SPLIT_PATTERN.split(text or "") if p.strip()]
|
| 199 |
chunks = []
|
| 200 |
-
current = []
|
| 201 |
-
current_words = 0
|
| 202 |
-
|
| 203 |
-
for part in parts:
|
| 204 |
-
words = part.split()
|
| 205 |
-
if not words:
|
| 206 |
-
continue
|
| 207 |
-
if len(words) > max_words:
|
| 208 |
-
for i in range(0, len(words), max_words):
|
| 209 |
-
piece = " ".join(words[i : i + max_words]).strip()
|
| 210 |
-
if piece:
|
| 211 |
-
chunks.append(piece)
|
| 212 |
-
if len(chunks) >= max_chunks:
|
| 213 |
-
return chunks[:max_chunks]
|
| 214 |
-
continue
|
| 215 |
-
|
| 216 |
-
if current_words + len(words) > max_words and current:
|
| 217 |
-
chunks.append(" ".join(current).strip())
|
| 218 |
-
current = [part]
|
| 219 |
-
current_words = len(words)
|
| 220 |
-
else:
|
| 221 |
-
current.append(part)
|
| 222 |
-
current_words += len(words)
|
| 223 |
-
|
| 224 |
-
if len(chunks) >= max_chunks:
|
| 225 |
-
return chunks[:max_chunks]
|
| 226 |
-
|
| 227 |
-
if current and len(chunks) < max_chunks:
|
| 228 |
-
chunks.append(" ".join(current).strip())
|
| 229 |
-
|
| 230 |
-
return chunks[:max_chunks]
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def _predict_single(raw_text: str, cleaned_text: str) -> dict:
|
| 234 |
-
if _contains_url(raw_text):
|
| 235 |
-
if _has_blocked_domain(raw_text):
|
| 236 |
-
return {
|
| 237 |
-
"is_spam": True,
|
| 238 |
-
"confidence": 0.99,
|
| 239 |
-
"category": "spam",
|
| 240 |
-
"threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)),
|
| 241 |
-
}
|
| 242 |
-
if not _has_link_spam_cues(raw_text):
|
| 243 |
-
return {
|
| 244 |
-
"is_spam": False,
|
| 245 |
-
"confidence": 0.05,
|
| 246 |
-
"category": "normal",
|
| 247 |
-
"threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)),
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
# Prepare input for ONNX
|
| 251 |
-
# Input name was set to 'input' in conversion script.
|
| 252 |
-
# It expects StringTensorType([None, 1])
|
| 253 |
-
onnx_input = np.array([[cleaned_text]], dtype=object)
|
| 254 |
-
|
| 255 |
-
# Binary prediction
|
| 256 |
-
binary_inputs = {_binary_session.get_inputs()[0].name: onnx_input}
|
| 257 |
-
# Output names are usually 'label' and 'probabilities'
|
| 258 |
-
binary_outputs = _binary_session.run(None, binary_inputs)
|
| 259 |
-
|
| 260 |
-
# binary_outputs[1] is a list of dictionaries like [{'0': 0.9, '1': 0.1}]
|
| 261 |
-
# Let's verify the actual output format.
|
| 262 |
-
# usually it's [labels, [{0: prob, 1: prob}]]
|
| 263 |
-
probs = binary_outputs[1][0]
|
| 264 |
-
spam_prob = float(probs.get(1, probs.get('1', 0.0)))
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
|
|
|
| 268 |
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
if is_spam and spam_prob < 0.85 and _is_benign_context(raw_text):
|
| 272 |
-
is_spam = False
|
| 273 |
|
| 274 |
-
|
| 275 |
-
if not is_spam and has_giveaway_override and not _is_benign_win_context(raw_text):
|
| 276 |
-
is_spam = True
|
| 277 |
|
| 278 |
-
short_brand_alert = bool(SHORT_BRAND_ALERT_PATTERN.match((raw_text or "").strip()))
|
| 279 |
-
money_job_scam = bool(MONEY_JOB_SCAM_PATTERN.search(raw_text or ""))
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
if not is_spam:
|
| 285 |
-
if
|
| 286 |
-
is_spam = True
|
| 287 |
-
elif short_brand_alert and spam_prob >= max(0.50, threshold - 0.16):
|
| 288 |
is_spam = True
|
| 289 |
|
|
|
|
| 290 |
if is_spam:
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
category = "phishing"
|
| 295 |
-
elif has_giveaway_override:
|
| 296 |
-
category = "giveaway"
|
| 297 |
-
else:
|
| 298 |
-
category_inputs = {_category_session.get_inputs()[0].name: onnx_input}
|
| 299 |
-
category_outputs = _category_session.run(None, category_inputs)
|
| 300 |
-
category = str(category_outputs[0][0])
|
| 301 |
else:
|
| 302 |
category = "normal"
|
| 303 |
|
| 304 |
return {
|
| 305 |
-
"is_spam":
|
| 306 |
-
"confidence":
|
| 307 |
-
"category":
|
| 308 |
-
"threshold_used":
|
| 309 |
}
|
| 310 |
|
| 311 |
|
| 312 |
-
|
| 313 |
-
if text is None:
|
| 314 |
-
return False, "Input is required."
|
| 315 |
-
if not isinstance(text, str):
|
| 316 |
-
return False, "Input must be a string."
|
| 317 |
-
normalized = text.strip()
|
| 318 |
-
if not normalized:
|
| 319 |
-
return False, "Input cannot be empty."
|
| 320 |
-
if not any(ch.isalnum() for ch in normalized):
|
| 321 |
-
return True, ""
|
| 322 |
-
if len(normalized) < 2:
|
| 323 |
-
return True, ""
|
| 324 |
-
return True, ""
|
| 325 |
-
|
| 326 |
|
| 327 |
-
def predict_message(text
|
| 328 |
load_models()
|
| 329 |
-
cleaned_text = preprocess_text(text)
|
| 330 |
-
word_count = len([w for w in cleaned_text.split(" ") if w])
|
| 331 |
-
long_threshold = int(getattr(config, "LONG_TEXT_WORD_THRESHOLD", 80))
|
| 332 |
-
|
| 333 |
-
if word_count <= long_threshold:
|
| 334 |
-
pred = _predict_single(text, cleaned_text)
|
| 335 |
-
return {
|
| 336 |
-
"is_spam": pred["is_spam"],
|
| 337 |
-
"confidence": round(pred["confidence"], 4),
|
| 338 |
-
"category": pred["category"],
|
| 339 |
-
"threshold_used": round(pred["threshold_used"], 4),
|
| 340 |
-
"chunked": False,
|
| 341 |
-
}
|
| 342 |
-
|
| 343 |
-
chunks = _split_long_text(text)
|
| 344 |
-
if not chunks:
|
| 345 |
-
pred = _predict_single(text, cleaned_text)
|
| 346 |
-
return {
|
| 347 |
-
"is_spam": pred["is_spam"],
|
| 348 |
-
"confidence": round(pred["confidence"], 4),
|
| 349 |
-
"category": pred["category"],
|
| 350 |
-
"threshold_used": round(pred["threshold_used"], 4),
|
| 351 |
-
"chunked": False,
|
| 352 |
-
}
|
| 353 |
-
|
| 354 |
-
chunk_predictions = []
|
| 355 |
-
for chunk in chunks:
|
| 356 |
-
cp = _predict_single(chunk, preprocess_text(chunk))
|
| 357 |
-
chunk_predictions.append(cp)
|
| 358 |
-
|
| 359 |
-
highest = max(chunk_predictions, key=lambda x: x["confidence"])
|
| 360 |
-
spam_chunks = [cp for cp in chunk_predictions if cp["is_spam"]]
|
| 361 |
-
is_spam = len(spam_chunks) > 0
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
else:
|
| 366 |
-
representative = highest
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
"confidence"
|
| 371 |
-
|
| 372 |
-
"threshold_used": round(float(representative["threshold_used"]), 4),
|
| 373 |
-
"chunked": True,
|
| 374 |
-
"chunk_count": len(chunks),
|
| 375 |
-
}
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
-
def run_model(text: str) -> dict:
|
| 379 |
-
ok, error = validate_message(text)
|
| 380 |
-
if not ok:
|
| 381 |
-
return {
|
| 382 |
-
"ok": False,
|
| 383 |
-
"error": error,
|
| 384 |
-
"input": text,
|
| 385 |
-
}
|
| 386 |
-
prediction = predict_message(text)
|
| 387 |
return {
|
| 388 |
-
"
|
| 389 |
-
"
|
| 390 |
-
"
|
|
|
|
| 391 |
}
|
| 392 |
|
| 393 |
|
| 394 |
-
def
|
| 395 |
-
if text
|
| 396 |
-
return
|
| 397 |
-
os.makedirs("dataset", exist_ok=True)
|
| 398 |
-
feedback_path = os.path.join("dataset", "feedback.jsonl")
|
| 399 |
-
payload = {
|
| 400 |
-
"text": text.strip(),
|
| 401 |
-
"label": int(label),
|
| 402 |
-
"category": str(category),
|
| 403 |
-
}
|
| 404 |
-
with open(feedback_path, "a", encoding="utf-8") as f:
|
| 405 |
-
f.write(json.dumps(payload, ensure_ascii=True) + "\n")
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
def get_model_specs() -> dict:
|
| 409 |
-
specs = {
|
| 410 |
-
"model_dir": config.MODEL_DIR,
|
| 411 |
-
"binary_model_path": config.BINARY_MODEL_PATH,
|
| 412 |
-
"category_model_path": config.CATEGORY_MODEL_PATH,
|
| 413 |
-
"metadata_path": config.METADATA_PATH,
|
| 414 |
-
"spam_threshold": config.SPAM_THRESHOLD,
|
| 415 |
-
"word_max_features": getattr(config, "WORD_MAX_FEATURES", None),
|
| 416 |
-
"char_max_features": getattr(config, "CHAR_MAX_FEATURES", None),
|
| 417 |
-
"files": {
|
| 418 |
-
"binary_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "binary_model.onnx")),
|
| 419 |
-
"category_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "category_model.onnx")),
|
| 420 |
-
"metadata_exists": os.path.exists(config.METADATA_PATH),
|
| 421 |
-
},
|
| 422 |
-
}
|
| 423 |
-
try:
|
| 424 |
-
load_models()
|
| 425 |
-
specs["loaded"] = True
|
| 426 |
-
specs["runtime_threshold"] = _metadata.get("spam_threshold", config.SPAM_THRESHOLD)
|
| 427 |
-
except Exception as exc:
|
| 428 |
-
specs["loaded"] = False
|
| 429 |
-
specs["load_error"] = str(exc)
|
| 430 |
-
return specs
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
def print_model_specs() -> None:
|
| 434 |
-
specs = get_model_specs()
|
| 435 |
-
print("Model Specs (ONNX)")
|
| 436 |
-
print(f"- Model dir: {specs['model_dir']}")
|
| 437 |
-
print(f"- Base threshold: {specs['spam_threshold']}")
|
| 438 |
-
print(
|
| 439 |
-
"- Files exist: "
|
| 440 |
-
f"binary_onnx={specs['files']['binary_onnx_exists']}, "
|
| 441 |
-
f"category_onnx={specs['files']['category_onnx_exists']}, "
|
| 442 |
-
f"metadata={specs['files']['metadata_exists']}"
|
| 443 |
-
)
|
| 444 |
-
if specs.get("loaded"):
|
| 445 |
-
print("- Loaded: True")
|
| 446 |
-
print(f"- Runtime threshold: {specs['runtime_threshold']}")
|
| 447 |
-
else:
|
| 448 |
-
print("- Loaded: False")
|
| 449 |
-
print(f"- Load error: {specs.get('load_error', 'unknown error')}")
|
| 450 |
|
|
|
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import onnxruntime as ort
|
| 8 |
|
|
|
|
| 9 |
from utils import preprocess_text
|
| 10 |
|
| 11 |
+
# --------- DEFAULT CONFIG ---------
|
| 12 |
+
|
| 13 |
+
MODEL_DIR = "models"
|
| 14 |
+
|
| 15 |
+
SPAM_THRESHOLD = 0.5
|
| 16 |
+
|
| 17 |
+
SHORT_TEXT_WORD_COUNT = 6
|
| 18 |
+
SHORT_TEXT_THRESHOLD = 0.65
|
| 19 |
+
|
| 20 |
+
VERY_SHORT_TEXT_WORD_COUNT = 3
|
| 21 |
+
VERY_SHORT_TEXT_THRESHOLD = 0.75
|
| 22 |
+
|
| 23 |
+
LONG_TEXT_WORD_THRESHOLD = 80
|
| 24 |
+
|
| 25 |
+
CHUNK_MAX_WORDS = 40
|
| 26 |
+
MAX_CHUNKS = 24
|
| 27 |
+
|
| 28 |
+
BLOCKED_URL_DOMAINS = set()
|
| 29 |
+
|
| 30 |
+
METADATA_PATH = os.path.join("/", "metadata.json")
|
| 31 |
+
|
| 32 |
+
# --------- GLOBALS ---------
|
| 33 |
+
|
| 34 |
_binary_session = None
|
| 35 |
_category_session = None
|
| 36 |
_metadata = None
|
| 37 |
|
| 38 |
+
# --------- REGEX ---------
|
| 39 |
+
|
| 40 |
SPAM_HINT_PATTERN = re.compile(
|
| 41 |
r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
|
| 42 |
r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
|
| 43 |
re.IGNORECASE,
|
| 44 |
)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
SCAM_ACTION_PATTERN = re.compile(
|
| 47 |
r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
|
| 48 |
re.IGNORECASE,
|
| 49 |
)
|
| 50 |
|
| 51 |
GIVEAWAY_OVERRIDE_PATTERN = re.compile(
|
| 52 |
+
r"(\b(won|winner|jackpot)\b.*\b(prize|reward|gift|voucher|iphone|cash)\b)",
|
|
|
|
| 53 |
re.IGNORECASE,
|
| 54 |
)
|
| 55 |
|
|
|
|
|
|
|
| 56 |
URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
LINK_SPAM_CUE_PATTERN = re.compile(
|
| 59 |
+
r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|deposit)",
|
|
|
|
|
|
|
|
|
|
| 60 |
re.IGNORECASE,
|
| 61 |
)
|
| 62 |
|
| 63 |
+
# --------- LOADERS ---------
|
| 64 |
|
| 65 |
def _load_metadata():
|
| 66 |
+
if os.path.exists(METADATA_PATH):
|
| 67 |
+
with open(METADATA_PATH, "r", encoding="utf-8") as f:
|
| 68 |
return json.load(f)
|
| 69 |
|
| 70 |
return {
|
| 71 |
+
"spam_threshold": SPAM_THRESHOLD
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
}
|
| 73 |
|
| 74 |
|
| 75 |
def load_models():
|
|
|
|
| 76 |
global _binary_session, _category_session, _metadata
|
| 77 |
|
| 78 |
if _binary_session is None:
|
| 79 |
+
binary_path = os.path.join(MODEL_DIR, "binary_model.onnx")
|
| 80 |
+
_binary_session = ort.InferenceSession(binary_path)
|
| 81 |
+
|
| 82 |
if _category_session is None:
|
| 83 |
+
category_path = os.path.join(MODEL_DIR, "category_model.onnx")
|
| 84 |
+
_category_session = ort.InferenceSession(category_path)
|
| 85 |
+
|
| 86 |
if _metadata is None:
|
| 87 |
_metadata = _load_metadata()
|
| 88 |
|
| 89 |
|
| 90 |
+
# --------- HELPERS ---------
|
| 91 |
+
|
| 92 |
+
def _contains_url(text):
|
| 93 |
+
return bool(URL_ANY_PATTERN.search(text or ""))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _effective_threshold(text):
|
| 97 |
+
threshold = float(_metadata.get("spam_threshold", SPAM_THRESHOLD))
|
| 98 |
+
|
| 99 |
+
words = text.split()
|
| 100 |
+
if len(words) <= VERY_SHORT_TEXT_WORD_COUNT:
|
| 101 |
+
threshold = max(threshold, VERY_SHORT_TEXT_THRESHOLD)
|
| 102 |
+
elif len(words) <= SHORT_TEXT_WORD_COUNT:
|
| 103 |
+
threshold = max(threshold, SHORT_TEXT_THRESHOLD)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return threshold
|
| 106 |
|
| 107 |
|
| 108 |
+
def _split_text(text):
|
| 109 |
+
words = text.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
chunks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
for i in range(0, len(words), CHUNK_MAX_WORDS):
|
| 113 |
+
chunk = " ".join(words[i:i + CHUNK_MAX_WORDS])
|
| 114 |
+
chunks.append(chunk)
|
| 115 |
|
| 116 |
+
if len(chunks) >= MAX_CHUNKS:
|
| 117 |
+
break
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
return chunks
|
|
|
|
|
|
|
| 120 |
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
# --------- CORE ---------
|
| 123 |
+
|
| 124 |
+
def _predict_single(raw_text, cleaned_text):
|
| 125 |
+
onnx_input = np.array([[cleaned_text]], dtype=object)
|
| 126 |
|
| 127 |
+
# Binary
|
| 128 |
+
inputs = {_binary_session.get_inputs()[0].name: onnx_input}
|
| 129 |
+
outputs = _binary_session.run(None, inputs)
|
| 130 |
+
|
| 131 |
+
probs = outputs[1][0]
|
| 132 |
+
spam_prob = float(probs.get(1, probs.get('1', 0.0)))
|
| 133 |
+
|
| 134 |
+
threshold = _effective_threshold(cleaned_text)
|
| 135 |
+
is_spam = spam_prob >= threshold
|
| 136 |
+
|
| 137 |
+
# Heuristics
|
| 138 |
if not is_spam:
|
| 139 |
+
if GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""):
|
|
|
|
|
|
|
| 140 |
is_spam = True
|
| 141 |
|
| 142 |
+
# Category
|
| 143 |
if is_spam:
|
| 144 |
+
cat_inputs = {_category_session.get_inputs()[0].name: onnx_input}
|
| 145 |
+
cat_outputs = _category_session.run(None, cat_inputs)
|
| 146 |
+
category = str(cat_outputs[0][0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
else:
|
| 148 |
category = "normal"
|
| 149 |
|
| 150 |
return {
|
| 151 |
+
"is_spam": is_spam,
|
| 152 |
+
"confidence": spam_prob,
|
| 153 |
+
"category": category,
|
| 154 |
+
"threshold_used": threshold
|
| 155 |
}
|
| 156 |
|
| 157 |
|
| 158 |
+
# --------- PUBLIC API ---------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
def predict_message(text):
|
| 161 |
load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
cleaned = preprocess_text(text)
|
| 164 |
+
word_count = len(cleaned.split())
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
if word_count <= LONG_TEXT_WORD_THRESHOLD:
|
| 167 |
+
pred = _predict_single(text, cleaned)
|
| 168 |
+
pred["confidence"] = round(pred["confidence"], 4)
|
| 169 |
+
return pred
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
chunks = _split_text(text)
|
| 172 |
+
|
| 173 |
+
preds = [_predict_single(chunk, preprocess_text(chunk)) for chunk in chunks]
|
| 174 |
+
|
| 175 |
+
best = max(preds, key=lambda x: x["confidence"])
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
return {
|
| 178 |
+
"is_spam": any(p["is_spam"] for p in preds),
|
| 179 |
+
"confidence": round(best["confidence"], 4),
|
| 180 |
+
"category": best["category"] if best["is_spam"] else "normal",
|
| 181 |
+
"threshold_used": best["threshold_used"]
|
| 182 |
}
|
| 183 |
|
| 184 |
|
| 185 |
+
def run_model(text):
|
| 186 |
+
if not text or not isinstance(text, str):
|
| 187 |
+
return {"ok": False, "error": "Invalid input"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
result = predict_message(text)
|
| 190 |
|
| 191 |
+
return {
|
| 192 |
+
"ok": True,
|
| 193 |
+
"input": text.strip(),
|
| 194 |
+
"result": result
|
| 195 |
+
}
|