Spaces:
Running
Running
Added tag statistics for all files in the output.
Browse filesChanged the output fields to a collapsible Accordion layout to save screen space.
app.py
CHANGED
|
@@ -16,6 +16,7 @@ import time
|
|
| 16 |
from datetime import datetime
|
| 17 |
from collections import defaultdict
|
| 18 |
from classifyTags import classify_tags
|
|
|
|
| 19 |
|
| 20 |
TITLE = "WaifuDiffusion Tagger multiple images/texts"
|
| 21 |
DESCRIPTION = """
|
|
@@ -172,7 +173,7 @@ class Llama3Reorganize:
|
|
| 172 |
|
| 173 |
Args:
|
| 174 |
repoId: LLAMA model repo.
|
| 175 |
-
device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
|
| 176 |
ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
|
| 177 |
localFilesOnly: If True, avoid downloading the file and return the path to the
|
| 178 |
local cached file if it exists.
|
|
@@ -264,7 +265,7 @@ class Llama3Reorganize:
|
|
| 264 |
except Exception as e:
|
| 265 |
self.release_vram()
|
| 266 |
raise e
|
| 267 |
-
|
| 268 |
|
| 269 |
def release_vram(self):
|
| 270 |
try:
|
|
@@ -348,7 +349,7 @@ class Predictor:
|
|
| 348 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
| 349 |
canvas.alpha_composite(image)
|
| 350 |
image = canvas.convert("RGB")
|
| 351 |
-
|
| 352 |
# Pad image to square
|
| 353 |
image_shape = image.size
|
| 354 |
max_dim = max(image_shape)
|
|
@@ -357,14 +358,14 @@ class Predictor:
|
|
| 357 |
|
| 358 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
| 359 |
padded_image.paste(image, (pad_left, pad_top))
|
| 360 |
-
|
| 361 |
# Resize
|
| 362 |
if max_dim != self.model_target_size:
|
| 363 |
padded_image = padded_image.resize(
|
| 364 |
(self.model_target_size, self.model_target_size),
|
| 365 |
Image.BICUBIC,
|
| 366 |
)
|
| 367 |
-
|
| 368 |
# Convert to numpy array
|
| 369 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
| 370 |
|
|
@@ -398,11 +399,11 @@ class Predictor:
|
|
| 398 |
):
|
| 399 |
if not gallery:
|
| 400 |
gr.Warning("No images in the gallery to process.")
|
| 401 |
-
return None, "", "{}", "", "", "", "{}", {}
|
| 402 |
-
|
| 403 |
gallery_len = len(gallery)
|
| 404 |
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
| 405 |
-
|
| 406 |
timer = Timer() # Create a timer
|
| 407 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 408 |
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
|
@@ -416,11 +417,14 @@ class Predictor:
|
|
| 416 |
# Result
|
| 417 |
txt_infos = []
|
| 418 |
output_dir = tempfile.mkdtemp()
|
| 419 |
-
|
| 420 |
last_sorted_general_strings = ""
|
| 421 |
last_classified_tags, last_unclassified_tags = {}, {}
|
| 422 |
last_rating, last_character_res, last_general_res = None, None, None
|
| 423 |
|
|
|
|
|
|
|
|
|
|
| 424 |
llama3_reorganize = None
|
| 425 |
if llama3_reorganize_model_repo:
|
| 426 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
|
@@ -428,7 +432,7 @@ class Predictor:
|
|
| 428 |
current_progress += 1 / progressTotal
|
| 429 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 430 |
timer.checkpoint(f"Initialize llama3 model")
|
| 431 |
-
|
| 432 |
timer.report()
|
| 433 |
|
| 434 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
|
@@ -457,11 +461,11 @@ class Predictor:
|
|
| 457 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 458 |
|
| 459 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
| 460 |
-
|
| 461 |
# First 4 labels are actually ratings: pick one with argmax
|
| 462 |
ratings_names = [labels[i] for i in self.rating_indexes]
|
| 463 |
rating = dict(ratings_names)
|
| 464 |
-
|
| 465 |
# Then we have general tags: pick any where prediction confidence > threshold
|
| 466 |
general_names = [labels[i] for i in self.general_indexes]
|
| 467 |
|
|
@@ -469,7 +473,7 @@ class Predictor:
|
|
| 469 |
general_probs = np.array([x[1] for x in general_names])
|
| 470 |
general_thresh = mcut_threshold(general_probs)
|
| 471 |
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
| 472 |
-
|
| 473 |
# Everything else is characters: pick any where prediction confidence > threshold
|
| 474 |
character_names = [labels[i] for i in self.character_indexes]
|
| 475 |
|
|
@@ -493,19 +497,22 @@ class Predictor:
|
|
| 493 |
final_tags_list = prepend_list + sorted_general_list + append_list
|
| 494 |
if characters_merge_enabled:
|
| 495 |
final_tags_list = character_list + final_tags_list
|
| 496 |
-
|
| 497 |
# Apply removal logic
|
| 498 |
if remove_list:
|
| 499 |
remove_set = set(remove_list)
|
| 500 |
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
| 501 |
|
|
|
|
|
|
|
|
|
|
| 502 |
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
| 503 |
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
| 504 |
|
| 505 |
current_progress += progressRatio / progressTotal
|
| 506 |
progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, predict finished")
|
| 507 |
timer.checkpoint(f"Image {idx+1}/{gallery_len}, predict finished")
|
| 508 |
-
|
| 509 |
if llama3_reorganize:
|
| 510 |
print(f"Starting reorganize with llama3...")
|
| 511 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
|
@@ -523,7 +530,7 @@ class Predictor:
|
|
| 523 |
txt_infos.append({"path": txt_file, "name": image_name + ".txt"})
|
| 524 |
|
| 525 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
| 526 |
-
|
| 527 |
# Store last result for UI display
|
| 528 |
last_sorted_general_strings = sorted_general_strings
|
| 529 |
last_classified_tags = classified_tags
|
|
@@ -537,7 +544,7 @@ class Predictor:
|
|
| 537 |
print(traceback.format_exc())
|
| 538 |
print("Error predicting image: " + str(e))
|
| 539 |
gr.Warning(f"Failed to process image {os.path.basename(value[0])}. Error: {e}")
|
| 540 |
-
|
| 541 |
# Result
|
| 542 |
download = []
|
| 543 |
if txt_infos:
|
|
@@ -548,16 +555,20 @@ class Predictor:
|
|
| 548 |
# Get file name from lookup
|
| 549 |
taggers_zip.write(info["path"], arcname=info["name"])
|
| 550 |
download.append(downloadZipPath)
|
| 551 |
-
|
| 552 |
if llama3_reorganize:
|
| 553 |
llama3_reorganize.release_vram()
|
| 554 |
|
| 555 |
progress(1, desc="Image processing completed")
|
| 556 |
timer.report_all()
|
| 557 |
print("Image prediction is complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
-
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
| 560 |
-
|
| 561 |
# Method to process text files
|
| 562 |
def predict_from_text(
|
| 563 |
self,
|
|
@@ -570,7 +581,7 @@ class Predictor:
|
|
| 570 |
):
|
| 571 |
if not text_files:
|
| 572 |
gr.Warning("No text files uploaded to process.")
|
| 573 |
-
return None, "", "{}", "", "", "", "{}", {}
|
| 574 |
|
| 575 |
files_len = len(text_files)
|
| 576 |
print(f"Predict from text: processing {files_len} files.")
|
|
@@ -579,10 +590,13 @@ class Predictor:
|
|
| 579 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
| 580 |
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
| 581 |
current_progress = 0
|
| 582 |
-
|
| 583 |
txt_infos = []
|
| 584 |
output_dir = tempfile.mkdtemp()
|
| 585 |
last_processed_string = ""
|
|
|
|
|
|
|
|
|
|
| 586 |
|
| 587 |
llama3_reorganize = None
|
| 588 |
if llama3_reorganize_model_repo:
|
|
@@ -591,7 +605,7 @@ class Predictor:
|
|
| 591 |
current_progress += 1 / progressTotal
|
| 592 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 593 |
timer.checkpoint(f"Initialize llama3 model")
|
| 594 |
-
|
| 595 |
timer.report()
|
| 596 |
|
| 597 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
|
@@ -599,7 +613,7 @@ class Predictor:
|
|
| 599 |
remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
|
| 600 |
if prepend_list and append_list:
|
| 601 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 602 |
-
|
| 603 |
name_counters = defaultdict(int)
|
| 604 |
for idx, file_obj in enumerate(text_files):
|
| 605 |
try:
|
|
@@ -614,7 +628,7 @@ class Predictor:
|
|
| 614 |
|
| 615 |
with open(file_path, 'r', encoding='utf-8') as f:
|
| 616 |
original_content = f.read()
|
| 617 |
-
|
| 618 |
# Process tags
|
| 619 |
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
| 620 |
|
|
@@ -629,9 +643,12 @@ class Predictor:
|
|
| 629 |
if remove_list:
|
| 630 |
remove_set = set(remove_list)
|
| 631 |
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
processed_string = ", ".join(final_tags_list)
|
| 634 |
-
|
| 635 |
current_progress += progressRatio / progressTotal
|
| 636 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, base processing finished")
|
| 637 |
timer.checkpoint(f"File {idx+1}/{files_len}, base processing finished")
|
|
@@ -644,16 +661,16 @@ class Predictor:
|
|
| 644 |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
| 645 |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
| 646 |
processed_string += "," + reorganize_strings
|
| 647 |
-
|
| 648 |
current_progress += progressRatio / progressTotal
|
| 649 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 650 |
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 651 |
-
|
| 652 |
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
| 653 |
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
| 654 |
last_processed_string = processed_string
|
| 655 |
timer.report()
|
| 656 |
-
|
| 657 |
except Exception as e:
|
| 658 |
print(traceback.format_exc())
|
| 659 |
print("Error processing text file: " + str(e))
|
|
@@ -675,8 +692,12 @@ class Predictor:
|
|
| 675 |
timer.report_all() # Print all recorded times
|
| 676 |
print("Text processing is complete.")
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
# Return values in the same structure as the image path, with placeholders for unused outputs
|
| 679 |
-
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
| 680 |
|
| 681 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
| 682 |
if not selected_state:
|
|
@@ -703,7 +724,7 @@ def extend_gallery(gallery: list, images):
|
|
| 703 |
gallery = []
|
| 704 |
if not images:
|
| 705 |
return gallery
|
| 706 |
-
|
| 707 |
# Combine the new images with the existing gallery images
|
| 708 |
gallery.extend(images)
|
| 709 |
|
|
@@ -732,6 +753,18 @@ def main():
|
|
| 732 |
text-align: left !important;
|
| 733 |
width: 55.5% !important;
|
| 734 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
"""
|
| 736 |
args = parse_args()
|
| 737 |
|
|
@@ -758,7 +791,7 @@ def main():
|
|
| 758 |
META_LLAMA_3_3B_REPO,
|
| 759 |
META_LLAMA_3_8B_REPO,
|
| 760 |
]
|
| 761 |
-
|
| 762 |
# Wrapper function to decide which prediction method to call
|
| 763 |
def run_prediction(
|
| 764 |
input_type, gallery, text_files, model_repo, general_thresh,
|
|
@@ -785,18 +818,18 @@ def main():
|
|
| 785 |
with gr.Blocks(title=TITLE, css=css) as demo:
|
| 786 |
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
|
| 787 |
gr.Markdown(value=DESCRIPTION)
|
| 788 |
-
|
| 789 |
with gr.Row():
|
| 790 |
with gr.Column():
|
| 791 |
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
| 792 |
-
|
| 793 |
# Input type selector
|
| 794 |
input_type_radio = gr.Radio(
|
| 795 |
-
choices=['Image', 'Text file (.txt)'],
|
| 796 |
-
value='Image',
|
| 797 |
label="Input Type"
|
| 798 |
)
|
| 799 |
-
|
| 800 |
# Group for image inputs, initially visible
|
| 801 |
with gr.Column(visible=True) as image_inputs_group:
|
| 802 |
with gr.Column(variant="panel"):
|
|
@@ -806,7 +839,7 @@ def main():
|
|
| 806 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 807 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 808 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 809 |
-
|
| 810 |
# Group for text file inputs, initially hidden
|
| 811 |
with gr.Column(visible=False) as text_inputs_group:
|
| 812 |
text_files_input = gr.Files(
|
|
@@ -856,7 +889,7 @@ def main():
|
|
| 856 |
scale=1,
|
| 857 |
visible=True,
|
| 858 |
)
|
| 859 |
-
|
| 860 |
# Common settings
|
| 861 |
with gr.Row():
|
| 862 |
llama3_reorganize_model_repo = gr.Dropdown(
|
|
@@ -868,11 +901,11 @@ def main():
|
|
| 868 |
with gr.Row():
|
| 869 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
| 870 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
| 871 |
-
|
| 872 |
# Add the remove tags input box
|
| 873 |
with gr.Row():
|
| 874 |
tags_to_remove = gr.Text(label="Remove tags (comma split)")
|
| 875 |
-
|
| 876 |
with gr.Row():
|
| 877 |
clear = gr.ClearButton(
|
| 878 |
components=[
|
|
@@ -895,13 +928,26 @@ def main():
|
|
| 895 |
|
| 896 |
with gr.Column(variant="panel"):
|
| 897 |
download_file = gr.File(label="Output (Download)")
|
| 898 |
-
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True, lines=5)
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 905 |
clear.add(
|
| 906 |
[
|
| 907 |
download_file,
|
|
@@ -911,12 +957,13 @@ def main():
|
|
| 911 |
character_res,
|
| 912 |
general_res,
|
| 913 |
unclassified,
|
|
|
|
| 914 |
]
|
| 915 |
)
|
| 916 |
|
| 917 |
tag_results = gr.State({})
|
| 918 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 919 |
-
|
| 920 |
# Event Listeners
|
| 921 |
# Define the event listener to add the uploaded image to the gallery
|
| 922 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
|
@@ -955,7 +1002,7 @@ def main():
|
|
| 955 |
categorized, rating, character_res, general_res, unclassified
|
| 956 |
]
|
| 957 |
)
|
| 958 |
-
|
| 959 |
# submit click now calls the wrapper function
|
| 960 |
submit.click(
|
| 961 |
fn=run_prediction,
|
|
@@ -975,11 +1022,11 @@ def main():
|
|
| 975 |
tags_to_remove,
|
| 976 |
tag_results,
|
| 977 |
],
|
| 978 |
-
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|
| 979 |
)
|
| 980 |
-
|
| 981 |
gr.Examples(
|
| 982 |
-
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
|
| 983 |
inputs=[
|
| 984 |
image_input,
|
| 985 |
model_repo,
|
|
|
|
| 16 |
from datetime import datetime
|
| 17 |
from collections import defaultdict
|
| 18 |
from classifyTags import classify_tags
|
| 19 |
+
from collections import Counter # Import Counter for statistics
|
| 20 |
|
| 21 |
TITLE = "WaifuDiffusion Tagger multiple images/texts"
|
| 22 |
DESCRIPTION = """
|
|
|
|
| 173 |
|
| 174 |
Args:
|
| 175 |
repoId: LLAMA model repo.
|
| 176 |
+
device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
|
| 177 |
ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
|
| 178 |
localFilesOnly: If True, avoid downloading the file and return the path to the
|
| 179 |
local cached file if it exists.
|
|
|
|
| 265 |
except Exception as e:
|
| 266 |
self.release_vram()
|
| 267 |
raise e
|
| 268 |
+
|
| 269 |
|
| 270 |
def release_vram(self):
|
| 271 |
try:
|
|
|
|
| 349 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
| 350 |
canvas.alpha_composite(image)
|
| 351 |
image = canvas.convert("RGB")
|
| 352 |
+
|
| 353 |
# Pad image to square
|
| 354 |
image_shape = image.size
|
| 355 |
max_dim = max(image_shape)
|
|
|
|
| 358 |
|
| 359 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
| 360 |
padded_image.paste(image, (pad_left, pad_top))
|
| 361 |
+
|
| 362 |
# Resize
|
| 363 |
if max_dim != self.model_target_size:
|
| 364 |
padded_image = padded_image.resize(
|
| 365 |
(self.model_target_size, self.model_target_size),
|
| 366 |
Image.BICUBIC,
|
| 367 |
)
|
| 368 |
+
|
| 369 |
# Convert to numpy array
|
| 370 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
| 371 |
|
|
|
|
| 399 |
):
|
| 400 |
if not gallery:
|
| 401 |
gr.Warning("No images in the gallery to process.")
|
| 402 |
+
return None, "", "{}", "", "", "", "{}", {}, ""
|
| 403 |
+
|
| 404 |
gallery_len = len(gallery)
|
| 405 |
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
| 406 |
+
|
| 407 |
timer = Timer() # Create a timer
|
| 408 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 409 |
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
|
|
|
| 417 |
# Result
|
| 418 |
txt_infos = []
|
| 419 |
output_dir = tempfile.mkdtemp()
|
| 420 |
+
|
| 421 |
last_sorted_general_strings = ""
|
| 422 |
last_classified_tags, last_unclassified_tags = {}, {}
|
| 423 |
last_rating, last_character_res, last_general_res = None, None, None
|
| 424 |
|
| 425 |
+
# Initialize counter for statistics
|
| 426 |
+
tag_counter = Counter()
|
| 427 |
+
|
| 428 |
llama3_reorganize = None
|
| 429 |
if llama3_reorganize_model_repo:
|
| 430 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
|
|
|
| 432 |
current_progress += 1 / progressTotal
|
| 433 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 434 |
timer.checkpoint(f"Initialize llama3 model")
|
| 435 |
+
|
| 436 |
timer.report()
|
| 437 |
|
| 438 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
|
|
|
| 461 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 462 |
|
| 463 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
| 464 |
+
|
| 465 |
# First 4 labels are actually ratings: pick one with argmax
|
| 466 |
ratings_names = [labels[i] for i in self.rating_indexes]
|
| 467 |
rating = dict(ratings_names)
|
| 468 |
+
|
| 469 |
# Then we have general tags: pick any where prediction confidence > threshold
|
| 470 |
general_names = [labels[i] for i in self.general_indexes]
|
| 471 |
|
|
|
|
| 473 |
general_probs = np.array([x[1] for x in general_names])
|
| 474 |
general_thresh = mcut_threshold(general_probs)
|
| 475 |
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
| 476 |
+
|
| 477 |
# Everything else is characters: pick any where prediction confidence > threshold
|
| 478 |
character_names = [labels[i] for i in self.character_indexes]
|
| 479 |
|
|
|
|
| 497 |
final_tags_list = prepend_list + sorted_general_list + append_list
|
| 498 |
if characters_merge_enabled:
|
| 499 |
final_tags_list = character_list + final_tags_list
|
| 500 |
+
|
| 501 |
# Apply removal logic
|
| 502 |
if remove_list:
|
| 503 |
remove_set = set(remove_list)
|
| 504 |
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
| 505 |
|
| 506 |
+
# Update counter with the final list of tags for this image
|
| 507 |
+
tag_counter.update(final_tags_list)
|
| 508 |
+
|
| 509 |
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
| 510 |
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
| 511 |
|
| 512 |
current_progress += progressRatio / progressTotal
|
| 513 |
progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, predict finished")
|
| 514 |
timer.checkpoint(f"Image {idx+1}/{gallery_len}, predict finished")
|
| 515 |
+
|
| 516 |
if llama3_reorganize:
|
| 517 |
print(f"Starting reorganize with llama3...")
|
| 518 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
|
|
|
| 530 |
txt_infos.append({"path": txt_file, "name": image_name + ".txt"})
|
| 531 |
|
| 532 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
| 533 |
+
|
| 534 |
# Store last result for UI display
|
| 535 |
last_sorted_general_strings = sorted_general_strings
|
| 536 |
last_classified_tags = classified_tags
|
|
|
|
| 544 |
print(traceback.format_exc())
|
| 545 |
print("Error predicting image: " + str(e))
|
| 546 |
gr.Warning(f"Failed to process image {os.path.basename(value[0])}. Error: {e}")
|
| 547 |
+
|
| 548 |
# Result
|
| 549 |
download = []
|
| 550 |
if txt_infos:
|
|
|
|
| 555 |
# Get file name from lookup
|
| 556 |
taggers_zip.write(info["path"], arcname=info["name"])
|
| 557 |
download.append(downloadZipPath)
|
| 558 |
+
|
| 559 |
if llama3_reorganize:
|
| 560 |
llama3_reorganize.release_vram()
|
| 561 |
|
| 562 |
progress(1, desc="Image processing completed")
|
| 563 |
timer.report_all()
|
| 564 |
print("Image prediction is complete.")
|
| 565 |
+
|
| 566 |
+
# Format statistics for output
|
| 567 |
+
stats_list = [f"{tag}: {count}" for tag, count in tag_counter.most_common()]
|
| 568 |
+
statistics_output = "\n".join(stats_list)
|
| 569 |
+
|
| 570 |
+
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results, statistics_output
|
| 571 |
|
|
|
|
|
|
|
| 572 |
# Method to process text files
|
| 573 |
def predict_from_text(
|
| 574 |
self,
|
|
|
|
| 581 |
):
|
| 582 |
if not text_files:
|
| 583 |
gr.Warning("No text files uploaded to process.")
|
| 584 |
+
return None, "", "{}", "", "", "", "{}", {}, ""
|
| 585 |
|
| 586 |
files_len = len(text_files)
|
| 587 |
print(f"Predict from text: processing {files_len} files.")
|
|
|
|
| 590 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
| 591 |
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
| 592 |
current_progress = 0
|
| 593 |
+
|
| 594 |
txt_infos = []
|
| 595 |
output_dir = tempfile.mkdtemp()
|
| 596 |
last_processed_string = ""
|
| 597 |
+
|
| 598 |
+
# Initialize counter for statistics
|
| 599 |
+
tag_counter = Counter()
|
| 600 |
|
| 601 |
llama3_reorganize = None
|
| 602 |
if llama3_reorganize_model_repo:
|
|
|
|
| 605 |
current_progress += 1 / progressTotal
|
| 606 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 607 |
timer.checkpoint(f"Initialize llama3 model")
|
| 608 |
+
|
| 609 |
timer.report()
|
| 610 |
|
| 611 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
|
|
|
| 613 |
remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
|
| 614 |
if prepend_list and append_list:
|
| 615 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 616 |
+
|
| 617 |
name_counters = defaultdict(int)
|
| 618 |
for idx, file_obj in enumerate(text_files):
|
| 619 |
try:
|
|
|
|
| 628 |
|
| 629 |
with open(file_path, 'r', encoding='utf-8') as f:
|
| 630 |
original_content = f.read()
|
| 631 |
+
|
| 632 |
# Process tags
|
| 633 |
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
| 634 |
|
|
|
|
| 643 |
if remove_list:
|
| 644 |
remove_set = set(remove_list)
|
| 645 |
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
| 646 |
+
|
| 647 |
+
# Update counter with the final list of tags for this file
|
| 648 |
+
tag_counter.update(final_tags_list)
|
| 649 |
|
| 650 |
processed_string = ", ".join(final_tags_list)
|
| 651 |
+
|
| 652 |
current_progress += progressRatio / progressTotal
|
| 653 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, base processing finished")
|
| 654 |
timer.checkpoint(f"File {idx+1}/{files_len}, base processing finished")
|
|
|
|
| 661 |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
| 662 |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
| 663 |
processed_string += "," + reorganize_strings
|
| 664 |
+
|
| 665 |
current_progress += progressRatio / progressTotal
|
| 666 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 667 |
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 668 |
+
|
| 669 |
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
| 670 |
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
| 671 |
last_processed_string = processed_string
|
| 672 |
timer.report()
|
| 673 |
+
|
| 674 |
except Exception as e:
|
| 675 |
print(traceback.format_exc())
|
| 676 |
print("Error processing text file: " + str(e))
|
|
|
|
| 692 |
timer.report_all() # Print all recorded times
|
| 693 |
print("Text processing is complete.")
|
| 694 |
|
| 695 |
+
# Format statistics for output
|
| 696 |
+
stats_list = [f"{tag}: {count}" for tag, count in tag_counter.most_common()]
|
| 697 |
+
statistics_output = "\n".join(stats_list)
|
| 698 |
+
|
| 699 |
# Return values in the same structure as the image path, with placeholders for unused outputs
|
| 700 |
+
return download, last_processed_string, "{}", "", "", "", "{}", {}, statistics_output
|
| 701 |
|
| 702 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
| 703 |
if not selected_state:
|
|
|
|
| 724 |
gallery = []
|
| 725 |
if not images:
|
| 726 |
return gallery
|
| 727 |
+
|
| 728 |
# Combine the new images with the existing gallery images
|
| 729 |
gallery.extend(images)
|
| 730 |
|
|
|
|
| 753 |
text-align: left !important;
|
| 754 |
width: 55.5% !important;
|
| 755 |
}
|
| 756 |
+
textarea[rows]:not([rows="1"]) {
|
| 757 |
+
overflow-y: auto !important;
|
| 758 |
+
scrollbar-width: thin !important;
|
| 759 |
+
}
|
| 760 |
+
textarea[rows]:not([rows="1"])::-webkit-scrollbar {
|
| 761 |
+
all: initial !important;
|
| 762 |
+
background: #f1f1f1 !important;
|
| 763 |
+
}
|
| 764 |
+
textarea[rows]:not([rows="1"])::-webkit-scrollbar-thumb {
|
| 765 |
+
all: initial !important;
|
| 766 |
+
background: #a8a8a8 !important;
|
| 767 |
+
}
|
| 768 |
"""
|
| 769 |
args = parse_args()
|
| 770 |
|
|
|
|
| 791 |
META_LLAMA_3_3B_REPO,
|
| 792 |
META_LLAMA_3_8B_REPO,
|
| 793 |
]
|
| 794 |
+
|
| 795 |
# Wrapper function to decide which prediction method to call
|
| 796 |
def run_prediction(
|
| 797 |
input_type, gallery, text_files, model_repo, general_thresh,
|
|
|
|
| 818 |
with gr.Blocks(title=TITLE, css=css) as demo:
|
| 819 |
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
|
| 820 |
gr.Markdown(value=DESCRIPTION)
|
| 821 |
+
|
| 822 |
with gr.Row():
|
| 823 |
with gr.Column():
|
| 824 |
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
| 825 |
+
|
| 826 |
# Input type selector
|
| 827 |
input_type_radio = gr.Radio(
|
| 828 |
+
choices=['Image', 'Text file (.txt)'],
|
| 829 |
+
value='Image',
|
| 830 |
label="Input Type"
|
| 831 |
)
|
| 832 |
+
|
| 833 |
# Group for image inputs, initially visible
|
| 834 |
with gr.Column(visible=True) as image_inputs_group:
|
| 835 |
with gr.Column(variant="panel"):
|
|
|
|
| 839 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 840 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 841 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 842 |
+
|
| 843 |
# Group for text file inputs, initially hidden
|
| 844 |
with gr.Column(visible=False) as text_inputs_group:
|
| 845 |
text_files_input = gr.Files(
|
|
|
|
| 889 |
scale=1,
|
| 890 |
visible=True,
|
| 891 |
)
|
| 892 |
+
|
| 893 |
# Common settings
|
| 894 |
with gr.Row():
|
| 895 |
llama3_reorganize_model_repo = gr.Dropdown(
|
|
|
|
| 901 |
with gr.Row():
|
| 902 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
| 903 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
| 904 |
+
|
| 905 |
# Add the remove tags input box
|
| 906 |
with gr.Row():
|
| 907 |
tags_to_remove = gr.Text(label="Remove tags (comma split)")
|
| 908 |
+
|
| 909 |
with gr.Row():
|
| 910 |
clear = gr.ClearButton(
|
| 911 |
components=[
|
|
|
|
| 928 |
|
| 929 |
with gr.Column(variant="panel"):
|
| 930 |
download_file = gr.File(label="Output (Download)")
|
| 931 |
+
sorted_general_strings = gr.Textbox(label="Output (string for last processed item)", show_label=True, show_copy_button=True, lines=5)
|
| 932 |
+
|
| 933 |
+
with gr.Accordion("Categorized (tags)", open=False):
|
| 934 |
+
categorized = gr.JSON(label="Categorized")
|
| 935 |
+
|
| 936 |
+
with gr.Accordion("Detailed Output (for last processed item)", open=False):
|
| 937 |
+
rating = gr.Label(label="Rating", visible=True)
|
| 938 |
+
character_res = gr.Label(label="Output (characters)", visible=True)
|
| 939 |
+
general_res = gr.Label(label="Output (tags)", visible=True)
|
| 940 |
+
unclassified = gr.JSON(label="Unclassified (tags)", visible=True)
|
| 941 |
+
|
| 942 |
+
with gr.Accordion("Tags Statistics (All files)", open=False):
|
| 943 |
+
tags_statistics = gr.Text(
|
| 944 |
+
label="Statistics",
|
| 945 |
+
autoscroll=False,
|
| 946 |
+
show_label=False,
|
| 947 |
+
show_copy_button=True,
|
| 948 |
+
lines=10,
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
clear.add(
|
| 952 |
[
|
| 953 |
download_file,
|
|
|
|
| 957 |
character_res,
|
| 958 |
general_res,
|
| 959 |
unclassified,
|
| 960 |
+
tags_statistics,
|
| 961 |
]
|
| 962 |
)
|
| 963 |
|
| 964 |
tag_results = gr.State({})
|
| 965 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 966 |
+
|
| 967 |
# Event Listeners
|
| 968 |
# Define the event listener to add the uploaded image to the gallery
|
| 969 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
|
|
|
| 1002 |
categorized, rating, character_res, general_res, unclassified
|
| 1003 |
]
|
| 1004 |
)
|
| 1005 |
+
|
| 1006 |
# submit click now calls the wrapper function
|
| 1007 |
submit.click(
|
| 1008 |
fn=run_prediction,
|
|
|
|
| 1022 |
tags_to_remove,
|
| 1023 |
tag_results,
|
| 1024 |
],
|
| 1025 |
+
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results, tags_statistics],
|
| 1026 |
)
|
| 1027 |
+
|
| 1028 |
gr.Examples(
|
| 1029 |
+
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
|
| 1030 |
inputs=[
|
| 1031 |
image_input,
|
| 1032 |
model_repo,
|