Spaces:
Runtime error
Runtime error
adding the files to run the app
Browse files- app.py +161 -0
- model.py +75 -0
- requirements.txt +3 -0
- search.py +105 -0
- utils.py +34 -0
app.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
+
from search import search_similarity, process_image_for_encoder_gradio
|
| 5 |
+
from utils import str_to_bytes
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
|
| 8 |
+
def add_ranking_number(image, rank):
|
| 9 |
+
"""Añade un número de ranking a la imagen"""
|
| 10 |
+
img_with_rank = image.copy()
|
| 11 |
+
draw = ImageDraw.Draw(img_with_rank)
|
| 12 |
+
|
| 13 |
+
width, height = image.size
|
| 14 |
+
circle_radius = min(width, height) // 15
|
| 15 |
+
circle_position = (circle_radius + 10, circle_radius + 10)
|
| 16 |
+
|
| 17 |
+
draw.ellipse(
|
| 18 |
+
[(circle_position[0] - circle_radius, circle_position[1] - circle_radius),
|
| 19 |
+
(circle_position[0] + circle_radius, circle_position[1] + circle_radius)],
|
| 20 |
+
fill='white',
|
| 21 |
+
outline='black'
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
font_size = circle_radius
|
| 25 |
+
try:
|
| 26 |
+
font = ImageFont.truetype("Arial.ttf", font_size)
|
| 27 |
+
except:
|
| 28 |
+
font = ImageFont.load_default()
|
| 29 |
+
|
| 30 |
+
text = str(rank + 1)
|
| 31 |
+
text_bbox = draw.textbbox((0, 0), text, font=font)
|
| 32 |
+
text_width = text_bbox[2] - text_bbox[0]
|
| 33 |
+
text_height = text_bbox[3] - text_bbox[1]
|
| 34 |
+
text_position = (
|
| 35 |
+
circle_position[0] - text_width // 2,
|
| 36 |
+
circle_position[1] - text_height // 2
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
draw.text(text_position, text, fill='black', font=font)
|
| 40 |
+
return img_with_rank
|
| 41 |
+
|
| 42 |
+
def process_image_result(image_str, rank):
|
| 43 |
+
"""Convierte una cadena de imagen en un objeto PIL Image con ranking"""
|
| 44 |
+
try:
|
| 45 |
+
img = Image.open(BytesIO(str_to_bytes(image_str)))
|
| 46 |
+
return add_ranking_number(img, rank)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error procesando imagen: {e}")
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
def interface_fn(mode, input_text, input_image, top_k):
|
| 52 |
+
try:
|
| 53 |
+
# Determinar qué input usar basado en el modo
|
| 54 |
+
if mode == "text":
|
| 55 |
+
if not input_text.strip():
|
| 56 |
+
return [], "Por favor, ingresa un texto para buscar."
|
| 57 |
+
input_data = input_text
|
| 58 |
+
else: # mode == "image"
|
| 59 |
+
if input_image is None:
|
| 60 |
+
return [], "Por favor, sube una imagen para buscar."
|
| 61 |
+
input_data = process_image_for_encoder_gradio(input_image, is_bytes=False)
|
| 62 |
+
|
| 63 |
+
# Show the input data
|
| 64 |
+
print(f"Input data: {input_data}") # Para debugging
|
| 65 |
+
|
| 66 |
+
# Realizar la búsqueda
|
| 67 |
+
results = search_similarity(input_data, mode, int(top_k))
|
| 68 |
+
|
| 69 |
+
# Formatear resultados según el modo
|
| 70 |
+
if mode == "text": # Devuelve imágenes
|
| 71 |
+
processed_images = []
|
| 72 |
+
# Si results es una lista de listas, la aplanamos
|
| 73 |
+
if results and isinstance(results[0], list):
|
| 74 |
+
print("Recibida lista de listas, aplanando...") # Para debugging
|
| 75 |
+
results = [item for sublist in results for item in sublist]
|
| 76 |
+
|
| 77 |
+
for idx, img_str in enumerate(results):
|
| 78 |
+
img = process_image_result(img_str, idx)
|
| 79 |
+
if img is not None:
|
| 80 |
+
processed_images.append(img)
|
| 81 |
+
|
| 82 |
+
if not processed_images:
|
| 83 |
+
return [], "No se pudieron procesar las imágenes"
|
| 84 |
+
return processed_images, None
|
| 85 |
+
|
| 86 |
+
else: # mode == "image" - Devuelve textos
|
| 87 |
+
if isinstance(results, list):
|
| 88 |
+
numbered_texts = [f"{i+1}. {text}" for i, text in enumerate(results)]
|
| 89 |
+
return [], "\n\n".join(numbered_texts)
|
| 90 |
+
else:
|
| 91 |
+
return [], str(results)
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"Error en interface_fn: {str(e)}")
|
| 95 |
+
print(f"Tipo de resultados: {type(results)}") # Para debugging
|
| 96 |
+
return [], f"Error durante la búsqueda: {str(e)}"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def search_text(input_text, top_k):
|
| 100 |
+
try:
|
| 101 |
+
if not input_text.strip():
|
| 102 |
+
return []
|
| 103 |
+
|
| 104 |
+
# Realizar la búsqueda
|
| 105 |
+
results = search_similarity(input_text, "text", int(top_k))
|
| 106 |
+
|
| 107 |
+
processed_images = []
|
| 108 |
+
# Si results es una lista de listas, la aplanamos
|
| 109 |
+
if results and isinstance(results[0], list):
|
| 110 |
+
results = [item for sublist in results for item in sublist]
|
| 111 |
+
|
| 112 |
+
for idx, img_str in enumerate(results):
|
| 113 |
+
img = process_image_result(img_str, idx)
|
| 114 |
+
if img is not None:
|
| 115 |
+
processed_images.append(img)
|
| 116 |
+
|
| 117 |
+
return processed_images
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error en search_text: {str(e)}")
|
| 121 |
+
return []
|
| 122 |
+
|
| 123 |
+
with gr.Blocks() as demo:
|
| 124 |
+
gr.Markdown("# Buscador de Similitud por Texto")
|
| 125 |
+
|
| 126 |
+
with gr.Row():
|
| 127 |
+
with gr.Column(scale=1):
|
| 128 |
+
input_text = gr.Textbox(
|
| 129 |
+
label="Texto de búsqueda",
|
| 130 |
+
placeholder="Ingresa aquí tu texto...",
|
| 131 |
+
lines=3
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
top_k = gr.Slider(
|
| 135 |
+
minimum=1,
|
| 136 |
+
maximum=20,
|
| 137 |
+
value=5,
|
| 138 |
+
step=1,
|
| 139 |
+
label="Número de resultados",
|
| 140 |
+
info="¿Cuántos resultados similares quieres ver?"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
search_button = gr.Button("Buscar")
|
| 144 |
+
|
| 145 |
+
with gr.Column(scale=1):
|
| 146 |
+
output_gallery = gr.Gallery(
|
| 147 |
+
label="Imágenes similares",
|
| 148 |
+
columns=3,
|
| 149 |
+
height="auto"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
search_button.click(
|
| 153 |
+
fn=search_text,
|
| 154 |
+
inputs=[input_text, top_k],
|
| 155 |
+
outputs=output_gallery
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
from multiprocessing import freeze_support
|
| 160 |
+
freeze_support()
|
| 161 |
+
demo.launch()
|
model.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 3 |
+
|
| 4 |
+
class TextEncoderHead(nn.Module):
|
| 5 |
+
def __init__(self, model):
|
| 6 |
+
super(TextEncoderHead, self).__init__()
|
| 7 |
+
self.model = model
|
| 8 |
+
for param in self.model.parameters():
|
| 9 |
+
param.requires_grad = False
|
| 10 |
+
# uncomment this for chemberta
|
| 11 |
+
# self.seq1 = nn.Sequential(
|
| 12 |
+
# nn.Flatten(),
|
| 13 |
+
# nn.Linear(767*256, 2000),
|
| 14 |
+
# nn.Dropout(0.3),
|
| 15 |
+
# nn.ReLU(),
|
| 16 |
+
# nn.Linear(2000, 512),
|
| 17 |
+
# nn.LayerNorm(512)
|
| 18 |
+
# )
|
| 19 |
+
self.seq1 = nn.Sequential(
|
| 20 |
+
nn.Flatten(),
|
| 21 |
+
nn.Linear(768*256, 2000),
|
| 22 |
+
nn.Dropout(0.3),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
nn.Linear(2000, 512),
|
| 25 |
+
nn.LayerNorm(512)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, input_ids, attention_mask):
|
| 29 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 30 |
+
# uncomment this for chemberta
|
| 31 |
+
# outputs = outputs.logits
|
| 32 |
+
outputs = outputs.last_hidden_state
|
| 33 |
+
outputs = self.seq1(outputs)
|
| 34 |
+
return outputs.contiguous()
|
| 35 |
+
|
| 36 |
+
class ImageEncoderHead(nn.Module):
|
| 37 |
+
def __init__(self, model):
|
| 38 |
+
super(ImageEncoderHead, self).__init__()
|
| 39 |
+
self.model = model
|
| 40 |
+
for param in self.model.parameters():
|
| 41 |
+
param.requires_grad = False
|
| 42 |
+
# for resnet model
|
| 43 |
+
# self.seq1 = nn.Sequential(
|
| 44 |
+
# nn.Flatten(),
|
| 45 |
+
# nn.Linear(512*7*7, 1000),
|
| 46 |
+
# nn.Linear(1000, 512),
|
| 47 |
+
# nn.LayerNorm(512)
|
| 48 |
+
# )
|
| 49 |
+
# for vit model
|
| 50 |
+
self.seq1 = nn.Sequential(
|
| 51 |
+
nn.Linear(768, 1000),
|
| 52 |
+
nn.Dropout(0.3),
|
| 53 |
+
nn.ReLU(),
|
| 54 |
+
nn.Linear(1000, 512),
|
| 55 |
+
nn.LayerNorm(512)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def forward(self, pixel_values):
|
| 60 |
+
outputs = self.model(pixel_values)
|
| 61 |
+
outputs = outputs.last_hidden_state.mean(dim=1)
|
| 62 |
+
outputs = self.seq1(outputs)
|
| 63 |
+
return outputs.contiguous()
|
| 64 |
+
|
| 65 |
+
class CLIPChemistryModel(nn.Module, PyTorchModelHubMixin):
|
| 66 |
+
def __init__(self, text_encoder, image_encoder):
|
| 67 |
+
super(CLIPChemistryModel, self).__init__()
|
| 68 |
+
self.text_encoder = text_encoder
|
| 69 |
+
self.image_encoder = image_encoder
|
| 70 |
+
|
| 71 |
+
def forward(self, image, input_ids, attention_mask):
|
| 72 |
+
# calculate the embeddings
|
| 73 |
+
ie = self.image_encoder(image)
|
| 74 |
+
te = self.text_encoder(input_ids, attention_mask)
|
| 75 |
+
return ie, te
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pinecone=="5.4.2"
|
| 2 |
+
transformers=="4.47.0"
|
| 3 |
+
huggingface-hub=="0.26.5"
|
search.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor, DistilBertModel
|
| 2 |
+
from pinecone import Pinecone
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
pc = Pinecone()
|
| 7 |
+
index = pc.Index("clipmodel")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
import base64
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
sys.path.append('../src')
|
| 17 |
+
|
| 18 |
+
from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
ENCODER_BASE = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
| 22 |
+
IMAGE_BASE = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
| 23 |
+
text_encoder = TextEncoderHead(model=ENCODER_BASE)
|
| 24 |
+
image_encoder = ImageEncoderHead(model=IMAGE_BASE)
|
| 25 |
+
|
| 26 |
+
clip_model = CLIPChemistryModel(text_encoder=text_encoder, image_encoder=image_encoder)
|
| 27 |
+
|
| 28 |
+
clip_model.load_state_dict(torch.load('/Users/sebastianalejandrosarastizambonino/Documents/projects/CLIP_Pytorch/src/best_model_fashion.pth', map_location=torch.device('cpu')))
|
| 29 |
+
|
| 30 |
+
te_final = clip_model.text_encoder
|
| 31 |
+
ie_final = clip_model.image_encoder
|
| 32 |
+
|
| 33 |
+
def process_text_for_encoder(text, model):
|
| 34 |
+
# tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
| 35 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
| 36 |
+
encoded_input = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=256)
|
| 37 |
+
input_ids = encoded_input['input_ids']
|
| 38 |
+
attention_mask = encoded_input['attention_mask']
|
| 39 |
+
output = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 40 |
+
return output.detach().numpy().tolist()[0]
|
| 41 |
+
|
| 42 |
+
def process_image_for_encoder(image, model):
|
| 43 |
+
# image = Image.open(BytesIO(image))
|
| 44 |
+
print(type(image))
|
| 45 |
+
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
| 46 |
+
image_tensor = image_processor(image,
|
| 47 |
+
return_tensors="pt",
|
| 48 |
+
do_resize=True
|
| 49 |
+
)['pixel_values']
|
| 50 |
+
output = model(pixel_values=image_tensor)
|
| 51 |
+
return output.detach().numpy().tolist()[0]
|
| 52 |
+
|
| 53 |
+
def search_similarity(input, mode, top_k=5):
|
| 54 |
+
if mode == 'text':
|
| 55 |
+
output = process_text_for_encoder(input, model=te_final)
|
| 56 |
+
else:
|
| 57 |
+
output = input
|
| 58 |
+
|
| 59 |
+
if mode == 'text':
|
| 60 |
+
mode_search = 'image'
|
| 61 |
+
response = index.query(
|
| 62 |
+
namespace="space-" + mode_search + "-fashion",
|
| 63 |
+
vector=output,
|
| 64 |
+
top_k=top_k,
|
| 65 |
+
include_values=True,
|
| 66 |
+
include_metadata=True
|
| 67 |
+
)
|
| 68 |
+
similar_images = [value['metadata']['image'] for value in response['matches']]
|
| 69 |
+
return similar_images
|
| 70 |
+
elif mode == 'image':
|
| 71 |
+
mode_search = 'text'
|
| 72 |
+
response = index.query(
|
| 73 |
+
namespace="space-" + mode_search + "-fashion",
|
| 74 |
+
vector=output,
|
| 75 |
+
top_k=top_k,
|
| 76 |
+
include_values=True,
|
| 77 |
+
include_metadata=True
|
| 78 |
+
)
|
| 79 |
+
similar_text = [value['metadata']['text'] for value in response['matches']]
|
| 80 |
+
return similar_text
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError("mode must be either 'text' or 'image'")
|
| 83 |
+
|
| 84 |
+
def process_image_for_encoder_gradio(image, is_bytes=True):
|
| 85 |
+
"""Procesa tanto imágenes en bytes como objetos PIL Image"""
|
| 86 |
+
try:
|
| 87 |
+
if is_bytes:
|
| 88 |
+
# Si la imagen viene en bytes
|
| 89 |
+
image = Image.open(BytesIO(image))
|
| 90 |
+
else:
|
| 91 |
+
# Si la imagen ya es un objeto PIL Image o viene de gradio
|
| 92 |
+
if not isinstance(image, Image.Image):
|
| 93 |
+
# Si viene de gradio, podría ser un numpy array
|
| 94 |
+
image = Image.fromarray(image)
|
| 95 |
+
|
| 96 |
+
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
| 97 |
+
image_tensor = image_processor(image,
|
| 98 |
+
return_tensors="pt",
|
| 99 |
+
do_resize=True
|
| 100 |
+
)['pixel_values']
|
| 101 |
+
output = ie_final(pixel_values=image_tensor)
|
| 102 |
+
return output.detach().numpy().tolist()[0]
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error en process_image_for_encoder: {e}")
|
| 105 |
+
raise
|
utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead
|
| 2 |
+
from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor
|
| 3 |
+
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
import base64
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
def bytes_to_str(bytes_data):
|
| 10 |
+
return base64.b64encode(bytes_data).decode('utf-8')
|
| 11 |
+
|
| 12 |
+
def str_to_bytes(str_data):
|
| 13 |
+
return base64.b64decode(str_data)
|
| 14 |
+
|
| 15 |
+
def push_embeddings_to_pine_cone(index, embeddings, df, mode, length):
|
| 16 |
+
records = []
|
| 17 |
+
for i in range(length):
|
| 18 |
+
if mode == 'text':
|
| 19 |
+
records.append({
|
| 20 |
+
"id": str(mode) + str(i),
|
| 21 |
+
"values": embeddings[i],
|
| 22 |
+
"metadata": {str(mode): df[mode].iloc[i]}})
|
| 23 |
+
elif mode == 'image':
|
| 24 |
+
records.append({
|
| 25 |
+
"id": str(mode) + str(i),
|
| 26 |
+
"values": embeddings[i],
|
| 27 |
+
"metadata": {str(mode): bytes_to_str(df[mode].iloc[i]['bytes'])}})
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("mode must be either 'text' or 'image'")
|
| 30 |
+
|
| 31 |
+
index.upsert(
|
| 32 |
+
vectors=records,
|
| 33 |
+
namespace="space-" + mode
|
| 34 |
+
)
|