GeoLLM / response_to_json.py
Ciallo0d00's picture
Upload folder using huggingface_hub
badcf3c verified
import json
import os
def parse_llm_response(response):
"""
解析LLM响应并转换为标准格式
Args:
response: 可以是字符串或带有content属性的对象
"""
# 检查response类型并获取content
if isinstance(response, str):
content = response
elif isinstance(response, dict) and 'content' in response:
content = response['content']
else:
content = response.content
# 尝试查找JSON内容
try:
# 首先尝试查找markdown JSON块
if '```json' in content:
content = content.split('```json\n')[1].split('\n```')[0]
elif '```' in content:
content = content.split('```\n')[1].split('\n```')[0]
# 如果上述方法失败,尝试查找方括号包围的JSON数组
if content.find('[') != -1 and content.find(']') != -1:
start = content.find('[')
end = content.rfind(']') + 1
content = content[start:end]
# 清理可能的多余空白字符
content = content.strip()
# 解析JSON
triples = json.loads(content)
# # 在解析时,检查是否存在实体和关系,不存在报错,提示重新生成
# cleaned_triples = []
# for triple in triples:
# cleaned = {
# "entity1": triple["entity1"],
# "relation": triple["relation"],
# "entity2": triple["entity2"]
# }
# cleaned_triples.append(cleaned)
# 转换格式,检查是否存在实体和关系,不存在报错,提示重新生成
formatted_triples = {
"triple_list": [
[triple["entity1"], triple["relation"], triple["entity2"]]
for triple in triples
if all(key in triple for key in ["entity1", "relation", "entity2"])
]
}
# "triple_list": cleaned_triples
return formatted_triples
except json.JSONDecodeError as e:
print(f"JSON解析错误,原始内容: {content}")
print(f"错误详情: {str(e)}")
return None
def save_to_json(text, formatted_triples, model_series, output_dir='./output'):
"""
保存结果到JSON文件
"""
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
output_path = f'{output_dir}/{model_series}.json'
# 如果文件不存在则创建新文件,否则读取已有内容
if os.path.exists(output_path):
with open(output_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
else:
existing_data = []
# 创建包含文本和三元组的新数据项
new_item = {
"text": text,
"triple_list": formatted_triples["triple_list"]
}
# 将新的数据项添加到已有数据中
if isinstance(existing_data, dict):
existing_data = [existing_data]
existing_data.append(new_item)
# 输出更新后的JSON格式
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, ensure_ascii=False, indent=4)
# 保存原始响应函数,仿照save_to_json
def save_raw_response(response, prompt, model_series, output_dir='./output/two_shot_raw'):
"""
保存原始响应到JSON文件
"""
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
output_path = f'{output_dir}/{model_series}.json'
# 如果文件不存在则创建新文件,否则读取已有内容
if os.path.exists(output_path):
with open(output_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
else:
existing_data = []
# 创建包含文本和三元组的新数据项
new_item = {
"prompt": prompt,
# 解析response中的/n为回车
# AttributeError: 'dict' object has no attribute 'replace'
# 将response转换为字符串
"response": response.replace('\\n', '\n') if isinstance(response, str) else str(response)
}
# 将新的数据项添加到已有数据中
existing_data.append(new_item)
# 输出更新后的JSON格式
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, ensure_ascii=False, indent=4)