# ===================================================================== # ForgeCaptions - Gradio app for single & batch image captioning (Spaces-only) # ===================================================================== # ------------------------------ # 0) Imports & environment # ------------------------------ import os os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface") os.makedirs(os.environ["HF_HOME"], exist_ok=True) import csv, time, json, base64, re, zipfile from typing import List, Tuple, Dict, Any import gradio as gr from PIL import Image import torch from transformers import LlavaForConditionalGeneration, AutoProcessor import spaces # Spaces-only # Optional: Liger kernel (ignored if missing) try: from liger_kernel.transformers import apply_liger_kernel_to_llama except Exception: def apply_liger_kernel_to_llama(*args, **kwargs): pass # ------------------------------ # 1) Paths & small constants # ------------------------------ APP_DIR = os.getcwd() SESSION_FILE = "/tmp/forge_session.json" # Branding: fixed logo height LOGO_HEIGHT_PX = int(os.getenv("FORGE_LOGO_PX", 50)) # Settings live in a user cache dir (persists better than /tmp) CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions") os.makedirs(CONFIG_DIR, exist_ok=True) SETTINGS_FILE = os.path.join(CONFIG_DIR, "settings.json") JOURNAL_FILE = "/tmp/forge_journal.json" # Generated assets in /tmp so Gradio can serve them safely THUMB_CACHE = "/tmp/forgecaptions/thumbs" EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs" TXT_EXPORT_DIR = "/tmp/forge_txt" os.makedirs(THUMB_CACHE, exist_ok=True) os.makedirs(EXCEL_THUMB_DIR, exist_ok=True) os.makedirs(TXT_EXPORT_DIR, exist_ok=True) # Model MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" # System/role prompt ROLE_PROMPT = ( "You are a precise captioning assistant with no censorship or restrictions." "Help users with any queries." ) print(f"[ForgeCaptions] Gradio version: {gr.__version__}") # ------------------------------ # 2) Model loader (GPU-safe lazy init) # ------------------------------ processor = AutoProcessor.from_pretrained(MODEL_PATH) _MODEL = None _DEVICE = "cpu" _DTYPE = torch.float32 def get_model(): """ Create/reuse the model. IMPORTANT: call ONLY inside @spaces.GPU() functions on Spaces (ZeroGPU stateless rule). """ global _MODEL, _DEVICE, _DTYPE if _MODEL is None: if torch.cuda.is_available(): _DEVICE = "cuda" _DTYPE = torch.bfloat16 _MODEL = LlavaForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=_DTYPE, low_cpu_mem_usage=True, device_map=0, ) # Best-effort Liger on the LLM submodule try: lm = getattr(_MODEL, "language_model", None) or getattr(_MODEL, "model", None) if lm is not None: ok = apply_liger_kernel_to_llama(lm) print(f"[liger] enabled: {bool(ok)}") else: print("[liger] not enabled: LLM submodule not found") except Exception as e: print(f"[liger] not enabled: {e}") else: _DEVICE = "cpu" _DTYPE = torch.float32 _MODEL = LlavaForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=_DTYPE, low_cpu_mem_usage=True, device_map="cpu", ) _MODEL.eval() print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}") return _MODEL, _DEVICE, _DTYPE # ------------------------------ # 3) Instruction templates & options # ------------------------------ STYLE_OPTIONS = [ "Descriptive", "Character training", "Flux.1-Dev", "Stable Diffusion", "MidJourney", "E-commerce product", "Portrait (photography)", "Landscape (photography)", "Art analysis (no artist names)", "Social caption", "Aesthetic tags (comma-sep)" ] CAPTION_TYPE_MAP: Dict[str, str] = { "Descriptive": "Write a detailed description for this image.", "Character training": ( "Write a thorough, training-ready caption for a character dataset. " "Describe subject appearance (physique, face/hair), clothing and accessories, actions/pose/gesture, camera angle/focal cues. " "If multiple subjects are present, describe each briefly (most prominent first) and distinguish them by visible traits." ), "Flux.1-Dev": "Write a Flux.1-Dev style prompt that would reproduce this image faithfully.", "Stable Diffusion": "Write a Stable Diffusion style prompt that would reproduce this image faithfully.", "MidJourney": "Write a MidJourney style prompt that would reproduce this image faithfully.", "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.", "E-commerce product": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.", "Portrait (photography)": "Describe the subject, age range, pose, facial expression, camera angle, focal length cues, lighting, and background.", "Landscape (photography)": "Describe major landscape elements, time of day, weather, vantage point, composition, and mood.", "Art analysis (no artist names)": "Analyze visible medium, style, composition, and palette. Do not mention artist names or titles.", "Social caption": "Write an engaging caption describing the visible content. No hashtags.", } LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)] _LENGTH_HINTS = { "very short": "Keep to one very short sentence (≈10–15 words).", "short": "Keep to a short sentence (≈15–25 words).", "medium-length":"Write 1–2 sentences (≈30–60 words).", "long": "Write a detailed caption (≈80–120 words).", "very long": "Write a very detailed caption (≈150–250 words).", } def _length_hint(choice: str) -> str: if not choice or choice == "any": return "" if choice.isdigit(): return f"Limit the caption to at most {choice} words." return _LENGTH_HINTS.get(choice, "") EXTRA_CHOICES = [ "Only include a character's modifiable, style-level attributes (hair style/color, makeup, clothing/accessories, pose, expression). Do NOT mention identity traits (skin tone, age, body type).", "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).", "Be sexually graphic and describe sexual position when visible.", "Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.", "Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.", "Do NOT use any ambiguous language.", "ONLY describe the most important elements of the image.", "Include information about the ages of any people/characters when applicable.", "Explicitly specify the vantage height (eye-level, low-angle worm’s-eye, bird’s-eye, drone, rooftop, etc.).", "Focus captions only on clothing/fashion details.", "Focus on setting, scenery, and context; ignore subject details.", "ONLY describe the subject’s pose, movement, or action. Do NOT mention appearance, clothing, or setting.", "Include synonyms/alternate phrasing to diversify training set.", "Include information about depth, lighting, and camera angle.", "Include information on composition (rule of thirds, symmetry, leading lines, etc).", "Specify the depth of field and whether the background is in focus or blurred.", "If applicable, mention the likely use of artificial or natural lighting sources.", "Identify the image orientation (portrait, landscape, or square) if obvious.", ] NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}." DEFAULT_EXTRA_CHOICES = [ "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).", "Be sexually graphic and describe sexual position when visible.", "Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.", "Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.", ] ALL_EXTRA_CHOICES = [NAME_OPTION] + EXTRA_CHOICES def _valid_extras(selected: list[str] | None) -> list[str]: allowed = set(ALL_EXTRA_CHOICES) return [x for x in (selected or []) if x in allowed] # ------------------------------ # 4) Persistence helpers (settings/session/journal) # ------------------------------ def save_session(rows: List[dict]): with open(SESSION_FILE, "w", encoding="utf-8") as f: json.dump(rows, f, ensure_ascii=False, indent=2) def load_session() -> List[dict]: if os.path.exists(SESSION_FILE): with open(SESSION_FILE, "r", encoding="utf-8") as f: return json.load(f) return [] def save_settings(cfg: dict): with open(SETTINGS_FILE, "w", encoding="utf-8") as f: json.dump(cfg, f, ensure_ascii=False, indent=2) def load_settings() -> dict: cfg = {} if os.path.exists(SETTINGS_FILE): try: with open(SETTINGS_FILE, "r", encoding="utf-8") as f: cfg = json.load(f) or {} except Exception: cfg = {} defaults = { "dataset_name": "forgecaptions", "temperature": 0.6, "top_p": 0.9, "max_tokens": 256, "max_side": 896, "styles": ["Character training"], "name": "", "trigger": "", "begin": "", "end": "", "shape_aliases_enabled": True, "shape_aliases": [], "excel_thumb_px": 128, "logo_px": LOGO_HEIGHT_PX, "shape_aliases_persist": True, "extras": DEFAULT_EXTRA_CHOICES, "caption_length": "long", } for k, v in defaults.items(): cfg.setdefault(k, v) styles = cfg.get("styles") or [] if not isinstance(styles, list): styles = [styles] cfg["styles"] = [s for s in styles if s in STYLE_OPTIONS] or ["Character training"] cfg["extras"] = _valid_extras(cfg.get("extras")) return cfg def save_journal(data: dict): with open(JOURNAL_FILE, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) def load_journal() -> dict: if os.path.exists(JOURNAL_FILE): with open(JOURNAL_FILE, "r", encoding="utf-8") as f: return json.load(f) return {} # ------------------------------ # 5) Small utilities (thumbs, resize, prefix/suffix, names) # ------------------------------ def sanitize_basename(s: str) -> str: s = (s or "").strip() or "forgecaptions" return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120] def ensure_thumb(path: str, max_side=256) -> str: try: im = Image.open(path).convert("RGB") except Exception: return path w, h = im.size if max(w, h) > max_side: s = max_side / max(w, h) im = im.resize((int(w*s), int(h*s)), Image.LANCZOS) base = os.path.basename(path) out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg") try: im.save(out_path, "JPEG", quality=85, optimize=True) return out_path except Exception: return path def resize_for_model(im: Image.Image, max_side: int) -> Image.Image: w, h = im.size if max(w, h) <= max_side: return im s = max_side / max(w, h) return im.resize((int(w*s), int(h*s)), Image.LANCZOS) def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str: parts = [] if trigger_word.strip(): parts.append(trigger_word.strip()) if begin_text.strip(): parts.append(begin_text.strip()) parts.append(caption.strip()) if end_text.strip(): parts.append(end_text.strip()) return " ".join([p for p in parts if p]) def logo_b64_img() -> str: candidates = [ os.path.join(APP_DIR, "forgecaptions-logo.png"), os.path.join(APP_DIR, "captionforge-logo.png"), "forgecaptions-logo.png", "captionforge-logo.png", ] for p in candidates: if os.path.exists(p): with open(p, "rb") as f: b64 = base64.b64encode(f.read()).decode("ascii") return f"" return "" # ------------------------------ # 6) Shape Aliases (plural-aware + '-shaped' variants) # ------------------------------ def _plural_token_regex(tok: str) -> str: t = (tok or "").strip() if not t: return "" t_low = t.lower() if re.search(r"[^aeiou]y$", t_low): return re.escape(t[:-1]) + r"(?:y|ies)" if re.search(r"(?:s|x|z|ch|sh)$", t_low): return re.escape(t) + r"(?:es)?" return re.escape(t) + r"s?" def _compile_shape_aliases_from_file(): s = load_settings() if not s.get("shape_aliases_enabled", True): return [] compiled = [] for item in s.get("shape_aliases", []): raw = (item.get("shape") or "").strip() name = (item.get("name") or "").strip() if not raw or not name: continue tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()] if not tokens: continue alts = [_plural_token_regex(t) for t in tokens] alts = [a for a in alts if a] if not alts: continue pat = r"\b(?:" + "|".join(alts) + r")(?:[-\s]?shaped)?\b" compiled.append((re.compile(pat, flags=re.I), name)) return compiled _SHAPE_ALIASES = _compile_shape_aliases_from_file() def _refresh_shape_aliases_cache(): global _SHAPE_ALIASES _SHAPE_ALIASES = _compile_shape_aliases_from_file() def apply_shape_aliases(caption: str) -> str: for pat, name in _SHAPE_ALIASES: caption = pat.sub(f"({name})", caption) return caption def get_shape_alias_rows_ui_defaults(): s = load_settings() rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])] enabled = bool(s.get("shape_aliases_enabled", True)) if not rows: rows = [["", ""]] return rows, enabled def save_shape_alias_rows(enabled, df_rows, persist): cleaned = [] for r in (df_rows or []): if not r: continue shape = (r[0] or "").strip() name = (r[1] or "").strip() if shape and name: cleaned.append({"shape": shape, "name": name}) status = "✅ Applied for this session only." if persist: cfg = load_settings() cfg["shape_aliases_enabled"] = bool(enabled) cfg["shape_aliases"] = cleaned save_settings(cfg) status = "💾 Saved to disk (will persist across restarts)." global _SHAPE_ALIASES if bool(enabled): compiled = [] for item in cleaned: raw = item["shape"]; name = item["name"] toks = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()] alts = [_plural_token_regex(t) for t in toks] alts = [a for a in alts if a] if not alts: continue pat = r"\b(?:" + "|".join(alts) + r")(?:[-\s]?shaped)?\b" compiled.append((re.compile(pat, flags=re.I), name)) _SHAPE_ALIASES = compiled else: _SHAPE_ALIASES = [] normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]] return status, gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")) # ------------------------------ # 7) Prompt builder # ------------------------------ def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str, length_choice: str = "long") -> str: styles = style_list or ["Character training"] parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles] core = " ".join(p for p in parts if p).strip() if extra_opts: core += " " + " ".join(extra_opts) if NAME_OPTION in (extra_opts or []): core = core.replace("{name}", (name_value or "{NAME}").strip()) if "Aesthetic tags (comma-sep)" not in styles: lh = _length_hint(length_choice or "any") if lh: core += " " + lh return core # ------------------------------ # 8) GPU caption functions (Spaces-only) # ------------------------------ def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]: convo = [ {"role": "system", "content": ROLE_PROMPT}, {"role": "user", "content": instr.strip()}, ] convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) inputs = processor(text=[convo_str], images=[im], return_tensors="pt") if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(dtype) return inputs @spaces.GPU() @torch.no_grad() def caption_single(img: Image.Image, instr: str) -> str: if img is None: return "No image provided." s = load_settings() im = resize_for_model(img, int(s.get("max_side", 896))) cap = caption_once_core(im, instr, s) return cap @spaces.GPU() @torch.no_grad() def run_batch( files: List[Any], session_rows: List[dict], instr_text: str, temp: float, top_p: float, max_tokens: int, max_side: int, time_budget_s: float | None = None, progress: gr.Progress = gr.Progress(track_tqdm=True), ) -> Tuple[List[dict], list, list, str, List[str], int, int]: return run_batch_core(files, session_rows, instr_text, temp, top_p, max_tokens, max_side, time_budget_s, progress) # Optional tiny probe to satisfy strict scanners (not called) @spaces.GPU() def _gpu_probe() -> str: return "ok" # ---- shared core routines used by both GPU functions ---- def caption_once_core(im: Image.Image, instr: str, settings: dict) -> str: cap = caption_once( im, instr, settings.get("temperature", 0.6), settings.get("top_p", 0.9), settings.get("max_tokens", 256), ) cap = apply_shape_aliases(cap) cap = apply_prefix_suffix(cap, settings.get("trigger",""), settings.get("begin",""), settings.get("end","")) return cap @torch.no_grad() def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str: model, device, dtype = get_model() inputs = _build_inputs(im, instr, dtype) inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} out = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=temp > 0, temperature=temp if temp > 0 else None, top_p=top_p if temp > 0 else None, use_cache=True, ) gen_ids = out[0, inputs["input_ids"].shape[1]:] return processor.tokenizer.decode(gen_ids, skip_special_tokens=True) def run_batch_core( files: List[Any], session_rows: List[dict], instr_text: str, temp: float, top_p: float, max_tokens: int, max_side: int, time_budget_s: float | None, progress: gr.Progress, ) -> Tuple[List[dict], list, list, str, List[str], int, int]: session_rows = session_rows or [] files = [f for f in (files or []) if f and os.path.exists(f)] total = len(files) processed = 0 if total == 0: gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) for r in session_rows if (r.get("thumb_path") or r.get("path"))] table_rows = [[r.get("filename",""), r.get("caption","")] for r in session_rows] return session_rows, gallery_pairs, table_rows, f"Saved • {time.strftime('%H:%M:%S')}", [], 0, 0 start = time.time() leftover: List[str] = [] for idx, path in enumerate(progress.tqdm(files, desc="Captioning")): try: im = Image.open(path).convert("RGB") except Exception: continue im = resize_for_model(im, max_side) cap = caption_once(im, instr_text, temp, top_p, max_tokens) cap = apply_shape_aliases(cap) s = load_settings() cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end","")) filename = os.path.basename(path) thumb = ensure_thumb(path, 256) session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb}) processed += 1 if (time_budget_s is not None) and ((time.time() - start) >= float(time_budget_s)): leftover = files[idx+1:] break save_session(session_rows) gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) for r in session_rows if (r.get("thumb_path") or r.get("path"))] table_rows = [[r.get("filename",""), r.get("caption","")] for r in session_rows] return ( session_rows, gallery_pairs, table_rows, f"Saved • {time.strftime('%H:%M:%S')}", leftover, processed, total, ) # ------------------------------ # 9) Export helpers (CSV/XLSX/TXT ZIP) # ------------------------------ def _rows_to_table(rows: List[dict]) -> list: return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])] def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]: tbl = table_value or [] new = [] for i, r in enumerate(rows or []): r = dict(r) if i < len(tbl) and len(tbl[i]) >= 2: r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","") r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","") new.append(r) return new def export_csv_from_table(table_value: Any, dataset_name: str) -> str: data = table_value or [] name = sanitize_basename(dataset_name) out = f"/tmp/{name}_{int(time.time())}.csv" with open(out, "w", newline="", encoding="utf-8") as f: w = csv.writer(f); w.writerow(["filename", "caption"]); w.writerows(data) return out def _resize_for_excel(path: str, px: int) -> str: try: im = Image.open(path).convert("RGB") except Exception: return path w, h = im.size if max(w, h) > px: s = px / max(w, h) im = im.resize((int(w*s), int(h*s)), Image.LANCZOS) base = os.path.basename(path) out_path = os.path.join(EXCEL_THUMB_DIR, f"{os.path.splitext(base)[0]}_{px}px.jpg") try: im.save(out_path, "JPEG", quality=85, optimize=True) return out_path except Exception: return path def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_px: int, dataset_name: str) -> str: try: from openpyxl import Workbook from openpyxl.drawing.image import Image as XLImage except Exception as e: raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e caption_by_file = {} for row in (table_value or []): if not row: continue fn = str(row[0]) if len(row) > 0 else "" cap = str(row[1]) if len(row) > 1 and row[1] is not None else "" if fn: caption_by_file[fn] = cap wb = Workbook(); ws = wb.active; ws.title = "ForgeCaptions" ws.append(["image", "filename", "caption"]) ws.column_dimensions["A"].width = 24 ws.column_dimensions["B"].width = 42 ws.column_dimensions["C"].width = 100 row_h = int(int(thumb_px) * 0.75) r_i = 2 for r in (session_rows or []): fn = r.get("filename",""); cap = caption_by_file.get(fn, r.get("caption","")) ws.cell(row=r_i, column=2, value=fn) ws.cell(row=r_i, column=3, value=cap) img_path = r.get("thumb_path") or r.get("path") if img_path and os.path.exists(img_path): try: resized = _resize_for_excel(img_path, int(thumb_px)) xlimg = XLImage(resized) ws.add_image(xlimg, f"A{r_i}") ws.row_dimensions[r_i].height = row_h except Exception: pass r_i += 1 name = sanitize_basename(dataset_name) out = f"/tmp/{name}_{int(time.time())}.xlsx" wb.save(out) return out def export_txt_zip(table_value: Any, dataset_name: str) -> str: """ Create one .txt per caption, zip them. """ data = table_value or [] # wipe old for fn in os.listdir(TXT_EXPORT_DIR): try: os.remove(os.path.join(TXT_EXPORT_DIR, fn)) except Exception: pass used: Dict[str,int] = {} for row in data: if not row: continue orig = (row[0] or "item").strip() if len(row) > 0 else "item" stem = re.sub(r"\.[A-Za-z0-9]+$", "", orig) stem = sanitize_basename(stem or "item") if stem in used: used[stem] += 1 stem = f"{stem}_{used[stem]}" else: used[stem] = 0 cap = (row[1] or "").strip() if len(row) > 1 and row[1] is not None else "" with open(os.path.join(TXT_EXPORT_DIR, f"{stem}.txt"), "w", encoding="utf-8") as f: f.write(cap) name = sanitize_basename(dataset_name) zpath = f"/tmp/{name}_{int(time.time())}_txt.zip" with zipfile.ZipFile(zpath, "w", zipfile.ZIP_DEFLATED) as z: for fn in os.listdir(TXT_EXPORT_DIR): if fn.endswith(".txt"): z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn) return zpath # ------------------------------ # 10) UI header helper (fixed logo size) # ------------------------------ def _render_header_html(px: int) -> str: return f"""
{logo_b64_img()}

ForgeCaptions

JoyCaption Image Captioning
Import CSV/XLSX • Export CSV/XLSX/TXT
Batch 10–20 per Zero GPU run • Larger batches with dedicated GPU

""" # ------------------------------ # 11) Handlers (defined before UI) # ------------------------------ def _split_chunks(files, csize: int): files = files or [] c = max(1, int(csize)) return [files[i:i + c] for i in range(0, len(files), c)] def _tpms(): s = load_settings() return s.get("temperature", 0.6), s.get("top_p", 0.9), s.get("max_tokens", 256) def _run_click(files, rows, instr, ms, mode, csize, budget_s, no_limit): t, p, m = _tpms() files = files or [] budget = None if no_limit else float(budget_s) if mode == "Manual (step)" and files: chunks = _split_chunks(files, int(csize)) batch = chunks[0] remaining = sum(chunks[1:], []) new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch( batch, rows or [], instr, t, p, m, int(ms), budget ) remaining = (leftover_from_batch or []) + remaining panel_vis = gr.update(visible=bool(remaining)) msg = f"{len(remaining)} files remain. Process next chunk?" prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}" return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog) # Auto new_rows, gal, tbl, stamp, leftover, done, total = run_batch( files, rows or [], instr, t, p, m, int(ms), budget ) panel_vis = gr.update(visible=bool(leftover)) msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else "" prog = f"Batch progress: {done}/{total} processed in this call • Remaining: {len(leftover)}" return new_rows, gal, tbl, stamp, leftover, panel_vis, gr.update(value=msg), gr.update(value=prog) def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit): t, p, m = _tpms() remain = remain or [] budget = None if no_limit else float(budget_s) if not remain: return ( rows, gr.update(value="No files remaining."), gr.update(visible=False), [], [], [], "Saved.", gr.update(value="") ) batch = remain[:int(csize)] leftover = remain[int(csize):] new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch( batch, rows or [], instr, t, p, m, int(ms), budget ) leftover = (leftover_from_batch or []) + leftover panel_vis = gr.update(visible=bool(leftover)) msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else "All done." prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(leftover)}" return new_rows, msg, panel_vis, leftover, gal, tbl, stamp, gr.update(value=prog) def _step_finish(): return gr.update(visible=False), gr.update(value=""), [] def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]: session_rows = _table_to_rows(table_value, session_rows or []) save_session(session_rows) gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption", "")) for r in session_rows if (r.get("thumb_path") or r.get("path"))] return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}" # ------------------------------ # 12) UI (Blocks) # ------------------------------ BASE_CSS = """ :root{--galleryW:50%;--tableW:50%;} .gradio-container{max-width:100%!important} /* Header */ .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px; margin:4px 0 12px; text-align:center;} .cf-hero .cf-text{text-align:center;} .cf-title{margin:0;font-size:3.0rem;line-height:1;letter-spacing:.2px} .cf-sub{margin:6px 0 0;font-size:1.05rem;color:#cfd3da} /* Results area + robust scrollbars */ .cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px} #cfGal{max-height:520px; overflow-y:auto !important;} #cfTableWrap{max-height:520px; overflow-y:auto !important;} #cfGal [data-testid="gallery"]{height:auto !important;} #cfGal .grid > div { height: 96px; } """ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo: # ---- Header settings = load_settings() header_html = gr.HTML(_render_header_html(settings.get("logo_px", LOGO_HEIGHT_PX))) # ---- Controls group with gr.Group(): with gr.Row(): # LEFT: styles / extras / name & prefix-suffix with gr.Column(scale=2): with gr.Accordion("Caption style (choose one or combine)", open=True): style_checks = gr.CheckboxGroup( choices=STYLE_OPTIONS, value=settings.get("styles", ["Character training"]), label=None ) caption_length = gr.Dropdown( choices=LENGTH_CHOICES, label="Caption Length", value=settings.get("caption_length", "long") ) with gr.Accordion("Extra options", open=False): extra_opts = gr.CheckboxGroup( choices=[NAME_OPTION] + EXTRA_CHOICES, value=settings.get("extras", []), label=None ) with gr.Accordion("Name & Prefix/Suffix", open=False): name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", "")) trig = gr.Textbox(label="Trigger word", value=settings.get("trigger","")) add_start = gr.Textbox(label="Add text to start", value=settings.get("begin","")) add_end = gr.Textbox(label="Add text to end", value=settings.get("end","")) # RIGHT: instructions + dataset + general sliders with gr.Column(scale=1): with gr.Accordion("Model Instructions", open=False): instruction_preview = gr.Textbox( label=None, lines=12, value=final_instruction( settings.get("styles", ["Character training"]), settings.get("extras", []), settings.get("name",""), settings.get("caption_length", "long"), ), ) dataset_name = gr.Textbox( label="Dataset name (export title prefix)", value=settings.get("dataset_name", "forgecaptions") ) max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)") excel_thumb_px = gr.Slider( 64, 256, value=settings.get("excel_thumb_px", 128), step=8, label="Excel thumbnail size (px)" ) # Chunking chunk_mode = gr.Radio( choices=["Auto", "Manual (step)"], value="Manual (step)", label="Batch mode" ) chunk_size = gr.Slider(1, 200, value=15, step=1, label="Chunk size") gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call") no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)") # Persist instruction + general settings def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms, cap_len): instr = final_instruction(styles or ["Character training"], extra or [], name_value, cap_len) cfg = load_settings() cfg.update({ "styles": styles or ["Character training"], "extras": _valid_extras(extra), "name": name_value, "trigger": trigv, "begin": begv, "end": endv, "excel_thumb_px": int(excel_px), "max_side": int(ms), "caption_length": cap_len or "any", }) save_settings(cfg) return instr for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length]: comp.change( _refresh_instruction, inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length], outputs=[instruction_preview] ) def _save_dataset_name(name): cfg = load_settings() cfg["dataset_name"] = sanitize_basename(name) save_settings(cfg) return gr.update() dataset_name.change(_save_dataset_name, inputs=[dataset_name], outputs=[]) # ---- Shape Aliases (with plural matching + persist) with gr.Accordion("Shape Aliases", open=False): gr.Markdown( "### 🔷 Shape Aliases\n" "Replace literal **shape tokens** in captions with a preferred **name**.\n\n" "- Left column = a single token **or** comma/pipe-separated synonyms (e.g., `diamond, rhombus | lozenge`)\n" "- Right column = replacement name (e.g., `family-emblem`)\n" "Matches are case-insensitive, catches simple plurals, and also matches `*-shaped` / `* shaped` variants." ) init_rows, init_enabled = get_shape_alias_rows_ui_defaults() enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled) persist_aliases = gr.Checkbox( label="Save aliases across sessions", value=load_settings().get("shape_aliases_persist", True) ) alias_table = gr.Dataframe( headers=["shape (token or synonyms)", "name to insert"], value=init_rows, row_count=(max(1, len(init_rows)), "dynamic"), datatype=["str", "str"], type="array", interactive=True ) with gr.Row(): add_row_btn = gr.Button("+ Add row", variant="secondary") clear_btn = gr.Button("Clear", variant="secondary") save_btn = gr.Button("💾 Save", variant="primary") save_status = gr.Markdown("") def _add_row(cur): cur = (cur or []) + [["", ""]] return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic")) def _clear_rows(): return gr.update(value=[["", ""]], row_count=(1, "dynamic")) add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table]) clear_btn.click(_clear_rows, outputs=[alias_table]) def _save_alias_persist_flag(v): cfg = load_settings() cfg["shape_aliases_persist"] = bool(v) save_settings(cfg) return gr.update() persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[]) save_btn.click( save_shape_alias_rows, inputs=[enable_aliases, alias_table, persist_aliases], outputs=[save_status, alias_table] ) # ---- Tabs: Single & Batch with gr.Tabs(): with gr.Tab("Single"): input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512) single_caption_btn = gr.Button("Caption") single_caption_out = gr.Textbox(label="Caption (single)") single_caption_btn.click( caption_single, inputs=[input_image_single, instruction_preview], outputs=[single_caption_out] ) with gr.Tab("Batch"): with gr.Accordion("Uploaded images", open=True): input_files = gr.File( label="Drop images (or click to select)", file_types=["image"], file_count="multiple", type="filepath" ) run_button = gr.Button("Caption batch", variant="primary") # ---- Results area (gallery left / table right) rows_state = gr.State(load_session()) autosave_md = gr.Markdown("Ready.") progress_md = gr.Markdown("") remaining_state = gr.State([]) with gr.Row(): with gr.Column(scale=2): gallery = gr.Gallery( label="Results", show_label=True, columns=1, elem_id="cfGal", elem_classes=["cf-scroll"] ) with gr.Column(scale=1, elem_id="cfTableWrap", elem_classes=["cf-scroll"]): table = gr.Dataframe( label="Editable captions", value=_rows_to_table(load_session()), headers=["filename", "caption"], interactive=True, wrap=True, type="array", elem_id="cfTable" ) # ---- Step panel step_panel = gr.Group(visible=False) with step_panel: step_msg = gr.Markdown("") step_next = gr.Button("Process next chunk") step_finish = gr.Button("Finish") # ---- Exports with gr.Row(): with gr.Column(): export_csv_btn = gr.Button("Export CSV") csv_file = gr.File(label="CSV file", visible=False) with gr.Column(): export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails") xlsx_file = gr.File(label="Excel file", visible=False) with gr.Column(): export_txt_btn = gr.Button("Export captions as .txt (zip)") txt_zip = gr.File(label="TXT zip", visible=False) # ---- Scroll-sync JS injection (inside Blocks) gr.HTML(""" """) # ---- Event bindings (MUST be inside Blocks in Gradio v5) run_button.click( _run_click, inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size, gpu_budget, no_time_limit], outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md] ) step_next.click( _step_next, inputs=[remaining_state, rows_state, instruction_preview, max_side, chunk_size, gpu_budget, no_time_limit], outputs=[rows_state, step_msg, step_panel, remaining_state, gallery, table, autosave_md, progress_md] ) step_finish.click(_step_finish, inputs=None, outputs=[step_panel, step_msg, remaining_state]) table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md]) export_csv_btn.click( lambda tbl, ds: (export_csv_from_table(tbl, ds), gr.update(visible=True)), inputs=[table, dataset_name], outputs=[csv_file, csv_file] ) export_xlsx_btn.click( lambda tbl, rows, px, ds: (export_excel_with_thumbs(tbl, rows or [], int(px), ds), gr.update(visible=True)), inputs=[table, rows_state, excel_thumb_px, dataset_name], outputs=[xlsx_file, xlsx_file] ) export_txt_btn.click( lambda tbl, ds: (export_txt_zip(tbl, ds), gr.update(visible=True)), inputs=[table, dataset_name], outputs=[txt_zip, txt_zip] ) # ------------------------------ # 12) Launch # ------------------------------ if __name__ == "__main__": demo.queue(max_size=64).launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), ssr_mode=False, debug=True, show_error=True, allowed_paths=[THUMB_CACHE, EXCEL_THUMB_DIR, TXT_EXPORT_DIR], )