# app.py ────────────────────────────────────────────────────────────────────── import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM TITLE = "Talk To Me Morty" DESCRIPTION = """

The bot was trained on a Rick & Morty dialogues dataset with DialoGPT.

Rick
""" ARTICLE = """

Complete Tutorial · Project on DAGsHub

""" # ─── Load model once at start ──────────────────────────────────────────────── tokenizer = AutoTokenizer.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2") model = AutoModelForCausalLM.from_pretrained("ericzhou/DialoGPT-Medium-Rick_v2") # ─── Chat handler ──────────────────────────────────────────────────────────── def chat(user_msg: str, history_ids: list[int] | None): if not user_msg: return [], history_ids or [] new_ids = tokenizer.encode(user_msg + tokenizer.eos_token, return_tensors="pt") bot_input = ( torch.cat([torch.LongTensor(history_ids), new_ids], dim=-1) if history_ids else new_ids ) history_ids = model.generate( bot_input, max_length=min(4096, bot_input.shape[-1] + 200), pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7, top_p=0.92, top_k=50, ).tolist() turns = tokenizer.decode(history_ids[0], skip_special_tokens=False) \ .split("<|endoftext|>") # pack into (user, bot) pairs for Chatbot component pairs = [(turns[i], turns[i + 1]) for i in range(0, len(turns) - 1, 2)] return pairs, history_ids # ─── Gradio UI ─────────────────────────────────────────────────────────────── with gr.Blocks(theme=gr.themes.Base()) as demo: gr.Markdown(f"

{TITLE}

") gr.Markdown(DESCRIPTION) chatbot = gr.Chatbot(height=450) state = gr.State([]) with gr.Row(equal_height=True): prompt = gr.Textbox(placeholder="Ask Rick anything…", scale=9, show_label=False) send = gr.Button("Send",scale=1, variant="primary") # send on click or ↵ send.click(chat, inputs=[prompt, state], outputs=[chatbot, state]) prompt.submit(chat, inputs=[prompt, state], outputs=[chatbot, state]) gr.Examples([["How are you, Rick?"], ["Tell me a joke!"]], inputs=prompt) gr.Markdown(ARTICLE) if __name__ == "__main__": demo.launch()