Keeby-smilyai commited on
Commit
ba941c2
Β·
verified Β·
1 Parent(s): 547c56f

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +90 -21
backend.py CHANGED
@@ -1,4 +1,4 @@
1
- # backend.py β€” REAL VERSION
2
  import sqlite3
3
  import threading
4
  import time
@@ -53,17 +53,96 @@ def init_db():
53
 
54
  init_db()
55
 
56
- # ... [KEEP ALL DB HELPER FUNCTIONS: get_user_by_token, create_user, etc. β€” NO CHANGES] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # ------------------------------ AUTH ------------------------------
59
- # ... [KEEP verify_hf_token β€” NO CHANGES] ...
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # ------------------------------ TRAINING QUEUE ------------------------------
62
 
63
  def ram_available():
64
- """Check if we can start a new run (1.5GB per run)"""
65
- total_ram = psutil.virtual_memory().total / (1024**3) # GB
66
- used_ram = psutil.virtual_memory().used / (1024**3) # GB
67
  available_gb = total_ram - used_ram
68
  return available_gb >= MAX_RAM_PER_RUN_GB
69
 
@@ -94,7 +173,6 @@ def start_training_if_free():
94
  thread = threading.Thread(target=run_training_job, args=(job,))
95
  thread.start()
96
 
97
- # Start 48h timeout killer
98
  timer = threading.Timer(RUN_TIMEOUT, kill_run_timeout, args=[job["run_id"]])
99
  timer.start()
100
 
@@ -107,7 +185,7 @@ def kill_run_timeout(run_id):
107
  log_update(f"Run {run_id}: πŸ’₯ 48-HOUR TIMEOUT REACHED. Terminating.", run_id)
108
  update_run_status(run_id, "timeout")
109
  active_run_id = None
110
- start_training_if_free() # try next
111
 
112
  # ------------------------------ CUSTOM MODELS FROM SCRATCH ------------------------------
113
 
@@ -125,8 +203,8 @@ class CNNLanguageModel(nn.Module):
125
  self.fc = nn.Linear(in_ch, vocab_size)
126
 
127
  def forward(self, x, labels=None):
128
- x = self.embedding(x).transpose(1, 2) # (B, E, L)
129
- x = self.convs(x).transpose(1, 2) # (B, L, E*2^N)
130
  logits = self.fc(x)
131
  loss = None
132
  if labels is not None:
@@ -184,7 +262,7 @@ def get_model(arch_type, vocab_size, num_layers):
184
  class TextDataset(Dataset):
185
  def __init__(self, tokenized_data):
186
  self.input_ids = tokenized_data["input_ids"]
187
- self.labels = tokenized_data["input_ids"] # causal LM
188
 
189
  def __len__(self):
190
  return len(self.input_ids)
@@ -204,18 +282,14 @@ def run_training_job(job):
204
  device = "cuda" if torch.cuda.is_available() else "cpu"
205
  log_update(f"Run {run_id}: πŸš€ Device = {device} | RAM available: {psutil.virtual_memory().available / (1024**3):.2f} GB", run_id)
206
 
207
- # Load tokenizer (shared for all models)
208
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
209
  if tokenizer.pad_token is None:
210
  tokenizer.pad_token = tokenizer.eos_token
211
 
212
  vocab_size = len(tokenizer)
213
-
214
- # Build model from scratch
215
  model = get_model(job["arch_type"], vocab_size, job["num_layers"]).to(device)
216
  log_update(f"Run {run_id}: 🧱 Model initialized: {job['arch_type']} x{job['num_layers']} layers", run_id)
217
 
218
- # Load dataset β€” full training set (or 100K for speed)
219
  dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:100000]")
220
  def tokenize_function(examples):
221
  texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
@@ -224,11 +298,8 @@ def run_training_job(job):
224
  train_dataset = TextDataset(tokenized_dataset)
225
 
226
  train_loader = DataLoader(train_dataset, batch_size=job["batch_size"], shuffle=True)
227
-
228
- # Optimizer
229
  optimizer = torch.optim.AdamW(model.parameters(), lr=job["learning_rate"])
230
 
231
- # Training loop
232
  model.train()
233
  log_update(f"Run {run_id}: ▢️ Starting training for {job['epochs']} epochs...", run_id)
234
 
@@ -252,7 +323,6 @@ def run_training_job(job):
252
  avg_loss = total_loss / len(train_loader)
253
  log_update(f"Run {run_id}: βœ… Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}", run_id)
254
 
255
- # Save model
256
  model_path = f"./runs/{run_id}"
257
  os.makedirs(model_path, exist_ok=True)
258
  torch.save(model.state_dict(), f"{model_path}/model.pth")
@@ -271,11 +341,10 @@ def run_training_job(job):
271
  def log_update(message, run_id):
272
  timestamp = time.strftime("%H:%M:%S")
273
  full_msg = f"[{timestamp}] {message}"
274
- print(full_msg) # Also shows in HF Spaces logs
275
  if run_id > 0:
276
  update_run_status(run_id, "running", full_msg)
277
 
278
-
279
  # ------------------------------ PUBLIC API ------------------------------
280
 
281
  __all__ = [
 
1
+ # backend.py β€” REAL, FULL, WORKING VERSION
2
  import sqlite3
3
  import threading
4
  import time
 
53
 
54
  init_db()
55
 
56
+ def get_user_by_token(hf_token):
57
+ conn = sqlite3.connect(DB_PATH)
58
+ cursor = conn.cursor()
59
+ cursor.execute("SELECT id FROM users WHERE hf_token = ?", (hf_token,))
60
+ row = cursor.fetchone()
61
+ conn.close()
62
+ return row[0] if row else None
63
+
64
+ def create_user(hf_token):
65
+ conn = sqlite3.connect(DB_PATH)
66
+ cursor = conn.cursor()
67
+ cursor.execute("INSERT INTO users (hf_token) VALUES (?)", (hf_token,))
68
+ user_id = cursor.lastrowid
69
+ conn.commit()
70
+ conn.close()
71
+ return user_id
72
+
73
+ def create_training_run(user_id, config):
74
+ conn = sqlite3.connect(DB_PATH)
75
+ cursor = conn.cursor()
76
+ cursor.execute("""
77
+ INSERT INTO training_runs
78
+ (user_id, arch_type, num_layers, learning_rate, epochs, batch_size)
79
+ VALUES (?, ?, ?, ?, ?, ?)
80
+ """, (
81
+ user_id,
82
+ config['arch_type'],
83
+ config['num_layers'],
84
+ config['learning_rate'],
85
+ config['epochs'],
86
+ config['batch_size']
87
+ ))
88
+ run_id = cursor.lastrowid
89
+ conn.commit()
90
+ conn.close()
91
+ return run_id
92
+
93
+ def get_user_runs(user_id):
94
+ conn = sqlite3.connect(DB_PATH)
95
+ cursor = conn.cursor()
96
+ cursor.execute("""
97
+ SELECT id, arch_type, num_layers, status, started_at
98
+ FROM training_runs
99
+ WHERE user_id = ?
100
+ ORDER BY started_at DESC
101
+ """, (user_id,))
102
+ runs = cursor.fetchall()
103
+ conn.close()
104
+ return runs
105
+
106
+ def get_run_logs(run_id):
107
+ conn = sqlite3.connect(DB_PATH)
108
+ cursor = conn.cursor()
109
+ cursor.execute("SELECT logs, status FROM training_runs WHERE id = ?", (run_id,))
110
+ row = cursor.fetchone()
111
+ conn.close()
112
+ return row if row else ("", "unknown")
113
+
114
+ def update_run_status(run_id, status, logs=""):
115
+ conn = sqlite3.connect(DB_PATH)
116
+ cursor = conn.cursor()
117
+ if status == 'running':
118
+ cursor.execute("UPDATE training_runs SET status = ?, started_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
119
+ elif status in ['completed', 'failed', 'timeout']:
120
+ cursor.execute("UPDATE training_runs SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
121
+ if logs:
122
+ current_logs = get_run_logs(run_id)[0]
123
+ cursor.execute("UPDATE training_runs SET logs = ? WHERE id = ?", (current_logs + "\n" + logs, run_id))
124
+ conn.commit()
125
+ conn.close()
126
 
127
  # ------------------------------ AUTH ------------------------------
128
+
129
+ def verify_hf_token(token):
130
+ try:
131
+ whoami(token=token)
132
+ user_id = get_user_by_token(token)
133
+ if not user_id:
134
+ user_id = create_user(token)
135
+ return user_id, "Welcome to the LLM Kitchen, Chef! 🍳 Your apron is ready."
136
+ else:
137
+ return user_id, "Welcome back, Chef! πŸ‘¨β€πŸ³ Your last dish is still warm."
138
+ except Exception as e:
139
+ return None, f"Invalid token. Please try again. ({str(e)})"
140
 
141
  # ------------------------------ TRAINING QUEUE ------------------------------
142
 
143
  def ram_available():
144
+ total_ram = psutil.virtual_memory().total / (1024**3)
145
+ used_ram = psutil.virtual_memory().used / (1024**3)
 
146
  available_gb = total_ram - used_ram
147
  return available_gb >= MAX_RAM_PER_RUN_GB
148
 
 
173
  thread = threading.Thread(target=run_training_job, args=(job,))
174
  thread.start()
175
 
 
176
  timer = threading.Timer(RUN_TIMEOUT, kill_run_timeout, args=[job["run_id"]])
177
  timer.start()
178
 
 
185
  log_update(f"Run {run_id}: πŸ’₯ 48-HOUR TIMEOUT REACHED. Terminating.", run_id)
186
  update_run_status(run_id, "timeout")
187
  active_run_id = None
188
+ start_training_if_free()
189
 
190
  # ------------------------------ CUSTOM MODELS FROM SCRATCH ------------------------------
191
 
 
203
  self.fc = nn.Linear(in_ch, vocab_size)
204
 
205
  def forward(self, x, labels=None):
206
+ x = self.embedding(x).transpose(1, 2)
207
+ x = self.convs(x).transpose(1, 2)
208
  logits = self.fc(x)
209
  loss = None
210
  if labels is not None:
 
262
  class TextDataset(Dataset):
263
  def __init__(self, tokenized_data):
264
  self.input_ids = tokenized_data["input_ids"]
265
+ self.labels = tokenized_data["input_ids"]
266
 
267
  def __len__(self):
268
  return len(self.input_ids)
 
282
  device = "cuda" if torch.cuda.is_available() else "cpu"
283
  log_update(f"Run {run_id}: πŸš€ Device = {device} | RAM available: {psutil.virtual_memory().available / (1024**3):.2f} GB", run_id)
284
 
 
285
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
286
  if tokenizer.pad_token is None:
287
  tokenizer.pad_token = tokenizer.eos_token
288
 
289
  vocab_size = len(tokenizer)
 
 
290
  model = get_model(job["arch_type"], vocab_size, job["num_layers"]).to(device)
291
  log_update(f"Run {run_id}: 🧱 Model initialized: {job['arch_type']} x{job['num_layers']} layers", run_id)
292
 
 
293
  dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:100000]")
294
  def tokenize_function(examples):
295
  texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
 
298
  train_dataset = TextDataset(tokenized_dataset)
299
 
300
  train_loader = DataLoader(train_dataset, batch_size=job["batch_size"], shuffle=True)
 
 
301
  optimizer = torch.optim.AdamW(model.parameters(), lr=job["learning_rate"])
302
 
 
303
  model.train()
304
  log_update(f"Run {run_id}: ▢️ Starting training for {job['epochs']} epochs...", run_id)
305
 
 
323
  avg_loss = total_loss / len(train_loader)
324
  log_update(f"Run {run_id}: βœ… Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}", run_id)
325
 
 
326
  model_path = f"./runs/{run_id}"
327
  os.makedirs(model_path, exist_ok=True)
328
  torch.save(model.state_dict(), f"{model_path}/model.pth")
 
341
  def log_update(message, run_id):
342
  timestamp = time.strftime("%H:%M:%S")
343
  full_msg = f"[{timestamp}] {message}"
344
+ print(full_msg)
345
  if run_id > 0:
346
  update_run_status(run_id, "running", full_msg)
347
 
 
348
  # ------------------------------ PUBLIC API ------------------------------
349
 
350
  __all__ = [