from __future__ import annotations import spaces import gradio as gr from threading import Thread from transformers import TextIteratorStreamer import hashlib import os from transformers import AutoModel, AutoProcessor import torch import sys import subprocess from PIL import Image import time subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'packaging']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'ninja']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d']) from cobra import load vlm = load("cobra+3b") if torch.cuda.is_available(): DEVICE = "cuda" DTYPE = torch.float32 else: DEVICE = "cpu" DTYPE = torch.float32 vlm.enable_mixed_precision_training = False vlm.to(DEVICE, dtype=DTYPE) prompt_builder = vlm.get_prompt_builder() @spaces.GPU def bot_streaming(message, history, temperature, top_k, max_new_tokens): if len(history) == 0: prompt_builder.prompt, prompt_builder.turn_count = "", 0 if message["files"]: image = message["files"][-1]["path"] else: # if there's no image uploaded for this turn, look for images in the past turns # kept inside tuples, take the last one for hist in history: if type(hist[0])==tuple: image = hist[0][0] image = Image.open(image).convert("RGB") prompt_builder.add_turn(role="human", message=message['text']) prompt_text = prompt_builder.get_prompt() # Generate from the VLM with torch.no_grad(): generated_text = vlm.generate( image, prompt_text, cg=True, do_sample=True, temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens, ) prompt_builder.add_turn(role="gpt", message=generated_text) time.sleep(0.04) yield generated_text demo = gr.ChatInterface(fn=bot_streaming, additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), gr.Slider(1, 3, value=1, step=1, label="Top k"), gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], title="Cobra", description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.", stop_btn="Stop Generation", multimodal=True) demo.launch(debug=True)