abdelrhman145's picture
Update app.py
872c147 verified
# app.py
import torch
import torch.nn.functional as F
import timm
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from facenet_pytorch import MTCNN
import gradio as gr
from torchvision import transforms
# ----------------- CONFIG -----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "best_model_stage3.pth"
IMG_SIZE = 224
CLASS_NAMES = ['Anger', 'Contempt', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise']
# ----------------- LOAD MODEL -----------------
def load_model():
model = timm.create_model("efficientnet_b4", pretrained=False, num_classes=len(CLASS_NAMES))
state = torch.load(MODEL_PATH, map_location="cpu")
if isinstance(state, dict) and 'model_state_dict' in state:
state = state['model_state_dict']
model.load_state_dict(state, strict=False)
model.to(DEVICE)
model.eval()
return model
model = load_model()
# ----------------- FACE DETECTOR -----------------
mtcnn = MTCNN(keep_all=True, device=DEVICE)
# ----------------- TRANSFORMS -----------------
preprocess = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ----------------- FUNCTIONS -----------------
def draw_results_on_image(pil_img, boxes, preds, probs):
pil_img = pil_img.copy()
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=18)
except:
font = ImageFont.load_default()
for i, box in enumerate(boxes):
x1, y1, x2, y2 = [int(b) for b in box]
label = preds[i]
prob = probs[i]
draw.rectangle([x1, y1, x2, y2], outline="lime", width=3)
text = f"{label} {prob:.2f}"
text_size = draw.textbbox((x1, y1), text, font=font)
draw.rectangle([x1, y1 - 25, x1 + text_size[2] - text_size[0] + 6, y1], fill="lime")
draw.text((x1 + 3, y1 - 22), text, fill="black", font=font)
return pil_img
def predict_on_image(image):
try:
if image is None:
return None
if isinstance(image, np.ndarray):
pil_img = Image.fromarray(np.uint8(image)).convert("RGB")
else:
pil_img = image.convert("RGB")
boxes, _ = mtcnn.detect(pil_img)
if boxes is None or len(boxes) == 0:
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=18)
except:
font = ImageFont.load_default()
draw.text((10, 10), "No faces detected", fill="red", font=font)
return pil_img
preds, probs = [], []
for box in boxes:
x1, y1, x2, y2 = [int(b) for b in box]
w, h = pil_img.size
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
face = pil_img.crop((x1, y1, x2, y2)).convert("RGB")
img_t = preprocess(face).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(img_t)
probs_t = F.softmax(logits, dim=1).cpu().numpy()[0]
top_idx = int(np.argmax(probs_t))
preds.append(CLASS_NAMES[top_idx])
probs.append(float(probs_t[top_idx]))
out_img = draw_results_on_image(pil_img, boxes, preds, probs)
return out_img
except Exception as e:
print("Error:", e)
return None
# ----------------- GRADIO UI -----------------
title = "Face Emotion Recognition (EfficientNet-B4)"
description = """
Upload an image and the model will **detect faces**, classify their **emotions**,
and draw bounding boxes with labels.
**Model:** EfficientNet-B4 fine-tuned with Focal Loss, MixUp & CutMix to reduce overfitting.
"""
demo = gr.Interface(
fn=predict_on_image,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Image(type="pil", label="Detected Emotions"),
title=title,
description=description,
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)