|
|
""" |
|
|
Face detection utility using MTCNN from facenet_pytorch. |
|
|
|
|
|
This module exposes a simple function to detect faces in a PIL Image. It |
|
|
returns bounding boxes for all detected faces. The detection model is |
|
|
constructed lazily on the first call to avoid unnecessary GPU/CPU |
|
|
initialisation when the module is imported. |
|
|
""" |
|
|
|
|
|
from typing import List, Tuple, Optional |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
try: |
|
|
from facenet_pytorch import MTCNN |
|
|
except ImportError as exc: |
|
|
raise ImportError( |
|
|
"facenet_pytorch is required for face detection. Install it with `pip install facenet-pytorch`." |
|
|
) from exc |
|
|
|
|
|
_mtcnn: Optional[MTCNN] = None |
|
|
|
|
|
|
|
|
def _get_mtcnn(device: str = "cpu") -> MTCNN: |
|
|
"""Return a singleton MTCNN detector instance. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
device: str, optional |
|
|
PyTorch device on which to run the detector. Defaults to ``"cpu"``. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
MTCNN |
|
|
The configured multi-task cascaded CNN detector. |
|
|
""" |
|
|
global _mtcnn |
|
|
if _mtcnn is None: |
|
|
_mtcnn = MTCNN(image_size=160, margin=0, keep_all=True, device=device) |
|
|
return _mtcnn |
|
|
|
|
|
|
|
|
def detect_faces(image: Image.Image, device: str = "cpu") -> List[Tuple[float, float, float, float]]: |
|
|
"""Detect faces in a PIL image. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image: PIL.Image.Image |
|
|
The input image in which to detect faces. |
|
|
device: str, optional |
|
|
Device on which to run the detector (``"cpu"`` or ``"cuda"``). Defaults to ``"cpu"``. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
List[Tuple[float, float, float, float]] |
|
|
A list of bounding boxes (x1, y1, x2, y2) for each detected face. If |
|
|
no faces are found, returns an empty list. |
|
|
""" |
|
|
mtcnn = _get_mtcnn(device) |
|
|
|
|
|
boxes, _ = mtcnn.detect(image) |
|
|
if boxes is None: |
|
|
return [] |
|
|
|
|
|
return [tuple(map(float, box)) for box in np.array(boxes)] |
|
|
|
|
|
|
|
|
__all__ = ["detect_faces"] |