# modeling_simple_classifier.py import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel class SimpleClassifierConfig: model_type = "simple_classifier" class SimpleClassifier(PreTrainedModel): config_class = SimpleClassifierConfig def __init__(self, config): super().__init__(config) self.linear1 = nn.Linear(config.input_dim, 256) self.ln1 = nn.LayerNorm(256) self.dropout = nn.Dropout(config.p_dropout) self.linear2 = nn.Linear(256, 128) self.ln2 = nn.LayerNorm(128) self.linear_out = nn.Linear(128, config.num_classes) self.post_init() def forward(self, x): x = F.gelu(self.ln1(self.linear1(x))) x = self.dropout(x) x = F.gelu(self.ln2(self.linear2(x))) x = self.dropout(x) return self.linear_out(x)