Make sure to always load the highest trained safetensors file for all cases (#36)
Browse files- Make sure to always load the highest trained safetensors file for all cases (c576d4ca850474f1cbb84f4ca2ff8ff449ea1f68)
Co-authored-by: Sylvain Filoni <[email protected]>
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_d
|
|
| 12 |
import copy
|
| 13 |
import random
|
| 14 |
import time
|
|
|
|
| 15 |
|
| 16 |
# Load LoRAs from JSON file
|
| 17 |
with open('loras.json', 'r') as f:
|
|
@@ -172,30 +173,73 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
| 172 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
| 173 |
|
| 174 |
def get_huggingface_safetensors(link):
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
def check_custom_model(link):
|
| 201 |
if(link.startswith("https://")):
|
|
|
|
| 12 |
import copy
|
| 13 |
import random
|
| 14 |
import time
|
| 15 |
+
import re
|
| 16 |
|
| 17 |
# Load LoRAs from JSON file
|
| 18 |
with open('loras.json', 'r') as f:
|
|
|
|
| 173 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
| 174 |
|
| 175 |
def get_huggingface_safetensors(link):
|
| 176 |
+
split_link = link.split("/")
|
| 177 |
+
if len(split_link) != 2:
|
| 178 |
+
raise Exception("Invalid Hugging Face repository link format.")
|
| 179 |
+
|
| 180 |
+
# Load model card
|
| 181 |
+
model_card = ModelCard.load(link)
|
| 182 |
+
base_model = model_card.data.get("base_model")
|
| 183 |
+
print(base_model)
|
| 184 |
+
|
| 185 |
+
# Validate model type
|
| 186 |
+
if base_model not in {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}:
|
| 187 |
+
raise Exception("Not a FLUX LoRA!")
|
| 188 |
+
|
| 189 |
+
# Extract image and trigger word
|
| 190 |
+
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
| 191 |
+
trigger_word = model_card.data.get("instance_prompt", "")
|
| 192 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
| 193 |
+
|
| 194 |
+
# Initialize Hugging Face file system
|
| 195 |
+
fs = HfFileSystem()
|
| 196 |
+
try:
|
| 197 |
+
list_of_files = fs.ls(link, detail=False)
|
| 198 |
+
|
| 199 |
+
# Initialize variables for safetensors selection
|
| 200 |
+
safetensors_name = None
|
| 201 |
+
highest_trained_file = None
|
| 202 |
+
highest_steps = -1
|
| 203 |
+
last_safetensors_file = None
|
| 204 |
+
step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
|
| 205 |
+
|
| 206 |
+
for file in list_of_files:
|
| 207 |
+
filename = file.split("/")[-1]
|
| 208 |
+
|
| 209 |
+
# Select safetensors file
|
| 210 |
+
if filename.endswith(".safetensors"):
|
| 211 |
+
last_safetensors_file = filename # Track last encountered file
|
| 212 |
+
|
| 213 |
+
match = step_pattern.search(filename)
|
| 214 |
+
if not match:
|
| 215 |
+
# Found a full model without step numbers, return immediately
|
| 216 |
+
safetensors_name = filename
|
| 217 |
+
break
|
| 218 |
+
else:
|
| 219 |
+
# Extract step count and track highest
|
| 220 |
+
steps = int(match.group().lstrip("_"))
|
| 221 |
+
if steps > highest_steps:
|
| 222 |
+
highest_trained_file = filename
|
| 223 |
+
highest_steps = steps
|
| 224 |
+
|
| 225 |
+
# Select an image file if not found in model card
|
| 226 |
+
if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
|
| 227 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
|
| 228 |
+
|
| 229 |
+
# If no full model found, fall back to the most trained safetensors file
|
| 230 |
+
if not safetensors_name:
|
| 231 |
+
safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
|
| 232 |
+
|
| 233 |
+
# If still no safetensors file found, raise an exception
|
| 234 |
+
if not safetensors_name:
|
| 235 |
+
raise Exception("No valid *.safetensors file found in the repository.")
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(e)
|
| 239 |
+
raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
|
| 240 |
+
|
| 241 |
+
return split_link[1], link, safetensors_name, trigger_word, image_url
|
| 242 |
+
|
| 243 |
|
| 244 |
def check_custom_model(link):
|
| 245 |
if(link.startswith("https://")):
|