Spaces:
Running
Running
Commit
·
9ee88e4
1
Parent(s):
104e60a
Update: switch to new models with character support
Browse files
app.py
CHANGED
|
@@ -20,7 +20,12 @@ from Utils import dbimutils
|
|
| 20 |
|
| 21 |
TITLE = "WaifuDiffusion v1.4 Tags"
|
| 22 |
DESCRIPTION = """
|
| 23 |
-
Demo for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
|
| 26 |
Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
|
|
@@ -31,8 +36,9 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
| 34 |
-
|
| 35 |
-
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger"
|
|
|
|
| 36 |
MODEL_FILENAME = "model.onnx"
|
| 37 |
LABEL_FILENAME = "selected_tags.csv"
|
| 38 |
|
|
@@ -40,7 +46,8 @@ LABEL_FILENAME = "selected_tags.csv"
|
|
| 40 |
def parse_args() -> argparse.Namespace:
|
| 41 |
parser = argparse.ArgumentParser()
|
| 42 |
parser.add_argument("--score-slider-step", type=float, default=0.05)
|
| 43 |
-
parser.add_argument("--score-threshold", type=float, default=0.35)
|
|
|
|
| 44 |
parser.add_argument("--share", action="store_true")
|
| 45 |
return parser.parse_args()
|
| 46 |
|
|
@@ -53,12 +60,31 @@ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
|
|
| 53 |
return model
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def load_labels() -> list[str]:
|
| 57 |
path = huggingface_hub.hf_hub_download(
|
| 58 |
-
|
| 59 |
)
|
| 60 |
-
df = pd.read_csv(path)
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
def plaintext_to_html(text):
|
|
@@ -70,14 +96,22 @@ def plaintext_to_html(text):
|
|
| 70 |
|
| 71 |
def predict(
|
| 72 |
image: PIL.Image.Image,
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
):
|
|
|
|
|
|
|
| 78 |
rawimage = image
|
| 79 |
|
| 80 |
-
model =
|
|
|
|
|
|
|
|
|
|
| 81 |
_, height, width, _ = model.get_inputs()[0].shape
|
| 82 |
|
| 83 |
# Alpha to white
|
|
@@ -99,18 +133,23 @@ def predict(
|
|
| 99 |
label_name = model.get_outputs()[0].name
|
| 100 |
probs = model.run([label_name], {input_name: image})[0]
|
| 101 |
|
| 102 |
-
labels = list(zip(
|
| 103 |
|
| 104 |
# First 4 labels are actually ratings: pick one with argmax
|
| 105 |
-
ratings_names = labels[
|
| 106 |
rating = dict(ratings_names)
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
b = dict(sorted(
|
| 114 |
a = (
|
| 115 |
", ".join(list(b.keys()))
|
| 116 |
.replace("_", " ")
|
|
@@ -167,40 +206,57 @@ def predict(
|
|
| 167 |
message = "Nothing found in the image."
|
| 168 |
info = f"<div><p>{message}<p></div>"
|
| 169 |
|
| 170 |
-
return (a, c, rating,
|
| 171 |
|
| 172 |
|
| 173 |
def main():
|
|
|
|
|
|
|
|
|
|
| 174 |
args = parse_args()
|
| 175 |
-
vit_model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
|
| 176 |
-
conv_model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
| 177 |
-
labels = load_labels()
|
| 178 |
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
func = functools.partial(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
gr.Interface(
|
| 184 |
fn=func,
|
| 185 |
inputs=[
|
| 186 |
gr.Image(type="pil", label="Input"),
|
| 187 |
-
gr.Radio(["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
gr.Slider(
|
| 189 |
0,
|
| 190 |
1,
|
| 191 |
step=args.score_slider_step,
|
| 192 |
-
value=args.
|
| 193 |
-
label="
|
| 194 |
),
|
| 195 |
],
|
| 196 |
outputs=[
|
| 197 |
gr.Textbox(label="Output (string)"),
|
| 198 |
gr.Textbox(label="Output (raw string)"),
|
| 199 |
gr.Label(label="Rating"),
|
| 200 |
-
gr.Label(label="Output (
|
|
|
|
| 201 |
gr.HTML(),
|
| 202 |
],
|
| 203 |
-
examples=[["power.jpg", "
|
| 204 |
title=TITLE,
|
| 205 |
description=DESCRIPTION,
|
| 206 |
allow_flagging="never",
|
|
|
|
| 20 |
|
| 21 |
TITLE = "WaifuDiffusion v1.4 Tags"
|
| 22 |
DESCRIPTION = """
|
| 23 |
+
Demo for:
|
| 24 |
+
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
| 25 |
+
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
| 26 |
+
- [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
|
| 27 |
+
|
| 28 |
+
Includes "ready to copy" prompt and a prompt analyzer.
|
| 29 |
|
| 30 |
Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
|
| 31 |
Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
|
|
|
|
| 36 |
"""
|
| 37 |
|
| 38 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
| 39 |
+
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
| 40 |
+
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
| 41 |
+
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
| 42 |
MODEL_FILENAME = "model.onnx"
|
| 43 |
LABEL_FILENAME = "selected_tags.csv"
|
| 44 |
|
|
|
|
| 46 |
def parse_args() -> argparse.Namespace:
|
| 47 |
parser = argparse.ArgumentParser()
|
| 48 |
parser.add_argument("--score-slider-step", type=float, default=0.05)
|
| 49 |
+
parser.add_argument("--score-general-threshold", type=float, default=0.35)
|
| 50 |
+
parser.add_argument("--score-character-threshold", type=float, default=0.85)
|
| 51 |
parser.add_argument("--share", action="store_true")
|
| 52 |
return parser.parse_args()
|
| 53 |
|
|
|
|
| 60 |
return model
|
| 61 |
|
| 62 |
|
| 63 |
+
def change_model(model_name):
|
| 64 |
+
global loaded_models
|
| 65 |
+
|
| 66 |
+
if model_name == "SwinV2":
|
| 67 |
+
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
| 68 |
+
elif model_name == "ConvNext":
|
| 69 |
+
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
| 70 |
+
elif model_name == "ViT":
|
| 71 |
+
model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
|
| 72 |
+
|
| 73 |
+
loaded_models[model_name] = model
|
| 74 |
+
return loaded_models[model_name]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
def load_labels() -> list[str]:
|
| 78 |
path = huggingface_hub.hf_hub_download(
|
| 79 |
+
SWIN_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
|
| 80 |
)
|
| 81 |
+
df = pd.read_csv(path)
|
| 82 |
+
|
| 83 |
+
tag_names = df["name"].tolist()
|
| 84 |
+
rating_indexes = list(np.where(df["category"] == 9)[0])
|
| 85 |
+
general_indexes = list(np.where(df["category"] == 0)[0])
|
| 86 |
+
character_indexes = list(np.where(df["category"] == 4)[0])
|
| 87 |
+
return tag_names, rating_indexes, general_indexes, character_indexes
|
| 88 |
|
| 89 |
|
| 90 |
def plaintext_to_html(text):
|
|
|
|
| 96 |
|
| 97 |
def predict(
|
| 98 |
image: PIL.Image.Image,
|
| 99 |
+
model_name: str,
|
| 100 |
+
general_threshold: float,
|
| 101 |
+
character_threshold: float,
|
| 102 |
+
tag_names: list[str],
|
| 103 |
+
rating_indexes: list[np.int64],
|
| 104 |
+
general_indexes: list[np.int64],
|
| 105 |
+
character_indexes: list[np.int64],
|
| 106 |
):
|
| 107 |
+
global loaded_models
|
| 108 |
+
|
| 109 |
rawimage = image
|
| 110 |
|
| 111 |
+
model = loaded_models[model_name]
|
| 112 |
+
if model is None:
|
| 113 |
+
model = change_model(model_name)
|
| 114 |
+
|
| 115 |
_, height, width, _ = model.get_inputs()[0].shape
|
| 116 |
|
| 117 |
# Alpha to white
|
|
|
|
| 133 |
label_name = model.get_outputs()[0].name
|
| 134 |
probs = model.run([label_name], {input_name: image})[0]
|
| 135 |
|
| 136 |
+
labels = list(zip(tag_names, probs[0].astype(float)))
|
| 137 |
|
| 138 |
# First 4 labels are actually ratings: pick one with argmax
|
| 139 |
+
ratings_names = [labels[i] for i in rating_indexes]
|
| 140 |
rating = dict(ratings_names)
|
| 141 |
|
| 142 |
+
# Then we have general tags: pick any where prediction confidence > threshold
|
| 143 |
+
general_names = [labels[i] for i in general_indexes]
|
| 144 |
+
general_res = [x for x in general_names if x[1] > general_threshold]
|
| 145 |
+
general_res = dict(general_res)
|
| 146 |
+
|
| 147 |
+
# Everything else is characters: pick any where prediction confidence > threshold
|
| 148 |
+
character_names = [labels[i] for i in character_indexes]
|
| 149 |
+
character_res = [x for x in character_names if x[1] > character_threshold]
|
| 150 |
+
character_res = dict(character_res)
|
| 151 |
|
| 152 |
+
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
|
| 153 |
a = (
|
| 154 |
", ".join(list(b.keys()))
|
| 155 |
.replace("_", " ")
|
|
|
|
| 206 |
message = "Nothing found in the image."
|
| 207 |
info = f"<div><p>{message}<p></div>"
|
| 208 |
|
| 209 |
+
return (a, c, rating, character_res, general_res, info)
|
| 210 |
|
| 211 |
|
| 212 |
def main():
|
| 213 |
+
global loaded_models
|
| 214 |
+
loaded_models = {"SwinV2": None, "ConvNext": None, "ViT": None}
|
| 215 |
+
|
| 216 |
args = parse_args()
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
+
swin_model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
| 219 |
+
loaded_models["SwinV2"] = swin_model
|
| 220 |
+
|
| 221 |
+
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
| 222 |
|
| 223 |
+
func = functools.partial(
|
| 224 |
+
predict,
|
| 225 |
+
tag_names=tag_names,
|
| 226 |
+
rating_indexes=rating_indexes,
|
| 227 |
+
general_indexes=general_indexes,
|
| 228 |
+
character_indexes=character_indexes,
|
| 229 |
+
)
|
| 230 |
|
| 231 |
gr.Interface(
|
| 232 |
fn=func,
|
| 233 |
inputs=[
|
| 234 |
gr.Image(type="pil", label="Input"),
|
| 235 |
+
gr.Radio(["SwinV2", "ConvNext", "ViT"], value="SwinV2", label="Model"),
|
| 236 |
+
gr.Slider(
|
| 237 |
+
0,
|
| 238 |
+
1,
|
| 239 |
+
step=args.score_slider_step,
|
| 240 |
+
value=args.score_general_threshold,
|
| 241 |
+
label="General Tags Threshold",
|
| 242 |
+
),
|
| 243 |
gr.Slider(
|
| 244 |
0,
|
| 245 |
1,
|
| 246 |
step=args.score_slider_step,
|
| 247 |
+
value=args.score_character_threshold,
|
| 248 |
+
label="Character Tags Threshold",
|
| 249 |
),
|
| 250 |
],
|
| 251 |
outputs=[
|
| 252 |
gr.Textbox(label="Output (string)"),
|
| 253 |
gr.Textbox(label="Output (raw string)"),
|
| 254 |
gr.Label(label="Rating"),
|
| 255 |
+
gr.Label(label="Output (characters)"),
|
| 256 |
+
gr.Label(label="Output (tags)"),
|
| 257 |
gr.HTML(),
|
| 258 |
],
|
| 259 |
+
examples=[["power.jpg", "SwinV2", 0.5]],
|
| 260 |
title=TITLE,
|
| 261 |
description=DESCRIPTION,
|
| 262 |
allow_flagging="never",
|