|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchvision.models as models
|
|
|
from transformers import PreTrainedModel
|
|
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
|
|
|
|
|
class DeepFakeDetectorConfig:
|
|
|
"""Configuration class for DeepFakeDetector."""
|
|
|
|
|
|
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048,
|
|
|
bidirectional=False, sequence_length=20, im_size=112, **kwargs):
|
|
|
self.num_classes = num_classes
|
|
|
self.latent_dim = latent_dim
|
|
|
self.lstm_layers = lstm_layers
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.bidirectional = bidirectional
|
|
|
self.sequence_length = sequence_length
|
|
|
self.im_size = im_size
|
|
|
|
|
|
@classmethod
|
|
|
def from_dict(cls, config_dict):
|
|
|
"""Create a configuration from a dictionary."""
|
|
|
return cls(**config_dict)
|
|
|
|
|
|
def to_dict(self):
|
|
|
"""Convert configuration to a dictionary."""
|
|
|
return {
|
|
|
"num_classes": self.num_classes,
|
|
|
"latent_dim": self.latent_dim,
|
|
|
"lstm_layers": self.lstm_layers,
|
|
|
"hidden_dim": self.hidden_dim,
|
|
|
"bidirectional": self.bidirectional,
|
|
|
"sequence_length": self.sequence_length,
|
|
|
"im_size": self.im_size
|
|
|
}
|
|
|
|
|
|
|
|
|
class DeepFakeDetectorModel(PreTrainedModel):
|
|
|
"""DeepFake detection model using ResNext50 and LSTM."""
|
|
|
|
|
|
config_class = DeepFakeDetectorConfig
|
|
|
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
self.num_classes = config.num_classes
|
|
|
self.latent_dim = config.latent_dim
|
|
|
self.lstm_layers = config.lstm_layers
|
|
|
self.hidden_dim = config.hidden_dim
|
|
|
self.bidirectional = config.bidirectional
|
|
|
|
|
|
|
|
|
resnext = models.resnext50_32x4d(pretrained=True)
|
|
|
self.backbone = nn.Sequential(*list(resnext.children())[:-2])
|
|
|
|
|
|
|
|
|
self.lstm = nn.LSTM(
|
|
|
self.latent_dim,
|
|
|
self.hidden_dim,
|
|
|
self.lstm_layers,
|
|
|
bidirectional=self.bidirectional
|
|
|
)
|
|
|
|
|
|
|
|
|
self.relu = nn.LeakyReLU()
|
|
|
self.dropout = nn.Dropout(0.4)
|
|
|
self.classifier = nn.Linear(self.hidden_dim, self.num_classes)
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
|
def forward(self, x, labels=None):
|
|
|
"""
|
|
|
Forward pass of the model.
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor of shape (batch_size, sequence_length, channels, height, width)
|
|
|
labels: Optional labels for computing loss
|
|
|
|
|
|
Returns:
|
|
|
SequenceClassifierOutput: Model outputs including loss and logits
|
|
|
"""
|
|
|
batch_size, seq_length, c, h, w = x.shape
|
|
|
|
|
|
|
|
|
x = x.view(batch_size * seq_length, c, h, w)
|
|
|
|
|
|
|
|
|
features = self.backbone(x)
|
|
|
|
|
|
|
|
|
pooled = self.avgpool(features)
|
|
|
|
|
|
|
|
|
pooled = pooled.view(batch_size, seq_length, self.latent_dim)
|
|
|
|
|
|
|
|
|
lstm_out, _ = self.lstm(pooled, None)
|
|
|
|
|
|
|
|
|
final = lstm_out[:, -1, :]
|
|
|
|
|
|
|
|
|
logits = self.classifier(self.dropout(final))
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
loss = loss_fct(logits, labels)
|
|
|
|
|
|
return SequenceClassifierOutput(
|
|
|
loss=loss,
|
|
|
logits=logits
|
|
|
)
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
|
"""Load a pretrained model."""
|
|
|
|
|
|
config_dict = kwargs.pop("config", None)
|
|
|
if config_dict is None:
|
|
|
config_dict = {
|
|
|
"num_classes": 2,
|
|
|
"latent_dim": 2048,
|
|
|
"lstm_layers": 1,
|
|
|
"hidden_dim": 2048,
|
|
|
"bidirectional": False,
|
|
|
"sequence_length": 20,
|
|
|
"im_size": 112
|
|
|
}
|
|
|
|
|
|
config = DeepFakeDetectorConfig.from_dict(config_dict)
|
|
|
|
|
|
|
|
|
model = cls(config, *model_args, **kwargs)
|
|
|
|
|
|
|
|
|
state_dict = torch.load(pretrained_model_name_or_path, map_location="cpu")
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
|
return model
|
|
|
|