tcmmichaelb139 commited on
Commit
66f1733
·
0 Parent(s):
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.env.example ADDED
@@ -0,0 +1 @@
 
 
1
+ REDIS_URL=""
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ .env
10
+ .venv
11
+
12
+ /models/
13
+
14
+ frontend/node_modules
15
+ frontend/build
16
+
17
+ .DS_Store
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+ WORKDIR /code
3
+
4
+ RUN pip install uv
5
+
6
+ COPY pyproject.toml uv.lock ./
7
+
8
+ RUN uv export --no-dev | uv pip install --system -r -
9
+
10
+ COPY evolutiontransformer/ ./evolutiontransformer/
README.md ADDED
File without changes
docker-compose.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ api:
3
+ build: .
4
+ command: >
5
+ sh -c "gunicorn evolutiontransformer.api:app -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8000"
6
+ ports:
7
+ - "8000:8000"
8
+ depends_on:
9
+ - worker
10
+ env_file:
11
+ - .env
12
+
13
+ worker:
14
+ build: .
15
+ command: >
16
+ sh -c "/usr/local/bin/celery -A evolutiontransformer.worker.celery_app worker --loglevel=info -c 1"
17
+ env_file:
18
+ - .env
evolutiontransformer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
evolutiontransformer/__init__.py ADDED
File without changes
evolutiontransformer/api.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ import uuid
6
+ from typing import List, Tuple
7
+ from fastapi import FastAPI, Depends, HTTPException, Request, Response
8
+ from pydantic import BaseModel
9
+ from celery import Celery
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+
15
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
16
+
17
+
18
+ celery_app = Celery("tasks", broker=REDIS_URL, backend=REDIS_URL)
19
+
20
+ app = FastAPI()
21
+
22
+
23
+ class GenerateRequest(BaseModel):
24
+ model_name: str
25
+ prompt: str
26
+ max_new_tokens: int = 512
27
+ temperature: float = 0.7
28
+
29
+
30
+ class MergeRequest(BaseModel):
31
+ model1_name: str
32
+ model2_name: str
33
+ layer_recipe: List[List[Tuple[int, int, float]]]
34
+ embedding_lambdas: List[float] = [0.5, 0.5]
35
+ linear_lambdas: List[float] = [0.5, 0.5]
36
+ merged_name: str = "merged"
37
+
38
+
39
+ def get_session_id(request: Request, response: Response):
40
+ session_id = request.cookies.get("session_id")
41
+
42
+ if not session_id:
43
+ session_id = str(uuid.uuid4())
44
+ response.set_cookie(key="session_id", value=session_id)
45
+
46
+ return session_id
47
+
48
+
49
+ @app.post("/generate")
50
+ def generate(request: GenerateRequest, session_id: str = Depends(get_session_id)):
51
+ task = celery_app.send_task(
52
+ "tasks.inference",
53
+ args=[
54
+ session_id,
55
+ request.model_name,
56
+ request.prompt,
57
+ request.max_new_tokens,
58
+ request.temperature,
59
+ ],
60
+ )
61
+ return {"task_id": task.id}
62
+
63
+
64
+ @app.post("/merge")
65
+ def merge(request: MergeRequest, session_id: str = Depends(get_session_id)):
66
+ task = celery_app.send_task(
67
+ "tasks.merge_models",
68
+ args=[
69
+ session_id,
70
+ request.model1_name,
71
+ request.model2_name,
72
+ request.layer_recipe,
73
+ request.embedding_lambdas,
74
+ request.linear_lambdas,
75
+ request.merged_name,
76
+ ],
77
+ )
78
+
79
+ return {"task_id": task.id}
80
+
81
+
82
+ @app.post("/list_models")
83
+ def list_models(session_id: str = Depends(get_session_id)):
84
+ task = celery_app.send_task("tasks.get_all_models", args=[session_id])
85
+ return {"task_id": task.id}
86
+
87
+
88
+ @app.get("/tasks/{task_id}")
89
+ def get_task_status(task_id: str):
90
+ task_result = celery_app.AsyncResult(task_id)
91
+
92
+ if task_result.ready():
93
+ if task_result.status == "FAILURE":
94
+ raise HTTPException(status_code=500, detail=str(task_result.result))
95
+ else:
96
+ return {"status": task_result.status, "result": task_result.result}
97
+ else:
98
+ return {"status": task_result.status}
evolutiontransformer/redis.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from redis import Redis
3
+ import json
4
+
5
+
6
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
7
+ redis_client = Redis.from_url(REDIS_URL, decode_responses=True)
8
+
9
+
10
+ def add_model_to_session(session_id: str, model_name: str, ttl_seconds: int = 3600):
11
+ session_key = f"session:{session_id}:models"
12
+ redis_client.sadd(session_key, model_name)
13
+ redis_client.expire(session_key, ttl_seconds)
14
+
15
+
16
+ def get_session_models(session_id: str):
17
+ session_key = f"session:{session_id}:models"
18
+ models = redis_client.smembers(session_key)
19
+ return list(models)
20
+
21
+
22
+ def save_model_recipe(
23
+ session_id: str, model_name: str, recipe: dict, ttl_seconds: int = 3600
24
+ ):
25
+ recipe_key = f"model:{session_id}:{model_name}"
26
+ serialized_recipe = json.dumps(recipe)
27
+ redis_client.setex(recipe_key, ttl_seconds, serialized_recipe)
28
+
29
+
30
+ def get_model_recipe(session_id: str, model_name: str):
31
+ recipe_key = f"model:{session_id}:{model_name}"
32
+ serialized_recipe = redis_client.get(recipe_key)
33
+ if serialized_recipe:
34
+ return json.loads(serialized_recipe)
35
+ return None
36
+
37
+
38
+ def delete_session(session_id: str):
39
+ model_names = get_session_models(session_id)
40
+
41
+ for model_name in model_names:
42
+ recipe_key = f"model:{session_id}:{model_name}"
43
+ redis_client.delete(recipe_key)
44
+
45
+ redis_client.delete(f"session:{session_id}:models")
evolutiontransformer/worker.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from celery import Celery
6
+ from celery.exceptions import InvalidTaskError
7
+ import torch
8
+ import torch.nn as nn
9
+ from dotenv import load_dotenv
10
+ from evolutiontransformer.redis import (
11
+ add_model_to_session,
12
+ get_session_models,
13
+ save_model_recipe,
14
+ get_model_recipe,
15
+ )
16
+
17
+
18
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
19
+ from typing import List, Tuple
20
+ from tqdm import tqdm
21
+
22
+ load_dotenv()
23
+
24
+ BASE_MODELS_NAMES = ["svamp", "tinystories"]
25
+ BASE_MODELS = {}
26
+ TOKENIZER = None
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
30
+
31
+ celery_app = Celery("tasks", broker=REDIS_URL, backend=REDIS_URL)
32
+
33
+
34
+ def load_base_models_if_needed():
35
+ global BASE_MODELS
36
+ if not BASE_MODELS:
37
+ print("WORKER: Loading base models into memory...")
38
+ for model_name in BASE_MODELS_NAMES:
39
+ model_path = f"tcmmichaelb139/gpt2-medium-{model_name}"
40
+ model = AutoModelForCausalLM.from_pretrained(model_path)
41
+ BASE_MODELS[model_name] = model.to(DEVICE)
42
+
43
+ if get_model_recipe("default", model_name) is None:
44
+ add_model_to_session("default", model_name)
45
+ save_model_recipe(
46
+ "default",
47
+ model_name,
48
+ {
49
+ "layer_recipe": [[(i, model_name, 1.0)] for i in range(24)],
50
+ "embedding_lambdas": [1.0, 1.0],
51
+ "linear_lambdas": [1.0, 1.0],
52
+ },
53
+ )
54
+
55
+ print("WORKER: Base models loaded.")
56
+
57
+
58
+ def get_tokenizer():
59
+ global TOKENIZER
60
+ if TOKENIZER is None:
61
+ print("WORKER: Initializing Tokenizer...")
62
+ TOKENIZER = AutoTokenizer.from_pretrained("gpt2-medium")
63
+ return TOKENIZER
64
+
65
+
66
+ def inference(model, prompt, max_new_tokens=512, temperature=0.7):
67
+ global DEVICE
68
+
69
+ do_sample = temperature > 0
70
+ model = model.to(DEVICE)
71
+ model.eval()
72
+ tokenizer = get_tokenizer()
73
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
74
+ with torch.no_grad():
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_new_tokens=max_new_tokens,
78
+ do_sample=do_sample,
79
+ temperature=temperature,
80
+ ).to(DEVICE)
81
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
82
+
83
+
84
+ def merge_model_recipe(
85
+ model1_recipe: dict,
86
+ model2_recipe: dict,
87
+ layer_recipe: List[List[Tuple[int, int, float]]],
88
+ embedding_lambdas: List[float] = [0.5, 0.5],
89
+ linear_lambdas: List[float] = [0.5, 0.5],
90
+ ) -> dict:
91
+ models = [model1_recipe, model2_recipe]
92
+ result_layer_recipe = []
93
+ for makeup in layer_recipe:
94
+ layer_result = {}
95
+ for comb in makeup:
96
+ idx, model_i, alpha = comb
97
+
98
+ for orig_i, orig_model, orig_a in models[model_i]["layer_recipe"][idx]:
99
+ if (orig_i, orig_model) in layer_result:
100
+ layer_result[(orig_i, orig_model)] += alpha * orig_a
101
+ else:
102
+ layer_result[(orig_i, orig_model)] = alpha * orig_a
103
+
104
+ final_layer_result = []
105
+ for k in layer_result:
106
+ final_layer_result.append((k[0], k[1], layer_result[k]))
107
+
108
+ result_layer_recipe.append(final_layer_result)
109
+
110
+ result_embedding_lambdas = [
111
+ embedding_lambdas[0] * model1_recipe["embedding_lambdas"][0]
112
+ + (1 - embedding_lambdas[0]) * model2_recipe["embedding_lambdas"][0],
113
+ embedding_lambdas[1] * model1_recipe["embedding_lambdas"][1]
114
+ + (1 - embedding_lambdas[1]) * model2_recipe["embedding_lambdas"][1],
115
+ ]
116
+ result_linear_lambdas = [
117
+ linear_lambdas[0] * model1_recipe["linear_lambdas"][0]
118
+ + (1 - linear_lambdas[0]) * model2_recipe["linear_lambdas"][0],
119
+ linear_lambdas[1] * model1_recipe["linear_lambdas"][1]
120
+ + (1 - linear_lambdas[1]) * model2_recipe["linear_lambdas"][1],
121
+ ]
122
+
123
+ return {
124
+ "layer_recipe": result_layer_recipe,
125
+ "embedding_lambdas": result_embedding_lambdas,
126
+ "linear_lambdas": result_linear_lambdas,
127
+ }
128
+
129
+
130
+ def merge_models(
131
+ model_recipe: dict,
132
+ base_model="gpt2-medium",
133
+ ) -> nn.Module:
134
+ """Merge two models based on a given recipe."""
135
+
136
+ model1_name = "svamp"
137
+ model2_name = "tinystories"
138
+
139
+ load_base_models_if_needed()
140
+
141
+ def get_model_layer(layer, model):
142
+ return model.transformer.h[layer].state_dict()
143
+
144
+ def merge_layer(recipe: List[Tuple[int, str, float]]):
145
+ base = get_model_layer(recipe[0][0], BASE_MODELS[recipe[0][1]])
146
+ for key in base.keys():
147
+ base[key] = recipe[0][2] * base[key]
148
+ for layer in recipe[1:]:
149
+ layer_data = get_model_layer(layer[0], BASE_MODELS[layer[1]])
150
+ for key in base.keys():
151
+ base[key] += layer[2] * layer_data[key]
152
+ return base
153
+
154
+ print("### Merging models... ###")
155
+
156
+ layer_recipe = model_recipe["layer_recipe"]
157
+ embedding_lambdas = model_recipe["embedding_lambdas"]
158
+ linear_lambdas = model_recipe["linear_lambdas"]
159
+
160
+ config = AutoConfig.from_pretrained(base_model)
161
+ config.n_layer = len(layer_recipe)
162
+
163
+ child_model = AutoModelForCausalLM.from_config(config).to(DEVICE)
164
+ child_model.eval()
165
+
166
+ print("Merging embeddings and lm_head...")
167
+ child_model.transformer.wte.weight.data = (
168
+ embedding_lambdas[0] * BASE_MODELS[model1_name].transformer.wte.weight.data
169
+ + (1 - embedding_lambdas[0])
170
+ * BASE_MODELS[model2_name].transformer.wte.weight.data
171
+ )
172
+ child_model.transformer.wpe.weight.data = (
173
+ embedding_lambdas[1] * BASE_MODELS[model1_name].transformer.wpe.weight.data
174
+ + (1 - embedding_lambdas[1])
175
+ * BASE_MODELS[model2_name].transformer.wpe.weight.data
176
+ )
177
+ child_model.lm_head.weight.data = (
178
+ linear_lambdas[0] * BASE_MODELS[model1_name].lm_head.weight.data
179
+ + (1 - linear_lambdas[0]) * BASE_MODELS[model2_name].lm_head.weight.data
180
+ )
181
+ child_model.transformer.ln_f.weight.data = (
182
+ linear_lambdas[1] * BASE_MODELS[model1_name].transformer.ln_f.weight.data
183
+ + (1 - linear_lambdas[1])
184
+ * BASE_MODELS[model2_name].transformer.ln_f.weight.data
185
+ )
186
+ child_model.transformer.ln_f.bias.data = (
187
+ linear_lambdas[1] * BASE_MODELS[model1_name].transformer.ln_f.bias.data
188
+ + (1 - linear_lambdas[1]) * BASE_MODELS[model2_name].transformer.ln_f.bias.data
189
+ )
190
+
191
+ for i, layer in tqdm(enumerate(layer_recipe), desc="Merging layers..."):
192
+ merged_layer = merge_layer(layer)
193
+ child_model.transformer.h[i].load_state_dict(merged_layer)
194
+
195
+ return child_model
196
+
197
+
198
+ def get_model_recipe_default(session_id: str, model_name: str) -> dict:
199
+ if model_name in BASE_MODELS_NAMES:
200
+ return get_model_recipe("default", model_name)
201
+ return get_model_recipe(session_id, model_name)
202
+
203
+
204
+ @celery_app.task(name="tasks.inference")
205
+ def inference_task(
206
+ session_id: str, model_name, prompt, max_new_tokens=512, temperature=0.7
207
+ ):
208
+ try:
209
+ model_recipe = get_model_recipe_default(session_id, model_name)
210
+ print("WORKER: Creating merged model...")
211
+ model = merge_models(model_recipe)
212
+ print("WORKER: Model loaded.")
213
+ output = inference(model, prompt, max_new_tokens, temperature)
214
+ return {"response": output}
215
+ except Exception as e:
216
+ raise InvalidTaskError(f"Inference failed: {e}")
217
+
218
+
219
+ @celery_app.task(name="tasks.merge_models")
220
+ def merge_models_task(
221
+ session_id: str,
222
+ model1_name: str,
223
+ model2_name: str,
224
+ layer_recipe: List[List[Tuple[int, int, float]]],
225
+ embedding_lambdas: List[float] = [0.5, 0.5],
226
+ linear_lambdas: List[float] = [0.5, 0.5],
227
+ merged_name: str = "merged",
228
+ ):
229
+ if len(layer_recipe) > 48:
230
+ raise InvalidTaskError("Layer recipe too long. Max 48 layers supported.")
231
+
232
+ session_models = get_session_models(session_id)
233
+
234
+ model1_recipe = get_model_recipe_default(session_id, model1_name)
235
+ model2_recipe = get_model_recipe_default(session_id, model2_name)
236
+ if model1_recipe is None or model2_recipe is None:
237
+ raise InvalidTaskError("One of the models does not exist.")
238
+
239
+ merged_recipe = merge_model_recipe(
240
+ model1_recipe,
241
+ model2_recipe,
242
+ layer_recipe,
243
+ embedding_lambdas,
244
+ linear_lambdas,
245
+ )
246
+
247
+ for i in range(20):
248
+ full_merged_name = f"{merged_name}_{i}"
249
+ if full_merged_name not in session_models:
250
+ add_model_to_session(session_id, full_merged_name)
251
+ save_model_recipe(session_id, full_merged_name, merged_recipe)
252
+ return {"response": full_merged_name}
253
+
254
+ raise InvalidTaskError("Could not find a unique model name.")
255
+
256
+
257
+ @celery_app.task(name="tasks.get_all_models")
258
+ def get_all_models_task(session_id: str) -> List[str]:
259
+ global SESSION_MODELS
260
+ return {
261
+ "response": list((BASE_MODELS | SESSION_MODELS[session_id]).keys()),
262
+ }
263
+
264
+
265
+ @celery_app.task(name="tasks.clear_session_models")
266
+ def clear_session_models_task(session_id: str) -> str:
267
+ global SESSION_MODELS
268
+ if session_id in SESSION_MODELS:
269
+ del SESSION_MODELS[session_id]
270
+ del SESSION_MODELS[session_id]
271
+ return {"response": "SUCCESS"}
finetuning/finetuning.ipynb ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "%pip install evaluate"
23
+ ],
24
+ "metadata": {
25
+ "id": "aqcbe-No3r2r"
26
+ },
27
+ "execution_count": null,
28
+ "outputs": []
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {
34
+ "id": "lOwXY3N4tmbr"
35
+ },
36
+ "outputs": [],
37
+ "source": [
38
+ "import numpy as np\n",
39
+ "import matplotlib\n",
40
+ "import torch\n",
41
+ "import torch.nn as nn\n",
42
+ "from torch.utils.data import Dataset, DataLoader\n",
43
+ "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\n",
44
+ "from datasets import load_dataset\n",
45
+ "import evaluate\n",
46
+ "from copy import deepcopy\n",
47
+ "\n",
48
+ "SEED=42\n",
49
+ "MODEL=\"gpt2-medium\"\n",
50
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "source": [
56
+ "def tokenize(x, tokenizer):\n",
57
+ " output = tokenizer(x[\"text\"], padding=\"max_length\", truncation=True, max_length=512)\n",
58
+ " output[\"label\"] = output[\"input_ids\"].copy()\n",
59
+ " return output\n",
60
+ "\n",
61
+ "def gen_tokenizer(model_name):\n",
62
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
63
+ " tokenizer.pad_token = tokenizer.eos_token\n",
64
+ " return tokenizer\n",
65
+ "\n",
66
+ "def finetune(config):\n",
67
+ " ds = config[\"ds\"]\n",
68
+ " preprocess_function = config[\"datasets_preprocess\"][config[\"dataset\"]]\n",
69
+ " tokenizer = gen_tokenizer(config[\"model\"])\n",
70
+ "\n",
71
+ " train_dataset = ds[\"train\"].select(range(config[\"max_train_size\"])).map(\n",
72
+ " lambda x: preprocess_function(x, tokenizer),\n",
73
+ " )\n",
74
+ "\n",
75
+ "\n",
76
+ " train_dataset = train_dataset.map(lambda x: tokenize(x, tokenizer), batched=True)\n",
77
+ "\n",
78
+ " model = AutoModelForCausalLM.from_pretrained(config[\"model\"])\n",
79
+ " orig_model = deepcopy(model)\n",
80
+ "\n",
81
+ " trainer = Trainer(\n",
82
+ " model=model,\n",
83
+ " args=config[\"training_args\"],\n",
84
+ " train_dataset=train_dataset,\n",
85
+ " processing_class=tokenizer,\n",
86
+ " )\n",
87
+ "\n",
88
+ " print(\"Starting training\")\n",
89
+ " trainer.train()\n",
90
+ " print(\"Training complete\")\n",
91
+ "\n",
92
+ " return orig_model, model"
93
+ ],
94
+ "metadata": {
95
+ "id": "B3XugMEV5vZF"
96
+ },
97
+ "execution_count": null,
98
+ "outputs": []
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "source": [
103
+ "def gsm8k_preprocess(x, tokenizer):\n",
104
+ " return {\"text\": f\"Question: {x['question']}\\nAnswer: {x['answer']}\" + tokenizer.eos_token}\n",
105
+ "\n",
106
+ "def svamp_preprocess(x, tokenizer):\n",
107
+ " return {\"text\": f\"{x['question_concat']}\\nAnswer: {x['Answer']}\" + tokenizer.eos_token}\n",
108
+ "\n",
109
+ "def tinystories_preprocess(x, tokenizer):\n",
110
+ " return {\"text\": x[\"text\"] + tokenizer.eos_token}\n",
111
+ "\n",
112
+ "datasets_finetune = {\n",
113
+ " \"openai/gsm8k\": gsm8k_preprocess,\n",
114
+ " \"ChilleD/SVAMP\": svamp_preprocess,\n",
115
+ " \"roneneldan/TinyStories\": tinystories_preprocess\n",
116
+ "}\n",
117
+ "\n",
118
+ "def preprocess_test_gsm8k(x):\n",
119
+ " return {\"text\": f\"Question: {x['question']}\\nAnswer:\" }\n",
120
+ "\n",
121
+ "def preprocess_test_svamp(x):\n",
122
+ " return {\"text\": f\"{x['question_concat']}\\nAnswer:\"}\n",
123
+ "\n",
124
+ "def preprocess_test_tinystories(x):\n",
125
+ " return {\"text\": x[\"text\"]}\n",
126
+ "\n",
127
+ "datasets_finetune_test = {\n",
128
+ " \"openai/gsm8k\": preprocess_test_gsm8k,\n",
129
+ " \"ChilleD/SVAMP\": preprocess_test_svamp,\n",
130
+ " \"roneneldan/TinyStories\": preprocess_test_tinystories\n",
131
+ "}"
132
+ ],
133
+ "metadata": {
134
+ "id": "Y09qs3FFxwx1"
135
+ },
136
+ "execution_count": null,
137
+ "outputs": []
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "source": [
142
+ "def test_finetune(dataset, ds, orig_model, model, datasets_preprocess, first_x):\n",
143
+ " tokenizer = gen_tokenizer(MODEL)\n",
144
+ " preprocess_function = datasets_preprocess[dataset]\n",
145
+ " if \"validation\" in ds:\n",
146
+ " ds[\"test\"] = deepcopy(ds[\"validation\"])\n",
147
+ "\n",
148
+ " test_dataset = ds[\"test\"].map(\n",
149
+ " lambda x: preprocess_function(x),\n",
150
+ " )\n",
151
+ "\n",
152
+ " model = model.to(device)\n",
153
+ " orig_model = orig_model.to(device)\n",
154
+ "\n",
155
+ " model.eval()\n",
156
+ " orig_model.eval()\n",
157
+ " xi = 0\n",
158
+ " with torch.no_grad():\n",
159
+ " for x in test_dataset:\n",
160
+ " input_tensor = tokenizer(x[\"text\"], return_tensors=\"pt\")\n",
161
+ " input_tensor[\"input_ids\"] = input_tensor[\"input_ids\"].to(device)\n",
162
+ " input_tensor[\"attention_mask\"] = input_tensor[\"attention_mask\"].to(device)\n",
163
+ "\n",
164
+ " output = orig_model.generate(**input_tensor, max_new_tokens=512)\n",
165
+ "\n",
166
+ " print(\"Original model output\")\n",
167
+ " print(tokenizer.decode(output[0], skip_special_tokens=True))\n",
168
+ "\n",
169
+ " finetuned_output = model.generate(**input_tensor, max_new_tokens=512)\n",
170
+ "\n",
171
+ " print(\"Finetuned model output\")\n",
172
+ " print(tokenizer.decode(finetuned_output[0], skip_special_tokens=True))\n",
173
+ "\n",
174
+ " xi += 1\n",
175
+ " if xi > first_x:\n",
176
+ " break\n"
177
+ ],
178
+ "metadata": {
179
+ "id": "zQqD3dHWDj6H"
180
+ },
181
+ "execution_count": null,
182
+ "outputs": []
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "source": [
187
+ "def generate_config(dataset):\n",
188
+ " return config\n"
189
+ ],
190
+ "metadata": {
191
+ "id": "2hlh0GERDBdC"
192
+ },
193
+ "execution_count": null,
194
+ "outputs": []
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "source": [
199
+ "dataset = \"ChilleD/SVAMP\"\n",
200
+ "ds = load_dataset(dataset, \"default\")\n",
201
+ "ds_1 = dataset.split('/')[1]\n",
202
+ "\n",
203
+ "config = {\n",
204
+ " \"ds\": ds,\n",
205
+ " \"dataset\": dataset,\n",
206
+ " \"datasets_preprocess\": datasets_finetune,\n",
207
+ " \"model\": MODEL,\n",
208
+ " \"max_train_size\": 700,\n",
209
+ " \"training_args\": TrainingArguments(\n",
210
+ " output_dir=f\"./results_{ds_1}\",\n",
211
+ " report_to=\"none\",\n",
212
+ " num_train_epochs=10,\n",
213
+ " per_device_train_batch_size=4,\n",
214
+ " warmup_steps=200,\n",
215
+ " learning_rate=5e-5,\n",
216
+ " weight_decay=0.01,\n",
217
+ " logging_steps=200,\n",
218
+ " save_strategy=\"steps\",\n",
219
+ " metric_for_best_model=\"loss\",\n",
220
+ " greater_is_better=False,\n",
221
+ " seed=SEED,\n",
222
+ " ),\n",
223
+ "}\n",
224
+ "\n",
225
+ "orig_model, model = finetune(config)"
226
+ ],
227
+ "metadata": {
228
+ "id": "WSGsa3Xtx04j"
229
+ },
230
+ "execution_count": null,
231
+ "outputs": []
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "source": [
236
+ "test_finetune(dataset, ds, orig_model, model, datasets_finetune_test, 3)"
237
+ ],
238
+ "metadata": {
239
+ "id": "2Z6kyEGqL7zN"
240
+ },
241
+ "execution_count": null,
242
+ "outputs": []
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "source": [
247
+ "dataset = \"roneneldan/TinyStories\"\n",
248
+ "ds = load_dataset(dataset, \"default\")\n",
249
+ "ds_1 = dataset.split('/')[1]\n",
250
+ "\n",
251
+ "\n",
252
+ "config = {\n",
253
+ " \"ds\": ds,\n",
254
+ " \"dataset\": dataset,\n",
255
+ " \"datasets_preprocess\": datasets_finetune,\n",
256
+ " \"model\": MODEL,\n",
257
+ " \"max_train_size\": 7000,\n",
258
+ " \"training_args\": TrainingArguments(\n",
259
+ " output_dir=f\"./results_{ds_1}\",\n",
260
+ " report_to=\"none\",\n",
261
+ " num_train_epochs=1,\n",
262
+ " per_device_train_batch_size=4,\n",
263
+ " warmup_steps=200,\n",
264
+ " learning_rate=5e-5,\n",
265
+ " weight_decay=0.01,\n",
266
+ " logging_steps=200,\n",
267
+ " save_strategy=\"steps\",\n",
268
+ " metric_for_best_model=\"loss\",\n",
269
+ " greater_is_better=False,\n",
270
+ " seed=SEED,\n",
271
+ " ),\n",
272
+ "}\n",
273
+ "\n",
274
+ "orig_model, model = finetune(generate_config(dataset))"
275
+ ],
276
+ "metadata": {
277
+ "id": "mOzvvJWP_PL1"
278
+ },
279
+ "execution_count": null,
280
+ "outputs": []
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "source": [
285
+ "test_finetune(dataset, ds, orig_model, model, datasets_finetune_test, 3)"
286
+ ],
287
+ "metadata": {
288
+ "id": "X6WryZ6p3xGm"
289
+ },
290
+ "execution_count": null,
291
+ "outputs": []
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "source": [
296
+ "from google.colab import files\n",
297
+ "files.download('/content/results_TinyStories/TinyStories-checkpoint-1750.zip')"
298
+ ],
299
+ "metadata": {
300
+ "id": "LBJxFu5oVP29"
301
+ },
302
+ "execution_count": null,
303
+ "outputs": []
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "source": [],
308
+ "metadata": {
309
+ "id": "xpFlk05UW87Q"
310
+ },
311
+ "execution_count": null,
312
+ "outputs": []
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "source": [
317
+ "from google.colab import files\n",
318
+ "files.download('/content/results_SVAMP/SVAMP-checkpoint-1750.zip')"
319
+ ],
320
+ "metadata": {
321
+ "id": "jxnCHVVDVO6j"
322
+ },
323
+ "execution_count": null,
324
+ "outputs": []
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "source": [],
329
+ "metadata": {
330
+ "id": "TygO0jjlVWG_"
331
+ },
332
+ "execution_count": null,
333
+ "outputs": []
334
+ }
335
+ ]
336
+ }
frontend/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
frontend/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # React + Vite
2
+
3
+ This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
+
5
+ Currently, two official plugins are available:
6
+
7
+ - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
8
+ - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9
+
10
+ ## Expanding the ESLint configuration
11
+
12
+ If you are developing a production application, we recommend using TypeScript with type-aware lint rules enabled. Check out the [TS template](https://github.com/vitejs/vite/tree/main/packages/create-vite/template-react-ts) for information on how to integrate TypeScript and [`typescript-eslint`](https://typescript-eslint.io) in your project.
frontend/eslint.config.js ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import { defineConfig, globalIgnores } from 'eslint/config'
6
+
7
+ export default defineConfig([
8
+ globalIgnores(['dist']),
9
+ {
10
+ files: ['**/*.{js,jsx}'],
11
+ extends: [
12
+ js.configs.recommended,
13
+ reactHooks.configs['recommended-latest'],
14
+ reactRefresh.configs.vite,
15
+ ],
16
+ languageOptions: {
17
+ ecmaVersion: 2020,
18
+ globals: globals.browser,
19
+ parserOptions: {
20
+ ecmaVersion: 'latest',
21
+ ecmaFeatures: { jsx: true },
22
+ sourceType: 'module',
23
+ },
24
+ },
25
+ rules: {
26
+ 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
+ },
28
+ },
29
+ ])
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Vite + React</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.jsx"></script>
12
+ </body>
13
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "frontend",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "@tailwindcss/vite": "^4.1.13",
14
+ "react": "^19.1.1",
15
+ "react-dom": "^19.1.1"
16
+ },
17
+ "devDependencies": {
18
+ "@eslint/js": "^9.35.0",
19
+ "@types/react": "^19.1.13",
20
+ "@types/react-dom": "^19.1.9",
21
+ "@vitejs/plugin-react": "^5.0.2",
22
+ "autoprefixer": "^10.4.21",
23
+ "eslint": "^9.35.0",
24
+ "eslint-plugin-react-hooks": "^5.2.0",
25
+ "eslint-plugin-react-refresh": "^0.4.20",
26
+ "globals": "^16.4.0",
27
+ "postcss": "^8.5.6",
28
+ "tailwindcss": "^4.1.13",
29
+ "vite": "^7.1.6"
30
+ }
31
+ }
frontend/public/vite.svg ADDED
frontend/src/App.css ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import "tailwindcss";
2
+
3
+ #root {
4
+ max-width: 1280px;
5
+ margin: 0 auto;
6
+ padding: 2rem;
7
+ text-align: center;
8
+ }
9
+
10
+ .logo {
11
+ height: 6em;
12
+ padding: 1.5em;
13
+ will-change: filter;
14
+ transition: filter 300ms;
15
+ }
16
+ .logo:hover {
17
+ filter: drop-shadow(0 0 2em #646cffaa);
18
+ }
19
+ .logo.react:hover {
20
+ filter: drop-shadow(0 0 2em #61dafbaa);
21
+ }
22
+
23
+ @keyframes logo-spin {
24
+ from {
25
+ transform: rotate(0deg);
26
+ }
27
+ to {
28
+ transform: rotate(360deg);
29
+ }
30
+ }
31
+
32
+ @media (prefers-reduced-motion: no-preference) {
33
+ a:nth-of-type(2) .logo {
34
+ animation: logo-spin infinite 20s linear;
35
+ }
36
+ }
37
+
38
+ .card {
39
+ padding: 2em;
40
+ }
41
+
42
+ .read-the-docs {
43
+ color: #888;
44
+ }
frontend/src/App.jsx ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState } from 'react'
2
+ import reactLogo from './assets/react.svg'
3
+ import viteLogo from '/vite.svg'
4
+ import './App.css'
5
+
6
+ function App() {
7
+ const [count, setCount] = useState(0)
8
+
9
+ return (
10
+ <>
11
+ <div>
12
+ <a href="https://vite.dev" target="_blank">
13
+ <img src={viteLogo} className="logo" alt="Vite logo" />
14
+ </a>
15
+ <a href="https://react.dev" target="_blank">
16
+ <img src={reactLogo} className="logo react" alt="React logo" />
17
+ </a>
18
+ </div>
19
+ <h1>Vite + React</h1>
20
+ <div className="card">
21
+ <button onClick={() => setCount((count) => count + 1)}>
22
+ count is {count}
23
+ </button>
24
+ <p>
25
+ Edit <code>src/App.jsx</code> and save to test HMR
26
+ </p>
27
+ </div>
28
+ <p className="read-the-docs">
29
+ Click on the Vite and React logos to learn more
30
+ </p>
31
+ </>
32
+ )
33
+ }
34
+
35
+ export default App
frontend/src/assets/react.svg ADDED
frontend/src/index.css ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
3
+ line-height: 1.5;
4
+ font-weight: 400;
5
+
6
+ color-scheme: light dark;
7
+ color: rgba(255, 255, 255, 0.87);
8
+ background-color: #242424;
9
+
10
+ font-synthesis: none;
11
+ text-rendering: optimizeLegibility;
12
+ -webkit-font-smoothing: antialiased;
13
+ -moz-osx-font-smoothing: grayscale;
14
+ }
15
+
16
+ a {
17
+ font-weight: 500;
18
+ color: #646cff;
19
+ text-decoration: inherit;
20
+ }
21
+ a:hover {
22
+ color: #535bf2;
23
+ }
24
+
25
+ body {
26
+ margin: 0;
27
+ display: flex;
28
+ place-items: center;
29
+ min-width: 320px;
30
+ min-height: 100vh;
31
+ }
32
+
33
+ h1 {
34
+ font-size: 3.2em;
35
+ line-height: 1.1;
36
+ }
37
+
38
+ button {
39
+ border-radius: 8px;
40
+ border: 1px solid transparent;
41
+ padding: 0.6em 1.2em;
42
+ font-size: 1em;
43
+ font-weight: 500;
44
+ font-family: inherit;
45
+ background-color: #1a1a1a;
46
+ cursor: pointer;
47
+ transition: border-color 0.25s;
48
+ }
49
+ button:hover {
50
+ border-color: #646cff;
51
+ }
52
+ button:focus,
53
+ button:focus-visible {
54
+ outline: 4px auto -webkit-focus-ring-color;
55
+ }
56
+
57
+ @media (prefers-color-scheme: light) {
58
+ :root {
59
+ color: #213547;
60
+ background-color: #ffffff;
61
+ }
62
+ a:hover {
63
+ color: #747bff;
64
+ }
65
+ button {
66
+ background-color: #f9f9f9;
67
+ }
68
+ }
frontend/src/main.jsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { StrictMode } from 'react'
2
+ import { createRoot } from 'react-dom/client'
3
+ import './index.css'
4
+ import App from './App.jsx'
5
+
6
+ createRoot(document.getElementById('root')).render(
7
+ <StrictMode>
8
+ <App />
9
+ </StrictMode>,
10
+ )
frontend/vite.config.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from "vite";
2
+ import react from "@vitejs/plugin-react";
3
+ import tailwindcss from "@tailwindcss/vite";
4
+
5
+ // https://vite.dev/config/
6
+ export default defineConfig({
7
+ plugins: [react(), tailwindcss()],
8
+ });
main.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
3
+
4
+ start = time.time()
5
+ model = AutoModelForCausalLM.from_pretrained("tcmmichaelb139/gpt2-medium-tinystories")
6
+ print(model)
7
+
8
+ print("Loaded model in", time.time() - start)
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "evolutiontransformer"
3
+ version = "0.1.0"
4
+ description = "Simulating evolution among LLMs"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "accelerate>=1.10.1",
9
+ "celery>=5.5.3",
10
+ "datasets>=4.1.1",
11
+ "evaluate>=0.4.6",
12
+ "fastapi>=0.116.2",
13
+ "gradio>=5.46.0",
14
+ "gunicorn>=23.0.0",
15
+ "matplotlib>=3.10.6",
16
+ "numpy>=2.3.3",
17
+ "pytest>=8.4.2",
18
+ "redis>=6.4.0",
19
+ "torch>=2.8.0",
20
+ "transformers>=4.56.1",
21
+ "uvicorn[standard]>=0.35.0",
22
+ ]
tests/__init__.py ADDED
File without changes
tests/test_api.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ import time
4
+ import re
5
+
6
+ from evolutiontransformer.api import app
7
+
8
+
9
+ def get_final_answer(text: str) -> int | None:
10
+ numbers = re.findall(r"\d+", text)
11
+ return int(numbers[-1]) if numbers else None
12
+
13
+
14
+ @pytest.fixture
15
+ def client():
16
+ with TestClient(app) as c:
17
+ yield c
18
+
19
+
20
+ def await_task_completion(client, task_id, timeout=60):
21
+ start_time = time.time()
22
+ while time.time() - start_time < timeout:
23
+ status_response = client.get(f"/tasks/{task_id}")
24
+
25
+ print(status_response.json())
26
+
27
+ if status_response.status_code == 500:
28
+ return {"error": status_response.json().get("detail", "Unknown error")}
29
+ assert status_response.status_code == 200
30
+ status_data = status_response.json()
31
+
32
+ if status_data["status"] == "SUCCESS":
33
+ return status_data["result"]
34
+
35
+ time.sleep(2)
36
+ else:
37
+ pytest.fail(
38
+ f"Task {task_id} did not complete within the {timeout}-second timeout."
39
+ )
40
+
41
+ return None
42
+
43
+
44
+ def test_generate_endpoint_svamp(client):
45
+ """
46
+ Tests inference on svamp
47
+ """
48
+ response = client.post(
49
+ "/generate",
50
+ json={
51
+ "model_name": "svamp",
52
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
53
+ "max_new_tokens": 50,
54
+ "temperature": 0.7,
55
+ },
56
+ )
57
+
58
+ assert response.status_code == 200
59
+ data = response.json()
60
+
61
+ assert "task_id" in data
62
+ task_id = data["task_id"]
63
+
64
+ final_result = await_task_completion(client, task_id)
65
+
66
+ assert "response" in final_result
67
+ output_text = final_result["response"]
68
+
69
+ answer = get_final_answer(output_text)
70
+ assert answer == 14
71
+
72
+
73
+ def test_merge_then_inference_svamp_1(client):
74
+ """
75
+ Tests merging then inference for svamp dataset
76
+ """
77
+
78
+ merge_response = client.post(
79
+ "/merge",
80
+ json={
81
+ "model1_name": "svamp",
82
+ "model2_name": "tinystories",
83
+ "layer_recipe": [[(i, 0, 1.0)] for i in range(24)],
84
+ "embedding_lambdas": [1.0, 1.0],
85
+ "linear_lambdas": [1.0, 1.0],
86
+ "merged_name": "svamp_merged",
87
+ },
88
+ )
89
+
90
+ assert merge_response.status_code == 200
91
+ merge_data = merge_response.json()
92
+ assert "task_id" in merge_data
93
+ merge_task_id = merge_data["task_id"]
94
+
95
+ merge_status_data = await_task_completion(client, merge_task_id)
96
+ model_name = merge_status_data["response"]
97
+
98
+ time.sleep(5)
99
+
100
+ generate_response = client.post(
101
+ "/generate",
102
+ json={
103
+ "model_name": model_name,
104
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
105
+ "max_new_tokens": 50,
106
+ "temperature": 0.7,
107
+ },
108
+ )
109
+
110
+ assert generate_response.status_code == 200
111
+ generate_data = generate_response.json()
112
+ assert "task_id" in generate_data
113
+ generate_task_id = generate_data["task_id"]
114
+
115
+ final_result = await_task_completion(client, generate_task_id)
116
+
117
+ assert "response" in final_result
118
+ output_text = final_result["response"]
119
+ answer = get_final_answer(output_text)
120
+
121
+ assert answer == 14
122
+
123
+
124
+ def test_merge_then_inference_svamp_2(client):
125
+ """
126
+ Tests merging then inference for svamp dataset
127
+ """
128
+
129
+ merge_repsonse = client.post(
130
+ "/merge",
131
+ json={
132
+ "model1_name": "svamp",
133
+ "model2_name": "tinystories",
134
+ "layer_recipe": [[(i % 24, 0, 1.0 if i < 24 else 0.5)] for i in range(48)],
135
+ "embedding_lambdas": [1.0, 1.0],
136
+ "linear_lambdas": [1.0, 1.0],
137
+ "merged_name": "svamp_merged",
138
+ },
139
+ )
140
+
141
+ assert merge_repsonse.status_code == 200
142
+ merge_data = merge_repsonse.json()
143
+ assert "task_id" in merge_data
144
+ merge_task_id = merge_data["task_id"]
145
+
146
+ merge_status_data = await_task_completion(client, merge_task_id)
147
+
148
+ model_name = merge_status_data["response"]
149
+
150
+ merge_response2 = client.post(
151
+ "/merge",
152
+ json={
153
+ "model1_name": model_name,
154
+ "model2_name": "tinystories",
155
+ "layer_recipe": [[(i, 1, 0.25)] for i in range(24)],
156
+ "embedding_lambdas": [0.0, 0.0],
157
+ "linear_lambdas": [0.0, 0.0],
158
+ "merged_name": "svamp_merged",
159
+ },
160
+ )
161
+
162
+ assert merge_response2.status_code == 200
163
+ merge_data2 = merge_response2.json()
164
+ assert "task_id" in merge_data2
165
+ merge_task_id2 = merge_data2["task_id"]
166
+ merge_status_data2 = await_task_completion(client, merge_task_id2)
167
+ model_name2 = merge_status_data2["response"]
168
+
169
+ time.sleep(5)
170
+
171
+ generate_response = client.post(
172
+ "/generate",
173
+ json={
174
+ "model_name": model_name2,
175
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
176
+ "max_new_tokens": 50,
177
+ "temperature": 0.7,
178
+ },
179
+ )
180
+
181
+ assert generate_response.status_code == 200
182
+ generate_data = generate_response.json()
183
+ assert "task_id" in generate_data
184
+ generate_task_id = generate_data["task_id"]
185
+
186
+ final_result = await_task_completion(client, generate_task_id)
187
+
188
+ assert "response" in final_result
189
+ output_text = final_result["response"]
190
+ answer = get_final_answer(output_text)
191
+
192
+ assert answer == 14
193
+
194
+
195
+ def test_merge_two_children_then_merge(client):
196
+ """
197
+ Tests creating two children and merging them
198
+ """
199
+
200
+ merge_response1 = client.post(
201
+ "/merge",
202
+ json={
203
+ "model1_name": "svamp",
204
+ "model2_name": "tinystories",
205
+ "layer_recipe": [[(i, 0, 0.8)] for i in range(12)]
206
+ + [[(i, 1, 0.6)] for i in range(12)],
207
+ "embedding_lambdas": [0.7, 0.3],
208
+ "linear_lambdas": [0.8, 0.2],
209
+ "merged_name": "child1",
210
+ },
211
+ )
212
+
213
+ assert merge_response1.status_code == 200
214
+ merge_data1 = merge_response1.json()
215
+ assert "task_id" in merge_data1
216
+ merge_task_id1 = merge_data1["task_id"]
217
+ merge_status_data1 = await_task_completion(client, merge_task_id1)
218
+ child1_name = merge_status_data1["response"]
219
+
220
+ merge_response2 = client.post(
221
+ "/merge",
222
+ json={
223
+ "model1_name": "svamp",
224
+ "model2_name": "tinystories",
225
+ "layer_recipe": [[(i, 1, 0.9)] for i in range(8)]
226
+ + [[(i, 0, 0.4)] for i in range(16)],
227
+ "embedding_lambdas": [0.2, 0.9],
228
+ "linear_lambdas": [0.3, 0.7],
229
+ "merged_name": "child2",
230
+ },
231
+ )
232
+
233
+ assert merge_response2.status_code == 200
234
+ merge_data2 = merge_response2.json()
235
+ assert "task_id" in merge_data2
236
+ merge_task_id2 = merge_data2["task_id"]
237
+ merge_status_data2 = await_task_completion(client, merge_task_id2)
238
+ child2_name = merge_status_data2["response"]
239
+
240
+ merge_response3 = client.post(
241
+ "/merge",
242
+ json={
243
+ "model1_name": child1_name,
244
+ "model2_name": child2_name,
245
+ "layer_recipe": [[(i, 0, 0.6), (i, 1, 0.4)] for i in range(24)],
246
+ "embedding_lambdas": [0.5, 0.5],
247
+ "linear_lambdas": [0.6, 0.4],
248
+ "merged_name": "final_merged",
249
+ },
250
+ )
251
+
252
+ assert merge_response3.status_code == 200
253
+ merge_data3 = merge_response3.json()
254
+ assert "task_id" in merge_data3
255
+ merge_task_id3 = merge_data3["task_id"]
256
+ merge_status_data3 = await_task_completion(client, merge_task_id3)
257
+ final_model_name = merge_status_data3["response"]
258
+
259
+ time.sleep(5)
260
+
261
+ generate_response = client.post(
262
+ "/generate",
263
+ json={
264
+ "model_name": final_model_name,
265
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
266
+ "max_new_tokens": 50,
267
+ "temperature": 0.7,
268
+ },
269
+ )
270
+
271
+ assert generate_response.status_code == 200
272
+ generate_data = generate_response.json()
273
+ assert "task_id" in generate_data
274
+ generate_task_id = generate_data["task_id"]
275
+
276
+ final_result = await_task_completion(client, generate_task_id)
277
+
278
+ assert "response" in final_result
279
+ output_text = final_result["response"]
280
+ answer = get_final_answer(output_text)
281
+
282
+ assert answer == 14
283
+
284
+
285
+ def test_merge_fail(client):
286
+ """
287
+ Tests merging with too many layers
288
+ """
289
+
290
+ merge_repsonse = client.post(
291
+ "/merge",
292
+ json={
293
+ "model1_name": "svamp",
294
+ "model2_name": "tinystories",
295
+ "layer_recipe": [[(i, 0, 1.0)] for i in range(50)],
296
+ "embedding_lambdas": [1.0, 1.0],
297
+ "linear_lambdas": [1.0, 1.0],
298
+ "merged_name": "svamp_merged",
299
+ },
300
+ )
301
+
302
+ assert merge_repsonse.status_code == 200
303
+ merge_data = merge_repsonse.json()
304
+ assert "task_id" in merge_data
305
+ merge_task_id = merge_data["task_id"]
306
+
307
+ merge_status_data = await_task_completion(client, merge_task_id)
308
+ assert "response" not in merge_status_data
309
+ assert "error" in merge_status_data
310
+ assert "Layer recipe too long" in merge_status_data["error"]
tests/test_model_actions.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM
3
+ import re
4
+
5
+
6
+ from evolutiontransformer.worker import (
7
+ load_base_models_if_needed,
8
+ BASE_MODELS,
9
+ inference,
10
+ inference_task,
11
+ merge_models,
12
+ )
13
+
14
+
15
+ def get_final_answer(text: str) -> int | None:
16
+ numbers = re.findall(r"\d+", text)
17
+ return int(numbers[-1]) if numbers else None
18
+
19
+
20
+ def test_inference():
21
+ session_id = "test_session"
22
+
23
+ print("### Testing inference on SVAMP model...")
24
+ prompt = "If there are 3 cars and 2 bikes, how many vehicles are there in total?\nAnswer:"
25
+ output = inference_task(session_id, "svamp", prompt)
26
+ assert get_final_answer(output["response"]) == 5
27
+
28
+
29
+ def test_merge_models():
30
+ load_base_models_if_needed()
31
+
32
+ model_recipe = {
33
+ "layer_recipe": [[(i, "svamp", 1.0)] for i in range(24)],
34
+ "embedding_lambdas": [1.0, 1.0],
35
+ "linear_lambdas": [1.0, 1.0],
36
+ }
37
+
38
+ merged_model = merge_models(model_recipe)
39
+
40
+ for (name1, param1), (name2, param2) in zip(
41
+ BASE_MODELS["svamp"].named_parameters(), merged_model.named_parameters()
42
+ ):
43
+ assert torch.allclose(param1, param2)
44
+
45
+
46
+ def test_merge_models_with_inference1():
47
+ load_base_models_if_needed()
48
+
49
+ model_recipe = {
50
+ "layer_recipe": [
51
+ [(i % 24, "svamp", 1.0 if i < 24 else 0.5)] for i in range(48)
52
+ ],
53
+ "embedding_lambdas": [1.0, 1.0],
54
+ "linear_lambdas": [1.0, 1.0],
55
+ }
56
+
57
+ merged_model = merge_models(model_recipe)
58
+
59
+ print(
60
+ inference(
61
+ merged_model,
62
+ "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
63
+ )
64
+ )
65
+
66
+
67
+ def test_merge_models_with_inference2():
68
+ load_base_models_if_needed()
69
+
70
+ model_recipe = {
71
+ "layer_recipe": [[(i, "tinystories", 1.0)] for i in range(24)],
72
+ "embedding_lambdas": [0.0, 0.0],
73
+ "linear_lambdas": [0.0, 0.0],
74
+ }
75
+
76
+ merged_model = merge_models(model_recipe)
77
+
78
+ print(
79
+ inference(
80
+ merged_model,
81
+ "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
82
+ )
83
+ )
uv.lock ADDED
The diff for this file is too large to render. See raw diff