|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import timm |
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
from facenet_pytorch import MTCNN |
|
|
import gradio as gr |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
MODEL_PATH = "best_model_stage3.pth" |
|
|
IMG_SIZE = 224 |
|
|
CLASS_NAMES = ['Anger', 'Contempt', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise'] |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
model = timm.create_model("efficientnet_b4", pretrained=False, num_classes=len(CLASS_NAMES)) |
|
|
state = torch.load(MODEL_PATH, map_location="cpu") |
|
|
if isinstance(state, dict) and 'model_state_dict' in state: |
|
|
state = state['model_state_dict'] |
|
|
model.load_state_dict(state, strict=False) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
mtcnn = MTCNN(keep_all=True, device=DEVICE) |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
def draw_results_on_image(pil_img, boxes, preds, probs): |
|
|
pil_img = pil_img.copy() |
|
|
draw = ImageDraw.Draw(pil_img) |
|
|
try: |
|
|
font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=18) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
for i, box in enumerate(boxes): |
|
|
x1, y1, x2, y2 = [int(b) for b in box] |
|
|
label = preds[i] |
|
|
prob = probs[i] |
|
|
draw.rectangle([x1, y1, x2, y2], outline="lime", width=3) |
|
|
text = f"{label} {prob:.2f}" |
|
|
text_size = draw.textbbox((x1, y1), text, font=font) |
|
|
draw.rectangle([x1, y1 - 25, x1 + text_size[2] - text_size[0] + 6, y1], fill="lime") |
|
|
draw.text((x1 + 3, y1 - 22), text, fill="black", font=font) |
|
|
return pil_img |
|
|
|
|
|
def predict_on_image(image): |
|
|
try: |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
pil_img = Image.fromarray(np.uint8(image)).convert("RGB") |
|
|
else: |
|
|
pil_img = image.convert("RGB") |
|
|
|
|
|
boxes, _ = mtcnn.detect(pil_img) |
|
|
if boxes is None or len(boxes) == 0: |
|
|
draw = ImageDraw.Draw(pil_img) |
|
|
try: |
|
|
font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=18) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
draw.text((10, 10), "No faces detected", fill="red", font=font) |
|
|
return pil_img |
|
|
|
|
|
preds, probs = [], [] |
|
|
for box in boxes: |
|
|
x1, y1, x2, y2 = [int(b) for b in box] |
|
|
w, h = pil_img.size |
|
|
x1, y1 = max(0, x1), max(0, y1) |
|
|
x2, y2 = min(w, x2), min(h, y2) |
|
|
face = pil_img.crop((x1, y1, x2, y2)).convert("RGB") |
|
|
|
|
|
img_t = preprocess(face).unsqueeze(0).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
logits = model(img_t) |
|
|
probs_t = F.softmax(logits, dim=1).cpu().numpy()[0] |
|
|
top_idx = int(np.argmax(probs_t)) |
|
|
preds.append(CLASS_NAMES[top_idx]) |
|
|
probs.append(float(probs_t[top_idx])) |
|
|
|
|
|
out_img = draw_results_on_image(pil_img, boxes, preds, probs) |
|
|
return out_img |
|
|
|
|
|
except Exception as e: |
|
|
print("Error:", e) |
|
|
return None |
|
|
|
|
|
|
|
|
title = "Face Emotion Recognition (EfficientNet-B4)" |
|
|
description = """ |
|
|
Upload an image and the model will **detect faces**, classify their **emotions**, |
|
|
and draw bounding boxes with labels. |
|
|
|
|
|
**Model:** EfficientNet-B4 fine-tuned with Focal Loss, MixUp & CutMix to reduce overfitting. |
|
|
""" |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_on_image, |
|
|
inputs=gr.Image(type="pil", label="Upload an Image"), |
|
|
outputs=gr.Image(type="pil", label="Detected Emotions"), |
|
|
title=title, |
|
|
description=description, |
|
|
allow_flagging="never", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|