Spaces:
Runtime error
Runtime error
shaocongma
commited on
Commit
·
94dc00e
1
Parent(s):
328e8d0
Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.
Browse files- api_wrapper.py +0 -42
- app.py +54 -36
- assets/idealab.png +0 -0
- auto_backgrounds.py → auto_generators.py +11 -17
- configurations/default.yaml +29 -0
- cyber-supervisor-openai.py +1 -1
- idealab.py +0 -144
- kdb_test.py +39 -5
- references_generator.py +0 -86
- utils/knowledge.py +1 -1
- utils/references.py +233 -136
- worker.py +1 -1
- wrapper.py +57 -0
api_wrapper.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
This script is used to wrap all generation methods together.
|
| 3 |
-
|
| 4 |
-
todo:
|
| 5 |
-
A worker keeps running on the server. Monitor the Amazon SQS. Once receive a new message, do the following:
|
| 6 |
-
Download the corresponding configuration files on S3.
|
| 7 |
-
Change Task status from Pending to Running.
|
| 8 |
-
Call `generator_wrapper` and wait for the outputs.
|
| 9 |
-
If `generator_wrapper` returns results:
|
| 10 |
-
evaluate the results; compile it; upload results to S3 ... Change Task status from Running to Completed.
|
| 11 |
-
If anything goes wrong, raise Error.
|
| 12 |
-
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
| 13 |
-
Change Task status from Running to Failed.
|
| 14 |
-
'''
|
| 15 |
-
import os.path
|
| 16 |
-
|
| 17 |
-
from auto_backgrounds import generate_draft
|
| 18 |
-
import json, time
|
| 19 |
-
from utils.file_operations import make_archive
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
GENERATOR_MAPPING = {"fake": None, # a fake generator
|
| 23 |
-
"draft": generate_draft # generate academic paper
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
def generator_wrapper(config):
|
| 27 |
-
generator = GENERATOR_MAPPING[config["generator"]]
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def generator_wrapper_from_json(path_to_config_json):
|
| 31 |
-
# Read configuration file and call corresponding function
|
| 32 |
-
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
| 33 |
-
config = json.load(f)
|
| 34 |
-
print("Configuration:", config)
|
| 35 |
-
# generator = GENERATOR_MAPPING.get(config["generator"])
|
| 36 |
-
generator = None
|
| 37 |
-
if generator is None:
|
| 38 |
-
# generate a fake ZIP file and upload
|
| 39 |
-
time.sleep(150)
|
| 40 |
-
zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
|
| 41 |
-
return make_archive(path_to_config_json, zip_path)
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -2,9 +2,10 @@ import uuid
|
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
import openai
|
| 5 |
-
|
| 6 |
from utils.file_operations import list_folders, urlify
|
| 7 |
from huggingface_hub import snapshot_download
|
|
|
|
| 8 |
|
| 9 |
# todo:
|
| 10 |
# 6. get logs when the procedure is not completed. *
|
|
@@ -22,8 +23,10 @@ from huggingface_hub import snapshot_download
|
|
| 22 |
# OPENAI_API_BASE: (Optional) Support alternative OpenAI minors
|
| 23 |
# GPT4_ENABLE: (Optional) Set it to 1 to enable GPT-4 model.
|
| 24 |
|
| 25 |
-
# AWS_ACCESS_KEY_ID: (Optional)
|
| 26 |
-
#
|
|
|
|
|
|
|
| 27 |
# KDB_REPO: (Optional) A Huggingface dataset hosting Knowledge Databases
|
| 28 |
# HF_TOKEN: (Optional) Access to KDB_REPO
|
| 29 |
|
|
@@ -34,7 +37,7 @@ openai_key = os.getenv("OPENAI_API_KEY")
|
|
| 34 |
openai_api_base = os.getenv("OPENAI_API_BASE")
|
| 35 |
if openai_api_base is not None:
|
| 36 |
openai.api_base = openai_api_base
|
| 37 |
-
GPT4_ENABLE = os.getenv("GPT4_ENABLE")
|
| 38 |
|
| 39 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
| 40 |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
|
@@ -124,7 +127,7 @@ REFERENCES = """## 一键搜索相关论文
|
|
| 124 |
REFERENCES_INSTRUCTION = """### References
|
| 125 |
这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
|
| 126 |
1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
|
| 127 |
-
2.
|
| 128 |
关于有希望利用本地文件来供GPT参考的功能将在未来实装.
|
| 129 |
"""
|
| 130 |
|
|
@@ -140,7 +143,7 @@ OUTPUTS_INSTRUCTION = """### Outputs
|
|
| 140 |
这一栏用于定义输出的内容:
|
| 141 |
* Template: 用于填装内容的LaTeX模板.
|
| 142 |
* Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
|
| 143 |
-
* Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用.
|
| 144 |
"""
|
| 145 |
|
| 146 |
OTHERS_INSTRUCTION = """### Others
|
|
@@ -164,18 +167,34 @@ def clear_inputs(*args):
|
|
| 164 |
def clear_inputs_refs(*args):
|
| 165 |
return "", 5
|
| 166 |
|
|
|
|
| 167 |
def wrapped_generator(
|
| 168 |
paper_title, paper_description, # main input
|
| 169 |
-
openai_api_key=None,
|
| 170 |
-
tldr=True, max_kw_refs=10,
|
| 171 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
| 172 |
paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False, # outputs parameters
|
| 173 |
cache_mode=IS_CACHE_AVAILABLE # handle cache mode
|
| 174 |
):
|
| 175 |
-
# if `cache_mode` is True, then always upload the generated content to my S3.
|
| 176 |
file_name_upload = urlify(paper_title) + "_" + uuid.uuid1().hex + ".zip"
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if openai_api_key is not None:
|
| 180 |
openai.api_key = openai_api_key
|
| 181 |
try:
|
|
@@ -183,12 +202,7 @@ def wrapped_generator(
|
|
| 183 |
except Exception as e:
|
| 184 |
raise gr.Error(f"Key错误. Error: {e}")
|
| 185 |
try:
|
| 186 |
-
output =
|
| 187 |
-
paper_title, description=paper_description, # main input
|
| 188 |
-
tldr=tldr, max_kw_refs=max_kw_refs, bib_refs=bib_refs, max_tokens_ref=max_tokens_ref, # references
|
| 189 |
-
knowledge_database=knowledge_database, max_tokens_kd=max_tokens_kd, query_counts=query_counts, # domain knowledge
|
| 190 |
-
sections=selected_sections, model=model, template=paper_template, prompts_mode=prompts_mode, # outputs parameters
|
| 191 |
-
)
|
| 192 |
if cache_mode:
|
| 193 |
from utils.storage import upload_file
|
| 194 |
upload_file(output, target_name=file_name_upload)
|
|
@@ -204,8 +218,6 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 204 |
with gr.Column(scale=2):
|
| 205 |
key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
|
| 206 |
visible=not IS_OPENAI_API_KEY_AVAILABLE)
|
| 207 |
-
url = gr.Textbox(value=None, lines=1, max_lines=1, label="URL",
|
| 208 |
-
visible=False)
|
| 209 |
# 每个功能做一个tab
|
| 210 |
with gr.Tab("学术论文"):
|
| 211 |
gr.Markdown(ACADEMIC_PAPER)
|
|
@@ -230,8 +242,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 230 |
interactive=GPT4_INTERACTIVE,
|
| 231 |
info="生成论文用到的语言模型.")
|
| 232 |
prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
|
| 233 |
-
|
| 234 |
-
|
| 235 |
|
| 236 |
sections = gr.CheckboxGroup(
|
| 237 |
choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
|
|
@@ -245,21 +257,27 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 245 |
|
| 246 |
with gr.Column(scale=2):
|
| 247 |
max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
| 248 |
-
|
| 249 |
-
|
| 250 |
|
| 251 |
max_tokens_ref_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
| 252 |
-
|
| 253 |
-
|
| 254 |
|
| 255 |
tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
|
| 256 |
info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
|
| 257 |
interactive=True)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
with gr.Row():
|
| 265 |
with gr.Column(scale=1):
|
|
@@ -267,11 +285,11 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 267 |
|
| 268 |
with gr.Column(scale=2):
|
| 269 |
query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
| 270 |
-
|
| 271 |
-
|
| 272 |
max_tokens_kd_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
| 273 |
-
|
| 274 |
-
|
| 275 |
domain_knowledge = gr.Dropdown(label="预载知识库",
|
| 276 |
choices=ALL_DATABASES,
|
| 277 |
value="(None)",
|
|
@@ -296,8 +314,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 296 |
json_output = gr.JSON(label="References")
|
| 297 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
| 298 |
submit_button_pp.click(fn=wrapped_generator,
|
| 299 |
-
inputs=[title, description_pp, key,
|
| 300 |
-
tldr_checkbox, max_kw_ref_slider,
|
| 301 |
domain_knowledge, max_tokens_kd_slider, query_counts_slider,
|
| 302 |
template, sections, model_selection, prompts_mode], outputs=file_output)
|
| 303 |
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
import openai
|
| 5 |
+
import yaml
|
| 6 |
from utils.file_operations import list_folders, urlify
|
| 7 |
from huggingface_hub import snapshot_download
|
| 8 |
+
from wrapper import generator_wrapper
|
| 9 |
|
| 10 |
# todo:
|
| 11 |
# 6. get logs when the procedure is not completed. *
|
|
|
|
| 23 |
# OPENAI_API_BASE: (Optional) Support alternative OpenAI minors
|
| 24 |
# GPT4_ENABLE: (Optional) Set it to 1 to enable GPT-4 model.
|
| 25 |
|
| 26 |
+
# AWS_ACCESS_KEY_ID: (Optional)
|
| 27 |
+
# Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
|
| 28 |
+
# AWS_SECRET_ACCESS_KEY: (Optional)
|
| 29 |
+
# Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
|
| 30 |
# KDB_REPO: (Optional) A Huggingface dataset hosting Knowledge Databases
|
| 31 |
# HF_TOKEN: (Optional) Access to KDB_REPO
|
| 32 |
|
|
|
|
| 37 |
openai_api_base = os.getenv("OPENAI_API_BASE")
|
| 38 |
if openai_api_base is not None:
|
| 39 |
openai.api_base = openai_api_base
|
| 40 |
+
GPT4_ENABLE = os.getenv("GPT4_ENABLE") # disable GPT-4 for public repo
|
| 41 |
|
| 42 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
| 43 |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
|
|
|
| 127 |
REFERENCES_INSTRUCTION = """### References
|
| 128 |
这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
|
| 129 |
1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
|
| 130 |
+
2. 用户通过输入文章标题(用英文逗号隔开), AI会自动搜索文献作为参考资料.
|
| 131 |
关于有希望利用本地文件来供GPT参考的功能将在未来实装.
|
| 132 |
"""
|
| 133 |
|
|
|
|
| 143 |
这一栏用于定义输出的内容:
|
| 144 |
* Template: 用于填装内容的LaTeX模板.
|
| 145 |
* Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
|
| 146 |
+
* Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用. (放在输出的ZIP文件的prompts.json文件中)
|
| 147 |
"""
|
| 148 |
|
| 149 |
OTHERS_INSTRUCTION = """### Others
|
|
|
|
| 167 |
def clear_inputs_refs(*args):
|
| 168 |
return "", 5
|
| 169 |
|
| 170 |
+
|
| 171 |
def wrapped_generator(
|
| 172 |
paper_title, paper_description, # main input
|
| 173 |
+
openai_api_key=None, # key
|
| 174 |
+
tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
|
| 175 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
| 176 |
paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False, # outputs parameters
|
| 177 |
cache_mode=IS_CACHE_AVAILABLE # handle cache mode
|
| 178 |
):
|
|
|
|
| 179 |
file_name_upload = urlify(paper_title) + "_" + uuid.uuid1().hex + ".zip"
|
| 180 |
+
|
| 181 |
+
# load the default configuration file
|
| 182 |
+
with open("configurations/default.yaml", 'r') as file:
|
| 183 |
+
config = yaml.safe_load(file)
|
| 184 |
+
config["paper"]["title"] = paper_title
|
| 185 |
+
config["paper"]["description"] = paper_description
|
| 186 |
+
config["references"]["tldr"] = tldr
|
| 187 |
+
config["references"]["max_kw_refs"] = max_kw_refs
|
| 188 |
+
config["references"]["refs"] = refs
|
| 189 |
+
config["references"]["max_tokens_ref"] = max_tokens_ref
|
| 190 |
+
config["domain_knowledge"]["knowledge_database"] = knowledge_database
|
| 191 |
+
config["domain_knowledge"]["max_tokens_kd"] = max_tokens_kd
|
| 192 |
+
config["domain_knowledge"]["query_counts"] = query_counts
|
| 193 |
+
config["output"]["selected_sections"] = selected_sections
|
| 194 |
+
config["output"]["model"] = model
|
| 195 |
+
config["output"]["template"] = paper_template
|
| 196 |
+
config["output"]["prompts_mode"] = prompts_mode
|
| 197 |
+
|
| 198 |
if openai_api_key is not None:
|
| 199 |
openai.api_key = openai_api_key
|
| 200 |
try:
|
|
|
|
| 202 |
except Exception as e:
|
| 203 |
raise gr.Error(f"Key错误. Error: {e}")
|
| 204 |
try:
|
| 205 |
+
output = generator_wrapper(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
if cache_mode:
|
| 207 |
from utils.storage import upload_file
|
| 208 |
upload_file(output, target_name=file_name_upload)
|
|
|
|
| 218 |
with gr.Column(scale=2):
|
| 219 |
key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
|
| 220 |
visible=not IS_OPENAI_API_KEY_AVAILABLE)
|
|
|
|
|
|
|
| 221 |
# 每个功能做一个tab
|
| 222 |
with gr.Tab("学术论文"):
|
| 223 |
gr.Markdown(ACADEMIC_PAPER)
|
|
|
|
| 242 |
interactive=GPT4_INTERACTIVE,
|
| 243 |
info="生成论文用到的语言模型.")
|
| 244 |
prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
|
| 245 |
+
label="Prompts模式",
|
| 246 |
+
info="只输出用于生成论文的Prompts, 可以复制到别的地方生成论文.")
|
| 247 |
|
| 248 |
sections = gr.CheckboxGroup(
|
| 249 |
choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
|
|
|
|
| 257 |
|
| 258 |
with gr.Column(scale=2):
|
| 259 |
max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
| 260 |
+
interactive=True, label="MAX_KW_REFS",
|
| 261 |
+
info="每个Keyword搜索几篇参考文献", visible=False)
|
| 262 |
|
| 263 |
max_tokens_ref_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
| 264 |
+
interactive=True, label="MAX_TOKENS",
|
| 265 |
+
info="参考文献内容占用Prompts中的Token数")
|
| 266 |
|
| 267 |
tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
|
| 268 |
info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
|
| 269 |
interactive=True)
|
| 270 |
+
|
| 271 |
+
text_ref = gr.Textbox(lines=5, label="References (Optional)", visible=True,
|
| 272 |
+
info="交给AI参考的文献的标题, 用英文逗号`,`隔开.")
|
| 273 |
+
|
| 274 |
+
gr.Examples(
|
| 275 |
+
examples = ["Understanding the Impact of Model Incoherence on Convergence of Incremental SGD with Random Reshuffle,"
|
| 276 |
+
"Variance-Reduced Off-Policy TDC Learning: Non-Asymptotic Convergence Analysis,"
|
| 277 |
+
"Greedy-GQ with Variance Reduction: Finite-time Analysis and Improved Complexity"],
|
| 278 |
+
inputs=text_ref,
|
| 279 |
+
cache_examples=False
|
| 280 |
+
)
|
| 281 |
|
| 282 |
with gr.Row():
|
| 283 |
with gr.Column(scale=1):
|
|
|
|
| 285 |
|
| 286 |
with gr.Column(scale=2):
|
| 287 |
query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
| 288 |
+
interactive=True, label="QUERY_COUNTS",
|
| 289 |
+
info="从知识库内检索多少条内容", visible=False)
|
| 290 |
max_tokens_kd_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
| 291 |
+
interactive=True, label="MAX_TOKENS",
|
| 292 |
+
info="知识库内容占用Prompts中的Token数")
|
| 293 |
domain_knowledge = gr.Dropdown(label="预载知识库",
|
| 294 |
choices=ALL_DATABASES,
|
| 295 |
value="(None)",
|
|
|
|
| 314 |
json_output = gr.JSON(label="References")
|
| 315 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
| 316 |
submit_button_pp.click(fn=wrapped_generator,
|
| 317 |
+
inputs=[title, description_pp, key,
|
| 318 |
+
tldr_checkbox, max_kw_ref_slider, text_ref, max_tokens_ref_slider,
|
| 319 |
domain_knowledge, max_tokens_kd_slider, query_counts_slider,
|
| 320 |
template, sections, model_selection, prompts_mode], outputs=file_output)
|
| 321 |
|
assets/idealab.png
DELETED
|
Binary file (52.1 kB)
|
|
|
auto_backgrounds.py → auto_generators.py
RENAMED
|
@@ -40,7 +40,7 @@ def log_usage(usage, generating_target, print_out=True):
|
|
| 40 |
|
| 41 |
|
| 42 |
def _generation_setup(title, description="", template="ICLR2022",
|
| 43 |
-
tldr=False, max_kw_refs=10,
|
| 44 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
|
| 45 |
debug=True):
|
| 46 |
"""
|
|
@@ -115,7 +115,7 @@ def _generation_setup(title, description="", template="ICLR2022",
|
|
| 115 |
|
| 116 |
print("Keywords: \n", keywords)
|
| 117 |
# todo: in some rare situations, collected papers will be an empty list. handle this issue
|
| 118 |
-
ref = References(title,
|
| 119 |
ref.collect_papers(keywords, tldr=tldr)
|
| 120 |
references = ref.to_prompts(max_tokens=max_tokens_ref)
|
| 121 |
all_paper_ids = ref.to_bibtex(bibtex_path)
|
|
@@ -200,7 +200,7 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
|
|
| 200 |
|
| 201 |
|
| 202 |
def generate_draft(title, description="", # main input
|
| 203 |
-
tldr=True, max_kw_refs=10,
|
| 204 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
| 205 |
sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
|
| 206 |
):
|
|
@@ -245,7 +245,7 @@ def generate_draft(title, description="", # main input
|
|
| 245 |
"abstract"]
|
| 246 |
else:
|
| 247 |
sections = _filter_sections(sections)
|
| 248 |
-
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs,
|
| 249 |
max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
|
| 250 |
query_counts=query_counts,
|
| 251 |
knowledge_database=knowledge_database)
|
|
@@ -254,11 +254,10 @@ def generate_draft(title, description="", # main input
|
|
| 254 |
prompts_dict = {}
|
| 255 |
print(f"================PROCESSING================")
|
| 256 |
for section in sections:
|
|
|
|
|
|
|
| 257 |
if prompts_mode:
|
| 258 |
-
prompts = generate_paper_prompts(paper, section)
|
| 259 |
-
prompts_dict[section] = prompts
|
| 260 |
continue
|
| 261 |
-
|
| 262 |
print(f"Generate {section} part...")
|
| 263 |
max_attempts = 4
|
| 264 |
attempts_count = 0
|
|
@@ -274,21 +273,16 @@ def generate_draft(title, description="", # main input
|
|
| 274 |
logging.info(message)
|
| 275 |
attempts_count += 1
|
| 276 |
time.sleep(15)
|
| 277 |
-
|
| 278 |
# post-processing
|
| 279 |
print("================POST-PROCESSING================")
|
| 280 |
create_copies(destination_folder)
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
print("\nMission completed.\n")
|
|
|
|
| 284 |
|
| 285 |
-
|
| 286 |
-
filename = hash_name(input_dict) + ".json"
|
| 287 |
-
with open(filename, "w") as f:
|
| 288 |
-
json.dump(prompts_dict, f)
|
| 289 |
-
return filename
|
| 290 |
-
else:
|
| 291 |
-
return make_archive(destination_folder, filename)
|
| 292 |
|
| 293 |
|
| 294 |
if __name__ == "__main__":
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def _generation_setup(title, description="", template="ICLR2022",
|
| 43 |
+
tldr=False, max_kw_refs=10, refs=None, max_tokens_ref=2048, # generating references
|
| 44 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
|
| 45 |
debug=True):
|
| 46 |
"""
|
|
|
|
| 115 |
|
| 116 |
print("Keywords: \n", keywords)
|
| 117 |
# todo: in some rare situations, collected papers will be an empty list. handle this issue
|
| 118 |
+
ref = References(title, load_papers=refs)
|
| 119 |
ref.collect_papers(keywords, tldr=tldr)
|
| 120 |
references = ref.to_prompts(max_tokens=max_tokens_ref)
|
| 121 |
all_paper_ids = ref.to_bibtex(bibtex_path)
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
def generate_draft(title, description="", # main input
|
| 203 |
+
tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
|
| 204 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
| 205 |
sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
|
| 206 |
):
|
|
|
|
| 245 |
"abstract"]
|
| 246 |
else:
|
| 247 |
sections = _filter_sections(sections)
|
| 248 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, refs,
|
| 249 |
max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
|
| 250 |
query_counts=query_counts,
|
| 251 |
knowledge_database=knowledge_database)
|
|
|
|
| 254 |
prompts_dict = {}
|
| 255 |
print(f"================PROCESSING================")
|
| 256 |
for section in sections:
|
| 257 |
+
prompts = generate_paper_prompts(paper, section)
|
| 258 |
+
prompts_dict[section] = prompts
|
| 259 |
if prompts_mode:
|
|
|
|
|
|
|
| 260 |
continue
|
|
|
|
| 261 |
print(f"Generate {section} part...")
|
| 262 |
max_attempts = 4
|
| 263 |
attempts_count = 0
|
|
|
|
| 273 |
logging.info(message)
|
| 274 |
attempts_count += 1
|
| 275 |
time.sleep(15)
|
|
|
|
| 276 |
# post-processing
|
| 277 |
print("================POST-PROCESSING================")
|
| 278 |
create_copies(destination_folder)
|
| 279 |
+
filename = "prompts.json"
|
| 280 |
+
with open(os.path.join(destination_folder, filename), "w") as f:
|
| 281 |
+
json.dump(prompts_dict, f)
|
| 282 |
print("\nMission completed.\n")
|
| 283 |
+
return destination_folder
|
| 284 |
|
| 285 |
+
# return make_archive(destination_folder, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
if __name__ == "__main__":
|
configurations/default.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
date: 2023-07-11
|
| 2 |
+
|
| 3 |
+
generator: "auto_draft"
|
| 4 |
+
|
| 5 |
+
paper:
|
| 6 |
+
title: "playing atari game with deep reinforcement learning"
|
| 7 |
+
description: ""
|
| 8 |
+
|
| 9 |
+
references:
|
| 10 |
+
tldr: True
|
| 11 |
+
max_kw_refs: 10
|
| 12 |
+
max_tokens_ref: 2048
|
| 13 |
+
refs: null
|
| 14 |
+
|
| 15 |
+
domain_knowledge:
|
| 16 |
+
knowledge_database: null
|
| 17 |
+
max_tokens_kd: 2048
|
| 18 |
+
query_counts: 10
|
| 19 |
+
|
| 20 |
+
output:
|
| 21 |
+
template: "default"
|
| 22 |
+
model: "gpt-4"
|
| 23 |
+
selected_sections: null
|
| 24 |
+
prompts_mode: False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
cyber-supervisor-openai.py
CHANGED
|
@@ -3,7 +3,7 @@ import openai
|
|
| 3 |
import ast
|
| 4 |
from tools import functions, TOOLS
|
| 5 |
|
| 6 |
-
MAX_ITER =
|
| 7 |
|
| 8 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
default_model = os.getenv("DEFAULT_MODEL")
|
|
|
|
| 3 |
import ast
|
| 4 |
from tools import functions, TOOLS
|
| 5 |
|
| 6 |
+
MAX_ITER = 99
|
| 7 |
|
| 8 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
default_model = os.getenv("DEFAULT_MODEL")
|
idealab.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import os
|
| 3 |
-
import openai
|
| 4 |
-
from utils.references import References
|
| 5 |
-
from utils.gpt_interaction import GPTModel
|
| 6 |
-
from utils.prompts import SYSTEM
|
| 7 |
-
|
| 8 |
-
openai_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
-
default_model = os.getenv("DEFAULT_MODEL")
|
| 10 |
-
if default_model is None:
|
| 11 |
-
# default_model = "gpt-3.5-turbo-16k"
|
| 12 |
-
default_model = "gpt-4"
|
| 13 |
-
|
| 14 |
-
openai.api_key = openai_key
|
| 15 |
-
|
| 16 |
-
paper_system_prompt = '''You are an assistant designed to propose choices of research direction.
|
| 17 |
-
The user will input questions or some keywords of a fields. You need to generate some paper titles and main contributions. Ensure follow the following instructions:
|
| 18 |
-
Instruction:
|
| 19 |
-
- Your response should follow the JSON format.
|
| 20 |
-
- Your response should have the following structure:
|
| 21 |
-
{
|
| 22 |
-
"your suggested paper title":
|
| 23 |
-
{
|
| 24 |
-
"summary": "an overview introducing what this paper will include",
|
| 25 |
-
"contributions": {
|
| 26 |
-
"contribution1": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
| 27 |
-
"contribution2": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
| 28 |
-
...
|
| 29 |
-
}
|
| 30 |
-
}
|
| 31 |
-
"your suggested paper title":
|
| 32 |
-
{
|
| 33 |
-
"summary": "an overview introducing what this paper will include",
|
| 34 |
-
"contributions": {
|
| 35 |
-
"contribution1": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
| 36 |
-
"contribution2": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
| 37 |
-
...
|
| 38 |
-
}
|
| 39 |
-
}
|
| 40 |
-
...
|
| 41 |
-
}
|
| 42 |
-
- Please list three to five suggested title and at least three contributions for each paper.
|
| 43 |
-
'''
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
contribution_system_prompt = '''You are an assistant designed to criticize the contributions of a paper. You will be provided Paper's Title, References and Contributions. Ensure follow the following instructions:
|
| 47 |
-
Instruction:
|
| 48 |
-
- Your response should follow the JSON format.
|
| 49 |
-
- Your response should have the following structure:
|
| 50 |
-
{
|
| 51 |
-
"title": "the title provided by the user",
|
| 52 |
-
"comment": "your thoughts on if this title clearly reflects the key ideas of this paper and explain why"
|
| 53 |
-
"contributions": {
|
| 54 |
-
"contribution1": {"statement": "briefly describe what the contribution is",
|
| 55 |
-
"reason": "reason why the user claims it is a contribution",
|
| 56 |
-
"judge": "your thought about if this is a novel contribution and explain why",
|
| 57 |
-
"suggestion": "your suggestion on how to modify the research direction to enhance the novelty "},
|
| 58 |
-
"contribution2": {"statement": "briefly describe what the contribution is",
|
| 59 |
-
"reason": "reason why the user claims it is a contribution",
|
| 60 |
-
"judge": "your thought about if this is a novel contribution and explain why",
|
| 61 |
-
"suggestion": "your suggestion on how to modify the research direction to enhance the novelty "},
|
| 62 |
-
...
|
| 63 |
-
}
|
| 64 |
-
}
|
| 65 |
-
- You need to carefully check if the claimed contribution has been made in the provided references, which makes the contribution not novel.
|
| 66 |
-
- You also need to propose your concerns on if any of contributions could be incremental or just a mild modification on an existing work.
|
| 67 |
-
'''
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
ANNOUNCEMENT = """
|
| 71 |
-
<h1 style="text-align: center"><img src='/file=assets/idealab.png' width=36px style="display: inline"/>灵感实验室IdeaLab</h1>
|
| 72 |
-
|
| 73 |
-
<p>灵感实验室IdeaLab可以为你选择你下一篇论文的研究方向! 输入你的研究领域或者任何想法, 灵感实验室会自动生成若干个论文标题+论文的主要贡献供你选择. </p>
|
| 74 |
-
|
| 75 |
-
<p>除此之外, 输入你的论文标题+主要贡献, 它会自动搜索相关文献, 来验证这个想法是不是有人做过了.</p>
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def criticize_my_idea(title, contributions, max_tokens=4096):
|
| 80 |
-
ref = References(title=title, description=f"{contributions}")
|
| 81 |
-
keywords, _ = llm(systems=SYSTEM["keywords"], prompts=title, return_json=True)
|
| 82 |
-
keywords = {keyword: 10 for keyword in keywords}
|
| 83 |
-
ref.collect_papers(keywords)
|
| 84 |
-
ref_prompt = ref.to_prompts(max_tokens=max_tokens)
|
| 85 |
-
|
| 86 |
-
prompt = f"Title: {title}\n References: {ref_prompt}\n Contributions: {contributions}"
|
| 87 |
-
output, _ = llm(systems=contribution_system_prompt, prompts=prompt, return_json=True)
|
| 88 |
-
return output, ref_prompt
|
| 89 |
-
|
| 90 |
-
def paste_title(suggestions):
|
| 91 |
-
if suggestions:
|
| 92 |
-
title = suggestions['title']['new title']
|
| 93 |
-
contributions = suggestions['contributions']
|
| 94 |
-
|
| 95 |
-
return title, contributions, {}, {}, {}
|
| 96 |
-
else:
|
| 97 |
-
return "", "", {}, {}, {}
|
| 98 |
-
|
| 99 |
-
def generate_choices(thoughts):
|
| 100 |
-
output, _ = llm(systems=paper_system_prompt, prompts=thoughts, return_json=True)
|
| 101 |
-
return output
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# def translate_json(json_input):
|
| 105 |
-
# system_prompt = "You are a translation bot. The user will input a JSON format string. You need to translate it into Chinese and return in the same formmat."
|
| 106 |
-
# output, _ = llm(systems=system_prompt, prompts=str(json_input), return_json=True)
|
| 107 |
-
# return output
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
with gr.Blocks() as demo:
|
| 111 |
-
llm = GPTModel(model=default_model)
|
| 112 |
-
|
| 113 |
-
gr.HTML(ANNOUNCEMENT)
|
| 114 |
-
with gr.Row():
|
| 115 |
-
with gr.Tab("生成论文想法 (Generate Paper Ideas)"):
|
| 116 |
-
thoughts_input = gr.Textbox(label="Thoughts")
|
| 117 |
-
with gr.Accordion("Show prompts", open=False):
|
| 118 |
-
prompts_1 = gr.Textbox(label="Prompts", interactive=False, value=paper_system_prompt)
|
| 119 |
-
|
| 120 |
-
with gr.Row():
|
| 121 |
-
button_generate_idea = gr.Button("Make it an idea!", variant="primary")
|
| 122 |
-
|
| 123 |
-
with gr.Tab("验证想法可行性 (Validate Feasibility)"):
|
| 124 |
-
title_input = gr.Textbox(label="Title")
|
| 125 |
-
contribution_input = gr.Textbox(label="Contributions", lines=5)
|
| 126 |
-
with gr.Accordion("Show prompts", open=False):
|
| 127 |
-
prompts_2 = gr.Textbox(label="Prompts", interactive=False, value=contribution_system_prompt)
|
| 128 |
-
|
| 129 |
-
with gr.Row():
|
| 130 |
-
button_submit = gr.Button("Criticize my idea!", variant="primary")
|
| 131 |
-
|
| 132 |
-
with gr.Tab("生成论文 (Generate Paper)"):
|
| 133 |
-
gr.Markdown("...")
|
| 134 |
-
|
| 135 |
-
with gr.Column(scale=1):
|
| 136 |
-
contribution_output = gr.JSON(label="Contributions")
|
| 137 |
-
# cn_output = gr.JSON(label="主要贡献")
|
| 138 |
-
with gr.Accordion("References", open=False):
|
| 139 |
-
references_output = gr.JSON(label="References")
|
| 140 |
-
|
| 141 |
-
button_submit.click(fn=criticize_my_idea, inputs=[title_input, contribution_input], outputs=[contribution_output, references_output])
|
| 142 |
-
button_generate_idea.click(fn=generate_choices, inputs=thoughts_input, outputs=contribution_output)#.success(translate_json, contribution_output, cn_output)
|
| 143 |
-
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
| 144 |
-
demo.launch(show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kdb_test.py
CHANGED
|
@@ -6,11 +6,15 @@ import gradio as gr
|
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
from models import EMBEDDINGS
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
HF_TOKEN =
|
| 13 |
-
REPO_ID =
|
| 14 |
if HF_TOKEN is not None and REPO_ID is not None:
|
| 15 |
snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/",
|
| 16 |
local_dir_use_symlinks=False, token=HF_TOKEN)
|
|
@@ -50,6 +54,29 @@ def query_from_kdb(input, kdb, query_counts):
|
|
| 50 |
raise RuntimeError(f"Failed to query from FAISS.")
|
| 51 |
return domain_knowledge, ""
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
with gr.Blocks() as demo:
|
| 54 |
with gr.Row():
|
| 55 |
with gr.Column():
|
|
@@ -76,9 +103,16 @@ with gr.Blocks() as demo:
|
|
| 76 |
interactive=True, label="QUERY_COUNTS",
|
| 77 |
info="How many contents will be retrieved from the vector database.")
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
button_retrieval.click(fn=query_from_kdb,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
| 84 |
demo.launch(show_error=True)
|
|
|
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
from models import EMBEDDINGS
|
| 9 |
+
from utils.gpt_interaction import GPTModel
|
| 10 |
+
from utils.prompts import SYSTEM
|
| 11 |
+
import openai
|
| 12 |
|
| 13 |
+
llm = GPTModel(model="gpt-3.5-turbo")
|
| 14 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 15 |
|
| 16 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 17 |
+
REPO_ID = os.getenv("KDB_REPO")
|
| 18 |
if HF_TOKEN is not None and REPO_ID is not None:
|
| 19 |
snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/",
|
| 20 |
local_dir_use_symlinks=False, token=HF_TOKEN)
|
|
|
|
| 54 |
raise RuntimeError(f"Failed to query from FAISS.")
|
| 55 |
return domain_knowledge, ""
|
| 56 |
|
| 57 |
+
def query_from_kdb_llm(title, contributions, kdb, query_counts):
|
| 58 |
+
if kdb == "(None)":
|
| 59 |
+
return {"knowledge_database": "(None)", "title": title, "contributions": contributions, "output": ""}, "", {}
|
| 60 |
+
|
| 61 |
+
db_path = f"knowledge_databases/{kdb}"
|
| 62 |
+
db_config_path = os.path.join(db_path, "db_meta.json")
|
| 63 |
+
db_index_path = os.path.join(db_path, "faiss_index")
|
| 64 |
+
if os.path.isdir(db_path):
|
| 65 |
+
# load configuration file
|
| 66 |
+
with open(db_config_path, "r", encoding="utf-8") as f:
|
| 67 |
+
db_config = json.load(f)
|
| 68 |
+
model_name = db_config["embedding_model"]
|
| 69 |
+
embeddings = EMBEDDINGS[model_name]
|
| 70 |
+
db = FAISS.load_local(db_index_path, embeddings)
|
| 71 |
+
knowledge = Knowledge(db=db)
|
| 72 |
+
prompts = f"Title: {title}\n Contributions: {contributions}"
|
| 73 |
+
preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts, return_json=True)
|
| 74 |
+
knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts)
|
| 75 |
+
domain_knowledge = knowledge.to_json()
|
| 76 |
+
else:
|
| 77 |
+
raise RuntimeError(f"Failed to query from FAISS.")
|
| 78 |
+
return domain_knowledge, "", preliminaries_kw
|
| 79 |
+
|
| 80 |
with gr.Blocks() as demo:
|
| 81 |
with gr.Row():
|
| 82 |
with gr.Column():
|
|
|
|
| 103 |
interactive=True, label="QUERY_COUNTS",
|
| 104 |
info="How many contents will be retrieved from the vector database.")
|
| 105 |
|
| 106 |
+
with gr.Column():
|
| 107 |
+
retrieval_output = gr.JSON(label="Output")
|
| 108 |
+
llm_kws = gr.JSON(label="Keywords generated by LLM")
|
| 109 |
|
| 110 |
+
button_retrieval.click(fn=query_from_kdb,
|
| 111 |
+
inputs=[user_input, kdb_dropdown, query_counts_slider],
|
| 112 |
+
outputs=[retrieval_output, user_input])
|
| 113 |
+
button_retrieval_2.click(fn=query_from_kdb_llm,
|
| 114 |
+
inputs=[title_input, contribution_input, kdb_dropdown, query_counts_slider],
|
| 115 |
+
outputs=[retrieval_output, user_input, llm_kws])
|
| 116 |
|
| 117 |
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
| 118 |
demo.launch(show_error=True)
|
references_generator.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
This script is used to generate the most relevant papers of a given title.
|
| 3 |
-
- Search for as many as possible references. For 10~15 keywords, 10 references each.
|
| 4 |
-
- Sort the results from most relevant to least relevant.
|
| 5 |
-
- Return the most relevant using token size.
|
| 6 |
-
|
| 7 |
-
Note: we do not use this function in auto-draft function. It has been integrated in that.
|
| 8 |
-
'''
|
| 9 |
-
|
| 10 |
-
import os.path
|
| 11 |
-
import json
|
| 12 |
-
from utils.references import References
|
| 13 |
-
from section_generator import keywords_generation # section_generation_bg, #, figures_generation, section_generation
|
| 14 |
-
import itertools
|
| 15 |
-
from gradio_client import Client
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def generate_raw_references(title, description="",
|
| 19 |
-
bib_refs=None, tldr=False, max_kw_refs=10,
|
| 20 |
-
save_to="ref.bib"):
|
| 21 |
-
# load pre-provided references
|
| 22 |
-
ref = References(title, bib_refs)
|
| 23 |
-
|
| 24 |
-
# generate multiple keywords for searching
|
| 25 |
-
input_dict = {"title": title, "description": description}
|
| 26 |
-
keywords, usage = keywords_generation(input_dict)
|
| 27 |
-
keywords = list(keywords)
|
| 28 |
-
comb_keywords = list(itertools.combinations(keywords, 2))
|
| 29 |
-
for comb_keyword in comb_keywords:
|
| 30 |
-
keywords.append(" ".join(comb_keyword))
|
| 31 |
-
keywords = {keyword:max_kw_refs for keyword in keywords}
|
| 32 |
-
print(f"keywords: {keywords}\n\n")
|
| 33 |
-
|
| 34 |
-
ref.collect_papers(keywords, tldr=tldr)
|
| 35 |
-
paper_json = ref.to_json()
|
| 36 |
-
|
| 37 |
-
with open(save_to, "w") as f:
|
| 38 |
-
json.dump(paper_json, f)
|
| 39 |
-
|
| 40 |
-
return save_to, ref # paper_json
|
| 41 |
-
|
| 42 |
-
def generate_top_k_references(title, description="",
|
| 43 |
-
bib_refs=None, tldr=False, max_kw_refs=10, save_to="ref.bib", top_k=5):
|
| 44 |
-
json_path, ref_raw = generate_raw_references(title, description, bib_refs, tldr, max_kw_refs, save_to)
|
| 45 |
-
json_content = ref_raw.to_json()
|
| 46 |
-
|
| 47 |
-
client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
|
| 48 |
-
result = client.predict(
|
| 49 |
-
title, # str in 'Title' Textbox component
|
| 50 |
-
json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
|
| 51 |
-
top_k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
|
| 52 |
-
api_name="/get_k_relevant_papers"
|
| 53 |
-
)
|
| 54 |
-
with open(result) as f:
|
| 55 |
-
result = json.load(f)
|
| 56 |
-
return result
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
if __name__ == "__main__":
|
| 60 |
-
import openai
|
| 61 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 62 |
-
|
| 63 |
-
title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
| 64 |
-
description = ""
|
| 65 |
-
save_to = "paper.json"
|
| 66 |
-
save_to, paper_json = generate_raw_references(title, description, save_to=save_to)
|
| 67 |
-
|
| 68 |
-
print("`paper.json` has been generated. Now evaluating its similarity...")
|
| 69 |
-
|
| 70 |
-
k = 5
|
| 71 |
-
client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
|
| 72 |
-
result = client.predict(
|
| 73 |
-
title, # str in 'Title' Textbox component
|
| 74 |
-
save_to, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
|
| 75 |
-
k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
|
| 76 |
-
api_name="/get_k_relevant_papers"
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
with open(result) as f:
|
| 80 |
-
result = json.load(f)
|
| 81 |
-
|
| 82 |
-
print(result)
|
| 83 |
-
|
| 84 |
-
save_to = "paper2.json"
|
| 85 |
-
with open(save_to, "w") as f:
|
| 86 |
-
json.dump(result, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/knowledge.py
CHANGED
|
@@ -16,7 +16,7 @@ class Knowledge:
|
|
| 16 |
self.db = db
|
| 17 |
self.contents = []
|
| 18 |
|
| 19 |
-
def collect_knowledge(self, keywords_dict, max_query):
|
| 20 |
"""
|
| 21 |
keywords_dict:
|
| 22 |
{"machine learning": 5, "language model": 2};
|
|
|
|
| 16 |
self.db = db
|
| 17 |
self.contents = []
|
| 18 |
|
| 19 |
+
def collect_knowledge(self, keywords_dict: dict, max_query: int):
|
| 20 |
"""
|
| 21 |
keywords_dict:
|
| 22 |
{"machine learning": 5, "language model": 2};
|
utils/references.py
CHANGED
|
@@ -3,52 +3,68 @@
|
|
| 3 |
#
|
| 4 |
# Generate references:
|
| 5 |
# `Reference` class:
|
| 6 |
-
# 1.
|
|
|
|
|
|
|
| 7 |
# 2. Given some keywords; use Semantic Scholar API to find papers.
|
| 8 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
| 9 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
| 10 |
# A sample prompt: {"paper_id": "paper summary"}
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# add all citations to `bib_papers`
|
| 15 |
-
# add all citedby to `bib_papers`
|
| 16 |
-
# use Semantic Scholar to find their embeddings
|
| 17 |
-
# (2) separate references:
|
| 18 |
-
# divide references into different groups to reduce the tokens count
|
| 19 |
-
# for generating different paragraph of related works, use different set of references
|
| 20 |
-
from typing import Dict, List
|
| 21 |
-
import requests
|
| 22 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
import bibtexparser
|
| 24 |
-
import random
|
| 25 |
-
from scholarly import scholarly
|
| 26 |
-
from scholarly import ProxyGenerator
|
| 27 |
-
import tiktoken
|
| 28 |
-
import itertools, uuid, json
|
| 29 |
-
from gradio_client import Client
|
| 30 |
-
import time
|
| 31 |
import numpy as np
|
|
|
|
|
|
|
| 32 |
from numpy.linalg import norm
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
|
| 36 |
MAX_BATCH_SIZE = 16
|
| 37 |
MAX_ATTEMPTS = 20
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
######################################################################################################################
|
| 40 |
# Some basic tools
|
| 41 |
######################################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def evaluate_cosine_similarity(v1, v2):
|
| 43 |
try:
|
| 44 |
-
return np.dot(v1, v2)/(norm(v1)*norm(v2))
|
| 45 |
except ValueError:
|
| 46 |
return 0.0
|
| 47 |
|
|
|
|
| 48 |
def chunks(lst, chunk_size=MAX_BATCH_SIZE):
|
| 49 |
"""Splits a longer list to respect batch size"""
|
| 50 |
for i in range(0, len(lst), chunk_size):
|
| 51 |
-
yield lst[i
|
|
|
|
| 52 |
|
| 53 |
def embed(papers):
|
| 54 |
embeddings_by_paper_id: Dict[str, List[float]] = {}
|
|
@@ -64,6 +80,7 @@ def embed(papers):
|
|
| 64 |
|
| 65 |
return embeddings_by_paper_id
|
| 66 |
|
|
|
|
| 67 |
def get_embeddings(paper_title, paper_description):
|
| 68 |
output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
|
| 69 |
emb_vector = embed(output)["target_paper"]
|
|
@@ -71,9 +88,17 @@ def get_embeddings(paper_title, paper_description):
|
|
| 71 |
target_paper["embeddings"] = emb_vector
|
| 72 |
return target_paper
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
|
|
|
| 75 |
target_paper = get_embeddings(paper_title, paper_description)
|
| 76 |
-
papers = papers_dict
|
| 77 |
|
| 78 |
# if k < len(papers_json), return k most relevant papers
|
| 79 |
# if k >= len(papers_json) or k is None, return all papers
|
|
@@ -88,7 +113,7 @@ def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
|
| 88 |
for k in papers:
|
| 89 |
v = papers[k]
|
| 90 |
embedding_vector = v["embeddings"]
|
| 91 |
-
cos_sim
|
| 92 |
papers[k]["cos_sim"] = cos_sim
|
| 93 |
|
| 94 |
# return the best k papers
|
|
@@ -97,14 +122,6 @@ def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
|
| 97 |
sorted_papers[key].pop("embeddings", None)
|
| 98 |
return sorted_papers
|
| 99 |
|
| 100 |
-
def remove_newlines(serie):
|
| 101 |
-
# This function is applied to the abstract of each paper to reduce the length of prompts.
|
| 102 |
-
serie = serie.replace('\n', ' ')
|
| 103 |
-
serie = serie.replace('\\n', ' ')
|
| 104 |
-
serie = serie.replace(' ', ' ')
|
| 105 |
-
serie = serie.replace(' ', ' ')
|
| 106 |
-
return serie
|
| 107 |
-
|
| 108 |
|
| 109 |
def search_paper_abstract(title):
|
| 110 |
pg = ProxyGenerator()
|
|
@@ -123,6 +140,159 @@ def search_paper_abstract(title):
|
|
| 123 |
return remove_newlines(found_paper['bib']['abstract'])
|
| 124 |
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def load_papers_from_bibtex(bib_file_path):
|
| 127 |
with open(bib_file_path) as bibtex_file:
|
| 128 |
bib_database = bibtexparser.load(bibtex_file)
|
|
@@ -154,15 +324,20 @@ def load_papers_from_bibtex(bib_file_path):
|
|
| 154 |
bib_papers.append(result)
|
| 155 |
return bib_papers
|
| 156 |
|
| 157 |
-
# `tokenizer`: used to count how many tokens
|
| 158 |
-
tokenizer_name = tiktoken.encoding_for_model('gpt-4')
|
| 159 |
-
tokenizer = tiktoken.get_encoding(tokenizer_name.name)
|
| 160 |
-
|
| 161 |
|
| 162 |
-
def
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
######################################################################################################################
|
|
@@ -174,7 +349,7 @@ def ss_search(keywords, limit=20, fields=None):
|
|
| 174 |
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
|
| 175 |
keywords = keywords.lower()
|
| 176 |
keywords = keywords.replace(" ", "+")
|
| 177 |
-
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)}'
|
| 178 |
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
| 179 |
headers = {"Accept": "*/*"}
|
| 180 |
|
|
@@ -183,27 +358,6 @@ def ss_search(keywords, limit=20, fields=None):
|
|
| 183 |
|
| 184 |
|
| 185 |
def _collect_papers_ss(keyword, counts=3, tldr=False):
|
| 186 |
-
def externalIds2link(externalIds):
|
| 187 |
-
# Sample externalIds:
|
| 188 |
-
# "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
|
| 189 |
-
if externalIds:
|
| 190 |
-
# Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
|
| 191 |
-
# priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
|
| 192 |
-
# DBLP
|
| 193 |
-
dblp_id = externalIds.get('DBLP')
|
| 194 |
-
if dblp_id is not None:
|
| 195 |
-
dblp_link = f"dblp.org/rec/{dblp_id}"
|
| 196 |
-
return dblp_link
|
| 197 |
-
# arXiv
|
| 198 |
-
arxiv_id = externalIds.get('ArXiv')
|
| 199 |
-
if arxiv_id is not None:
|
| 200 |
-
arxiv_link = f"arxiv.org/abs/{arxiv_id}"
|
| 201 |
-
return arxiv_link
|
| 202 |
-
return ""
|
| 203 |
-
else:
|
| 204 |
-
# if this is an empty dictionary, return an empty string
|
| 205 |
-
return ""
|
| 206 |
-
|
| 207 |
def extract_paper_id(last_name, year_str, title):
|
| 208 |
pattern = r'^\w+'
|
| 209 |
words = re.findall(pattern, title)
|
|
@@ -289,24 +443,28 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
| 289 |
######################################################################################################################
|
| 290 |
|
| 291 |
class References:
|
| 292 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
if load_papers is not None:
|
| 294 |
-
self.papers =
|
| 295 |
-
|
| 296 |
-
self.papers = {}
|
| 297 |
self.title = title
|
| 298 |
self.description = description
|
| 299 |
|
| 300 |
-
def
|
| 301 |
-
self.papers[keyword] = load_papers_from_bibtex(bibtex)
|
| 302 |
-
|
| 303 |
-
def generate_keywords_dict(self):
|
| 304 |
keywords_dict = {}
|
| 305 |
for k in self.papers:
|
| 306 |
keywords_dict[k] = len(self.papers[k])
|
| 307 |
return keywords_dict
|
| 308 |
|
| 309 |
-
def collect_papers(self, keywords_dict, tldr=False):
|
| 310 |
"""
|
| 311 |
Collect as many papers as possible
|
| 312 |
|
|
@@ -320,21 +478,15 @@ class References:
|
|
| 320 |
keywords.append(" ".join(comb_keyword))
|
| 321 |
for key in keywords:
|
| 322 |
self.papers[key] = _collect_papers_ss(key, 10, tldr)
|
| 323 |
-
# print("Collected papers: ", papers)
|
| 324 |
-
# for key, counts in keywords_dict.items():
|
| 325 |
-
# self.papers[key] = _collect_papers_ss(key, counts, tldr)
|
| 326 |
|
| 327 |
-
def to_bibtex(self, path_to_bibtex="ref.bib"):
|
| 328 |
"""
|
| 329 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
| 330 |
"""
|
| 331 |
-
# todo:
|
| 332 |
-
# use embeddings to evaluate; keep top k relevant references in papers
|
| 333 |
-
# send (title, .bib file) to evaluate embeddings; recieve truncated papers
|
| 334 |
papers = self._get_papers(keyword="_all")
|
| 335 |
|
| 336 |
-
|
| 337 |
-
print(f"{
|
| 338 |
# clear the bibtex file
|
| 339 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
| 340 |
file.write("")
|
|
@@ -372,7 +524,7 @@ class References:
|
|
| 372 |
papers = self.papers["keyword"]
|
| 373 |
return papers
|
| 374 |
|
| 375 |
-
def to_prompts(self, keyword="_all", max_tokens=2048):
|
| 376 |
# `prompts`:
|
| 377 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
| 378 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
|
@@ -384,21 +536,11 @@ class References:
|
|
| 384 |
papers_json = self.to_json()
|
| 385 |
with open(json_path, "w") as f:
|
| 386 |
json.dump(papers_json, f)
|
| 387 |
-
|
| 388 |
try:
|
| 389 |
# Use external API to obtain the most relevant papers
|
| 390 |
title = self.title
|
| 391 |
description = self.description
|
| 392 |
result = get_top_k(papers_json, title, description)
|
| 393 |
-
# client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
|
| 394 |
-
# result = client.predict(
|
| 395 |
-
# title, # str in 'Title' Textbox component
|
| 396 |
-
# json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
|
| 397 |
-
# 50, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
|
| 398 |
-
# api_name="/get_k_relevant_papers"
|
| 399 |
-
# )
|
| 400 |
-
# with open(result) as f:
|
| 401 |
-
# result = json.load(f)
|
| 402 |
result = [item for key, item in result.items()]
|
| 403 |
except Exception as e:
|
| 404 |
print(f"Error occurs during calling external API: {e}\n")
|
|
@@ -417,54 +559,9 @@ class References:
|
|
| 417 |
break
|
| 418 |
return prompts
|
| 419 |
|
| 420 |
-
def to_json(self, keyword="_all"):
|
| 421 |
papers = self._get_papers(keyword)
|
| 422 |
papers_json = {}
|
| 423 |
for paper in papers:
|
| 424 |
papers_json[paper["paper_id"]] = paper
|
| 425 |
return papers_json
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
if __name__ == "__main__":
|
| 429 |
-
# testing search results
|
| 430 |
-
print("================Testing `ss_search`================")
|
| 431 |
-
r = ss_search("Deep Q-Networks", limit=1) # a list of raw papers
|
| 432 |
-
if r['total'] > 0:
|
| 433 |
-
paper = r['data'][0]
|
| 434 |
-
# print(paper)
|
| 435 |
-
|
| 436 |
-
# resting References
|
| 437 |
-
print("================Testing `References`================")
|
| 438 |
-
refs = References(title="Super Deep Q-Networks")
|
| 439 |
-
keywords_dict = {
|
| 440 |
-
"Deep Q-Networks": 5,
|
| 441 |
-
"Actor-Critic Algorithms": 4,
|
| 442 |
-
"Exploration-Exploitation Trade-off": 3
|
| 443 |
-
}
|
| 444 |
-
print("================Testing `References.collect_papers`================")
|
| 445 |
-
refs.collect_papers(keywords_dict, tldr=True)
|
| 446 |
-
for k in refs.papers:
|
| 447 |
-
papers = refs.papers[k] # for each keyword, there is a list of papers
|
| 448 |
-
print("keyword: ", k)
|
| 449 |
-
for paper in papers:
|
| 450 |
-
print(paper["paper_id"])
|
| 451 |
-
|
| 452 |
-
print("================Testing `References.to_bibtex`================")
|
| 453 |
-
refs.to_bibtex()
|
| 454 |
-
|
| 455 |
-
print("================Testing `References.to_json`================")
|
| 456 |
-
papers_json = refs.to_json() # this json can be used to find the most relevant papers
|
| 457 |
-
with open("papers.json", "w", encoding='utf-8') as text_file:
|
| 458 |
-
text_file.write(f"{papers_json}")
|
| 459 |
-
|
| 460 |
-
print("================Testing `References.to_prompts`================")
|
| 461 |
-
prompts = refs.to_prompts()
|
| 462 |
-
print(prompts)
|
| 463 |
-
|
| 464 |
-
# bib = "test.bib"
|
| 465 |
-
# refs.load_papers(bib, "variance-reduction rl")
|
| 466 |
-
# print(refs.papers)
|
| 467 |
-
#
|
| 468 |
-
# prompts = refs.to_prompts()
|
| 469 |
-
# for k in prompts:
|
| 470 |
-
# print(f"{k}: {prompts[k]}\n")
|
|
|
|
| 3 |
#
|
| 4 |
# Generate references:
|
| 5 |
# `Reference` class:
|
| 6 |
+
# 1. Two methods to load papers:
|
| 7 |
+
# 1.1. Read a given string including paper titles separated by `,`
|
| 8 |
+
# 1.2. Read a .bib file
|
| 9 |
# 2. Given some keywords; use Semantic Scholar API to find papers.
|
| 10 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
| 11 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
| 12 |
# A sample prompt: {"paper_id": "paper summary"}
|
| 13 |
+
# 5. Generate json from the selected papers. --> to_json()
|
| 14 |
|
| 15 |
+
import itertools
|
| 16 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import re
|
| 18 |
+
import uuid
|
| 19 |
+
from typing import Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import arxiv
|
| 22 |
import bibtexparser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
import numpy as np
|
| 24 |
+
import requests
|
| 25 |
+
import tiktoken
|
| 26 |
from numpy.linalg import norm
|
| 27 |
+
from scholarly import ProxyGenerator
|
| 28 |
+
from scholarly import scholarly
|
| 29 |
|
| 30 |
+
# used to evaluate embeddings
|
| 31 |
URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
|
| 32 |
MAX_BATCH_SIZE = 16
|
| 33 |
MAX_ATTEMPTS = 20
|
| 34 |
|
| 35 |
+
# `tokenizer`: used to count how many tokens
|
| 36 |
+
tokenizer_name = tiktoken.encoding_for_model('gpt-4')
|
| 37 |
+
tokenizer = tiktoken.get_encoding(tokenizer_name.name)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
######################################################################################################################
|
| 41 |
# Some basic tools
|
| 42 |
######################################################################################################################
|
| 43 |
+
def remove_special_characters(s):
|
| 44 |
+
return ''.join(c for c in s if c.isalnum() or c.isspace() or c == ',')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def remove_newlines(serie):
|
| 48 |
+
# This function is applied to the abstract of each paper to reduce the length of prompts.
|
| 49 |
+
serie = serie.replace('\n', ' ')
|
| 50 |
+
serie = serie.replace('\\n', ' ')
|
| 51 |
+
serie = serie.replace(' ', ' ')
|
| 52 |
+
serie = serie.replace(' ', ' ')
|
| 53 |
+
return serie
|
| 54 |
+
|
| 55 |
+
|
| 56 |
def evaluate_cosine_similarity(v1, v2):
|
| 57 |
try:
|
| 58 |
+
return np.dot(v1, v2) / (norm(v1) * norm(v2))
|
| 59 |
except ValueError:
|
| 60 |
return 0.0
|
| 61 |
|
| 62 |
+
|
| 63 |
def chunks(lst, chunk_size=MAX_BATCH_SIZE):
|
| 64 |
"""Splits a longer list to respect batch size"""
|
| 65 |
for i in range(0, len(lst), chunk_size):
|
| 66 |
+
yield lst[i: i + chunk_size]
|
| 67 |
+
|
| 68 |
|
| 69 |
def embed(papers):
|
| 70 |
embeddings_by_paper_id: Dict[str, List[float]] = {}
|
|
|
|
| 80 |
|
| 81 |
return embeddings_by_paper_id
|
| 82 |
|
| 83 |
+
|
| 84 |
def get_embeddings(paper_title, paper_description):
|
| 85 |
output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
|
| 86 |
emb_vector = embed(output)["target_paper"]
|
|
|
|
| 88 |
target_paper["embeddings"] = emb_vector
|
| 89 |
return target_paper
|
| 90 |
|
| 91 |
+
|
| 92 |
+
def get_embeddings_vector(paper_title, paper_description):
|
| 93 |
+
output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
|
| 94 |
+
emb_vector = embed(output)["target_paper"]
|
| 95 |
+
return emb_vector
|
| 96 |
+
|
| 97 |
+
|
| 98 |
def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
| 99 |
+
# returns the top k papers most similar to the target paper
|
| 100 |
target_paper = get_embeddings(paper_title, paper_description)
|
| 101 |
+
papers = papers_dict # must include embeddings
|
| 102 |
|
| 103 |
# if k < len(papers_json), return k most relevant papers
|
| 104 |
# if k >= len(papers_json) or k is None, return all papers
|
|
|
|
| 113 |
for k in papers:
|
| 114 |
v = papers[k]
|
| 115 |
embedding_vector = v["embeddings"]
|
| 116 |
+
cos_sim = evaluate_cosine_similarity(embedding_vector, target_embedding_vector)
|
| 117 |
papers[k]["cos_sim"] = cos_sim
|
| 118 |
|
| 119 |
# return the best k papers
|
|
|
|
| 122 |
sorted_papers[key].pop("embeddings", None)
|
| 123 |
return sorted_papers
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
def search_paper_abstract(title):
|
| 127 |
pg = ProxyGenerator()
|
|
|
|
| 140 |
return remove_newlines(found_paper['bib']['abstract'])
|
| 141 |
|
| 142 |
|
| 143 |
+
def tiktoken_len(text):
|
| 144 |
+
# evaluate how many tokens for the given text
|
| 145 |
+
tokens = tokenizer.encode(text, disallowed_special=())
|
| 146 |
+
return len(tokens)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
######################################################################################################################
|
| 150 |
+
# Academic search tools
|
| 151 |
+
######################################################################################################################
|
| 152 |
+
def externalIds2link(externalIds):
|
| 153 |
+
# Sample externalIds:
|
| 154 |
+
# "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
|
| 155 |
+
if externalIds:
|
| 156 |
+
# Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
|
| 157 |
+
# priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
|
| 158 |
+
# DBLP
|
| 159 |
+
dblp_id = externalIds.get('DBLP')
|
| 160 |
+
if dblp_id is not None:
|
| 161 |
+
dblp_link = f"dblp.org/rec/{dblp_id}"
|
| 162 |
+
return dblp_link
|
| 163 |
+
# arXiv
|
| 164 |
+
arxiv_id = externalIds.get('ArXiv')
|
| 165 |
+
if arxiv_id is not None:
|
| 166 |
+
arxiv_link = f"arxiv.org/abs/{arxiv_id}"
|
| 167 |
+
return arxiv_link
|
| 168 |
+
return ""
|
| 169 |
+
else:
|
| 170 |
+
# if this is an empty dictionary, return an empty string
|
| 171 |
+
return ""
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def search_paper_arxiv(title):
|
| 175 |
+
search = arxiv.Search(
|
| 176 |
+
query=title,
|
| 177 |
+
max_results=1,
|
| 178 |
+
sort_by=arxiv.SortCriterion.Relevance
|
| 179 |
+
)
|
| 180 |
+
try:
|
| 181 |
+
# (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal (8) embeddings
|
| 182 |
+
result = next(search.results())
|
| 183 |
+
title = result.title
|
| 184 |
+
authors = " and ".join([author.name for author in result.authors])
|
| 185 |
+
year = str(result.updated.now().year)
|
| 186 |
+
link = result.pdf_url
|
| 187 |
+
abstract = result.summary
|
| 188 |
+
journal = f"Arxiv: {result.entry_id}"
|
| 189 |
+
paper_id = result.authors[0].name.replace(" ", "")[:4] + year + title[:6].replace(" ", "")
|
| 190 |
+
paper_id = paper_id.lower()
|
| 191 |
+
|
| 192 |
+
paper = {"paper_id": paper_id,
|
| 193 |
+
"title": title,
|
| 194 |
+
"authors": authors,
|
| 195 |
+
"year": year,
|
| 196 |
+
"link": link,
|
| 197 |
+
"abstract": abstract,
|
| 198 |
+
"journal": journal}
|
| 199 |
+
except StopIteration:
|
| 200 |
+
paper = {}
|
| 201 |
+
return paper
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def search_paper_ss(title):
|
| 205 |
+
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "externalIds"]
|
| 206 |
+
limit = 1
|
| 207 |
+
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={title}&limit={limit}&fields={",".join(fields)}'
|
| 208 |
+
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
| 209 |
+
headers = {"Accept": "*/*"}
|
| 210 |
+
response = requests.get(url, headers=headers, timeout=30)
|
| 211 |
+
results = response.json()
|
| 212 |
+
if results['total'] == 0:
|
| 213 |
+
return {}
|
| 214 |
+
raw_paper = results['data'][0]
|
| 215 |
+
if raw_paper['tldr'] is not None:
|
| 216 |
+
abstract = raw_paper['tldr']['text']
|
| 217 |
+
elif raw_paper['abstract'] is not None:
|
| 218 |
+
abstract = remove_newlines(raw_paper['abstract'])
|
| 219 |
+
else:
|
| 220 |
+
abstract = ""
|
| 221 |
+
|
| 222 |
+
authors = [author['name'] for author in raw_paper['authors']]
|
| 223 |
+
authors_str = " and ".join(authors)
|
| 224 |
+
year_str = str(raw_paper['year'])
|
| 225 |
+
title = raw_paper['title']
|
| 226 |
+
|
| 227 |
+
paper_id = authors_str.replace(" ", "")[:4] + year_str + title[:6].replace(" ", "")
|
| 228 |
+
|
| 229 |
+
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
| 230 |
+
journal = remove_special_characters(raw_paper['venue'])
|
| 231 |
+
if not journal:
|
| 232 |
+
journal = "arXiv preprint"
|
| 233 |
+
link = externalIds2link(raw_paper['externalIds'])
|
| 234 |
+
paper = {
|
| 235 |
+
"paper_id": paper_id,
|
| 236 |
+
"title": title,
|
| 237 |
+
"abstract": abstract,
|
| 238 |
+
"link": link,
|
| 239 |
+
"authors": authors_str,
|
| 240 |
+
"year": year_str,
|
| 241 |
+
"journal": journal
|
| 242 |
+
}
|
| 243 |
+
return paper
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def search_paper_scrape(title):
|
| 247 |
+
pg = ProxyGenerator()
|
| 248 |
+
success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
|
| 249 |
+
if success:
|
| 250 |
+
try:
|
| 251 |
+
scholarly.use_proxy(pg)
|
| 252 |
+
# input the title of a paper, return its abstract
|
| 253 |
+
search_query = scholarly.search_pubs(title)
|
| 254 |
+
found_paper = next(search_query)
|
| 255 |
+
url = found_paper['pub_url']
|
| 256 |
+
|
| 257 |
+
result = found_paper['bib']
|
| 258 |
+
|
| 259 |
+
title = result['title']
|
| 260 |
+
authors = " and ".join(result['author'])
|
| 261 |
+
year = str(result['pub_year'])
|
| 262 |
+
journal = result['pub_year']
|
| 263 |
+
abstract = result['abstract']
|
| 264 |
+
|
| 265 |
+
paper_id = authors.replace(" ", "")[:4] + year + title[:6].replace(" ", "")
|
| 266 |
+
paper = {
|
| 267 |
+
"paper_id": paper_id,
|
| 268 |
+
"title": title,
|
| 269 |
+
"abstract": abstract,
|
| 270 |
+
"link": url,
|
| 271 |
+
"authors": authors,
|
| 272 |
+
"year": year,
|
| 273 |
+
"journal": journal
|
| 274 |
+
}
|
| 275 |
+
return paper
|
| 276 |
+
except StopIteration:
|
| 277 |
+
return {}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def search_paper(title, verbose=True):
|
| 281 |
+
if verbose:
|
| 282 |
+
print(f"Searching {title}...")
|
| 283 |
+
# try Semantic Scholar first
|
| 284 |
+
paper = search_paper_ss(title)
|
| 285 |
+
if not paper:
|
| 286 |
+
paper = search_paper_arxiv(title)
|
| 287 |
+
if not paper:
|
| 288 |
+
paper = search_paper_scrape(title)
|
| 289 |
+
if paper:
|
| 290 |
+
paper["embeddings"] = get_embeddings_vector(paper_title=paper['title'], paper_description=paper['abstract'])
|
| 291 |
+
if verbose:
|
| 292 |
+
print(f"Search result: {paper}.")
|
| 293 |
+
return paper
|
| 294 |
+
|
| 295 |
+
|
| 296 |
def load_papers_from_bibtex(bib_file_path):
|
| 297 |
with open(bib_file_path) as bibtex_file:
|
| 298 |
bib_database = bibtexparser.load(bibtex_file)
|
|
|
|
| 324 |
bib_papers.append(result)
|
| 325 |
return bib_papers
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
+
def load_papers_from_text(text):
|
| 329 |
+
# split text by comma
|
| 330 |
+
titles = [part.strip() for part in text.split(',')]
|
| 331 |
+
titles = [remove_special_characters(title) for title in titles]
|
| 332 |
+
papers = []
|
| 333 |
+
if len(titles) > 0:
|
| 334 |
+
for title in titles:
|
| 335 |
+
paper = search_paper(title)
|
| 336 |
+
if paper:
|
| 337 |
+
papers.append(paper)
|
| 338 |
+
return papers
|
| 339 |
+
else:
|
| 340 |
+
return []
|
| 341 |
|
| 342 |
|
| 343 |
######################################################################################################################
|
|
|
|
| 349 |
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
|
| 350 |
keywords = keywords.lower()
|
| 351 |
keywords = keywords.replace(" ", "+")
|
| 352 |
+
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)} '
|
| 353 |
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
| 354 |
headers = {"Accept": "*/*"}
|
| 355 |
|
|
|
|
| 358 |
|
| 359 |
|
| 360 |
def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
def extract_paper_id(last_name, year_str, title):
|
| 362 |
pattern = r'^\w+'
|
| 363 |
words = re.findall(pattern, title)
|
|
|
|
| 443 |
######################################################################################################################
|
| 444 |
|
| 445 |
class References:
|
| 446 |
+
def __init__(self,
|
| 447 |
+
title: str,
|
| 448 |
+
load_papers: Optional[str] = None,
|
| 449 |
+
load_bibtex: Optional[str] = None,
|
| 450 |
+
description: str = ""
|
| 451 |
+
):
|
| 452 |
+
self.papers = {}
|
| 453 |
+
if load_bibtex is not None:
|
| 454 |
+
self.papers["load_from_bibtex"] = load_papers_from_bibtex(load_bibtex)
|
| 455 |
if load_papers is not None:
|
| 456 |
+
self.papers["load_from_text"] = load_papers_from_text(load_papers)
|
| 457 |
+
|
|
|
|
| 458 |
self.title = title
|
| 459 |
self.description = description
|
| 460 |
|
| 461 |
+
def generate_keywords_dict(self) -> Dict[str, int]:
|
|
|
|
|
|
|
|
|
|
| 462 |
keywords_dict = {}
|
| 463 |
for k in self.papers:
|
| 464 |
keywords_dict[k] = len(self.papers[k])
|
| 465 |
return keywords_dict
|
| 466 |
|
| 467 |
+
def collect_papers(self, keywords_dict: Dict[str, int], tldr: bool = False) -> None:
|
| 468 |
"""
|
| 469 |
Collect as many papers as possible
|
| 470 |
|
|
|
|
| 478 |
keywords.append(" ".join(comb_keyword))
|
| 479 |
for key in keywords:
|
| 480 |
self.papers[key] = _collect_papers_ss(key, 10, tldr)
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
+
def to_bibtex(self, path_to_bibtex: str = "ref.bib") -> List[str]:
|
| 483 |
"""
|
| 484 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
| 485 |
"""
|
|
|
|
|
|
|
|
|
|
| 486 |
papers = self._get_papers(keyword="_all")
|
| 487 |
|
| 488 |
+
num_papers = len(papers)
|
| 489 |
+
print(f"{num_papers} papers will be added to `ref.bib`.")
|
| 490 |
# clear the bibtex file
|
| 491 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
| 492 |
file.write("")
|
|
|
|
| 524 |
papers = self.papers["keyword"]
|
| 525 |
return papers
|
| 526 |
|
| 527 |
+
def to_prompts(self, keyword: str = "_all", max_tokens: int = 2048):
|
| 528 |
# `prompts`:
|
| 529 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
| 530 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
|
|
|
| 536 |
papers_json = self.to_json()
|
| 537 |
with open(json_path, "w") as f:
|
| 538 |
json.dump(papers_json, f)
|
|
|
|
| 539 |
try:
|
| 540 |
# Use external API to obtain the most relevant papers
|
| 541 |
title = self.title
|
| 542 |
description = self.description
|
| 543 |
result = get_top_k(papers_json, title, description)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
result = [item for key, item in result.items()]
|
| 545 |
except Exception as e:
|
| 546 |
print(f"Error occurs during calling external API: {e}\n")
|
|
|
|
| 559 |
break
|
| 560 |
return prompts
|
| 561 |
|
| 562 |
+
def to_json(self, keyword: str = "_all"):
|
| 563 |
papers = self._get_papers(keyword)
|
| 564 |
papers_json = {}
|
| 565 |
for paper in papers:
|
| 566 |
papers_json[paper["paper_id"]] = paper
|
| 567 |
return papers_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
worker.py
CHANGED
|
@@ -3,7 +3,7 @@ This script is only used for service-side host.
|
|
| 3 |
'''
|
| 4 |
import boto3
|
| 5 |
import os, time
|
| 6 |
-
from
|
| 7 |
from sqlalchemy import create_engine, Table, MetaData, update, select
|
| 8 |
from sqlalchemy.orm import sessionmaker
|
| 9 |
from sqlalchemy import inspect
|
|
|
|
| 3 |
'''
|
| 4 |
import boto3
|
| 5 |
import os, time
|
| 6 |
+
from wrapper import generator_wrapper
|
| 7 |
from sqlalchemy import create_engine, Table, MetaData, update, select
|
| 8 |
from sqlalchemy.orm import sessionmaker
|
| 9 |
from sqlalchemy import inspect
|
wrapper.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script is used to wrap all generation methods together.
|
| 3 |
+
|
| 4 |
+
todo:
|
| 5 |
+
A worker keeps running on the server. Monitor the Amazon SQS. Once receive a new message, do the following:
|
| 6 |
+
Download the corresponding configuration files on S3.
|
| 7 |
+
Change Task status from Pending to Running.
|
| 8 |
+
Call `generator_wrapper` and wait for the outputs.
|
| 9 |
+
If `generator_wrapper` returns results:
|
| 10 |
+
evaluate the results; compile it; upload results to S3 ... Change Task status from Running to Completed.
|
| 11 |
+
If anything goes wrong, raise Error.
|
| 12 |
+
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
| 13 |
+
Change Task status from Running to Failed.
|
| 14 |
+
"""
|
| 15 |
+
from auto_generators import generate_draft
|
| 16 |
+
from utils.file_operations import make_archive
|
| 17 |
+
import yaml
|
| 18 |
+
import uuid
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def remove_special_characters(s):
|
| 22 |
+
return ''.join(c for c in s if c.isalnum() or c.isspace() or c == ',')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def generator_wrapper(config):
|
| 26 |
+
if not isinstance(config, dict):
|
| 27 |
+
with open(config, "r") as file:
|
| 28 |
+
config = yaml.safe_load(file)
|
| 29 |
+
title = config["paper"]["title"]
|
| 30 |
+
generator = config["generator"]
|
| 31 |
+
if generator == "auto_draft":
|
| 32 |
+
folder = generate_draft(title, config["paper"]["description"],
|
| 33 |
+
tldr=config["references"]["tldr"],
|
| 34 |
+
max_kw_refs=config["references"]["max_kw_refs"],
|
| 35 |
+
refs=config["references"]["refs"],
|
| 36 |
+
max_tokens_ref=config["references"]["max_tokens_ref"],
|
| 37 |
+
knowledge_database=config["domain_knowledge"]["knowledge_database"],
|
| 38 |
+
max_tokens_kd=config["domain_knowledge"]["max_tokens_kd"],
|
| 39 |
+
query_counts=config["domain_knowledge"]["query_counts"],
|
| 40 |
+
sections=config["output"]["selected_sections"],
|
| 41 |
+
model=config["output"]["model"],
|
| 42 |
+
template=config["output"]["template"],
|
| 43 |
+
prompts_mode=config["output"]["prompts_mode"],
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
raise NotImplementedError(f"The generator {generator} has not been supported yet.")
|
| 47 |
+
# todo: post processing: translate to Chinese, compile PDF ...
|
| 48 |
+
filename = remove_special_characters(title).replace(" ", "_") + uuid.uuid1().hex + ".zip"
|
| 49 |
+
return make_archive(folder, filename)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
pass
|
| 54 |
+
# with open("configurations/default.yaml", 'r') as file:
|
| 55 |
+
# config = yaml.safe_load(file)
|
| 56 |
+
# print(config)
|
| 57 |
+
# generator_wrapper(config)
|