Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer | |
| import requests | |
| from PIL import Image | |
| import torch | |
| CHECKPOINT = "adalbertojunior/image_captioning_portuguese" | |
| def get_model(): | |
| model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT) | |
| return model | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT) | |
| tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) | |
| st.title("Image Captioning with ViT & GPT2 π§π·") | |
| st.sidebar.markdown("## Generation parameters") | |
| max_length = st.sidebar.number_input("Max length", value=20, min_value=1) | |
| no_repeat_ngram_size = st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1) | |
| num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1) | |
| gen_mode = st.sidebar.selectbox("Generation mode", ["beam search", "sampling"]) | |
| if gen_mode == "beam search": | |
| num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1) | |
| early_stopping = st.sidebar.checkbox("Early stopping", value=True) | |
| gen_params = { | |
| "num_beams": num_beams, | |
| "early_stopping": early_stopping | |
| } | |
| elif gen_mode == "sampling": | |
| do_sample = True | |
| top_k = st.sidebar.number_input("top_k", value=30, min_value=0) | |
| top_p = st.sidebar.number_input("top_p", value=0, min_value=0) | |
| temperature = st.sidebar.number_input("temperature", value=0.7, min_value=0.0) | |
| gen_params = { | |
| "do_sample": do_sample, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temperature | |
| } | |
| def generate_caption(url): | |
| image = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
| inputs = feature_extractor(image, return_tensors="pt") | |
| model = get_model() | |
| model.eval() | |
| generated_ids = model.generate( | |
| inputs["pixel_values"], | |
| max_length=20, | |
| no_repeat_ngram_size=2, | |
| num_return_sequences=3, | |
| **gen_params | |
| ) | |
| captions = tokenizer.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| ) | |
| return captions[0] | |
| url = st.text_input( | |
| "Insert your URL", "https://static.cdn.pleno.news/2017/09/avi%C3%A3o-e1572374124339.jpg" | |
| ) | |
| st.image(url) | |
| if st.button("Run captioning"): | |
| with st.spinner("Processing image..."): | |
| caption = generate_caption(url) | |
| st.text(caption) | |
| st.text("Research supported with Cloud TPUs from Google's TPU Research Cloud (TRC)") |