Spaces:
Build error
Build error
| import argparse | |
| import os | |
| from queue import SimpleQueue | |
| from threading import Thread | |
| from typing import Iterator | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from gradio import Chatbot | |
| from image_utils import ImageStitcher | |
| from transformers import (AutoModelForCausalLM, AutoTokenizer, | |
| TextIteratorStreamer) | |
| from StreamDiffusionIO import LatentConsistencyModelStreamIO | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| DESCRIPTION = """\ | |
| # Kanji-Streaming Chat | |
| π This Space is adapted from [Llama-2-7b-chat](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat) space, demonstrating how to "chat" with LLM with [Kanji-Streaming](https://github.com/AgainstEntropy/kanji). | |
| π¨ The technique behind Kanji-Streaming is [StreamDiffusionIO](https://github.com/AgainstEntropy/StreamDiffusionIO), which is based on [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion), *but especially allows to render text streams into image streams*. | |
| π For more details about Kanji-Streaming, take a look at the [github repository](https://github.com/AgainstEntropy/kanji). | |
| """ | |
| LICENSE = """ | |
| <p/> | |
| --- | |
| As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, | |
| this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
| """ | |
| parser = argparse.ArgumentParser(description="Gradio launcher for Streaming-Kanji.") | |
| parser.add_argument( | |
| "--llama_model_id_or_path", | |
| type=str, | |
| default="meta-llama/Llama-2-7b-chat-hf", | |
| required=False, | |
| help="Path to downloaded llama-chat-hf model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--sd_model_id_or_path", | |
| type=str, | |
| default="stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| required=False, | |
| help="Path to downloaded sd-1-5 model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--lora_path", | |
| type=str, | |
| default="AgainstEntropy/kanji-lora-sd-v1-5", | |
| required=False, | |
| help="Path to downloaded LoRA weight or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--lcm_lora_path", | |
| type=str, | |
| default="AgainstEntropy/kanji-lcm-lora-sd-v1-5", | |
| required=False, | |
| help="Path to downloaded LCM-LoRA weight or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--img_res", | |
| type=int, | |
| default=64, | |
| required=False, | |
| help="Image resolution for displaying Kanji characters in ChatBot.", | |
| ) | |
| parser.add_argument( | |
| "--img_per_line", | |
| type=int, | |
| default=16, | |
| required=False, | |
| help="Number of Kanji characters to display in a single line.", | |
| ) | |
| parser.add_argument( | |
| "--tmp_dir", | |
| type=str, | |
| default="./tmp", | |
| required=False, | |
| help="Path to save temporary images generated by StreamDiffusionIO.", | |
| ) | |
| args = parser.parse_args() | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| DESCRIPTION += "\n<p>Running on CPU π₯Ά This demo works best on GPU.</p>" | |
| DESCRIPTION += "\n<p>This demo will get the best kanji streaming experience in localhost (or SSH forward), instead of shared link generated by Gradio.</p>" | |
| model = AutoModelForCausalLM.from_pretrained(args.llama_model_id_or_path, torch_dtype=torch.float16, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(args.llama_model_id_or_path) | |
| tokenizer.use_default_system_prompt = False | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| lcm_stream = LatentConsistencyModelStreamIO( | |
| model_id_or_path=args.sd_model_id_or_path, | |
| lcm_lora_path=args.lcm_lora_path, | |
| lora_dict={args.lora_path: 1}, | |
| resolution=128, | |
| device=device, | |
| use_xformers=True, | |
| verbose=True, | |
| ) | |
| tmp_dir_template = f"{args.tmp_dir}/%d" | |
| response_num = 0 | |
| stitcher = ImageStitcher( | |
| tmp_dir=tmp_dir_template % response_num, | |
| img_res=args.img_res, | |
| img_per_line=args.img_per_line, | |
| verbose=True, | |
| ) | |
| def generate( | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| show_original_response: bool, | |
| seed: int, | |
| system_prompt: str = '', | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| conversation = [] | |
| if system_prompt: | |
| conversation.append({"role": "system", "content": system_prompt}) | |
| for user, assistant in chat_history: | |
| if isinstance(assistant, tuple): | |
| assistant = assistant[1] | |
| else: | |
| assistant = str(assistant) | |
| conversation.extend([ | |
| {"role": "user", "content": user}, | |
| {"role": "assistant", "content": assistant}, | |
| ]) | |
| conversation.append({"role": "user", "content": message}) | |
| print(conversation) | |
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(model.device) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| prompt_queue = SimpleQueue() | |
| lcm_stream.reset(seed) | |
| stitcher.reset() | |
| global response_num | |
| response_num += 1 | |
| stitcher.update_tmp_dir(tmp_dir_template % response_num) | |
| def append_to_queue(): | |
| for text in streamer: | |
| outputs.append(text) | |
| prompt = text.strip() | |
| if prompt: | |
| if prompt.endswith("."): prompt = prompt[:-1] | |
| prompt_queue.put(prompt) | |
| prompt_queue.put(None) | |
| append_thread = Thread(target=append_to_queue) | |
| append_thread.start() | |
| def show_image(prompt: str = None): | |
| image, text = lcm_stream(prompt) | |
| img_path = None | |
| if image is not None: | |
| img_path = stitcher.add(image, text) | |
| return img_path | |
| while True: | |
| prompt = prompt_queue.get() | |
| if prompt is None: | |
| break | |
| img_path = show_image(prompt) | |
| if img_path is not None: | |
| yield (img_path, ) | |
| # Continue to display the remaining images | |
| while True: | |
| img_path = show_image() | |
| if img_path is not None: | |
| yield (img_path, ''.join(outputs)) | |
| if lcm_stream.stop(): | |
| break | |
| print(outputs) | |
| if show_original_response: | |
| yield ''.join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| chatbot=Chatbot(height=400), | |
| additional_inputs=[ | |
| gr.Checkbox( | |
| label="Show original response", | |
| value=False, | |
| ), | |
| gr.Number( | |
| label="Seed", | |
| info="Random Seed for Kanji Generation (maybe some kind of accent π€)", | |
| step=1, | |
| value=1026, | |
| ), | |
| gr.Textbox(label="System prompt", lines=4), | |
| gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=4.0, | |
| step=0.1, | |
| value=0.6, | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.9, | |
| ), | |
| gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=50, | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.2, | |
| ), | |
| ], | |
| stop_btn=None, | |
| examples=[ | |
| ["Hello there! How are you doing?"], | |
| ["Can you explain briefly to me what is the Python programming language?"], | |
| ["Explain the plot of Cinderella in a sentence."], | |
| ["How many hours does it take a man to eat a Helicopter?"], | |
| ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], | |
| ], | |
| ) | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
| chat_interface.render() | |
| gr.Markdown(LICENSE) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(server_name="0.0.0.0", share=False) | |