rajyalakshmijampani commited on
Commit
0edb773
·
1 Parent(s): 6367efb

use hf_hub_download in place of direct URL download

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -12,7 +12,7 @@ import re
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from tavily import TavilyClient
15
- from huggingface_hub import InferenceClient
16
 
17
  class CLIPImageClassifier(nn.Module):
18
  def __init__(self):
@@ -39,6 +39,7 @@ HF_TOKEN = None
39
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
40
  explain_model = "meta-llama/Llama-3.1-8B-Instruct"
41
  text_model = "rajyalakshmijampani/fever_finetuned_deberta"
 
42
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
43
  image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
44
 
@@ -51,18 +52,12 @@ def get_text_classifier():
51
  return text_classifier
52
 
53
  def get_image_classifier():
54
- global image_classifier
 
55
  if image_classifier is None:
56
- url = "https://huggingface.co/rajyalakshmijampani/finetuned_clip/resolve/main/best_clip_finetuned_classifier.pth"
57
- path = "best_clip_finetuned_classifier.pth"
58
-
59
- if not os.path.exists(path):
60
- r = requests.get(url)
61
- with open(path, "wb") as f:
62
- f.write(r.content)
63
-
64
  image_classifier = CLIPImageClassifier()
65
- state = torch.load(path, map_location="cpu",weights_only=False)
66
  clean_state = OrderedDict(
67
  (k[7:], v) if k.startswith("module.") else (k, v)
68
  for k, v in state.items()
@@ -276,8 +271,9 @@ with gr.Blocks() as demo:
276
 
277
  with gr.Tab("Image Detector"):
278
  img_input = gr.Image(type="pil", label="Upload Image")
279
- img_output = gr.Markdown(label="Model Output", value="Results will appear here...")
280
  img_button = gr.Button("Classify Image")
 
 
281
  img_button.click(classify_image, inputs=img_input, outputs=img_output)
282
 
283
  demo.launch()
 
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from tavily import TavilyClient
15
+ from huggingface_hub import InferenceClient, hf_hub_download
16
 
17
  class CLIPImageClassifier(nn.Module):
18
  def __init__(self):
 
39
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
40
  explain_model = "meta-llama/Llama-3.1-8B-Instruct"
41
  text_model = "rajyalakshmijampani/fever_finetuned_deberta"
42
+ image_model = "rajyalakshmijampani/finetuned_clip"
43
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
44
  image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
45
 
 
52
  return text_classifier
53
 
54
  def get_image_classifier():
55
+ global image_classifier, image_model
56
+ filename = "finetuned_clip.pth"
57
  if image_classifier is None:
58
+ model_path = hf_hub_download(repo_id=image_model, filename=filename)
 
 
 
 
 
 
 
59
  image_classifier = CLIPImageClassifier()
60
+ state = torch.load(model_path, map_location="cpu",weights_only=False)
61
  clean_state = OrderedDict(
62
  (k[7:], v) if k.startswith("module.") else (k, v)
63
  for k, v in state.items()
 
271
 
272
  with gr.Tab("Image Detector"):
273
  img_input = gr.Image(type="pil", label="Upload Image")
 
274
  img_button = gr.Button("Classify Image")
275
+ img_output = gr.Markdown(label="Model Output", value="Results will appear here...")
276
+
277
  img_button.click(classify_image, inputs=img_input, outputs=img_output)
278
 
279
  demo.launch()