File size: 3,303 Bytes
573fbce
 
b9ca81a
573fbce
 
 
b9ca81a
573fbce
 
c2056b1
 
 
573fbce
 
c2056b1
 
ea55fae
f76e9f3
573fbce
 
 
 
 
 
 
 
 
 
f72e3da
573fbce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63bf9c6
573fbce
 
 
 
 
 
c2056b1
 
 
573fbce
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# app.py  ──────────────────────────────────────────────────────────────────────
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

TITLE = "Talk To Me Morty"

DESCRIPTION = """
<p style='text-align:center'>
The bot was trained on a Rick & Morty dialogues dataset with DialoGPT.
</p>
<center>
<img src="https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot/resolve/main/img/rick.png"
     alt="Rick"
     width="150">
</center>
"""

ARTICLE = """
<p style='text-align:center'>
<a href="https://medium.com/geekculture/discord-bot-using-dailogpt-and-huggingface-api-c71983422701"
   target="_blank">Complete Tutorial</a> Β·
<a href="https://dagshub.com/kingabzpro/DailoGPT-RickBot"
   target="_blank">Project on DAGsHub</a>
</p>
"""

# ─── Load model once at start ────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")
model     = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2")

# ─── Chat handler ────────────────────────────────────────────────────────────
def chat(user_msg: str, history_ids: list[int] | None):
    if not user_msg:
        return [], history_ids or []

    new_ids = tokenizer.encode(user_msg + tokenizer.eos_token,
                               return_tensors="pt")

    bot_input = (
        torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1)
        if history_ids else new_ids
    )

    history_ids = model.generate(
        bot_input,
        max_length=min(4096, bot_input.shape[-1] + 200),
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.7,
        top_p=0.92,
        top_k=50,
    ).tolist()

    turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \
                     .split("<|endoftext|>")
    # pack into (user, bot) pairs for Chatbot component
    pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)]
    return pairs, history_ids

# ─── Gradio UI ───────────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Base()) as demo:
    gr.Markdown(f"<h1 style='text-align:center'>{TITLE}</h1>")
    gr.Markdown(DESCRIPTION)

    chatbot = gr.Chatbot(height=450)
    state   = gr.State([])

    with gr.Row(equal_height=True):
        prompt = gr.Textbox(placeholder="Ask Rick anything…", scale=9, show_label=False)
        send   = gr.Button("Send",scale=1, variant="primary")

    # send on click or ↡
    send.click(chat, inputs=[prompt, state], outputs=[chatbot, state])
    prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state])

    gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt)
    gr.Markdown(ARTICLE)

if __name__ == "__main__":
    demo.launch()