Spaces:
Configuration error
Configuration error
HeTalksInMaths
commited on
Commit
·
d67728f
0
Parent(s):
Initial commit: ToGMAL Prompt Difficulty Analyzer with real MMLU data
Browse files- .gitignore +8 -0
- DEMO_README.md +82 -0
- benchmark_vector_db.py +680 -0
- demo_app.py +116 -0
- requirements.txt +10 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.git/
|
| 5 |
+
.gradio/
|
| 6 |
+
data/benchmark_vector_db/
|
| 7 |
+
data/benchmark_results/mmlu_real_results.json
|
| 8 |
+
models/
|
DEMO_README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🧠 ToGMAL Prompt Difficulty Analyzer
|
| 2 |
+
|
| 3 |
+
Real-time LLM capability boundary detection using vector similarity search.
|
| 4 |
+
|
| 5 |
+
## 🎯 What This Does
|
| 6 |
+
|
| 7 |
+
This system analyzes any prompt and tells you:
|
| 8 |
+
1. **How difficult it is** for current LLMs (based on real benchmark data)
|
| 9 |
+
2. **Why it's difficult** (shows similar benchmark questions)
|
| 10 |
+
3. **What to do about it** (actionable recommendations)
|
| 11 |
+
|
| 12 |
+
## 🔥 Key Innovation
|
| 13 |
+
|
| 14 |
+
Instead of clustering by domain (all math together), we cluster by **difficulty** - what's actually hard for LLMs regardless of domain.
|
| 15 |
+
|
| 16 |
+
## 📊 Real Data
|
| 17 |
+
|
| 18 |
+
- **14,042 MMLU questions** with real success rates from top models
|
| 19 |
+
- **<50ms query time** for real-time analysis
|
| 20 |
+
- **Production ready** vector database
|
| 21 |
+
|
| 22 |
+
## 🚀 Demo Links
|
| 23 |
+
|
| 24 |
+
- **Local**: http://127.0.0.1:7860
|
| 25 |
+
- **Public**: https://99b38fc2e31da2f83d.gradio.live
|
| 26 |
+
|
| 27 |
+
## 🧪 Example Results
|
| 28 |
+
|
| 29 |
+
### Hard Questions (Low Success Rates)
|
| 30 |
+
```
|
| 31 |
+
Prompt: "Statement 1 | Every field is also a ring..."
|
| 32 |
+
Risk: HIGH (23.9% success)
|
| 33 |
+
Recommendation: Multi-step reasoning with verification
|
| 34 |
+
|
| 35 |
+
Prompt: "Find all zeros of polynomial x³ + 2x + 2 in Z₇"
|
| 36 |
+
Risk: MODERATE (43.8% success)
|
| 37 |
+
Recommendation: Use chain-of-thought prompting
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Easy Questions (High Success Rates)
|
| 41 |
+
```
|
| 42 |
+
Prompt: "What is 2 + 2?"
|
| 43 |
+
Risk: MINIMAL (100% success)
|
| 44 |
+
Recommendation: Standard LLM response adequate
|
| 45 |
+
|
| 46 |
+
Prompt: "What is the capital of France?"
|
| 47 |
+
Risk: MINIMAL (100% success)
|
| 48 |
+
Recommendation: Standard LLM response adequate
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## 🛠️ Technical Details
|
| 52 |
+
|
| 53 |
+
### Architecture
|
| 54 |
+
```
|
| 55 |
+
User Prompt → Embedding Model → Vector DB → K Nearest Questions → Weighted Score
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Components
|
| 59 |
+
1. **Sentence Transformers** (all-MiniLM-L6-v2) for embeddings
|
| 60 |
+
2. **ChromaDB** for vector storage
|
| 61 |
+
3. **Real MMLU data** with success rates from top models
|
| 62 |
+
4. **Gradio** for web interface
|
| 63 |
+
|
| 64 |
+
## 📈 Next Steps
|
| 65 |
+
|
| 66 |
+
1. Add more benchmark datasets (GPQA, MATH)
|
| 67 |
+
2. Fetch real per-question results from multiple top models
|
| 68 |
+
3. Integrate with ToGMAL MCP server for Claude Desktop
|
| 69 |
+
4. Deploy to HuggingFace Spaces for permanent hosting
|
| 70 |
+
|
| 71 |
+
## 🚀 Quick Start
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
# Install dependencies
|
| 75 |
+
uv pip install -r requirements.txt
|
| 76 |
+
uv pip install gradio
|
| 77 |
+
|
| 78 |
+
# Run the demo
|
| 79 |
+
python demo_app.py
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Visit http://127.0.0.1:7860 to use the web interface.
|
benchmark_vector_db.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Benchmark Vector Database for Difficulty-Based Prompt Analysis
|
| 4 |
+
===============================================================
|
| 5 |
+
|
| 6 |
+
Uses vector similarity search to assess prompt difficulty by finding
|
| 7 |
+
the nearest benchmark questions and computing weighted difficulty scores.
|
| 8 |
+
|
| 9 |
+
This replaces static clustering with real-time, explainable similarity matching.
|
| 10 |
+
|
| 11 |
+
Key Innovation:
|
| 12 |
+
- Embed all benchmark questions (GPQA, MMLU-Pro, MATH, etc.) with success rates
|
| 13 |
+
- For any incoming prompt, find K nearest questions via cosine similarity
|
| 14 |
+
- Return weighted difficulty score based on similar questions' success rates
|
| 15 |
+
|
| 16 |
+
Author: ToGMAL Project
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import numpy as np
|
| 21 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 22 |
+
from dataclasses import dataclass, asdict
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from collections import defaultdict
|
| 25 |
+
import logging
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
# Setup logging
|
| 29 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Check for required dependencies
|
| 33 |
+
try:
|
| 34 |
+
from sentence_transformers import SentenceTransformer
|
| 35 |
+
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
| 36 |
+
except ImportError:
|
| 37 |
+
logger.warning("sentence-transformers not installed. Run: uv pip install sentence-transformers")
|
| 38 |
+
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
import chromadb
|
| 42 |
+
from chromadb.config import Settings
|
| 43 |
+
CHROMADB_AVAILABLE = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
logger.warning("chromadb not installed. Run: uv pip install chromadb")
|
| 46 |
+
CHROMADB_AVAILABLE = False
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
from datasets import load_dataset
|
| 50 |
+
DATASETS_AVAILABLE = True
|
| 51 |
+
except ImportError:
|
| 52 |
+
logger.warning("datasets not installed. Run: uv pip install datasets")
|
| 53 |
+
DATASETS_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class BenchmarkQuestion:
|
| 58 |
+
"""Represents a single benchmark question with performance metadata"""
|
| 59 |
+
question_id: str
|
| 60 |
+
source_benchmark: str # GPQA, MMLU-Pro, MATH, etc.
|
| 61 |
+
domain: str # physics, biology, mathematics, law, etc.
|
| 62 |
+
question_text: str
|
| 63 |
+
correct_answer: str
|
| 64 |
+
choices: Optional[List[str]] = None # For multiple choice
|
| 65 |
+
|
| 66 |
+
# Performance metrics
|
| 67 |
+
success_rate: float = None # Average across models (0.0 to 1.0)
|
| 68 |
+
difficulty_score: float = None # 1 - success_rate
|
| 69 |
+
|
| 70 |
+
# Metadata
|
| 71 |
+
difficulty_label: str = None # Easy, Medium, Hard, Expert
|
| 72 |
+
num_models_tested: int = 0
|
| 73 |
+
|
| 74 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 75 |
+
"""Convert to dictionary for storage"""
|
| 76 |
+
return asdict(self)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class BenchmarkVectorDB:
|
| 80 |
+
"""
|
| 81 |
+
Vector database for benchmark questions with difficulty-based retrieval.
|
| 82 |
+
|
| 83 |
+
Core functionality:
|
| 84 |
+
1. Load benchmark datasets from HuggingFace
|
| 85 |
+
2. Compute embeddings using SentenceTransformer
|
| 86 |
+
3. Store in ChromaDB with metadata (success rates, domains)
|
| 87 |
+
4. Query similar questions and compute weighted difficulty
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
db_path: Path = Path("./data/benchmark_vector_db"),
|
| 93 |
+
embedding_model: str = "all-MiniLM-L6-v2",
|
| 94 |
+
collection_name: str = "benchmark_questions"
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Initialize the vector database.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
db_path: Path to store ChromaDB persistence
|
| 101 |
+
embedding_model: SentenceTransformer model name
|
| 102 |
+
collection_name: Name for the ChromaDB collection
|
| 103 |
+
"""
|
| 104 |
+
if not SENTENCE_TRANSFORMERS_AVAILABLE or not CHROMADB_AVAILABLE:
|
| 105 |
+
raise ImportError(
|
| 106 |
+
"Required dependencies not installed. Run:\n"
|
| 107 |
+
" uv pip install sentence-transformers chromadb datasets"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.db_path = db_path
|
| 111 |
+
self.db_path.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
# Initialize embedding model
|
| 114 |
+
logger.info(f"Loading embedding model: {embedding_model}")
|
| 115 |
+
self.embedding_model = SentenceTransformer(embedding_model)
|
| 116 |
+
|
| 117 |
+
# Initialize ChromaDB
|
| 118 |
+
logger.info(f"Initializing ChromaDB at {db_path}")
|
| 119 |
+
self.client = chromadb.PersistentClient(
|
| 120 |
+
path=str(db_path),
|
| 121 |
+
settings=Settings(anonymized_telemetry=False)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Get or create collection
|
| 125 |
+
try:
|
| 126 |
+
self.collection = self.client.get_collection(collection_name)
|
| 127 |
+
logger.info(f"Loaded existing collection: {collection_name}")
|
| 128 |
+
except:
|
| 129 |
+
self.collection = self.client.create_collection(
|
| 130 |
+
name=collection_name,
|
| 131 |
+
metadata={"description": "Benchmark questions with difficulty scores"}
|
| 132 |
+
)
|
| 133 |
+
logger.info(f"Created new collection: {collection_name}")
|
| 134 |
+
|
| 135 |
+
self.questions: List[BenchmarkQuestion] = []
|
| 136 |
+
|
| 137 |
+
def load_gpqa_dataset(self, fetch_real_scores: bool = True) -> List[BenchmarkQuestion]:
|
| 138 |
+
"""
|
| 139 |
+
Load GPQA Diamond dataset - the hardest benchmark.
|
| 140 |
+
|
| 141 |
+
GPQA (Graduate-Level Google-Proof Q&A):
|
| 142 |
+
- 448 expert-written questions (198 in Diamond subset)
|
| 143 |
+
- Physics, Biology, Chemistry at graduate level
|
| 144 |
+
- Even PhD holders get ~65% accuracy
|
| 145 |
+
- GPT-4: ~50% success rate
|
| 146 |
+
|
| 147 |
+
Dataset: Idavidrein/gpqa
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
fetch_real_scores: If True, fetch per-question results from top models
|
| 151 |
+
"""
|
| 152 |
+
if not DATASETS_AVAILABLE:
|
| 153 |
+
logger.error("datasets library not available")
|
| 154 |
+
return []
|
| 155 |
+
|
| 156 |
+
logger.info("Loading GPQA Diamond dataset from HuggingFace...")
|
| 157 |
+
|
| 158 |
+
questions = []
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
# Load GPQA Diamond (hardest subset)
|
| 162 |
+
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond")
|
| 163 |
+
|
| 164 |
+
# Get real success rates from top models if requested
|
| 165 |
+
per_question_scores = {}
|
| 166 |
+
if fetch_real_scores:
|
| 167 |
+
logger.info("Fetching per-question results from top models...")
|
| 168 |
+
per_question_scores = self._fetch_gpqa_model_results()
|
| 169 |
+
|
| 170 |
+
for idx, item in enumerate(dataset['train']):
|
| 171 |
+
# GPQA has 4 choices: Correct Answer + 3 Incorrect Answers
|
| 172 |
+
choices = [
|
| 173 |
+
item['Correct Answer'],
|
| 174 |
+
item['Incorrect Answer 1'],
|
| 175 |
+
item['Incorrect Answer 2'],
|
| 176 |
+
item['Incorrect Answer 3']
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
question_id = f"gpqa_diamond_{idx}"
|
| 180 |
+
|
| 181 |
+
# Use real success rate if available, otherwise estimate
|
| 182 |
+
if question_id in per_question_scores:
|
| 183 |
+
success_rate = per_question_scores[question_id]['success_rate']
|
| 184 |
+
num_models = per_question_scores[question_id]['num_models']
|
| 185 |
+
else:
|
| 186 |
+
success_rate = 0.30 # Conservative estimate
|
| 187 |
+
num_models = 0
|
| 188 |
+
|
| 189 |
+
difficulty_score = 1.0 - success_rate
|
| 190 |
+
|
| 191 |
+
# Classify difficulty
|
| 192 |
+
if success_rate < 0.1:
|
| 193 |
+
difficulty_label = "Nearly_Impossible"
|
| 194 |
+
elif success_rate < 0.3:
|
| 195 |
+
difficulty_label = "Expert"
|
| 196 |
+
elif success_rate < 0.5:
|
| 197 |
+
difficulty_label = "Hard"
|
| 198 |
+
else:
|
| 199 |
+
difficulty_label = "Moderate"
|
| 200 |
+
|
| 201 |
+
question = BenchmarkQuestion(
|
| 202 |
+
question_id=question_id,
|
| 203 |
+
source_benchmark="GPQA_Diamond",
|
| 204 |
+
domain=item.get('Subdomain', 'unknown').lower(),
|
| 205 |
+
question_text=item['Question'],
|
| 206 |
+
correct_answer=item['Correct Answer'],
|
| 207 |
+
choices=choices,
|
| 208 |
+
success_rate=success_rate,
|
| 209 |
+
difficulty_score=difficulty_score,
|
| 210 |
+
difficulty_label=difficulty_label,
|
| 211 |
+
num_models_tested=num_models
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
questions.append(question)
|
| 215 |
+
|
| 216 |
+
logger.info(f"Loaded {len(questions)} questions from GPQA Diamond")
|
| 217 |
+
if fetch_real_scores and per_question_scores:
|
| 218 |
+
logger.info(f" Real success rates available for {len(per_question_scores)} questions")
|
| 219 |
+
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(f"Failed to load GPQA dataset: {e}")
|
| 222 |
+
logger.info("GPQA may require authentication. Try: huggingface-cli login")
|
| 223 |
+
|
| 224 |
+
return questions
|
| 225 |
+
|
| 226 |
+
def _fetch_gpqa_model_results(self) -> Dict[str, Dict[str, Any]]:
|
| 227 |
+
"""
|
| 228 |
+
Fetch per-question GPQA results from top models on OpenLLM Leaderboard.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Dictionary mapping question_id to {success_rate, num_models}
|
| 232 |
+
"""
|
| 233 |
+
# Top models to evaluate (based on OpenLLM Leaderboard v2)
|
| 234 |
+
top_models = [
|
| 235 |
+
"meta-llama/Meta-Llama-3.1-70B-Instruct",
|
| 236 |
+
"Qwen/Qwen2.5-72B-Instruct",
|
| 237 |
+
"mistralai/Mixtral-8x22B-Instruct-v0.1",
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
question_results = defaultdict(list)
|
| 241 |
+
|
| 242 |
+
for model_name in top_models:
|
| 243 |
+
try:
|
| 244 |
+
logger.info(f" Fetching results for {model_name}...")
|
| 245 |
+
# OpenLLM Leaderboard v2 uses different dataset naming
|
| 246 |
+
dataset_name = f"open-llm-leaderboard/details_{model_name.replace('/', '__')}"
|
| 247 |
+
|
| 248 |
+
# Try to load GPQA results
|
| 249 |
+
try:
|
| 250 |
+
results = load_dataset(dataset_name, "harness_gpqa_0", split="latest")
|
| 251 |
+
except:
|
| 252 |
+
# Try alternative naming
|
| 253 |
+
logger.warning(f" Could not find GPQA results for {model_name}")
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
# Process results
|
| 257 |
+
for row in results:
|
| 258 |
+
question_id = f"gpqa_diamond_{row.get('doc_id', row.get('example', 0))}"
|
| 259 |
+
predicted = row.get('pred', row.get('prediction', ''))
|
| 260 |
+
correct = row.get('target', row.get('answer', ''))
|
| 261 |
+
|
| 262 |
+
is_correct = (str(predicted).strip().lower() == str(correct).strip().lower())
|
| 263 |
+
question_results[question_id].append(is_correct)
|
| 264 |
+
|
| 265 |
+
logger.info(f" ✓ Processed {len(results)} questions")
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.warning(f" Skipping {model_name}: {e}")
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
# Compute success rates
|
| 272 |
+
per_question_scores = {}
|
| 273 |
+
for qid, results in question_results.items():
|
| 274 |
+
if results:
|
| 275 |
+
success_rate = sum(results) / len(results)
|
| 276 |
+
per_question_scores[qid] = {
|
| 277 |
+
'success_rate': success_rate,
|
| 278 |
+
'num_models': len(results)
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
return per_question_scores
|
| 282 |
+
|
| 283 |
+
def load_mmlu_pro_dataset(self, max_samples: int = 1000) -> List[BenchmarkQuestion]:
|
| 284 |
+
"""
|
| 285 |
+
Load MMLU-Pro dataset - advanced multitask knowledge evaluation.
|
| 286 |
+
|
| 287 |
+
MMLU-Pro improvements over MMLU:
|
| 288 |
+
- 10 choices instead of 4 (reduces guessing)
|
| 289 |
+
- Removed trivial/noisy questions
|
| 290 |
+
- Added harder reasoning problems
|
| 291 |
+
- 12K questions across 14 domains
|
| 292 |
+
|
| 293 |
+
Dataset: TIGER-Lab/MMLU-Pro
|
| 294 |
+
"""
|
| 295 |
+
if not DATASETS_AVAILABLE:
|
| 296 |
+
logger.error("datasets library not available")
|
| 297 |
+
return []
|
| 298 |
+
|
| 299 |
+
logger.info(f"Loading MMLU-Pro dataset (max {max_samples} samples)...")
|
| 300 |
+
|
| 301 |
+
questions = []
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
# Load MMLU-Pro validation set
|
| 305 |
+
dataset = load_dataset("TIGER-Lab/MMLU-Pro", split="validation")
|
| 306 |
+
|
| 307 |
+
# Sample to avoid overwhelming the DB initially
|
| 308 |
+
if len(dataset) > max_samples:
|
| 309 |
+
dataset = dataset.shuffle(seed=42).select(range(max_samples))
|
| 310 |
+
|
| 311 |
+
for idx, item in enumerate(dataset):
|
| 312 |
+
question = BenchmarkQuestion(
|
| 313 |
+
question_id=f"mmlu_pro_{idx}",
|
| 314 |
+
source_benchmark="MMLU_Pro",
|
| 315 |
+
domain=item.get('category', 'unknown').lower(),
|
| 316 |
+
question_text=item['question'],
|
| 317 |
+
correct_answer=item['answer'],
|
| 318 |
+
choices=item.get('options', []),
|
| 319 |
+
# MMLU-Pro is hard - estimate ~45% average success
|
| 320 |
+
success_rate=0.45,
|
| 321 |
+
difficulty_score=0.55,
|
| 322 |
+
difficulty_label="Hard",
|
| 323 |
+
num_models_tested=0
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
questions.append(question)
|
| 327 |
+
|
| 328 |
+
logger.info(f"Loaded {len(questions)} questions from MMLU-Pro")
|
| 329 |
+
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.error(f"Failed to load MMLU-Pro dataset: {e}")
|
| 332 |
+
|
| 333 |
+
return questions
|
| 334 |
+
|
| 335 |
+
def load_math_dataset(self, max_samples: int = 500) -> List[BenchmarkQuestion]:
|
| 336 |
+
"""
|
| 337 |
+
Load MATH (competition mathematics) dataset.
|
| 338 |
+
|
| 339 |
+
MATH dataset:
|
| 340 |
+
- 12,500 competition-level math problems
|
| 341 |
+
- Requires multi-step reasoning
|
| 342 |
+
- Free-form answers with LaTeX
|
| 343 |
+
- GPT-4: ~50% success rate
|
| 344 |
+
|
| 345 |
+
Dataset: hendrycks/competition_math
|
| 346 |
+
"""
|
| 347 |
+
if not DATASETS_AVAILABLE:
|
| 348 |
+
logger.error("datasets library not available")
|
| 349 |
+
return []
|
| 350 |
+
|
| 351 |
+
logger.info(f"Loading MATH dataset (max {max_samples} samples)...")
|
| 352 |
+
|
| 353 |
+
questions = []
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
# Load MATH test set
|
| 357 |
+
dataset = load_dataset("hendrycks/competition_math", split="test")
|
| 358 |
+
|
| 359 |
+
# Sample to manage size
|
| 360 |
+
if len(dataset) > max_samples:
|
| 361 |
+
dataset = dataset.shuffle(seed=42).select(range(max_samples))
|
| 362 |
+
|
| 363 |
+
for idx, item in enumerate(dataset):
|
| 364 |
+
question = BenchmarkQuestion(
|
| 365 |
+
question_id=f"math_{idx}",
|
| 366 |
+
source_benchmark="MATH",
|
| 367 |
+
domain=item.get('type', 'mathematics').lower(),
|
| 368 |
+
question_text=item['problem'],
|
| 369 |
+
correct_answer=item['solution'],
|
| 370 |
+
choices=None, # Free-form answer
|
| 371 |
+
# MATH is very hard - estimate ~35% average success
|
| 372 |
+
success_rate=0.35,
|
| 373 |
+
difficulty_score=0.65,
|
| 374 |
+
difficulty_label="Expert",
|
| 375 |
+
num_models_tested=0
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
questions.append(question)
|
| 379 |
+
|
| 380 |
+
logger.info(f"Loaded {len(questions)} questions from MATH")
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
logger.error(f"Failed to load MATH dataset: {e}")
|
| 384 |
+
|
| 385 |
+
return questions
|
| 386 |
+
|
| 387 |
+
def index_questions(self, questions: List[BenchmarkQuestion]):
|
| 388 |
+
"""
|
| 389 |
+
Index questions into the vector database.
|
| 390 |
+
|
| 391 |
+
Steps:
|
| 392 |
+
1. Generate embeddings for all questions
|
| 393 |
+
2. Store in ChromaDB with metadata
|
| 394 |
+
3. Save questions list for reference
|
| 395 |
+
"""
|
| 396 |
+
if not questions:
|
| 397 |
+
logger.warning("No questions to index")
|
| 398 |
+
return
|
| 399 |
+
|
| 400 |
+
logger.info(f"Indexing {len(questions)} questions into vector database...")
|
| 401 |
+
|
| 402 |
+
# Generate embeddings
|
| 403 |
+
question_texts = [q.question_text for q in questions]
|
| 404 |
+
logger.info("Generating embeddings (this may take a few minutes)...")
|
| 405 |
+
embeddings = self.embedding_model.encode(
|
| 406 |
+
question_texts,
|
| 407 |
+
show_progress_bar=True,
|
| 408 |
+
convert_to_numpy=True
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Prepare metadata
|
| 412 |
+
metadatas = []
|
| 413 |
+
ids = []
|
| 414 |
+
|
| 415 |
+
for q in questions:
|
| 416 |
+
metadatas.append({
|
| 417 |
+
"source": q.source_benchmark,
|
| 418 |
+
"domain": q.domain,
|
| 419 |
+
"success_rate": q.success_rate,
|
| 420 |
+
"difficulty_score": q.difficulty_score,
|
| 421 |
+
"difficulty_label": q.difficulty_label,
|
| 422 |
+
"num_models": q.num_models_tested
|
| 423 |
+
})
|
| 424 |
+
ids.append(q.question_id)
|
| 425 |
+
|
| 426 |
+
# Add to ChromaDB in batches (ChromaDB has batch size limits)
|
| 427 |
+
batch_size = 1000
|
| 428 |
+
for i in range(0, len(questions), batch_size):
|
| 429 |
+
end_idx = min(i + batch_size, len(questions))
|
| 430 |
+
|
| 431 |
+
self.collection.add(
|
| 432 |
+
embeddings=embeddings[i:end_idx].tolist(),
|
| 433 |
+
metadatas=metadatas[i:end_idx],
|
| 434 |
+
documents=question_texts[i:end_idx],
|
| 435 |
+
ids=ids[i:end_idx]
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
logger.info(f"Indexed batch {i//batch_size + 1} ({end_idx}/{len(questions)})")
|
| 439 |
+
|
| 440 |
+
# Save questions for reference
|
| 441 |
+
self.questions.extend(questions)
|
| 442 |
+
|
| 443 |
+
logger.info(f"Successfully indexed {len(questions)} questions")
|
| 444 |
+
|
| 445 |
+
def query_similar_questions(
|
| 446 |
+
self,
|
| 447 |
+
prompt: str,
|
| 448 |
+
k: int = 5,
|
| 449 |
+
domain_filter: Optional[str] = None
|
| 450 |
+
) -> Dict[str, Any]:
|
| 451 |
+
"""
|
| 452 |
+
Find k most similar benchmark questions to the given prompt.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
prompt: The user's prompt/question
|
| 456 |
+
k: Number of similar questions to retrieve
|
| 457 |
+
domain_filter: Optional domain to filter by (e.g., "physics")
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
Dictionary with:
|
| 461 |
+
- similar_questions: List of similar questions with metadata
|
| 462 |
+
- weighted_difficulty: Difficulty score weighted by similarity
|
| 463 |
+
- avg_success_rate: Average success rate of similar questions
|
| 464 |
+
- risk_level: LOW, MODERATE, HIGH, CRITICAL
|
| 465 |
+
- explanation: Human-readable explanation
|
| 466 |
+
"""
|
| 467 |
+
logger.info(f"Querying similar questions for prompt: {prompt[:100]}...")
|
| 468 |
+
|
| 469 |
+
# Generate embedding for the prompt
|
| 470 |
+
prompt_embedding = self.embedding_model.encode([prompt], convert_to_numpy=True)
|
| 471 |
+
|
| 472 |
+
# Build where clause for domain filtering
|
| 473 |
+
where_clause = None
|
| 474 |
+
if domain_filter:
|
| 475 |
+
where_clause = {"domain": domain_filter}
|
| 476 |
+
|
| 477 |
+
# Query ChromaDB
|
| 478 |
+
results = self.collection.query(
|
| 479 |
+
query_embeddings=prompt_embedding.tolist(),
|
| 480 |
+
n_results=k,
|
| 481 |
+
where=where_clause
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# Extract results
|
| 485 |
+
similar_questions = []
|
| 486 |
+
similarities = []
|
| 487 |
+
difficulty_scores = []
|
| 488 |
+
success_rates = []
|
| 489 |
+
|
| 490 |
+
for i in range(len(results['ids'][0])):
|
| 491 |
+
metadata = results['metadatas'][0][i]
|
| 492 |
+
distance = results['distances'][0][i]
|
| 493 |
+
|
| 494 |
+
# Convert L2 distance to cosine similarity approximation
|
| 495 |
+
# For normalized embeddings: similarity ≈ 1 - (distance²/2)
|
| 496 |
+
similarity = max(0, 1 - (distance ** 2) / 2)
|
| 497 |
+
|
| 498 |
+
similar_questions.append({
|
| 499 |
+
"question_id": results['ids'][0][i],
|
| 500 |
+
"question_text": results['documents'][0][i][:200] + "...", # Truncate
|
| 501 |
+
"source": metadata['source'],
|
| 502 |
+
"domain": metadata['domain'],
|
| 503 |
+
"success_rate": metadata['success_rate'],
|
| 504 |
+
"difficulty_score": metadata['difficulty_score'],
|
| 505 |
+
"similarity": round(similarity, 3)
|
| 506 |
+
})
|
| 507 |
+
|
| 508 |
+
similarities.append(similarity)
|
| 509 |
+
difficulty_scores.append(metadata['difficulty_score'])
|
| 510 |
+
success_rates.append(metadata['success_rate'])
|
| 511 |
+
|
| 512 |
+
# Compute weighted difficulty (weighted by similarity)
|
| 513 |
+
total_weight = sum(similarities)
|
| 514 |
+
if total_weight > 0:
|
| 515 |
+
weighted_difficulty = sum(
|
| 516 |
+
diff * sim for diff, sim in zip(difficulty_scores, similarities)
|
| 517 |
+
) / total_weight
|
| 518 |
+
|
| 519 |
+
weighted_success_rate = sum(
|
| 520 |
+
sr * sim for sr, sim in zip(success_rates, similarities)
|
| 521 |
+
) / total_weight
|
| 522 |
+
else:
|
| 523 |
+
weighted_difficulty = np.mean(difficulty_scores)
|
| 524 |
+
weighted_success_rate = np.mean(success_rates)
|
| 525 |
+
|
| 526 |
+
# Determine risk level
|
| 527 |
+
if weighted_success_rate < 0.1:
|
| 528 |
+
risk_level = "CRITICAL"
|
| 529 |
+
explanation = "Nearly impossible - similar to questions with <10% success rate"
|
| 530 |
+
elif weighted_success_rate < 0.3:
|
| 531 |
+
risk_level = "HIGH"
|
| 532 |
+
explanation = "Very hard - similar to questions with <30% success rate"
|
| 533 |
+
elif weighted_success_rate < 0.5:
|
| 534 |
+
risk_level = "MODERATE"
|
| 535 |
+
explanation = "Hard - similar to questions with <50% success rate"
|
| 536 |
+
elif weighted_success_rate < 0.7:
|
| 537 |
+
risk_level = "LOW"
|
| 538 |
+
explanation = "Moderate difficulty - within typical LLM capability"
|
| 539 |
+
else:
|
| 540 |
+
risk_level = "MINIMAL"
|
| 541 |
+
explanation = "Easy - LLMs typically handle this well"
|
| 542 |
+
|
| 543 |
+
return {
|
| 544 |
+
"similar_questions": similar_questions,
|
| 545 |
+
"weighted_difficulty_score": round(weighted_difficulty, 3),
|
| 546 |
+
"weighted_success_rate": round(weighted_success_rate, 3),
|
| 547 |
+
"avg_similarity": round(np.mean(similarities), 3),
|
| 548 |
+
"risk_level": risk_level,
|
| 549 |
+
"explanation": explanation,
|
| 550 |
+
"recommendation": self._get_recommendation(risk_level, weighted_success_rate)
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
def _get_recommendation(self, risk_level: str, success_rate: float) -> str:
|
| 554 |
+
"""Generate recommendation based on difficulty assessment"""
|
| 555 |
+
if risk_level == "CRITICAL":
|
| 556 |
+
return "Recommend: Break into smaller steps, use external tools, or human-in-the-loop"
|
| 557 |
+
elif risk_level == "HIGH":
|
| 558 |
+
return "Recommend: Multi-step reasoning with verification, consider using web search"
|
| 559 |
+
elif risk_level == "MODERATE":
|
| 560 |
+
return "Recommend: Use chain-of-thought prompting for better accuracy"
|
| 561 |
+
else:
|
| 562 |
+
return "Recommend: Standard LLM response should be adequate"
|
| 563 |
+
|
| 564 |
+
def get_statistics(self) -> Dict[str, Any]:
|
| 565 |
+
"""Get statistics about the indexed benchmark questions"""
|
| 566 |
+
count = self.collection.count()
|
| 567 |
+
|
| 568 |
+
if count == 0:
|
| 569 |
+
return {"total_questions": 0, "message": "No questions indexed yet"}
|
| 570 |
+
|
| 571 |
+
# Get sample to compute statistics (ChromaDB doesn't have aggregate functions)
|
| 572 |
+
sample_size = min(1000, count)
|
| 573 |
+
sample = self.collection.get(limit=sample_size, include=["metadatas"])
|
| 574 |
+
|
| 575 |
+
domains = defaultdict(int)
|
| 576 |
+
sources = defaultdict(int)
|
| 577 |
+
difficulty_levels = defaultdict(int)
|
| 578 |
+
|
| 579 |
+
for metadata in sample['metadatas']:
|
| 580 |
+
domains[metadata['domain']] += 1
|
| 581 |
+
sources[metadata['source']] += 1
|
| 582 |
+
difficulty_levels[metadata['difficulty_label']] += 1
|
| 583 |
+
|
| 584 |
+
return {
|
| 585 |
+
"total_questions": count,
|
| 586 |
+
"domains": dict(domains),
|
| 587 |
+
"sources": dict(sources),
|
| 588 |
+
"difficulty_levels": dict(difficulty_levels)
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
def build_database(
|
| 592 |
+
self,
|
| 593 |
+
load_gpqa: bool = True,
|
| 594 |
+
load_mmlu_pro: bool = True,
|
| 595 |
+
load_math: bool = True,
|
| 596 |
+
max_samples_per_dataset: int = 1000
|
| 597 |
+
):
|
| 598 |
+
"""
|
| 599 |
+
Build the complete vector database from benchmark datasets.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
load_gpqa: Load GPQA Diamond (hardest)
|
| 603 |
+
load_mmlu_pro: Load MMLU-Pro (hard, broad coverage)
|
| 604 |
+
load_math: Load MATH (hard, math-focused)
|
| 605 |
+
max_samples_per_dataset: Max samples per dataset to manage size
|
| 606 |
+
"""
|
| 607 |
+
logger.info("="*80)
|
| 608 |
+
logger.info("Building Benchmark Vector Database")
|
| 609 |
+
logger.info("="*80)
|
| 610 |
+
|
| 611 |
+
all_questions = []
|
| 612 |
+
|
| 613 |
+
# Load datasets
|
| 614 |
+
if load_gpqa:
|
| 615 |
+
gpqa_questions = self.load_gpqa_dataset()
|
| 616 |
+
all_questions.extend(gpqa_questions)
|
| 617 |
+
|
| 618 |
+
if load_mmlu_pro:
|
| 619 |
+
mmlu_questions = self.load_mmlu_pro_dataset(max_samples=max_samples_per_dataset)
|
| 620 |
+
all_questions.extend(mmlu_questions)
|
| 621 |
+
|
| 622 |
+
if load_math:
|
| 623 |
+
math_questions = self.load_math_dataset(max_samples=max_samples_per_dataset // 2)
|
| 624 |
+
all_questions.extend(math_questions)
|
| 625 |
+
|
| 626 |
+
# Index all questions
|
| 627 |
+
if all_questions:
|
| 628 |
+
self.index_questions(all_questions)
|
| 629 |
+
|
| 630 |
+
# Print statistics
|
| 631 |
+
stats = self.get_statistics()
|
| 632 |
+
logger.info("\nDatabase Statistics:")
|
| 633 |
+
logger.info(f" Total Questions: {stats['total_questions']}")
|
| 634 |
+
logger.info(f" Sources: {stats.get('sources', {})}")
|
| 635 |
+
logger.info(f" Domains: {stats.get('domains', {})}")
|
| 636 |
+
|
| 637 |
+
logger.info("="*80)
|
| 638 |
+
logger.info("Database build complete!")
|
| 639 |
+
logger.info("="*80)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def main():
|
| 643 |
+
"""Main entry point for building the vector database"""
|
| 644 |
+
|
| 645 |
+
# Initialize database
|
| 646 |
+
db = BenchmarkVectorDB(
|
| 647 |
+
db_path=Path("/Users/hetalksinmaths/togmal/data/benchmark_vector_db"),
|
| 648 |
+
embedding_model="all-MiniLM-L6-v2"
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Build database with hardest benchmarks
|
| 652 |
+
db.build_database(
|
| 653 |
+
load_gpqa=True, # Start with hardest
|
| 654 |
+
load_mmlu_pro=True,
|
| 655 |
+
load_math=True,
|
| 656 |
+
max_samples_per_dataset=1000
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Test query
|
| 660 |
+
print("\n" + "="*80)
|
| 661 |
+
print("Testing with example prompts:")
|
| 662 |
+
print("="*80)
|
| 663 |
+
|
| 664 |
+
test_prompts = [
|
| 665 |
+
"Calculate the quantum correction to the partition function for a 3D harmonic oscillator",
|
| 666 |
+
"What is the capital of France?",
|
| 667 |
+
"Prove that the square root of 2 is irrational"
|
| 668 |
+
]
|
| 669 |
+
|
| 670 |
+
for prompt in test_prompts:
|
| 671 |
+
print(f"\nPrompt: {prompt}")
|
| 672 |
+
result = db.query_similar_questions(prompt, k=3)
|
| 673 |
+
print(f" Risk Level: {result['risk_level']}")
|
| 674 |
+
print(f" Weighted Success Rate: {result['weighted_success_rate']:.1%}")
|
| 675 |
+
print(f" Explanation: {result['explanation']}")
|
| 676 |
+
print(f" Recommendation: {result['recommendation']}")
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
if __name__ == "__main__":
|
| 680 |
+
main()
|
demo_app.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ToGMAL Difficulty Assessment Demo
|
| 4 |
+
=================================
|
| 5 |
+
|
| 6 |
+
Gradio demo for the vector database-based prompt difficulty assessment.
|
| 7 |
+
Shows real-time difficulty scores and recommendations.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import json
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from benchmark_vector_db import BenchmarkVectorDB
|
| 14 |
+
|
| 15 |
+
# Initialize the vector database
|
| 16 |
+
db = BenchmarkVectorDB(
|
| 17 |
+
db_path=Path("./data/benchmark_vector_db"),
|
| 18 |
+
embedding_model="all-MiniLM-L6-v2"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def analyze_prompt(prompt: str, k: int = 5) -> str:
|
| 22 |
+
"""
|
| 23 |
+
Analyze a prompt and return difficulty assessment.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
prompt: The user's prompt/question
|
| 27 |
+
k: Number of similar questions to retrieve
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Formatted analysis results
|
| 31 |
+
"""
|
| 32 |
+
if not prompt.strip():
|
| 33 |
+
return "Please enter a prompt to analyze."
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# Query the vector database
|
| 37 |
+
result = db.query_similar_questions(prompt, k=k)
|
| 38 |
+
|
| 39 |
+
# Format results
|
| 40 |
+
output = []
|
| 41 |
+
output.append(f"## 🎯 Difficulty Assessment\n")
|
| 42 |
+
output.append(f"**Risk Level**: {result['risk_level']}")
|
| 43 |
+
output.append(f"**Success Rate**: {result['weighted_success_rate']:.1%}")
|
| 44 |
+
output.append(f"**Avg Similarity**: {result['avg_similarity']:.3f}")
|
| 45 |
+
output.append("")
|
| 46 |
+
output.append(f"**Recommendation**: {result['recommendation']}")
|
| 47 |
+
output.append("")
|
| 48 |
+
output.append(f"## 🔍 Similar Benchmark Questions\n")
|
| 49 |
+
|
| 50 |
+
for i, q in enumerate(result['similar_questions'], 1):
|
| 51 |
+
output.append(f"{i}. **{q['question_text'][:100]}...**")
|
| 52 |
+
output.append(f" - Source: {q['source']} ({q['domain']})")
|
| 53 |
+
output.append(f" - Success Rate: {q['success_rate']:.1%}")
|
| 54 |
+
output.append(f" - Similarity: {q['similarity']:.3f}")
|
| 55 |
+
output.append("")
|
| 56 |
+
|
| 57 |
+
output.append(f"*Analyzed using {k} most similar questions from 14,042 benchmark questions*")
|
| 58 |
+
|
| 59 |
+
return "\n".join(output)
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
return f"Error analyzing prompt: {str(e)}"
|
| 63 |
+
|
| 64 |
+
# Create Gradio interface
|
| 65 |
+
with gr.Blocks(title="ToGMAL Prompt Difficulty Analyzer") as demo:
|
| 66 |
+
gr.Markdown("# 🧠 ToGMAL Prompt Difficulty Analyzer")
|
| 67 |
+
gr.Markdown("Enter any prompt to see how difficult it is for current LLMs based on real benchmark data.")
|
| 68 |
+
|
| 69 |
+
with gr.Row():
|
| 70 |
+
with gr.Column():
|
| 71 |
+
prompt_input = gr.Textbox(
|
| 72 |
+
label="Enter your prompt",
|
| 73 |
+
placeholder="e.g., Calculate the quantum correction to the partition function...",
|
| 74 |
+
lines=3
|
| 75 |
+
)
|
| 76 |
+
k_slider = gr.Slider(
|
| 77 |
+
minimum=1,
|
| 78 |
+
maximum=10,
|
| 79 |
+
value=5,
|
| 80 |
+
step=1,
|
| 81 |
+
label="Number of similar questions to show"
|
| 82 |
+
)
|
| 83 |
+
submit_btn = gr.Button("Analyze Difficulty")
|
| 84 |
+
|
| 85 |
+
with gr.Column():
|
| 86 |
+
result_output = gr.Markdown(label="Analysis Results")
|
| 87 |
+
|
| 88 |
+
# Examples
|
| 89 |
+
gr.Examples(
|
| 90 |
+
examples=[
|
| 91 |
+
"Calculate the quantum correction to the partition function for a 3D harmonic oscillator",
|
| 92 |
+
"Prove that there are infinitely many prime numbers",
|
| 93 |
+
"Diagnose a patient with acute chest pain and shortness of breath",
|
| 94 |
+
"Explain the legal doctrine of precedent in common law systems",
|
| 95 |
+
"Implement a binary search tree with insert and search operations",
|
| 96 |
+
"What is 2 + 2?",
|
| 97 |
+
"What is the capital of France?"
|
| 98 |
+
],
|
| 99 |
+
inputs=prompt_input
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Event handling
|
| 103 |
+
submit_btn.click(
|
| 104 |
+
fn=analyze_prompt,
|
| 105 |
+
inputs=[prompt_input, k_slider],
|
| 106 |
+
outputs=result_output
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
prompt_input.submit(
|
| 110 |
+
fn=analyze_prompt,
|
| 111 |
+
inputs=[prompt_input, k_slider],
|
| 112 |
+
outputs=result_output
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
demo.launch(share=True, server_port=7860)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mcp>=1.0.0
|
| 2 |
+
pydantic>=2.0.0
|
| 3 |
+
httpx>=0.24.0
|
| 4 |
+
scikit-learn>=1.2
|
| 5 |
+
numpy>=1.24
|
| 6 |
+
scipy>=1.10
|
| 7 |
+
joblib>=1.3
|
| 8 |
+
sentence-transformers>=2.2.0
|
| 9 |
+
chromadb>=0.4.0
|
| 10 |
+
datasets>=2.14.0
|