Spaces:
Running
Running
Add progress display during the predict phase
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import tempfile
|
|
| 12 |
import zipfile
|
| 13 |
import re
|
| 14 |
import ast
|
|
|
|
| 15 |
from datetime import datetime
|
| 16 |
from collections import defaultdict
|
| 17 |
from classifyTags import classify_tags
|
|
@@ -111,6 +112,52 @@ def mcut_threshold(probs):
|
|
| 111 |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
| 112 |
return thresh
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
class Llama3Reorganize:
|
| 115 |
def __init__(
|
| 116 |
self,
|
|
@@ -355,9 +402,21 @@ class Predictor:
|
|
| 355 |
additional_tags_prepend,
|
| 356 |
additional_tags_append,
|
| 357 |
tag_results,
|
|
|
|
| 358 |
):
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
self.load_model(model_repo)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
# Result
|
| 362 |
txt_infos = []
|
| 363 |
output_dir = tempfile.mkdtemp()
|
|
@@ -372,6 +431,11 @@ class Predictor:
|
|
| 372 |
if llama3_reorganize_model_repo:
|
| 373 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
| 374 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 377 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
@@ -394,7 +458,7 @@ class Predictor:
|
|
| 394 |
|
| 395 |
input_name = self.model.get_inputs()[0].name
|
| 396 |
label_name = self.model.get_outputs()[0].name
|
| 397 |
-
print(f"Gallery {idx}: Starting run wd model...")
|
| 398 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 399 |
|
| 400 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
@@ -444,6 +508,10 @@ class Predictor:
|
|
| 444 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
|
| 445 |
|
| 446 |
classified_tags, unclassified_tags = classify_tags(sorted_general_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
if llama3_reorganize_model_repo:
|
| 449 |
print(f"Starting reorganize with llama3...")
|
|
@@ -453,11 +521,15 @@ class Predictor:
|
|
| 453 |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
| 454 |
sorted_general_strings += "," + reorganize_strings
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
| 457 |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
| 458 |
|
| 459 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
| 460 |
-
|
| 461 |
except Exception as e:
|
| 462 |
print(traceback.format_exc())
|
| 463 |
print("Error predict: " + str(e))
|
|
@@ -475,7 +547,9 @@ class Predictor:
|
|
| 475 |
if llama3_reorganize_model_repo:
|
| 476 |
llama3_reorganize.release_vram()
|
| 477 |
del llama3_reorganize
|
| 478 |
-
|
|
|
|
|
|
|
| 479 |
print("Predict is complete.")
|
| 480 |
|
| 481 |
return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
|
@@ -524,6 +598,14 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
|
|
| 524 |
|
| 525 |
|
| 526 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
args = parse_args()
|
| 528 |
|
| 529 |
predictor = Predictor()
|
|
@@ -550,7 +632,7 @@ def main():
|
|
| 550 |
META_LLAMA_3_8B_REPO,
|
| 551 |
]
|
| 552 |
|
| 553 |
-
with gr.Blocks(title=TITLE) as demo:
|
| 554 |
gr.Markdown(
|
| 555 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
| 556 |
)
|
|
@@ -561,10 +643,10 @@ def main():
|
|
| 561 |
with gr.Column(variant="panel"):
|
| 562 |
# Create an Image component for uploading images
|
| 563 |
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
|
| 564 |
-
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 565 |
with gr.Row():
|
| 566 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 567 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
|
|
|
| 568 |
|
| 569 |
model_repo = gr.Dropdown(
|
| 570 |
dropdown_list,
|
|
|
|
| 12 |
import zipfile
|
| 13 |
import re
|
| 14 |
import ast
|
| 15 |
+
import time
|
| 16 |
from datetime import datetime
|
| 17 |
from collections import defaultdict
|
| 18 |
from classifyTags import classify_tags
|
|
|
|
| 112 |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
| 113 |
return thresh
|
| 114 |
|
| 115 |
+
class Timer:
|
| 116 |
+
def __init__(self):
|
| 117 |
+
self.start_time = time.perf_counter() # Record the start time
|
| 118 |
+
self.checkpoints = [("Start", self.start_time)] # Store checkpoints
|
| 119 |
+
|
| 120 |
+
def checkpoint(self, label="Checkpoint"):
|
| 121 |
+
"""Record a checkpoint with a given label."""
|
| 122 |
+
now = time.perf_counter()
|
| 123 |
+
self.checkpoints.append((label, now))
|
| 124 |
+
|
| 125 |
+
def report(self, is_clear_checkpoints = True):
|
| 126 |
+
# Determine the max label width for alignment
|
| 127 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
| 128 |
+
|
| 129 |
+
prev_time = self.checkpoints[0][1]
|
| 130 |
+
for label, curr_time in self.checkpoints[1:]:
|
| 131 |
+
elapsed = curr_time - prev_time
|
| 132 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 133 |
+
prev_time = curr_time
|
| 134 |
+
|
| 135 |
+
if is_clear_checkpoints:
|
| 136 |
+
self.checkpoints.clear()
|
| 137 |
+
self.checkpoint() # Store checkpoints
|
| 138 |
+
|
| 139 |
+
def report_all(self):
|
| 140 |
+
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
| 141 |
+
print("\n> Execution Time Report:")
|
| 142 |
+
|
| 143 |
+
# Determine the max label width for alignment
|
| 144 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
|
| 145 |
+
|
| 146 |
+
prev_time = self.start_time
|
| 147 |
+
for label, curr_time in self.checkpoints[1:]:
|
| 148 |
+
elapsed = curr_time - prev_time
|
| 149 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 150 |
+
prev_time = curr_time
|
| 151 |
+
|
| 152 |
+
total_time = self.checkpoints[-1][1] - self.start_time
|
| 153 |
+
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
| 154 |
+
|
| 155 |
+
self.checkpoints.clear()
|
| 156 |
+
|
| 157 |
+
def restart(self):
|
| 158 |
+
self.start_time = time.perf_counter() # Record the start time
|
| 159 |
+
self.checkpoints = [("Start", self.start_time)] # Store checkpoints
|
| 160 |
+
|
| 161 |
class Llama3Reorganize:
|
| 162 |
def __init__(
|
| 163 |
self,
|
|
|
|
| 402 |
additional_tags_prepend,
|
| 403 |
additional_tags_append,
|
| 404 |
tag_results,
|
| 405 |
+
progress=gr.Progress()
|
| 406 |
):
|
| 407 |
+
gallery_len = len(gallery)
|
| 408 |
+
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
|
| 409 |
+
|
| 410 |
+
timer = Timer() # Create a timer
|
| 411 |
+
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 412 |
+
progressTotal = gallery_len + 1
|
| 413 |
+
current_progress = 0
|
| 414 |
+
|
| 415 |
self.load_model(model_repo)
|
| 416 |
+
current_progress += progressRatio/progressTotal;
|
| 417 |
+
progress(current_progress, desc="Initialize wd model finished")
|
| 418 |
+
timer.checkpoint(f"Initialize wd model")
|
| 419 |
+
|
| 420 |
# Result
|
| 421 |
txt_infos = []
|
| 422 |
output_dir = tempfile.mkdtemp()
|
|
|
|
| 431 |
if llama3_reorganize_model_repo:
|
| 432 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
| 433 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
| 434 |
+
current_progress += progressRatio/progressTotal;
|
| 435 |
+
progress(current_progress, desc="Initialize llama3 model finished")
|
| 436 |
+
timer.checkpoint(f"Initialize llama3 model")
|
| 437 |
+
|
| 438 |
+
timer.report()
|
| 439 |
|
| 440 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 441 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
|
|
| 458 |
|
| 459 |
input_name = self.model.get_inputs()[0].name
|
| 460 |
label_name = self.model.get_outputs()[0].name
|
| 461 |
+
print(f"Gallery {idx:02d}: Starting run wd model...")
|
| 462 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 463 |
|
| 464 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
|
|
| 508 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
|
| 509 |
|
| 510 |
classified_tags, unclassified_tags = classify_tags(sorted_general_list)
|
| 511 |
+
|
| 512 |
+
current_progress += progressRatio/progressTotal;
|
| 513 |
+
progress(current_progress, desc=f"image{idx:02d}, predict finished")
|
| 514 |
+
timer.checkpoint(f"image{idx:02d}, predict finished")
|
| 515 |
|
| 516 |
if llama3_reorganize_model_repo:
|
| 517 |
print(f"Starting reorganize with llama3...")
|
|
|
|
| 521 |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
| 522 |
sorted_general_strings += "," + reorganize_strings
|
| 523 |
|
| 524 |
+
current_progress += progressRatio/progressTotal;
|
| 525 |
+
progress(current_progress, desc=f"image{idx:02d}, llama3 reorganize finished")
|
| 526 |
+
timer.checkpoint(f"image{idx:02d}, llama3 reorganize finished")
|
| 527 |
+
|
| 528 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
| 529 |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
| 530 |
|
| 531 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
| 532 |
+
timer.report()
|
| 533 |
except Exception as e:
|
| 534 |
print(traceback.format_exc())
|
| 535 |
print("Error predict: " + str(e))
|
|
|
|
| 547 |
if llama3_reorganize_model_repo:
|
| 548 |
llama3_reorganize.release_vram()
|
| 549 |
del llama3_reorganize
|
| 550 |
+
|
| 551 |
+
progress(1, desc=f"Predict completed")
|
| 552 |
+
timer.report_all() # Print all recorded times
|
| 553 |
print("Predict is complete.")
|
| 554 |
|
| 555 |
return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
|
|
|
| 598 |
|
| 599 |
|
| 600 |
def main():
|
| 601 |
+
# Custom CSS to set the height of the gr.Dropdown menu
|
| 602 |
+
css = """
|
| 603 |
+
div.progress-level div.progress-level-inner {
|
| 604 |
+
text-align: left !important;
|
| 605 |
+
width: 55.5% !important;
|
| 606 |
+
}
|
| 607 |
+
"""
|
| 608 |
+
|
| 609 |
args = parse_args()
|
| 610 |
|
| 611 |
predictor = Predictor()
|
|
|
|
| 632 |
META_LLAMA_3_8B_REPO,
|
| 633 |
]
|
| 634 |
|
| 635 |
+
with gr.Blocks(title=TITLE, css = css) as demo:
|
| 636 |
gr.Markdown(
|
| 637 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
| 638 |
)
|
|
|
|
| 643 |
with gr.Column(variant="panel"):
|
| 644 |
# Create an Image component for uploading images
|
| 645 |
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
|
|
|
|
| 646 |
with gr.Row():
|
| 647 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 648 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 649 |
+
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 650 |
|
| 651 |
model_repo = gr.Dropdown(
|
| 652 |
dropdown_list,
|