This model has been pushed to the Hub using the PytorchModelHubMixin integration:
For full documentation of this model, please see the official model card. They are the ones who built the model.
Mozilla AI has made it so you can call the govtech/jina-embeddings-v2-small-en-off-topic using from_pretrained. To do this, you'll need to first pull the CrossEncoderWithSharedBase model
architectuer from their model card and make sure to add PyTorchModelHubMixin as an inherited class. See this article
Then, you can do the following:
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import PyTorchModelHubMixin
import torch.nn as nn
class Adapter(nn.Module):
    def __init__(self, hidden_size):
        super(Adapter, self).__init__()
        self.down_project = nn.Linear(hidden_size, hidden_size // 2)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(hidden_size // 2, hidden_size)
    def forward(self, x):
        down = self.down_project(x)
        activated = self.activation(down)
        up = self.up_project(activated)
        return up + x  # Residual connection
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(hidden_size))
    def forward(self, hidden_states):
        # hidden_states: [seq_len, batch_size, hidden_size]
        scores = torch.matmul(hidden_states, self.attention_weights)
        attention_weights = torch.softmax(scores, dim=0)
        weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
        return weighted_sum
class CrossEncoderWithSharedBase(nn.Module, PyTorchModelHubMixin):
    def __init__(self, base_model, num_labels=2, num_heads=8):
        super(CrossEncoderWithSharedBase, self).__init__()
        # Shared pre-trained model
        self.shared_encoder = base_model
        hidden_size = self.shared_encoder.config.hidden_size
        # Sentence-specific adapters
        self.adapter1 = Adapter(hidden_size)
        self.adapter2 = Adapter(hidden_size)
        # Cross-attention layers
        self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
        self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
        # Attention pooling layers
        self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
        self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
        # Projection layer with non-linearity
        self.projection_layer = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU()
        )
        # Classifier with three hidden layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_labels)
        )
    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        # Encode sentences
        outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
        outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
        # Apply sentence-specific adapters
        embeds1 = self.adapter1(outputs1.last_hidden_state)
        embeds2 = self.adapter2(outputs2.last_hidden_state)
        # Transpose for attention layers
        embeds1 = embeds1.transpose(0, 1)
        embeds2 = embeds2.transpose(0, 1)
        # Cross-attention
        cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
        cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
        # Attention pooling
        pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
        pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
        # Concatenate and project
        combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
        projected = self.projection_layer(combined)
        # Classification
        logits = self.classifier(projected)
        return logits
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-small-en")
base_model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-small-en")
off_topic = CrossEncoderWithSharedBase.from_pretrained("mozilla-ai/jina-embeddings-v2-small-en", base_model=base_model)
# Then you can build a predict function that utilizes the tokenizer
def predict(model, tokenizer, sentence1, sentence2):
    inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
    inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
    input_ids1 = inputs1['input_ids'].to(device)
    attention_mask1 = inputs1['attention_mask'].to(device)
    input_ids2 = inputs2['input_ids'].to(device)
    attention_mask2 = inputs2['attention_mask'].to(device)
    # Get outputs
    with torch.no_grad():
        outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
                        input_ids2=input_ids2, attention_mask2=attention_mask2)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()
    return predicted_label, probabilities.cpu().numpy()
- Downloads last month
- 45
