Zynara commited on
Commit
90e7d34
·
verified ·
1 Parent(s): ffb81a0

Rename README.md to Main.js

Browse files
Files changed (2) hide show
  1. Main.js +990 -0
  2. README.md +0 -3
Main.js ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ import json
5
+ import math
6
+ import shlex
7
+ import tempfile
8
+ import hashlib
9
+ import subprocess
10
+ from threading import Thread
11
+ from typing import List, Dict, Any, Tuple, Optional
12
+ from dotenv import load_dotenv
13
+ import os
14
+
15
+ load_dotenv() # loads .env from the same directory
16
+
17
+ from fastapi import FastAPI, UploadFile, File, WebSocket, BackgroundTasks, Depends, Header, HTTPException, Response
18
+ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from pydantic import BaseModel
21
+
22
+ # ===============================
23
+ # Optional heavy deps (guarded)
24
+ # ===============================
25
+ try:
26
+ import torch
27
+ except Exception:
28
+ torch = None
29
+
30
+ try:
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
32
+ except Exception:
33
+ AutoTokenizer = AutoModelForCausalLM = TextIteratorStreamer = BitsAndBytesConfig = None
34
+
35
+ try:
36
+ from sentence_transformers import SentenceTransformer
37
+ except Exception:
38
+ SentenceTransformer = None
39
+
40
+ try:
41
+ import chromadb
42
+ except Exception:
43
+ chromadb = None
44
+
45
+ try:
46
+ from supabase import Client, create_client
47
+ except Exception:
48
+ create_client = None
49
+
50
+ try:
51
+ import redis as redis_lib
52
+ except Exception:
53
+ redis_lib = None
54
+
55
+ try:
56
+ from faster_whisper import WhisperModel
57
+ except Exception:
58
+ WhisperModel = None
59
+
60
+ try:
61
+ from TTS.api import TTS as CoquiTTS
62
+ except Exception:
63
+ CoquiTTS = None
64
+
65
+ try:
66
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
67
+ except Exception:
68
+ Blip2Processor = Blip2ForConditionalGeneration = None
69
+
70
+ try:
71
+ from PIL import Image
72
+ except Exception:
73
+ Image = None
74
+
75
+ try:
76
+ from audiocraft.models import musicgen as musicgen_lib
77
+ except Exception:
78
+ musicgen_lib = None
79
+
80
+ try:
81
+ from openai import OpenAI
82
+ except Exception:
83
+ OpenAI = None
84
+
85
+ try:
86
+ from duckduckgo_search import ddg as ddg_func
87
+ except Exception:
88
+ ddg_func = None
89
+
90
+ # Prometheus (optional)
91
+ try:
92
+ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
93
+ except Exception:
94
+ generate_latest = CONTENT_TYPE_LATEST = None
95
+
96
+ # ===============================
97
+ # Environment / Config
98
+ # ===============================
99
+ OPENAI_KEY = os.getenv("OPENAI_API_KEY")
100
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
101
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
102
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
103
+ DEFAULT_MODEL = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3.1-8B-Instruct")
104
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2")
105
+ FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "changeme")
106
+ ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,https://zynara.xyz")
107
+ REDIS_URL = os.getenv("REDIS_URL")
108
+ ASR_MODEL_SIZE = os.getenv("ASR_MODEL", "small")
109
+ COQUI_TTS_MODEL = os.getenv("COQUI_TTS_MODEL") # optional, else auto-pick
110
+ CDN_BASE_URL = os.getenv("CDN_BASE_URL")
111
+ DISABLE_MULTIMODAL = os.getenv("DISABLE_MULTIMODAL", "0") == "1"
112
+
113
+ # ===============================
114
+ # App + CORS
115
+ # ===============================
116
+ app = FastAPI(title="Billy AI — All-in-one", version="1.0.0")
117
+ origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=origins or ["*"],
121
+ allow_credentials=True,
122
+ allow_methods=["*"],
123
+ allow_headers=["*"],
124
+ )
125
+
126
+ # ===============================
127
+ # Clients (OpenAI, Redis, Chroma, Supabase)
128
+ # ===============================
129
+ openai_client = OpenAI(api_key=OPENAI_KEY) if (OPENAI_KEY and OpenAI) else None
130
+
131
+ redis_client = None
132
+ if redis_lib and REDIS_URL:
133
+ try:
134
+ redis_client = redis_lib.from_url(REDIS_URL)
135
+ print("✅ Redis connected")
136
+ except Exception as e:
137
+ print("⚠️ Redis init failed:", e)
138
+
139
+ embedder = None
140
+ if SentenceTransformer is not None:
141
+ try:
142
+ embedder = SentenceTransformer(EMBED_MODEL)
143
+ print("✅ Embedder loaded")
144
+ except Exception as e:
145
+ print("⚠️ Embedder init failed:", e)
146
+
147
+ chroma_client = None
148
+ chroma_collection = None
149
+ if chromadb is not None:
150
+ try:
151
+ chroma_client = chromadb.PersistentClient(path="./billy_rag_db")
152
+ try:
153
+ chroma_collection = chroma_client.get_collection("billy_rag")
154
+ except Exception:
155
+ chroma_collection = chroma_client.create_collection("billy_rag")
156
+ print("✅ Chroma ready")
157
+ except Exception as e:
158
+ print("⚠️ Chroma init failed:", e)
159
+
160
+ supabase_client = None
161
+ if create_client and SUPABASE_URL and SUPABASE_KEY:
162
+ try:
163
+ supabase_client = create_client(SUPABASE_URL, SUPABASE_KEY)
164
+ print("✅ Supabase client initialized")
165
+ except Exception as e:
166
+ print("⚠️ Supabase init failed:", e)
167
+
168
+ # ===============================
169
+ # Helpers (IDs, cosine, cache/RL, moderation)
170
+ # ===============================
171
+ def _stable_id(text: str) -> str:
172
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
173
+
174
+ def _cosine(a: List[float], b: List[float]) -> float:
175
+ import numpy as np
176
+ a = np.array(a, dtype=np.float32)
177
+ b = np.array(b, dtype=np.float32)
178
+ na = np.linalg.norm(a) or 1.0
179
+ nb = np.linalg.norm(b) or 1.0
180
+ return float(np.dot(a, b) / (na * nb))
181
+
182
+ async def api_key_auth(x_api_key: Optional[str] = Header(None)):
183
+ if x_api_key is None:
184
+ if FRONTEND_API_KEY == "changeme":
185
+ return True
186
+ raise HTTPException(status_code=401, detail="Missing API key")
187
+ if x_api_key != FRONTEND_API_KEY:
188
+ raise HTTPException(status_code=401, detail="Invalid API key")
189
+ return True
190
+
191
+ def rate_limit(key: str, limit: int = 60, window: int = 60) -> bool:
192
+ if not redis_client:
193
+ return True
194
+ try:
195
+ p = redis_client.pipeline()
196
+ p.incr(key)
197
+ p.expire(key, window)
198
+ val, _ = p.execute()
199
+ return int(val) <= limit
200
+ except Exception:
201
+ return True
202
+
203
+ def cache_get(key: str):
204
+ if not redis_client:
205
+ return None
206
+ try:
207
+ v = redis_client.get(key)
208
+ return json.loads(v) if v else None
209
+ except Exception:
210
+ return None
211
+
212
+ def cache_set(key: str, value, ttl: int = 300):
213
+ if not redis_client:
214
+ return
215
+ try:
216
+ redis_client.set(key, json.dumps(value), ex=ttl)
217
+ except Exception:
218
+ pass
219
+
220
+ def is_safe_message(text: str) -> Tuple[bool, str]:
221
+ if not text:
222
+ return True, ""
223
+ if openai_client is None:
224
+ # very simple heuristic fallback
225
+ banned = ["kill", "terror", "bomb", "nuke"]
226
+ if any(b in text.lower() for b in banned):
227
+ return False, "Blocked by local safety heuristic."
228
+ return True, ""
229
+ try:
230
+ resp = openai_client.moderations.create(model="omni-moderation-latest", input=text)
231
+ flagged = bool(resp.results[0].flagged)
232
+ return (not flagged), ("Blocked by moderation." if flagged else "")
233
+ except Exception:
234
+ return True, ""
235
+
236
+ # ===============================
237
+ # RAG storage (Chroma/Supabase/memory)
238
+ # ===============================
239
+ memory_store: List[Dict[str, Any]] = []
240
+
241
+ def embed_text_local(text: str) -> List[float]:
242
+ if not embedder:
243
+ raise RuntimeError("Embedder not loaded.")
244
+ return embedder.encode(text).tolist()
245
+
246
+ def store_knowledge(text: str, user_id: Optional[str] = None):
247
+ if not text or not text.strip():
248
+ return
249
+ try:
250
+ vec = embed_text_local(text)
251
+ except Exception:
252
+ return
253
+ idx = _stable_id(text)
254
+ if supabase_client:
255
+ try:
256
+ row = {"id": idx, "text": text, "embedding": vec, "source": "user", "created_at": int(time.time())}
257
+ if user_id:
258
+ row["user_id"] = user_id
259
+ supabase_client.table("knowledge").upsert(row).execute()
260
+ return
261
+ except Exception:
262
+ pass
263
+ if chroma_collection:
264
+ try:
265
+ chroma_collection.add(documents=[text], embeddings=[vec], ids=[idx], metadatas=[{"user_id": user_id}])
266
+ return
267
+ except Exception:
268
+ pass
269
+ memory_store.append({"text": text, "embedding": vec, "user_id": user_id})
270
+
271
+ def retrieve_knowledge(query: str, k: int = 5) -> str:
272
+ try:
273
+ qvec = embed_text_local(query)
274
+ except Exception:
275
+ return ""
276
+ if supabase_client:
277
+ try:
278
+ resp = supabase_client.table("knowledge").select("text,embedding").execute()
279
+ data = resp.data or []
280
+ scored = []
281
+ for item in data:
282
+ emb = item.get("embedding")
283
+ if isinstance(emb, list):
284
+ scored.append((item["text"], _cosine(qvec, emb)))
285
+ scored.sort(key=lambda x: x[1], reverse=True)
286
+ return " ".join([t for t, _ in scored[:k]])
287
+ except Exception:
288
+ pass
289
+ if chroma_collection:
290
+ try:
291
+ res = chroma_collection.query(query_embeddings=[qvec], n_results=k)
292
+ docs = res.get("documents", [])
293
+ if docs and docs[0]:
294
+ return " ".join(docs[0])
295
+ except Exception:
296
+ pass
297
+ if memory_store:
298
+ scored = []
299
+ for item in memory_store:
300
+ scored.append((item["text"], _cosine(qvec, item["embedding"])))
301
+ scored.sort(key=lambda x: x[1], reverse=True)
302
+ return " ".join([t for t, _ in scored[:k]])
303
+ return ""
304
+
305
+ def delete_memory_by_id(mem_id: str) -> bool:
306
+ ok = False
307
+ if supabase_client:
308
+ try:
309
+ supabase_client.table("knowledge").delete().eq("id", mem_id).execute()
310
+ ok = True
311
+ except Exception:
312
+ pass
313
+ if chroma_collection:
314
+ try:
315
+ chroma_collection.delete(ids=[mem_id])
316
+ ok = True
317
+ except Exception:
318
+ pass
319
+ global memory_store
320
+ before = len(memory_store)
321
+ memory_store = [m for m in memory_store if _stable_id(m.get("text","")) != mem_id]
322
+ return ok or (len(memory_store) < before)
323
+
324
+ # ===============================
325
+ # Tools & Agent
326
+ # ===============================
327
+
328
+ def save_media_to_supabase(path: str, media_type: str, prompt: str = "") -> Optional[str]:
329
+ if not supabase_client:
330
+ return None
331
+ try:
332
+ file_name = os.path.basename(path)
333
+ bucket = "generated_media"
334
+ with open(path, "rb") as f:
335
+ supabase_client.storage.from_(bucket).upload(file_name, f, {"upsert": True})
336
+ return f"{SUPABASE_URL}/storage/v1/object/public/{bucket}/{file_name}"
337
+ except Exception as e:
338
+ print("⚠️ save_media_to_supabase failed:", e)
339
+ return None
340
+
341
+ class Tool:
342
+ name: str
343
+ description: str
344
+ def run(self, args: str) -> Dict[str, Any]:
345
+ raise NotImplementedError
346
+
347
+ TOOLS: Dict[str, Tool] = {}
348
+
349
+ def register_tool(tool: Tool):
350
+ TOOLS[tool.name] = tool
351
+
352
+ def call_tool(name: str, args: str) -> Dict[str, Any]:
353
+ tool = TOOLS.get(name)
354
+ if not tool:
355
+ return {"ok": False, "error": f"Tool '{name}' not found"}
356
+ start = time.time()
357
+ try:
358
+ res = tool.run(args)
359
+ return {"ok": True, "result": res, "runtime": time.time() - start}
360
+ except Exception as e:
361
+ return {"ok": False, "error": str(e), "runtime": time.time() - start}
362
+
363
+ class Calculator(Tool):
364
+ name = "calculator"
365
+ description = "Evaluate math expressions using Python's math (e.g., '2+2', 'sin(1)')."
366
+ def run(self, args: str) -> Dict[str, Any]:
367
+ allowed = {k: getattr(math, k) for k in dir(math) if not k.startswith("_")}
368
+ allowed.update({"abs": abs, "round": round, "min": min, "max": max})
369
+ expr = args.strip()
370
+ if "__" in expr:
371
+ raise ValueError("Invalid expression")
372
+ val = eval(expr, {"__builtins__": {}}, allowed)
373
+ return {"input": expr, "value": val}
374
+
375
+ register_tool(Calculator())
376
+
377
+ class PythonSandbox(Tool):
378
+ name = "python_sandbox"
379
+ description = "Run a short Python script in a subprocess (timeout 2s). For production, isolate via container."
380
+ def run(self, args: str) -> Dict[str, Any]:
381
+ code = args
382
+ with tempfile.TemporaryDirectory() as td:
383
+ path = os.path.join(td, "script.py")
384
+ with open(path, "w") as f:
385
+ f.write(code)
386
+ cmd = f"timeout 2 python3 {shlex.quote(path)}"
387
+ proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
388
+ try:
389
+ out, err = proc.communicate(timeout=4)
390
+ except subprocess.TimeoutExpired:
391
+ proc.kill()
392
+ return {"stdout": "", "stderr": "Execution timed out", "returncode": 124}
393
+ return {"stdout": out.decode("utf-8", errors="ignore"), "stderr": err.decode("utf-8", errors="ignore"), "returncode": proc.returncode}
394
+
395
+ register_tool(PythonSandbox())
396
+
397
+ class WebSearchTool(Tool):
398
+ name = "web_search"
399
+ description = "DuckDuckGo search. Returns top snippets (no links)."
400
+ def run(self, args: str) -> Dict[str, Any]:
401
+ if not ddg_func:
402
+ return {"error": "duckduckgo-search not installed"}
403
+ q = args.strip()
404
+ try:
405
+ results = ddg_func(q, max_results=3)
406
+ except TypeError:
407
+ results = ddg_func(keywords=q, max_results=3)
408
+ snippets = []
409
+ for r in results or []:
410
+ snippets.append(r.get("body") or r.get("snippet") or r.get("title") or "")
411
+ return {"query": q, "snippets": [s for s in snippets if s]}
412
+
413
+ register_tool(WebSearchTool())
414
+
415
+ class NodeSandbox(Tool):
416
+ name = "node_sandbox"
417
+ description = "Run short JavaScript code via Node.js (timeout 2s)."
418
+ def run(self, args: str) -> Dict[str, Any]:
419
+ code = args
420
+ with tempfile.TemporaryDirectory() as td:
421
+ path = os.path.join(td, "script.js")
422
+ with open(path, "w") as f:
423
+ f.write(code)
424
+ cmd = f"timeout 2 node {shlex.quote(path)}"
425
+ proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
426
+ try:
427
+ out, err = proc.communicate(timeout=3)
428
+ except subprocess.TimeoutExpired:
429
+ proc.kill()
430
+ return {"stdout": "", "stderr": "Execution timed out", "returncode": 124}
431
+ return {"stdout": out.decode("utf-8", errors="ignore"), "stderr": err.decode("utf-8", errors="ignore"), "returncode": proc.returncode}
432
+
433
+ register_tool(NodeSandbox())
434
+
435
+ class BashSandbox(Tool):
436
+ name = "bash_sandbox"
437
+ description = "Run safe shell commands (timeout 2s)."
438
+ def run(self, args: str) -> Dict[str, Any]:
439
+ cmd = args.strip()
440
+ if ";" in cmd or "&&" in cmd or "|" in cmd:
441
+ return {"stdout": "", "stderr": "Unsafe command detected", "returncode": 1}
442
+ safe_cmd = shlex.split(cmd)
443
+ proc = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
444
+ try:
445
+ out, err = proc.communicate(timeout=2)
446
+ except subprocess.TimeoutExpired:
447
+ proc.kill()
448
+ return {"stdout":"", "stderr":"Execution timed out", "returncode":124}
449
+ return {"stdout": out.decode("utf-8", errors="ignore"), "stderr": err.decode("utf-8", errors="ignore"), "returncode": proc.returncode}
450
+
451
+ register_tool(BashSandbox())
452
+
453
+ def agent_run(llm_func, system_prompt: str, user_prompt: str, chat_history: List[Tuple[str,str]] = None, max_steps: int = 4):
454
+ chat_history = chat_history or []
455
+ tools_info = "\n".join([f"{name}: {TOOLS[name].description}" for name in TOOLS])
456
+ agent_hdr = (
457
+ f"{system_prompt}\n\nAvailable tools:\n{tools_info}\n\n"
458
+ "When you want to call a tool respond ONLY with a JSON object:\n"
459
+ '{"action":"tool_name","args":"..."}\n'
460
+ 'When finished respond: {"action":"final","answer":"..."}\n'
461
+ )
462
+ context = agent_hdr + f"\nUser: {user_prompt}\n"
463
+ for _ in range(max_steps):
464
+ model_out = llm_func(context)
465
+ try:
466
+ first_line = model_out.strip().splitlines()[0]
467
+ action_obj = json.loads(first_line)
468
+ except Exception:
469
+ return {"final": model_out}
470
+ act = action_obj.get("action")
471
+ if act == "final":
472
+ return {"final": action_obj.get("answer", "")}
473
+ args = action_obj.get("args", "")
474
+ tool_res = call_tool(act, args)
475
+ context += f"\nToolCall: {json.dumps({'tool': act, 'args': args})}\nToolResult: {json.dumps(tool_res)}\n"
476
+ return {"final": "Max steps reached. Partial reasoning returned.", "context": context}
477
+
478
+ # ===============================
479
+ # Multimodal (ASR / TTS / Vision / Music)
480
+ # ===============================
481
+ ASR_MODEL = None
482
+ def init_asr():
483
+ global ASR_MODEL
484
+ if DISABLE_MULTIMODAL or WhisperModel is None:
485
+ return None
486
+ if ASR_MODEL is None:
487
+ try:
488
+ device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
489
+ compute_type = "float16" if device == "cuda" else "int8"
490
+ ASR_MODEL = WhisperModel(ASR_MODEL_SIZE, device=device, compute_type=compute_type)
491
+ print(f"✅ ASR model loaded: {ASR_MODEL_SIZE} on {device}")
492
+ except Exception as e:
493
+ print("⚠️ ASR init failed:", e)
494
+ ASR_MODEL = None
495
+ return ASR_MODEL
496
+
497
+ def transcribe_audio(path: str, language: Optional[str] = None, task: str = "transcribe"):
498
+ if init_asr() is None:
499
+ return {"text": "asr-disabled"}
500
+ segments, info = ASR_MODEL.transcribe(path, language=language, task=task)
501
+ text = " ".join(seg.text for seg in segments)
502
+ return {"text": text, "duration": getattr(info, "duration", None)}
503
+
504
+ TTS_CLIENT = None
505
+ def init_tts():
506
+ global TTS_CLIENT
507
+ if DISABLE_MULTIMODAL or CoquiTTS is None:
508
+ return None
509
+ if TTS_CLIENT is None:
510
+ try:
511
+ TTS_CLIENT = CoquiTTS(model_name=COQUI_TTS_MODEL) if COQUI_TTS_MODEL else CoquiTTS()
512
+ print("✅ Coqui TTS initialized")
513
+ except Exception as e:
514
+ print("⚠️ TTS init failed:", e)
515
+ TTS_CLIENT = None
516
+ return TTS_CLIENT
517
+
518
+ def synthesize_to_file(text: str, voice: Optional[str] = None, out_path: Optional[str] = None):
519
+ out_path = out_path or f"/tmp/tts_{uuid.uuid4().hex}.mp3"
520
+ if init_tts() is None:
521
+ open(out_path, "wb").close()
522
+ return {"path": out_path}
523
+ try:
524
+ # Some models require specific speaker names; None often works with single-speaker
525
+ TTS_CLIENT.tts_to_file(text=text, speaker=voice, file_path=out_path)
526
+ except Exception as e:
527
+ print("⚠️ TTS synthesis failed:", e)
528
+ open(out_path, "wb").close()
529
+ return {"path": out_path}
530
+
531
+ BLIP_PROCESSOR = BLIP_MODEL = None
532
+ def init_vision():
533
+ global BLIP_PROCESSOR, BLIP_MODEL
534
+ if DISABLE_MULTIMODAL or (Blip2Processor is None or Blip2ForConditionalGeneration is None or Image is None):
535
+ return None, None
536
+ if BLIP_MODEL is None:
537
+ try:
538
+ model_name = "Salesforce/blip2-flan-t5-base"
539
+ BLIP_PROCESSOR = Blip2Processor.from_pretrained(model_name)
540
+ BLIP_MODEL = Blip2ForConditionalGeneration.from_pretrained(model_name)
541
+ device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
542
+ BLIP_MODEL.to(device)
543
+ print(f"✅ BLIP-2 loaded on {device}")
544
+ except Exception as e:
545
+ print("⚠️ Vision init failed:", e)
546
+ BLIP_PROCESSOR = BLIP_MODEL = None
547
+ return BLIP_PROCESSOR, BLIP_MODEL
548
+
549
+ def caption_image(path: str) -> str:
550
+ proc, model = init_vision()
551
+ if not proc or not model:
552
+ return "A photo (caption placeholder)."
553
+ device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
554
+ img = Image.open(path).convert("RGB")
555
+ inputs = proc(images=img, return_tensors="pt")
556
+ for k in inputs:
557
+ inputs[k] = inputs[k].to(device)
558
+ out_ids = model.generate(**inputs, max_new_tokens=64)
559
+ return proc.decode(out_ids[0], skip_special_tokens=True)
560
+
561
+ def ocr_image(path: str) -> str:
562
+ # Placeholder (integrate easyocr or pytesseract as needed)
563
+ return "OCR placeholder text."
564
+
565
+ MUSIC_MODEL = None
566
+ def init_music():
567
+ global MUSIC_MODEL
568
+ if DISABLE_MULTIMODAL or musicgen_lib is None:
569
+ return None
570
+ if MUSIC_MODEL is None:
571
+ try:
572
+ MUSIC_MODEL = musicgen_lib.MusicGen.get_pretrained("melody")
573
+ device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
574
+ MUSIC_MODEL.to(device)
575
+ print(f"✅ MusicGen loaded on {device}")
576
+ except Exception as e:
577
+ print("⚠️ Music init failed:", e)
578
+ MUSIC_MODEL = None
579
+ return MUSIC_MODEL
580
+
581
+ def generate_music(prompt: str, duration: int = 20) -> Dict[str, Any]:
582
+ out = f"/tmp/music_{uuid.uuid4().hex}.wav"
583
+ if init_music() is None:
584
+ open(out, "wb").close()
585
+ return {"path": out}
586
+ try:
587
+ wav = MUSIC_MODEL.generate([prompt], duration=duration)
588
+ # audiocraft write helper changed over time; safest: torchaudio or soundfile
589
+ try:
590
+ import torchaudio
591
+ torchaudio.save(out, wav[0].cpu(), 32000)
592
+ except Exception:
593
+ # fallback empty file
594
+ open(out, "wb").close()
595
+ except Exception as e:
596
+ print("⚠️ Music generation failed:", e)
597
+ open(out, "wb").close()
598
+ return {"path": out}
599
+
600
+ # ===============================
601
+ # LLM loading & generation
602
+ # ===============================
603
+ MODEL = None
604
+ TOKENIZER = None
605
+ MODEL_DEVICE = "cpu"
606
+
607
+ def load_llm(model_id: str = DEFAULT_MODEL, use_bnb: bool = True):
608
+ global MODEL, TOKENIZER, MODEL_DEVICE
609
+ if AutoTokenizer is None or AutoModelForCausalLM is None:
610
+ raise RuntimeError("transformers is required. pip install transformers")
611
+ TOKENIZER = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
612
+ if TOKENIZER.pad_token_id is None:
613
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
614
+ kwargs = {}
615
+ if torch is not None and torch.cuda.is_available():
616
+ MODEL_DEVICE = "cuda"
617
+ if BitsAndBytesConfig is not None and use_bnb:
618
+ bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
619
+ bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
620
+ kwargs.update(dict(device_map="auto", quantization_config=bnb, token=HF_TOKEN))
621
+ else:
622
+ kwargs.update(dict(device_map="auto",
623
+ torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
624
+ token=HF_TOKEN))
625
+ else:
626
+ MODEL_DEVICE = "cpu"
627
+ kwargs.update(dict(torch_dtype=torch.float32, token=HF_TOKEN))
628
+ try:
629
+ MODEL = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
630
+ except TypeError:
631
+ kwargs.pop("token", None)
632
+ MODEL = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=HF_TOKEN, **kwargs)
633
+ print(f"✅ LLM loaded on {MODEL_DEVICE}")
634
+
635
+ def _get_eos_token_id():
636
+ if TOKENIZER is None:
637
+ return None
638
+ eid = getattr(TOKENIZER, "eos_token_id", None)
639
+ if isinstance(eid, list) and eid:
640
+ return eid[0]
641
+ return eid
642
+
643
+ def make_system_prompt(local_knowledge: str) -> str:
644
+ base = ("You are Billy AI — a helpful, witty, and precise assistant. "
645
+ "Be concise but thorough; use bullet points; cite assumptions; avoid hallucinations.")
646
+ if local_knowledge:
647
+ base += f"\nUseful context: {local_knowledge[:3000]}"
648
+ return base
649
+
650
+ def build_prompt(user_prompt: str, chat_history: List[Tuple[str,str]]) -> str:
651
+ context = retrieve_knowledge(user_prompt, k=5)
652
+ system = make_system_prompt(context)
653
+ hist = ""
654
+ for u, a in (chat_history or []):
655
+ if u:
656
+ hist += f"\nUser: {u}\nAssistant: {a or ''}"
657
+ return f"<s>[INST]{system}[/INST]</s>\n{hist}\n[INST]User: {user_prompt}\nAssistant:[/INST]"
658
+
659
+ def generate_text_sync(prompt_text: str, max_tokens: int = 600, temperature: float = 0.6, top_p: float = 0.9) -> str:
660
+ if MODEL is None or TOKENIZER is None:
661
+ raise RuntimeError("LLM not loaded")
662
+ inputs = TOKENIZER(prompt_text, return_tensors="pt").to(MODEL_DEVICE)
663
+ out_ids = MODEL.generate(
664
+ **inputs,
665
+ max_new_tokens=min(max_tokens, 2048),
666
+ do_sample=True,
667
+ temperature=temperature,
668
+ top_p=top_p,
669
+ pad_token_id=TOKENIZER.pad_token_id,
670
+ eos_token_id=_get_eos_token_id(),
671
+ )
672
+ text = TOKENIZER.decode(out_ids[0], skip_special_tokens=True)
673
+ if text.startswith(prompt_text):
674
+ return text[len(prompt_text):].strip()
675
+ return text.strip()
676
+
677
+ def stream_generate(prompt_text: str, max_tokens: int = 600, temperature: float = 0.6, top_p: float = 0.9):
678
+ if MODEL is None or TOKENIZER is None:
679
+ yield "ERROR: model not loaded"
680
+ return
681
+ inputs = TOKENIZER(prompt_text, return_tensors="pt").to(MODEL_DEVICE)
682
+ streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
683
+ def _gen():
684
+ MODEL.generate(
685
+ **inputs,
686
+ max_new_tokens=min(max_tokens, 2048),
687
+ do_sample=True,
688
+ temperature=temperature,
689
+ top_p=top_p,
690
+ pad_token_id=TOKENIZER.pad_token_id,
691
+ eos_token_id=_get_eos_token_id(),
692
+ streamer=streamer
693
+ )
694
+ Thread(target=_gen).start()
695
+ for tok in streamer:
696
+ yield tok
697
+
698
+ # ===============================
699
+ # Schemas
700
+ # ===============================
701
+ class GenerateRequest(BaseModel):
702
+ prompt: str
703
+ chat_history: Optional[List[Tuple[str,str]]] = []
704
+ max_tokens: int = 600
705
+ temperature: float = 0.6
706
+ top_p: float = 0.9
707
+ max_steps: int = 4 # for agent if triggered
708
+
709
+ class EmbedRequest(BaseModel):
710
+ texts: List[str]
711
+
712
+ class RememberRequest(BaseModel):
713
+ text: str
714
+ user_id: Optional[str] = None
715
+
716
+ class SearchRequest(BaseModel):
717
+ query: str
718
+ max_results: int = 3
719
+
720
+ class MusicRequest(BaseModel):
721
+ prompt: str
722
+ style: Optional[str] = None
723
+ duration: Optional[int] = 20
724
+
725
+ class TTSRequest(BaseModel):
726
+ text: str
727
+ voice: Optional[str] = "default"
728
+ format: Optional[str] = "mp3"
729
+
730
+ class AgentRequest(BaseModel):
731
+ prompt: str
732
+ chat_history: Optional[List[Tuple[str,str]]] = []
733
+ max_steps: int = 4
734
+
735
+ class ForgetRequest(BaseModel):
736
+ id: str
737
+
738
+ # ===============================
739
+ # Endpoints
740
+ # ===============================
741
+ @app.get("/health")
742
+ def health( ):
743
+ return {"status": "ok"}
744
+
745
+
746
+ @app.post("/generate", dependencies=[Depends(api_key_auth)])
747
+ def generate(req: GenerateRequest):
748
+ rl_key = f"rl:{hashlib.sha1((req.prompt or '').encode()).hexdigest()}"
749
+ if not rate_limit(rl_key, limit=120, window=60):
750
+ raise HTTPException(status_code=429, detail="Rate limit exceeded")
751
+
752
+ safe, reason = is_safe_message(req.prompt)
753
+ if not safe:
754
+ return JSONResponse({"error": reason or "Unsafe prompt"}, status_code=400)
755
+
756
+ if req.prompt.strip().lower().startswith("use tool:") or "CALL_TOOL" in req.prompt:
757
+ def _llm(p): return generate_text_sync(p, max_tokens=400, temperature=0.2, top_p=0.9)
758
+ out = agent_run(_llm, make_system_prompt(retrieve_knowledge(req.prompt, k=5)),
759
+ req.prompt, req.chat_history or [], max_steps=req.max_steps)
760
+ return out
761
+
762
+ prompt = build_prompt(req.prompt, req.chat_history or [])
763
+ cache_key = f"resp:{hashlib.sha1(prompt.encode()).hexdigest()}"
764
+ cached = cache_get(cache_key)
765
+ if cached:
766
+ return {"response": cached}
767
+
768
+ out = generate_text_sync(prompt, max_tokens=req.max_tokens,
769
+ temperature=req.temperature, top_p=req.top_p)
770
+ safe_out, _ = is_safe_message(out)
771
+ if not safe_out:
772
+ return JSONResponse({"error": "Response blocked by moderation."}, status_code=400)
773
+ cache_set(cache_key, out, ttl=30)
774
+ return {"response": out}
775
+
776
+
777
+ @app.post("/stream", dependencies=[Depends(api_key_auth)])
778
+ def stream(req: GenerateRequest):
779
+ safe, reason = is_safe_message(req.prompt)
780
+ if not safe:
781
+ return StreamingResponse(iter([reason or "Unsafe prompt"]), media_type="text/plain")
782
+ prompt = build_prompt(req.prompt, req.chat_history or [])
783
+ def gen():
784
+ for chunk in stream_generate(prompt, max_tokens=req.max_tokens,
785
+ temperature=req.temperature, top_p=req.top_p):
786
+ yield chunk
787
+ return StreamingResponse(gen(), media_type="text/plain")
788
+
789
+
790
+ @app.post("/agent", dependencies=[Depends(api_key_auth)])
791
+ def agent_endpoint(req: AgentRequest):
792
+ def _llm(p): return generate_text_sync(p, max_tokens=400, temperature=0.2, top_p=0.9)
793
+ out = agent_run(_llm, make_system_prompt(retrieve_knowledge(req.prompt, k=5)),
794
+ req.prompt, req.chat_history or [], max_steps=req.max_steps)
795
+ return out
796
+
797
+
798
+ @app.post("/embed", dependencies=[Depends(api_key_auth)])
799
+ def embed(req: EmbedRequest):
800
+ if not embedder:
801
+ return JSONResponse({"error": "Embedder not loaded."}, status_code=500)
802
+ vecs = [embed_text_local(t) for t in req.texts]
803
+ for t in req.texts:
804
+ store_knowledge(t)
805
+ return {"embeddings": vecs}
806
+
807
+
808
+ @app.post("/remember", dependencies=[Depends(api_key_auth)])
809
+ def remember(req: RememberRequest):
810
+ store_knowledge(req.text, user_id=req.user_id if hasattr(req, "user_id") else None)
811
+ return {"status": "stored"}
812
+
813
+
814
+ @app.post("/search", dependencies=[Depends(api_key_auth)])
815
+ def web_search(req: SearchRequest):
816
+ ws = TOOLS.get("web_search")
817
+ if not ws:
818
+ return {"ingested": 0, "context_sample": ""}
819
+ res = ws.run(req.query)
820
+ count = 0
821
+ for s in res.get("snippets", []):
822
+ store_knowledge(s)
823
+ count += 1
824
+ ctx = retrieve_knowledge(req.query, k=req.max_results or 3)
825
+ return {"ingested": count, "context_sample": ctx[:1000]}
826
+
827
+
828
+ @app.post("/music", dependencies=[Depends(api_key_auth)])
829
+ def music(req: MusicRequest, background_tasks: BackgroundTasks):
830
+ try:
831
+ tmp = generate_music(req.prompt, duration=req.duration or 20).get("path")
832
+ url = save_media_to_supabase(tmp, "audio", prompt=req.prompt)
833
+ return {
834
+ "reply": f"Generated music for: {req.prompt}",
835
+ "audioUrl": url or tmp
836
+ }
837
+ except Exception as e:
838
+ return JSONResponse({"error": str(e)}, status_code=500)
839
+
840
+ # === Updated TTS with Supabase ===
841
+ @app.post("/tts", dependencies=[Depends(api_key_auth)])
842
+ def tts(req: TTSRequest):
843
+ try:
844
+ out = synthesize_to_file(req.text, voice=req.voice)
845
+ url = save_media_to_supabase(out["path"], "audio", prompt=req.text)
846
+ return {"audioUrl": url or out["path"]}
847
+ except Exception as e:
848
+ return JSONResponse({"error": str(e)}, status_code=500)
849
+
850
+
851
+ @app.post("/tts_stream", dependencies=[Depends(api_key_auth)])
852
+ def tts_stream(req: TTSRequest):
853
+ try:
854
+ out = synthesize_to_file(req.text, voice=req.voice)
855
+ def iterfile():
856
+ with open(out["path"], "rb") as f:
857
+ while True:
858
+ chunk = f.read(4096)
859
+ if not chunk:
860
+ break
861
+ yield chunk
862
+ return StreamingResponse(iterfile(), media_type="audio/mpeg")
863
+ except Exception as e:
864
+ return JSONResponse({"error": str(e)}, status_code=500)
865
+
866
+
867
+ @app.post("/asr", dependencies=[Depends(api_key_auth)])
868
+ async def asr(file: UploadFile = File(...)):
869
+ try:
870
+ tmp = f"/tmp/asr_{uuid.uuid4().hex}_{file.filename}"
871
+ with open(tmp, "wb") as f:
872
+ f.write(await file.read())
873
+ res = transcribe_audio(tmp)
874
+ return {"transcript": res.get("text", "")}
875
+ except Exception as e:
876
+ return JSONResponse({"error": str(e)}, status_code=500)
877
+
878
+
879
+ @app.post("/vision", dependencies=[Depends(api_key_auth)])
880
+ async def vision(file: UploadFile = File(...), task: Optional[str] = "caption"):
881
+ try:
882
+ tmp = f"/tmp/vision_{uuid.uuid4().hex}.jpg"
883
+ with open(tmp, "wb") as f:
884
+ f.write(await file.read())
885
+ if (task or "").lower() == "ocr":
886
+ text = ocr_image(tmp)
887
+ return {"text": text}
888
+ caption = caption_image(tmp)
889
+ return {"caption": caption}
890
+ except Exception as e:
891
+ return JSONResponse({"error": str(e)}, status_code=500)
892
+
893
+
894
+ @app.websocket("/ws/generate")
895
+ async def websocket_generate(ws: WebSocket):
896
+ await ws.accept()
897
+ try:
898
+ while True:
899
+ data = await ws.receive_json()
900
+ prompt = data.get("prompt", "")
901
+ chat_history = data.get("chat_history", [])
902
+ max_tokens = int(data.get("max_tokens", 256))
903
+ temperature = float(data.get("temperature", 0.6))
904
+ top_p = float(data.get("top_p", 0.9))
905
+
906
+ built = build_prompt(prompt, chat_history or [])
907
+ inputs = TOKENIZER(built, return_tensors="pt").to(MODEL_DEVICE)
908
+ streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
909
+
910
+ def run_gen():
911
+ MODEL.generate(
912
+ **inputs,
913
+ max_new_tokens=min(max_tokens, 2048),
914
+ do_sample=True,
915
+ temperature=temperature,
916
+ top_p=top_p,
917
+ pad_token_id=TOKENIZER.pad_token_id,
918
+ eos_token_id=_get_eos_token_id(),
919
+ streamer=streamer
920
+ )
921
+ Thread(target=run_gen).start()
922
+
923
+ accumulated = ""
924
+ for token in streamer:
925
+ accumulated += token
926
+ await ws.send_json({"delta": token})
927
+ safe_out, _ = is_safe_message(accumulated)
928
+ if not safe_out:
929
+ await ws.send_json({"done": True, "final": "⚠️ Response blocked by moderation."})
930
+ else:
931
+ await ws.send_json({"done": True, "final": accumulated})
932
+ except Exception:
933
+ await ws.close()
934
+
935
+
936
+ @app.get("/admin/memory")
937
+ def admin_memory():
938
+ return {"count": len(memory_store)}
939
+
940
+
941
+ @app.post("/forget", dependencies=[Depends(api_key_auth)])
942
+ def forget(req: ForgetRequest):
943
+ try:
944
+ ok = delete_memory_by_id(req.id)
945
+ if ok:
946
+ return {"status": "deleted"}
947
+ return JSONResponse({"error": "Not found"}, status_code=404)
948
+ except Exception as e:
949
+ return JSONResponse({"error": str(e)}, status_code=500)
950
+
951
+
952
+ @app.get("/metrics")
953
+ def metrics():
954
+ if not generate_latest or not CONTENT_TYPE_LATEST:
955
+ return JSONResponse({"error": "prometheus-client not installed"}, status_code=500)
956
+ return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
957
+
958
+
959
+ # === New: Library Endpoint ===
960
+ @app.get("/library", dependencies=[Depends(api_key_auth)])
961
+ def list_library(page: int = 0, page_size: int = 12):
962
+ if not supabase_client:
963
+ return {"items": []}
964
+ try:
965
+ start = page * page_size
966
+ end = start + page_size - 1
967
+ resp = supabase_client.table("generated_media") \
968
+ .select("*") \
969
+ .order("created_at", desc=True) \
970
+ .range(start, end) \
971
+ .execute()
972
+ return {"items": resp.data or []}
973
+ except Exception as e:
974
+ return JSONResponse({"error": str(e)}, status_code=500)
975
+
976
+ @app.on_event("startup")
977
+ def on_startup():
978
+ try:
979
+ load_llm(DEFAULT_MODEL)
980
+ except Exception as e:
981
+ print("⚠️ LLM load failed:", e)
982
+ try:
983
+ if not DISABLE_MULTIMODAL:
984
+ init_asr()
985
+ init_tts()
986
+ init_vision()
987
+ init_music()
988
+ except Exception:
989
+ pass
990
+ print("🚀 Billy AI startup complete")
README.md DELETED
@@ -1,3 +0,0 @@
1
- ---
2
- license: mit
3
- ---