Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,22 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
import
|
|
|
|
|
|
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
-
import urllib.request
|
| 9 |
import uuid
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
models = [
|
| 14 |
"umm-maybe/AI-image-detector",
|
| 15 |
"Organika/sdxl-detector",
|
|
@@ -21,11 +28,78 @@ pipe1 = pipeline("image-classification", f"{models[1]}")
|
|
| 21 |
pipe2 = pipeline("image-classification", f"{models[2]}")
|
| 22 |
|
| 23 |
fin_sum = []
|
|
|
|
| 24 |
|
|
|
|
| 25 |
def softmax(vector):
|
| 26 |
e = exp(vector - vector.max()) # for numerical stability
|
| 27 |
return e / e.sum()
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def image_classifier0(image):
|
| 30 |
labels = ["AI", "Real"]
|
| 31 |
outputs = pipe0(image)
|
|
@@ -70,8 +144,8 @@ def aiornot0(image):
|
|
| 70 |
html_out = f"""
|
| 71 |
<h1>This image is likely: {label}</h1><br><h3>
|
| 72 |
Probabilities:<br>
|
| 73 |
-
Real: {float(px[1][0])}<br>
|
| 74 |
-
AI: {float(px[0][0])}"""
|
| 75 |
|
| 76 |
results = {
|
| 77 |
"Real": float(px[1][0]),
|
|
@@ -97,8 +171,8 @@ def aiornot1(image):
|
|
| 97 |
html_out = f"""
|
| 98 |
<h1>This image is likely: {label}</h1><br><h3>
|
| 99 |
Probabilities:<br>
|
| 100 |
-
Real: {float(px[1][0])}<br>
|
| 101 |
-
AI: {float(px[0][0])}"""
|
| 102 |
|
| 103 |
results = {
|
| 104 |
"Real": float(px[1][0]),
|
|
@@ -124,8 +198,8 @@ def aiornot2(image):
|
|
| 124 |
html_out = f"""
|
| 125 |
<h1>This image is likely: {label}</h1><br><h3>
|
| 126 |
Probabilities:<br>
|
| 127 |
-
Real: {float(px[1][0])}<br>
|
| 128 |
-
AI: {float(px[0][0])}"""
|
| 129 |
|
| 130 |
results = {
|
| 131 |
"Real": float(px[1][0]),
|
|
@@ -149,8 +223,8 @@ def tot_prob():
|
|
| 149 |
fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
|
| 150 |
fin_sub = 1 - fin_out
|
| 151 |
out = {
|
| 152 |
-
"Real": f"{fin_out}",
|
| 153 |
-
"AI": f"{fin_sub}"
|
| 154 |
}
|
| 155 |
return out
|
| 156 |
except Exception as e:
|
|
@@ -167,50 +241,56 @@ def upd(image):
|
|
| 167 |
out = Image.open(f"{rand_im}-vid_tmp_proc.png")
|
| 168 |
return out
|
| 169 |
|
|
|
|
| 170 |
with gr.Blocks() as app:
|
| 171 |
-
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)""")
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
with gr.
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
|
| 4 |
+
import os
|
| 5 |
+
import zipfile
|
| 6 |
+
import shutil
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
from PIL import Image
|
|
|
|
| 11 |
import uuid
|
| 12 |
+
import tempfile
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from numpy import exp
|
| 15 |
+
import numpy as np
|
| 16 |
+
from sklearn.metrics import ConfusionMatrixDisplay
|
| 17 |
+
import urllib.request
|
| 18 |
|
| 19 |
+
# Define models
|
| 20 |
models = [
|
| 21 |
"umm-maybe/AI-image-detector",
|
| 22 |
"Organika/sdxl-detector",
|
|
|
|
| 28 |
pipe2 = pipeline("image-classification", f"{models[2]}")
|
| 29 |
|
| 30 |
fin_sum = []
|
| 31 |
+
uid = uuid.uuid4()
|
| 32 |
|
| 33 |
+
# Softmax function
|
| 34 |
def softmax(vector):
|
| 35 |
e = exp(vector - vector.max()) # for numerical stability
|
| 36 |
return e / e.sum()
|
| 37 |
|
| 38 |
+
# Function to extract images from zip
|
| 39 |
+
def extract_zip(zip_file):
|
| 40 |
+
temp_dir = tempfile.mkdtemp() # Temporary directory
|
| 41 |
+
with zipfile.ZipFile(zip_file, 'r') as z:
|
| 42 |
+
z.extractall(temp_dir)
|
| 43 |
+
return temp_dir
|
| 44 |
+
|
| 45 |
+
# Function to classify images in a folder
|
| 46 |
+
def classify_images(image_dir, model_pipeline):
|
| 47 |
+
images = []
|
| 48 |
+
labels = []
|
| 49 |
+
preds = []
|
| 50 |
+
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
| 51 |
+
folder_path = os.path.join(image_dir, folder_name)
|
| 52 |
+
if not os.path.exists(folder_path):
|
| 53 |
+
continue
|
| 54 |
+
for img_name in os.listdir(folder_path):
|
| 55 |
+
img_path = os.path.join(folder_path, img_name)
|
| 56 |
+
try:
|
| 57 |
+
img = Image.open(img_path).convert("RGB")
|
| 58 |
+
pred = model_pipeline(img)
|
| 59 |
+
pred_label = np.argmax([x['score'] for x in pred])
|
| 60 |
+
preds.append(pred_label)
|
| 61 |
+
labels.append(ground_truth_label)
|
| 62 |
+
images.append(img_name)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error processing image {img_name}: {e}")
|
| 65 |
+
return labels, preds, images
|
| 66 |
+
|
| 67 |
+
# Function to generate evaluation metrics
|
| 68 |
+
def evaluate_model(labels, preds):
|
| 69 |
+
cm = confusion_matrix(labels, preds)
|
| 70 |
+
accuracy = accuracy_score(labels, preds)
|
| 71 |
+
roc_score = roc_auc_score(labels, preds)
|
| 72 |
+
report = classification_report(labels, preds)
|
| 73 |
+
fpr, tpr, _ = roc_curve(labels, preds)
|
| 74 |
+
roc_auc = auc(fpr, tpr)
|
| 75 |
+
|
| 76 |
+
fig, ax = plt.subplots()
|
| 77 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AI", "Real"])
|
| 78 |
+
disp.plot(cmap=plt.cm.Blues, ax=ax)
|
| 79 |
+
plt.close(fig)
|
| 80 |
+
|
| 81 |
+
fig_roc, ax_roc = plt.subplots()
|
| 82 |
+
ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
| 83 |
+
ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
| 84 |
+
ax_roc.set_xlim([0.0, 1.0])
|
| 85 |
+
ax_roc.set_ylim([0.0, 1.05])
|
| 86 |
+
ax_roc.set_xlabel('False Positive Rate')
|
| 87 |
+
ax_roc.set_ylabel('True Positive Rate')
|
| 88 |
+
ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
|
| 89 |
+
ax_roc.legend(loc="lower right")
|
| 90 |
+
plt.close(fig_roc)
|
| 91 |
+
|
| 92 |
+
return accuracy, roc_score, report, fig, fig_roc
|
| 93 |
+
|
| 94 |
+
# Gradio function for batch image processing
|
| 95 |
+
def process_zip(zip_file):
|
| 96 |
+
extracted_dir = extract_zip(zip_file.name)
|
| 97 |
+
labels, preds, images = classify_images(extracted_dir, pipe0) # You can switch to pipe1 or pipe2
|
| 98 |
+
accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
|
| 99 |
+
shutil.rmtree(extracted_dir) # Clean up extracted files
|
| 100 |
+
return accuracy, roc_score, report, cm_fig, roc_fig
|
| 101 |
+
|
| 102 |
+
# Single image classification functions
|
| 103 |
def image_classifier0(image):
|
| 104 |
labels = ["AI", "Real"]
|
| 105 |
outputs = pipe0(image)
|
|
|
|
| 144 |
html_out = f"""
|
| 145 |
<h1>This image is likely: {label}</h1><br><h3>
|
| 146 |
Probabilities:<br>
|
| 147 |
+
Real: {float(px[1][0]):.4f}<br>
|
| 148 |
+
AI: {float(px[0][0]):.4f}"""
|
| 149 |
|
| 150 |
results = {
|
| 151 |
"Real": float(px[1][0]),
|
|
|
|
| 171 |
html_out = f"""
|
| 172 |
<h1>This image is likely: {label}</h1><br><h3>
|
| 173 |
Probabilities:<br>
|
| 174 |
+
Real: {float(px[1][0]):.4f}<br>
|
| 175 |
+
AI: {float(px[0][0]):.4f}"""
|
| 176 |
|
| 177 |
results = {
|
| 178 |
"Real": float(px[1][0]),
|
|
|
|
| 198 |
html_out = f"""
|
| 199 |
<h1>This image is likely: {label}</h1><br><h3>
|
| 200 |
Probabilities:<br>
|
| 201 |
+
Real: {float(px[1][0]):.4f}<br>
|
| 202 |
+
AI: {float(px[0][0]):.4f}"""
|
| 203 |
|
| 204 |
results = {
|
| 205 |
"Real": float(px[1][0]),
|
|
|
|
| 223 |
fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
|
| 224 |
fin_sub = 1 - fin_out
|
| 225 |
out = {
|
| 226 |
+
"Real": f"{fin_out:.4f}",
|
| 227 |
+
"AI": f"{fin_sub:.4f}"
|
| 228 |
}
|
| 229 |
return out
|
| 230 |
except Exception as e:
|
|
|
|
| 241 |
out = Image.open(f"{rand_im}-vid_tmp_proc.png")
|
| 242 |
return out
|
| 243 |
|
| 244 |
+
# Set up Gradio app
|
| 245 |
with gr.Blocks() as app:
|
| 246 |
+
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
|
| 247 |
+
|
| 248 |
+
with gr.Tabs():
|
| 249 |
+
# Tab for single image detection
|
| 250 |
+
with gr.Tab("Single Image Detection"):
|
| 251 |
+
with gr.Column():
|
| 252 |
+
inp = gr.Image(type='pil')
|
| 253 |
+
in_url = gr.Textbox(label="Image URL")
|
| 254 |
+
with gr.Row():
|
| 255 |
+
load_btn = gr.Button("Load URL")
|
| 256 |
+
btn = gr.Button("Detect AI")
|
| 257 |
+
mes = gr.HTML("""""")
|
| 258 |
+
|
| 259 |
+
with gr.Group():
|
| 260 |
+
with gr.Row():
|
| 261 |
+
fin = gr.Label(label="Final Probability")
|
| 262 |
+
with gr.Row():
|
| 263 |
+
for i, model in enumerate(models):
|
| 264 |
+
with gr.Box():
|
| 265 |
+
gr.HTML(f"""<b>Testing on Model {i}: <a href='https://huggingface.co/{model}'>{model}</a></b>""")
|
| 266 |
+
globals()[f'outp{i}'] = gr.HTML("""""")
|
| 267 |
+
globals()[f'n_out{i}'] = gr.Label(label="Output")
|
| 268 |
+
|
| 269 |
+
btn.click(fin_clear, None, fin, show_progress=False)
|
| 270 |
+
load_btn.click(load_url, in_url, [inp, mes])
|
| 271 |
+
|
| 272 |
+
btn.click(aiornot0, [inp], [outp0, n_out0]).then(
|
| 273 |
+
aiornot1, [inp], [outp1, n_out1]).then(
|
| 274 |
+
aiornot2, [inp], [outp2, n_out2]).then(
|
| 275 |
+
tot_prob, None, fin, show_progress=False)
|
| 276 |
+
|
| 277 |
+
btn.click(image_classifier0, [inp], [n_out0]).then(
|
| 278 |
+
image_classifier1, [inp], [n_out1]).then(
|
| 279 |
+
image_classifier2, [inp], [n_out2]).then(
|
| 280 |
+
tot_prob, None, fin, show_progress=False)
|
| 281 |
+
|
| 282 |
+
# Tab for batch processing
|
| 283 |
+
with gr.Tab("Batch Image Processing"):
|
| 284 |
+
zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
|
| 285 |
+
output_acc = gr.Label(label="Accuracy")
|
| 286 |
+
output_roc = gr.Label(label="ROC Score")
|
| 287 |
+
output_report = gr.Textbox(label="Classification Report", lines=10)
|
| 288 |
+
output_cm = gr.Plot(label="Confusion Matrix")
|
| 289 |
+
output_roc_plot = gr.Plot(label="ROC Curve")
|
| 290 |
+
|
| 291 |
+
batch_btn = gr.Button("Process Batch")
|
| 292 |
+
|
| 293 |
+
# Connect batch processing
|
| 294 |
+
batch_btn.click(process_zip, zip_file, [output_acc, output_roc, output_report, output_cm, output_roc_plot])
|
| 295 |
+
|
| 296 |
+
app.launch(show_api=False, max_threads=24)
|