Spaces:
Paused
Paused
| # Copyright 2023 MosaicML spaces authors | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Optional | |
| import datetime | |
| import os | |
| from threading import Event, Thread | |
| from uuid import uuid4 | |
| import gradio as gr | |
| import requests | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, | |
| ) | |
| model_name = "JosephusCheung/Guanaco" | |
| max_new_tokens = 2048 | |
| print(f"Starting to load the model {model_name} into memory") | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| m = AutoModelForCausalLM.from_pretrained(model_name).eval() | |
| # tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) | |
| stop_token_ids = [tok.eos_token_id] | |
| print(f"Successfully loaded the model {model_name} into memory") | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| for stop_id in stop_token_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| PROMPT_DICT = { | |
| "prompt_input": ( | |
| "Below is an instruction that describes a task, paired with an input that provides further context. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | |
| ), | |
| "prompt_no_input": ( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Response:" | |
| ), | |
| } | |
| def generate_input(instruction: Optional[str] = None, input_str: Optional[str] = None) -> str: | |
| if input_str is None: | |
| return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction}) | |
| else: | |
| return PROMPT_DICT['prompt_input'].format_map({'instruction': instruction, 'input': input_str}) | |
| def convert_history_to_text(history): | |
| user_input = history[-1][0] | |
| text = generate_input(user_input) | |
| return text | |
| def log_conversation(conversation_id, history, messages, generate_kwargs): | |
| logging_url = os.getenv("LOGGING_URL", None) | |
| if logging_url is None: | |
| return | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") | |
| data = { | |
| "conversation_id": conversation_id, | |
| "timestamp": timestamp, | |
| "history": history, | |
| "messages": messages, | |
| "generate_kwargs": generate_kwargs, | |
| } | |
| try: | |
| requests.post(logging_url, json=data) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error logging conversation: {e}") | |
| def user(message, history): | |
| # Append the user's message to the conversation history | |
| return "", history + [[message, ""]] | |
| def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): | |
| print(f"history: {history}") | |
| # Initialize a StopOnTokens object | |
| stop = StopOnTokens() | |
| # Construct the input message string for the model by concatenating the current system message and conversation history | |
| messages = convert_history_to_text(history) | |
| # Tokenize the messages string | |
| input_ids = tok(messages, return_tensors="pt").input_ids | |
| input_ids = input_ids.to(m.device) | |
| streamer = TextIteratorStreamer( | |
| tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0.0, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| streamer=streamer, | |
| stopping_criteria=StoppingCriteriaList([stop]), | |
| ) | |
| stream_complete = Event() | |
| def generate_and_signal_complete(): | |
| m.generate(**generate_kwargs) | |
| stream_complete.set() | |
| def log_after_stream_complete(): | |
| stream_complete.wait() | |
| log_conversation( | |
| conversation_id, | |
| history, | |
| messages, | |
| { | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "repetition_penalty": repetition_penalty, | |
| }, | |
| ) | |
| t1 = Thread(target=generate_and_signal_complete) | |
| t1.start() | |
| t2 = Thread(target=log_after_stream_complete) | |
| t2.start() | |
| # Initialize an empty string to store the generated text | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| history[-1][1] = partial_text | |
| yield history | |
| def get_uuid(): | |
| return str(uuid4()) | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| css=".disclaimer {font-variant-caps: all-small-caps;}", | |
| ) as demo: | |
| conversation_id = gr.State(get_uuid) | |
| gr.Markdown( | |
| """ | |
| ## sambanovasystems/BLOOMChat-176B-v1 模型 | |
| """ | |
| ) | |
| chatbot = gr.Chatbot().style(height=500) | |
| with gr.Row(): | |
| with gr.Column(): | |
| msg = gr.Textbox( | |
| label="Chat Message Box", | |
| placeholder="输入您的问题", | |
| show_label=False, | |
| ).style(container=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| submit = gr.Button("Submit") | |
| stop = gr.Button("停止") | |
| clear = gr.Button("Clear") | |
| with gr.Row(): | |
| with gr.Accordion("高级选项:", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| value=0.1, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=1.0, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.01, | |
| interactive=True, | |
| info=( | |
| "Sample from the smallest possible set of tokens whose cumulative probability " | |
| "exceeds top_p. Set to 1 to disable and sample from all tokens." | |
| ), | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| top_k = gr.Slider( | |
| label="Top-k", | |
| value=0, | |
| minimum=0.0, | |
| maximum=200, | |
| step=1, | |
| interactive=True, | |
| info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| repetition_penalty = gr.Slider( | |
| label="Repetition Penalty", | |
| value=1.1, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.1, | |
| interactive=True, | |
| info="Penalize repetition — 1.0 to disable.", | |
| ) | |
| # with gr.Row(): | |
| # gr.Markdown( | |
| # "demo 2", | |
| # elem_classes=["disclaimer"], | |
| # ) | |
| submit_event = msg.submit( | |
| fn=user, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).then( | |
| fn=bot, | |
| inputs=[ | |
| chatbot, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| conversation_id, | |
| ], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| submit_click_event = submit.click( | |
| fn=user, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).then( | |
| fn=bot, | |
| inputs=[ | |
| chatbot, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| conversation_id, | |
| ], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| stop.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[submit_event, submit_click_event], | |
| queue=False, | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue(max_size=128, concurrency_count=2) | |
| demo.launch() | |