import gradio as gr import torch import torch.nn.functional as F from PIL import Image from transformers import AutoModel, AutoTokenizer, AutoImageProcessor # Check if flash_attn is available def is_flash_attn_available(): try: import flash_attn return True except ImportError: return False # Load model and tokenizer @torch.inference_mode() def load_model(): use_optimized = torch.cuda.is_available() and is_flash_attn_available() model = AutoModel.from_pretrained( "visheratin/mexma-siglip2", torch_dtype=torch.bfloat16, trust_remote_code=True, optimized=True if use_optimized else False, ) if torch.cuda.is_available(): model = model.to("cuda") tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2") processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2") return model, tokenizer, processor model, tokenizer, processor = load_model() device = "cuda" if torch.cuda.is_available() else "cpu" def classify_image(image, text_queries): if image is None or not text_queries.strip(): return None # Process image processed_image = processor(images=image, return_tensors="pt")["pixel_values"] processed_image = processed_image.to(torch.bfloat16) if torch.cuda.is_available(): processed_image = processed_image.to("cuda") # Process text queries queries = [q.strip() for q in text_queries.split("\n") if q.strip()] if not queries: return None text_inputs = tokenizer(queries, return_tensors="pt", padding=True) if torch.cuda.is_available(): text_inputs = text_inputs.to("cuda") # Get predictions with torch.inference_mode(): image_logits, _ = model.get_logits( text_inputs["input_ids"], text_inputs["attention_mask"], processed_image ) probs = F.softmax(image_logits, dim=-1)[0].cpu().tolist() # Format results results = {queries[i]: f"{probs[i]:.4f}" for i in range(len(queries))} return results # Create Gradio interface with gr.Blocks(title="Mexma-SigLIP2 Zero-Shot Classification") as demo: gr.Markdown("# Mexma-SigLIP2 Zero-Shot Classification Demo") gr.Markdown(""" This demo showcases the zero-shot classification capabilities of Mexma-SigLIP2 - state-of-the-art model for multilingual zero-shot classification. ### Instructions: 1. Upload or select an image 2. Enter text queries (one per line) to classify the image 3. Click 'Submit' to see the classification probabilities The model supports multilingual queries (English, Russian, Hindi, etc.) """) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") text_input = gr.Textbox( placeholder="Enter text queries (one per line)\nExample:\na cat\na dog\nEiffel Tower", label="Text Queries", lines=5 ) submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): output = gr.Label(label="Classification Results") submit_btn.click( fn=classify_image, inputs=[image_input, text_input], outputs=output ) gr.Examples( [ [ "https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg", "Eiffel Tower\nStatue of Liberty\nTaj Mahal\nкошка\nएफिल टॉवर" ], [ "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg", "a cat\na dog\na bird\nкошка\nсобака" ] ], inputs=[image_input, text_input] ) # Launch the demo if __name__ == "__main__": demo.launch()