Keeby-smilyai commited on
Commit
e0790fc
Β·
verified Β·
1 Parent(s): 06078e5

Create backend.py

Browse files
Files changed (1) hide show
  1. backend.py +234 -0
backend.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend.py
2
+ import sqlite3
3
+ import threading
4
+ import time
5
+ import torch
6
+ from huggingface_hub import whoami
7
+ from datasets import load_dataset
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
9
+ import os
10
+
11
+ DB_PATH = "llm_kitchen.db"
12
+ training_queue = []
13
+ active_run_lock = threading.Lock()
14
+ active_run_id = None
15
+
16
+ # ------------------------------ DATABASE ------------------------------
17
+
18
+ def init_db():
19
+ if os.path.exists(DB_PATH):
20
+ return
21
+ conn = sqlite3.connect(DB_PATH)
22
+ cursor = conn.cursor()
23
+ cursor.executescript("""
24
+ CREATE TABLE IF NOT EXISTS users (
25
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
26
+ hf_token TEXT UNIQUE NOT NULL,
27
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
28
+ );
29
+ CREATE TABLE IF NOT EXISTS training_runs (
30
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
31
+ user_id INTEGER NOT NULL,
32
+ arch_type TEXT NOT NULL,
33
+ num_layers INTEGER NOT NULL,
34
+ learning_rate REAL NOT NULL,
35
+ epochs INTEGER NOT NULL,
36
+ batch_size INTEGER NOT NULL,
37
+ status TEXT DEFAULT 'queued',
38
+ logs TEXT DEFAULT '',
39
+ started_at DATETIME,
40
+ completed_at DATETIME,
41
+ FOREIGN KEY (user_id) REFERENCES users(id)
42
+ );
43
+ """)
44
+ conn.commit()
45
+ conn.close()
46
+
47
+ init_db()
48
+
49
+ def get_user_by_token(hf_token):
50
+ conn = sqlite3.connect(DB_PATH)
51
+ cursor = conn.cursor()
52
+ cursor.execute("SELECT id FROM users WHERE hf_token = ?", (hf_token,))
53
+ row = cursor.fetchone()
54
+ conn.close()
55
+ return row[0] if row else None
56
+
57
+ def create_user(hf_token):
58
+ conn = sqlite3.connect(DB_PATH)
59
+ cursor = conn.cursor()
60
+ cursor.execute("INSERT INTO users (hf_token) VALUES (?)", (hf_token,))
61
+ user_id = cursor.lastrowid
62
+ conn.commit()
63
+ conn.close()
64
+ return user_id
65
+
66
+ def create_training_run(user_id, config):
67
+ conn = sqlite3.connect(DB_PATH)
68
+ cursor = conn.cursor()
69
+ cursor.execute("""
70
+ INSERT INTO training_runs
71
+ (user_id, arch_type, num_layers, learning_rate, epochs, batch_size)
72
+ VALUES (?, ?, ?, ?, ?, ?)
73
+ """, (
74
+ user_id,
75
+ config['arch_type'],
76
+ config['num_layers'],
77
+ config['learning_rate'],
78
+ config['epochs'],
79
+ config['batch_size']
80
+ ))
81
+ run_id = cursor.lastrowid
82
+ conn.commit()
83
+ conn.close()
84
+ return run_id
85
+
86
+ def get_user_runs(user_id):
87
+ conn = sqlite3.connect(DB_PATH)
88
+ cursor = conn.cursor()
89
+ cursor.execute("""
90
+ SELECT id, arch_type, num_layers, status, started_at
91
+ FROM training_runs
92
+ WHERE user_id = ?
93
+ ORDER BY started_at DESC
94
+ """, (user_id,))
95
+ runs = cursor.fetchall()
96
+ conn.close()
97
+ return runs
98
+
99
+ def get_run_logs(run_id):
100
+ conn = sqlite3.connect(DB_PATH)
101
+ cursor = conn.cursor()
102
+ cursor.execute("SELECT logs, status FROM training_runs WHERE id = ?", (run_id,))
103
+ row = cursor.fetchone()
104
+ conn.close()
105
+ return row if row else ("", "unknown")
106
+
107
+ def update_run_status(run_id, status, logs=""):
108
+ conn = sqlite3.connect(DB_PATH)
109
+ cursor = conn.cursor()
110
+ if status == 'running':
111
+ cursor.execute("UPDATE training_runs SET status = ?, started_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
112
+ elif status in ['completed', 'failed', 'timeout']:
113
+ cursor.execute("UPDATE training_runs SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
114
+ if logs:
115
+ current_logs = get_run_logs(run_id)[0]
116
+ cursor.execute("UPDATE training_runs SET logs = ? WHERE id = ?", (current_logs + "\n" + logs, run_id))
117
+ conn.commit()
118
+ conn.close()
119
+
120
+ # ------------------------------ AUTH ------------------------------
121
+
122
+ def verify_hf_token(token):
123
+ try:
124
+ whoami(token=token)
125
+ user_id = get_user_by_token(token)
126
+ if not user_id:
127
+ user_id = create_user(token)
128
+ return user_id, "Welcome to the LLM Kitchen, Chef! 🍳 Your apron is ready."
129
+ else:
130
+ return user_id, "Welcome back, Chef! πŸ‘¨β€πŸ³ Your last dish is still warm."
131
+ except Exception as e:
132
+ return None, f"Invalid token. Please try again. ({str(e)})"
133
+
134
+ # ------------------------------ TRAINING QUEUE ------------------------------
135
+
136
+ def queue_training_run(user_id, config):
137
+ run_id = create_training_run(user_id, config)
138
+ training_queue.append({
139
+ "run_id": run_id,
140
+ "user_id": user_id,
141
+ **config
142
+ })
143
+ return run_id
144
+
145
+ def ram_check_mock():
146
+ # Mock: Allow 1 run at a time, 1.5GB per run
147
+ global active_run_id
148
+ return active_run_id is None
149
+
150
+ def start_training_if_free():
151
+ global active_run_id
152
+ with active_run_lock:
153
+ if active_run_id is not None:
154
+ return False
155
+ if not training_queue:
156
+ return False
157
+ if not ram_check_mock():
158
+ return False
159
+
160
+ job = training_queue.pop(0)
161
+ active_run_id = job["run_id"]
162
+ update_run_status(active_run_id, "running", "🍳 Starting kitchen process...")
163
+
164
+ thread = threading.Thread(target=run_training_job, args=(job,))
165
+ thread.start()
166
+ return True
167
+
168
+ def run_training_job(job):
169
+ global active_run_id
170
+ run_id = job["run_id"]
171
+ try:
172
+ device = "cuda" if torch.cuda.is_available() else "cpu"
173
+ log_update(f"Run {run_id}: Device = {device}", run_id)
174
+
175
+ # Load tiny model for demo (replace with custom later)
176
+ model_name = "distilgpt2"
177
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
178
+ if tokenizer.pad_token is None:
179
+ tokenizer.pad_token = tokenizer.eos_token
180
+
181
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
182
+ log_update(f"Run {run_id}: Model loaded", run_id)
183
+
184
+ # Load dataset
185
+ dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:1%]") # Tiny slice for demo
186
+ def tokenize_function(examples):
187
+ texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
188
+ return tokenizer(texts, truncation=True, padding="max_length", max_length=128)
189
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["message", "answer"])
190
+ log_update(f"Run {run_id}: Dataset tokenized", run_id)
191
+
192
+ # Training args
193
+ training_args = TrainingArguments(
194
+ output_dir=f"./runs/{run_id}",
195
+ num_train_epochs=job["epochs"],
196
+ per_device_train_batch_size=job["batch_size"],
197
+ learning_rate=job["learning_rate"],
198
+ save_strategy="no",
199
+ logging_steps=1,
200
+ report_to="none",
201
+ fp16=False,
202
+ no_cuda=(device == "cpu")
203
+ )
204
+
205
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
206
+ trainer = Trainer(
207
+ model=model,
208
+ args=training_args,
209
+ train_dataset=tokenized_dataset,
210
+ data_collator=data_collator,
211
+ )
212
+
213
+ log_update(f"Run {run_id}: Starting training...", run_id)
214
+ trainer.train()
215
+
216
+ # Simulate 48h timeout with short sleep for demo
217
+ time.sleep(10) # Replace with real training
218
+
219
+ eval_results = trainer.evaluate()
220
+ log_update(f"Run {run_id}: Training complete. Loss = {eval_results.get('eval_loss', 'N/A')}", run_id)
221
+ update_run_status(run_id, "completed")
222
+
223
+ except Exception as e:
224
+ log_update(f"Run {run_id}: FAILED - {str(e)}", run_id)
225
+ update_run_status(run_id, "failed")
226
+ finally:
227
+ with active_run_lock:
228
+ active_run_id = None
229
+ # Try starting next queued job
230
+ start_training_if_free()
231
+
232
+ def log_update(message, run_id):
233
+ print(f"[LOG] {message}") # Also print to Spaces logs
234
+ update_run_status(run_id, "running", message)