Arrcttacsrks commited on
Commit
8d280db
Β·
verified Β·
1 Parent(s): 34e5c18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -79
app.py CHANGED
@@ -6,7 +6,11 @@ import gradio as gr
6
  import torch
7
  import numpy as np
8
  from torchvision.utils import save_image
9
-
 
 
 
 
10
 
11
  # Import files from the local folder
12
  root_path = os.path.abspath('.')
@@ -14,6 +18,18 @@ sys.path.append(root_path)
14
  from test_code.inference import super_resolve_img
15
  from test_code.test_utils import load_grl, load_rrdb, load_dat
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def auto_download_if_needed(weight_path):
19
  if os.path.exists(weight_path):
@@ -39,59 +55,252 @@ def auto_download_if_needed(weight_path):
39
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
40
 
41
 
42
- def inference(img_path, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
44
  try:
45
- # Load the model
46
- if model_name == "4xGRL":
47
- weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
48
- auto_download_if_needed(weight_path)
49
- generator = load_grl(weight_path, scale=4)
50
- generator = generator.to(device='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- elif model_name == "4xRRDB":
53
- weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
54
- auto_download_if_needed(weight_path)
55
- generator = load_rrdb(weight_path, scale=4)
56
- generator = generator.to(device='cpu')
57
 
58
- elif model_name == "2xRRDB":
59
- weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
60
- auto_download_if_needed(weight_path)
61
- generator = load_rrdb(weight_path, scale=2)
62
- generator = generator.to(device='cpu')
63
 
64
- elif model_name == "4xDAT":
65
- weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
66
- auto_download_if_needed(weight_path)
67
- generator = load_dat(weight_path, scale=4)
68
- generator = generator.to(device='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- else:
71
- raise gr.Error("We don't support such Model")
 
 
 
 
 
 
 
72
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- print("We are processing ", img_path)
75
- print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
76
 
77
- # In default, we will automatically use crop to match 4x size
78
- super_resolved_img = super_resolve_img(generator, img_path, output_path=None, downsample_threshold=720, crop_for_4x=True)
79
- store_name = str(time.time()) + ".png"
80
- save_image(super_resolved_img, store_name)
81
- outputs = cv2.imread(store_name)
82
- outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
83
- os.remove(store_name)
84
-
85
- return outputs
86
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- except Exception as error:
89
- raise gr.Error(f"global exception: {error}")
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  if __name__ == '__main__':
94
 
 
 
 
 
95
  MARKDOWN = \
96
  """
97
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
@@ -100,50 +309,104 @@ if __name__ == '__main__':
100
 
101
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
102
 
103
- ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
104
- ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
105
-
106
- ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
107
  """
108
 
109
  block = gr.Blocks().queue(max_size=10)
110
  with block:
111
- with gr.Row():
112
- gr.Markdown(MARKDOWN)
113
- with gr.Row(elem_classes=["container"]):
114
- with gr.Column(scale=2):
115
- input_image = gr.Image(type="filepath", label="Input")
116
- model_name = gr.Dropdown(
117
- [
118
- "2xRRDB",
119
- "4xRRDB",
120
- "4xGRL",
121
- "4xDAT",
122
- ],
123
- type="value",
124
- value="4xGRL",
125
- label="model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
- run_btn = gr.Button(value="Submit")
128
-
129
- with gr.Column(scale=3):
130
- output_image = gr.Image(type="numpy", label="Output image")
131
-
132
- with gr.Row(elem_classes=["container"]):
133
- gr.Examples(
134
- [
135
- ["__assets__/lr_inputs/image-00277.png"],
136
- ["__assets__/lr_inputs/image-00542.png"],
137
- ["__assets__/lr_inputs/41.png"],
138
- ["__assets__/lr_inputs/f91.jpg"],
139
- ["__assets__/lr_inputs/image-00440.png"],
140
- ["__assets__/lr_inputs/image-00164.jpg"],
141
- ["__assets__/lr_inputs/img_eva.jpeg"],
142
- ["__assets__/lr_inputs/naruto.jpg"],
143
- ],
144
- [input_image],
145
- )
146
-
147
- run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- block.launch()
 
6
  import torch
7
  import numpy as np
8
  from torchvision.utils import save_image
9
+ import json
10
+ import threading
11
+ from queue import Queue
12
+ from pathlib import Path
13
+ import shutil
14
 
15
  # Import files from the local folder
16
  root_path = os.path.abspath('.')
 
18
  from test_code.inference import super_resolve_img
19
  from test_code.test_utils import load_grl, load_rrdb, load_dat
20
 
21
+ # Global configuration
22
+ OUTPUT_DIR = "outputs"
23
+ HISTORY_FILE = "history.json"
24
+ VIDEO_QUEUE_FILE = "video_queue.json"
25
+ video_queue = Queue()
26
+ processing_status = {}
27
+
28
+ # Initialize directories
29
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
30
+ os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
31
+ os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True)
32
+
33
 
34
  def auto_download_if_needed(weight_path):
35
  if os.path.exists(weight_path):
 
55
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
56
 
57
 
58
+ def load_history():
59
+ """Load processing history from JSON file"""
60
+ if os.path.exists(HISTORY_FILE):
61
+ with open(HISTORY_FILE, 'r') as f:
62
+ return json.load(f)
63
+ return []
64
+
65
+
66
+ def save_history(history):
67
+ """Save processing history to JSON file"""
68
+ with open(HISTORY_FILE, 'w') as f:
69
+ json.dump(history, f, indent=2)
70
+
71
+
72
+ def add_to_history(input_path, output_path, model_name, process_type, status="completed"):
73
+ """Add a record to history"""
74
+ history = load_history()
75
+ record = {
76
+ "timestamp": datetime.datetime.now().isoformat(),
77
+ "input_path": input_path,
78
+ "output_path": output_path,
79
+ "model_name": model_name,
80
+ "process_type": process_type,
81
+ "status": status
82
+ }
83
+ history.insert(0, record) # Add to beginning
84
+ save_history(history)
85
+
86
+
87
+ def load_generator(model_name):
88
+ """Load the appropriate model"""
89
+ if model_name == "4xGRL":
90
+ weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
91
+ auto_download_if_needed(weight_path)
92
+ generator = load_grl(weight_path, scale=4)
93
+
94
+ elif model_name == "4xRRDB":
95
+ weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
96
+ auto_download_if_needed(weight_path)
97
+ generator = load_rrdb(weight_path, scale=4)
98
+
99
+ elif model_name == "2xRRDB":
100
+ weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
101
+ auto_download_if_needed(weight_path)
102
+ generator = load_rrdb(weight_path, scale=2)
103
+
104
+ elif model_name == "4xDAT":
105
+ weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
106
+ auto_download_if_needed(weight_path)
107
+ generator = load_dat(weight_path, scale=4)
108
+ else:
109
+ raise ValueError(f"Model {model_name} not supported")
110
 
111
+ return generator.to(device='cpu')
112
+
113
+
114
+ def inference_image(img_path, model_name):
115
+ """Process a single image"""
116
  try:
117
+ generator = load_generator(model_name)
118
+
119
+ print("Processing image:", img_path)
120
+ print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern')))
121
+
122
+ # Process image
123
+ super_resolved_img = super_resolve_img(
124
+ generator, img_path, output_path=None,
125
+ downsample_threshold=720, crop_for_4x=True
126
+ )
127
+
128
+ # Save output
129
+ timestamp = int(time.time() * 1000)
130
+ output_name = f"image_{timestamp}.png"
131
+ output_path = os.path.join(OUTPUT_DIR, "images", output_name)
132
+ save_image(super_resolved_img, output_path)
133
+
134
+ # Load and convert for display
135
+ outputs = cv2.imread(output_path)
136
+ outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB)
137
+
138
+ # Add to history
139
+ add_to_history(img_path, output_path, model_name, "image")
140
+
141
+ return outputs, f"βœ… Saved to: {output_path}"
142
+
143
+ except Exception as error:
144
+ raise gr.Error(f"Error processing image: {error}")
145
+
146
+
147
+ def process_video_frame_by_frame(video_path, model_name, task_id):
148
+ """Process video frame by frame"""
149
+ try:
150
+ processing_status[task_id] = {"status": "processing", "progress": 0}
151
+
152
+ # Load model
153
+ generator = load_generator(model_name)
154
+
155
+ # Open video
156
+ cap = cv2.VideoCapture(video_path)
157
+ if not cap.isOpened():
158
+ raise ValueError("Cannot open video file")
159
+
160
+ # Get video properties
161
+ fps = cap.get(cv2.CAP_PROP_FPS)
162
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
163
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
164
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
165
+
166
+ # Prepare output
167
+ timestamp = int(time.time() * 1000)
168
+ output_name = f"video_{timestamp}.mp4"
169
+ output_path = os.path.join(OUTPUT_DIR, "videos", output_name)
170
+
171
+ # Create temporary directory for frames
172
+ temp_dir = f"temp_frames_{timestamp}"
173
+ os.makedirs(temp_dir, exist_ok=True)
174
+
175
+ # Process frames
176
+ frame_count = 0
177
+ while True:
178
+ ret, frame = cap.read()
179
+ if not ret:
180
+ break
181
+
182
+ # Save frame temporarily
183
+ temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png")
184
+ cv2.imwrite(temp_frame_path, frame)
185
+
186
+ # Super resolve frame
187
+ super_resolved_img = super_resolve_img(
188
+ generator, temp_frame_path, output_path=None,
189
+ downsample_threshold=720, crop_for_4x=True
190
+ )
191
 
192
+ # Save processed frame
193
+ output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png")
194
+ save_image(super_resolved_img, output_frame_path)
 
 
195
 
196
+ frame_count += 1
197
+ progress = int((frame_count / total_frames) * 100)
198
+ processing_status[task_id] = {"status": "processing", "progress": progress}
 
 
199
 
200
+ print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)")
201
+
202
+ cap.release()
203
+
204
+ # Combine frames into video using ffmpeg
205
+ print(f"Task {task_id}: Combining frames into video...")
206
+ processing_status[task_id] = {"status": "encoding", "progress": 100}
207
+
208
+ os.system(f"ffmpeg -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}")
209
+
210
+ # Clean up
211
+ shutil.rmtree(temp_dir)
212
+
213
+ processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path}
214
+ add_to_history(video_path, output_path, model_name, "video")
215
+
216
+ print(f"Task {task_id}: Completed! Output: {output_path}")
217
+
218
+ except Exception as error:
219
+ processing_status[task_id] = {"status": "error", "error": str(error)}
220
+ print(f"Task {task_id}: Error - {error}")
221
+
222
+
223
+ def video_queue_worker():
224
+ """Background worker to process video queue"""
225
+ print("Video queue worker started...")
226
+ while True:
227
+ try:
228
+ task = video_queue.get()
229
+ if task is None: # Poison pill to stop worker
230
+ break
231
 
232
+ task_id, video_path, model_name = task
233
+ print(f"Starting task {task_id}...")
234
+ process_video_frame_by_frame(video_path, model_name, task_id)
235
+
236
+ except Exception as e:
237
+ print(f"Worker error: {e}")
238
+ finally:
239
+ video_queue.task_done()
240
+
241
 
242
+ def submit_video(video_path, model_name):
243
+ """Submit video to processing queue"""
244
+ if video_path is None:
245
+ return None, "❌ Please upload a video first"
246
+
247
+ task_id = f"task_{int(time.time() * 1000)}"
248
+ video_queue.put((task_id, video_path, model_name))
249
+ processing_status[task_id] = {"status": "queued", "progress": 0}
250
+
251
+ return None, f"βœ… Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section."
252
 
 
 
253
 
254
+ def get_queue_status():
255
+ """Get current queue status"""
256
+ status_text = "πŸ“Š **Queue Status**\n\n"
257
+ status_text += f"Videos in queue: {video_queue.qsize()}\n\n"
 
 
 
 
 
258
 
259
+ if processing_status:
260
+ status_text += "**Active Tasks:**\n"
261
+ for task_id, status in processing_status.items():
262
+ status_text += f"\n🎬 {task_id}:\n"
263
+ status_text += f" Status: {status['status']}\n"
264
+ status_text += f" Progress: {status.get('progress', 0)}%\n"
265
+ if 'output' in status:
266
+ status_text += f" Output: {status['output']}\n"
267
+ if 'error' in status:
268
+ status_text += f" Error: {status['error']}\n"
269
+ else:
270
+ status_text += "No active tasks"
271
 
272
+ return status_text
 
273
 
274
 
275
+ def get_history_display():
276
+ """Get formatted history for display"""
277
+ history = load_history()
278
+ if not history:
279
+ return "No history available"
280
+
281
+ history_text = "πŸ“œ **Processing History**\n\n"
282
+ for idx, record in enumerate(history[:50]): # Show last 50
283
+ history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n"
284
+ history_text += f" Model: {record['model_name']}\n"
285
+ history_text += f" Status: {record['status']}\n"
286
+ history_text += f" Output: {record['output_path']}\n\n"
287
+
288
+ return history_text
289
+
290
+
291
+ def clear_history():
292
+ """Clear all history"""
293
+ if os.path.exists(HISTORY_FILE):
294
+ os.remove(HISTORY_FILE)
295
+ return "βœ… History cleared!"
296
+
297
 
298
  if __name__ == '__main__':
299
 
300
+ # Start background worker thread
301
+ worker_thread = threading.Thread(target=video_queue_worker, daemon=True)
302
+ worker_thread.start()
303
+
304
  MARKDOWN = \
305
  """
306
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
 
309
 
310
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
311
 
312
+ ### ⚠️ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β†’ 1280x720)
313
+ ### πŸ“Ή New: Video processing runs in background queue - you can close the browser and it continues!
 
 
314
  """
315
 
316
  block = gr.Blocks().queue(max_size=10)
317
  with block:
318
+ gr.Markdown(MARKDOWN)
319
+
320
+ with gr.Tabs():
321
+ # Tab 1: Image Processing
322
+ with gr.Tab("πŸ–ΌοΈ Image Processing"):
323
+ with gr.Row():
324
+ with gr.Column(scale=2):
325
+ input_image = gr.Image(type="filepath", label="Input Image")
326
+ image_model = gr.Dropdown(
327
+ ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
328
+ value="4xGRL",
329
+ label="Model"
330
+ )
331
+ image_btn = gr.Button("πŸš€ Process Image", variant="primary")
332
+
333
+ with gr.Column(scale=3):
334
+ output_image = gr.Image(type="numpy", label="Output Image")
335
+ image_status = gr.Textbox(label="Status", lines=2)
336
+
337
+ with gr.Row():
338
+ gr.Examples(
339
+ [
340
+ ["__assets__/lr_inputs/image-00277.png"],
341
+ ["__assets__/lr_inputs/image-00542.png"],
342
+ ["__assets__/lr_inputs/41.png"],
343
+ ["__assets__/lr_inputs/f91.jpg"],
344
+ ],
345
+ [input_image],
346
+ )
347
+
348
+ image_btn.click(
349
+ inference_image,
350
+ inputs=[input_image, image_model],
351
+ outputs=[output_image, image_status]
352
  )
353
+
354
+ # Tab 2: Video Processing
355
+ with gr.Tab("🎬 Video Processing"):
356
+ gr.Markdown("""
357
+ ### Video Processing Queue
358
+ Videos are processed in the background. You can submit multiple videos and close the browser - processing continues!
359
+ """)
360
+
361
+ with gr.Row():
362
+ with gr.Column():
363
+ input_video = gr.Video(label="Input Video")
364
+ video_model = gr.Dropdown(
365
+ ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
366
+ value="4xGRL",
367
+ label="Model"
368
+ )
369
+ video_btn = gr.Button("πŸ“€ Submit to Queue", variant="primary")
370
+ video_status = gr.Textbox(label="Submission Status", lines=3)
371
+
372
+ with gr.Column():
373
+ gr.Markdown("### πŸ“Š Queue Monitor")
374
+ queue_status = gr.Textbox(label="Queue Status", lines=15)
375
+ refresh_btn = gr.Button("πŸ”„ Refresh Status")
376
+
377
+ video_btn.click(
378
+ submit_video,
379
+ inputs=[input_video, video_model],
380
+ outputs=[input_video, video_status]
381
+ )
382
+
383
+ refresh_btn.click(
384
+ get_queue_status,
385
+ outputs=[queue_status]
386
+ )
387
+
388
+ # Tab 3: History
389
+ with gr.Tab("πŸ“œ History"):
390
+ gr.Markdown("### Processing History")
391
+
392
+ with gr.Row():
393
+ refresh_history_btn = gr.Button("πŸ”„ Refresh History")
394
+ clear_history_btn = gr.Button("πŸ—‘οΈ Clear History", variant="stop")
395
+
396
+ history_display = gr.Textbox(label="History", lines=20)
397
+ clear_status = gr.Textbox(label="Status", lines=1)
398
+
399
+ refresh_history_btn.click(
400
+ get_history_display,
401
+ outputs=[history_display]
402
+ )
403
+
404
+ clear_history_btn.click(
405
+ clear_history,
406
+ outputs=[clear_status]
407
+ )
408
+
409
+ # Auto-load history on tab open
410
+ block.load(get_history_display, outputs=[history_display])
411
 
412
+ block.launch(server_name="0.0.0.0", server_port=7860)