Spaces:
Build error
Build error
| #TextAugmentation.py | |
| from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer | |
| import torch | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| class TextAugmentation: | |
| def __init__(self, | |
| paraphrase_model_name="cointegrated/rut5-base-paraphraser", | |
| ru_en_model_name="Helsinki-NLP/opus-mt-ru-en", | |
| en_ru_model_name="Helsinki-NLP/opus-mt-en-ru"): | |
| # Инициализация модели для перефразирования | |
| self.paraphrase_tokenizer = T5Tokenizer.from_pretrained(paraphrase_model_name, legacy=False) | |
| self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(paraphrase_model_name) | |
| # Инициализация моделей для обратного перевода | |
| self.ru_en_tokenizer = MarianTokenizer.from_pretrained(ru_en_model_name) | |
| self.ru_en_model = MarianMTModel.from_pretrained(ru_en_model_name) | |
| self.en_ru_tokenizer = MarianTokenizer.from_pretrained(en_ru_model_name) | |
| self.en_ru_model = MarianMTModel.from_pretrained(en_ru_model_name) | |
| def paraphrase(self, text, num_return_sequences=1): | |
| """ | |
| Перефразирование текста с использованием модели. | |
| Args: | |
| text (str): Исходный текст для перефразирования. | |
| num_return_sequences (int): Количество вариантов перефразирования. | |
| Returns: | |
| list[str]: Список вариантов перефразирования текста. | |
| """ | |
| inputs = self.paraphrase_tokenizer([text], max_length=512, truncation=True, return_tensors="pt") | |
| outputs = self.paraphrase_model.generate( | |
| **inputs, | |
| max_length=128, | |
| num_return_sequences=num_return_sequences, | |
| do_sample=True, | |
| temperature=1.2, | |
| top_k=50, | |
| top_p=0.90 | |
| ) | |
| return [self.paraphrase_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| def back_translate(self, text): | |
| """ | |
| Выполняет обратный перевод текста: русский -> английский -> русский. | |
| Args: | |
| text (str): Исходный текст для обратного перевода. | |
| Returns: | |
| str: Текст после обратного перевода. | |
| """ | |
| # Перевод с русского на английский | |
| inputs = self.ru_en_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = self.ru_en_model.generate(**inputs) | |
| translated_text = self.ru_en_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Перевод с английского обратно на русский | |
| inputs = self.en_ru_tokenizer(translated_text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = self.en_ru_model.generate(**inputs) | |
| back_translated_text = self.en_ru_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return back_translated_text | |