Spaces:
Runtime error
Runtime error
File size: 3,303 Bytes
573fbce b9ca81a 573fbce b9ca81a 573fbce c2056b1 573fbce c2056b1 ea55fae f76e9f3 573fbce f72e3da 573fbce 63bf9c6 573fbce c2056b1 573fbce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
TITLE = "Talk To Me Morty"
DESCRIPTION = """
<p style='text-align:center'>
The bot was trained on a Rick & Morty dialogues dataset with DialoGPT.
</p>
<center>
<img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png"
alt="Rick"
width="150">
</center>
"""
ARTICLE = """
<p style='text-align:center'>
<a href="https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701"
target="_blank">Complete Tutorial</a> Β·
<a href="https://dagshub.com/kingabzpro/DailoGPT-RickBot"
target="_blank">Project on DAGsHub</a>
</p>
"""
# βββ Load model once at start ββββββββββββββββββββββββββββββββββββββββββββββββ
tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
# βββ Chat handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def chat(user_msg: str, history_ids: list[int] | None):
if not user_msg:
return [], history_ids or []
new_ids = tokenizer.encode(user_msg + tokenizer.eos_token,
return_tensors="pt")
bot_input = (
torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1)
if history_ids else new_ids
)
history_ids = model.generate(
bot_input,
max_length=min(4096, bot_input.shape[-1] + 200),
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_p=0.92,
top_k=50,
).tolist()
turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \
.split("<|endoftext|>")
# pack into (user, bot) pairs for Chatbot component
pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)]
return pairs, history_ids
# βββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown(f"<h1 style='text-align:center'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot(height=450)
state = gr.State([])
with gr.Row(equal_height=True):
prompt = gr.Textbox(placeholder="Ask Rick anythingβ¦", scale=9, show_label=False)
send = gr.Button("Send",scale=1, variant="primary")
# send on click or β΅
send.click(chat, inputs=[prompt, state], outputs=[chatbot, state])
prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state])
gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt)
gr.Markdown(ARTICLE)
if __name__ == "__main__":
demo.launch() |