Spaces:
Running
on
Zero
Running
on
Zero
| """Template Demo for IBM Granite Hugging Face spaces.""" | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import PIL | |
| import spaces | |
| import torch | |
| from PIL.Image import Image as PILImage | |
| from PIL.Image import Resampling | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoModelForVision2Seq, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| LlavaNextForConditionalGeneration, | |
| LlavaNextProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from themes.research_monochrome import theme | |
| dir_ = Path(__file__).parent.parent | |
| today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
| MODEL_ID = "ibm-granite/granite-vision-3.2-2b" | |
| MODEL_ID_PREVIEW = "ibm-granite/granite-vision-3.1-2b-preview" | |
| # SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
| # Today's Date: {today_date}. | |
| # You are Granite, developed by IBM. You are a helpful AI assistant""" | |
| TITLE = "IBM Granite VISION 3.1 2b preview" | |
| DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, \ | |
| AI models can make mistakes." | |
| MAX_INPUT_TOKEN_LENGTH = 4096 | |
| MAX_NEW_TOKENS = 1024 | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.85 | |
| TOP_K = 50 | |
| REPETITION_PENALTY = 1.05 | |
| sample_data = [ | |
| [ | |
| "https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png", | |
| ["What are the three symbols on the tshirt?"], | |
| ], | |
| [ | |
| str(dir_ / "data" / "p2-report.png"), | |
| [ | |
| "What's the difference in rental income between 2020 and 2019?", | |
| "Which table entries are less in 2020 than 2019?", | |
| ], | |
| ], | |
| [ | |
| "https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png", | |
| ["What's this?"], | |
| ], | |
| ] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| processor: LlavaNextProcessor = None | |
| model: LlavaNextForConditionalGeneration = None | |
| selected_image: PILImage = None | |
| def image_changed(im: PILImage): | |
| global selected_image | |
| if im is None: | |
| selected_image = None | |
| else: | |
| selected_image = im.copy() | |
| selected_image.thumbnail((800, 800)) | |
| # return selected_image | |
| def create_single_turn(image: PILImage, text: str) -> dict: | |
| if image is None: | |
| return { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": text}, | |
| ], | |
| } | |
| else: | |
| return { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": text}, | |
| ], | |
| } | |
| def generate( | |
| image: PILImage, | |
| message: str, | |
| chat_history: list[dict], | |
| temperature: float = TEMPERATURE, | |
| repetition_penalty: float = REPETITION_PENALTY, | |
| top_p: float = TOP_P, | |
| top_k: float = TOP_K, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| ): | |
| """Generate function for chat demo. | |
| Args: | |
| max_new_tokens: | |
| top_k: | |
| top_p: | |
| repetition_penalty: | |
| temperature: | |
| image: the image to be talked about... | |
| message (str): The latest input message from the user. | |
| chat_history (list[dict]): A list of dictionaries representing previous chat history, where each dictionary | |
| contains 'role' and 'content'. | |
| Yields: | |
| str: The generated response, broken down into smaller chunks. | |
| """ | |
| print(top_p) | |
| # Build messages | |
| conversation = [] | |
| # TODO: maybe add back custom sys prompt | |
| # conversation.append({"role": "system", "content": SYS_PROMPT}) | |
| conversation += chat_history | |
| conversation.append(create_single_turn(image, message)) | |
| # Convert messages to prompt format | |
| inputs = processor.apply_chat_template( | |
| conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" | |
| ).to(device) | |
| # TODO: This might cut out the image tokens -- find better strategy | |
| # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| generate_kwargs = dict( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| output = model.generate(**inputs, **generate_kwargs) | |
| out = processor.decode(output[0], skip_special_tokens=True) | |
| out_s = out.strip().split("<|assistant|>") | |
| return [gr.ChatMessage(role="user", content=message), gr.ChatMessage(role="assistant", content=out_s[-1])] | |
| def multimodal_generate_v2( | |
| msg: str, | |
| temperature: float = TEMPERATURE, | |
| repetition_penalty: float = REPETITION_PENALTY, | |
| top_p: float = TOP_P, | |
| top_k: float = TOP_K, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| ): | |
| global model, processor | |
| # lazy loading and adding image | |
| if model is None: | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, device_map="auto").to(device) | |
| return generate( | |
| selected_image, | |
| msg, | |
| [], | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| top_p=top_p, | |
| top_k=top_k, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| tb = gr.Textbox(submit_btn=True) | |
| # advanced settings (displayed in Accordion) | |
| temperature_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1.0, | |
| value=TEMPERATURE, | |
| step=0.1, | |
| label="Temperature", | |
| elem_classes=["gr_accordion_element"], | |
| interactive=True, | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1.0, | |
| value=TOP_P, | |
| step=0.05, | |
| label="Top P", | |
| elem_classes=["gr_accordion_element"], | |
| interactive=True, | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"], interactive=True | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=0, | |
| maximum=2.0, | |
| value=REPETITION_PENALTY, | |
| step=0.05, | |
| label="Repetition Penalty", | |
| elem_classes=["gr_accordion_element"], | |
| interactive=True, | |
| ) | |
| max_new_tokens_slider = gr.Slider( | |
| minimum=1, | |
| maximum=2000, | |
| value=MAX_NEW_TOKENS, | |
| step=1, | |
| label="Max New Tokens", | |
| elem_classes=["gr_accordion_element"], | |
| interactive=True, | |
| ) | |
| chatbot = gr.Chatbot(examples=[{"text": "Hello World!"}], type="messages", label="Q&A about selected document") | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
| is_in_edit_mode = gr.State(True) # in block to be reactive | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # create sample image object for reference, render later | |
| image_x = gr.Image( | |
| type="pil", | |
| label="Example image", | |
| render=False, | |
| interactive=False, | |
| show_label=False, | |
| show_fullscreen_button=False, | |
| height=800, | |
| ) | |
| image_x.change(fn=image_changed, inputs=image_x) | |
| # Create Dataset object and render it | |
| ds = gr.Dataset(label="Select one document", samples=sample_data, components=[gr.Image(render=False)]) | |
| def sample_image_selected(d: gr.SelectData, dx): | |
| return gr.Image(dx[0]), gr.update(examples=[{"text": x} for x in dx[1]]) | |
| ds.select(lambda: [], outputs=[chatbot]) | |
| ds.select(sample_image_selected, inputs=[ds], outputs=[image_x, chatbot]) | |
| # Render image object after DS | |
| image_x.render() | |
| with gr.Column(): | |
| # Render ChatBot | |
| chatbot.render() | |
| # Define behavior for example selection | |
| def update_user_chat_x(x: gr.SelectData): | |
| return [gr.ChatMessage(role="user", content=x.value["text"])] | |
| def send_generate_x(x: gr.SelectData, temperature, repetition_penalty, top_p, top_k, max_new_tokens): | |
| txt = x.value["text"] | |
| return multimodal_generate_v2(txt, temperature, repetition_penalty, top_p, top_k, max_new_tokens) | |
| chatbot.example_select(lambda: False, outputs=is_in_edit_mode) | |
| chatbot.example_select(update_user_chat_x, outputs=[chatbot]) | |
| chatbot.example_select( | |
| send_generate_x, | |
| inputs=[ | |
| temperature_slider, | |
| repetition_penalty_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| max_new_tokens_slider, | |
| ], | |
| outputs=[chatbot], | |
| ) | |
| # Create User Chat Textbox and Reset Button | |
| tbb = gr.Textbox(submit_btn=True, show_label=False) | |
| fb = gr.Button("Reset Chat", visible=False) | |
| fb.click(lambda: [], outputs=[chatbot]) | |
| # Handle toggling betwwen edit and non-edit mode | |
| def textbox_switch(emode): | |
| # if t.visible: | |
| if not emode: | |
| return [gr.update(visible=False), gr.update(visible=True)] | |
| else: | |
| return [gr.update(visible=True), gr.update(visible=False)] | |
| tbb.submit(lambda: False, outputs=[is_in_edit_mode]) | |
| fb.click(lambda: True, outputs=[is_in_edit_mode]) | |
| is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb]) | |
| # submit user question | |
| tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot) | |
| tbb.submit( | |
| multimodal_generate_v2, | |
| inputs=[ | |
| tbb, | |
| temperature_slider, | |
| repetition_penalty_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| max_new_tokens_slider, | |
| ], | |
| outputs=[chatbot], | |
| ) | |
| # extra model parameters | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_new_tokens_slider.render() | |
| temperature_slider.render() | |
| top_k_slider.render() | |
| top_p_slider.render() | |
| repetition_penalty_slider.render() | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |