if __name__ == '__main__': import pandas as pd import json # 加载地质描述文本,提取prompt和label with open('./data/train_triples.json', 'r', encoding='utf-8') as f: data = json.load(f) # 将data转换为DataFrame df = pd.DataFrame(data) # 提取prompt和label text = df['text'] label = df['triple_list'] from response_to_json import parse_llm_response, save_to_json, save_raw_response from LLM import zero_shot from prompt_generate import generate_prompt_with_examples as generate_prompt from prompt_generate import generate_prompt_with_best_matches as generate_prompt_b model_series = 'deepSeek' # 提示 $0.5 / 1M tokens 补全 $1 / 1M tokens # model_name = 'deepseek-ai/DeepSeek-V3' # 提示 $1 / 1M tokens 补全 $4 / 1M tokens # model_name='deepseek-ai/DeepSeek-R1' # model_name = 'meta-llama/Meta-Llama-3.1-405B-Instruct' model_name = 'Qwen/Qwen2.5-72B-Instruct' prompt = ''' 你是一名专业经验丰富的工程地质领域专家,你的任务是从给定的输入文本中提取"实体-关系-实体"三元组。关系类型包括24种:"出露于"、"位于"、"整合接触"、"不整合接触"、"假整合接触"、"断层接触"、"分布形态"、"大地构造位置"、"地层区划"、"出露地层"、"岩性"、"厚度"、"面积"、"坐标"、"长度"、"含有"、"所属年代"、"行政区划"、"发育"、"古生物"、"海拔"、"属于"、"吞噬"、"侵入"。提取过程请按照以下规范: 1. 输出格式: 严格遵循JSON数组,无额外文本,每个元素包含: [ { "entity1": "实体1", "relation": "关系", "entity2": "实体2" } ] 2. 复杂关系处理: - 若同一实体参与多个关系,需分别列出不同三元组 ''' j=0 q=0 # json_path = './output/knn/three_shot/'+model_name+'.json' # j=len(json.load(open(json_path,'r',encoding='utf-8'))) # q=len(json.load(open('./output/knn/three_shot_raw/'+model_name+'.json','r',encoding='utf-8'))) # # 当q=j时才继续处理 if q==j: print(j) for i in range(j,500): # 从text的500-1000数据中随机获取一个完整的text和triple_list作为提示 # prompt_string = generate_prompt(text, label, 3) # print(prompt_string) prompt_string_b = generate_prompt_b(text, label, text[i], 3) # response = zero_shot(model_series, model_name, prompt+text[i]) response = zero_shot(model_series, model_name, prompt+'\n'+'以下是地质描述文本和三元组提取样例'+'\n'+prompt_string_b+'\n'+'请根据样例提取三元组'+'\n'+text[i]) # print(prompt+'\n'+'以下是地质描述文本和三元组提取样例'+'\n'+prompt_string_b+'\n'+'请根据样例提取三元组'+'\n'+text[i]) # 解析响应 formatted_triples = parse_llm_response(response) # 保存结果 save_to_json(text[i], formatted_triples, model_series=model_name, output_dir='./output/knn/three_shot/') # 保存原始响应为josn文件save_raw_response save_raw_response(response, text[i], model_series=model_name, output_dir='./output/knn/three_shot_raw/') else: print('q!=j')