Rename README.md to Main.js
Browse files
Main.js
ADDED
|
@@ -0,0 +1,990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import shlex
|
| 7 |
+
import tempfile
|
| 8 |
+
import hashlib
|
| 9 |
+
import subprocess
|
| 10 |
+
from threading import Thread
|
| 11 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
load_dotenv() # loads .env from the same directory
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI, UploadFile, File, WebSocket, BackgroundTasks, Depends, Header, HTTPException, Response
|
| 18 |
+
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from pydantic import BaseModel
|
| 21 |
+
|
| 22 |
+
# ===============================
|
| 23 |
+
# Optional heavy deps (guarded)
|
| 24 |
+
# ===============================
|
| 25 |
+
try:
|
| 26 |
+
import torch
|
| 27 |
+
except Exception:
|
| 28 |
+
torch = None
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
|
| 32 |
+
except Exception:
|
| 33 |
+
AutoTokenizer = AutoModelForCausalLM = TextIteratorStreamer = BitsAndBytesConfig = None
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from sentence_transformers import SentenceTransformer
|
| 37 |
+
except Exception:
|
| 38 |
+
SentenceTransformer = None
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
import chromadb
|
| 42 |
+
except Exception:
|
| 43 |
+
chromadb = None
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
from supabase import Client, create_client
|
| 47 |
+
except Exception:
|
| 48 |
+
create_client = None
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
import redis as redis_lib
|
| 52 |
+
except Exception:
|
| 53 |
+
redis_lib = None
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
from faster_whisper import WhisperModel
|
| 57 |
+
except Exception:
|
| 58 |
+
WhisperModel = None
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
from TTS.api import TTS as CoquiTTS
|
| 62 |
+
except Exception:
|
| 63 |
+
CoquiTTS = None
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
| 67 |
+
except Exception:
|
| 68 |
+
Blip2Processor = Blip2ForConditionalGeneration = None
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
from PIL import Image
|
| 72 |
+
except Exception:
|
| 73 |
+
Image = None
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from audiocraft.models import musicgen as musicgen_lib
|
| 77 |
+
except Exception:
|
| 78 |
+
musicgen_lib = None
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
from openai import OpenAI
|
| 82 |
+
except Exception:
|
| 83 |
+
OpenAI = None
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
from duckduckgo_search import ddg as ddg_func
|
| 87 |
+
except Exception:
|
| 88 |
+
ddg_func = None
|
| 89 |
+
|
| 90 |
+
# Prometheus (optional)
|
| 91 |
+
try:
|
| 92 |
+
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
| 93 |
+
except Exception:
|
| 94 |
+
generate_latest = CONTENT_TYPE_LATEST = None
|
| 95 |
+
|
| 96 |
+
# ===============================
|
| 97 |
+
# Environment / Config
|
| 98 |
+
# ===============================
|
| 99 |
+
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
| 100 |
+
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 101 |
+
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
| 102 |
+
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 103 |
+
DEFAULT_MODEL = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
| 104 |
+
EMBED_MODEL = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2")
|
| 105 |
+
FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "changeme")
|
| 106 |
+
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,https://zynara.xyz")
|
| 107 |
+
REDIS_URL = os.getenv("REDIS_URL")
|
| 108 |
+
ASR_MODEL_SIZE = os.getenv("ASR_MODEL", "small")
|
| 109 |
+
COQUI_TTS_MODEL = os.getenv("COQUI_TTS_MODEL") # optional, else auto-pick
|
| 110 |
+
CDN_BASE_URL = os.getenv("CDN_BASE_URL")
|
| 111 |
+
DISABLE_MULTIMODAL = os.getenv("DISABLE_MULTIMODAL", "0") == "1"
|
| 112 |
+
|
| 113 |
+
# ===============================
|
| 114 |
+
# App + CORS
|
| 115 |
+
# ===============================
|
| 116 |
+
app = FastAPI(title="Billy AI — All-in-one", version="1.0.0")
|
| 117 |
+
origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
|
| 118 |
+
app.add_middleware(
|
| 119 |
+
CORSMiddleware,
|
| 120 |
+
allow_origins=origins or ["*"],
|
| 121 |
+
allow_credentials=True,
|
| 122 |
+
allow_methods=["*"],
|
| 123 |
+
allow_headers=["*"],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# ===============================
|
| 127 |
+
# Clients (OpenAI, Redis, Chroma, Supabase)
|
| 128 |
+
# ===============================
|
| 129 |
+
openai_client = OpenAI(api_key=OPENAI_KEY) if (OPENAI_KEY and OpenAI) else None
|
| 130 |
+
|
| 131 |
+
redis_client = None
|
| 132 |
+
if redis_lib and REDIS_URL:
|
| 133 |
+
try:
|
| 134 |
+
redis_client = redis_lib.from_url(REDIS_URL)
|
| 135 |
+
print("✅ Redis connected")
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print("⚠️ Redis init failed:", e)
|
| 138 |
+
|
| 139 |
+
embedder = None
|
| 140 |
+
if SentenceTransformer is not None:
|
| 141 |
+
try:
|
| 142 |
+
embedder = SentenceTransformer(EMBED_MODEL)
|
| 143 |
+
print("✅ Embedder loaded")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print("⚠️ Embedder init failed:", e)
|
| 146 |
+
|
| 147 |
+
chroma_client = None
|
| 148 |
+
chroma_collection = None
|
| 149 |
+
if chromadb is not None:
|
| 150 |
+
try:
|
| 151 |
+
chroma_client = chromadb.PersistentClient(path="./billy_rag_db")
|
| 152 |
+
try:
|
| 153 |
+
chroma_collection = chroma_client.get_collection("billy_rag")
|
| 154 |
+
except Exception:
|
| 155 |
+
chroma_collection = chroma_client.create_collection("billy_rag")
|
| 156 |
+
print("✅ Chroma ready")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print("⚠️ Chroma init failed:", e)
|
| 159 |
+
|
| 160 |
+
supabase_client = None
|
| 161 |
+
if create_client and SUPABASE_URL and SUPABASE_KEY:
|
| 162 |
+
try:
|
| 163 |
+
supabase_client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
| 164 |
+
print("✅ Supabase client initialized")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print("⚠️ Supabase init failed:", e)
|
| 167 |
+
|
| 168 |
+
# ===============================
|
| 169 |
+
# Helpers (IDs, cosine, cache/RL, moderation)
|
| 170 |
+
# ===============================
|
| 171 |
+
def _stable_id(text: str) -> str:
|
| 172 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 173 |
+
|
| 174 |
+
def _cosine(a: List[float], b: List[float]) -> float:
|
| 175 |
+
import numpy as np
|
| 176 |
+
a = np.array(a, dtype=np.float32)
|
| 177 |
+
b = np.array(b, dtype=np.float32)
|
| 178 |
+
na = np.linalg.norm(a) or 1.0
|
| 179 |
+
nb = np.linalg.norm(b) or 1.0
|
| 180 |
+
return float(np.dot(a, b) / (na * nb))
|
| 181 |
+
|
| 182 |
+
async def api_key_auth(x_api_key: Optional[str] = Header(None)):
|
| 183 |
+
if x_api_key is None:
|
| 184 |
+
if FRONTEND_API_KEY == "changeme":
|
| 185 |
+
return True
|
| 186 |
+
raise HTTPException(status_code=401, detail="Missing API key")
|
| 187 |
+
if x_api_key != FRONTEND_API_KEY:
|
| 188 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
def rate_limit(key: str, limit: int = 60, window: int = 60) -> bool:
|
| 192 |
+
if not redis_client:
|
| 193 |
+
return True
|
| 194 |
+
try:
|
| 195 |
+
p = redis_client.pipeline()
|
| 196 |
+
p.incr(key)
|
| 197 |
+
p.expire(key, window)
|
| 198 |
+
val, _ = p.execute()
|
| 199 |
+
return int(val) <= limit
|
| 200 |
+
except Exception:
|
| 201 |
+
return True
|
| 202 |
+
|
| 203 |
+
def cache_get(key: str):
|
| 204 |
+
if not redis_client:
|
| 205 |
+
return None
|
| 206 |
+
try:
|
| 207 |
+
v = redis_client.get(key)
|
| 208 |
+
return json.loads(v) if v else None
|
| 209 |
+
except Exception:
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
def cache_set(key: str, value, ttl: int = 300):
|
| 213 |
+
if not redis_client:
|
| 214 |
+
return
|
| 215 |
+
try:
|
| 216 |
+
redis_client.set(key, json.dumps(value), ex=ttl)
|
| 217 |
+
except Exception:
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
def is_safe_message(text: str) -> Tuple[bool, str]:
|
| 221 |
+
if not text:
|
| 222 |
+
return True, ""
|
| 223 |
+
if openai_client is None:
|
| 224 |
+
# very simple heuristic fallback
|
| 225 |
+
banned = ["kill", "terror", "bomb", "nuke"]
|
| 226 |
+
if any(b in text.lower() for b in banned):
|
| 227 |
+
return False, "Blocked by local safety heuristic."
|
| 228 |
+
return True, ""
|
| 229 |
+
try:
|
| 230 |
+
resp = openai_client.moderations.create(model="omni-moderation-latest", input=text)
|
| 231 |
+
flagged = bool(resp.results[0].flagged)
|
| 232 |
+
return (not flagged), ("Blocked by moderation." if flagged else "")
|
| 233 |
+
except Exception:
|
| 234 |
+
return True, ""
|
| 235 |
+
|
| 236 |
+
# ===============================
|
| 237 |
+
# RAG storage (Chroma/Supabase/memory)
|
| 238 |
+
# ===============================
|
| 239 |
+
memory_store: List[Dict[str, Any]] = []
|
| 240 |
+
|
| 241 |
+
def embed_text_local(text: str) -> List[float]:
|
| 242 |
+
if not embedder:
|
| 243 |
+
raise RuntimeError("Embedder not loaded.")
|
| 244 |
+
return embedder.encode(text).tolist()
|
| 245 |
+
|
| 246 |
+
def store_knowledge(text: str, user_id: Optional[str] = None):
|
| 247 |
+
if not text or not text.strip():
|
| 248 |
+
return
|
| 249 |
+
try:
|
| 250 |
+
vec = embed_text_local(text)
|
| 251 |
+
except Exception:
|
| 252 |
+
return
|
| 253 |
+
idx = _stable_id(text)
|
| 254 |
+
if supabase_client:
|
| 255 |
+
try:
|
| 256 |
+
row = {"id": idx, "text": text, "embedding": vec, "source": "user", "created_at": int(time.time())}
|
| 257 |
+
if user_id:
|
| 258 |
+
row["user_id"] = user_id
|
| 259 |
+
supabase_client.table("knowledge").upsert(row).execute()
|
| 260 |
+
return
|
| 261 |
+
except Exception:
|
| 262 |
+
pass
|
| 263 |
+
if chroma_collection:
|
| 264 |
+
try:
|
| 265 |
+
chroma_collection.add(documents=[text], embeddings=[vec], ids=[idx], metadatas=[{"user_id": user_id}])
|
| 266 |
+
return
|
| 267 |
+
except Exception:
|
| 268 |
+
pass
|
| 269 |
+
memory_store.append({"text": text, "embedding": vec, "user_id": user_id})
|
| 270 |
+
|
| 271 |
+
def retrieve_knowledge(query: str, k: int = 5) -> str:
|
| 272 |
+
try:
|
| 273 |
+
qvec = embed_text_local(query)
|
| 274 |
+
except Exception:
|
| 275 |
+
return ""
|
| 276 |
+
if supabase_client:
|
| 277 |
+
try:
|
| 278 |
+
resp = supabase_client.table("knowledge").select("text,embedding").execute()
|
| 279 |
+
data = resp.data or []
|
| 280 |
+
scored = []
|
| 281 |
+
for item in data:
|
| 282 |
+
emb = item.get("embedding")
|
| 283 |
+
if isinstance(emb, list):
|
| 284 |
+
scored.append((item["text"], _cosine(qvec, emb)))
|
| 285 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 286 |
+
return " ".join([t for t, _ in scored[:k]])
|
| 287 |
+
except Exception:
|
| 288 |
+
pass
|
| 289 |
+
if chroma_collection:
|
| 290 |
+
try:
|
| 291 |
+
res = chroma_collection.query(query_embeddings=[qvec], n_results=k)
|
| 292 |
+
docs = res.get("documents", [])
|
| 293 |
+
if docs and docs[0]:
|
| 294 |
+
return " ".join(docs[0])
|
| 295 |
+
except Exception:
|
| 296 |
+
pass
|
| 297 |
+
if memory_store:
|
| 298 |
+
scored = []
|
| 299 |
+
for item in memory_store:
|
| 300 |
+
scored.append((item["text"], _cosine(qvec, item["embedding"])))
|
| 301 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 302 |
+
return " ".join([t for t, _ in scored[:k]])
|
| 303 |
+
return ""
|
| 304 |
+
|
| 305 |
+
def delete_memory_by_id(mem_id: str) -> bool:
|
| 306 |
+
ok = False
|
| 307 |
+
if supabase_client:
|
| 308 |
+
try:
|
| 309 |
+
supabase_client.table("knowledge").delete().eq("id", mem_id).execute()
|
| 310 |
+
ok = True
|
| 311 |
+
except Exception:
|
| 312 |
+
pass
|
| 313 |
+
if chroma_collection:
|
| 314 |
+
try:
|
| 315 |
+
chroma_collection.delete(ids=[mem_id])
|
| 316 |
+
ok = True
|
| 317 |
+
except Exception:
|
| 318 |
+
pass
|
| 319 |
+
global memory_store
|
| 320 |
+
before = len(memory_store)
|
| 321 |
+
memory_store = [m for m in memory_store if _stable_id(m.get("text","")) != mem_id]
|
| 322 |
+
return ok or (len(memory_store) < before)
|
| 323 |
+
|
| 324 |
+
# ===============================
|
| 325 |
+
# Tools & Agent
|
| 326 |
+
# ===============================
|
| 327 |
+
|
| 328 |
+
def save_media_to_supabase(path: str, media_type: str, prompt: str = "") -> Optional[str]:
|
| 329 |
+
if not supabase_client:
|
| 330 |
+
return None
|
| 331 |
+
try:
|
| 332 |
+
file_name = os.path.basename(path)
|
| 333 |
+
bucket = "generated_media"
|
| 334 |
+
with open(path, "rb") as f:
|
| 335 |
+
supabase_client.storage.from_(bucket).upload(file_name, f, {"upsert": True})
|
| 336 |
+
return f"{SUPABASE_URL}/storage/v1/object/public/{bucket}/{file_name}"
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print("⚠️ save_media_to_supabase failed:", e)
|
| 339 |
+
return None
|
| 340 |
+
|
| 341 |
+
class Tool:
|
| 342 |
+
name: str
|
| 343 |
+
description: str
|
| 344 |
+
def run(self, args: str) -> Dict[str, Any]:
|
| 345 |
+
raise NotImplementedError
|
| 346 |
+
|
| 347 |
+
TOOLS: Dict[str, Tool] = {}
|
| 348 |
+
|
| 349 |
+
def register_tool(tool: Tool):
|
| 350 |
+
TOOLS[tool.name] = tool
|
| 351 |
+
|
| 352 |
+
def call_tool(name: str, args: str) -> Dict[str, Any]:
|
| 353 |
+
tool = TOOLS.get(name)
|
| 354 |
+
if not tool:
|
| 355 |
+
return {"ok": False, "error": f"Tool '{name}' not found"}
|
| 356 |
+
start = time.time()
|
| 357 |
+
try:
|
| 358 |
+
res = tool.run(args)
|
| 359 |
+
return {"ok": True, "result": res, "runtime": time.time() - start}
|
| 360 |
+
except Exception as e:
|
| 361 |
+
return {"ok": False, "error": str(e), "runtime": time.time() - start}
|
| 362 |
+
|
| 363 |
+
class Calculator(Tool):
|
| 364 |
+
name = "calculator"
|
| 365 |
+
description = "Evaluate math expressions using Python's math (e.g., '2+2', 'sin(1)')."
|
| 366 |
+
def run(self, args: str) -> Dict[str, Any]:
|
| 367 |
+
allowed = {k: getattr(math, k) for k in dir(math) if not k.startswith("_")}
|
| 368 |
+
allowed.update({"abs": abs, "round": round, "min": min, "max": max})
|
| 369 |
+
expr = args.strip()
|
| 370 |
+
if "__" in expr:
|
| 371 |
+
raise ValueError("Invalid expression")
|
| 372 |
+
val = eval(expr, {"__builtins__": {}}, allowed)
|
| 373 |
+
return {"input": expr, "value": val}
|
| 374 |
+
|
| 375 |
+
register_tool(Calculator())
|
| 376 |
+
|
| 377 |
+
class PythonSandbox(Tool):
|
| 378 |
+
name = "python_sandbox"
|
| 379 |
+
description = "Run a short Python script in a subprocess (timeout 2s). For production, isolate via container."
|
| 380 |
+
def run(self, args: str) -> Dict[str, Any]:
|
| 381 |
+
code = args
|
| 382 |
+
with tempfile.TemporaryDirectory() as td:
|
| 383 |
+
path = os.path.join(td, "script.py")
|
| 384 |
+
with open(path, "w") as f:
|
| 385 |
+
f.write(code)
|
| 386 |
+
cmd = f"timeout 2 python3 {shlex.quote(path)}"
|
| 387 |
+
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 388 |
+
try:
|
| 389 |
+
out, err = proc.communicate(timeout=4)
|
| 390 |
+
except subprocess.TimeoutExpired:
|
| 391 |
+
proc.kill()
|
| 392 |
+
return {"stdout": "", "stderr": "Execution timed out", "returncode": 124}
|
| 393 |
+
return {"stdout": out.decode("utf-8", errors="ignore"), "stderr": err.decode("utf-8", errors="ignore"), "returncode": proc.returncode}
|
| 394 |
+
|
| 395 |
+
register_tool(PythonSandbox())
|
| 396 |
+
|
| 397 |
+
class WebSearchTool(Tool):
|
| 398 |
+
name = "web_search"
|
| 399 |
+
description = "DuckDuckGo search. Returns top snippets (no links)."
|
| 400 |
+
def run(self, args: str) -> Dict[str, Any]:
|
| 401 |
+
if not ddg_func:
|
| 402 |
+
return {"error": "duckduckgo-search not installed"}
|
| 403 |
+
q = args.strip()
|
| 404 |
+
try:
|
| 405 |
+
results = ddg_func(q, max_results=3)
|
| 406 |
+
except TypeError:
|
| 407 |
+
results = ddg_func(keywords=q, max_results=3)
|
| 408 |
+
snippets = []
|
| 409 |
+
for r in results or []:
|
| 410 |
+
snippets.append(r.get("body") or r.get("snippet") or r.get("title") or "")
|
| 411 |
+
return {"query": q, "snippets": [s for s in snippets if s]}
|
| 412 |
+
|
| 413 |
+
register_tool(WebSearchTool())
|
| 414 |
+
|
| 415 |
+
class NodeSandbox(Tool):
|
| 416 |
+
name = "node_sandbox"
|
| 417 |
+
description = "Run short JavaScript code via Node.js (timeout 2s)."
|
| 418 |
+
def run(self, args: str) -> Dict[str, Any]:
|
| 419 |
+
code = args
|
| 420 |
+
with tempfile.TemporaryDirectory() as td:
|
| 421 |
+
path = os.path.join(td, "script.js")
|
| 422 |
+
with open(path, "w") as f:
|
| 423 |
+
f.write(code)
|
| 424 |
+
cmd = f"timeout 2 node {shlex.quote(path)}"
|
| 425 |
+
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 426 |
+
try:
|
| 427 |
+
out, err = proc.communicate(timeout=3)
|
| 428 |
+
except subprocess.TimeoutExpired:
|
| 429 |
+
proc.kill()
|
| 430 |
+
return {"stdout": "", "stderr": "Execution timed out", "returncode": 124}
|
| 431 |
+
return {"stdout": out.decode("utf-8", errors="ignore"), "stderr": err.decode("utf-8", errors="ignore"), "returncode": proc.returncode}
|
| 432 |
+
|
| 433 |
+
register_tool(NodeSandbox())
|
| 434 |
+
|
| 435 |
+
class BashSandbox(Tool):
|
| 436 |
+
name = "bash_sandbox"
|
| 437 |
+
description = "Run safe shell commands (timeout 2s)."
|
| 438 |
+
def run(self, args: str) -> Dict[str, Any]:
|
| 439 |
+
cmd = args.strip()
|
| 440 |
+
if ";" in cmd or "&&" in cmd or "|" in cmd:
|
| 441 |
+
return {"stdout": "", "stderr": "Unsafe command detected", "returncode": 1}
|
| 442 |
+
safe_cmd = shlex.split(cmd)
|
| 443 |
+
proc = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
| 444 |
+
try:
|
| 445 |
+
out, err = proc.communicate(timeout=2)
|
| 446 |
+
except subprocess.TimeoutExpired:
|
| 447 |
+
proc.kill()
|
| 448 |
+
return {"stdout":"", "stderr":"Execution timed out", "returncode":124}
|
| 449 |
+
return {"stdout": out.decode("utf-8", errors="ignore"), "stderr": err.decode("utf-8", errors="ignore"), "returncode": proc.returncode}
|
| 450 |
+
|
| 451 |
+
register_tool(BashSandbox())
|
| 452 |
+
|
| 453 |
+
def agent_run(llm_func, system_prompt: str, user_prompt: str, chat_history: List[Tuple[str,str]] = None, max_steps: int = 4):
|
| 454 |
+
chat_history = chat_history or []
|
| 455 |
+
tools_info = "\n".join([f"{name}: {TOOLS[name].description}" for name in TOOLS])
|
| 456 |
+
agent_hdr = (
|
| 457 |
+
f"{system_prompt}\n\nAvailable tools:\n{tools_info}\n\n"
|
| 458 |
+
"When you want to call a tool respond ONLY with a JSON object:\n"
|
| 459 |
+
'{"action":"tool_name","args":"..."}\n'
|
| 460 |
+
'When finished respond: {"action":"final","answer":"..."}\n'
|
| 461 |
+
)
|
| 462 |
+
context = agent_hdr + f"\nUser: {user_prompt}\n"
|
| 463 |
+
for _ in range(max_steps):
|
| 464 |
+
model_out = llm_func(context)
|
| 465 |
+
try:
|
| 466 |
+
first_line = model_out.strip().splitlines()[0]
|
| 467 |
+
action_obj = json.loads(first_line)
|
| 468 |
+
except Exception:
|
| 469 |
+
return {"final": model_out}
|
| 470 |
+
act = action_obj.get("action")
|
| 471 |
+
if act == "final":
|
| 472 |
+
return {"final": action_obj.get("answer", "")}
|
| 473 |
+
args = action_obj.get("args", "")
|
| 474 |
+
tool_res = call_tool(act, args)
|
| 475 |
+
context += f"\nToolCall: {json.dumps({'tool': act, 'args': args})}\nToolResult: {json.dumps(tool_res)}\n"
|
| 476 |
+
return {"final": "Max steps reached. Partial reasoning returned.", "context": context}
|
| 477 |
+
|
| 478 |
+
# ===============================
|
| 479 |
+
# Multimodal (ASR / TTS / Vision / Music)
|
| 480 |
+
# ===============================
|
| 481 |
+
ASR_MODEL = None
|
| 482 |
+
def init_asr():
|
| 483 |
+
global ASR_MODEL
|
| 484 |
+
if DISABLE_MULTIMODAL or WhisperModel is None:
|
| 485 |
+
return None
|
| 486 |
+
if ASR_MODEL is None:
|
| 487 |
+
try:
|
| 488 |
+
device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
|
| 489 |
+
compute_type = "float16" if device == "cuda" else "int8"
|
| 490 |
+
ASR_MODEL = WhisperModel(ASR_MODEL_SIZE, device=device, compute_type=compute_type)
|
| 491 |
+
print(f"✅ ASR model loaded: {ASR_MODEL_SIZE} on {device}")
|
| 492 |
+
except Exception as e:
|
| 493 |
+
print("⚠️ ASR init failed:", e)
|
| 494 |
+
ASR_MODEL = None
|
| 495 |
+
return ASR_MODEL
|
| 496 |
+
|
| 497 |
+
def transcribe_audio(path: str, language: Optional[str] = None, task: str = "transcribe"):
|
| 498 |
+
if init_asr() is None:
|
| 499 |
+
return {"text": "asr-disabled"}
|
| 500 |
+
segments, info = ASR_MODEL.transcribe(path, language=language, task=task)
|
| 501 |
+
text = " ".join(seg.text for seg in segments)
|
| 502 |
+
return {"text": text, "duration": getattr(info, "duration", None)}
|
| 503 |
+
|
| 504 |
+
TTS_CLIENT = None
|
| 505 |
+
def init_tts():
|
| 506 |
+
global TTS_CLIENT
|
| 507 |
+
if DISABLE_MULTIMODAL or CoquiTTS is None:
|
| 508 |
+
return None
|
| 509 |
+
if TTS_CLIENT is None:
|
| 510 |
+
try:
|
| 511 |
+
TTS_CLIENT = CoquiTTS(model_name=COQUI_TTS_MODEL) if COQUI_TTS_MODEL else CoquiTTS()
|
| 512 |
+
print("✅ Coqui TTS initialized")
|
| 513 |
+
except Exception as e:
|
| 514 |
+
print("⚠️ TTS init failed:", e)
|
| 515 |
+
TTS_CLIENT = None
|
| 516 |
+
return TTS_CLIENT
|
| 517 |
+
|
| 518 |
+
def synthesize_to_file(text: str, voice: Optional[str] = None, out_path: Optional[str] = None):
|
| 519 |
+
out_path = out_path or f"/tmp/tts_{uuid.uuid4().hex}.mp3"
|
| 520 |
+
if init_tts() is None:
|
| 521 |
+
open(out_path, "wb").close()
|
| 522 |
+
return {"path": out_path}
|
| 523 |
+
try:
|
| 524 |
+
# Some models require specific speaker names; None often works with single-speaker
|
| 525 |
+
TTS_CLIENT.tts_to_file(text=text, speaker=voice, file_path=out_path)
|
| 526 |
+
except Exception as e:
|
| 527 |
+
print("⚠️ TTS synthesis failed:", e)
|
| 528 |
+
open(out_path, "wb").close()
|
| 529 |
+
return {"path": out_path}
|
| 530 |
+
|
| 531 |
+
BLIP_PROCESSOR = BLIP_MODEL = None
|
| 532 |
+
def init_vision():
|
| 533 |
+
global BLIP_PROCESSOR, BLIP_MODEL
|
| 534 |
+
if DISABLE_MULTIMODAL or (Blip2Processor is None or Blip2ForConditionalGeneration is None or Image is None):
|
| 535 |
+
return None, None
|
| 536 |
+
if BLIP_MODEL is None:
|
| 537 |
+
try:
|
| 538 |
+
model_name = "Salesforce/blip2-flan-t5-base"
|
| 539 |
+
BLIP_PROCESSOR = Blip2Processor.from_pretrained(model_name)
|
| 540 |
+
BLIP_MODEL = Blip2ForConditionalGeneration.from_pretrained(model_name)
|
| 541 |
+
device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
|
| 542 |
+
BLIP_MODEL.to(device)
|
| 543 |
+
print(f"✅ BLIP-2 loaded on {device}")
|
| 544 |
+
except Exception as e:
|
| 545 |
+
print("⚠️ Vision init failed:", e)
|
| 546 |
+
BLIP_PROCESSOR = BLIP_MODEL = None
|
| 547 |
+
return BLIP_PROCESSOR, BLIP_MODEL
|
| 548 |
+
|
| 549 |
+
def caption_image(path: str) -> str:
|
| 550 |
+
proc, model = init_vision()
|
| 551 |
+
if not proc or not model:
|
| 552 |
+
return "A photo (caption placeholder)."
|
| 553 |
+
device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
|
| 554 |
+
img = Image.open(path).convert("RGB")
|
| 555 |
+
inputs = proc(images=img, return_tensors="pt")
|
| 556 |
+
for k in inputs:
|
| 557 |
+
inputs[k] = inputs[k].to(device)
|
| 558 |
+
out_ids = model.generate(**inputs, max_new_tokens=64)
|
| 559 |
+
return proc.decode(out_ids[0], skip_special_tokens=True)
|
| 560 |
+
|
| 561 |
+
def ocr_image(path: str) -> str:
|
| 562 |
+
# Placeholder (integrate easyocr or pytesseract as needed)
|
| 563 |
+
return "OCR placeholder text."
|
| 564 |
+
|
| 565 |
+
MUSIC_MODEL = None
|
| 566 |
+
def init_music():
|
| 567 |
+
global MUSIC_MODEL
|
| 568 |
+
if DISABLE_MULTIMODAL or musicgen_lib is None:
|
| 569 |
+
return None
|
| 570 |
+
if MUSIC_MODEL is None:
|
| 571 |
+
try:
|
| 572 |
+
MUSIC_MODEL = musicgen_lib.MusicGen.get_pretrained("melody")
|
| 573 |
+
device = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
|
| 574 |
+
MUSIC_MODEL.to(device)
|
| 575 |
+
print(f"✅ MusicGen loaded on {device}")
|
| 576 |
+
except Exception as e:
|
| 577 |
+
print("⚠️ Music init failed:", e)
|
| 578 |
+
MUSIC_MODEL = None
|
| 579 |
+
return MUSIC_MODEL
|
| 580 |
+
|
| 581 |
+
def generate_music(prompt: str, duration: int = 20) -> Dict[str, Any]:
|
| 582 |
+
out = f"/tmp/music_{uuid.uuid4().hex}.wav"
|
| 583 |
+
if init_music() is None:
|
| 584 |
+
open(out, "wb").close()
|
| 585 |
+
return {"path": out}
|
| 586 |
+
try:
|
| 587 |
+
wav = MUSIC_MODEL.generate([prompt], duration=duration)
|
| 588 |
+
# audiocraft write helper changed over time; safest: torchaudio or soundfile
|
| 589 |
+
try:
|
| 590 |
+
import torchaudio
|
| 591 |
+
torchaudio.save(out, wav[0].cpu(), 32000)
|
| 592 |
+
except Exception:
|
| 593 |
+
# fallback empty file
|
| 594 |
+
open(out, "wb").close()
|
| 595 |
+
except Exception as e:
|
| 596 |
+
print("⚠️ Music generation failed:", e)
|
| 597 |
+
open(out, "wb").close()
|
| 598 |
+
return {"path": out}
|
| 599 |
+
|
| 600 |
+
# ===============================
|
| 601 |
+
# LLM loading & generation
|
| 602 |
+
# ===============================
|
| 603 |
+
MODEL = None
|
| 604 |
+
TOKENIZER = None
|
| 605 |
+
MODEL_DEVICE = "cpu"
|
| 606 |
+
|
| 607 |
+
def load_llm(model_id: str = DEFAULT_MODEL, use_bnb: bool = True):
|
| 608 |
+
global MODEL, TOKENIZER, MODEL_DEVICE
|
| 609 |
+
if AutoTokenizer is None or AutoModelForCausalLM is None:
|
| 610 |
+
raise RuntimeError("transformers is required. pip install transformers")
|
| 611 |
+
TOKENIZER = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
|
| 612 |
+
if TOKENIZER.pad_token_id is None:
|
| 613 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
| 614 |
+
kwargs = {}
|
| 615 |
+
if torch is not None and torch.cuda.is_available():
|
| 616 |
+
MODEL_DEVICE = "cuda"
|
| 617 |
+
if BitsAndBytesConfig is not None and use_bnb:
|
| 618 |
+
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
| 619 |
+
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
|
| 620 |
+
kwargs.update(dict(device_map="auto", quantization_config=bnb, token=HF_TOKEN))
|
| 621 |
+
else:
|
| 622 |
+
kwargs.update(dict(device_map="auto",
|
| 623 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
|
| 624 |
+
token=HF_TOKEN))
|
| 625 |
+
else:
|
| 626 |
+
MODEL_DEVICE = "cpu"
|
| 627 |
+
kwargs.update(dict(torch_dtype=torch.float32, token=HF_TOKEN))
|
| 628 |
+
try:
|
| 629 |
+
MODEL = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
|
| 630 |
+
except TypeError:
|
| 631 |
+
kwargs.pop("token", None)
|
| 632 |
+
MODEL = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=HF_TOKEN, **kwargs)
|
| 633 |
+
print(f"✅ LLM loaded on {MODEL_DEVICE}")
|
| 634 |
+
|
| 635 |
+
def _get_eos_token_id():
|
| 636 |
+
if TOKENIZER is None:
|
| 637 |
+
return None
|
| 638 |
+
eid = getattr(TOKENIZER, "eos_token_id", None)
|
| 639 |
+
if isinstance(eid, list) and eid:
|
| 640 |
+
return eid[0]
|
| 641 |
+
return eid
|
| 642 |
+
|
| 643 |
+
def make_system_prompt(local_knowledge: str) -> str:
|
| 644 |
+
base = ("You are Billy AI — a helpful, witty, and precise assistant. "
|
| 645 |
+
"Be concise but thorough; use bullet points; cite assumptions; avoid hallucinations.")
|
| 646 |
+
if local_knowledge:
|
| 647 |
+
base += f"\nUseful context: {local_knowledge[:3000]}"
|
| 648 |
+
return base
|
| 649 |
+
|
| 650 |
+
def build_prompt(user_prompt: str, chat_history: List[Tuple[str,str]]) -> str:
|
| 651 |
+
context = retrieve_knowledge(user_prompt, k=5)
|
| 652 |
+
system = make_system_prompt(context)
|
| 653 |
+
hist = ""
|
| 654 |
+
for u, a in (chat_history or []):
|
| 655 |
+
if u:
|
| 656 |
+
hist += f"\nUser: {u}\nAssistant: {a or ''}"
|
| 657 |
+
return f"<s>[INST]{system}[/INST]</s>\n{hist}\n[INST]User: {user_prompt}\nAssistant:[/INST]"
|
| 658 |
+
|
| 659 |
+
def generate_text_sync(prompt_text: str, max_tokens: int = 600, temperature: float = 0.6, top_p: float = 0.9) -> str:
|
| 660 |
+
if MODEL is None or TOKENIZER is None:
|
| 661 |
+
raise RuntimeError("LLM not loaded")
|
| 662 |
+
inputs = TOKENIZER(prompt_text, return_tensors="pt").to(MODEL_DEVICE)
|
| 663 |
+
out_ids = MODEL.generate(
|
| 664 |
+
**inputs,
|
| 665 |
+
max_new_tokens=min(max_tokens, 2048),
|
| 666 |
+
do_sample=True,
|
| 667 |
+
temperature=temperature,
|
| 668 |
+
top_p=top_p,
|
| 669 |
+
pad_token_id=TOKENIZER.pad_token_id,
|
| 670 |
+
eos_token_id=_get_eos_token_id(),
|
| 671 |
+
)
|
| 672 |
+
text = TOKENIZER.decode(out_ids[0], skip_special_tokens=True)
|
| 673 |
+
if text.startswith(prompt_text):
|
| 674 |
+
return text[len(prompt_text):].strip()
|
| 675 |
+
return text.strip()
|
| 676 |
+
|
| 677 |
+
def stream_generate(prompt_text: str, max_tokens: int = 600, temperature: float = 0.6, top_p: float = 0.9):
|
| 678 |
+
if MODEL is None or TOKENIZER is None:
|
| 679 |
+
yield "ERROR: model not loaded"
|
| 680 |
+
return
|
| 681 |
+
inputs = TOKENIZER(prompt_text, return_tensors="pt").to(MODEL_DEVICE)
|
| 682 |
+
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
|
| 683 |
+
def _gen():
|
| 684 |
+
MODEL.generate(
|
| 685 |
+
**inputs,
|
| 686 |
+
max_new_tokens=min(max_tokens, 2048),
|
| 687 |
+
do_sample=True,
|
| 688 |
+
temperature=temperature,
|
| 689 |
+
top_p=top_p,
|
| 690 |
+
pad_token_id=TOKENIZER.pad_token_id,
|
| 691 |
+
eos_token_id=_get_eos_token_id(),
|
| 692 |
+
streamer=streamer
|
| 693 |
+
)
|
| 694 |
+
Thread(target=_gen).start()
|
| 695 |
+
for tok in streamer:
|
| 696 |
+
yield tok
|
| 697 |
+
|
| 698 |
+
# ===============================
|
| 699 |
+
# Schemas
|
| 700 |
+
# ===============================
|
| 701 |
+
class GenerateRequest(BaseModel):
|
| 702 |
+
prompt: str
|
| 703 |
+
chat_history: Optional[List[Tuple[str,str]]] = []
|
| 704 |
+
max_tokens: int = 600
|
| 705 |
+
temperature: float = 0.6
|
| 706 |
+
top_p: float = 0.9
|
| 707 |
+
max_steps: int = 4 # for agent if triggered
|
| 708 |
+
|
| 709 |
+
class EmbedRequest(BaseModel):
|
| 710 |
+
texts: List[str]
|
| 711 |
+
|
| 712 |
+
class RememberRequest(BaseModel):
|
| 713 |
+
text: str
|
| 714 |
+
user_id: Optional[str] = None
|
| 715 |
+
|
| 716 |
+
class SearchRequest(BaseModel):
|
| 717 |
+
query: str
|
| 718 |
+
max_results: int = 3
|
| 719 |
+
|
| 720 |
+
class MusicRequest(BaseModel):
|
| 721 |
+
prompt: str
|
| 722 |
+
style: Optional[str] = None
|
| 723 |
+
duration: Optional[int] = 20
|
| 724 |
+
|
| 725 |
+
class TTSRequest(BaseModel):
|
| 726 |
+
text: str
|
| 727 |
+
voice: Optional[str] = "default"
|
| 728 |
+
format: Optional[str] = "mp3"
|
| 729 |
+
|
| 730 |
+
class AgentRequest(BaseModel):
|
| 731 |
+
prompt: str
|
| 732 |
+
chat_history: Optional[List[Tuple[str,str]]] = []
|
| 733 |
+
max_steps: int = 4
|
| 734 |
+
|
| 735 |
+
class ForgetRequest(BaseModel):
|
| 736 |
+
id: str
|
| 737 |
+
|
| 738 |
+
# ===============================
|
| 739 |
+
# Endpoints
|
| 740 |
+
# ===============================
|
| 741 |
+
@app.get("/health")
|
| 742 |
+
def health( ):
|
| 743 |
+
return {"status": "ok"}
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
@app.post("/generate", dependencies=[Depends(api_key_auth)])
|
| 747 |
+
def generate(req: GenerateRequest):
|
| 748 |
+
rl_key = f"rl:{hashlib.sha1((req.prompt or '').encode()).hexdigest()}"
|
| 749 |
+
if not rate_limit(rl_key, limit=120, window=60):
|
| 750 |
+
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
| 751 |
+
|
| 752 |
+
safe, reason = is_safe_message(req.prompt)
|
| 753 |
+
if not safe:
|
| 754 |
+
return JSONResponse({"error": reason or "Unsafe prompt"}, status_code=400)
|
| 755 |
+
|
| 756 |
+
if req.prompt.strip().lower().startswith("use tool:") or "CALL_TOOL" in req.prompt:
|
| 757 |
+
def _llm(p): return generate_text_sync(p, max_tokens=400, temperature=0.2, top_p=0.9)
|
| 758 |
+
out = agent_run(_llm, make_system_prompt(retrieve_knowledge(req.prompt, k=5)),
|
| 759 |
+
req.prompt, req.chat_history or [], max_steps=req.max_steps)
|
| 760 |
+
return out
|
| 761 |
+
|
| 762 |
+
prompt = build_prompt(req.prompt, req.chat_history or [])
|
| 763 |
+
cache_key = f"resp:{hashlib.sha1(prompt.encode()).hexdigest()}"
|
| 764 |
+
cached = cache_get(cache_key)
|
| 765 |
+
if cached:
|
| 766 |
+
return {"response": cached}
|
| 767 |
+
|
| 768 |
+
out = generate_text_sync(prompt, max_tokens=req.max_tokens,
|
| 769 |
+
temperature=req.temperature, top_p=req.top_p)
|
| 770 |
+
safe_out, _ = is_safe_message(out)
|
| 771 |
+
if not safe_out:
|
| 772 |
+
return JSONResponse({"error": "Response blocked by moderation."}, status_code=400)
|
| 773 |
+
cache_set(cache_key, out, ttl=30)
|
| 774 |
+
return {"response": out}
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
@app.post("/stream", dependencies=[Depends(api_key_auth)])
|
| 778 |
+
def stream(req: GenerateRequest):
|
| 779 |
+
safe, reason = is_safe_message(req.prompt)
|
| 780 |
+
if not safe:
|
| 781 |
+
return StreamingResponse(iter([reason or "Unsafe prompt"]), media_type="text/plain")
|
| 782 |
+
prompt = build_prompt(req.prompt, req.chat_history or [])
|
| 783 |
+
def gen():
|
| 784 |
+
for chunk in stream_generate(prompt, max_tokens=req.max_tokens,
|
| 785 |
+
temperature=req.temperature, top_p=req.top_p):
|
| 786 |
+
yield chunk
|
| 787 |
+
return StreamingResponse(gen(), media_type="text/plain")
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
@app.post("/agent", dependencies=[Depends(api_key_auth)])
|
| 791 |
+
def agent_endpoint(req: AgentRequest):
|
| 792 |
+
def _llm(p): return generate_text_sync(p, max_tokens=400, temperature=0.2, top_p=0.9)
|
| 793 |
+
out = agent_run(_llm, make_system_prompt(retrieve_knowledge(req.prompt, k=5)),
|
| 794 |
+
req.prompt, req.chat_history or [], max_steps=req.max_steps)
|
| 795 |
+
return out
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
@app.post("/embed", dependencies=[Depends(api_key_auth)])
|
| 799 |
+
def embed(req: EmbedRequest):
|
| 800 |
+
if not embedder:
|
| 801 |
+
return JSONResponse({"error": "Embedder not loaded."}, status_code=500)
|
| 802 |
+
vecs = [embed_text_local(t) for t in req.texts]
|
| 803 |
+
for t in req.texts:
|
| 804 |
+
store_knowledge(t)
|
| 805 |
+
return {"embeddings": vecs}
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
@app.post("/remember", dependencies=[Depends(api_key_auth)])
|
| 809 |
+
def remember(req: RememberRequest):
|
| 810 |
+
store_knowledge(req.text, user_id=req.user_id if hasattr(req, "user_id") else None)
|
| 811 |
+
return {"status": "stored"}
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
@app.post("/search", dependencies=[Depends(api_key_auth)])
|
| 815 |
+
def web_search(req: SearchRequest):
|
| 816 |
+
ws = TOOLS.get("web_search")
|
| 817 |
+
if not ws:
|
| 818 |
+
return {"ingested": 0, "context_sample": ""}
|
| 819 |
+
res = ws.run(req.query)
|
| 820 |
+
count = 0
|
| 821 |
+
for s in res.get("snippets", []):
|
| 822 |
+
store_knowledge(s)
|
| 823 |
+
count += 1
|
| 824 |
+
ctx = retrieve_knowledge(req.query, k=req.max_results or 3)
|
| 825 |
+
return {"ingested": count, "context_sample": ctx[:1000]}
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
@app.post("/music", dependencies=[Depends(api_key_auth)])
|
| 829 |
+
def music(req: MusicRequest, background_tasks: BackgroundTasks):
|
| 830 |
+
try:
|
| 831 |
+
tmp = generate_music(req.prompt, duration=req.duration or 20).get("path")
|
| 832 |
+
url = save_media_to_supabase(tmp, "audio", prompt=req.prompt)
|
| 833 |
+
return {
|
| 834 |
+
"reply": f"Generated music for: {req.prompt}",
|
| 835 |
+
"audioUrl": url or tmp
|
| 836 |
+
}
|
| 837 |
+
except Exception as e:
|
| 838 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 839 |
+
|
| 840 |
+
# === Updated TTS with Supabase ===
|
| 841 |
+
@app.post("/tts", dependencies=[Depends(api_key_auth)])
|
| 842 |
+
def tts(req: TTSRequest):
|
| 843 |
+
try:
|
| 844 |
+
out = synthesize_to_file(req.text, voice=req.voice)
|
| 845 |
+
url = save_media_to_supabase(out["path"], "audio", prompt=req.text)
|
| 846 |
+
return {"audioUrl": url or out["path"]}
|
| 847 |
+
except Exception as e:
|
| 848 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
@app.post("/tts_stream", dependencies=[Depends(api_key_auth)])
|
| 852 |
+
def tts_stream(req: TTSRequest):
|
| 853 |
+
try:
|
| 854 |
+
out = synthesize_to_file(req.text, voice=req.voice)
|
| 855 |
+
def iterfile():
|
| 856 |
+
with open(out["path"], "rb") as f:
|
| 857 |
+
while True:
|
| 858 |
+
chunk = f.read(4096)
|
| 859 |
+
if not chunk:
|
| 860 |
+
break
|
| 861 |
+
yield chunk
|
| 862 |
+
return StreamingResponse(iterfile(), media_type="audio/mpeg")
|
| 863 |
+
except Exception as e:
|
| 864 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
@app.post("/asr", dependencies=[Depends(api_key_auth)])
|
| 868 |
+
async def asr(file: UploadFile = File(...)):
|
| 869 |
+
try:
|
| 870 |
+
tmp = f"/tmp/asr_{uuid.uuid4().hex}_{file.filename}"
|
| 871 |
+
with open(tmp, "wb") as f:
|
| 872 |
+
f.write(await file.read())
|
| 873 |
+
res = transcribe_audio(tmp)
|
| 874 |
+
return {"transcript": res.get("text", "")}
|
| 875 |
+
except Exception as e:
|
| 876 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
@app.post("/vision", dependencies=[Depends(api_key_auth)])
|
| 880 |
+
async def vision(file: UploadFile = File(...), task: Optional[str] = "caption"):
|
| 881 |
+
try:
|
| 882 |
+
tmp = f"/tmp/vision_{uuid.uuid4().hex}.jpg"
|
| 883 |
+
with open(tmp, "wb") as f:
|
| 884 |
+
f.write(await file.read())
|
| 885 |
+
if (task or "").lower() == "ocr":
|
| 886 |
+
text = ocr_image(tmp)
|
| 887 |
+
return {"text": text}
|
| 888 |
+
caption = caption_image(tmp)
|
| 889 |
+
return {"caption": caption}
|
| 890 |
+
except Exception as e:
|
| 891 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
@app.websocket("/ws/generate")
|
| 895 |
+
async def websocket_generate(ws: WebSocket):
|
| 896 |
+
await ws.accept()
|
| 897 |
+
try:
|
| 898 |
+
while True:
|
| 899 |
+
data = await ws.receive_json()
|
| 900 |
+
prompt = data.get("prompt", "")
|
| 901 |
+
chat_history = data.get("chat_history", [])
|
| 902 |
+
max_tokens = int(data.get("max_tokens", 256))
|
| 903 |
+
temperature = float(data.get("temperature", 0.6))
|
| 904 |
+
top_p = float(data.get("top_p", 0.9))
|
| 905 |
+
|
| 906 |
+
built = build_prompt(prompt, chat_history or [])
|
| 907 |
+
inputs = TOKENIZER(built, return_tensors="pt").to(MODEL_DEVICE)
|
| 908 |
+
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
|
| 909 |
+
|
| 910 |
+
def run_gen():
|
| 911 |
+
MODEL.generate(
|
| 912 |
+
**inputs,
|
| 913 |
+
max_new_tokens=min(max_tokens, 2048),
|
| 914 |
+
do_sample=True,
|
| 915 |
+
temperature=temperature,
|
| 916 |
+
top_p=top_p,
|
| 917 |
+
pad_token_id=TOKENIZER.pad_token_id,
|
| 918 |
+
eos_token_id=_get_eos_token_id(),
|
| 919 |
+
streamer=streamer
|
| 920 |
+
)
|
| 921 |
+
Thread(target=run_gen).start()
|
| 922 |
+
|
| 923 |
+
accumulated = ""
|
| 924 |
+
for token in streamer:
|
| 925 |
+
accumulated += token
|
| 926 |
+
await ws.send_json({"delta": token})
|
| 927 |
+
safe_out, _ = is_safe_message(accumulated)
|
| 928 |
+
if not safe_out:
|
| 929 |
+
await ws.send_json({"done": True, "final": "⚠️ Response blocked by moderation."})
|
| 930 |
+
else:
|
| 931 |
+
await ws.send_json({"done": True, "final": accumulated})
|
| 932 |
+
except Exception:
|
| 933 |
+
await ws.close()
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
@app.get("/admin/memory")
|
| 937 |
+
def admin_memory():
|
| 938 |
+
return {"count": len(memory_store)}
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
@app.post("/forget", dependencies=[Depends(api_key_auth)])
|
| 942 |
+
def forget(req: ForgetRequest):
|
| 943 |
+
try:
|
| 944 |
+
ok = delete_memory_by_id(req.id)
|
| 945 |
+
if ok:
|
| 946 |
+
return {"status": "deleted"}
|
| 947 |
+
return JSONResponse({"error": "Not found"}, status_code=404)
|
| 948 |
+
except Exception as e:
|
| 949 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
@app.get("/metrics")
|
| 953 |
+
def metrics():
|
| 954 |
+
if not generate_latest or not CONTENT_TYPE_LATEST:
|
| 955 |
+
return JSONResponse({"error": "prometheus-client not installed"}, status_code=500)
|
| 956 |
+
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
# === New: Library Endpoint ===
|
| 960 |
+
@app.get("/library", dependencies=[Depends(api_key_auth)])
|
| 961 |
+
def list_library(page: int = 0, page_size: int = 12):
|
| 962 |
+
if not supabase_client:
|
| 963 |
+
return {"items": []}
|
| 964 |
+
try:
|
| 965 |
+
start = page * page_size
|
| 966 |
+
end = start + page_size - 1
|
| 967 |
+
resp = supabase_client.table("generated_media") \
|
| 968 |
+
.select("*") \
|
| 969 |
+
.order("created_at", desc=True) \
|
| 970 |
+
.range(start, end) \
|
| 971 |
+
.execute()
|
| 972 |
+
return {"items": resp.data or []}
|
| 973 |
+
except Exception as e:
|
| 974 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 975 |
+
|
| 976 |
+
@app.on_event("startup")
|
| 977 |
+
def on_startup():
|
| 978 |
+
try:
|
| 979 |
+
load_llm(DEFAULT_MODEL)
|
| 980 |
+
except Exception as e:
|
| 981 |
+
print("⚠️ LLM load failed:", e)
|
| 982 |
+
try:
|
| 983 |
+
if not DISABLE_MULTIMODAL:
|
| 984 |
+
init_asr()
|
| 985 |
+
init_tts()
|
| 986 |
+
init_vision()
|
| 987 |
+
init_music()
|
| 988 |
+
except Exception:
|
| 989 |
+
pass
|
| 990 |
+
print("🚀 Billy AI startup complete")
|
README.md
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|