Tarun-1999M commited on
Commit
300d24e
·
verified ·
1 Parent(s): 38b881b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: Project_CUDA_Enabled (1).ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['model_checkpoint', 'model', 'tokenizer', 'dataset', 'train_dataset', 'iface', 'transform', 'cls_pooling',
5
+ 'search_arxiv']
6
+
7
+ # %% Project_CUDA_Enabled (1).ipynb 52
8
+ import gradio as gr
9
+ from datasets import load_dataset
10
+ import numpy as np
11
+ from transformers import AutoTokenizer, AutoModel
12
+
13
+
14
+ model_checkpoint = 'sentence-transformers/all-MiniLM-L6-v2'
15
+ model = AutoModel.from_pretrained(model_checkpoint)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
17
+
18
+
19
+ # Load the dataset from Hugging Face
20
+ dataset = load_dataset('Tarun-1999M/arxiv_cs_lg_embeddings')
21
+ train_dataset = dataset['train']
22
+
23
+
24
+
25
+ # Ensure embeddings are converted to NumPy arrays on-the-fly using set_transform
26
+ def transform(example):
27
+ example['embeddings'] = np.array(example['embeddings'], dtype=np.float32)
28
+ return example
29
+
30
+
31
+
32
+
33
+ def cls_pooling(model_output):
34
+ return model_output.last_hidden_state[:,0]
35
+
36
+ train_dataset.set_transform(transform)
37
+
38
+ # Add FAISS index
39
+ train_dataset.add_faiss_index(column='embeddings')
40
+
41
+
42
+ # Function to search the ArXiv papers
43
+ def search_arxiv(query):
44
+ # Get the embedding for the query
45
+ question_embedding = get_embeddings([query]).cpu().detach().numpy()
46
+
47
+ # Search for similar papers
48
+ scores, samples = train_dataset.get_nearest_examples("embeddings", question_embedding, k=5)
49
+
50
+ # Sort the results by scores in descending order
51
+ sorted_results = sorted(zip(scores, samples['title'], samples['abstract']), reverse=True)
52
+
53
+ # Prepare and format the results for display
54
+ results = []
55
+ for score, title, abstract in sorted_results:
56
+ result = f"\n**Title:** {title}\n**Abstract:** {abstract}\n**Score:** {score:.4f}"
57
+ results.append(result)
58
+
59
+ return "\n\n".join(results)
60
+
61
+ # Create the Gradio interface
62
+ iface = gr.Interface(
63
+ fn=search_arxiv,
64
+ inputs=gr.components.Textbox(lines=1, placeholder="Enter your query..."),
65
+ outputs="markdown",
66
+ title="Semantic Search in ArXiv ML Papers",
67
+ description="Enter a query to find relevant ML papers from the ArXiv dataset."
68
+ )
69
+
70
+ # Launch the interface
71
+ iface.launch(share=True)
72
+