Spaces:
Runtime error
Runtime error
| # app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| TITLE = "Talk To Me Morty" | |
| DESCRIPTION = """ | |
| <p style='text-align:center'> | |
| The bot was trained on a Rick & Morty dialogues dataset with DialoGPT. | |
| </p> | |
| <center> | |
| <img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png" | |
| alt="Rick" | |
| width="150"> | |
| </center> | |
| """ | |
| ARTICLE = """ | |
| <p style='text-align:center'> | |
| <a href="https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701" | |
| target="_blank">Complete Tutorial</a> Β· | |
| <a href="https://dagshub.com/kingabzpro/DailoGPT-RickBot" | |
| target="_blank">Project on DAGsHub</a> | |
| </p> | |
| """ | |
| # βββ Load model once at start ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2") | |
| model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2") | |
| # βββ Chat handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def chat(user_msg: str, history_ids: list[int] | None): | |
| if not user_msg: | |
| return [], history_ids or [] | |
| new_ids = tokenizer.encode(user_msg + tokenizer.eos_token, | |
| return_tensors="pt") | |
| bot_input = ( | |
| torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1) | |
| if history_ids else new_ids | |
| ) | |
| history_ids = model.generate( | |
| bot_input, | |
| max_length=min(4096, bot_input.shape[-1] + 200), | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.92, | |
| top_k=50, | |
| ).tolist() | |
| turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \ | |
| .split("<|endoftext|>") | |
| # pack into (user, bot) pairs for Chatbot component | |
| pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)] | |
| return pairs, history_ids | |
| # βββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| gr.Markdown(f"<h1 style='text-align:center'>{TITLE}</h1>") | |
| gr.Markdown(DESCRIPTION) | |
| chatbot = gr.Chatbot(height=450) | |
| state = gr.State([]) | |
| with gr.Row(equal_height=True): | |
| prompt = gr.Textbox(placeholder="Ask Rick anythingβ¦", scale=9, show_label=False) | |
| send = gr.Button("Send",scale=1, variant="primary") | |
| # send on click or β΅ | |
| send.click(chat, inputs=[prompt, state], outputs=[chatbot, state]) | |
| prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state]) | |
| gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt) | |
| gr.Markdown(ARTICLE) | |
| if __name__ == "__main__": | |
| demo.launch() |