avans06 commited on
Commit
75c6415
·
1 Parent(s): 87e5e5f

Added tag statistics for all files in the output.

Browse files

Changed the output fields to a collapsible Accordion layout to save screen space.

Files changed (1) hide show
  1. app.py +100 -53
app.py CHANGED
@@ -16,6 +16,7 @@ import time
16
  from datetime import datetime
17
  from collections import defaultdict
18
  from classifyTags import classify_tags
 
19
 
20
  TITLE = "WaifuDiffusion Tagger multiple images/texts"
21
  DESCRIPTION = """
@@ -172,7 +173,7 @@ class Llama3Reorganize:
172
 
173
  Args:
174
  repoId: LLAMA model repo.
175
- device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
176
  ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
177
  localFilesOnly: If True, avoid downloading the file and return the path to the
178
  local cached file if it exists.
@@ -264,7 +265,7 @@ class Llama3Reorganize:
264
  except Exception as e:
265
  self.release_vram()
266
  raise e
267
-
268
 
269
  def release_vram(self):
270
  try:
@@ -348,7 +349,7 @@ class Predictor:
348
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
349
  canvas.alpha_composite(image)
350
  image = canvas.convert("RGB")
351
-
352
  # Pad image to square
353
  image_shape = image.size
354
  max_dim = max(image_shape)
@@ -357,14 +358,14 @@ class Predictor:
357
 
358
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
359
  padded_image.paste(image, (pad_left, pad_top))
360
-
361
  # Resize
362
  if max_dim != self.model_target_size:
363
  padded_image = padded_image.resize(
364
  (self.model_target_size, self.model_target_size),
365
  Image.BICUBIC,
366
  )
367
-
368
  # Convert to numpy array
369
  image_array = np.asarray(padded_image, dtype=np.float32)
370
 
@@ -398,11 +399,11 @@ class Predictor:
398
  ):
399
  if not gallery:
400
  gr.Warning("No images in the gallery to process.")
401
- return None, "", "{}", "", "", "", "{}", {}
402
-
403
  gallery_len = len(gallery)
404
  print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
405
-
406
  timer = Timer() # Create a timer
407
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1
408
  progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
@@ -416,11 +417,14 @@ class Predictor:
416
  # Result
417
  txt_infos = []
418
  output_dir = tempfile.mkdtemp()
419
-
420
  last_sorted_general_strings = ""
421
  last_classified_tags, last_unclassified_tags = {}, {}
422
  last_rating, last_character_res, last_general_res = None, None, None
423
 
 
 
 
424
  llama3_reorganize = None
425
  if llama3_reorganize_model_repo:
426
  print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
@@ -428,7 +432,7 @@ class Predictor:
428
  current_progress += 1 / progressTotal
429
  progress(current_progress, desc="Initialize llama3 model finished")
430
  timer.checkpoint(f"Initialize llama3 model")
431
-
432
  timer.report()
433
 
434
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
@@ -457,11 +461,11 @@ class Predictor:
457
  preds = self.model.run([label_name], {input_name: image})[0]
458
 
459
  labels = list(zip(self.tag_names, preds[0].astype(float)))
460
-
461
  # First 4 labels are actually ratings: pick one with argmax
462
  ratings_names = [labels[i] for i in self.rating_indexes]
463
  rating = dict(ratings_names)
464
-
465
  # Then we have general tags: pick any where prediction confidence > threshold
466
  general_names = [labels[i] for i in self.general_indexes]
467
 
@@ -469,7 +473,7 @@ class Predictor:
469
  general_probs = np.array([x[1] for x in general_names])
470
  general_thresh = mcut_threshold(general_probs)
471
  general_res = dict([x for x in general_names if x[1] > general_thresh])
472
-
473
  # Everything else is characters: pick any where prediction confidence > threshold
474
  character_names = [labels[i] for i in self.character_indexes]
475
 
@@ -493,19 +497,22 @@ class Predictor:
493
  final_tags_list = prepend_list + sorted_general_list + append_list
494
  if characters_merge_enabled:
495
  final_tags_list = character_list + final_tags_list
496
-
497
  # Apply removal logic
498
  if remove_list:
499
  remove_set = set(remove_list)
500
  final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
501
 
 
 
 
502
  sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
503
  classified_tags, unclassified_tags = classify_tags(final_tags_list)
504
 
505
  current_progress += progressRatio / progressTotal
506
  progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, predict finished")
507
  timer.checkpoint(f"Image {idx+1}/{gallery_len}, predict finished")
508
-
509
  if llama3_reorganize:
510
  print(f"Starting reorganize with llama3...")
511
  reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
@@ -523,7 +530,7 @@ class Predictor:
523
  txt_infos.append({"path": txt_file, "name": image_name + ".txt"})
524
 
525
  tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
526
-
527
  # Store last result for UI display
528
  last_sorted_general_strings = sorted_general_strings
529
  last_classified_tags = classified_tags
@@ -537,7 +544,7 @@ class Predictor:
537
  print(traceback.format_exc())
538
  print("Error predicting image: " + str(e))
539
  gr.Warning(f"Failed to process image {os.path.basename(value[0])}. Error: {e}")
540
-
541
  # Result
542
  download = []
543
  if txt_infos:
@@ -548,16 +555,20 @@ class Predictor:
548
  # Get file name from lookup
549
  taggers_zip.write(info["path"], arcname=info["name"])
550
  download.append(downloadZipPath)
551
-
552
  if llama3_reorganize:
553
  llama3_reorganize.release_vram()
554
 
555
  progress(1, desc="Image processing completed")
556
  timer.report_all()
557
  print("Image prediction is complete.")
 
 
 
 
 
 
558
 
559
- return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
560
-
561
  # Method to process text files
562
  def predict_from_text(
563
  self,
@@ -570,7 +581,7 @@ class Predictor:
570
  ):
571
  if not text_files:
572
  gr.Warning("No text files uploaded to process.")
573
- return None, "", "{}", "", "", "", "{}", {}
574
 
575
  files_len = len(text_files)
576
  print(f"Predict from text: processing {files_len} files.")
@@ -579,10 +590,13 @@ class Predictor:
579
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
580
  progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
581
  current_progress = 0
582
-
583
  txt_infos = []
584
  output_dir = tempfile.mkdtemp()
585
  last_processed_string = ""
 
 
 
586
 
587
  llama3_reorganize = None
588
  if llama3_reorganize_model_repo:
@@ -591,7 +605,7 @@ class Predictor:
591
  current_progress += 1 / progressTotal
592
  progress(current_progress, desc="Initialize llama3 model finished")
593
  timer.checkpoint(f"Initialize llama3 model")
594
-
595
  timer.report()
596
 
597
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
@@ -599,7 +613,7 @@ class Predictor:
599
  remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
600
  if prepend_list and append_list:
601
  append_list = [item for item in append_list if item not in prepend_list]
602
-
603
  name_counters = defaultdict(int)
604
  for idx, file_obj in enumerate(text_files):
605
  try:
@@ -614,7 +628,7 @@ class Predictor:
614
 
615
  with open(file_path, 'r', encoding='utf-8') as f:
616
  original_content = f.read()
617
-
618
  # Process tags
619
  tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
620
 
@@ -629,9 +643,12 @@ class Predictor:
629
  if remove_list:
630
  remove_set = set(remove_list)
631
  final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
 
 
 
632
 
633
  processed_string = ", ".join(final_tags_list)
634
-
635
  current_progress += progressRatio / progressTotal
636
  progress(current_progress, desc=f"File {idx+1}/{files_len}, base processing finished")
637
  timer.checkpoint(f"File {idx+1}/{files_len}, base processing finished")
@@ -644,16 +661,16 @@ class Predictor:
644
  reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
645
  reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
646
  processed_string += "," + reorganize_strings
647
-
648
  current_progress += progressRatio / progressTotal
649
  progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
650
  timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
651
-
652
  txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
653
  txt_infos.append({"path": txt_file_path, "name": output_file_name})
654
  last_processed_string = processed_string
655
  timer.report()
656
-
657
  except Exception as e:
658
  print(traceback.format_exc())
659
  print("Error processing text file: " + str(e))
@@ -675,8 +692,12 @@ class Predictor:
675
  timer.report_all() # Print all recorded times
676
  print("Text processing is complete.")
677
 
 
 
 
 
678
  # Return values in the same structure as the image path, with placeholders for unused outputs
679
- return download, last_processed_string, "{}", "", "", "", "{}", {}
680
 
681
  def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
682
  if not selected_state:
@@ -703,7 +724,7 @@ def extend_gallery(gallery: list, images):
703
  gallery = []
704
  if not images:
705
  return gallery
706
-
707
  # Combine the new images with the existing gallery images
708
  gallery.extend(images)
709
 
@@ -732,6 +753,18 @@ def main():
732
  text-align: left !important;
733
  width: 55.5% !important;
734
  }
 
 
 
 
 
 
 
 
 
 
 
 
735
  """
736
  args = parse_args()
737
 
@@ -758,7 +791,7 @@ def main():
758
  META_LLAMA_3_3B_REPO,
759
  META_LLAMA_3_8B_REPO,
760
  ]
761
-
762
  # Wrapper function to decide which prediction method to call
763
  def run_prediction(
764
  input_type, gallery, text_files, model_repo, general_thresh,
@@ -785,18 +818,18 @@ def main():
785
  with gr.Blocks(title=TITLE, css=css) as demo:
786
  gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
787
  gr.Markdown(value=DESCRIPTION)
788
-
789
  with gr.Row():
790
  with gr.Column():
791
  submit = gr.Button(value="Submit", variant="primary", size="lg")
792
-
793
  # Input type selector
794
  input_type_radio = gr.Radio(
795
- choices=['Image', 'Text file (.txt)'],
796
- value='Image',
797
  label="Input Type"
798
  )
799
-
800
  # Group for image inputs, initially visible
801
  with gr.Column(visible=True) as image_inputs_group:
802
  with gr.Column(variant="panel"):
@@ -806,7 +839,7 @@ def main():
806
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
807
  remove_button = gr.Button("Remove Selected Image", size="sm")
808
  gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
809
-
810
  # Group for text file inputs, initially hidden
811
  with gr.Column(visible=False) as text_inputs_group:
812
  text_files_input = gr.Files(
@@ -856,7 +889,7 @@ def main():
856
  scale=1,
857
  visible=True,
858
  )
859
-
860
  # Common settings
861
  with gr.Row():
862
  llama3_reorganize_model_repo = gr.Dropdown(
@@ -868,11 +901,11 @@ def main():
868
  with gr.Row():
869
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
870
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
871
-
872
  # Add the remove tags input box
873
  with gr.Row():
874
  tags_to_remove = gr.Text(label="Remove tags (comma split)")
875
-
876
  with gr.Row():
877
  clear = gr.ClearButton(
878
  components=[
@@ -895,13 +928,26 @@ def main():
895
 
896
  with gr.Column(variant="panel"):
897
  download_file = gr.File(label="Output (Download)")
898
- sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True, lines=5)
899
- # Image-specific outputs
900
- categorized = gr.JSON(label="Categorized (tags)", visible=True)
901
- rating = gr.Label(label="Rating", visible=True)
902
- character_res = gr.Label(label="Output (characters)", visible=True)
903
- general_res = gr.Label(label="Output (tags)", visible=True)
904
- unclassified = gr.JSON(label="Unclassified (tags)", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
905
  clear.add(
906
  [
907
  download_file,
@@ -911,12 +957,13 @@ def main():
911
  character_res,
912
  general_res,
913
  unclassified,
 
914
  ]
915
  )
916
 
917
  tag_results = gr.State({})
918
  selected_image = gr.Textbox(label="Selected Image", visible=False)
919
-
920
  # Event Listeners
921
  # Define the event listener to add the uploaded image to the gallery
922
  image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
@@ -955,7 +1002,7 @@ def main():
955
  categorized, rating, character_res, general_res, unclassified
956
  ]
957
  )
958
-
959
  # submit click now calls the wrapper function
960
  submit.click(
961
  fn=run_prediction,
@@ -975,11 +1022,11 @@ def main():
975
  tags_to_remove,
976
  tag_results,
977
  ],
978
- outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
979
  )
980
-
981
  gr.Examples(
982
- [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
983
  inputs=[
984
  image_input,
985
  model_repo,
 
16
  from datetime import datetime
17
  from collections import defaultdict
18
  from classifyTags import classify_tags
19
+ from collections import Counter # Import Counter for statistics
20
 
21
  TITLE = "WaifuDiffusion Tagger multiple images/texts"
22
  DESCRIPTION = """
 
173
 
174
  Args:
175
  repoId: LLAMA model repo.
176
+ device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
177
  ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
178
  localFilesOnly: If True, avoid downloading the file and return the path to the
179
  local cached file if it exists.
 
265
  except Exception as e:
266
  self.release_vram()
267
  raise e
268
+
269
 
270
  def release_vram(self):
271
  try:
 
349
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
350
  canvas.alpha_composite(image)
351
  image = canvas.convert("RGB")
352
+
353
  # Pad image to square
354
  image_shape = image.size
355
  max_dim = max(image_shape)
 
358
 
359
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
360
  padded_image.paste(image, (pad_left, pad_top))
361
+
362
  # Resize
363
  if max_dim != self.model_target_size:
364
  padded_image = padded_image.resize(
365
  (self.model_target_size, self.model_target_size),
366
  Image.BICUBIC,
367
  )
368
+
369
  # Convert to numpy array
370
  image_array = np.asarray(padded_image, dtype=np.float32)
371
 
 
399
  ):
400
  if not gallery:
401
  gr.Warning("No images in the gallery to process.")
402
+ return None, "", "{}", "", "", "", "{}", {}, ""
403
+
404
  gallery_len = len(gallery)
405
  print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
406
+
407
  timer = Timer() # Create a timer
408
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1
409
  progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
 
417
  # Result
418
  txt_infos = []
419
  output_dir = tempfile.mkdtemp()
420
+
421
  last_sorted_general_strings = ""
422
  last_classified_tags, last_unclassified_tags = {}, {}
423
  last_rating, last_character_res, last_general_res = None, None, None
424
 
425
+ # Initialize counter for statistics
426
+ tag_counter = Counter()
427
+
428
  llama3_reorganize = None
429
  if llama3_reorganize_model_repo:
430
  print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
 
432
  current_progress += 1 / progressTotal
433
  progress(current_progress, desc="Initialize llama3 model finished")
434
  timer.checkpoint(f"Initialize llama3 model")
435
+
436
  timer.report()
437
 
438
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
 
461
  preds = self.model.run([label_name], {input_name: image})[0]
462
 
463
  labels = list(zip(self.tag_names, preds[0].astype(float)))
464
+
465
  # First 4 labels are actually ratings: pick one with argmax
466
  ratings_names = [labels[i] for i in self.rating_indexes]
467
  rating = dict(ratings_names)
468
+
469
  # Then we have general tags: pick any where prediction confidence > threshold
470
  general_names = [labels[i] for i in self.general_indexes]
471
 
 
473
  general_probs = np.array([x[1] for x in general_names])
474
  general_thresh = mcut_threshold(general_probs)
475
  general_res = dict([x for x in general_names if x[1] > general_thresh])
476
+
477
  # Everything else is characters: pick any where prediction confidence > threshold
478
  character_names = [labels[i] for i in self.character_indexes]
479
 
 
497
  final_tags_list = prepend_list + sorted_general_list + append_list
498
  if characters_merge_enabled:
499
  final_tags_list = character_list + final_tags_list
500
+
501
  # Apply removal logic
502
  if remove_list:
503
  remove_set = set(remove_list)
504
  final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
505
 
506
+ # Update counter with the final list of tags for this image
507
+ tag_counter.update(final_tags_list)
508
+
509
  sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
510
  classified_tags, unclassified_tags = classify_tags(final_tags_list)
511
 
512
  current_progress += progressRatio / progressTotal
513
  progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, predict finished")
514
  timer.checkpoint(f"Image {idx+1}/{gallery_len}, predict finished")
515
+
516
  if llama3_reorganize:
517
  print(f"Starting reorganize with llama3...")
518
  reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
 
530
  txt_infos.append({"path": txt_file, "name": image_name + ".txt"})
531
 
532
  tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
533
+
534
  # Store last result for UI display
535
  last_sorted_general_strings = sorted_general_strings
536
  last_classified_tags = classified_tags
 
544
  print(traceback.format_exc())
545
  print("Error predicting image: " + str(e))
546
  gr.Warning(f"Failed to process image {os.path.basename(value[0])}. Error: {e}")
547
+
548
  # Result
549
  download = []
550
  if txt_infos:
 
555
  # Get file name from lookup
556
  taggers_zip.write(info["path"], arcname=info["name"])
557
  download.append(downloadZipPath)
558
+
559
  if llama3_reorganize:
560
  llama3_reorganize.release_vram()
561
 
562
  progress(1, desc="Image processing completed")
563
  timer.report_all()
564
  print("Image prediction is complete.")
565
+
566
+ # Format statistics for output
567
+ stats_list = [f"{tag}: {count}" for tag, count in tag_counter.most_common()]
568
+ statistics_output = "\n".join(stats_list)
569
+
570
+ return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results, statistics_output
571
 
 
 
572
  # Method to process text files
573
  def predict_from_text(
574
  self,
 
581
  ):
582
  if not text_files:
583
  gr.Warning("No text files uploaded to process.")
584
+ return None, "", "{}", "", "", "", "{}", {}, ""
585
 
586
  files_len = len(text_files)
587
  print(f"Predict from text: processing {files_len} files.")
 
590
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
591
  progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
592
  current_progress = 0
593
+
594
  txt_infos = []
595
  output_dir = tempfile.mkdtemp()
596
  last_processed_string = ""
597
+
598
+ # Initialize counter for statistics
599
+ tag_counter = Counter()
600
 
601
  llama3_reorganize = None
602
  if llama3_reorganize_model_repo:
 
605
  current_progress += 1 / progressTotal
606
  progress(current_progress, desc="Initialize llama3 model finished")
607
  timer.checkpoint(f"Initialize llama3 model")
608
+
609
  timer.report()
610
 
611
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
 
613
  remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
614
  if prepend_list and append_list:
615
  append_list = [item for item in append_list if item not in prepend_list]
616
+
617
  name_counters = defaultdict(int)
618
  for idx, file_obj in enumerate(text_files):
619
  try:
 
628
 
629
  with open(file_path, 'r', encoding='utf-8') as f:
630
  original_content = f.read()
631
+
632
  # Process tags
633
  tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
634
 
 
643
  if remove_list:
644
  remove_set = set(remove_list)
645
  final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
646
+
647
+ # Update counter with the final list of tags for this file
648
+ tag_counter.update(final_tags_list)
649
 
650
  processed_string = ", ".join(final_tags_list)
651
+
652
  current_progress += progressRatio / progressTotal
653
  progress(current_progress, desc=f"File {idx+1}/{files_len}, base processing finished")
654
  timer.checkpoint(f"File {idx+1}/{files_len}, base processing finished")
 
661
  reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
662
  reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
663
  processed_string += "," + reorganize_strings
664
+
665
  current_progress += progressRatio / progressTotal
666
  progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
667
  timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
668
+
669
  txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
670
  txt_infos.append({"path": txt_file_path, "name": output_file_name})
671
  last_processed_string = processed_string
672
  timer.report()
673
+
674
  except Exception as e:
675
  print(traceback.format_exc())
676
  print("Error processing text file: " + str(e))
 
692
  timer.report_all() # Print all recorded times
693
  print("Text processing is complete.")
694
 
695
+ # Format statistics for output
696
+ stats_list = [f"{tag}: {count}" for tag, count in tag_counter.most_common()]
697
+ statistics_output = "\n".join(stats_list)
698
+
699
  # Return values in the same structure as the image path, with placeholders for unused outputs
700
+ return download, last_processed_string, "{}", "", "", "", "{}", {}, statistics_output
701
 
702
  def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
703
  if not selected_state:
 
724
  gallery = []
725
  if not images:
726
  return gallery
727
+
728
  # Combine the new images with the existing gallery images
729
  gallery.extend(images)
730
 
 
753
  text-align: left !important;
754
  width: 55.5% !important;
755
  }
756
+ textarea[rows]:not([rows="1"]) {
757
+ overflow-y: auto !important;
758
+ scrollbar-width: thin !important;
759
+ }
760
+ textarea[rows]:not([rows="1"])::-webkit-scrollbar {
761
+ all: initial !important;
762
+ background: #f1f1f1 !important;
763
+ }
764
+ textarea[rows]:not([rows="1"])::-webkit-scrollbar-thumb {
765
+ all: initial !important;
766
+ background: #a8a8a8 !important;
767
+ }
768
  """
769
  args = parse_args()
770
 
 
791
  META_LLAMA_3_3B_REPO,
792
  META_LLAMA_3_8B_REPO,
793
  ]
794
+
795
  # Wrapper function to decide which prediction method to call
796
  def run_prediction(
797
  input_type, gallery, text_files, model_repo, general_thresh,
 
818
  with gr.Blocks(title=TITLE, css=css) as demo:
819
  gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
820
  gr.Markdown(value=DESCRIPTION)
821
+
822
  with gr.Row():
823
  with gr.Column():
824
  submit = gr.Button(value="Submit", variant="primary", size="lg")
825
+
826
  # Input type selector
827
  input_type_radio = gr.Radio(
828
+ choices=['Image', 'Text file (.txt)'],
829
+ value='Image',
830
  label="Input Type"
831
  )
832
+
833
  # Group for image inputs, initially visible
834
  with gr.Column(visible=True) as image_inputs_group:
835
  with gr.Column(variant="panel"):
 
839
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
840
  remove_button = gr.Button("Remove Selected Image", size="sm")
841
  gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
842
+
843
  # Group for text file inputs, initially hidden
844
  with gr.Column(visible=False) as text_inputs_group:
845
  text_files_input = gr.Files(
 
889
  scale=1,
890
  visible=True,
891
  )
892
+
893
  # Common settings
894
  with gr.Row():
895
  llama3_reorganize_model_repo = gr.Dropdown(
 
901
  with gr.Row():
902
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
903
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
904
+
905
  # Add the remove tags input box
906
  with gr.Row():
907
  tags_to_remove = gr.Text(label="Remove tags (comma split)")
908
+
909
  with gr.Row():
910
  clear = gr.ClearButton(
911
  components=[
 
928
 
929
  with gr.Column(variant="panel"):
930
  download_file = gr.File(label="Output (Download)")
931
+ sorted_general_strings = gr.Textbox(label="Output (string for last processed item)", show_label=True, show_copy_button=True, lines=5)
932
+
933
+ with gr.Accordion("Categorized (tags)", open=False):
934
+ categorized = gr.JSON(label="Categorized")
935
+
936
+ with gr.Accordion("Detailed Output (for last processed item)", open=False):
937
+ rating = gr.Label(label="Rating", visible=True)
938
+ character_res = gr.Label(label="Output (characters)", visible=True)
939
+ general_res = gr.Label(label="Output (tags)", visible=True)
940
+ unclassified = gr.JSON(label="Unclassified (tags)", visible=True)
941
+
942
+ with gr.Accordion("Tags Statistics (All files)", open=False):
943
+ tags_statistics = gr.Text(
944
+ label="Statistics",
945
+ autoscroll=False,
946
+ show_label=False,
947
+ show_copy_button=True,
948
+ lines=10,
949
+ )
950
+
951
  clear.add(
952
  [
953
  download_file,
 
957
  character_res,
958
  general_res,
959
  unclassified,
960
+ tags_statistics,
961
  ]
962
  )
963
 
964
  tag_results = gr.State({})
965
  selected_image = gr.Textbox(label="Selected Image", visible=False)
966
+
967
  # Event Listeners
968
  # Define the event listener to add the uploaded image to the gallery
969
  image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
 
1002
  categorized, rating, character_res, general_res, unclassified
1003
  ]
1004
  )
1005
+
1006
  # submit click now calls the wrapper function
1007
  submit.click(
1008
  fn=run_prediction,
 
1022
  tags_to_remove,
1023
  tag_results,
1024
  ],
1025
+ outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results, tags_statistics],
1026
  )
1027
+
1028
  gr.Examples(
1029
+ [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
1030
  inputs=[
1031
  image_input,
1032
  model_repo,