varunv2004 commited on
Commit
c70c930
·
verified ·
1 Parent(s): eeef6a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import gradio as gr
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # ComfyUI imports (assumes ComfyUI folder is dedicated in repo)
11
+ from comfy import model_management # may be needed for plugin system
12
+ from nodes import (
13
+ CheckpointLoaderSimple,
14
+ CLIPLoader,
15
+ CLIPTextEncode,
16
+ VAELoader,
17
+ VAEDecode,
18
+ KSampler,
19
+ )
20
+ from custom_nodes.ComfyUI_GGUF.nodes import UnetLoaderGGUF
21
+ from comfy_extras.nodes_hunyuan import EmptyHunyuanLatentVideo
22
+ from comfy_extras.nodes_images import SaveAnimatedWEBP
23
+ from comfy_extras.nodes_video import SaveWEBM
24
+
25
+ # Globals
26
+ unet_loader = None
27
+ clip_loader = None
28
+ clip_encode_positive = None
29
+ clip_encode_negative = None
30
+ vae_loader = None
31
+ empty_latent_video = None
32
+ ksampler = None
33
+ vae_decode = None
34
+
35
+
36
+ # ✅ Ensure models are available via HF hub or local
37
+ def ensure_model(repo_id, filename, folder):
38
+ os.makedirs(f"ComfyUI/models/{folder}", exist_ok=True)
39
+ local_path = os.path.join("ComfyUI", "models", folder, filename)
40
+ if not os.path.isfile(local_path):
41
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.dirname(local_path))
42
+ return local_path
43
+
44
+
45
+ # 1️⃣ Initialize imports and model loader utilities
46
+ def imports_initialization():
47
+ global unet_loader, clip_loader, clip_encode_positive, clip_encode_negative
48
+ global vae_loader, empty_latent_video, ksampler, vae_decode
49
+
50
+ unet_loader = UnetLoaderGGUF()
51
+ clip_loader = CLIPLoader()
52
+ clip_encode_positive = CLIPTextEncode()
53
+ clip_encode_negative = CLIPTextEncode()
54
+ vae_loader = VAELoader()
55
+ empty_latent_video = EmptyHunyuanLatentVideo()
56
+ ksampler = KSampler()
57
+ vae_decode = VAEDecode()
58
+
59
+ return "✅ Imports done and models initialized."
60
+
61
+
62
+ # Clean GPU memory
63
+ def clear_memory():
64
+ gc.collect()
65
+ if torch.cuda.is_available():
66
+ torch.cuda.empty_cache()
67
+ torch.cuda.ipc_collect()
68
+ for obj in list(globals().values()):
69
+ try:
70
+ if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
71
+ del obj
72
+ except:
73
+ pass
74
+ gc.collect()
75
+
76
+
77
+ # Save utility functions
78
+ def save_as_mp4(images, prefix, fps):
79
+ os.makedirs("output", exist_ok=True)
80
+ path = f"output/{prefix}.mp4"
81
+ writer = imageio.get_writer(path, fps=fps)
82
+ for img in images:
83
+ writer.append_data((img.cpu().numpy() * 255).astype(np.uint8))
84
+ writer.close()
85
+ return path
86
+
87
+ def save_as_webm(images, prefix, fps):
88
+ os.makedirs("output", exist_ok=True)
89
+ path = f"output/{prefix}.webm"
90
+ writer = imageio.get_writer(
91
+ path, format='FFMPEG', fps=fps,
92
+ codec='vp9', quality=20
93
+ )
94
+ for img in images:
95
+ writer.append_data((img.cpu().numpy() * 255).astype(np.uint8))
96
+ writer.close()
97
+ return path
98
+
99
+ def save_as_image(img, prefix):
100
+ os.makedirs("output", exist_ok=True)
101
+ path = f"output/{prefix}.png"
102
+ pil = Image.fromarray((img.cpu().numpy() * 255).astype(np.uint8))
103
+ pil.save(path)
104
+ return path
105
+
106
+
107
+ # 2️⃣ Text-to-Video generation pipeline
108
+ def generate_video(
109
+ positive_prompt, negative_prompt,
110
+ width, height, seed, steps, cfg_scale,
111
+ sampler_name, scheduler, frames, fps, output_format, use_q6
112
+ ):
113
+ log = []
114
+
115
+ # 2a. Download or load model files
116
+ unet_file = ensure_model(
117
+ "city96/Wan2.1-T2V-14B-gguf",
118
+ "wan2.1-t2v-14b-Q6_K.gguf" if use_q6 else "wan2.1-t2v-14b-Q5_0.gguf",
119
+ "unet"
120
+ )
121
+ text_enc_file = ensure_model(
122
+ "Comfy-Org/Wan_2.1_ComfyUI_repackaged",
123
+ "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
124
+ "text_encoders"
125
+ )
126
+ vae_file = ensure_model(
127
+ "Comfy-Org/Wan_2.1_ComfyUI_repackaged",
128
+ "wan_2.1_vae.safetensors",
129
+ "vae"
130
+ )
131
+
132
+ # 2b. Encode text prompts
133
+ log.append("🔧 Encoding prompts...")
134
+ clip_model = clip_loader.load_clip(text_enc_file, "wan", "default")[0]
135
+ pos = clip_encode_positive.encode(clip_model, positive_prompt)[0]
136
+ neg = clip_encode_negative.encode(clip_model, negative_prompt)[0]
137
+ del clip_model
138
+ clear_memory()
139
+
140
+ # 2c. Setup latent video
141
+ latent = empty_latent_video.generate(width, height, frames, 1)[0]
142
+
143
+ # 2d. Sample using UNet
144
+ model = unet_loader.load_unet(unet_file)[0]
145
+ log.append("🎥 Sampling latents...")
146
+ sampled = ksampler.sample(
147
+ model=model,
148
+ seed=seed,
149
+ steps=steps,
150
+ cfg=cfg_scale,
151
+ sampler_name=sampler_name,
152
+ scheduler=scheduler,
153
+ positive=pos,
154
+ negative=neg,
155
+ latent_image=latent
156
+ )[0]
157
+ del model
158
+ clear_memory()
159
+
160
+ # 2e. Decode via VAE
161
+ log.append("🔓 Decoding with VAE...")
162
+ vae_model = vae_loader.load_vae(vae_file)[0]
163
+ decoded = vae_decode.decode(vae_model, sampled)[0]
164
+ del vae_model
165
+ clear_memory()
166
+
167
+ # 2f. Save output
168
+ filename = "hf_gen"
169
+ if frames == 1:
170
+ log.append("💾 Saving single frame...")
171
+ out = save_as_image(decoded[0], filename)
172
+ else:
173
+ if output_format == "webm":
174
+ log.append("💾 Saving as WEBM...")
175
+ out = save_as_webm(decoded, filename, fps)
176
+ else:
177
+ log.append("💾 Saving as MP4...")
178
+ out = save_as_mp4(decoded, filename, fps)
179
+
180
+ log.append(f"✅ Saved: {out}")
181
+ clear_memory()
182
+ return "\n".join(log), out
183
+
184
+
185
+ # 3️⃣ Gradio UI
186
+
187
+ app = gr.Blocks()
188
+ with app:
189
+ gr.Markdown("# ComfyUI Text‑to‑Video on Hugging Face Spaces")
190
+
191
+ with gr.Tab("Initialize"):
192
+ init_btn = gr.Button("Initialize Models")
193
+ init_out = gr.Textbox(lines=3, interactive=False, label="Status")
194
+ init_btn.click(imports_initialization, None, init_out)
195
+
196
+ with gr.Tab("Generate"):
197
+ with gr.Row():
198
+ pos = gr.Textbox(label="Positive Prompt", value="lion")
199
+ neg = gr.Textbox(label="Negative Prompt", value="")
200
+ with gr.Row():
201
+ w = gr.Slider(64, 1024, step=8, value=400, label="Width")
202
+ h = gr.Slider(64, 1024, step=8, value=400, label="Height")
203
+ with gr.Row():
204
+ se = gr.Number(label="Seed", value=0)
205
+ st = gr.Slider(1, 100, value=10, label="Steps")
206
+ cf = gr.Slider(1, 20, step=0.1, value=3, label="CFG Scale")
207
+ with gr.Row():
208
+ samp = gr.Dropdown(["uni_pc", "euler", "dpmpp_2m", "ddim", "lms"], value="uni_pc", label="Sampler")
209
+ sched = gr.Dropdown(["simple", "normal", "karras", "exponential"], value="normal", label="Scheduler")
210
+ with gr.Row():
211
+ fr = gr.Slider(1, 60, value=2, label="Frames")
212
+ fps = gr.Slider(1, 60, value=10, label="FPS")
213
+ fmt = gr.Radio(["mp4", "webm"], value="webm", label="Output Format")
214
+ q6 = gr.Checkbox(label="Use Q6 UNet model", value=False)
215
+ gen_btn = gr.Button("Generate")
216
+ gen_log = gr.Textbox(lines=10, interactive=False, label="Log")
217
+ gen_out = gr.Video(label="Output Video/Image")
218
+
219
+ gen_btn.click(
220
+ fn=generate_video,
221
+ inputs=[pos, neg, w, h, se, st, cf, samp, sched, fr, fps, fmt, q6],
222
+ outputs=[gen_log, gen_out]
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ app.launch()