Spaces:
Build error
Build error
| #EmbeddingGenerator.py | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="transformers.models.bert") | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| class EmbeddingGenerator: | |
| def __init__(self, pavlov_model_name="DeepPavlov/rubert-base-cased", sentence_transformer_model_name="cointegrated/rubert-tiny2"): | |
| """ | |
| Инициализирует токенизатор и модели для генерации эмбеддингов. | |
| Args: | |
| pavlov_model_name (str): Название модели для загрузки Pavlov модели. | |
| sentence_transformer_model_name (str): Название модели SentenceTransformer для генерации эмбеддингов. | |
| """ | |
| self.pavlov_tokenizer = AutoTokenizer.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True) | |
| self.pavlov_model = AutoModel.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True) | |
| self.sentence_transformer_model = SentenceTransformer(sentence_transformer_model_name) | |
| def generate_embeddings(self, texts, method="pavlov"): | |
| """ | |
| Генерирует эмбеддинги для списка текстов с использованием выбранного метода. | |
| Args: | |
| texts (list of str): Список текстов для генерации эмбеддингов. | |
| method (str): Метод генерации эмбеддингов: "pavlov" или "rubert_tiny2". | |
| Returns: | |
| np.ndarray: Эмбеддинги текстов. | |
| """ | |
| if method == "pavlov": | |
| # Генерация эмбеддингов с использованием Pavlov модели | |
| inputs = self.pavlov_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = self.pavlov_model(**inputs) | |
| # Mean pooling | |
| embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| elif method == "rubert_tiny2": | |
| # Генерация эмбеддингов с использованием SentenceTransformer | |
| embeddings = self.sentence_transformer_model.encode(texts, show_progress_bar=False) | |
| else: | |
| raise ValueError("Unsupported method. Choose 'pavlov' or 'rubert_tiny2'.") | |
| return embeddings | |