Spaces:
Runtime error
Runtime error
| import torch | |
| import transformers | |
| from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import ( | |
| PeftModel, | |
| LoraConfig, | |
| get_peft_model, | |
| prepare_model_for_kbit_training | |
| ) | |
| import bs4 | |
| import requests | |
| from typing import List | |
| import nltk | |
| from nltk import sent_tokenize | |
| from tqdm import tqdm | |
| import numpy as np | |
| import faiss | |
| import re | |
| import unicodedata | |
| import gradio as gr | |
| import asyncio | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device | |
| base_model_id = "microsoft/phi-2" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| device_map='auto', | |
| trust_remote_code=True | |
| ) | |
| ft_model = PeftModel.from_pretrained(model, "yurezsml/phi2_chan", offload_dir="./") | |
| def remove_accents(input_str): | |
| nfkd_form = unicodedata.normalize('NFKD', input_str) | |
| return u"".join([c for c in nfkd_form if not unicodedata.combining(c)]) | |
| def preprocess(text): | |
| text = text.lower() | |
| temp = remove_accents(text) | |
| text = text.replace('\xa0', ' ') | |
| text = text.replace('\n\n', '\n') | |
| text = text.replace('()', '') | |
| text = text.replace('[]', '') | |
| text = re.sub("[\(\[].*?[\)\]]", "", text) | |
| text = text.replace('а́', 'а') | |
| return text | |
| def split_text(text: str, n=2, character=" ") -> List[str]: | |
| text = preprocess(text) | |
| all_sentences = sent_tokenize(text) | |
| return [' '.join(all_sentences[i : i + n]) for i in range(0, len(all_sentences), 2)] | |
| def split_documents(documents: List[str]) -> list: | |
| texts = [] | |
| for text in documents: | |
| if text is not None: | |
| for passage in split_text(text): | |
| texts.append(passage) | |
| return texts | |
| def embed(text, model, tokenizer): | |
| encoded_input = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(model.device) | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
| input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() | |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return sum_embeddings / sum_mask | |
| response = requests.get("https://en.wikipedia.org/wiki/Chandler_Bing") | |
| base_text = '' | |
| if response: | |
| html = bs4.BeautifulSoup(response.text, 'html.parser') | |
| title = html.select("#firstHeading")[0].text | |
| paragraphs = html.select("p") | |
| for para in paragraphs: | |
| base_text = base_text + para.text | |
| fact_coh_tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence") | |
| fact_coh_model = AutoModel.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence") | |
| fact_coh_model.to(device) | |
| nltk.download('punkt') | |
| subsample_documents = split_documents([base_text]) | |
| batch_size = 8 | |
| total_batches = len(subsample_documents) // batch_size + (0 if len(subsample_documents) % batch_size == 0 else 1) | |
| base = list() | |
| for i in tqdm(range(0, len(subsample_documents), batch_size), total=total_batches, desc="Processing Batches"): | |
| batch_texts = subsample_documents[i:i + batch_size] | |
| base.extend(embed(batch_texts, fact_coh_model, fact_coh_tokenizer)) | |
| base = np.array([vector.cpu().numpy() for vector in base]) | |
| index = faiss.IndexFlatL2(base.shape[1]) | |
| index.add(base) | |
| async def get_context(subsample_documents, query, index, model, tokenizer): | |
| k = 5 | |
| xq = embed(query.lower(), model, tokenizer).cpu().numpy() | |
| D, I = index.search(xq.reshape(1, 768), k) | |
| return subsample_documents[I[0][0]] | |
| async def get_prompt(question, use_rag, answers_history: list[str]): | |
| eval_prompt = '###system: answer the question as Chandler. ' | |
| for idx, text in enumerate(answers_history): | |
| if idx % 2 == 0: | |
| eval_prompt = eval_prompt + f' ###question: {text}' | |
| else: | |
| eval_prompt = eval_prompt + f' ###answer: {text} ' | |
| if use_rag: | |
| context = await asyncio.wait_for(get_context(subsample_documents, question, index, fact_coh_model, fact_coh_tokenizer), timeout=60) | |
| eval_prompt = eval_prompt + f' Chandler. {context}' | |
| eval_prompt = eval_prompt + f' ###question: {question} ' | |
| eval_prompt = ' '.join(eval_prompt.split()) | |
| return eval_prompt | |
| async def get_answer(question, use_rag, answers_history: list[str]): | |
| eval_prompt = await asyncio.wait_for(get_prompt(question, use_rag, answers_history), timeout=60) | |
| model_input = tokenizer(eval_prompt, return_tensors="pt").to(device) | |
| ft_model.eval() | |
| with torch.no_grad(): | |
| answer = tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=30, repetition_penalty=1.11)[0], skip_special_tokens=True) + '\n' | |
| answer = ' '.join(answer.split()) | |
| if eval_prompt in answer: | |
| answer = answer.replace(eval_prompt,'') | |
| answer = answer.split('###answer')[1] | |
| dialog = '' | |
| for idx, text in enumerate(answers_history): | |
| if idx % 2 == 0: | |
| dialog = dialog + f'you: {text}\n' | |
| else: | |
| dialog = dialog + f'Chandler: {text}\n' | |
| dialog = dialog + f'you: {question}\n' | |
| dialog = dialog + f'Chandler: {answer}\n' | |
| answers_history.append(question) | |
| answers_history.append(answer) | |
| return dialog, answers_history | |
| async def async_proc(question, use_rag, answers_history: list[str]): | |
| try: | |
| return await asyncio.wait_for(get_answer(question, use_rag, answers_history), timeout=60) | |
| except asyncio.TimeoutError: | |
| return "Processing timed out.", answers_history | |
| gr.Interface( | |
| fn=async_proc, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Question", | |
| ), | |
| gr.Checkbox(label="Use RAG", info="Pick to RAG to improve factual coherence"), | |
| gr.State(value=[]), | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="Chat" | |
| ), | |
| gr.State(), | |
| ], | |
| title="Асинхронный сервис для чат-бота по сериалу Друзья", | |
| concurrency_limit=5 | |
| ).queue().launch(share=True, debug=True) |