Spaces:
Runtime error
Runtime error
Update search
Browse files
app.py
CHANGED
|
@@ -2,34 +2,52 @@ import gradio as gr
|
|
| 2 |
import random
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def main():
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
demo = gr.TabbedInterface([text_to_image_iface, image_to_image_iface], ["Text query", "Image query"])
|
| 34 |
demo.launch()
|
| 35 |
|
|
|
|
| 2 |
import random
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
import logging
|
| 6 |
+
from PIL import Image
|
| 7 |
+
# Create a custom logger
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
+
# Set the level of this logger. INFO means that it will log all INFO, WARNING, ERROR, and CRITICAL messages.
|
| 11 |
+
logger.setLevel(logging.INFO)
|
| 12 |
+
|
| 13 |
+
# Create handlers
|
| 14 |
+
c_handler = logging.StreamHandler()
|
| 15 |
+
c_handler.setLevel(logging.INFO)
|
| 16 |
+
|
| 17 |
+
# Create formatters and add it to handlers
|
| 18 |
+
c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
|
| 19 |
+
c_handler.setFormatter(c_format)
|
| 20 |
+
|
| 21 |
+
# Add handlers to the logger
|
| 22 |
+
logger.addHandler(c_handler)
|
| 23 |
+
|
| 24 |
+
class SearchEngine:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.model = SentenceTransformer('clip-ViT-B-32')
|
| 27 |
+
self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device="cuda:0")
|
| 28 |
+
image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
|
| 29 |
+
self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
|
| 30 |
+
|
| 31 |
+
def get_candidates(self, query_embedding, top_k=5):
|
| 32 |
+
logger.info("Getting candidates")
|
| 33 |
+
candidates = util.semantic_search(query_embeddings=query_embedding.unsqueeze(0), corpus_embeddings=self.embedding_dataset["image_embedding"].squeeze(1), top_k=top_k)[0]
|
| 34 |
+
return [self.image_dataset.get(self.embedding_dataset[candidate["corpus_id"]]["image_id"], "https://upload.wikimedia.org/wikipedia/commons/6/69/NASA-HS201427a-HubbleUltraDeepField2014-20140603.jpg") for candidate in candidates]
|
| 35 |
+
|
| 36 |
+
def search_images_from_text(self, text):
|
| 37 |
+
logger.info("Searching images from text")
|
| 38 |
+
emb = self.model.encode(text, convert_to_tensor=True, device="cuda:0")
|
| 39 |
+
return self.get_candidates(query_embedding=emb)
|
| 40 |
+
|
| 41 |
+
def search_images_from_image(self, image):
|
| 42 |
+
logger.info("Searching images from image")
|
| 43 |
+
emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device="cuda:0")
|
| 44 |
+
return self.get_candidates(query_embedding=emb)
|
| 45 |
|
| 46 |
def main():
|
| 47 |
+
logger.info("Loading dataset")
|
| 48 |
+
search_engine = SearchEngine()
|
| 49 |
+
text_to_image_iface = gr.Interface(fn=search_engine.search_images_from_text, inputs="text", outputs="gallery")
|
| 50 |
+
image_to_image_iface = gr.Interface(fn=search_engine.search_images_from_image, inputs="image", outputs="gallery")
|
| 51 |
demo = gr.TabbedInterface([text_to_image_iface, image_to_image_iface], ["Text query", "Image query"])
|
| 52 |
demo.launch()
|
| 53 |
|