SofiTesfay2010 commited on
Commit
ecb6e4f
·
1 Parent(s): d816e29

Describe your changes here

Browse files
Files changed (1) hide show
  1. app.py +118 -51
app.py CHANGED
@@ -1,64 +1,131 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
41
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from google import genai
4
+ from googleapiclient.discovery import build
5
+ from googleapiclient.errors import HttpError
6
+ from time import sleep
7
+ from typing import List
8
 
9
+ # Constants (ensure you secure these in production)
10
+ GOOGLE_SEARCH_API_KEY = "AIzaSyB06LrMInO1PDO6OoUFockguFuBX9EXJM8"
11
+ GOOGLE_SEARCH_ENGINE_ID = "a0172f6639ea44605"
12
+ GEMINI_API_KEY = "AIzaSyDeJRqHsnRYtuCufX2VB8nH7_r35jZxk20"
13
+ MAX_SEARCH_RESULTS = 10
14
 
15
+ # --- Your original functions (unchanged) ---
16
 
17
+ def initialize_apis():
18
+ try:
19
+ gemini_client = genai.Client(api_key=GEMINI_API_KEY)
20
+ search_service = build("customsearch", "v1", developerKey=GOOGLE_SEARCH_API_KEY)
21
+ test_search = search_service.cse().list(q="test", cx=GOOGLE_SEARCH_ENGINE_ID, num=1).execute()
22
+ if not test_search.get('items'):
23
+ print("⚠️ Warning: Test search returned no results. Check CX configuration.")
24
+ return gemini_client, search_service
25
+ except Exception as e:
26
+ raise Exception(f"Initialization failed: {str(e)}")
27
 
28
+ def execute_search(search_service, query: str) -> List[str]:
29
+ print(f"🔍 Searching for: {query}")
30
+ try:
31
+ response = search_service.cse().list(q=query, cx=GOOGLE_SEARCH_ENGINE_ID, num=MAX_SEARCH_RESULTS).execute()
32
+ print(f"Response keys: {list(response.keys())}")
33
+ items = response.get('items', [])
34
+ print(f"Found {len(items)} results")
35
+ return [item["link"] for item in items]
36
+ except HttpError as e:
37
+ print(f"HTTP Error {e.resp.status}: {e._get_reason()}")
38
+ return []
39
+ except Exception as e:
40
+ print(f"Search failed: {str(e)}")
41
+ return []
42
 
43
+ def plan_research_strategy(client: genai.Client, research_topic: str) -> List[str]:
44
+ prompt = f"""Generate 3-5 Google search queries to research: {research_topic}
45
+ - Use general web search terms
46
+ - Avoid special characters
47
+ - Use common terminology
48
+ Format as a numbered list."""
49
+ try:
50
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt])
51
+ raw_queries = response.text.split("\n")
52
+ valid_queries = []
53
+ for q in raw_queries:
54
+ clean_q = q.split(". ", 1)[-1].strip()
55
+ if clean_q and len(clean_q) < 150:
56
+ valid_queries.append(clean_q)
57
+ print(f"Generated queries: {valid_queries}")
58
+ return valid_queries
59
+ except Exception as e:
60
+ raise Exception(f"Error generating queries: {e}")
61
 
62
+ def understand_user_request(client: genai.Client, user_request: str) -> str:
63
+ prompt = f"""You are a research assistant. The user provides: {user_request}.
64
+ First, summarize the request.
65
+ Second, identify ambiguities needing clarification.
66
+ If needed, ask questions. Else confirm understanding.
67
+ Format response as:
68
+ Summary: [summary]
69
+ Clarification Needed: [Yes/No]
70
+ Questions: [questions or None]"""
71
+ try:
72
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt])
73
+ analysis = response.text
74
+ print(f"Analysis: {analysis}")
75
+ summary = analysis.split("Summary:")[1].split("Clarification Needed:")[0].strip()
76
+ clarification_needed = analysis.split("Clarification Needed:")[1].split("Questions:")[0].strip()
77
+ questions = analysis.split("Questions:")[1].strip()
78
+ if clarification_needed.lower() == "yes":
79
+ raise Exception(f"Clarification needed: {questions}")
80
+ print("Understood the request.")
81
+ return summary
82
+ except Exception as e:
83
+ raise Exception(f"Error analyzing request: {e}")
84
 
85
+ def extract_content_from_url(url: str) -> str:
86
+ print(f"Extracting content from {url} (simulated)...")
87
+ sleep(0.5)
88
+ return f"Content from {url} [placeholder]"
 
 
 
 
89
 
90
+ def summarize_information(client: genai.Client, information: str) -> str:
91
+ prompt = f"""Summarize the following into a detailed report:
92
+ {information}"""
93
+ try:
94
+ response = client.models.generate_content(model="gemini-2.0-flash", contents=[prompt])
95
+ return response.text
96
+ except Exception as e:
97
+ print(f"Error summarizing: {e}")
98
+ return "Summary unavailable"
99
 
100
+ # --- FastAPI app definition ---
101
 
102
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ class RequestPayload(BaseModel):
105
+ research_request: str
106
 
107
+ @app.post("/predict")
108
+ def predict(payload: RequestPayload):
109
+ try:
110
+ gemini_client, search_service = initialize_apis()
111
+ research_topic = understand_user_request(gemini_client, payload.research_request)
112
+ queries = plan_research_strategy(gemini_client, research_topic)
113
+ if not queries:
114
+ raise HTTPException(status_code=400, detail="No valid queries generated")
115
+ all_content = []
116
+ for query in queries:
117
+ results = execute_search(search_service, query)
118
+ for url in results:
119
+ content = extract_content_from_url(url)
120
+ all_content.append(content)
121
+ if not all_content:
122
+ raise HTTPException(status_code=400, detail="No content gathered from searches")
123
+ summary = summarize_information(gemini_client, "\n".join(all_content))
124
+ return {"summary": summary}
125
+ except Exception as e:
126
+ raise HTTPException(status_code=500, detail=str(e))
127
+
128
+ # If running locally (for testing), use:
129
  if __name__ == "__main__":
130
+ import uvicorn
131
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)