Keeby-smilyai commited on
Commit
7b63c1c
Β·
verified Β·
1 Parent(s): cc49bf6

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +178 -136
backend.py CHANGED
@@ -1,17 +1,23 @@
1
- # backend.py
2
  import sqlite3
3
  import threading
4
  import time
5
  import torch
 
 
6
  from huggingface_hub import whoami
7
  from datasets import load_dataset
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
 
9
  import os
 
10
 
11
  DB_PATH = "llm_kitchen.db"
12
  training_queue = []
13
  active_run_lock = threading.Lock()
14
  active_run_id = None
 
 
15
 
16
  # ------------------------------ DATABASE ------------------------------
17
 
@@ -38,6 +44,7 @@ def init_db():
38
  logs TEXT DEFAULT '',
39
  started_at DATETIME,
40
  completed_at DATETIME,
 
41
  FOREIGN KEY (user_id) REFERENCES users(id)
42
  );
43
  """)
@@ -46,93 +53,20 @@ def init_db():
46
 
47
  init_db()
48
 
49
- def get_user_by_token(hf_token):
50
- conn = sqlite3.connect(DB_PATH)
51
- cursor = conn.cursor()
52
- cursor.execute("SELECT id FROM users WHERE hf_token = ?", (hf_token,))
53
- row = cursor.fetchone()
54
- conn.close()
55
- return row[0] if row else None
56
-
57
- def create_user(hf_token):
58
- conn = sqlite3.connect(DB_PATH)
59
- cursor = conn.cursor()
60
- cursor.execute("INSERT INTO users (hf_token) VALUES (?)", (hf_token,))
61
- user_id = cursor.lastrowid
62
- conn.commit()
63
- conn.close()
64
- return user_id
65
-
66
- def create_training_run(user_id, config):
67
- conn = sqlite3.connect(DB_PATH)
68
- cursor = conn.cursor()
69
- cursor.execute("""
70
- INSERT INTO training_runs
71
- (user_id, arch_type, num_layers, learning_rate, epochs, batch_size)
72
- VALUES (?, ?, ?, ?, ?, ?)
73
- """, (
74
- user_id,
75
- config['arch_type'],
76
- config['num_layers'],
77
- config['learning_rate'],
78
- config['epochs'],
79
- config['batch_size']
80
- ))
81
- run_id = cursor.lastrowid
82
- conn.commit()
83
- conn.close()
84
- return run_id
85
-
86
- def get_user_runs(user_id):
87
- conn = sqlite3.connect(DB_PATH)
88
- cursor = conn.cursor()
89
- cursor.execute("""
90
- SELECT id, arch_type, num_layers, status, started_at
91
- FROM training_runs
92
- WHERE user_id = ?
93
- ORDER BY started_at DESC
94
- """, (user_id,))
95
- runs = cursor.fetchall()
96
- conn.close()
97
- return runs
98
-
99
- def get_run_logs(run_id):
100
- conn = sqlite3.connect(DB_PATH)
101
- cursor = conn.cursor()
102
- cursor.execute("SELECT logs, status FROM training_runs WHERE id = ?", (run_id,))
103
- row = cursor.fetchone()
104
- conn.close()
105
- return row if row else ("", "unknown")
106
-
107
- def update_run_status(run_id, status, logs=""):
108
- conn = sqlite3.connect(DB_PATH)
109
- cursor = conn.cursor()
110
- if status == 'running':
111
- cursor.execute("UPDATE training_runs SET status = ?, started_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
112
- elif status in ['completed', 'failed', 'timeout']:
113
- cursor.execute("UPDATE training_runs SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
114
- if logs:
115
- current_logs = get_run_logs(run_id)[0]
116
- cursor.execute("UPDATE training_runs SET logs = ? WHERE id = ?", (current_logs + "\n" + logs, run_id))
117
- conn.commit()
118
- conn.close()
119
 
120
  # ------------------------------ AUTH ------------------------------
121
-
122
- def verify_hf_token(token):
123
- try:
124
- whoami(token=token)
125
- user_id = get_user_by_token(token)
126
- if not user_id:
127
- user_id = create_user(token)
128
- return user_id, "Welcome to the LLM Kitchen, Chef! 🍳 Your apron is ready."
129
- else:
130
- return user_id, "Welcome back, Chef! πŸ‘¨β€πŸ³ Your last dish is still warm."
131
- except Exception as e:
132
- return None, f"Invalid token. Please try again. ({str(e)})"
133
 
134
  # ------------------------------ TRAINING QUEUE ------------------------------
135
 
 
 
 
 
 
 
 
136
  def queue_training_run(user_id, config):
137
  run_id = create_training_run(user_id, config)
138
  training_queue.append({
@@ -142,11 +76,6 @@ def queue_training_run(user_id, config):
142
  })
143
  return run_id
144
 
145
- def ram_check_mock():
146
- # Mock: Allow 1 run at a time, 1.5GB per run
147
- global active_run_id
148
- return active_run_id is None
149
-
150
  def start_training_if_free():
151
  global active_run_id
152
  with active_run_lock:
@@ -154,7 +83,8 @@ def start_training_if_free():
154
  return False
155
  if not training_queue:
156
  return False
157
- if not ram_check_mock():
 
158
  return False
159
 
160
  job = training_queue.pop(0)
@@ -163,72 +93,184 @@ def start_training_if_free():
163
 
164
  thread = threading.Thread(target=run_training_job, args=(job,))
165
  thread.start()
 
 
 
 
 
166
  return True
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def run_training_job(job):
169
  global active_run_id
170
  run_id = job["run_id"]
171
  try:
172
  device = "cuda" if torch.cuda.is_available() else "cpu"
173
- log_update(f"Run {run_id}: Device = {device}", run_id)
174
 
175
- # Load tiny model for demo (replace with custom later)
176
- model_name = "distilgpt2"
177
- tokenizer = AutoTokenizer.from_pretrained(model_name)
178
  if tokenizer.pad_token is None:
179
  tokenizer.pad_token = tokenizer.eos_token
180
 
181
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
182
- log_update(f"Run {run_id}: Model loaded", run_id)
 
 
 
183
 
184
- # Load dataset
185
- dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:1%]") # Tiny slice for demo
186
  def tokenize_function(examples):
187
  texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
188
  return tokenizer(texts, truncation=True, padding="max_length", max_length=128)
189
- tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["message", "answer"])
190
- log_update(f"Run {run_id}: Dataset tokenized", run_id)
191
-
192
- # Training args
193
- training_args = TrainingArguments(
194
- output_dir=f"./runs/{run_id}",
195
- num_train_epochs=job["epochs"],
196
- per_device_train_batch_size=job["batch_size"],
197
- learning_rate=job["learning_rate"],
198
- save_strategy="no",
199
- logging_steps=1,
200
- report_to="none",
201
- fp16=False,
202
- no_cuda=(device == "cpu")
203
- )
204
-
205
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
206
- trainer = Trainer(
207
- model=model,
208
- args=training_args,
209
- train_dataset=tokenized_dataset,
210
- data_collator=data_collator,
211
- )
212
-
213
- log_update(f"Run {run_id}: Starting training...", run_id)
214
- trainer.train()
215
-
216
- # Simulate 48h timeout with short sleep for demo
217
- time.sleep(10) # Replace with real training
218
-
219
- eval_results = trainer.evaluate()
220
- log_update(f"Run {run_id}: Training complete. Loss = {eval_results.get('eval_loss', 'N/A')}", run_id)
221
- update_run_status(run_id, "completed")
 
 
 
 
 
222
 
223
  except Exception as e:
224
- log_update(f"Run {run_id}: FAILED - {str(e)}", run_id)
225
  update_run_status(run_id, "failed")
226
  finally:
227
  with active_run_lock:
228
- active_run_id = None
229
- # Try starting next queued job
230
  start_training_if_free()
231
 
232
  def log_update(message, run_id):
233
- print(f"[LOG] {message}") # Also print to Spaces logs
234
- update_run_status(run_id, "running", message)
 
 
 
 
1
+ # backend.py β€” REAL VERSION
2
  import sqlite3
3
  import threading
4
  import time
5
  import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader, Dataset
8
  from huggingface_hub import whoami
9
  from datasets import load_dataset
10
+ from transformers import AutoTokenizer
11
+ import psutil
12
  import os
13
+ import signal
14
 
15
  DB_PATH = "llm_kitchen.db"
16
  training_queue = []
17
  active_run_lock = threading.Lock()
18
  active_run_id = None
19
+ RUN_TIMEOUT = 48 * 3600 # 48 hours
20
+ MAX_RAM_PER_RUN_GB = 1.5
21
 
22
  # ------------------------------ DATABASE ------------------------------
23
 
 
44
  logs TEXT DEFAULT '',
45
  started_at DATETIME,
46
  completed_at DATETIME,
47
+ model_path TEXT,
48
  FOREIGN KEY (user_id) REFERENCES users(id)
49
  );
50
  """)
 
53
 
54
  init_db()
55
 
56
+ # ... [KEEP ALL DB HELPER FUNCTIONS: get_user_by_token, create_user, etc. β€” NO CHANGES] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # ------------------------------ AUTH ------------------------------
59
+ # ... [KEEP verify_hf_token β€” NO CHANGES] ...
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # ------------------------------ TRAINING QUEUE ------------------------------
62
 
63
+ def ram_available():
64
+ """Check if we can start a new run (1.5GB per run)"""
65
+ total_ram = psutil.virtual_memory().total / (1024**3) # GB
66
+ used_ram = psutil.virtual_memory().used / (1024**3) # GB
67
+ available_gb = total_ram - used_ram
68
+ return available_gb >= MAX_RAM_PER_RUN_GB
69
+
70
  def queue_training_run(user_id, config):
71
  run_id = create_training_run(user_id, config)
72
  training_queue.append({
 
76
  })
77
  return run_id
78
 
 
 
 
 
 
79
  def start_training_if_free():
80
  global active_run_id
81
  with active_run_lock:
 
83
  return False
84
  if not training_queue:
85
  return False
86
+ if not ram_available():
87
+ log_update("MemoryWarning: Not enough RAM to start new run.", -1)
88
  return False
89
 
90
  job = training_queue.pop(0)
 
93
 
94
  thread = threading.Thread(target=run_training_job, args=(job,))
95
  thread.start()
96
+
97
+ # Start 48h timeout killer
98
+ timer = threading.Timer(RUN_TIMEOUT, kill_run_timeout, args=[job["run_id"]])
99
+ timer.start()
100
+
101
  return True
102
 
103
+ def kill_run_timeout(run_id):
104
+ global active_run_id
105
+ with active_run_lock:
106
+ if active_run_id == run_id:
107
+ log_update(f"Run {run_id}: πŸ’₯ 48-HOUR TIMEOUT REACHED. Terminating.", run_id)
108
+ update_run_status(run_id, "timeout")
109
+ active_run_id = None
110
+ start_training_if_free() # try next
111
+
112
+ # ------------------------------ CUSTOM MODELS FROM SCRATCH ------------------------------
113
+
114
+ class CNNLanguageModel(nn.Module):
115
+ def __init__(self, vocab_size, embed_dim=128, num_layers=4):
116
+ super().__init__()
117
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
118
+ layers = []
119
+ in_ch = embed_dim
120
+ for _ in range(num_layers):
121
+ layers.append(nn.Conv1d(in_ch, in_ch * 2, kernel_size=3, padding=1))
122
+ layers.append(nn.ReLU())
123
+ in_ch *= 2
124
+ self.convs = nn.Sequential(*layers)
125
+ self.fc = nn.Linear(in_ch, vocab_size)
126
+
127
+ def forward(self, x, labels=None):
128
+ x = self.embedding(x).transpose(1, 2) # (B, E, L)
129
+ x = self.convs(x).transpose(1, 2) # (B, L, E*2^N)
130
+ logits = self.fc(x)
131
+ loss = None
132
+ if labels is not None:
133
+ loss_fct = nn.CrossEntropyLoss()
134
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
135
+ return {"loss": loss, "logits": logits}
136
+
137
+ class RNNLanguageModel(nn.Module):
138
+ def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
139
+ super().__init__()
140
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
141
+ self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
142
+ self.fc = nn.Linear(hidden_dim, vocab_size)
143
+
144
+ def forward(self, x, labels=None):
145
+ x = self.embedding(x)
146
+ output, _ = self.rnn(x)
147
+ logits = self.fc(output)
148
+ loss = None
149
+ if labels is not None:
150
+ loss_fct = nn.CrossEntropyLoss()
151
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
152
+ return {"loss": loss, "logits": logits}
153
+
154
+ class TransformerLanguageModel(nn.Module):
155
+ def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=3):
156
+ super().__init__()
157
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
158
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
159
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
160
+ self.fc = nn.Linear(embed_dim, vocab_size)
161
+
162
+ def forward(self, x, labels=None):
163
+ x = self.embedding(x)
164
+ x = self.transformer(x)
165
+ logits = self.fc(x)
166
+ loss = None
167
+ if labels is not None:
168
+ loss_fct = nn.CrossEntropyLoss()
169
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
170
+ return {"loss": loss, "logits": logits}
171
+
172
+ def get_model(arch_type, vocab_size, num_layers):
173
+ if arch_type == "cnn":
174
+ return CNNLanguageModel(vocab_size, num_layers=num_layers)
175
+ elif arch_type == "rnn":
176
+ return RNNLanguageModel(vocab_size, num_layers=num_layers)
177
+ elif arch_type == "transformer":
178
+ return TransformerLanguageModel(vocab_size, num_layers=num_layers)
179
+ else:
180
+ raise ValueError(f"Unknown arch: {arch_type}")
181
+
182
+ # ------------------------------ DATASET ------------------------------
183
+
184
+ class TextDataset(Dataset):
185
+ def __init__(self, tokenized_data):
186
+ self.input_ids = tokenized_data["input_ids"]
187
+ self.labels = tokenized_data["input_ids"] # causal LM
188
+
189
+ def __len__(self):
190
+ return len(self.input_ids)
191
+
192
+ def __getitem__(self, idx):
193
+ return {
194
+ "input_ids": torch.tensor(self.input_ids[idx], dtype=torch.long),
195
+ "labels": torch.tensor(self.labels[idx], dtype=torch.long),
196
+ }
197
+
198
+ # ------------------------------ TRAINING JOB ------------------------------
199
+
200
  def run_training_job(job):
201
  global active_run_id
202
  run_id = job["run_id"]
203
  try:
204
  device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ log_update(f"Run {run_id}: πŸš€ Device = {device} | RAM available: {psutil.virtual_memory().available / (1024**3):.2f} GB", run_id)
206
 
207
+ # Load tokenizer (shared for all models)
208
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
209
  if tokenizer.pad_token is None:
210
  tokenizer.pad_token = tokenizer.eos_token
211
 
212
+ vocab_size = len(tokenizer)
213
+
214
+ # Build model from scratch
215
+ model = get_model(job["arch_type"], vocab_size, job["num_layers"]).to(device)
216
+ log_update(f"Run {run_id}: 🧱 Model initialized: {job['arch_type']} x{job['num_layers']} layers", run_id)
217
 
218
+ # Load dataset β€” full training set (or 100K for speed)
219
+ dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:100000]")
220
  def tokenize_function(examples):
221
  texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
222
  return tokenizer(texts, truncation=True, padding="max_length", max_length=128)
223
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
224
+ train_dataset = TextDataset(tokenized_dataset)
225
+
226
+ train_loader = DataLoader(train_dataset, batch_size=job["batch_size"], shuffle=True)
227
+
228
+ # Optimizer
229
+ optimizer = torch.optim.AdamW(model.parameters(), lr=job["learning_rate"])
230
+
231
+ # Training loop
232
+ model.train()
233
+ log_update(f"Run {run_id}: ▢️ Starting training for {job['epochs']} epochs...", run_id)
234
+
235
+ for epoch in range(job["epochs"]):
236
+ total_loss = 0
237
+ for step, batch in enumerate(train_loader):
238
+ input_ids = batch["input_ids"].to(device)
239
+ labels = batch["labels"].to(device)
240
+
241
+ optimizer.zero_grad()
242
+ outputs = model(input_ids, labels=labels)
243
+ loss = outputs["loss"]
244
+ loss.backward()
245
+ optimizer.step()
246
+
247
+ total_loss += loss.item()
248
+ if step % 50 == 0:
249
+ ram_gb = psutil.virtual_memory().used / (1024**3)
250
+ log_update(f"Run {run_id}: Epoch {epoch+1} | Step {step} | Loss: {loss.item():.4f} | RAM: {ram_gb:.2f}GB", run_id)
251
+
252
+ avg_loss = total_loss / len(train_loader)
253
+ log_update(f"Run {run_id}: βœ… Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}", run_id)
254
+
255
+ # Save model
256
+ model_path = f"./runs/{run_id}"
257
+ os.makedirs(model_path, exist_ok=True)
258
+ torch.save(model.state_dict(), f"{model_path}/model.pth")
259
+ update_run_status(run_id, "completed", f"Model saved to {model_path}")
260
+ log_update(f"Run {run_id}: πŸ’Ύ Model checkpoint saved.", run_id)
261
 
262
  except Exception as e:
263
+ log_update(f"Run {run_id}: πŸ’₯ FAILED - {str(e)}", run_id)
264
  update_run_status(run_id, "failed")
265
  finally:
266
  with active_run_lock:
267
+ if active_run_id == run_id:
268
+ active_run_id = None
269
  start_training_if_free()
270
 
271
  def log_update(message, run_id):
272
+ timestamp = time.strftime("%H:%M:%S")
273
+ full_msg = f"[{timestamp}] {message}"
274
+ print(full_msg) # Also shows in HF Spaces logs
275
+ if run_id > 0:
276
+ update_run_status(run_id, "running", full_msg)