Spaces:
Running
Running
Update backend.py
Browse files- backend.py +178 -136
backend.py
CHANGED
|
@@ -1,17 +1,23 @@
|
|
| 1 |
-
# backend.py
|
| 2 |
import sqlite3
|
| 3 |
import threading
|
| 4 |
import time
|
| 5 |
import torch
|
|
|
|
|
|
|
| 6 |
from huggingface_hub import whoami
|
| 7 |
from datasets import load_dataset
|
| 8 |
-
from transformers import AutoTokenizer
|
|
|
|
| 9 |
import os
|
|
|
|
| 10 |
|
| 11 |
DB_PATH = "llm_kitchen.db"
|
| 12 |
training_queue = []
|
| 13 |
active_run_lock = threading.Lock()
|
| 14 |
active_run_id = None
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# ------------------------------ DATABASE ------------------------------
|
| 17 |
|
|
@@ -38,6 +44,7 @@ def init_db():
|
|
| 38 |
logs TEXT DEFAULT '',
|
| 39 |
started_at DATETIME,
|
| 40 |
completed_at DATETIME,
|
|
|
|
| 41 |
FOREIGN KEY (user_id) REFERENCES users(id)
|
| 42 |
);
|
| 43 |
""")
|
|
@@ -46,93 +53,20 @@ def init_db():
|
|
| 46 |
|
| 47 |
init_db()
|
| 48 |
|
| 49 |
-
|
| 50 |
-
conn = sqlite3.connect(DB_PATH)
|
| 51 |
-
cursor = conn.cursor()
|
| 52 |
-
cursor.execute("SELECT id FROM users WHERE hf_token = ?", (hf_token,))
|
| 53 |
-
row = cursor.fetchone()
|
| 54 |
-
conn.close()
|
| 55 |
-
return row[0] if row else None
|
| 56 |
-
|
| 57 |
-
def create_user(hf_token):
|
| 58 |
-
conn = sqlite3.connect(DB_PATH)
|
| 59 |
-
cursor = conn.cursor()
|
| 60 |
-
cursor.execute("INSERT INTO users (hf_token) VALUES (?)", (hf_token,))
|
| 61 |
-
user_id = cursor.lastrowid
|
| 62 |
-
conn.commit()
|
| 63 |
-
conn.close()
|
| 64 |
-
return user_id
|
| 65 |
-
|
| 66 |
-
def create_training_run(user_id, config):
|
| 67 |
-
conn = sqlite3.connect(DB_PATH)
|
| 68 |
-
cursor = conn.cursor()
|
| 69 |
-
cursor.execute("""
|
| 70 |
-
INSERT INTO training_runs
|
| 71 |
-
(user_id, arch_type, num_layers, learning_rate, epochs, batch_size)
|
| 72 |
-
VALUES (?, ?, ?, ?, ?, ?)
|
| 73 |
-
""", (
|
| 74 |
-
user_id,
|
| 75 |
-
config['arch_type'],
|
| 76 |
-
config['num_layers'],
|
| 77 |
-
config['learning_rate'],
|
| 78 |
-
config['epochs'],
|
| 79 |
-
config['batch_size']
|
| 80 |
-
))
|
| 81 |
-
run_id = cursor.lastrowid
|
| 82 |
-
conn.commit()
|
| 83 |
-
conn.close()
|
| 84 |
-
return run_id
|
| 85 |
-
|
| 86 |
-
def get_user_runs(user_id):
|
| 87 |
-
conn = sqlite3.connect(DB_PATH)
|
| 88 |
-
cursor = conn.cursor()
|
| 89 |
-
cursor.execute("""
|
| 90 |
-
SELECT id, arch_type, num_layers, status, started_at
|
| 91 |
-
FROM training_runs
|
| 92 |
-
WHERE user_id = ?
|
| 93 |
-
ORDER BY started_at DESC
|
| 94 |
-
""", (user_id,))
|
| 95 |
-
runs = cursor.fetchall()
|
| 96 |
-
conn.close()
|
| 97 |
-
return runs
|
| 98 |
-
|
| 99 |
-
def get_run_logs(run_id):
|
| 100 |
-
conn = sqlite3.connect(DB_PATH)
|
| 101 |
-
cursor = conn.cursor()
|
| 102 |
-
cursor.execute("SELECT logs, status FROM training_runs WHERE id = ?", (run_id,))
|
| 103 |
-
row = cursor.fetchone()
|
| 104 |
-
conn.close()
|
| 105 |
-
return row if row else ("", "unknown")
|
| 106 |
-
|
| 107 |
-
def update_run_status(run_id, status, logs=""):
|
| 108 |
-
conn = sqlite3.connect(DB_PATH)
|
| 109 |
-
cursor = conn.cursor()
|
| 110 |
-
if status == 'running':
|
| 111 |
-
cursor.execute("UPDATE training_runs SET status = ?, started_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
|
| 112 |
-
elif status in ['completed', 'failed', 'timeout']:
|
| 113 |
-
cursor.execute("UPDATE training_runs SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?", (status, run_id))
|
| 114 |
-
if logs:
|
| 115 |
-
current_logs = get_run_logs(run_id)[0]
|
| 116 |
-
cursor.execute("UPDATE training_runs SET logs = ? WHERE id = ?", (current_logs + "\n" + logs, run_id))
|
| 117 |
-
conn.commit()
|
| 118 |
-
conn.close()
|
| 119 |
|
| 120 |
# ------------------------------ AUTH ------------------------------
|
| 121 |
-
|
| 122 |
-
def verify_hf_token(token):
|
| 123 |
-
try:
|
| 124 |
-
whoami(token=token)
|
| 125 |
-
user_id = get_user_by_token(token)
|
| 126 |
-
if not user_id:
|
| 127 |
-
user_id = create_user(token)
|
| 128 |
-
return user_id, "Welcome to the LLM Kitchen, Chef! π³ Your apron is ready."
|
| 129 |
-
else:
|
| 130 |
-
return user_id, "Welcome back, Chef! π¨βπ³ Your last dish is still warm."
|
| 131 |
-
except Exception as e:
|
| 132 |
-
return None, f"Invalid token. Please try again. ({str(e)})"
|
| 133 |
|
| 134 |
# ------------------------------ TRAINING QUEUE ------------------------------
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def queue_training_run(user_id, config):
|
| 137 |
run_id = create_training_run(user_id, config)
|
| 138 |
training_queue.append({
|
|
@@ -142,11 +76,6 @@ def queue_training_run(user_id, config):
|
|
| 142 |
})
|
| 143 |
return run_id
|
| 144 |
|
| 145 |
-
def ram_check_mock():
|
| 146 |
-
# Mock: Allow 1 run at a time, 1.5GB per run
|
| 147 |
-
global active_run_id
|
| 148 |
-
return active_run_id is None
|
| 149 |
-
|
| 150 |
def start_training_if_free():
|
| 151 |
global active_run_id
|
| 152 |
with active_run_lock:
|
|
@@ -154,7 +83,8 @@ def start_training_if_free():
|
|
| 154 |
return False
|
| 155 |
if not training_queue:
|
| 156 |
return False
|
| 157 |
-
if not
|
|
|
|
| 158 |
return False
|
| 159 |
|
| 160 |
job = training_queue.pop(0)
|
|
@@ -163,72 +93,184 @@ def start_training_if_free():
|
|
| 163 |
|
| 164 |
thread = threading.Thread(target=run_training_job, args=(job,))
|
| 165 |
thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
return True
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
def run_training_job(job):
|
| 169 |
global active_run_id
|
| 170 |
run_id = job["run_id"]
|
| 171 |
try:
|
| 172 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 173 |
-
log_update(f"Run {run_id}: Device = {device}", run_id)
|
| 174 |
|
| 175 |
-
# Load
|
| 176 |
-
|
| 177 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 178 |
if tokenizer.pad_token is None:
|
| 179 |
tokenizer.pad_token = tokenizer.eos_token
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
# Load dataset
|
| 185 |
-
dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:
|
| 186 |
def tokenize_function(examples):
|
| 187 |
texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
|
| 188 |
return tokenizer(texts, truncation=True, padding="max_length", max_length=128)
|
| 189 |
-
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
except Exception as e:
|
| 224 |
-
log_update(f"Run {run_id}: FAILED - {str(e)}", run_id)
|
| 225 |
update_run_status(run_id, "failed")
|
| 226 |
finally:
|
| 227 |
with active_run_lock:
|
| 228 |
-
active_run_id
|
| 229 |
-
|
| 230 |
start_training_if_free()
|
| 231 |
|
| 232 |
def log_update(message, run_id):
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# backend.py β REAL VERSION
|
| 2 |
import sqlite3
|
| 3 |
import threading
|
| 4 |
import time
|
| 5 |
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader, Dataset
|
| 8 |
from huggingface_hub import whoami
|
| 9 |
from datasets import load_dataset
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
import psutil
|
| 12 |
import os
|
| 13 |
+
import signal
|
| 14 |
|
| 15 |
DB_PATH = "llm_kitchen.db"
|
| 16 |
training_queue = []
|
| 17 |
active_run_lock = threading.Lock()
|
| 18 |
active_run_id = None
|
| 19 |
+
RUN_TIMEOUT = 48 * 3600 # 48 hours
|
| 20 |
+
MAX_RAM_PER_RUN_GB = 1.5
|
| 21 |
|
| 22 |
# ------------------------------ DATABASE ------------------------------
|
| 23 |
|
|
|
|
| 44 |
logs TEXT DEFAULT '',
|
| 45 |
started_at DATETIME,
|
| 46 |
completed_at DATETIME,
|
| 47 |
+
model_path TEXT,
|
| 48 |
FOREIGN KEY (user_id) REFERENCES users(id)
|
| 49 |
);
|
| 50 |
""")
|
|
|
|
| 53 |
|
| 54 |
init_db()
|
| 55 |
|
| 56 |
+
# ... [KEEP ALL DB HELPER FUNCTIONS: get_user_by_token, create_user, etc. β NO CHANGES] ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# ------------------------------ AUTH ------------------------------
|
| 59 |
+
# ... [KEEP verify_hf_token β NO CHANGES] ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# ------------------------------ TRAINING QUEUE ------------------------------
|
| 62 |
|
| 63 |
+
def ram_available():
|
| 64 |
+
"""Check if we can start a new run (1.5GB per run)"""
|
| 65 |
+
total_ram = psutil.virtual_memory().total / (1024**3) # GB
|
| 66 |
+
used_ram = psutil.virtual_memory().used / (1024**3) # GB
|
| 67 |
+
available_gb = total_ram - used_ram
|
| 68 |
+
return available_gb >= MAX_RAM_PER_RUN_GB
|
| 69 |
+
|
| 70 |
def queue_training_run(user_id, config):
|
| 71 |
run_id = create_training_run(user_id, config)
|
| 72 |
training_queue.append({
|
|
|
|
| 76 |
})
|
| 77 |
return run_id
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def start_training_if_free():
|
| 80 |
global active_run_id
|
| 81 |
with active_run_lock:
|
|
|
|
| 83 |
return False
|
| 84 |
if not training_queue:
|
| 85 |
return False
|
| 86 |
+
if not ram_available():
|
| 87 |
+
log_update("MemoryWarning: Not enough RAM to start new run.", -1)
|
| 88 |
return False
|
| 89 |
|
| 90 |
job = training_queue.pop(0)
|
|
|
|
| 93 |
|
| 94 |
thread = threading.Thread(target=run_training_job, args=(job,))
|
| 95 |
thread.start()
|
| 96 |
+
|
| 97 |
+
# Start 48h timeout killer
|
| 98 |
+
timer = threading.Timer(RUN_TIMEOUT, kill_run_timeout, args=[job["run_id"]])
|
| 99 |
+
timer.start()
|
| 100 |
+
|
| 101 |
return True
|
| 102 |
|
| 103 |
+
def kill_run_timeout(run_id):
|
| 104 |
+
global active_run_id
|
| 105 |
+
with active_run_lock:
|
| 106 |
+
if active_run_id == run_id:
|
| 107 |
+
log_update(f"Run {run_id}: π₯ 48-HOUR TIMEOUT REACHED. Terminating.", run_id)
|
| 108 |
+
update_run_status(run_id, "timeout")
|
| 109 |
+
active_run_id = None
|
| 110 |
+
start_training_if_free() # try next
|
| 111 |
+
|
| 112 |
+
# ------------------------------ CUSTOM MODELS FROM SCRATCH ------------------------------
|
| 113 |
+
|
| 114 |
+
class CNNLanguageModel(nn.Module):
|
| 115 |
+
def __init__(self, vocab_size, embed_dim=128, num_layers=4):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 118 |
+
layers = []
|
| 119 |
+
in_ch = embed_dim
|
| 120 |
+
for _ in range(num_layers):
|
| 121 |
+
layers.append(nn.Conv1d(in_ch, in_ch * 2, kernel_size=3, padding=1))
|
| 122 |
+
layers.append(nn.ReLU())
|
| 123 |
+
in_ch *= 2
|
| 124 |
+
self.convs = nn.Sequential(*layers)
|
| 125 |
+
self.fc = nn.Linear(in_ch, vocab_size)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, labels=None):
|
| 128 |
+
x = self.embedding(x).transpose(1, 2) # (B, E, L)
|
| 129 |
+
x = self.convs(x).transpose(1, 2) # (B, L, E*2^N)
|
| 130 |
+
logits = self.fc(x)
|
| 131 |
+
loss = None
|
| 132 |
+
if labels is not None:
|
| 133 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 134 |
+
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 135 |
+
return {"loss": loss, "logits": logits}
|
| 136 |
+
|
| 137 |
+
class RNNLanguageModel(nn.Module):
|
| 138 |
+
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 141 |
+
self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
|
| 142 |
+
self.fc = nn.Linear(hidden_dim, vocab_size)
|
| 143 |
+
|
| 144 |
+
def forward(self, x, labels=None):
|
| 145 |
+
x = self.embedding(x)
|
| 146 |
+
output, _ = self.rnn(x)
|
| 147 |
+
logits = self.fc(output)
|
| 148 |
+
loss = None
|
| 149 |
+
if labels is not None:
|
| 150 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 151 |
+
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 152 |
+
return {"loss": loss, "logits": logits}
|
| 153 |
+
|
| 154 |
+
class TransformerLanguageModel(nn.Module):
|
| 155 |
+
def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=3):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 158 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
|
| 159 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 160 |
+
self.fc = nn.Linear(embed_dim, vocab_size)
|
| 161 |
+
|
| 162 |
+
def forward(self, x, labels=None):
|
| 163 |
+
x = self.embedding(x)
|
| 164 |
+
x = self.transformer(x)
|
| 165 |
+
logits = self.fc(x)
|
| 166 |
+
loss = None
|
| 167 |
+
if labels is not None:
|
| 168 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 169 |
+
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 170 |
+
return {"loss": loss, "logits": logits}
|
| 171 |
+
|
| 172 |
+
def get_model(arch_type, vocab_size, num_layers):
|
| 173 |
+
if arch_type == "cnn":
|
| 174 |
+
return CNNLanguageModel(vocab_size, num_layers=num_layers)
|
| 175 |
+
elif arch_type == "rnn":
|
| 176 |
+
return RNNLanguageModel(vocab_size, num_layers=num_layers)
|
| 177 |
+
elif arch_type == "transformer":
|
| 178 |
+
return TransformerLanguageModel(vocab_size, num_layers=num_layers)
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError(f"Unknown arch: {arch_type}")
|
| 181 |
+
|
| 182 |
+
# ------------------------------ DATASET ------------------------------
|
| 183 |
+
|
| 184 |
+
class TextDataset(Dataset):
|
| 185 |
+
def __init__(self, tokenized_data):
|
| 186 |
+
self.input_ids = tokenized_data["input_ids"]
|
| 187 |
+
self.labels = tokenized_data["input_ids"] # causal LM
|
| 188 |
+
|
| 189 |
+
def __len__(self):
|
| 190 |
+
return len(self.input_ids)
|
| 191 |
+
|
| 192 |
+
def __getitem__(self, idx):
|
| 193 |
+
return {
|
| 194 |
+
"input_ids": torch.tensor(self.input_ids[idx], dtype=torch.long),
|
| 195 |
+
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# ------------------------------ TRAINING JOB ------------------------------
|
| 199 |
+
|
| 200 |
def run_training_job(job):
|
| 201 |
global active_run_id
|
| 202 |
run_id = job["run_id"]
|
| 203 |
try:
|
| 204 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 205 |
+
log_update(f"Run {run_id}: π Device = {device} | RAM available: {psutil.virtual_memory().available / (1024**3):.2f} GB", run_id)
|
| 206 |
|
| 207 |
+
# Load tokenizer (shared for all models)
|
| 208 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
| 209 |
if tokenizer.pad_token is None:
|
| 210 |
tokenizer.pad_token = tokenizer.eos_token
|
| 211 |
|
| 212 |
+
vocab_size = len(tokenizer)
|
| 213 |
+
|
| 214 |
+
# Build model from scratch
|
| 215 |
+
model = get_model(job["arch_type"], vocab_size, job["num_layers"]).to(device)
|
| 216 |
+
log_update(f"Run {run_id}: π§± Model initialized: {job['arch_type']} x{job['num_layers']} layers", run_id)
|
| 217 |
|
| 218 |
+
# Load dataset β full training set (or 100K for speed)
|
| 219 |
+
dataset = load_dataset("voidful/reasoning_gemini_300k", split="train[:100000]")
|
| 220 |
def tokenize_function(examples):
|
| 221 |
texts = [q + " " + a for q, a in zip(examples["message"], examples["answer"])]
|
| 222 |
return tokenizer(texts, truncation=True, padding="max_length", max_length=128)
|
| 223 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
|
| 224 |
+
train_dataset = TextDataset(tokenized_dataset)
|
| 225 |
+
|
| 226 |
+
train_loader = DataLoader(train_dataset, batch_size=job["batch_size"], shuffle=True)
|
| 227 |
+
|
| 228 |
+
# Optimizer
|
| 229 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=job["learning_rate"])
|
| 230 |
+
|
| 231 |
+
# Training loop
|
| 232 |
+
model.train()
|
| 233 |
+
log_update(f"Run {run_id}: βΆοΈ Starting training for {job['epochs']} epochs...", run_id)
|
| 234 |
+
|
| 235 |
+
for epoch in range(job["epochs"]):
|
| 236 |
+
total_loss = 0
|
| 237 |
+
for step, batch in enumerate(train_loader):
|
| 238 |
+
input_ids = batch["input_ids"].to(device)
|
| 239 |
+
labels = batch["labels"].to(device)
|
| 240 |
+
|
| 241 |
+
optimizer.zero_grad()
|
| 242 |
+
outputs = model(input_ids, labels=labels)
|
| 243 |
+
loss = outputs["loss"]
|
| 244 |
+
loss.backward()
|
| 245 |
+
optimizer.step()
|
| 246 |
+
|
| 247 |
+
total_loss += loss.item()
|
| 248 |
+
if step % 50 == 0:
|
| 249 |
+
ram_gb = psutil.virtual_memory().used / (1024**3)
|
| 250 |
+
log_update(f"Run {run_id}: Epoch {epoch+1} | Step {step} | Loss: {loss.item():.4f} | RAM: {ram_gb:.2f}GB", run_id)
|
| 251 |
+
|
| 252 |
+
avg_loss = total_loss / len(train_loader)
|
| 253 |
+
log_update(f"Run {run_id}: β
Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}", run_id)
|
| 254 |
+
|
| 255 |
+
# Save model
|
| 256 |
+
model_path = f"./runs/{run_id}"
|
| 257 |
+
os.makedirs(model_path, exist_ok=True)
|
| 258 |
+
torch.save(model.state_dict(), f"{model_path}/model.pth")
|
| 259 |
+
update_run_status(run_id, "completed", f"Model saved to {model_path}")
|
| 260 |
+
log_update(f"Run {run_id}: πΎ Model checkpoint saved.", run_id)
|
| 261 |
|
| 262 |
except Exception as e:
|
| 263 |
+
log_update(f"Run {run_id}: π₯ FAILED - {str(e)}", run_id)
|
| 264 |
update_run_status(run_id, "failed")
|
| 265 |
finally:
|
| 266 |
with active_run_lock:
|
| 267 |
+
if active_run_id == run_id:
|
| 268 |
+
active_run_id = None
|
| 269 |
start_training_if_free()
|
| 270 |
|
| 271 |
def log_update(message, run_id):
|
| 272 |
+
timestamp = time.strftime("%H:%M:%S")
|
| 273 |
+
full_msg = f"[{timestamp}] {message}"
|
| 274 |
+
print(full_msg) # Also shows in HF Spaces logs
|
| 275 |
+
if run_id > 0:
|
| 276 |
+
update_run_status(run_id, "running", full_msg)
|