Add first app
Browse files
app.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.fftpack import dct
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from multiprocessing import cpu_count
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def perceptual_hash_color(image):
|
| 11 |
+
image = image.convert("RGB") # Convert to grayscale
|
| 12 |
+
image = image.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
|
| 13 |
+
image_array = np.asarray(image) # Convert to numpy array
|
| 14 |
+
hashes = []
|
| 15 |
+
for i in range(3):
|
| 16 |
+
channel = image_array[:, :, i]
|
| 17 |
+
dct_coef = dct(dct(channel, axis=0), axis=1) # Compute DCT
|
| 18 |
+
dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
|
| 19 |
+
# Median of DCT coefficients excluding the DC term (0th term)
|
| 20 |
+
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
|
| 21 |
+
# Mask of all coefficients greater than median of coefficients
|
| 22 |
+
hashes.append((dct_reduced_coef >= median_coef_val).flatten() * 1)
|
| 23 |
+
return np.concatenate(hashes)
|
| 24 |
+
|
| 25 |
+
def hamming_distance(array_1, array_2):
|
| 26 |
+
return len([1 for el_1, el_2 in zip(array_1, array_2) if el_1 != el_2])
|
| 27 |
+
|
| 28 |
+
def search_closest_examples(hash_refs, img_dataset):
|
| 29 |
+
distances = []
|
| 30 |
+
for hash_ref in hash_refs:
|
| 31 |
+
distances.extend([hamming_distance(hash_ref, img_dataset[idx]["hash"]) for idx in range(img_dataset.num_rows)])
|
| 32 |
+
closests = [i.item() % len(img_dataset) for i in np.argsort(distances)[:9]]
|
| 33 |
+
return closests, [distances[c] for c in closests]
|
| 34 |
+
|
| 35 |
+
def find_closest_images(images, img_dataset):
|
| 36 |
+
if not isinstance(images, (list, tuple)):
|
| 37 |
+
images = [images]
|
| 38 |
+
hashes = [perceptual_hash_color(img) for img in images]
|
| 39 |
+
closest_idx, distances = search_closest_examples(hashes, img_dataset)
|
| 40 |
+
return closest_idx, distances
|
| 41 |
+
|
| 42 |
+
def compute_hash_from_image(img):
|
| 43 |
+
img = img.convert("L") # Convert to grayscale
|
| 44 |
+
img = img.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
|
| 45 |
+
img_array = np.asarray(img) # Convert to numpy array
|
| 46 |
+
dct_coef = dct(dct(img_array, axis=0), axis=1) # Compute DCT
|
| 47 |
+
dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
|
| 48 |
+
# Median of DCT coefficients excluding the DC term (0th term)
|
| 49 |
+
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
|
| 50 |
+
# Mask of all coefficients greater than median of coefficients
|
| 51 |
+
hash = (dct_reduced_coef >= median_coef_val).flatten() * 1
|
| 52 |
+
return hash
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def process_dataset(dataset_name, dataset_split, dataset_column_image):
|
| 56 |
+
img_dataset = load_dataset(dataset_name)[dataset_split]
|
| 57 |
+
|
| 58 |
+
def add_hash(example):
|
| 59 |
+
example["hash"] = perceptual_hash_color(example[dataset_column_image])
|
| 60 |
+
return example
|
| 61 |
+
|
| 62 |
+
# Compute hash of every image in the dataset
|
| 63 |
+
img_dataset = img_dataset.map(add_hash, num_proc=4)
|
| 64 |
+
return img_dataset
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def compute(dataset_name, dataset_split, dataset_column_image, img):
|
| 68 |
+
img_dataset = process_dataset(dataset_name, dataset_split, dataset_column_image)
|
| 69 |
+
closest_idx, distances = find_closest_images(img, img_dataset)
|
| 70 |
+
return [img_dataset[i] for i in closest_idx]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
with gr.Blocks() as demo:
|
| 74 |
+
gr.Markdown("# Find if your images are in a public dataset!")
|
| 75 |
+
with gr.Row():
|
| 76 |
+
with gr.Column(scale=1, min_width=600):
|
| 77 |
+
dataset_name = gr.Textbox(label="Enter the name of a dataset containing images")
|
| 78 |
+
dataset_split = gr.Textbox(label="Enter the split of this dataset to consider")
|
| 79 |
+
dataset_column_image = gr.Textbox(label="Enter the name of the column of this dataset that contains images")
|
| 80 |
+
img = gr.Image(label="Input your image that will be compared against images of the dataset", type="pil")
|
| 81 |
+
btn = gr.Button("Find").style(full_width=True)
|
| 82 |
+
|
| 83 |
+
with gr.Column(scale=2, min_width=600):
|
| 84 |
+
gallery_similar = gr.Gallery(label="similar images")
|
| 85 |
+
|
| 86 |
+
event = btn.click(compute, [dataset_name, dataset_split, dataset_column_image, img], gallery_similar)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
demo.launch()
|