Spaces:
Sleeping
Sleeping
Commit
·
66f1733
0
Parent(s):
init
Browse files- .DS_Store +0 -0
- .env.example +1 -0
- .gitignore +17 -0
- .python-version +1 -0
- Dockerfile +10 -0
- README.md +0 -0
- docker-compose.yaml +18 -0
- evolutiontransformer/.DS_Store +0 -0
- evolutiontransformer/__init__.py +0 -0
- evolutiontransformer/api.py +98 -0
- evolutiontransformer/redis.py +45 -0
- evolutiontransformer/worker.py +271 -0
- finetuning/finetuning.ipynb +336 -0
- frontend/.gitignore +24 -0
- frontend/README.md +12 -0
- frontend/eslint.config.js +29 -0
- frontend/index.html +13 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +31 -0
- frontend/public/vite.svg +1 -0
- frontend/src/App.css +44 -0
- frontend/src/App.jsx +35 -0
- frontend/src/assets/react.svg +1 -0
- frontend/src/index.css +68 -0
- frontend/src/main.jsx +10 -0
- frontend/vite.config.js +8 -0
- main.py +8 -0
- pyproject.toml +22 -0
- tests/__init__.py +0 -0
- tests/test_api.py +310 -0
- tests/test_model_actions.py +83 -0
- uv.lock +0 -0
.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
.env.example
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
REDIS_URL=""
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
.env
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
/models/
|
| 13 |
+
|
| 14 |
+
frontend/node_modules
|
| 15 |
+
frontend/build
|
| 16 |
+
|
| 17 |
+
.DS_Store
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
Dockerfile
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
WORKDIR /code
|
| 3 |
+
|
| 4 |
+
RUN pip install uv
|
| 5 |
+
|
| 6 |
+
COPY pyproject.toml uv.lock ./
|
| 7 |
+
|
| 8 |
+
RUN uv export --no-dev | uv pip install --system -r -
|
| 9 |
+
|
| 10 |
+
COPY evolutiontransformer/ ./evolutiontransformer/
|
README.md
ADDED
|
File without changes
|
docker-compose.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
api:
|
| 3 |
+
build: .
|
| 4 |
+
command: >
|
| 5 |
+
sh -c "gunicorn evolutiontransformer.api:app -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8000"
|
| 6 |
+
ports:
|
| 7 |
+
- "8000:8000"
|
| 8 |
+
depends_on:
|
| 9 |
+
- worker
|
| 10 |
+
env_file:
|
| 11 |
+
- .env
|
| 12 |
+
|
| 13 |
+
worker:
|
| 14 |
+
build: .
|
| 15 |
+
command: >
|
| 16 |
+
sh -c "/usr/local/bin/celery -A evolutiontransformer.worker.celery_app worker --loglevel=info -c 1"
|
| 17 |
+
env_file:
|
| 18 |
+
- .env
|
evolutiontransformer/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
evolutiontransformer/__init__.py
ADDED
|
File without changes
|
evolutiontransformer/api.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 4 |
+
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
from fastapi import FastAPI, Depends, HTTPException, Request, Response
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
from celery import Celery
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
celery_app = Celery("tasks", broker=REDIS_URL, backend=REDIS_URL)
|
| 19 |
+
|
| 20 |
+
app = FastAPI()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GenerateRequest(BaseModel):
|
| 24 |
+
model_name: str
|
| 25 |
+
prompt: str
|
| 26 |
+
max_new_tokens: int = 512
|
| 27 |
+
temperature: float = 0.7
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MergeRequest(BaseModel):
|
| 31 |
+
model1_name: str
|
| 32 |
+
model2_name: str
|
| 33 |
+
layer_recipe: List[List[Tuple[int, int, float]]]
|
| 34 |
+
embedding_lambdas: List[float] = [0.5, 0.5]
|
| 35 |
+
linear_lambdas: List[float] = [0.5, 0.5]
|
| 36 |
+
merged_name: str = "merged"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_session_id(request: Request, response: Response):
|
| 40 |
+
session_id = request.cookies.get("session_id")
|
| 41 |
+
|
| 42 |
+
if not session_id:
|
| 43 |
+
session_id = str(uuid.uuid4())
|
| 44 |
+
response.set_cookie(key="session_id", value=session_id)
|
| 45 |
+
|
| 46 |
+
return session_id
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@app.post("/generate")
|
| 50 |
+
def generate(request: GenerateRequest, session_id: str = Depends(get_session_id)):
|
| 51 |
+
task = celery_app.send_task(
|
| 52 |
+
"tasks.inference",
|
| 53 |
+
args=[
|
| 54 |
+
session_id,
|
| 55 |
+
request.model_name,
|
| 56 |
+
request.prompt,
|
| 57 |
+
request.max_new_tokens,
|
| 58 |
+
request.temperature,
|
| 59 |
+
],
|
| 60 |
+
)
|
| 61 |
+
return {"task_id": task.id}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@app.post("/merge")
|
| 65 |
+
def merge(request: MergeRequest, session_id: str = Depends(get_session_id)):
|
| 66 |
+
task = celery_app.send_task(
|
| 67 |
+
"tasks.merge_models",
|
| 68 |
+
args=[
|
| 69 |
+
session_id,
|
| 70 |
+
request.model1_name,
|
| 71 |
+
request.model2_name,
|
| 72 |
+
request.layer_recipe,
|
| 73 |
+
request.embedding_lambdas,
|
| 74 |
+
request.linear_lambdas,
|
| 75 |
+
request.merged_name,
|
| 76 |
+
],
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return {"task_id": task.id}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.post("/list_models")
|
| 83 |
+
def list_models(session_id: str = Depends(get_session_id)):
|
| 84 |
+
task = celery_app.send_task("tasks.get_all_models", args=[session_id])
|
| 85 |
+
return {"task_id": task.id}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@app.get("/tasks/{task_id}")
|
| 89 |
+
def get_task_status(task_id: str):
|
| 90 |
+
task_result = celery_app.AsyncResult(task_id)
|
| 91 |
+
|
| 92 |
+
if task_result.ready():
|
| 93 |
+
if task_result.status == "FAILURE":
|
| 94 |
+
raise HTTPException(status_code=500, detail=str(task_result.result))
|
| 95 |
+
else:
|
| 96 |
+
return {"status": task_result.status, "result": task_result.result}
|
| 97 |
+
else:
|
| 98 |
+
return {"status": task_result.status}
|
evolutiontransformer/redis.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from redis import Redis
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
| 7 |
+
redis_client = Redis.from_url(REDIS_URL, decode_responses=True)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def add_model_to_session(session_id: str, model_name: str, ttl_seconds: int = 3600):
|
| 11 |
+
session_key = f"session:{session_id}:models"
|
| 12 |
+
redis_client.sadd(session_key, model_name)
|
| 13 |
+
redis_client.expire(session_key, ttl_seconds)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_session_models(session_id: str):
|
| 17 |
+
session_key = f"session:{session_id}:models"
|
| 18 |
+
models = redis_client.smembers(session_key)
|
| 19 |
+
return list(models)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def save_model_recipe(
|
| 23 |
+
session_id: str, model_name: str, recipe: dict, ttl_seconds: int = 3600
|
| 24 |
+
):
|
| 25 |
+
recipe_key = f"model:{session_id}:{model_name}"
|
| 26 |
+
serialized_recipe = json.dumps(recipe)
|
| 27 |
+
redis_client.setex(recipe_key, ttl_seconds, serialized_recipe)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_model_recipe(session_id: str, model_name: str):
|
| 31 |
+
recipe_key = f"model:{session_id}:{model_name}"
|
| 32 |
+
serialized_recipe = redis_client.get(recipe_key)
|
| 33 |
+
if serialized_recipe:
|
| 34 |
+
return json.loads(serialized_recipe)
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def delete_session(session_id: str):
|
| 39 |
+
model_names = get_session_models(session_id)
|
| 40 |
+
|
| 41 |
+
for model_name in model_names:
|
| 42 |
+
recipe_key = f"model:{session_id}:{model_name}"
|
| 43 |
+
redis_client.delete(recipe_key)
|
| 44 |
+
|
| 45 |
+
redis_client.delete(f"session:{session_id}:models")
|
evolutiontransformer/worker.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 4 |
+
|
| 5 |
+
from celery import Celery
|
| 6 |
+
from celery.exceptions import InvalidTaskError
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from evolutiontransformer.redis import (
|
| 11 |
+
add_model_to_session,
|
| 12 |
+
get_session_models,
|
| 13 |
+
save_model_recipe,
|
| 14 |
+
get_model_recipe,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
| 19 |
+
from typing import List, Tuple
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
BASE_MODELS_NAMES = ["svamp", "tinystories"]
|
| 25 |
+
BASE_MODELS = {}
|
| 26 |
+
TOKENIZER = None
|
| 27 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
|
| 29 |
+
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
| 30 |
+
|
| 31 |
+
celery_app = Celery("tasks", broker=REDIS_URL, backend=REDIS_URL)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_base_models_if_needed():
|
| 35 |
+
global BASE_MODELS
|
| 36 |
+
if not BASE_MODELS:
|
| 37 |
+
print("WORKER: Loading base models into memory...")
|
| 38 |
+
for model_name in BASE_MODELS_NAMES:
|
| 39 |
+
model_path = f"tcmmichaelb139/gpt2-medium-{model_name}"
|
| 40 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 41 |
+
BASE_MODELS[model_name] = model.to(DEVICE)
|
| 42 |
+
|
| 43 |
+
if get_model_recipe("default", model_name) is None:
|
| 44 |
+
add_model_to_session("default", model_name)
|
| 45 |
+
save_model_recipe(
|
| 46 |
+
"default",
|
| 47 |
+
model_name,
|
| 48 |
+
{
|
| 49 |
+
"layer_recipe": [[(i, model_name, 1.0)] for i in range(24)],
|
| 50 |
+
"embedding_lambdas": [1.0, 1.0],
|
| 51 |
+
"linear_lambdas": [1.0, 1.0],
|
| 52 |
+
},
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
print("WORKER: Base models loaded.")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_tokenizer():
|
| 59 |
+
global TOKENIZER
|
| 60 |
+
if TOKENIZER is None:
|
| 61 |
+
print("WORKER: Initializing Tokenizer...")
|
| 62 |
+
TOKENIZER = AutoTokenizer.from_pretrained("gpt2-medium")
|
| 63 |
+
return TOKENIZER
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def inference(model, prompt, max_new_tokens=512, temperature=0.7):
|
| 67 |
+
global DEVICE
|
| 68 |
+
|
| 69 |
+
do_sample = temperature > 0
|
| 70 |
+
model = model.to(DEVICE)
|
| 71 |
+
model.eval()
|
| 72 |
+
tokenizer = get_tokenizer()
|
| 73 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
outputs = model.generate(
|
| 76 |
+
**inputs,
|
| 77 |
+
max_new_tokens=max_new_tokens,
|
| 78 |
+
do_sample=do_sample,
|
| 79 |
+
temperature=temperature,
|
| 80 |
+
).to(DEVICE)
|
| 81 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def merge_model_recipe(
|
| 85 |
+
model1_recipe: dict,
|
| 86 |
+
model2_recipe: dict,
|
| 87 |
+
layer_recipe: List[List[Tuple[int, int, float]]],
|
| 88 |
+
embedding_lambdas: List[float] = [0.5, 0.5],
|
| 89 |
+
linear_lambdas: List[float] = [0.5, 0.5],
|
| 90 |
+
) -> dict:
|
| 91 |
+
models = [model1_recipe, model2_recipe]
|
| 92 |
+
result_layer_recipe = []
|
| 93 |
+
for makeup in layer_recipe:
|
| 94 |
+
layer_result = {}
|
| 95 |
+
for comb in makeup:
|
| 96 |
+
idx, model_i, alpha = comb
|
| 97 |
+
|
| 98 |
+
for orig_i, orig_model, orig_a in models[model_i]["layer_recipe"][idx]:
|
| 99 |
+
if (orig_i, orig_model) in layer_result:
|
| 100 |
+
layer_result[(orig_i, orig_model)] += alpha * orig_a
|
| 101 |
+
else:
|
| 102 |
+
layer_result[(orig_i, orig_model)] = alpha * orig_a
|
| 103 |
+
|
| 104 |
+
final_layer_result = []
|
| 105 |
+
for k in layer_result:
|
| 106 |
+
final_layer_result.append((k[0], k[1], layer_result[k]))
|
| 107 |
+
|
| 108 |
+
result_layer_recipe.append(final_layer_result)
|
| 109 |
+
|
| 110 |
+
result_embedding_lambdas = [
|
| 111 |
+
embedding_lambdas[0] * model1_recipe["embedding_lambdas"][0]
|
| 112 |
+
+ (1 - embedding_lambdas[0]) * model2_recipe["embedding_lambdas"][0],
|
| 113 |
+
embedding_lambdas[1] * model1_recipe["embedding_lambdas"][1]
|
| 114 |
+
+ (1 - embedding_lambdas[1]) * model2_recipe["embedding_lambdas"][1],
|
| 115 |
+
]
|
| 116 |
+
result_linear_lambdas = [
|
| 117 |
+
linear_lambdas[0] * model1_recipe["linear_lambdas"][0]
|
| 118 |
+
+ (1 - linear_lambdas[0]) * model2_recipe["linear_lambdas"][0],
|
| 119 |
+
linear_lambdas[1] * model1_recipe["linear_lambdas"][1]
|
| 120 |
+
+ (1 - linear_lambdas[1]) * model2_recipe["linear_lambdas"][1],
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"layer_recipe": result_layer_recipe,
|
| 125 |
+
"embedding_lambdas": result_embedding_lambdas,
|
| 126 |
+
"linear_lambdas": result_linear_lambdas,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def merge_models(
|
| 131 |
+
model_recipe: dict,
|
| 132 |
+
base_model="gpt2-medium",
|
| 133 |
+
) -> nn.Module:
|
| 134 |
+
"""Merge two models based on a given recipe."""
|
| 135 |
+
|
| 136 |
+
model1_name = "svamp"
|
| 137 |
+
model2_name = "tinystories"
|
| 138 |
+
|
| 139 |
+
load_base_models_if_needed()
|
| 140 |
+
|
| 141 |
+
def get_model_layer(layer, model):
|
| 142 |
+
return model.transformer.h[layer].state_dict()
|
| 143 |
+
|
| 144 |
+
def merge_layer(recipe: List[Tuple[int, str, float]]):
|
| 145 |
+
base = get_model_layer(recipe[0][0], BASE_MODELS[recipe[0][1]])
|
| 146 |
+
for key in base.keys():
|
| 147 |
+
base[key] = recipe[0][2] * base[key]
|
| 148 |
+
for layer in recipe[1:]:
|
| 149 |
+
layer_data = get_model_layer(layer[0], BASE_MODELS[layer[1]])
|
| 150 |
+
for key in base.keys():
|
| 151 |
+
base[key] += layer[2] * layer_data[key]
|
| 152 |
+
return base
|
| 153 |
+
|
| 154 |
+
print("### Merging models... ###")
|
| 155 |
+
|
| 156 |
+
layer_recipe = model_recipe["layer_recipe"]
|
| 157 |
+
embedding_lambdas = model_recipe["embedding_lambdas"]
|
| 158 |
+
linear_lambdas = model_recipe["linear_lambdas"]
|
| 159 |
+
|
| 160 |
+
config = AutoConfig.from_pretrained(base_model)
|
| 161 |
+
config.n_layer = len(layer_recipe)
|
| 162 |
+
|
| 163 |
+
child_model = AutoModelForCausalLM.from_config(config).to(DEVICE)
|
| 164 |
+
child_model.eval()
|
| 165 |
+
|
| 166 |
+
print("Merging embeddings and lm_head...")
|
| 167 |
+
child_model.transformer.wte.weight.data = (
|
| 168 |
+
embedding_lambdas[0] * BASE_MODELS[model1_name].transformer.wte.weight.data
|
| 169 |
+
+ (1 - embedding_lambdas[0])
|
| 170 |
+
* BASE_MODELS[model2_name].transformer.wte.weight.data
|
| 171 |
+
)
|
| 172 |
+
child_model.transformer.wpe.weight.data = (
|
| 173 |
+
embedding_lambdas[1] * BASE_MODELS[model1_name].transformer.wpe.weight.data
|
| 174 |
+
+ (1 - embedding_lambdas[1])
|
| 175 |
+
* BASE_MODELS[model2_name].transformer.wpe.weight.data
|
| 176 |
+
)
|
| 177 |
+
child_model.lm_head.weight.data = (
|
| 178 |
+
linear_lambdas[0] * BASE_MODELS[model1_name].lm_head.weight.data
|
| 179 |
+
+ (1 - linear_lambdas[0]) * BASE_MODELS[model2_name].lm_head.weight.data
|
| 180 |
+
)
|
| 181 |
+
child_model.transformer.ln_f.weight.data = (
|
| 182 |
+
linear_lambdas[1] * BASE_MODELS[model1_name].transformer.ln_f.weight.data
|
| 183 |
+
+ (1 - linear_lambdas[1])
|
| 184 |
+
* BASE_MODELS[model2_name].transformer.ln_f.weight.data
|
| 185 |
+
)
|
| 186 |
+
child_model.transformer.ln_f.bias.data = (
|
| 187 |
+
linear_lambdas[1] * BASE_MODELS[model1_name].transformer.ln_f.bias.data
|
| 188 |
+
+ (1 - linear_lambdas[1]) * BASE_MODELS[model2_name].transformer.ln_f.bias.data
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
for i, layer in tqdm(enumerate(layer_recipe), desc="Merging layers..."):
|
| 192 |
+
merged_layer = merge_layer(layer)
|
| 193 |
+
child_model.transformer.h[i].load_state_dict(merged_layer)
|
| 194 |
+
|
| 195 |
+
return child_model
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_model_recipe_default(session_id: str, model_name: str) -> dict:
|
| 199 |
+
if model_name in BASE_MODELS_NAMES:
|
| 200 |
+
return get_model_recipe("default", model_name)
|
| 201 |
+
return get_model_recipe(session_id, model_name)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@celery_app.task(name="tasks.inference")
|
| 205 |
+
def inference_task(
|
| 206 |
+
session_id: str, model_name, prompt, max_new_tokens=512, temperature=0.7
|
| 207 |
+
):
|
| 208 |
+
try:
|
| 209 |
+
model_recipe = get_model_recipe_default(session_id, model_name)
|
| 210 |
+
print("WORKER: Creating merged model...")
|
| 211 |
+
model = merge_models(model_recipe)
|
| 212 |
+
print("WORKER: Model loaded.")
|
| 213 |
+
output = inference(model, prompt, max_new_tokens, temperature)
|
| 214 |
+
return {"response": output}
|
| 215 |
+
except Exception as e:
|
| 216 |
+
raise InvalidTaskError(f"Inference failed: {e}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@celery_app.task(name="tasks.merge_models")
|
| 220 |
+
def merge_models_task(
|
| 221 |
+
session_id: str,
|
| 222 |
+
model1_name: str,
|
| 223 |
+
model2_name: str,
|
| 224 |
+
layer_recipe: List[List[Tuple[int, int, float]]],
|
| 225 |
+
embedding_lambdas: List[float] = [0.5, 0.5],
|
| 226 |
+
linear_lambdas: List[float] = [0.5, 0.5],
|
| 227 |
+
merged_name: str = "merged",
|
| 228 |
+
):
|
| 229 |
+
if len(layer_recipe) > 48:
|
| 230 |
+
raise InvalidTaskError("Layer recipe too long. Max 48 layers supported.")
|
| 231 |
+
|
| 232 |
+
session_models = get_session_models(session_id)
|
| 233 |
+
|
| 234 |
+
model1_recipe = get_model_recipe_default(session_id, model1_name)
|
| 235 |
+
model2_recipe = get_model_recipe_default(session_id, model2_name)
|
| 236 |
+
if model1_recipe is None or model2_recipe is None:
|
| 237 |
+
raise InvalidTaskError("One of the models does not exist.")
|
| 238 |
+
|
| 239 |
+
merged_recipe = merge_model_recipe(
|
| 240 |
+
model1_recipe,
|
| 241 |
+
model2_recipe,
|
| 242 |
+
layer_recipe,
|
| 243 |
+
embedding_lambdas,
|
| 244 |
+
linear_lambdas,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
for i in range(20):
|
| 248 |
+
full_merged_name = f"{merged_name}_{i}"
|
| 249 |
+
if full_merged_name not in session_models:
|
| 250 |
+
add_model_to_session(session_id, full_merged_name)
|
| 251 |
+
save_model_recipe(session_id, full_merged_name, merged_recipe)
|
| 252 |
+
return {"response": full_merged_name}
|
| 253 |
+
|
| 254 |
+
raise InvalidTaskError("Could not find a unique model name.")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@celery_app.task(name="tasks.get_all_models")
|
| 258 |
+
def get_all_models_task(session_id: str) -> List[str]:
|
| 259 |
+
global SESSION_MODELS
|
| 260 |
+
return {
|
| 261 |
+
"response": list((BASE_MODELS | SESSION_MODELS[session_id]).keys()),
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@celery_app.task(name="tasks.clear_session_models")
|
| 266 |
+
def clear_session_models_task(session_id: str) -> str:
|
| 267 |
+
global SESSION_MODELS
|
| 268 |
+
if session_id in SESSION_MODELS:
|
| 269 |
+
del SESSION_MODELS[session_id]
|
| 270 |
+
del SESSION_MODELS[session_id]
|
| 271 |
+
return {"response": "SUCCESS"}
|
finetuning/finetuning.ipynb
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"source": [
|
| 22 |
+
"%pip install evaluate"
|
| 23 |
+
],
|
| 24 |
+
"metadata": {
|
| 25 |
+
"id": "aqcbe-No3r2r"
|
| 26 |
+
},
|
| 27 |
+
"execution_count": null,
|
| 28 |
+
"outputs": []
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"metadata": {
|
| 34 |
+
"id": "lOwXY3N4tmbr"
|
| 35 |
+
},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"import numpy as np\n",
|
| 39 |
+
"import matplotlib\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"import torch.nn as nn\n",
|
| 42 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 43 |
+
"from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\n",
|
| 44 |
+
"from datasets import load_dataset\n",
|
| 45 |
+
"import evaluate\n",
|
| 46 |
+
"from copy import deepcopy\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"SEED=42\n",
|
| 49 |
+
"MODEL=\"gpt2-medium\"\n",
|
| 50 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"source": [
|
| 56 |
+
"def tokenize(x, tokenizer):\n",
|
| 57 |
+
" output = tokenizer(x[\"text\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
| 58 |
+
" output[\"label\"] = output[\"input_ids\"].copy()\n",
|
| 59 |
+
" return output\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"def gen_tokenizer(model_name):\n",
|
| 62 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 63 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 64 |
+
" return tokenizer\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"def finetune(config):\n",
|
| 67 |
+
" ds = config[\"ds\"]\n",
|
| 68 |
+
" preprocess_function = config[\"datasets_preprocess\"][config[\"dataset\"]]\n",
|
| 69 |
+
" tokenizer = gen_tokenizer(config[\"model\"])\n",
|
| 70 |
+
"\n",
|
| 71 |
+
" train_dataset = ds[\"train\"].select(range(config[\"max_train_size\"])).map(\n",
|
| 72 |
+
" lambda x: preprocess_function(x, tokenizer),\n",
|
| 73 |
+
" )\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" train_dataset = train_dataset.map(lambda x: tokenize(x, tokenizer), batched=True)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
" model = AutoModelForCausalLM.from_pretrained(config[\"model\"])\n",
|
| 79 |
+
" orig_model = deepcopy(model)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" trainer = Trainer(\n",
|
| 82 |
+
" model=model,\n",
|
| 83 |
+
" args=config[\"training_args\"],\n",
|
| 84 |
+
" train_dataset=train_dataset,\n",
|
| 85 |
+
" processing_class=tokenizer,\n",
|
| 86 |
+
" )\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" print(\"Starting training\")\n",
|
| 89 |
+
" trainer.train()\n",
|
| 90 |
+
" print(\"Training complete\")\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" return orig_model, model"
|
| 93 |
+
],
|
| 94 |
+
"metadata": {
|
| 95 |
+
"id": "B3XugMEV5vZF"
|
| 96 |
+
},
|
| 97 |
+
"execution_count": null,
|
| 98 |
+
"outputs": []
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"source": [
|
| 103 |
+
"def gsm8k_preprocess(x, tokenizer):\n",
|
| 104 |
+
" return {\"text\": f\"Question: {x['question']}\\nAnswer: {x['answer']}\" + tokenizer.eos_token}\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"def svamp_preprocess(x, tokenizer):\n",
|
| 107 |
+
" return {\"text\": f\"{x['question_concat']}\\nAnswer: {x['Answer']}\" + tokenizer.eos_token}\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"def tinystories_preprocess(x, tokenizer):\n",
|
| 110 |
+
" return {\"text\": x[\"text\"] + tokenizer.eos_token}\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"datasets_finetune = {\n",
|
| 113 |
+
" \"openai/gsm8k\": gsm8k_preprocess,\n",
|
| 114 |
+
" \"ChilleD/SVAMP\": svamp_preprocess,\n",
|
| 115 |
+
" \"roneneldan/TinyStories\": tinystories_preprocess\n",
|
| 116 |
+
"}\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"def preprocess_test_gsm8k(x):\n",
|
| 119 |
+
" return {\"text\": f\"Question: {x['question']}\\nAnswer:\" }\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"def preprocess_test_svamp(x):\n",
|
| 122 |
+
" return {\"text\": f\"{x['question_concat']}\\nAnswer:\"}\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"def preprocess_test_tinystories(x):\n",
|
| 125 |
+
" return {\"text\": x[\"text\"]}\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"datasets_finetune_test = {\n",
|
| 128 |
+
" \"openai/gsm8k\": preprocess_test_gsm8k,\n",
|
| 129 |
+
" \"ChilleD/SVAMP\": preprocess_test_svamp,\n",
|
| 130 |
+
" \"roneneldan/TinyStories\": preprocess_test_tinystories\n",
|
| 131 |
+
"}"
|
| 132 |
+
],
|
| 133 |
+
"metadata": {
|
| 134 |
+
"id": "Y09qs3FFxwx1"
|
| 135 |
+
},
|
| 136 |
+
"execution_count": null,
|
| 137 |
+
"outputs": []
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"source": [
|
| 142 |
+
"def test_finetune(dataset, ds, orig_model, model, datasets_preprocess, first_x):\n",
|
| 143 |
+
" tokenizer = gen_tokenizer(MODEL)\n",
|
| 144 |
+
" preprocess_function = datasets_preprocess[dataset]\n",
|
| 145 |
+
" if \"validation\" in ds:\n",
|
| 146 |
+
" ds[\"test\"] = deepcopy(ds[\"validation\"])\n",
|
| 147 |
+
"\n",
|
| 148 |
+
" test_dataset = ds[\"test\"].map(\n",
|
| 149 |
+
" lambda x: preprocess_function(x),\n",
|
| 150 |
+
" )\n",
|
| 151 |
+
"\n",
|
| 152 |
+
" model = model.to(device)\n",
|
| 153 |
+
" orig_model = orig_model.to(device)\n",
|
| 154 |
+
"\n",
|
| 155 |
+
" model.eval()\n",
|
| 156 |
+
" orig_model.eval()\n",
|
| 157 |
+
" xi = 0\n",
|
| 158 |
+
" with torch.no_grad():\n",
|
| 159 |
+
" for x in test_dataset:\n",
|
| 160 |
+
" input_tensor = tokenizer(x[\"text\"], return_tensors=\"pt\")\n",
|
| 161 |
+
" input_tensor[\"input_ids\"] = input_tensor[\"input_ids\"].to(device)\n",
|
| 162 |
+
" input_tensor[\"attention_mask\"] = input_tensor[\"attention_mask\"].to(device)\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" output = orig_model.generate(**input_tensor, max_new_tokens=512)\n",
|
| 165 |
+
"\n",
|
| 166 |
+
" print(\"Original model output\")\n",
|
| 167 |
+
" print(tokenizer.decode(output[0], skip_special_tokens=True))\n",
|
| 168 |
+
"\n",
|
| 169 |
+
" finetuned_output = model.generate(**input_tensor, max_new_tokens=512)\n",
|
| 170 |
+
"\n",
|
| 171 |
+
" print(\"Finetuned model output\")\n",
|
| 172 |
+
" print(tokenizer.decode(finetuned_output[0], skip_special_tokens=True))\n",
|
| 173 |
+
"\n",
|
| 174 |
+
" xi += 1\n",
|
| 175 |
+
" if xi > first_x:\n",
|
| 176 |
+
" break\n"
|
| 177 |
+
],
|
| 178 |
+
"metadata": {
|
| 179 |
+
"id": "zQqD3dHWDj6H"
|
| 180 |
+
},
|
| 181 |
+
"execution_count": null,
|
| 182 |
+
"outputs": []
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"source": [
|
| 187 |
+
"def generate_config(dataset):\n",
|
| 188 |
+
" return config\n"
|
| 189 |
+
],
|
| 190 |
+
"metadata": {
|
| 191 |
+
"id": "2hlh0GERDBdC"
|
| 192 |
+
},
|
| 193 |
+
"execution_count": null,
|
| 194 |
+
"outputs": []
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"cell_type": "code",
|
| 198 |
+
"source": [
|
| 199 |
+
"dataset = \"ChilleD/SVAMP\"\n",
|
| 200 |
+
"ds = load_dataset(dataset, \"default\")\n",
|
| 201 |
+
"ds_1 = dataset.split('/')[1]\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"config = {\n",
|
| 204 |
+
" \"ds\": ds,\n",
|
| 205 |
+
" \"dataset\": dataset,\n",
|
| 206 |
+
" \"datasets_preprocess\": datasets_finetune,\n",
|
| 207 |
+
" \"model\": MODEL,\n",
|
| 208 |
+
" \"max_train_size\": 700,\n",
|
| 209 |
+
" \"training_args\": TrainingArguments(\n",
|
| 210 |
+
" output_dir=f\"./results_{ds_1}\",\n",
|
| 211 |
+
" report_to=\"none\",\n",
|
| 212 |
+
" num_train_epochs=10,\n",
|
| 213 |
+
" per_device_train_batch_size=4,\n",
|
| 214 |
+
" warmup_steps=200,\n",
|
| 215 |
+
" learning_rate=5e-5,\n",
|
| 216 |
+
" weight_decay=0.01,\n",
|
| 217 |
+
" logging_steps=200,\n",
|
| 218 |
+
" save_strategy=\"steps\",\n",
|
| 219 |
+
" metric_for_best_model=\"loss\",\n",
|
| 220 |
+
" greater_is_better=False,\n",
|
| 221 |
+
" seed=SEED,\n",
|
| 222 |
+
" ),\n",
|
| 223 |
+
"}\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"orig_model, model = finetune(config)"
|
| 226 |
+
],
|
| 227 |
+
"metadata": {
|
| 228 |
+
"id": "WSGsa3Xtx04j"
|
| 229 |
+
},
|
| 230 |
+
"execution_count": null,
|
| 231 |
+
"outputs": []
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"cell_type": "code",
|
| 235 |
+
"source": [
|
| 236 |
+
"test_finetune(dataset, ds, orig_model, model, datasets_finetune_test, 3)"
|
| 237 |
+
],
|
| 238 |
+
"metadata": {
|
| 239 |
+
"id": "2Z6kyEGqL7zN"
|
| 240 |
+
},
|
| 241 |
+
"execution_count": null,
|
| 242 |
+
"outputs": []
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"cell_type": "code",
|
| 246 |
+
"source": [
|
| 247 |
+
"dataset = \"roneneldan/TinyStories\"\n",
|
| 248 |
+
"ds = load_dataset(dataset, \"default\")\n",
|
| 249 |
+
"ds_1 = dataset.split('/')[1]\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"config = {\n",
|
| 253 |
+
" \"ds\": ds,\n",
|
| 254 |
+
" \"dataset\": dataset,\n",
|
| 255 |
+
" \"datasets_preprocess\": datasets_finetune,\n",
|
| 256 |
+
" \"model\": MODEL,\n",
|
| 257 |
+
" \"max_train_size\": 7000,\n",
|
| 258 |
+
" \"training_args\": TrainingArguments(\n",
|
| 259 |
+
" output_dir=f\"./results_{ds_1}\",\n",
|
| 260 |
+
" report_to=\"none\",\n",
|
| 261 |
+
" num_train_epochs=1,\n",
|
| 262 |
+
" per_device_train_batch_size=4,\n",
|
| 263 |
+
" warmup_steps=200,\n",
|
| 264 |
+
" learning_rate=5e-5,\n",
|
| 265 |
+
" weight_decay=0.01,\n",
|
| 266 |
+
" logging_steps=200,\n",
|
| 267 |
+
" save_strategy=\"steps\",\n",
|
| 268 |
+
" metric_for_best_model=\"loss\",\n",
|
| 269 |
+
" greater_is_better=False,\n",
|
| 270 |
+
" seed=SEED,\n",
|
| 271 |
+
" ),\n",
|
| 272 |
+
"}\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"orig_model, model = finetune(generate_config(dataset))"
|
| 275 |
+
],
|
| 276 |
+
"metadata": {
|
| 277 |
+
"id": "mOzvvJWP_PL1"
|
| 278 |
+
},
|
| 279 |
+
"execution_count": null,
|
| 280 |
+
"outputs": []
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "code",
|
| 284 |
+
"source": [
|
| 285 |
+
"test_finetune(dataset, ds, orig_model, model, datasets_finetune_test, 3)"
|
| 286 |
+
],
|
| 287 |
+
"metadata": {
|
| 288 |
+
"id": "X6WryZ6p3xGm"
|
| 289 |
+
},
|
| 290 |
+
"execution_count": null,
|
| 291 |
+
"outputs": []
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"source": [
|
| 296 |
+
"from google.colab import files\n",
|
| 297 |
+
"files.download('/content/results_TinyStories/TinyStories-checkpoint-1750.zip')"
|
| 298 |
+
],
|
| 299 |
+
"metadata": {
|
| 300 |
+
"id": "LBJxFu5oVP29"
|
| 301 |
+
},
|
| 302 |
+
"execution_count": null,
|
| 303 |
+
"outputs": []
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"source": [],
|
| 308 |
+
"metadata": {
|
| 309 |
+
"id": "xpFlk05UW87Q"
|
| 310 |
+
},
|
| 311 |
+
"execution_count": null,
|
| 312 |
+
"outputs": []
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"cell_type": "code",
|
| 316 |
+
"source": [
|
| 317 |
+
"from google.colab import files\n",
|
| 318 |
+
"files.download('/content/results_SVAMP/SVAMP-checkpoint-1750.zip')"
|
| 319 |
+
],
|
| 320 |
+
"metadata": {
|
| 321 |
+
"id": "jxnCHVVDVO6j"
|
| 322 |
+
},
|
| 323 |
+
"execution_count": null,
|
| 324 |
+
"outputs": []
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"source": [],
|
| 329 |
+
"metadata": {
|
| 330 |
+
"id": "TygO0jjlVWG_"
|
| 331 |
+
},
|
| 332 |
+
"execution_count": null,
|
| 333 |
+
"outputs": []
|
| 334 |
+
}
|
| 335 |
+
]
|
| 336 |
+
}
|
frontend/.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules
|
| 11 |
+
dist
|
| 12 |
+
dist-ssr
|
| 13 |
+
*.local
|
| 14 |
+
|
| 15 |
+
# Editor directories and files
|
| 16 |
+
.vscode/*
|
| 17 |
+
!.vscode/extensions.json
|
| 18 |
+
.idea
|
| 19 |
+
.DS_Store
|
| 20 |
+
*.suo
|
| 21 |
+
*.ntvs*
|
| 22 |
+
*.njsproj
|
| 23 |
+
*.sln
|
| 24 |
+
*.sw?
|
frontend/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# React + Vite
|
| 2 |
+
|
| 3 |
+
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
| 4 |
+
|
| 5 |
+
Currently, two official plugins are available:
|
| 6 |
+
|
| 7 |
+
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
|
| 8 |
+
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
| 9 |
+
|
| 10 |
+
## Expanding the ESLint configuration
|
| 11 |
+
|
| 12 |
+
If you are developing a production application, we recommend using TypeScript with type-aware lint rules enabled. Check out the [TS template](https://github.com/vitejs/vite/tree/main/packages/create-vite/template-react-ts) for information on how to integrate TypeScript and [`typescript-eslint`](https://typescript-eslint.io) in your project.
|
frontend/eslint.config.js
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 6 |
+
|
| 7 |
+
export default defineConfig([
|
| 8 |
+
globalIgnores(['dist']),
|
| 9 |
+
{
|
| 10 |
+
files: ['**/*.{js,jsx}'],
|
| 11 |
+
extends: [
|
| 12 |
+
js.configs.recommended,
|
| 13 |
+
reactHooks.configs['recommended-latest'],
|
| 14 |
+
reactRefresh.configs.vite,
|
| 15 |
+
],
|
| 16 |
+
languageOptions: {
|
| 17 |
+
ecmaVersion: 2020,
|
| 18 |
+
globals: globals.browser,
|
| 19 |
+
parserOptions: {
|
| 20 |
+
ecmaVersion: 'latest',
|
| 21 |
+
ecmaFeatures: { jsx: true },
|
| 22 |
+
sourceType: 'module',
|
| 23 |
+
},
|
| 24 |
+
},
|
| 25 |
+
rules: {
|
| 26 |
+
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
])
|
frontend/index.html
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>Vite + React</title>
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<div id="root"></div>
|
| 11 |
+
<script type="module" src="/src/main.jsx"></script>
|
| 12 |
+
</body>
|
| 13 |
+
</html>
|
frontend/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "frontend",
|
| 3 |
+
"private": true,
|
| 4 |
+
"version": "0.0.0",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "vite",
|
| 8 |
+
"build": "vite build",
|
| 9 |
+
"lint": "eslint .",
|
| 10 |
+
"preview": "vite preview"
|
| 11 |
+
},
|
| 12 |
+
"dependencies": {
|
| 13 |
+
"@tailwindcss/vite": "^4.1.13",
|
| 14 |
+
"react": "^19.1.1",
|
| 15 |
+
"react-dom": "^19.1.1"
|
| 16 |
+
},
|
| 17 |
+
"devDependencies": {
|
| 18 |
+
"@eslint/js": "^9.35.0",
|
| 19 |
+
"@types/react": "^19.1.13",
|
| 20 |
+
"@types/react-dom": "^19.1.9",
|
| 21 |
+
"@vitejs/plugin-react": "^5.0.2",
|
| 22 |
+
"autoprefixer": "^10.4.21",
|
| 23 |
+
"eslint": "^9.35.0",
|
| 24 |
+
"eslint-plugin-react-hooks": "^5.2.0",
|
| 25 |
+
"eslint-plugin-react-refresh": "^0.4.20",
|
| 26 |
+
"globals": "^16.4.0",
|
| 27 |
+
"postcss": "^8.5.6",
|
| 28 |
+
"tailwindcss": "^4.1.13",
|
| 29 |
+
"vite": "^7.1.6"
|
| 30 |
+
}
|
| 31 |
+
}
|
frontend/public/vite.svg
ADDED
|
|
frontend/src/App.css
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@import "tailwindcss";
|
| 2 |
+
|
| 3 |
+
#root {
|
| 4 |
+
max-width: 1280px;
|
| 5 |
+
margin: 0 auto;
|
| 6 |
+
padding: 2rem;
|
| 7 |
+
text-align: center;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.logo {
|
| 11 |
+
height: 6em;
|
| 12 |
+
padding: 1.5em;
|
| 13 |
+
will-change: filter;
|
| 14 |
+
transition: filter 300ms;
|
| 15 |
+
}
|
| 16 |
+
.logo:hover {
|
| 17 |
+
filter: drop-shadow(0 0 2em #646cffaa);
|
| 18 |
+
}
|
| 19 |
+
.logo.react:hover {
|
| 20 |
+
filter: drop-shadow(0 0 2em #61dafbaa);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
@keyframes logo-spin {
|
| 24 |
+
from {
|
| 25 |
+
transform: rotate(0deg);
|
| 26 |
+
}
|
| 27 |
+
to {
|
| 28 |
+
transform: rotate(360deg);
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
@media (prefers-reduced-motion: no-preference) {
|
| 33 |
+
a:nth-of-type(2) .logo {
|
| 34 |
+
animation: logo-spin infinite 20s linear;
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.card {
|
| 39 |
+
padding: 2em;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
.read-the-docs {
|
| 43 |
+
color: #888;
|
| 44 |
+
}
|
frontend/src/App.jsx
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState } from 'react'
|
| 2 |
+
import reactLogo from './assets/react.svg'
|
| 3 |
+
import viteLogo from '/vite.svg'
|
| 4 |
+
import './App.css'
|
| 5 |
+
|
| 6 |
+
function App() {
|
| 7 |
+
const [count, setCount] = useState(0)
|
| 8 |
+
|
| 9 |
+
return (
|
| 10 |
+
<>
|
| 11 |
+
<div>
|
| 12 |
+
<a href="https://vite.dev" target="_blank">
|
| 13 |
+
<img src={viteLogo} className="logo" alt="Vite logo" />
|
| 14 |
+
</a>
|
| 15 |
+
<a href="https://react.dev" target="_blank">
|
| 16 |
+
<img src={reactLogo} className="logo react" alt="React logo" />
|
| 17 |
+
</a>
|
| 18 |
+
</div>
|
| 19 |
+
<h1>Vite + React</h1>
|
| 20 |
+
<div className="card">
|
| 21 |
+
<button onClick={() => setCount((count) => count + 1)}>
|
| 22 |
+
count is {count}
|
| 23 |
+
</button>
|
| 24 |
+
<p>
|
| 25 |
+
Edit <code>src/App.jsx</code> and save to test HMR
|
| 26 |
+
</p>
|
| 27 |
+
</div>
|
| 28 |
+
<p className="read-the-docs">
|
| 29 |
+
Click on the Vite and React logos to learn more
|
| 30 |
+
</p>
|
| 31 |
+
</>
|
| 32 |
+
)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
export default App
|
frontend/src/assets/react.svg
ADDED
|
|
frontend/src/index.css
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
|
| 3 |
+
line-height: 1.5;
|
| 4 |
+
font-weight: 400;
|
| 5 |
+
|
| 6 |
+
color-scheme: light dark;
|
| 7 |
+
color: rgba(255, 255, 255, 0.87);
|
| 8 |
+
background-color: #242424;
|
| 9 |
+
|
| 10 |
+
font-synthesis: none;
|
| 11 |
+
text-rendering: optimizeLegibility;
|
| 12 |
+
-webkit-font-smoothing: antialiased;
|
| 13 |
+
-moz-osx-font-smoothing: grayscale;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
a {
|
| 17 |
+
font-weight: 500;
|
| 18 |
+
color: #646cff;
|
| 19 |
+
text-decoration: inherit;
|
| 20 |
+
}
|
| 21 |
+
a:hover {
|
| 22 |
+
color: #535bf2;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
body {
|
| 26 |
+
margin: 0;
|
| 27 |
+
display: flex;
|
| 28 |
+
place-items: center;
|
| 29 |
+
min-width: 320px;
|
| 30 |
+
min-height: 100vh;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
h1 {
|
| 34 |
+
font-size: 3.2em;
|
| 35 |
+
line-height: 1.1;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
button {
|
| 39 |
+
border-radius: 8px;
|
| 40 |
+
border: 1px solid transparent;
|
| 41 |
+
padding: 0.6em 1.2em;
|
| 42 |
+
font-size: 1em;
|
| 43 |
+
font-weight: 500;
|
| 44 |
+
font-family: inherit;
|
| 45 |
+
background-color: #1a1a1a;
|
| 46 |
+
cursor: pointer;
|
| 47 |
+
transition: border-color 0.25s;
|
| 48 |
+
}
|
| 49 |
+
button:hover {
|
| 50 |
+
border-color: #646cff;
|
| 51 |
+
}
|
| 52 |
+
button:focus,
|
| 53 |
+
button:focus-visible {
|
| 54 |
+
outline: 4px auto -webkit-focus-ring-color;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
@media (prefers-color-scheme: light) {
|
| 58 |
+
:root {
|
| 59 |
+
color: #213547;
|
| 60 |
+
background-color: #ffffff;
|
| 61 |
+
}
|
| 62 |
+
a:hover {
|
| 63 |
+
color: #747bff;
|
| 64 |
+
}
|
| 65 |
+
button {
|
| 66 |
+
background-color: #f9f9f9;
|
| 67 |
+
}
|
| 68 |
+
}
|
frontend/src/main.jsx
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { StrictMode } from 'react'
|
| 2 |
+
import { createRoot } from 'react-dom/client'
|
| 3 |
+
import './index.css'
|
| 4 |
+
import App from './App.jsx'
|
| 5 |
+
|
| 6 |
+
createRoot(document.getElementById('root')).render(
|
| 7 |
+
<StrictMode>
|
| 8 |
+
<App />
|
| 9 |
+
</StrictMode>,
|
| 10 |
+
)
|
frontend/vite.config.js
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { defineConfig } from "vite";
|
| 2 |
+
import react from "@vitejs/plugin-react";
|
| 3 |
+
import tailwindcss from "@tailwindcss/vite";
|
| 4 |
+
|
| 5 |
+
// https://vite.dev/config/
|
| 6 |
+
export default defineConfig({
|
| 7 |
+
plugins: [react(), tailwindcss()],
|
| 8 |
+
});
|
main.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
| 3 |
+
|
| 4 |
+
start = time.time()
|
| 5 |
+
model = AutoModelForCausalLM.from_pretrained("tcmmichaelb139/gpt2-medium-tinystories")
|
| 6 |
+
print(model)
|
| 7 |
+
|
| 8 |
+
print("Loaded model in", time.time() - start)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "evolutiontransformer"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Simulating evolution among LLMs"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"accelerate>=1.10.1",
|
| 9 |
+
"celery>=5.5.3",
|
| 10 |
+
"datasets>=4.1.1",
|
| 11 |
+
"evaluate>=0.4.6",
|
| 12 |
+
"fastapi>=0.116.2",
|
| 13 |
+
"gradio>=5.46.0",
|
| 14 |
+
"gunicorn>=23.0.0",
|
| 15 |
+
"matplotlib>=3.10.6",
|
| 16 |
+
"numpy>=2.3.3",
|
| 17 |
+
"pytest>=8.4.2",
|
| 18 |
+
"redis>=6.4.0",
|
| 19 |
+
"torch>=2.8.0",
|
| 20 |
+
"transformers>=4.56.1",
|
| 21 |
+
"uvicorn[standard]>=0.35.0",
|
| 22 |
+
]
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
import time
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
from evolutiontransformer.api import app
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_final_answer(text: str) -> int | None:
|
| 10 |
+
numbers = re.findall(r"\d+", text)
|
| 11 |
+
return int(numbers[-1]) if numbers else None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def client():
|
| 16 |
+
with TestClient(app) as c:
|
| 17 |
+
yield c
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def await_task_completion(client, task_id, timeout=60):
|
| 21 |
+
start_time = time.time()
|
| 22 |
+
while time.time() - start_time < timeout:
|
| 23 |
+
status_response = client.get(f"/tasks/{task_id}")
|
| 24 |
+
|
| 25 |
+
print(status_response.json())
|
| 26 |
+
|
| 27 |
+
if status_response.status_code == 500:
|
| 28 |
+
return {"error": status_response.json().get("detail", "Unknown error")}
|
| 29 |
+
assert status_response.status_code == 200
|
| 30 |
+
status_data = status_response.json()
|
| 31 |
+
|
| 32 |
+
if status_data["status"] == "SUCCESS":
|
| 33 |
+
return status_data["result"]
|
| 34 |
+
|
| 35 |
+
time.sleep(2)
|
| 36 |
+
else:
|
| 37 |
+
pytest.fail(
|
| 38 |
+
f"Task {task_id} did not complete within the {timeout}-second timeout."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_generate_endpoint_svamp(client):
|
| 45 |
+
"""
|
| 46 |
+
Tests inference on svamp
|
| 47 |
+
"""
|
| 48 |
+
response = client.post(
|
| 49 |
+
"/generate",
|
| 50 |
+
json={
|
| 51 |
+
"model_name": "svamp",
|
| 52 |
+
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
|
| 53 |
+
"max_new_tokens": 50,
|
| 54 |
+
"temperature": 0.7,
|
| 55 |
+
},
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
assert response.status_code == 200
|
| 59 |
+
data = response.json()
|
| 60 |
+
|
| 61 |
+
assert "task_id" in data
|
| 62 |
+
task_id = data["task_id"]
|
| 63 |
+
|
| 64 |
+
final_result = await_task_completion(client, task_id)
|
| 65 |
+
|
| 66 |
+
assert "response" in final_result
|
| 67 |
+
output_text = final_result["response"]
|
| 68 |
+
|
| 69 |
+
answer = get_final_answer(output_text)
|
| 70 |
+
assert answer == 14
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_merge_then_inference_svamp_1(client):
|
| 74 |
+
"""
|
| 75 |
+
Tests merging then inference for svamp dataset
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
merge_response = client.post(
|
| 79 |
+
"/merge",
|
| 80 |
+
json={
|
| 81 |
+
"model1_name": "svamp",
|
| 82 |
+
"model2_name": "tinystories",
|
| 83 |
+
"layer_recipe": [[(i, 0, 1.0)] for i in range(24)],
|
| 84 |
+
"embedding_lambdas": [1.0, 1.0],
|
| 85 |
+
"linear_lambdas": [1.0, 1.0],
|
| 86 |
+
"merged_name": "svamp_merged",
|
| 87 |
+
},
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
assert merge_response.status_code == 200
|
| 91 |
+
merge_data = merge_response.json()
|
| 92 |
+
assert "task_id" in merge_data
|
| 93 |
+
merge_task_id = merge_data["task_id"]
|
| 94 |
+
|
| 95 |
+
merge_status_data = await_task_completion(client, merge_task_id)
|
| 96 |
+
model_name = merge_status_data["response"]
|
| 97 |
+
|
| 98 |
+
time.sleep(5)
|
| 99 |
+
|
| 100 |
+
generate_response = client.post(
|
| 101 |
+
"/generate",
|
| 102 |
+
json={
|
| 103 |
+
"model_name": model_name,
|
| 104 |
+
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
|
| 105 |
+
"max_new_tokens": 50,
|
| 106 |
+
"temperature": 0.7,
|
| 107 |
+
},
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
assert generate_response.status_code == 200
|
| 111 |
+
generate_data = generate_response.json()
|
| 112 |
+
assert "task_id" in generate_data
|
| 113 |
+
generate_task_id = generate_data["task_id"]
|
| 114 |
+
|
| 115 |
+
final_result = await_task_completion(client, generate_task_id)
|
| 116 |
+
|
| 117 |
+
assert "response" in final_result
|
| 118 |
+
output_text = final_result["response"]
|
| 119 |
+
answer = get_final_answer(output_text)
|
| 120 |
+
|
| 121 |
+
assert answer == 14
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_merge_then_inference_svamp_2(client):
|
| 125 |
+
"""
|
| 126 |
+
Tests merging then inference for svamp dataset
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
merge_repsonse = client.post(
|
| 130 |
+
"/merge",
|
| 131 |
+
json={
|
| 132 |
+
"model1_name": "svamp",
|
| 133 |
+
"model2_name": "tinystories",
|
| 134 |
+
"layer_recipe": [[(i % 24, 0, 1.0 if i < 24 else 0.5)] for i in range(48)],
|
| 135 |
+
"embedding_lambdas": [1.0, 1.0],
|
| 136 |
+
"linear_lambdas": [1.0, 1.0],
|
| 137 |
+
"merged_name": "svamp_merged",
|
| 138 |
+
},
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
assert merge_repsonse.status_code == 200
|
| 142 |
+
merge_data = merge_repsonse.json()
|
| 143 |
+
assert "task_id" in merge_data
|
| 144 |
+
merge_task_id = merge_data["task_id"]
|
| 145 |
+
|
| 146 |
+
merge_status_data = await_task_completion(client, merge_task_id)
|
| 147 |
+
|
| 148 |
+
model_name = merge_status_data["response"]
|
| 149 |
+
|
| 150 |
+
merge_response2 = client.post(
|
| 151 |
+
"/merge",
|
| 152 |
+
json={
|
| 153 |
+
"model1_name": model_name,
|
| 154 |
+
"model2_name": "tinystories",
|
| 155 |
+
"layer_recipe": [[(i, 1, 0.25)] for i in range(24)],
|
| 156 |
+
"embedding_lambdas": [0.0, 0.0],
|
| 157 |
+
"linear_lambdas": [0.0, 0.0],
|
| 158 |
+
"merged_name": "svamp_merged",
|
| 159 |
+
},
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
assert merge_response2.status_code == 200
|
| 163 |
+
merge_data2 = merge_response2.json()
|
| 164 |
+
assert "task_id" in merge_data2
|
| 165 |
+
merge_task_id2 = merge_data2["task_id"]
|
| 166 |
+
merge_status_data2 = await_task_completion(client, merge_task_id2)
|
| 167 |
+
model_name2 = merge_status_data2["response"]
|
| 168 |
+
|
| 169 |
+
time.sleep(5)
|
| 170 |
+
|
| 171 |
+
generate_response = client.post(
|
| 172 |
+
"/generate",
|
| 173 |
+
json={
|
| 174 |
+
"model_name": model_name2,
|
| 175 |
+
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
|
| 176 |
+
"max_new_tokens": 50,
|
| 177 |
+
"temperature": 0.7,
|
| 178 |
+
},
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
assert generate_response.status_code == 200
|
| 182 |
+
generate_data = generate_response.json()
|
| 183 |
+
assert "task_id" in generate_data
|
| 184 |
+
generate_task_id = generate_data["task_id"]
|
| 185 |
+
|
| 186 |
+
final_result = await_task_completion(client, generate_task_id)
|
| 187 |
+
|
| 188 |
+
assert "response" in final_result
|
| 189 |
+
output_text = final_result["response"]
|
| 190 |
+
answer = get_final_answer(output_text)
|
| 191 |
+
|
| 192 |
+
assert answer == 14
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def test_merge_two_children_then_merge(client):
|
| 196 |
+
"""
|
| 197 |
+
Tests creating two children and merging them
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
merge_response1 = client.post(
|
| 201 |
+
"/merge",
|
| 202 |
+
json={
|
| 203 |
+
"model1_name": "svamp",
|
| 204 |
+
"model2_name": "tinystories",
|
| 205 |
+
"layer_recipe": [[(i, 0, 0.8)] for i in range(12)]
|
| 206 |
+
+ [[(i, 1, 0.6)] for i in range(12)],
|
| 207 |
+
"embedding_lambdas": [0.7, 0.3],
|
| 208 |
+
"linear_lambdas": [0.8, 0.2],
|
| 209 |
+
"merged_name": "child1",
|
| 210 |
+
},
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
assert merge_response1.status_code == 200
|
| 214 |
+
merge_data1 = merge_response1.json()
|
| 215 |
+
assert "task_id" in merge_data1
|
| 216 |
+
merge_task_id1 = merge_data1["task_id"]
|
| 217 |
+
merge_status_data1 = await_task_completion(client, merge_task_id1)
|
| 218 |
+
child1_name = merge_status_data1["response"]
|
| 219 |
+
|
| 220 |
+
merge_response2 = client.post(
|
| 221 |
+
"/merge",
|
| 222 |
+
json={
|
| 223 |
+
"model1_name": "svamp",
|
| 224 |
+
"model2_name": "tinystories",
|
| 225 |
+
"layer_recipe": [[(i, 1, 0.9)] for i in range(8)]
|
| 226 |
+
+ [[(i, 0, 0.4)] for i in range(16)],
|
| 227 |
+
"embedding_lambdas": [0.2, 0.9],
|
| 228 |
+
"linear_lambdas": [0.3, 0.7],
|
| 229 |
+
"merged_name": "child2",
|
| 230 |
+
},
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
assert merge_response2.status_code == 200
|
| 234 |
+
merge_data2 = merge_response2.json()
|
| 235 |
+
assert "task_id" in merge_data2
|
| 236 |
+
merge_task_id2 = merge_data2["task_id"]
|
| 237 |
+
merge_status_data2 = await_task_completion(client, merge_task_id2)
|
| 238 |
+
child2_name = merge_status_data2["response"]
|
| 239 |
+
|
| 240 |
+
merge_response3 = client.post(
|
| 241 |
+
"/merge",
|
| 242 |
+
json={
|
| 243 |
+
"model1_name": child1_name,
|
| 244 |
+
"model2_name": child2_name,
|
| 245 |
+
"layer_recipe": [[(i, 0, 0.6), (i, 1, 0.4)] for i in range(24)],
|
| 246 |
+
"embedding_lambdas": [0.5, 0.5],
|
| 247 |
+
"linear_lambdas": [0.6, 0.4],
|
| 248 |
+
"merged_name": "final_merged",
|
| 249 |
+
},
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
assert merge_response3.status_code == 200
|
| 253 |
+
merge_data3 = merge_response3.json()
|
| 254 |
+
assert "task_id" in merge_data3
|
| 255 |
+
merge_task_id3 = merge_data3["task_id"]
|
| 256 |
+
merge_status_data3 = await_task_completion(client, merge_task_id3)
|
| 257 |
+
final_model_name = merge_status_data3["response"]
|
| 258 |
+
|
| 259 |
+
time.sleep(5)
|
| 260 |
+
|
| 261 |
+
generate_response = client.post(
|
| 262 |
+
"/generate",
|
| 263 |
+
json={
|
| 264 |
+
"model_name": final_model_name,
|
| 265 |
+
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
|
| 266 |
+
"max_new_tokens": 50,
|
| 267 |
+
"temperature": 0.7,
|
| 268 |
+
},
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
assert generate_response.status_code == 200
|
| 272 |
+
generate_data = generate_response.json()
|
| 273 |
+
assert "task_id" in generate_data
|
| 274 |
+
generate_task_id = generate_data["task_id"]
|
| 275 |
+
|
| 276 |
+
final_result = await_task_completion(client, generate_task_id)
|
| 277 |
+
|
| 278 |
+
assert "response" in final_result
|
| 279 |
+
output_text = final_result["response"]
|
| 280 |
+
answer = get_final_answer(output_text)
|
| 281 |
+
|
| 282 |
+
assert answer == 14
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def test_merge_fail(client):
|
| 286 |
+
"""
|
| 287 |
+
Tests merging with too many layers
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
merge_repsonse = client.post(
|
| 291 |
+
"/merge",
|
| 292 |
+
json={
|
| 293 |
+
"model1_name": "svamp",
|
| 294 |
+
"model2_name": "tinystories",
|
| 295 |
+
"layer_recipe": [[(i, 0, 1.0)] for i in range(50)],
|
| 296 |
+
"embedding_lambdas": [1.0, 1.0],
|
| 297 |
+
"linear_lambdas": [1.0, 1.0],
|
| 298 |
+
"merged_name": "svamp_merged",
|
| 299 |
+
},
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
assert merge_repsonse.status_code == 200
|
| 303 |
+
merge_data = merge_repsonse.json()
|
| 304 |
+
assert "task_id" in merge_data
|
| 305 |
+
merge_task_id = merge_data["task_id"]
|
| 306 |
+
|
| 307 |
+
merge_status_data = await_task_completion(client, merge_task_id)
|
| 308 |
+
assert "response" not in merge_status_data
|
| 309 |
+
assert "error" in merge_status_data
|
| 310 |
+
assert "Layer recipe too long" in merge_status_data["error"]
|
tests/test_model_actions.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from evolutiontransformer.worker import (
|
| 7 |
+
load_base_models_if_needed,
|
| 8 |
+
BASE_MODELS,
|
| 9 |
+
inference,
|
| 10 |
+
inference_task,
|
| 11 |
+
merge_models,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_final_answer(text: str) -> int | None:
|
| 16 |
+
numbers = re.findall(r"\d+", text)
|
| 17 |
+
return int(numbers[-1]) if numbers else None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_inference():
|
| 21 |
+
session_id = "test_session"
|
| 22 |
+
|
| 23 |
+
print("### Testing inference on SVAMP model...")
|
| 24 |
+
prompt = "If there are 3 cars and 2 bikes, how many vehicles are there in total?\nAnswer:"
|
| 25 |
+
output = inference_task(session_id, "svamp", prompt)
|
| 26 |
+
assert get_final_answer(output["response"]) == 5
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_merge_models():
|
| 30 |
+
load_base_models_if_needed()
|
| 31 |
+
|
| 32 |
+
model_recipe = {
|
| 33 |
+
"layer_recipe": [[(i, "svamp", 1.0)] for i in range(24)],
|
| 34 |
+
"embedding_lambdas": [1.0, 1.0],
|
| 35 |
+
"linear_lambdas": [1.0, 1.0],
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
merged_model = merge_models(model_recipe)
|
| 39 |
+
|
| 40 |
+
for (name1, param1), (name2, param2) in zip(
|
| 41 |
+
BASE_MODELS["svamp"].named_parameters(), merged_model.named_parameters()
|
| 42 |
+
):
|
| 43 |
+
assert torch.allclose(param1, param2)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_merge_models_with_inference1():
|
| 47 |
+
load_base_models_if_needed()
|
| 48 |
+
|
| 49 |
+
model_recipe = {
|
| 50 |
+
"layer_recipe": [
|
| 51 |
+
[(i % 24, "svamp", 1.0 if i < 24 else 0.5)] for i in range(48)
|
| 52 |
+
],
|
| 53 |
+
"embedding_lambdas": [1.0, 1.0],
|
| 54 |
+
"linear_lambdas": [1.0, 1.0],
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
merged_model = merge_models(model_recipe)
|
| 58 |
+
|
| 59 |
+
print(
|
| 60 |
+
inference(
|
| 61 |
+
merged_model,
|
| 62 |
+
"A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_merge_models_with_inference2():
|
| 68 |
+
load_base_models_if_needed()
|
| 69 |
+
|
| 70 |
+
model_recipe = {
|
| 71 |
+
"layer_recipe": [[(i, "tinystories", 1.0)] for i in range(24)],
|
| 72 |
+
"embedding_lambdas": [0.0, 0.0],
|
| 73 |
+
"linear_lambdas": [0.0, 0.0],
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
merged_model = merge_models(model_recipe)
|
| 77 |
+
|
| 78 |
+
print(
|
| 79 |
+
inference(
|
| 80 |
+
merged_model,
|
| 81 |
+
"A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
|
| 82 |
+
)
|
| 83 |
+
)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|