Nekochu commited on
Commit
99bba35
·
verified ·
1 Parent(s): d65c6d1

app.py overall upgrade + img gen Chroma

Browse files
Files changed (1) hide show
  1. app.py +452 -128
app.py CHANGED
@@ -1,146 +1,470 @@
1
  import os
2
- from threading import Thread
3
- from typing import Iterator
 
4
 
5
- import gradio as gr
6
  import spaces
 
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
9
 
10
- MAX_MAX_NEW_TOKENS = 2048
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
13
 
14
- DESCRIPTION = """\
15
- # Nekochu/Luminia-13B-v3
16
- This Space demonstrates model Nekochu/Luminia-13B-v3 by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
17
- """
 
 
18
 
19
- LICENSE = """
20
- <p/>
21
- ---.
22
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
24
  models_cache = {}
 
 
 
 
25
 
26
- def load_model(model_id):
27
- if model_id not in models_cache:
28
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
29
- tokenizer = AutoTokenizer.from_pretrained(model_id)
30
- tokenizer.use_default_system_prompt = False
31
- models_cache[model_id] = (model, tokenizer)
32
- return models_cache[model_id]
 
 
 
 
33
 
34
- if not torch.cuda.is_available():
35
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- if torch.cuda.is_available():
38
- model_id = "Nekochu/Luminia-13B-v3"
39
- model, tokenizer = load_model(model_id)
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- @spaces.GPU(duration=120)
43
- def generate(
44
- message: str,
45
- chat_history: list[tuple[str, str]],
46
- system_prompt: str,
47
- model_id: str = "Nekochu/Luminia-13B-v3",
48
- max_new_tokens: int = 1024,
49
- temperature: float = 0.6,
50
- top_p: float = 0.9,
51
- top_k: int = 50,
52
- repetition_penalty: float = 1.2,
53
- ) -> Iterator[str]:
54
- model, tokenizer = load_model(model_id)
55
- conversation = []
56
- if system_prompt:
57
- conversation.append({"role": "system", "content": system_prompt})
58
- for user, assistant in chat_history:
59
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
60
- conversation.append({"role": "user", "content": message})
61
-
62
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
63
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
64
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
65
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
66
- input_ids = input_ids.to(model.device)
67
-
68
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
69
- generate_kwargs = dict(
70
- {"input_ids": input_ids},
71
- streamer=streamer,
72
- max_new_tokens=max_new_tokens,
73
- do_sample=True,
74
- top_p=top_p,
75
- top_k=top_k,
76
- temperature=temperature,
77
- num_beams=1,
78
- repetition_penalty=repetition_penalty,
79
- )
80
- t = Thread(target=model.generate, kwargs=generate_kwargs)
81
- t.start()
82
-
83
- outputs = []
84
- for text in streamer:
85
- outputs.append(text)
86
- yield "".join(outputs)
87
 
 
 
 
 
 
 
88
 
89
- chat_interface = gr.ChatInterface(
90
- fn=generate,
91
- additional_inputs=[
92
- gr.Textbox(label="System prompt", lines=6),
93
- gr.Textbox(label="Model ID", value="Nekochu/Luminia-13B-v3", placeholder="Enter a model ID here, e.g. Nekochu/Llama-2-13B-German-ORPO"),
94
- gr.Slider(
95
- label="Max new tokens",
96
- minimum=1,
97
- maximum=MAX_MAX_NEW_TOKENS,
98
- step=1,
99
- value=DEFAULT_MAX_NEW_TOKENS,
100
- ),
101
- gr.Slider(
102
- label="Temperature",
103
- minimum=0.1,
104
- maximum=4.0,
105
- step=0.1,
106
- value=0.6,
107
- ),
108
- gr.Slider(
109
- label="Top-p (nucleus sampling)",
110
- minimum=0.05,
111
- maximum=1.0,
112
- step=0.05,
113
- value=0.9,
114
- ),
115
- gr.Slider(
116
- label="Top-k",
117
- minimum=1,
118
- maximum=1000,
119
- step=1,
120
- value=50,
121
- ),
122
- gr.Slider(
123
- label="Repetition penalty",
124
- minimum=1.0,
125
- maximum=2.0,
126
- step=0.05,
127
- value=1.2,
128
- ),
129
- ],
130
- stop_btn=None,
131
- examples=[
132
- ["### Instruction: Create stable diffusion metadata based on the given english description. Luminia ### Input: favorites and popular SFW ### Response:"],
133
- ["### Instruction: Provide tips on stable diffusion to optimize low token prompts and enhance quality include prompt example. ### Response:"],
134
- ],
135
- )
136
 
137
- with gr.Blocks(css="style.css") as demo:
138
- gr.Markdown(DESCRIPTION)
139
- with gr.Row():
140
- gr.DuplicateButton(value="GPU Ver", elem_id="duplicate-button")
141
- gr.HTML("""<a href="https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt/tree/Luminia-13B-v3-GGUF" style="margin:0 0 0 8px; padding:2px 8px; border:1px solid; border-radius:4px; text-decoration:none; font-size:0.9em;">or clone only the GGUF branch for free CPU Ver</a>""")
142
- chat_interface.render()
143
- gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- if __name__ == "__main__":
146
- demo.queue(max_size=20).launch()
 
1
  import os
2
+ import gc
3
+ import subprocess
4
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
5
 
 
6
  import spaces
7
+ import gradio as gr
8
  import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
10
+ from threading import Thread, Event
11
+ import time
12
+ import uuid
13
+ import re
14
+ from diffusers import ChromaPipeline
15
 
16
+ # Pre-load ONLY Chroma (not LLMs, to support custom models)
17
+ print("Loading Chroma1-HD...")
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Device at module level: {device}")
20
 
21
+ chroma_pipe = ChromaPipeline.from_pretrained(
22
+ "lodestones/Chroma1-HD",
23
+ torch_dtype=torch.bfloat16
24
+ )
25
+ chroma_pipe = chroma_pipe.to(device)
26
+ print("✓ Chroma1-HD ready")
27
 
28
+ MODEL_CONFIGS = {
29
+ "Nekochu/Luminia-13B-v3": {
30
+ "system": "",
31
+ "examples": [
32
+ "### Instruction:\nCreate stable diffusion metadata based on the given english description. Luminia\n\n### Input:\nfavorites and popular SFW",
33
+ "### Instruction:\nProvide tips on stable diffusion to optimize low token prompts and enhance quality include prompt example."
34
+ ],
35
+ "supports_image_gen": True,
36
+ "sd_temp": 0.3,
37
+ "sd_top_p": 0.8,
38
+ "branch": None # Uses main/default branch
39
+ },
40
+ "Nekochu/Luminia-8B-v4-Chan": {
41
+ "system": "write a response like a 4chan user",
42
+ "examples": [],
43
+ "supports_image_gen": False,
44
+ "branch": "Llama-3-8B-4Chan_SD_QLoRa"
45
+ },
46
+ "Nekochu/Luminia-8B-RP": {
47
+ "system": "You are a knowledgeable and empathetic mental health professional.",
48
+ "examples": ["How to cope with anxiety?"],
49
+ "supports_image_gen": False,
50
+ "branch": None
51
+ }
52
+ }
53
 
54
+ DEFAULT_MODELS = list(MODEL_CONFIGS.keys())
55
  models_cache = {}
56
+ stop_event = Event()
57
+ current_thread = None
58
+ MAX_CACHE_SIZE = 2
59
+ DEFAULT_MODEL = DEFAULT_MODELS[0]
60
 
61
+ def parse_model_id(model_id_str):
62
+ """Parse model ID and optional branch (format: 'model_id:branch')"""
63
+ if ':' in model_id_str:
64
+ parts = model_id_str.split(':', 1)
65
+ return parts[0], parts[1]
66
+
67
+ if model_id_str in MODEL_CONFIGS: # Check if it's a known model with a specific branch
68
+ config = MODEL_CONFIGS[model_id_str]
69
+ return model_id_str, config.get('branch', None)
70
+
71
+ return model_id_str, None
72
 
73
+ def parse_sd_metadata(text: str):
74
+ """Parse SD metadata"""
75
+ metadata = {
76
+ 'prompt': '',
77
+ 'negative_prompt': '',
78
+ 'steps': 25,
79
+ 'cfg_scale': 7.0,
80
+ 'seed': 42,
81
+ 'width': 1024,
82
+ 'height': 1024
83
+ }
84
+
85
+ if not text:
86
+ metadata['prompt'] = '(masterpiece, best quality), 1girl'
87
+ return metadata
88
+
89
+ try:
90
+ if "Negative prompt:" in text:
91
+ parts = text.split("Negative prompt:", 1)
92
+ metadata['prompt'] = parts[0].strip().rstrip('.,;')[:500]
93
+
94
+ if len(parts) > 1:
95
+ neg_section = parts[1]
96
+ param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', neg_section)
97
+ if param_match:
98
+ metadata['negative_prompt'] = neg_section[:param_match.start()].strip().rstrip('.,;')[:300]
99
+ else:
100
+ metadata['negative_prompt'] = neg_section.strip().rstrip('.,;')[:300]
101
+ else:
102
+ param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', text)
103
+ if param_match:
104
+ metadata['prompt'] = text[:param_match.start()].strip().rstrip('.,;')[:500]
105
+ else:
106
+ metadata['prompt'] = text.strip()[:500]
107
+
108
+ patterns = {
109
+ 'Steps': (r'Steps:\s*(\d+)', lambda x: min(int(x), 30)),
110
+ 'CFG scale': (r'CFG scale:\s*([\d.]+)', float),
111
+ 'Seed': (r'Seed:\s*(\d+)', lambda x: int(x) % (2**32)),
112
+ 'Size': (r'Size:\s*(\d+)x(\d+)', None)
113
+ }
114
+
115
+ for key, (pattern, converter) in patterns.items():
116
+ match = re.search(pattern, text)
117
+ if match:
118
+ try:
119
+ if key == 'Size':
120
+ metadata['width'] = min(max(int(match.group(1)), 512), 1536)
121
+ metadata['height'] = min(max(int(match.group(2)), 512), 1536)
122
+ else:
123
+ metadata[key.lower().replace(' ', '_')] = converter(match.group(1))
124
+ except:
125
+ pass
126
+ except:
127
+ pass
128
+
129
+ if not metadata['prompt']:
130
+ metadata['prompt'] = '(masterpiece, best quality), 1girl'
131
+
132
+ return metadata
133
 
134
+ def clear_old_cache():
135
+ global models_cache
136
+ if len(models_cache) >= MAX_CACHE_SIZE:
137
+ oldest = min(models_cache.items(), key=lambda x: x[1].get('last_used', 0))
138
+ del models_cache[oldest[0]]
139
+ gc.collect()
140
+ torch.cuda.empty_cache()
141
 
142
+ @spaces.GPU(duration=119)
143
+ def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty):
144
+ """Text generation with branch support"""
145
+ global models_cache, stop_event, current_thread
146
+ stop_event.clear()
147
+
148
+ model_id, branch = parse_model_id(model_id_str) # Parse model ID and branch
149
+ cache_key = f"{model_id}:{branch}" if branch else model_id
150
+
151
+ config = MODEL_CONFIGS.get(model_id, {})
152
+ if "Luminia-13B-v3" in model_id and ("stable diffusion" in message.lower() or "metadata" in message.lower()):
153
+ temp = config.get('sd_temp', 0.3)
154
+ top_p = config.get('sd_top_p', 0.8)
155
+ print(f"Using SD settings: temp={temp}, top_p={top_p}")
156
+
157
+ if cache_key not in models_cache:
158
+ clear_old_cache()
159
+ try:
160
+ yield history + [[message, f"📥 Loading {model_id}{f' ({branch})' if branch else ''}..."]], "Loading..."
161
+
162
+ # Load with branch/revision support
163
+ load_kwargs = {"trust_remote_code": True}
164
+ if branch:
165
+ load_kwargs["revision"] = branch
166
+ print(f"Loading from branch: {branch}")
167
+
168
+ tokenizer = AutoTokenizer.from_pretrained(model_id, **load_kwargs)
169
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
170
+
171
+ bnb_config = BitsAndBytesConfig(
172
+ load_in_4bit=True,
173
+ bnb_4bit_compute_dtype=torch.bfloat16,
174
+ bnb_4bit_quant_type="nf4",
175
+ bnb_4bit_use_double_quant=True
176
+ )
177
+
178
+ model_kwargs = {
179
+ "quantization_config": bnb_config,
180
+ "device_map": "auto",
181
+ "trust_remote_code": True,
182
+ "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else None,
183
+ "low_cpu_mem_usage": True
184
+ }
185
+ if branch:
186
+ model_kwargs["revision"] = branch
187
+
188
+ model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
189
+
190
+ models_cache[cache_key] = {
191
+ "model": model,
192
+ "tokenizer": tokenizer,
193
+ "last_used": time.time()
194
+ }
195
+
196
+ except Exception as e:
197
+ yield history + [[message, f"❌ Failed: {str(e)[:200]}"]], "Error"
198
+ return
199
+
200
+ models_cache[cache_key]['last_used'] = time.time()
201
+ model = models_cache[cache_key]["model"]
202
+ tokenizer = models_cache[cache_key]["tokenizer"]
203
+
204
+ prompt = ""
205
+ if system:
206
+ prompt = f"{system}\n\n"
207
+
208
+ for user_msg, assistant_msg in history:
209
+ if "### Instruction:" in user_msg:
210
+ prompt += f"{user_msg}\n### Response:\n{assistant_msg}\n\n"
211
+ else:
212
+ prompt += f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}\n\n"
213
+
214
+ if "### Instruction:" in message and "### Response:" not in message:
215
+ prompt += f"{message}\n### Response:\n"
216
+ elif "### Instruction:" not in message:
217
+ prompt += f"### Instruction:\n{message}\n\n### Response:\n"
218
+ else:
219
+ prompt += message
220
+
221
+ print(f"Prompt ending: ...{prompt[-200:]}")
222
+
223
+ try:
224
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
225
+ input_tokens = inputs['input_ids'].shape[1]
226
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
227
+ except Exception as e:
228
+ yield history + [[message, f"❌ Tokenization failed: {str(e)}"]], "Error"
229
+ return
230
+
231
+ print(f"📝 {input_tokens} tokens | Temp: {temp} | Top-p: {top_p}")
232
+
233
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5)
234
+ gen_kwargs = {
235
+ **inputs,
236
+ "streamer": streamer,
237
+ "max_new_tokens": min(max_tokens, 2048),
238
+ "temperature": max(temp, 0.01),
239
+ "top_p": top_p,
240
+ "top_k": top_k,
241
+ "repetition_penalty": rep_penalty,
242
+ "do_sample": temp > 0.01,
243
+ "pad_token_id": tokenizer.pad_token_id
244
+ }
245
+
246
+ current_thread = Thread(target=model.generate, kwargs=gen_kwargs)
247
+ current_thread.start()
248
+
249
+ start_time = time.time()
250
+ partial = ""
251
+ token_count = 0
252
+
253
+ try:
254
+ for text in streamer:
255
+ if stop_event.is_set():
256
+ break
257
+ partial += text
258
+ token_count = len(tokenizer.encode(partial, add_special_tokens=False))
259
+ elapsed = time.time() - start_time
260
+ if elapsed > 0:
261
+ yield history + [[message, partial]], f"⚡ {token_count} @ {token_count/elapsed:.1f} t/s"
262
+ except:
263
+ pass
264
+ finally:
265
+ if current_thread.is_alive():
266
+ stop_event.set()
267
+ current_thread.join(timeout=2)
268
+
269
+ final_time = time.time() - start_time
270
+ yield history + [[message, partial]], f"✅ {token_count} tokens in {final_time:.1f}s"
271
 
272
+ @spaces.GPU()
273
+ def generate_image_gpu(text_output):
274
+ """Image generation with pre-loaded Chroma"""
275
+ global chroma_pipe
276
+
277
+ if not text_output or text_output.isspace():
278
+ return None, "❌ No valid text", gr.update(visible=False)
279
+
280
+ try:
281
+ metadata = parse_sd_metadata(text_output)
282
+ print(f"Generating: {metadata['width']}x{metadata['height']} | Steps: {metadata['steps']}")
283
+
284
+ if torch.cuda.is_available():
285
+ chroma_pipe = chroma_pipe.to("cuda")
286
+
287
+ generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(metadata['seed'])
288
+
289
+ image = chroma_pipe(
290
+ prompt=metadata['prompt'],
291
+ negative_prompt=metadata['negative_prompt'],
292
+ generator=generator,
293
+ num_inference_steps=metadata['steps'],
294
+ guidance_scale=metadata['cfg_scale'],
295
+ width=metadata['width'],
296
+ height=metadata['height']
297
+ ).images[0]
298
+
299
+ status = f"✅ {metadata['width']}x{metadata['height']} | {metadata['steps']} steps | CFG: {metadata['cfg_scale']} | Seed: {metadata['seed']}"
300
+ return image, status, gr.update(visible=False)
301
+
302
+ except Exception as e:
303
+ import traceback
304
+ traceback.print_exc()
305
+ return None, f"❌ Failed: {str(e)[:200]}", gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
306
 
307
+ def stop_generation():
308
+ global stop_event, current_thread
309
+ stop_event.set()
310
+ if current_thread and current_thread.is_alive():
311
+ current_thread.join(timeout=2)
312
+ return gr.update(visible=True), gr.update(visible=False)
313
 
314
+ css = """
315
+ #chatbot {height: 305px;}
316
+ #input-row {display: flex; gap: 4px;}
317
+ #input-box {flex-grow: 1;}
318
+ #button-group {display: inline-flex; flex-direction: column; gap: 2px; width: 45px;}
319
+ #button-group button {width: 40px; height: 28px; padding: 2px; font-size: 14px;}
320
+ #status {font-size: 11px; color: #666; margin-top: 2px;}
321
+ #image-output {max-height: 400px; margin-top: 8px;}
322
+ #img-loading {font-size: 11px; color: #666; margin: 4px 0;}
323
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
326
+ with gr.Row():
327
+ with gr.Column(scale=4):
328
+ chatbot = gr.Chatbot(elem_id="chatbot", type="tuples")
329
+
330
+ with gr.Row(elem_id="input-row"):
331
+ msg = gr.Textbox(
332
+ label="Instruction",
333
+ lines=3,
334
+ elem_id="input-box",
335
+ value=MODEL_CONFIGS[DEFAULT_MODEL]["examples"][0] if MODEL_CONFIGS[DEFAULT_MODEL]["examples"] else "",
336
+ scale=10
337
+ )
338
+ with gr.Column(elem_id="button-group", scale=1, min_width=45):
339
+ submit = gr.Button("▶", variant="primary", size="sm")
340
+ stop = gr.Button("⏹", variant="stop", size="sm", visible=False)
341
+ undo = gr.Button("↩", size="sm")
342
+ clear = gr.Button("🗑", size="sm")
343
+
344
+ status = gr.Markdown("", elem_id="status")
345
+
346
+ with gr.Row():
347
+ image_btn = gr.Button("🎨 Generate Image using Chroma1-HD", visible=False, variant="secondary")
348
+ last_text = gr.Textbox(visible=False)
349
+
350
+ img_loading = gr.Markdown("", visible=False, elem_id="img-loading")
351
+ image_output = gr.Image(visible=False, elem_id="image-output")
352
+ image_status = gr.Markdown("", visible=False)
353
+
354
+ examples = gr.Examples(
355
+ examples=[[ex] for ex in MODEL_CONFIGS[DEFAULT_MODEL]["examples"] if ex],
356
+ inputs=msg,
357
+ label="Examples"
358
+ )
359
+
360
+ with gr.Column(scale=1):
361
+ model = gr.Dropdown(
362
+ DEFAULT_MODELS,
363
+ value=DEFAULT_MODEL,
364
+ label="Model",
365
+ allow_custom_value=True,
366
+ info="Custom HF ID + optional :branch"
367
+ )
368
+
369
+ with gr.Accordion("Settings", open=False):
370
+ system = gr.Textbox(
371
+ label="System Prompt",
372
+ value=MODEL_CONFIGS[DEFAULT_MODEL]["system"],
373
+ lines=2
374
+ )
375
+ temp = gr.Slider(0.1, 1.0, 0.35, label="Temperature")
376
+ top_p = gr.Slider(0.5, 1.0, 0.85, label="Top-p")
377
+ top_k = gr.Slider(10, 100, 40, label="Top-k")
378
+ rep_penalty = gr.Slider(1.0, 1.5, 1.1, label="Repetition Penalty")
379
+ max_tokens = gr.Slider(256, 2048, 1024, label="Max Tokens")
380
+
381
+ export_btn = gr.Button("💾 Export", size="sm")
382
+ export_file = gr.File(visible=False)
383
+
384
+ def update_ui_on_model_change(model_id_str):
385
+ """Update all UI components when model changes"""
386
+ model_id, branch = parse_model_id(model_id_str)
387
+ config = MODEL_CONFIGS.get(model_id, {"system": "", "examples": [""], "supports_image_gen": False})
388
+ return (
389
+ config["system"],
390
+ config["examples"][0] if config["examples"] else "",
391
+ gr.update(visible=False), # image_btn
392
+ "", # last_text
393
+ None, # image_output (clear image)
394
+ gr.update(visible=False), # image_output visibility
395
+ "", # image_status text
396
+ gr.update(visible=False), # image_status visibility
397
+ gr.update(visible=False) # img_loading visibility
398
+ )
399
+
400
+ def check_image_availability(model_id_str, history):
401
+ model_id, _ = parse_model_id(model_id_str)
402
+ if "Luminia-13B-v3" in model_id and history and len(history) > 0:
403
+ return gr.update(visible=True), history[-1][1]
404
+ return gr.update(visible=False), ""
405
+
406
+ submit.click(
407
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
408
+ None, [submit, stop]
409
+ ).then(
410
+ generate_text_gpu,
411
+ [model, msg, chatbot, system, temp, top_p, top_k, max_tokens, rep_penalty],
412
+ [chatbot, status]
413
+ ).then(
414
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
415
+ None, [submit, stop]
416
+ ).then(
417
+ check_image_availability,
418
+ [model, chatbot],
419
+ [image_btn, last_text]
420
+ )
421
+
422
+ stop.click(stop_generation, None, [submit, stop])
423
+
424
+ image_btn.click(
425
+ lambda: gr.update(value="🎨 Generating...", visible=True),
426
+ None, img_loading
427
+ ).then(
428
+ generate_image_gpu,
429
+ last_text,
430
+ [image_output, image_status, img_loading]
431
+ ).then(
432
+ lambda img: (gr.update(visible=img is not None), gr.update(visible=True)),
433
+ image_output,
434
+ [image_output, image_status]
435
+ )
436
+
437
+ model.change(
438
+ update_ui_on_model_change,
439
+ model,
440
+ [system, msg, image_btn, last_text, image_output, image_output, image_status, image_status, img_loading]
441
+ )
442
+
443
+ undo.click(
444
+ lambda h: h[:-1] if h else h,
445
+ chatbot, chatbot
446
+ ).then(
447
+ check_image_availability,
448
+ [model, chatbot],
449
+ [image_btn, last_text]
450
+ )
451
+
452
+ clear.click(
453
+ lambda: ([], "", "", None, "", gr.update(visible=False), "", gr.update(visible=False)),
454
+ None, [chatbot, msg, status, image_output, image_status, image_btn, last_text, img_loading]
455
+ )
456
+
457
+ def export_chat(history):
458
+ if not history:
459
+ return None
460
+ content = "\n\n".join([f"User: {u}\n\nAssistant: {a}" for u, a in history])
461
+ path = f"chat_{uuid.uuid4().hex[:8]}.txt"
462
+ with open(path, "w", encoding="utf-8") as f:
463
+ f.write(content)
464
+ return path
465
+
466
+ export_btn.click(export_chat, chatbot, export_file).then(
467
+ lambda: gr.update(visible=True), None, export_file
468
+ )
469
 
470
+ demo.queue().launch()