cleaner app
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import pipeline
|
| 3 |
from haystack.document_stores import FAISSDocumentStore
|
| 4 |
-
from haystack.nodes import EmbeddingRetriever
|
| 5 |
import numpy as np
|
| 6 |
import openai
|
| 7 |
import os
|
|
@@ -15,12 +14,21 @@ from utils import (
|
|
| 15 |
get_random_string,
|
| 16 |
)
|
| 17 |
|
| 18 |
-
|
| 19 |
-
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
| 20 |
system_template = {"role": os.environ["role"], "content": os.environ["content"]}
|
| 21 |
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""return (answer:str, history:list[dict], sources:str)
|
| 25 |
|
| 26 |
Args:
|
|
@@ -31,57 +39,32 @@ def gen_conv(query: str, report_type, history=[system_template], ipcc=True):
|
|
| 31 |
Returns:
|
| 32 |
_type_: _description_
|
| 33 |
"""
|
| 34 |
-
if report_type == "IPCC only":
|
| 35 |
-
document_store = FAISSDocumentStore.load(
|
| 36 |
-
index_path="./documents/climate_gpt_only_giec.faiss",
|
| 37 |
-
config_path="./documents/climate_gpt_only_giec.json",
|
| 38 |
-
)
|
| 39 |
-
else:
|
| 40 |
-
document_store = FAISSDocumentStore.load(
|
| 41 |
-
index_path="./documents/climate_gpt.faiss",
|
| 42 |
-
config_path="./documents/climate_gpt.json",
|
| 43 |
-
)
|
| 44 |
|
| 45 |
dense = EmbeddingRetriever(
|
| 46 |
-
document_store=document_store,
|
| 47 |
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
| 48 |
model_format="sentence_transformers",
|
| 49 |
)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
sources = ""
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
[os.environ["sources"]]
|
| 62 |
-
+ [
|
| 63 |
-
f"{d.meta['file_name']} Page {d.meta['page_number']}\n{d.content}"
|
| 64 |
-
for d in docs
|
| 65 |
-
]
|
| 66 |
-
)
|
| 67 |
-
messages.append({"role": "system", "content": sources})
|
| 68 |
-
|
| 69 |
-
answer = openai.ChatCompletion.create(
|
| 70 |
-
model="gpt-3.5-turbo",
|
| 71 |
-
messages=messages,
|
| 72 |
-
temperature=0.2,
|
| 73 |
-
)["choices"][0]["message"]["content"]
|
| 74 |
-
|
| 75 |
-
if retrieve:
|
| 76 |
-
messages.pop()
|
| 77 |
-
sources = "\n\n".join(
|
| 78 |
-
f"{d.meta['file_name']} Page {d.meta['page_number']}:\n{d.content}"
|
| 79 |
-
for d in docs
|
| 80 |
-
)
|
| 81 |
else:
|
|
|
|
| 82 |
sources = "No environmental report was used to provide this answer."
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
| 86 |
|
| 87 |
return gradio_format, messages, sources
|
|
@@ -123,20 +106,18 @@ with gr.Blocks(title="π ClimateGPT Ekimetrics", css=css_code) as demo:
|
|
| 123 |
|
| 124 |
with gr.Column(scale=1, variant="panel"):
|
| 125 |
gr.Markdown("### Sources")
|
| 126 |
-
sources_textbox = gr.Textbox(
|
| 127 |
-
interactive=False, show_label=False, max_lines=50
|
| 128 |
-
)
|
| 129 |
|
| 130 |
ask.submit(
|
| 131 |
fn=gen_conv,
|
| 132 |
inputs=[
|
| 133 |
ask,
|
|
|
|
| 134 |
gr.inputs.Dropdown(
|
| 135 |
["IPCC only", "All available"],
|
| 136 |
default="All available",
|
| 137 |
label="Select reports",
|
| 138 |
),
|
| 139 |
-
state,
|
| 140 |
],
|
| 141 |
outputs=[chatbot, state, sources_textbox],
|
| 142 |
)
|
|
@@ -153,12 +134,8 @@ with gr.Blocks(title="π ClimateGPT Ekimetrics", css=css_code) as demo:
|
|
| 153 |
lines=1,
|
| 154 |
type="password",
|
| 155 |
)
|
| 156 |
-
openai_api_key_textbox.change(
|
| 157 |
-
|
| 158 |
-
)
|
| 159 |
-
openai_api_key_textbox.submit(
|
| 160 |
-
set_openai_api_key, inputs=[openai_api_key_textbox]
|
| 161 |
-
)
|
| 162 |
|
| 163 |
with gr.Tab("Information"):
|
| 164 |
gr.Markdown(
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from haystack.document_stores import FAISSDocumentStore
|
| 3 |
+
from haystack.nodes import EmbeddingRetriever
|
| 4 |
import numpy as np
|
| 5 |
import openai
|
| 6 |
import os
|
|
|
|
| 14 |
get_random_string,
|
| 15 |
)
|
| 16 |
|
|
|
|
|
|
|
| 17 |
system_template = {"role": os.environ["role"], "content": os.environ["content"]}
|
| 18 |
|
| 19 |
|
| 20 |
+
only_ipcc_document_store = FAISSDocumentStore.load(
|
| 21 |
+
index_path="./documents/climate_gpt_only_giec.faiss",
|
| 22 |
+
config_path="./documents/climate_gpt_only_giec.json",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
document_store = FAISSDocumentStore.load(
|
| 26 |
+
index_path="./documents/climate_gpt.faiss",
|
| 27 |
+
config_path="./documents/climate_gpt.json",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def gen_conv(query: str, history=[system_template], report_type="All available", threshold=0.56):
|
| 32 |
"""return (answer:str, history:list[dict], sources:str)
|
| 33 |
|
| 34 |
Args:
|
|
|
|
| 39 |
Returns:
|
| 40 |
_type_: _description_
|
| 41 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
dense = EmbeddingRetriever(
|
| 44 |
+
document_store=document_store if report_type == "All available" else only_ipcc_document_store,
|
| 45 |
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
| 46 |
model_format="sentence_transformers",
|
| 47 |
)
|
| 48 |
|
| 49 |
+
messages = history + [{"role": "user", "content": query}]
|
| 50 |
+
docs = dense.retrieve(query=query, top_k=10)
|
| 51 |
+
sources = "\n\n".join(
|
| 52 |
+
f"doc {i}: {d.meta['file_name']} page {d.meta['page_number']}\n{d.content}"
|
| 53 |
+
for i, d in enumerate(docs, 1)
|
| 54 |
+
if d.score > threshold
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if sources:
|
| 58 |
+
messages.append({"role": "system", "content": f"{os.environ['sources']}\n\n{sources}"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
else:
|
| 60 |
+
messages.append({"role": "system", "content": "no relevant document available."})
|
| 61 |
sources = "No environmental report was used to provide this answer."
|
| 62 |
|
| 63 |
+
answer = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages, temperature=0.2,)["choices"][0][
|
| 64 |
+
"message"
|
| 65 |
+
]["content"]
|
| 66 |
+
|
| 67 |
+
messages[-1] = {"role": "assistant", "content": answer}
|
| 68 |
gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
| 69 |
|
| 70 |
return gradio_format, messages, sources
|
|
|
|
| 106 |
|
| 107 |
with gr.Column(scale=1, variant="panel"):
|
| 108 |
gr.Markdown("### Sources")
|
| 109 |
+
sources_textbox = gr.Textbox(interactive=False, show_label=False, max_lines=50)
|
|
|
|
|
|
|
| 110 |
|
| 111 |
ask.submit(
|
| 112 |
fn=gen_conv,
|
| 113 |
inputs=[
|
| 114 |
ask,
|
| 115 |
+
state,
|
| 116 |
gr.inputs.Dropdown(
|
| 117 |
["IPCC only", "All available"],
|
| 118 |
default="All available",
|
| 119 |
label="Select reports",
|
| 120 |
),
|
|
|
|
| 121 |
],
|
| 122 |
outputs=[chatbot, state, sources_textbox],
|
| 123 |
)
|
|
|
|
| 134 |
lines=1,
|
| 135 |
type="password",
|
| 136 |
)
|
| 137 |
+
openai_api_key_textbox.change(set_openai_api_key, inputs=[openai_api_key_textbox])
|
| 138 |
+
openai_api_key_textbox.submit(set_openai_api_key, inputs=[openai_api_key_textbox])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
with gr.Tab("Information"):
|
| 141 |
gr.Markdown(
|