Keeby-smilyai commited on
Commit
ff5d079
Β·
verified Β·
1 Parent(s): d05f823

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +110 -101
backend.py CHANGED
@@ -1,4 +1,4 @@
1
- # backend.py β€” PARALLEL PROCESSING VERSION
2
  import sqlite3
3
  import threading
4
  import time
@@ -11,24 +11,42 @@ from transformers import AutoTokenizer
11
  import psutil
12
  import os
13
  import shutil
 
14
 
15
  DB_PATH = "llm_kitchen.db"
16
  training_queue = []
17
- # --- NEW STATE MANAGEMENT FOR PARALLELISM ---
18
- active_runs = set() # Stores run_ids of currently running jobs
19
- active_users = set() # Stores user_ids of users with a currently running job
20
- scheduler_lock = threading.Lock() # Protects access to the queue and active sets
21
- # --- CONSTANTS ---
22
- RUN_TIMEOUT = 48 * 3600 # 48 hours
23
  MAX_RAM_PER_RUN_GB = 1.5
24
 
25
- # ------------------------------ DATABASE (No Changes Needed) ------------------------------
 
26
  def init_db():
27
  conn = sqlite3.connect(DB_PATH, check_same_thread=False)
28
  cursor = conn.cursor()
29
  cursor.executescript("""
30
- CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY AUTOINCREMENT, hf_token TEXT UNIQUE NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP);
31
- CREATE TABLE IF NOT EXISTS training_runs (id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, arch_type TEXT NOT NULL, num_layers INTEGER NOT NULL, learning_rate REAL NOT NULL, epochs INTEGER NOT NULL, batch_size INTEGER NOT NULL, status TEXT DEFAULT 'queued', logs TEXT DEFAULT '', started_at DATETIME, completed_at DATETIME, FOREIGN KEY (user_id) REFERENCES users(id));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """)
33
  conn.close()
34
 
@@ -44,111 +62,96 @@ def db_query(query, params=()):
44
  conn.close()
45
  return res, last_id
46
 
47
- def get_user_by_token(hf_token):
48
- rows, _ = db_query("SELECT id FROM users WHERE hf_token = ?", (hf_token,))
49
- return rows[0][0] if rows else None
50
-
51
- def create_user(hf_token):
52
- _, user_id = db_query("INSERT INTO users (hf_token) VALUES (?)", (hf_token,))
53
- return user_id
54
-
55
- def create_training_run(user_id, config):
56
- _, run_id = db_query("INSERT INTO training_runs (user_id, arch_type, num_layers, learning_rate, epochs, batch_size) VALUES (?, ?, ?, ?, ?, ?)", (user_id, config['arch_type'], config['num_layers'], config['learning_rate'], config['epochs'], config['batch_size']))
57
- return run_id
58
-
59
- def get_user_runs(user_id):
60
- rows, _ = db_query("SELECT id, arch_type, num_layers, status, started_at FROM training_runs WHERE user_id = ? ORDER BY id DESC", (user_id,))
61
- return rows
62
 
63
- def get_run_logs(run_id):
64
- rows, _ = db_query("SELECT logs, status FROM training_runs WHERE id = ?", (run_id,))
65
- return rows[0] if rows else ("", "unknown")
66
-
67
- def update_run_status(run_id, status):
68
- if status == 'running':
69
- db_query("UPDATE training_runs SET status = ?, started_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
70
- elif status in ['completed', 'failed', 'timeout']:
71
- db_query("UPDATE training_runs SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
72
- else:
73
- db_query("UPDATE training_runs SET status = ? WHERE id = ?", (status, run_id))
74
 
75
- def log_update(message, run_id):
76
- timestamp = time.strftime("%H:%M:%S")
77
- full_msg = f"[{timestamp}] {message}"
78
- print(full_msg)
79
- if run_id > 0:
80
- db_query("UPDATE training_runs SET logs = logs || ? || ? WHERE id = ?", ('\n', full_msg, run_id))
 
 
 
81
 
82
- # ------------------------------ AUTH (No Changes Needed) ------------------------------
83
- def verify_hf_token(token):
84
- try:
85
- whoami(token=token)
86
- user_id = get_user_by_token(token)
87
- if not user_id:
88
- user_id = create_user(token)
89
- return user_id, "Welcome to the LLM Kitchen, Chef! 🍳 Your apron is ready."
90
- return user_id, "Welcome back, Chef! πŸ‘¨β€πŸ³ Your last dish is still warm."
91
- except Exception as e:
92
- return None, f"Invalid token. Please try again. ({str(e)})"
93
 
94
- # ------------------------------ NEW PARALLEL TRAINING QUEUE ------------------------------
95
 
96
  def ram_available():
97
  return (psutil.virtual_memory().available / (1024**3)) >= MAX_RAM_PER_RUN_GB
98
 
99
  def queue_training_run(user_id, config):
100
- run_id = create_training_run(user_id, config)
101
  training_queue.append({"run_id": run_id, "user_id": user_id, **config})
 
 
102
  return run_id
103
 
104
  def start_training_if_free():
105
- """
106
- The new scheduler. Tries to start as many jobs as possible from the queue
107
- based on available RAM and the one-run-per-user constraint.
108
- """
109
  with scheduler_lock:
110
- # Iterate through a copy of the queue as we might modify it
111
  for job in list(training_queue):
112
- # 1. Check for global resource constraint (RAM)
113
  if not ram_available():
114
  log_update("MemoryWarning: Not enough RAM for new runs. Waiting.", -1)
115
- break # Stop trying to schedule if we're out of RAM
116
-
117
- # 2. Check for per-user constraint
118
  if job["user_id"] in active_users:
119
- continue # Skip this job, user already has a run. Check next job.
120
-
121
- # --- If we get here, we can start the job ---
122
  log_update(f"Scheduler: Starting run #{job['run_id']} for user #{job['user_id']}", -1)
123
-
124
- # Update state to reflect the new running job
125
  active_runs.add(job["run_id"])
126
  active_users.add(job["user_id"])
127
  training_queue.remove(job)
128
-
129
- # Update database and start the training thread
130
  update_run_status(job["run_id"], "running")
131
  log_update("🍳 Starting kitchen process...", job["run_id"])
132
-
133
  thread = threading.Thread(target=run_training_job, args=(job,))
134
  thread.start()
135
  threading.Timer(RUN_TIMEOUT, kill_run_timeout, args=[job]).start()
136
 
137
  def kill_run_timeout(job):
138
- run_id = job["run_id"]
139
- user_id = job["user_id"]
140
  with scheduler_lock:
141
  if run_id in active_runs:
142
- log_update(f"Run {run_id}: πŸ’₯ 48-HOUR TIMEOUT REACHED. Terminating.", run_id)
143
  update_run_status(run_id, "timeout")
144
- # Free up resources
145
  active_runs.discard(run_id)
146
  active_users.discard(user_id)
147
- # Try to schedule a new job now that resources are free
148
  start_training_if_free()
149
 
150
- # ------------------------------ MODELS & DATASET (No Changes Needed) -------------------------
151
- # ... (All model and dataset classes are unchanged) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  class CNNLanguageModel(nn.Module):
153
  def __init__(self, vocab_size, embed_dim=128, num_layers=4):
154
  super().__init__()
@@ -164,6 +167,7 @@ class CNNLanguageModel(nn.Module):
164
  logits = self.fc(x)
165
  loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
166
  return {"loss": loss, "logits": logits}
 
167
  class RNNLanguageModel(nn.Module):
168
  def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
169
  super().__init__()
@@ -176,6 +180,7 @@ class RNNLanguageModel(nn.Module):
176
  logits = self.fc(output)
177
  loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
178
  return {"loss": loss, "logits": logits}
 
179
  class TransformerLanguageModel(nn.Module):
180
  def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=3):
181
  super().__init__()
@@ -189,10 +194,11 @@ class TransformerLanguageModel(nn.Module):
189
  logits = self.fc(x)
190
  loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
191
  return {"loss": loss, "logits": logits}
 
192
  def get_model(arch_type, vocab_size, num_layers):
193
  models = {"cnn": CNNLanguageModel, "rnn": RNNLanguageModel, "transformer": TransformerLanguageModel}
194
- if arch_type not in models: raise ValueError(f"Unknown arch: {arch_type}")
195
  return models[arch_type](vocab_size, num_layers=num_layers)
 
196
  class TextDataset(Dataset):
197
  def __init__(self, tokenized_data):
198
  self.data = tokenized_data["input_ids"]
@@ -201,23 +207,18 @@ class TextDataset(Dataset):
201
  def __getitem__(self, idx):
202
  return {"input_ids": torch.tensor(self.data[idx]), "labels": torch.tensor(self.data[idx])}
203
 
204
- # ------------------------------ TRAINING JOB (Updated `finally` block) -----------------------
205
-
206
  def run_training_job(job):
207
- run_id = job["run_id"]
208
- user_id = job["user_id"] # Get user_id for state management
209
  try:
210
  device = "cuda" if torch.cuda.is_available() else "cpu"
211
- log_update(f"πŸš€ Device = {device} | RAM available: {psutil.virtual_memory().available / (1024**3):.2f} GB", run_id)
212
- # (The core training logic remains the same)
213
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
214
  tokenizer.pad_token = tokenizer.eos_token
215
  tokenizer_save_path = f"./runs/{run_id}/tokenizer"
216
  os.makedirs(tokenizer_save_path, exist_ok=True)
217
  tokenizer.save_pretrained(tokenizer_save_path)
218
- log_update(f"πŸ’Ύ Tokenizer saved to {tokenizer_save_path}", run_id)
219
  model = get_model(job["arch_type"], len(tokenizer), job["num_layers"]).to(device)
220
- log_update(f"🧱 Model initialized: {job['arch_type']} x{job['num_layers']} layers", run_id)
221
  dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:5000]")
222
  tokenized_dataset = dataset.map(lambda ex: tokenizer([q + " " + a for q, a in zip(ex["message"], ex["answer"])], truncation=True, padding="max_length", max_length=128), batched=True, remove_columns=dataset.column_names)
223
  train_loader = DataLoader(TextDataset(tokenized_dataset), batch_size=job["batch_size"], shuffle=True)
@@ -239,33 +240,27 @@ def run_training_job(job):
239
  model_path = f"./runs/{run_id}"
240
  os.makedirs(model_path, exist_ok=True)
241
  torch.save(model.state_dict(), f"{model_path}/pytorch_model.bin")
242
- log_update(f"πŸ’Ύ Model checkpoint saved successfully.", run_id)
243
-
244
  except Exception as e:
245
- error_message = f"πŸ’₯ FAILED - {str(e)}"
246
- log_update(error_message, run_id)
247
  update_run_status(run_id, "failed")
248
  else:
249
- success_message = "πŸŽ‰ Cooking complete! Your model is ready."
250
- log_update(success_message, run_id)
251
  update_run_status(run_id, "completed")
252
  finally:
253
- # --- NEW: Free up resources and trigger scheduler ---
254
  with scheduler_lock:
255
  active_runs.discard(run_id)
256
  active_users.discard(user_id)
257
  start_training_if_free()
258
 
259
- # ------------------------------ INFERENCE & PUBLISH (No Changes Needed) --------------------
260
- # ... (run_inference and publish_run_to_hub are unchanged) ...
261
  def run_inference(run_id, prompt):
262
  model_path = f"./runs/{run_id}/pytorch_model.bin"
263
  tokenizer_path = f"./runs/{run_id}/tokenizer"
264
  if not (os.path.exists(model_path) and os.path.exists(tokenizer_path)):
265
- return "ModelError: Model or tokenizer files not found."
266
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
267
  rows, _ = db_query("SELECT arch_type, num_layers FROM training_runs WHERE id = ?", (run_id,))
268
- if not rows: return "ModelError: Run not found in database."
 
269
  arch_type, num_layers = rows[0]
270
  model = get_model(arch_type, len(tokenizer), num_layers)
271
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
@@ -277,15 +272,29 @@ def run_inference(run_id, prompt):
277
  logits = outputs["logits"]
278
  generated_ids = torch.argmax(logits, dim=-1)
279
  return f"πŸ§‘β€πŸ³ Model says:\n{tokenizer.decode(generated_ids[0], skip_special_tokens=True)}"
 
 
 
280
  def publish_run_to_hub(run_id, hf_token, repo_name, user_description=""):
 
 
 
 
 
 
 
281
  local_dir = f"./runs/{run_id}/hub_upload"
282
  shutil.rmtree(local_dir, ignore_errors=True)
283
  os.makedirs(local_dir, exist_ok=True)
 
284
  shutil.copy(f"./runs/{run_id}/pytorch_model.bin", f"{local_dir}/pytorch_model.bin")
285
  shutil.copytree(f"./runs/{run_id}/tokenizer", f"{local_dir}/tokenizer", dirs_exist_ok=True)
 
286
  readme_content = user_description.strip() or f"# Model from LLM Kitchen - Run #{run_id}"
287
- with open(f"{local_dir}/README.md", "w") as f: f.write(readme_content)
 
 
288
  api = HfApi()
289
- repo_url = api.create_repo(repo_id=repo_name, token=hf_token, exist_ok=True).repo_id
290
  api.upload_folder(folder_path=local_dir, repo_id=repo_url, token=hf_token)
291
  return f"https://huggingface.co/{repo_url}"
 
1
+ # backend.py β€” USERNAME/PASSWORD & PARALLEL PROCESSING VERSION
2
  import sqlite3
3
  import threading
4
  import time
 
11
  import psutil
12
  import os
13
  import shutil
14
+ from werkzeug.security import generate_password_hash, check_password_hash
15
 
16
  DB_PATH = "llm_kitchen.db"
17
  training_queue = []
18
+ active_runs = set()
19
+ active_users = set()
20
+ scheduler_lock = threading.Lock()
21
+ RUN_TIMEOUT = 48 * 3600
 
 
22
  MAX_RAM_PER_RUN_GB = 1.5
23
 
24
+ # ------------------------------ DATABASE (NEW SCHEMA) ------------------------------
25
+
26
  def init_db():
27
  conn = sqlite3.connect(DB_PATH, check_same_thread=False)
28
  cursor = conn.cursor()
29
  cursor.executescript("""
30
+ CREATE TABLE IF NOT EXISTS users (
31
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
32
+ username TEXT UNIQUE NOT NULL,
33
+ password_hash TEXT NOT NULL,
34
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
35
+ );
36
+ CREATE TABLE IF NOT EXISTS training_runs (
37
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
38
+ user_id INTEGER NOT NULL,
39
+ arch_type TEXT NOT NULL,
40
+ num_layers INTEGER NOT NULL,
41
+ learning_rate REAL NOT NULL,
42
+ epochs INTEGER NOT NULL,
43
+ batch_size INTEGER NOT NULL,
44
+ status TEXT DEFAULT 'queued',
45
+ logs TEXT DEFAULT '',
46
+ started_at DATETIME,
47
+ completed_at DATETIME,
48
+ FOREIGN KEY (user_id) REFERENCES users(id)
49
+ );
50
  """)
51
  conn.close()
52
 
 
62
  conn.close()
63
  return res, last_id
64
 
65
+ def get_user_by_username(username):
66
+ rows, _ = db_query("SELECT id, password_hash FROM users WHERE username = ?", (username,))
67
+ return rows[0] if rows else None
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # ------------------------------ NEW AUTHENTICATION ------------------------------
 
 
 
 
 
 
 
 
 
 
70
 
71
+ def signup_user(username, password):
72
+ if not username or not password:
73
+ return None, "Username and password cannot be empty."
74
+ if get_user_by_username(username):
75
+ return None, "Username already exists. Please choose another."
76
+
77
+ password_hash = generate_password_hash(password)
78
+ _, user_id = db_query("INSERT INTO users (username, password_hash) VALUES (?, ?)", (username, password_hash))
79
+ return user_id, f"Welcome, {username}! Your account is ready. Please log in."
80
 
81
+ def login_user(username, password):
82
+ user = get_user_by_username(username)
83
+ if user and check_password_hash(user[1], password):
84
+ user_id = user[0]
85
+ return user_id, f"Welcome back, {username}!"
86
+ return None, "Invalid username or password."
 
 
 
 
 
87
 
88
+ # ------------------------------ PARALLEL TRAINING QUEUE ------------------------------
89
 
90
  def ram_available():
91
  return (psutil.virtual_memory().available / (1024**3)) >= MAX_RAM_PER_RUN_GB
92
 
93
  def queue_training_run(user_id, config):
94
+ _, run_id = db_query("INSERT INTO training_runs (user_id, arch_type, num_layers, learning_rate, epochs, batch_size) VALUES (?, ?, ?, ?, ?, ?)", (user_id, config['arch_type'], config['num_layers'], config['learning_rate'], config['epochs'], config['batch_size']))
95
  training_queue.append({"run_id": run_id, "user_id": user_id, **config})
96
+ # Trigger the scheduler every time a new job is added
97
+ start_training_if_free()
98
  return run_id
99
 
100
  def start_training_if_free():
 
 
 
 
101
  with scheduler_lock:
 
102
  for job in list(training_queue):
 
103
  if not ram_available():
104
  log_update("MemoryWarning: Not enough RAM for new runs. Waiting.", -1)
105
+ break
 
 
106
  if job["user_id"] in active_users:
107
+ continue
108
+
 
109
  log_update(f"Scheduler: Starting run #{job['run_id']} for user #{job['user_id']}", -1)
 
 
110
  active_runs.add(job["run_id"])
111
  active_users.add(job["user_id"])
112
  training_queue.remove(job)
113
+
 
114
  update_run_status(job["run_id"], "running")
115
  log_update("🍳 Starting kitchen process...", job["run_id"])
 
116
  thread = threading.Thread(target=run_training_job, args=(job,))
117
  thread.start()
118
  threading.Timer(RUN_TIMEOUT, kill_run_timeout, args=[job]).start()
119
 
120
  def kill_run_timeout(job):
121
+ run_id, user_id = job["run_id"], job["user_id"]
 
122
  with scheduler_lock:
123
  if run_id in active_runs:
124
+ log_update(f"Run {run_id}: πŸ’₯ 48-HOUR TIMEOUT. Terminating.", run_id)
125
  update_run_status(run_id, "timeout")
 
126
  active_runs.discard(run_id)
127
  active_users.discard(user_id)
 
128
  start_training_if_free()
129
 
130
+ def get_user_runs(user_id):
131
+ rows, _ = db_query("SELECT id, arch_type, num_layers, status, started_at FROM training_runs WHERE user_id = ? ORDER BY id DESC", (user_id,))
132
+ return rows
133
+
134
+ def get_run_logs(run_id):
135
+ rows, _ = db_query("SELECT logs, status FROM training_runs WHERE id = ?", (run_id,))
136
+ return rows[0] if rows else ("", "unknown")
137
+
138
+ def update_run_status(run_id, status):
139
+ if status == 'running':
140
+ db_query("UPDATE training_runs SET status = ?, started_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
141
+ elif status in ['completed', 'failed', 'timeout']:
142
+ db_query("UPDATE training_runs SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
143
+ else:
144
+ db_query("UPDATE training_runs SET status = ? WHERE id = ?", (status, run_id))
145
+
146
+ def log_update(message, run_id):
147
+ timestamp = time.strftime("%H:%M:%S")
148
+ full_msg = f"[{timestamp}] {message}"
149
+ print(full_msg)
150
+ if run_id > 0:
151
+ db_query("UPDATE training_runs SET logs = logs || ? || ? WHERE id = ?", ('\n', full_msg, run_id))
152
+
153
+ # ------------------------------ MODELS & TRAINING ------------------------------
154
+
155
  class CNNLanguageModel(nn.Module):
156
  def __init__(self, vocab_size, embed_dim=128, num_layers=4):
157
  super().__init__()
 
167
  logits = self.fc(x)
168
  loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
169
  return {"loss": loss, "logits": logits}
170
+
171
  class RNNLanguageModel(nn.Module):
172
  def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
173
  super().__init__()
 
180
  logits = self.fc(output)
181
  loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
182
  return {"loss": loss, "logits": logits}
183
+
184
  class TransformerLanguageModel(nn.Module):
185
  def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=3):
186
  super().__init__()
 
194
  logits = self.fc(x)
195
  loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
196
  return {"loss": loss, "logits": logits}
197
+
198
  def get_model(arch_type, vocab_size, num_layers):
199
  models = {"cnn": CNNLanguageModel, "rnn": RNNLanguageModel, "transformer": TransformerLanguageModel}
 
200
  return models[arch_type](vocab_size, num_layers=num_layers)
201
+
202
  class TextDataset(Dataset):
203
  def __init__(self, tokenized_data):
204
  self.data = tokenized_data["input_ids"]
 
207
  def __getitem__(self, idx):
208
  return {"input_ids": torch.tensor(self.data[idx]), "labels": torch.tensor(self.data[idx])}
209
 
 
 
210
  def run_training_job(job):
211
+ run_id, user_id = job["run_id"], job["user_id"]
 
212
  try:
213
  device = "cuda" if torch.cuda.is_available() else "cpu"
214
+ log_update(f"πŸš€ Device = {device}", run_id)
 
215
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
216
  tokenizer.pad_token = tokenizer.eos_token
217
  tokenizer_save_path = f"./runs/{run_id}/tokenizer"
218
  os.makedirs(tokenizer_save_path, exist_ok=True)
219
  tokenizer.save_pretrained(tokenizer_save_path)
 
220
  model = get_model(job["arch_type"], len(tokenizer), job["num_layers"]).to(device)
221
+ log_update(f"🧱 Model: {job['arch_type']} x{job['num_layers']} layers", run_id)
222
  dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:5000]")
223
  tokenized_dataset = dataset.map(lambda ex: tokenizer([q + " " + a for q, a in zip(ex["message"], ex["answer"])], truncation=True, padding="max_length", max_length=128), batched=True, remove_columns=dataset.column_names)
224
  train_loader = DataLoader(TextDataset(tokenized_dataset), batch_size=job["batch_size"], shuffle=True)
 
240
  model_path = f"./runs/{run_id}"
241
  os.makedirs(model_path, exist_ok=True)
242
  torch.save(model.state_dict(), f"{model_path}/pytorch_model.bin")
 
 
243
  except Exception as e:
244
+ log_update(f"πŸ’₯ FAILED - {str(e)}", run_id)
 
245
  update_run_status(run_id, "failed")
246
  else:
247
+ log_update("πŸŽ‰ Cooking complete!", run_id)
 
248
  update_run_status(run_id, "completed")
249
  finally:
 
250
  with scheduler_lock:
251
  active_runs.discard(run_id)
252
  active_users.discard(user_id)
253
  start_training_if_free()
254
 
 
 
255
  def run_inference(run_id, prompt):
256
  model_path = f"./runs/{run_id}/pytorch_model.bin"
257
  tokenizer_path = f"./runs/{run_id}/tokenizer"
258
  if not (os.path.exists(model_path) and os.path.exists(tokenizer_path)):
259
+ return "ModelError: Files not found."
260
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
261
  rows, _ = db_query("SELECT arch_type, num_layers FROM training_runs WHERE id = ?", (run_id,))
262
+ if not rows:
263
+ return "ModelError: Run not found."
264
  arch_type, num_layers = rows[0]
265
  model = get_model(arch_type, len(tokenizer), num_layers)
266
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
 
272
  logits = outputs["logits"]
273
  generated_ids = torch.argmax(logits, dim=-1)
274
  return f"πŸ§‘β€πŸ³ Model says:\n{tokenizer.decode(generated_ids[0], skip_special_tokens=True)}"
275
+
276
+ # ------------------------------ PUBLISH (HF Token passed as argument) ------------------------------
277
+
278
  def publish_run_to_hub(run_id, hf_token, repo_name, user_description=""):
279
+ try:
280
+ user_info = whoami(token=hf_token)
281
+ hf_username = user_info['name']
282
+ except Exception as e:
283
+ raise ValueError(f"Invalid Hugging Face Token. Error: {e}")
284
+
285
+ final_repo_name = f"{hf_username}/{repo_name}"
286
  local_dir = f"./runs/{run_id}/hub_upload"
287
  shutil.rmtree(local_dir, ignore_errors=True)
288
  os.makedirs(local_dir, exist_ok=True)
289
+
290
  shutil.copy(f"./runs/{run_id}/pytorch_model.bin", f"{local_dir}/pytorch_model.bin")
291
  shutil.copytree(f"./runs/{run_id}/tokenizer", f"{local_dir}/tokenizer", dirs_exist_ok=True)
292
+
293
  readme_content = user_description.strip() or f"# Model from LLM Kitchen - Run #{run_id}"
294
+ with open(f"{local_dir}/README.md", "w") as f:
295
+ f.write(readme_content)
296
+
297
  api = HfApi()
298
+ repo_url = api.create_repo(repo_id=final_repo_name, token=hf_token, exist_ok=True).repo_id
299
  api.upload_folder(folder_path=local_dir, repo_id=repo_url, token=hf_token)
300
  return f"https://huggingface.co/{repo_url}"