import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer DESCRIPTION = """\ # Gemma 2 9B IT Gemma 2 is Google's latest iteration of open LLMs. This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following. For more details, please check [our post](https://huggingface.co/blog/gemma2). 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it). """ MAX_NEW_TOKENS_LIMIT = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096")) MODEL_ID = "google/gemma-2-9b-it" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, ) model.config.sliding_window = 4096 model.eval() @spaces.GPU(duration=90) def _generate_on_gpu( input_ids: torch.Tensor, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, ) -> Iterator[str]: input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, "disable_compile": True, } exception_holder: list[Exception] = [] def _generate() -> None: try: model.generate(**generate_kwargs) except Exception as e: # noqa: BLE001 exception_holder.append(e) thread = Thread(target=_generate) thread.start() chunks: list[str] = [] for text in streamer: chunks.append(text) yield "".join(chunks) thread.join() if exception_holder: msg = f"Generation failed: {exception_holder[0]}" raise gr.Error(msg) def generate( message: str, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: if not message or not message.strip(): raise gr.Error("Please enter a message.") conversation = [] for msg in chat_history: content = msg["content"] if isinstance(content, list): text = "".join(part["text"] for part in content if part.get("type") == "text") else: text = content conversation.append({"role": msg["role"], "content": text}) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True ).input_ids n_input_tokens = input_ids.shape[1] if n_input_tokens > MAX_INPUT_TOKENS: err_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens." raise gr.Error(err_msg) max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens) if max_new_tokens <= 0: raise gr.Error("Input uses the entire context window. No room to generate new tokens.") yield from _generate_on_gpu( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, ) demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_NEW_TOKENS_LIMIT, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=False, examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], cache_examples=False, description=DESCRIPTION, fill_height=True, ) if __name__ == "__main__": demo.launch(css_paths="style.css")