M-Arjun commited on
Commit
b1dcc43
·
verified ·
1 Parent(s): 51272d4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +110 -368
model.py CHANGED
@@ -6,448 +6,190 @@ from urllib.parse import urlparse
6
  import numpy as np
7
  import onnxruntime as ort
8
 
9
- import config
10
  from utils import preprocess_text
11
 
12
- # Global variables to hold the loaded models and sessions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  _binary_session = None
14
  _category_session = None
15
  _metadata = None
16
 
17
- # Signals that are strong spam indicators even in short messages.
 
18
  SPAM_HINT_PATTERN = re.compile(
19
  r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
20
  r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
21
  re.IGNORECASE,
22
  )
23
 
24
- BENIGN_WIN_CONTEXT_PATTERN = re.compile(
25
- r"\b(won|win|winner)\b.*\b(match|game|tournament|league|race|finals|team|football|cricket|basketball)\b",
26
- re.IGNORECASE,
27
- )
28
-
29
  SCAM_ACTION_PATTERN = re.compile(
30
  r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
31
  re.IGNORECASE,
32
  )
33
 
34
  GIVEAWAY_OVERRIDE_PATTERN = re.compile(
35
- r"(\b(won|winner|jackpot|lucky draw)\b.*\b(lambo|lamboo|prize|reward|gift|voucher|tesla|iphone|cash)\b)|"
36
- r"(\b(claim|redeem)\b.*\b(prize|reward|gift|voucher)\b)",
37
  re.IGNORECASE,
38
  )
39
 
40
- SENTENCE_SPLIT_PATTERN = re.compile(r"(?<=[.!?])\s+|\n+")
41
- URL_TOKEN_PATTERN = re.compile(r"^(https?://\S+|www\.\S+)$", re.IGNORECASE)
42
  URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
43
- LINK_SPAM_CUE_PATTERN = re.compile(
44
- r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|"
45
- r"deposit|investment|gift card|limited time|act now|suspended|login|otp|kyc)",
46
- re.IGNORECASE,
47
- )
48
-
49
- BENIGN_ADULT_CONTEXT_PATTERN = re.compile(
50
- r"(older than 18|over 18|under 18|age requirement|adult supervision|age limit|"
51
- r"content rating|parental guidance|legal age|years old)",
52
- re.IGNORECASE,
53
- )
54
-
55
- BENIGN_WORK_CONTEXT_PATTERN = re.compile(
56
- r"(pull request|code review|deployment|sprint|bug fix|qa|release note|"
57
- r"project update|meeting notes|standup|ticket|merge request|ci pipeline)",
58
- re.IGNORECASE,
59
- )
60
-
61
- SHORT_BRAND_ALERT_PATTERN = re.compile(
62
- r"^[A-Za-z0-9&'._+-]{2,32}\s*:\s*[^:]{2,90}[.!?]?$"
63
- )
64
 
65
- MONEY_JOB_SCAM_PATTERN = re.compile(
66
- r"(\$\s?\d[\d,]*(?:\.\d+)?\s*/?\s*(day|week|month))|"
67
- r"(earn\s+\$?\s?\d[\d,]*)|"
68
- r"(get rich quick)|"
69
- r"(no experience needed)",
70
  re.IGNORECASE,
71
  )
72
 
 
73
 
74
  def _load_metadata():
75
- if os.path.exists(config.METADATA_PATH):
76
- with open(config.METADATA_PATH, "r", encoding="utf-8") as f:
77
  return json.load(f)
78
 
79
  return {
80
- "spam_threshold": config.SPAM_THRESHOLD,
81
- "short_text_word_count": config.SHORT_TEXT_WORD_COUNT,
82
- "short_text_threshold": config.SHORT_TEXT_THRESHOLD,
83
- "very_short_text_word_count": config.VERY_SHORT_TEXT_WORD_COUNT,
84
- "very_short_text_threshold": config.VERY_SHORT_TEXT_THRESHOLD,
85
  }
86
 
87
 
88
  def load_models():
89
- """Loads the model sessions from disk only once."""
90
  global _binary_session, _category_session, _metadata
91
 
92
  if _binary_session is None:
93
- onnx_path = os.path.join(config.MODEL_DIR, "binary_model.onnx")
94
- _binary_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
95
-
96
  if _category_session is None:
97
- onnx_path = os.path.join(config.MODEL_DIR, "category_model.onnx")
98
- _category_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
99
-
100
  if _metadata is None:
101
  _metadata = _load_metadata()
102
 
103
 
104
- def _effective_threshold(raw_text, cleaned_text):
105
- threshold = float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD))
106
- short_word_count = int(
107
- _metadata.get("short_text_word_count", config.SHORT_TEXT_WORD_COUNT)
108
- )
109
- short_threshold = float(
110
- _metadata.get("short_text_threshold", config.SHORT_TEXT_THRESHOLD)
111
- )
112
- very_short_word_count = int(
113
- _metadata.get("very_short_text_word_count", config.VERY_SHORT_TEXT_WORD_COUNT)
114
- )
115
- very_short_threshold = float(
116
- _metadata.get("very_short_text_threshold", config.VERY_SHORT_TEXT_THRESHOLD)
117
- )
118
-
119
- words = [w for w in cleaned_text.split(" ") if w]
120
- has_spam_hint = bool(SPAM_HINT_PATTERN.search(raw_text or ""))
121
-
122
- if not has_spam_hint:
123
- if len(words) <= very_short_word_count:
124
- threshold = max(threshold, very_short_threshold)
125
- elif len(words) <= short_word_count:
126
- threshold = max(threshold, short_threshold)
127
 
128
  return threshold
129
 
130
 
131
- def _is_benign_win_context(raw_text):
132
- if not raw_text:
133
- return False
134
- return bool(BENIGN_WIN_CONTEXT_PATTERN.search(raw_text)) and not bool(
135
- SCAM_ACTION_PATTERN.search(raw_text)
136
- )
137
-
138
-
139
- def _is_benign_context(raw_text):
140
- if not raw_text:
141
- return False
142
- if bool(SCAM_ACTION_PATTERN.search(raw_text)):
143
- return False
144
- return bool(BENIGN_ADULT_CONTEXT_PATTERN.search(raw_text)) or bool(
145
- BENIGN_WORK_CONTEXT_PATTERN.search(raw_text)
146
- )
147
-
148
-
149
- def _extract_url_domains(raw_text: str) -> list[str]:
150
- if not raw_text:
151
- return []
152
- domains = []
153
- for m in URL_ANY_PATTERN.finditer(raw_text):
154
- url = m.group(0).strip()
155
- if url.lower().startswith("www."):
156
- url = "https://" + url
157
- try:
158
- parsed = urlparse(url)
159
- host = (parsed.netloc or "").lower().strip()
160
- except Exception:
161
- continue
162
- if not host:
163
- continue
164
- if host.startswith("m."):
165
- host = host[2:]
166
- domains.append(host)
167
- return domains
168
-
169
-
170
- def _has_blocked_domain(raw_text: str) -> bool:
171
- blocked = set(getattr(config, "BLOCKED_URL_DOMAINS", set()))
172
- if not blocked:
173
- return False
174
- domains = _extract_url_domains(raw_text)
175
- if not domains:
176
- return False
177
- for host in domains:
178
- if host in blocked:
179
- return True
180
- for base in blocked:
181
- if host.endswith('.' + base):
182
- return True
183
- return False
184
-
185
-
186
- def _contains_url(raw_text: str) -> bool:
187
- return bool(URL_ANY_PATTERN.search(raw_text or ""))
188
-
189
-
190
- def _has_link_spam_cues(raw_text: str) -> bool:
191
- return bool(LINK_SPAM_CUE_PATTERN.search(raw_text or ""))
192
-
193
-
194
- def _split_long_text(text: str) -> list[str]:
195
- max_words = int(getattr(config, "CHUNK_MAX_WORDS", 40))
196
- max_chunks = int(getattr(config, "MAX_CHUNKS", 24))
197
-
198
- parts = [p.strip() for p in SENTENCE_SPLIT_PATTERN.split(text or "") if p.strip()]
199
  chunks = []
200
- current = []
201
- current_words = 0
202
-
203
- for part in parts:
204
- words = part.split()
205
- if not words:
206
- continue
207
- if len(words) > max_words:
208
- for i in range(0, len(words), max_words):
209
- piece = " ".join(words[i : i + max_words]).strip()
210
- if piece:
211
- chunks.append(piece)
212
- if len(chunks) >= max_chunks:
213
- return chunks[:max_chunks]
214
- continue
215
-
216
- if current_words + len(words) > max_words and current:
217
- chunks.append(" ".join(current).strip())
218
- current = [part]
219
- current_words = len(words)
220
- else:
221
- current.append(part)
222
- current_words += len(words)
223
-
224
- if len(chunks) >= max_chunks:
225
- return chunks[:max_chunks]
226
-
227
- if current and len(chunks) < max_chunks:
228
- chunks.append(" ".join(current).strip())
229
-
230
- return chunks[:max_chunks]
231
-
232
-
233
- def _predict_single(raw_text: str, cleaned_text: str) -> dict:
234
- if _contains_url(raw_text):
235
- if _has_blocked_domain(raw_text):
236
- return {
237
- "is_spam": True,
238
- "confidence": 0.99,
239
- "category": "spam",
240
- "threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)),
241
- }
242
- if not _has_link_spam_cues(raw_text):
243
- return {
244
- "is_spam": False,
245
- "confidence": 0.05,
246
- "category": "normal",
247
- "threshold_used": float(_metadata.get("spam_threshold", config.SPAM_THRESHOLD)),
248
- }
249
-
250
- # Prepare input for ONNX
251
- # Input name was set to 'input' in conversion script.
252
- # It expects StringTensorType([None, 1])
253
- onnx_input = np.array([[cleaned_text]], dtype=object)
254
-
255
- # Binary prediction
256
- binary_inputs = {_binary_session.get_inputs()[0].name: onnx_input}
257
- # Output names are usually 'label' and 'probabilities'
258
- binary_outputs = _binary_session.run(None, binary_inputs)
259
-
260
- # binary_outputs[1] is a list of dictionaries like [{'0': 0.9, '1': 0.1}]
261
- # Let's verify the actual output format.
262
- # usually it's [labels, [{0: prob, 1: prob}]]
263
- probs = binary_outputs[1][0]
264
- spam_prob = float(probs.get(1, probs.get('1', 0.0)))
265
 
266
- threshold = _effective_threshold(raw_text, cleaned_text)
267
- is_spam = spam_prob >= threshold
 
268
 
269
- if is_spam and spam_prob < 0.92 and _is_benign_win_context(raw_text):
270
- is_spam = False
271
- if is_spam and spam_prob < 0.85 and _is_benign_context(raw_text):
272
- is_spam = False
273
 
274
- has_giveaway_override = bool(GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""))
275
- if not is_spam and has_giveaway_override and not _is_benign_win_context(raw_text):
276
- is_spam = True
277
 
278
- short_brand_alert = bool(SHORT_BRAND_ALERT_PATTERN.match((raw_text or "").strip()))
279
- money_job_scam = bool(MONEY_JOB_SCAM_PATTERN.search(raw_text or ""))
280
 
281
- if cleaned_text.strip().lower() == "join now":
282
- is_spam = False
 
 
283
 
 
 
 
 
 
 
 
 
 
 
 
284
  if not is_spam:
285
- if money_job_scam and spam_prob >= max(0.55, threshold - 0.20):
286
- is_spam = True
287
- elif short_brand_alert and spam_prob >= max(0.50, threshold - 0.16):
288
  is_spam = True
289
 
 
290
  if is_spam:
291
- if money_job_scam:
292
- category = "job_scam"
293
- elif short_brand_alert:
294
- category = "phishing"
295
- elif has_giveaway_override:
296
- category = "giveaway"
297
- else:
298
- category_inputs = {_category_session.get_inputs()[0].name: onnx_input}
299
- category_outputs = _category_session.run(None, category_inputs)
300
- category = str(category_outputs[0][0])
301
  else:
302
  category = "normal"
303
 
304
  return {
305
- "is_spam": bool(is_spam),
306
- "confidence": float(spam_prob),
307
- "category": str(category),
308
- "threshold_used": float(threshold),
309
  }
310
 
311
 
312
- def validate_message(text: str) -> tuple[bool, str]:
313
- if text is None:
314
- return False, "Input is required."
315
- if not isinstance(text, str):
316
- return False, "Input must be a string."
317
- normalized = text.strip()
318
- if not normalized:
319
- return False, "Input cannot be empty."
320
- if not any(ch.isalnum() for ch in normalized):
321
- return True, ""
322
- if len(normalized) < 2:
323
- return True, ""
324
- return True, ""
325
-
326
 
327
- def predict_message(text: str) -> dict:
328
  load_models()
329
- cleaned_text = preprocess_text(text)
330
- word_count = len([w for w in cleaned_text.split(" ") if w])
331
- long_threshold = int(getattr(config, "LONG_TEXT_WORD_THRESHOLD", 80))
332
-
333
- if word_count <= long_threshold:
334
- pred = _predict_single(text, cleaned_text)
335
- return {
336
- "is_spam": pred["is_spam"],
337
- "confidence": round(pred["confidence"], 4),
338
- "category": pred["category"],
339
- "threshold_used": round(pred["threshold_used"], 4),
340
- "chunked": False,
341
- }
342
-
343
- chunks = _split_long_text(text)
344
- if not chunks:
345
- pred = _predict_single(text, cleaned_text)
346
- return {
347
- "is_spam": pred["is_spam"],
348
- "confidence": round(pred["confidence"], 4),
349
- "category": pred["category"],
350
- "threshold_used": round(pred["threshold_used"], 4),
351
- "chunked": False,
352
- }
353
-
354
- chunk_predictions = []
355
- for chunk in chunks:
356
- cp = _predict_single(chunk, preprocess_text(chunk))
357
- chunk_predictions.append(cp)
358
-
359
- highest = max(chunk_predictions, key=lambda x: x["confidence"])
360
- spam_chunks = [cp for cp in chunk_predictions if cp["is_spam"]]
361
- is_spam = len(spam_chunks) > 0
362
 
363
- if is_spam:
364
- representative = max(spam_chunks, key=lambda x: x["confidence"])
365
- else:
366
- representative = highest
367
 
368
- return {
369
- "is_spam": bool(is_spam),
370
- "confidence": round(float(highest["confidence"]), 4),
371
- "category": str(representative["category"] if is_spam else "normal"),
372
- "threshold_used": round(float(representative["threshold_used"]), 4),
373
- "chunked": True,
374
- "chunk_count": len(chunks),
375
- }
376
 
 
 
 
 
 
377
 
378
- def run_model(text: str) -> dict:
379
- ok, error = validate_message(text)
380
- if not ok:
381
- return {
382
- "ok": False,
383
- "error": error,
384
- "input": text,
385
- }
386
- prediction = predict_message(text)
387
  return {
388
- "ok": True,
389
- "input": text.strip(),
390
- "result": prediction,
 
391
  }
392
 
393
 
394
- def update_model(text: str, label: int, category: str):
395
- if text is None:
396
- return
397
- os.makedirs("dataset", exist_ok=True)
398
- feedback_path = os.path.join("dataset", "feedback.jsonl")
399
- payload = {
400
- "text": text.strip(),
401
- "label": int(label),
402
- "category": str(category),
403
- }
404
- with open(feedback_path, "a", encoding="utf-8") as f:
405
- f.write(json.dumps(payload, ensure_ascii=True) + "\n")
406
-
407
-
408
- def get_model_specs() -> dict:
409
- specs = {
410
- "model_dir": config.MODEL_DIR,
411
- "binary_model_path": config.BINARY_MODEL_PATH,
412
- "category_model_path": config.CATEGORY_MODEL_PATH,
413
- "metadata_path": config.METADATA_PATH,
414
- "spam_threshold": config.SPAM_THRESHOLD,
415
- "word_max_features": getattr(config, "WORD_MAX_FEATURES", None),
416
- "char_max_features": getattr(config, "CHAR_MAX_FEATURES", None),
417
- "files": {
418
- "binary_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "binary_model.onnx")),
419
- "category_onnx_exists": os.path.exists(os.path.join(config.MODEL_DIR, "category_model.onnx")),
420
- "metadata_exists": os.path.exists(config.METADATA_PATH),
421
- },
422
- }
423
- try:
424
- load_models()
425
- specs["loaded"] = True
426
- specs["runtime_threshold"] = _metadata.get("spam_threshold", config.SPAM_THRESHOLD)
427
- except Exception as exc:
428
- specs["loaded"] = False
429
- specs["load_error"] = str(exc)
430
- return specs
431
-
432
-
433
- def print_model_specs() -> None:
434
- specs = get_model_specs()
435
- print("Model Specs (ONNX)")
436
- print(f"- Model dir: {specs['model_dir']}")
437
- print(f"- Base threshold: {specs['spam_threshold']}")
438
- print(
439
- "- Files exist: "
440
- f"binary_onnx={specs['files']['binary_onnx_exists']}, "
441
- f"category_onnx={specs['files']['category_onnx_exists']}, "
442
- f"metadata={specs['files']['metadata_exists']}"
443
- )
444
- if specs.get("loaded"):
445
- print("- Loaded: True")
446
- print(f"- Runtime threshold: {specs['runtime_threshold']}")
447
- else:
448
- print("- Loaded: False")
449
- print(f"- Load error: {specs.get('load_error', 'unknown error')}")
450
 
 
451
 
452
- if __name__ == "__main__":
453
- print_model_specs()
 
 
 
 
6
  import numpy as np
7
  import onnxruntime as ort
8
 
 
9
  from utils import preprocess_text
10
 
11
+ # --------- DEFAULT CONFIG ---------
12
+
13
+ MODEL_DIR = "models"
14
+
15
+ SPAM_THRESHOLD = 0.5
16
+
17
+ SHORT_TEXT_WORD_COUNT = 6
18
+ SHORT_TEXT_THRESHOLD = 0.65
19
+
20
+ VERY_SHORT_TEXT_WORD_COUNT = 3
21
+ VERY_SHORT_TEXT_THRESHOLD = 0.75
22
+
23
+ LONG_TEXT_WORD_THRESHOLD = 80
24
+
25
+ CHUNK_MAX_WORDS = 40
26
+ MAX_CHUNKS = 24
27
+
28
+ BLOCKED_URL_DOMAINS = set()
29
+
30
+ METADATA_PATH = os.path.join("/", "metadata.json")
31
+
32
+ # --------- GLOBALS ---------
33
+
34
  _binary_session = None
35
  _category_session = None
36
  _metadata = None
37
 
38
+ # --------- REGEX ---------
39
+
40
  SPAM_HINT_PATTERN = re.compile(
41
  r"(http|www|win|winner|claim|click|offer|bonus|urgent|verify|password|"
42
  r"account|bank|deposit|earn|investment|crypto|btc|telegram|airdrop|giveaway|jackpot|prize)",
43
  re.IGNORECASE,
44
  )
45
 
 
 
 
 
 
46
  SCAM_ACTION_PATTERN = re.compile(
47
  r"(claim|click|prize|reward|link|http|www|money|cash|gift|airdrop|crypto|account|verify|urgent)",
48
  re.IGNORECASE,
49
  )
50
 
51
  GIVEAWAY_OVERRIDE_PATTERN = re.compile(
52
+ r"(\b(won|winner|jackpot)\b.*\b(prize|reward|gift|voucher|iphone|cash)\b)",
 
53
  re.IGNORECASE,
54
  )
55
 
 
 
56
  URL_ANY_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ LINK_SPAM_CUE_PATTERN = re.compile(
59
+ r"(claim|verify|password|bank|urgent|winner|prize|reward|bonus|airdrop|crypto|deposit)",
 
 
 
60
  re.IGNORECASE,
61
  )
62
 
63
+ # --------- LOADERS ---------
64
 
65
  def _load_metadata():
66
+ if os.path.exists(METADATA_PATH):
67
+ with open(METADATA_PATH, "r", encoding="utf-8") as f:
68
  return json.load(f)
69
 
70
  return {
71
+ "spam_threshold": SPAM_THRESHOLD
 
 
 
 
72
  }
73
 
74
 
75
  def load_models():
 
76
  global _binary_session, _category_session, _metadata
77
 
78
  if _binary_session is None:
79
+ binary_path = os.path.join(MODEL_DIR, "binary_model.onnx")
80
+ _binary_session = ort.InferenceSession(binary_path)
81
+
82
  if _category_session is None:
83
+ category_path = os.path.join(MODEL_DIR, "category_model.onnx")
84
+ _category_session = ort.InferenceSession(category_path)
85
+
86
  if _metadata is None:
87
  _metadata = _load_metadata()
88
 
89
 
90
+ # --------- HELPERS ---------
91
+
92
+ def _contains_url(text):
93
+ return bool(URL_ANY_PATTERN.search(text or ""))
94
+
95
+
96
+ def _effective_threshold(text):
97
+ threshold = float(_metadata.get("spam_threshold", SPAM_THRESHOLD))
98
+
99
+ words = text.split()
100
+ if len(words) <= VERY_SHORT_TEXT_WORD_COUNT:
101
+ threshold = max(threshold, VERY_SHORT_TEXT_THRESHOLD)
102
+ elif len(words) <= SHORT_TEXT_WORD_COUNT:
103
+ threshold = max(threshold, SHORT_TEXT_THRESHOLD)
 
 
 
 
 
 
 
 
 
104
 
105
  return threshold
106
 
107
 
108
+ def _split_text(text):
109
+ words = text.split()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  chunks = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ for i in range(0, len(words), CHUNK_MAX_WORDS):
113
+ chunk = " ".join(words[i:i + CHUNK_MAX_WORDS])
114
+ chunks.append(chunk)
115
 
116
+ if len(chunks) >= MAX_CHUNKS:
117
+ break
 
 
118
 
119
+ return chunks
 
 
120
 
 
 
121
 
122
+ # --------- CORE ---------
123
+
124
+ def _predict_single(raw_text, cleaned_text):
125
+ onnx_input = np.array([[cleaned_text]], dtype=object)
126
 
127
+ # Binary
128
+ inputs = {_binary_session.get_inputs()[0].name: onnx_input}
129
+ outputs = _binary_session.run(None, inputs)
130
+
131
+ probs = outputs[1][0]
132
+ spam_prob = float(probs.get(1, probs.get('1', 0.0)))
133
+
134
+ threshold = _effective_threshold(cleaned_text)
135
+ is_spam = spam_prob >= threshold
136
+
137
+ # Heuristics
138
  if not is_spam:
139
+ if GIVEAWAY_OVERRIDE_PATTERN.search(raw_text or ""):
 
 
140
  is_spam = True
141
 
142
+ # Category
143
  if is_spam:
144
+ cat_inputs = {_category_session.get_inputs()[0].name: onnx_input}
145
+ cat_outputs = _category_session.run(None, cat_inputs)
146
+ category = str(cat_outputs[0][0])
 
 
 
 
 
 
 
147
  else:
148
  category = "normal"
149
 
150
  return {
151
+ "is_spam": is_spam,
152
+ "confidence": spam_prob,
153
+ "category": category,
154
+ "threshold_used": threshold
155
  }
156
 
157
 
158
+ # --------- PUBLIC API ---------
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ def predict_message(text):
161
  load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ cleaned = preprocess_text(text)
164
+ word_count = len(cleaned.split())
 
 
165
 
166
+ if word_count <= LONG_TEXT_WORD_THRESHOLD:
167
+ pred = _predict_single(text, cleaned)
168
+ pred["confidence"] = round(pred["confidence"], 4)
169
+ return pred
 
 
 
 
170
 
171
+ chunks = _split_text(text)
172
+
173
+ preds = [_predict_single(chunk, preprocess_text(chunk)) for chunk in chunks]
174
+
175
+ best = max(preds, key=lambda x: x["confidence"])
176
 
 
 
 
 
 
 
 
 
 
177
  return {
178
+ "is_spam": any(p["is_spam"] for p in preds),
179
+ "confidence": round(best["confidence"], 4),
180
+ "category": best["category"] if best["is_spam"] else "normal",
181
+ "threshold_used": best["threshold_used"]
182
  }
183
 
184
 
185
+ def run_model(text):
186
+ if not text or not isinstance(text, str):
187
+ return {"ok": False, "error": "Invalid input"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ result = predict_message(text)
190
 
191
+ return {
192
+ "ok": True,
193
+ "input": text.strip(),
194
+ "result": result
195
+ }