|
|
from propositional_logic.random_gen.evaluation_access import * |
|
|
from vllm import LLM, SamplingParams |
|
|
import time |
|
|
import json |
|
|
import pickle |
|
|
import os |
|
|
import time |
|
|
import re |
|
|
import sys |
|
|
from tqdm import tqdm |
|
|
import argparse |
|
|
from loguru import logger |
|
|
from tqdm import tqdm |
|
|
import concurrent.futures |
|
|
import random |
|
|
import subprocess |
|
|
|
|
|
import uuid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DFS: |
|
|
def __init__( |
|
|
self, |
|
|
num_sampled_tactics: int, |
|
|
temperature, |
|
|
test_theorem_list, |
|
|
max_workers, |
|
|
saved_file_path, |
|
|
experiment_id |
|
|
) -> None: |
|
|
self.max_workers = max_workers |
|
|
self.test_theorem_list = test_theorem_list |
|
|
self.whether_backtrack = {} |
|
|
self.num_sampled_tactics = num_sampled_tactics |
|
|
self.count_lean_dict = {} |
|
|
self.counter_success = 0 |
|
|
self.counter_failed = 0 |
|
|
self.counter_in_process = 0 |
|
|
self.counter_too_long = 0 |
|
|
self.tactic_list_tree = {} |
|
|
self.theorem_object_dict = {} |
|
|
self.prompts_tactic_state_list = {} |
|
|
self.root = {} |
|
|
self.round_count = {} |
|
|
self.parent_node_of_node = {} |
|
|
self.round = 0 |
|
|
self.saved_file_path = saved_file_path |
|
|
self.cstate_round = {} |
|
|
self.experiment_id = experiment_id |
|
|
self.key_to_be_infered = [] |
|
|
self.key_not_finished = [] |
|
|
self.prompts_entered = [] |
|
|
self.counter_failed_with_error = 0 |
|
|
self.temperature = temperature |
|
|
|
|
|
print(f"{len(self.test_theorem_list)} many theorem loaded") |
|
|
for key in tqdm(self.test_theorem_list): |
|
|
sample_eval = SingleTheoremEval(5, int(key)) |
|
|
self.theorem_object_dict[key] = sample_eval |
|
|
init_state = self.theorem_object_dict[key].get_initial_prompt() |
|
|
|
|
|
self.parent_node_of_node[key] = {} |
|
|
self.prompts_tactic_state_list[key] = [init_state] |
|
|
self.root[key] = 'open' |
|
|
self.tactic_list_tree[key] = {} |
|
|
self.tactic_list_tree[key]["state_0:"] = None |
|
|
self.parent_node_of_node[key]["state_0:"] = None |
|
|
self.count_lean_dict[key] = {} |
|
|
self.count_lean_dict[key]['count_lean_multiple_backtrack'] = 0 |
|
|
self.count_lean_dict[key]['count_lean_single_backtrack'] = 0 |
|
|
self.count_lean_dict[key]['count_lean_tactic_success'] = 0 |
|
|
self.round_count[key] = 0 |
|
|
self.whether_backtrack[key] = False |
|
|
|
|
|
print('initialization done') |
|
|
def get_current_state_number(self, key): |
|
|
string = self.theorem_object_dict[key].get_current_state_with_label() |
|
|
for line in string.split('\n'): |
|
|
break |
|
|
return line |
|
|
def get_prev_state_number(self, key): |
|
|
string = self.theorem_object_dict[key].get_prev_state_with_label() |
|
|
for line in string.split('\n'): |
|
|
break |
|
|
return line |
|
|
def revise_entered_tactic(self,entered_tactic,key): |
|
|
if len(entered_tactic) != 2: |
|
|
assert True==False |
|
|
current_state_label = self.get_current_state_number(key) |
|
|
entered_tactic[0] = current_state_label[:-1] + '_tactic_0:' |
|
|
return entered_tactic |
|
|
def back_track_tactic(self, key): |
|
|
current_state_number = self.get_current_state_number(key) |
|
|
for line in current_state_number.split('\n'): |
|
|
break |
|
|
match = re.search(r'\d+', line) |
|
|
if match: |
|
|
extracted_current_integer = int(match.group()) |
|
|
else: |
|
|
assert False, 'no number in current state' |
|
|
|
|
|
previous_state = self.parent_node_of_node[key][current_state_number] |
|
|
previous_state_to_be_checked = self.get_prev_state_number(key) |
|
|
if previous_state != previous_state_to_be_checked: |
|
|
assert False, f'key is {key}, during backtrack, previous state marked and previous state by system are not the same' |
|
|
for line in previous_state.split('\n'): |
|
|
break |
|
|
match = re.search(r'\d+', line) |
|
|
if match: |
|
|
extracted_previous_integer = int(match.group()) |
|
|
else: |
|
|
assert False |
|
|
return f"no solution, return to state {extracted_previous_integer} [that leads to state {extracted_current_integer}]" |
|
|
def revise_output_list(self, output_text): |
|
|
output_line_list = output_text.split("\n") |
|
|
is_tactic = False |
|
|
for idx_tactic, line in enumerate(output_line_list): |
|
|
if '_tactic_' in line: |
|
|
is_tactic = True |
|
|
break |
|
|
|
|
|
if is_tactic==False: |
|
|
|
|
|
return ['no_tactic','no_tactic'] |
|
|
|
|
|
if "::: " in output_line_list[idx_tactic]: |
|
|
output_line_list[idx_tactic] = output_line_list[idx_tactic][4:] |
|
|
entered_tactic_list = output_line_list[idx_tactic:idx_tactic+2] |
|
|
|
|
|
if len(entered_tactic_list) == 1: |
|
|
return ['no_tactic','no_tactic'] |
|
|
return entered_tactic_list |
|
|
def check_if_failure_per_key(self, key): |
|
|
if len(self.tactic_list_tree[key]['state_0:']) == 0 and self.get_current_state_number(key) == 'state_0:': |
|
|
print('triggered failure') |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
def check_path_length(self,key): |
|
|
current_state_label = self.get_current_state_number(key) |
|
|
previous_state_label = current_state_label |
|
|
theorem_object_length = 1 |
|
|
|
|
|
while True: |
|
|
|
|
|
if previous_state_label == 'state_0:': |
|
|
break |
|
|
previous_state_label = self.parent_node_of_node[key][current_state_label] |
|
|
theorem_object_length += 1 |
|
|
current_state_label = previous_state_label |
|
|
|
|
|
|
|
|
if theorem_object_length != len(self.prompts_tactic_state_list[key]): |
|
|
assert True==False, "path_length not equal to each other" |
|
|
def check_if_program_finished(self): |
|
|
stop_signal = True |
|
|
for key in self.test_theorem_list: |
|
|
if self.root[key] == 'open': |
|
|
stop_signal = False |
|
|
else: |
|
|
pass |
|
|
return stop_signal |
|
|
def revise_prompt(self, prompts_tactic_state_list_per_key): |
|
|
pattern = r'state_\d+:' |
|
|
matches = re.findall(pattern, prompts_tactic_state_list_per_key) |
|
|
state_order = {} |
|
|
order = 0 |
|
|
for match in matches: |
|
|
if match not in state_order: |
|
|
state_order[match] = order |
|
|
order += 1 |
|
|
|
|
|
for state, ord in state_order.items(): |
|
|
prompts_tactic_state_list_per_key = prompts_tactic_state_list_per_key.replace(state, f'state_{ord}:') |
|
|
|
|
|
last_state_id = None |
|
|
output_prompt = [] |
|
|
for line in prompts_tactic_state_list_per_key.split('\n'): |
|
|
if re.search('state_\d+:', line): |
|
|
last_state_id = line[6:-1] |
|
|
elif re.search('state_\d+_tactic_', line): |
|
|
line = f'state_{last_state_id}_tactic_0:' |
|
|
output_prompt.append(line) |
|
|
|
|
|
|
|
|
'''for idx, item in enumerate(prompts_tactic_state_list_per_key): |
|
|
temp_string = re.sub(r'state_(\d+):',f'state_{idx}:', item) |
|
|
prompts_tactic_state_list_per_key[idx] = re.sub(r'state_(\d+)_tactic_(\d+):',f'state_{idx-1}_tactic_0:', temp_string)''' |
|
|
return "\n".join(output_prompt) |
|
|
def status_report(self): |
|
|
counter_in_process = 0 |
|
|
counter_success = 0 |
|
|
counter_failed = 0 |
|
|
counter_too_long = 0 |
|
|
counter_failed_with_error = 0 |
|
|
for key in self.test_theorem_list: |
|
|
if self.root[key] == 'open': |
|
|
counter_in_process += 1 |
|
|
if self.root[key] == 'success': |
|
|
counter_success += 1 |
|
|
if self.root[key] == 'failed': |
|
|
counter_failed += 1 |
|
|
if self.root[key] == 'failed, too long': |
|
|
counter_too_long += 1 |
|
|
if self.root[key] == 'failed with error': |
|
|
counter_failed_with_error += 1 |
|
|
self.counter_success = counter_success |
|
|
self.counter_failed = counter_failed |
|
|
self.counter_failed_with_error = counter_failed_with_error |
|
|
self.counter_too_long = counter_too_long |
|
|
self.counter_in_process = counter_in_process |
|
|
if counter_success + counter_failed + counter_too_long + counter_in_process + counter_failed_with_error != len(test_theorem_list): |
|
|
assert False, 'success, failed, too long, in process, failed with error add up not equal to total number' |
|
|
print(f'saved_file_path is {self.saved_file_path}') |
|
|
print(f'total number of theorem is {len(self.test_theorem_list)}') |
|
|
print(f'proof success number is {self.counter_success}') |
|
|
print(f'proof failed number is {self.counter_failed}') |
|
|
print(f'proof failed with error number is {self.counter_failed_with_error}') |
|
|
print(f'proof too long number is {self.counter_too_long}') |
|
|
print(f'proof in process number is {self.counter_in_process}') |
|
|
|
|
|
count_lean_single_backtrack = 0 |
|
|
count_lean_multiple_backtrack = 0 |
|
|
count_lean_tactic_success = 0 |
|
|
for key in test_theorem_list: |
|
|
count_lean_single_backtrack += self.count_lean_dict[key]['count_lean_single_backtrack'] |
|
|
count_lean_multiple_backtrack += self.count_lean_dict[key]['count_lean_multiple_backtrack'] |
|
|
count_lean_tactic_success += self.count_lean_dict[key]['count_lean_tactic_success'] |
|
|
|
|
|
print(f'total lean count tactic success is {count_lean_tactic_success}') |
|
|
print(f'total lean count single backtrack is {count_lean_single_backtrack}') |
|
|
print(f'total lean count multiple backtrack is {count_lean_multiple_backtrack}') |
|
|
def collect_inference_result(self, key_to_be_infered, outputs): |
|
|
for idx, output_list in tqdm(enumerate(outputs), total=len(outputs), desc=f"Processing LLM output for Round {self.round}"): |
|
|
assinged_output_list_per_key = [] |
|
|
for i in range(0, self.num_sampled_tactics): |
|
|
output_tactic = self.revise_output_list(output_list.outputs[i].text) |
|
|
if output_tactic[0] == 'no_tactic' or output_tactic[1] == 'no_tactic': |
|
|
pass |
|
|
else: |
|
|
assinged_output_list_per_key.append(output_tactic) |
|
|
|
|
|
|
|
|
seen = set() |
|
|
unique_assigned_output_list_per_key = [] |
|
|
for inner_list in assinged_output_list_per_key: |
|
|
inner_tuple = tuple(inner_list) |
|
|
if inner_tuple not in seen: |
|
|
seen.add(inner_tuple) |
|
|
unique_assigned_output_list_per_key.append(inner_list) |
|
|
|
|
|
|
|
|
|
|
|
self.tactic_list_tree[key_to_be_infered[idx]][self.cstate_round[key_to_be_infered[idx]]] = unique_assigned_output_list_per_key |
|
|
|
|
|
def current_state_obtained_list(self, key): |
|
|
if self.root[key] == 'open': |
|
|
self.cstate_round[key] = self.get_current_state_number(key) |
|
|
cstate = self.cstate_round[key] |
|
|
pickle.dump(cstate, open(f'~/leandojo_project/atp_research/DFS/temp/current_state_{key}_{self.round}_{self.experiment_id}.pkl','wb')) |
|
|
def search(self): |
|
|
tokenizer = llm.get_tokenizer() |
|
|
while True: |
|
|
self.round += 1 |
|
|
print(f'Round {self.round}------') |
|
|
if self.check_if_program_finished() or self.round > 65: |
|
|
print('confirmed test theorem finished. exit.') |
|
|
self.status_report() |
|
|
break |
|
|
|
|
|
self.key_to_be_infered = [] |
|
|
self.key_not_finished = [] |
|
|
self.prompts_entered = [] |
|
|
|
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor: |
|
|
list(tqdm(executor.map(self.current_state_obtained_list, self.test_theorem_list), total=len(self.test_theorem_list), desc=f"get current state for Round {self.round}")) |
|
|
|
|
|
key_prompt_length_list = [] |
|
|
for key in tqdm(self.test_theorem_list, total=len(self.test_theorem_list), desc=f"check state for theorems for round {self.round}" ): |
|
|
if self.root[key] == 'open': |
|
|
self.round_count[key] = self.round |
|
|
self.key_not_finished.append(key) |
|
|
self.cstate_round[key] = pickle.load(open(f'~/leandojo_project/atp_research/DFS/temp/current_state_{key}_{self.round}_{self.experiment_id}.pkl','rb')) |
|
|
if self.tactic_list_tree[key][self.cstate_round[key]] == None: |
|
|
prompt_per_key = self.revise_prompt('\n'.join(self.prompts_tactic_state_list[key])) |
|
|
key_prompt_length_list.append(len(prompt_per_key.split())) |
|
|
tokenized_prompt_per_key = tokenizer.encode(prompt_per_key) |
|
|
if len(prompt_per_key.split()) < 1500 and len(tokenized_prompt_per_key) < 4000: |
|
|
self.key_to_be_infered.append(key) |
|
|
self.prompts_entered.append(prompt_per_key) |
|
|
else: |
|
|
self.root[key] = 'failed, too long' |
|
|
self.key_not_finished.remove(key) |
|
|
print(f'key open need inference before check length, length list is {key_prompt_length_list}') |
|
|
|
|
|
print(f'key to be infered is {self.key_to_be_infered}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampling_params = SamplingParams(n=self.num_sampled_tactics, temperature=self.temperature, top_p=1, |
|
|
max_tokens=200) |
|
|
outputs = llm.generate(self.prompts_entered, sampling_params) |
|
|
|
|
|
|
|
|
print('now we collect the inference') |
|
|
self.collect_inference_result(self.key_to_be_infered, outputs) |
|
|
print('inference collected') |
|
|
|
|
|
print(f'enter concurrent process with max_workers as {self.max_workers}') |
|
|
|
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor: |
|
|
list(tqdm(executor.map(self.search_per_key_per_step, self.key_not_finished), total=len(self.key_not_finished), desc=f"Lean_verifying for Round {self.round}")) |
|
|
|
|
|
'''for key in self.key_not_finished: |
|
|
self.search_per_key_per_step(key)''' |
|
|
|
|
|
|
|
|
for key in tqdm(self.key_not_finished, total=len(self.key_not_finished), desc=f"Retrieve results from concurrent processing for Round {self.round}"): |
|
|
report_key_file = pickle.load(open(f'~/leandojo_project/atp_research/DFS/temp/temp_file_{key}_{self.round}_{self.experiment_id}.pkl','rb')) |
|
|
self.count_lean_dict[key]['count_lean_single_backtrack'] += report_key_file['count_lean_single_backtrack'] |
|
|
self.count_lean_dict[key]['count_lean_multiple_backtrack'] += report_key_file['count_lean_multiple_backtrack'] |
|
|
self.count_lean_dict[key]['count_lean_tactic_success'] += report_key_file['count_lean_tactic_success'] |
|
|
self.root[key] = report_key_file['key_status'] |
|
|
self.tactic_list_tree[key] = report_key_file['tactic_list_tree'] |
|
|
self.prompts_tactic_state_list[key] = report_key_file['prompts_tactic_state_list'] |
|
|
self.theorem_object_dict[key] = report_key_file['theorem_object_dict'] |
|
|
self.parent_node_of_node[key] = report_key_file['node_relation'] |
|
|
self.whether_backtrack[key] = report_key_file['whether_backtrack'] |
|
|
|
|
|
self.status_report() |
|
|
self.save_outcome() |
|
|
def search_per_key_per_step(self, key): |
|
|
|
|
|
|
|
|
try: |
|
|
key_status = 'open' |
|
|
count_lean_tactic = 0 |
|
|
count_lean_single_backtrack = 0 |
|
|
count_lean_multiple_backtrack = 0 |
|
|
whether_backtrack = False |
|
|
|
|
|
tactic_list_at_top_per_key = self.tactic_list_tree[key][self.get_current_state_number(key)] |
|
|
|
|
|
if tactic_list_at_top_per_key == None: |
|
|
assert False, f"tactic_list_at_top_per_key is None, key is {key}\ncurrent tactic list is {self.tactic_list_tree[key]}" \ |
|
|
f"\nwhether key in key_to_be_infered {key in self.key_to_be_infered}" |
|
|
|
|
|
if len(tactic_list_at_top_per_key) != 0: |
|
|
try: |
|
|
|
|
|
count_lean_tactic += 1 |
|
|
entered_tactic = self.revise_entered_tactic(tactic_list_at_top_per_key[0], key) |
|
|
label_before_tactic = self.get_current_state_number(key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lean_output = self.theorem_object_dict[key].provide_tactic(entered_tactic[0], entered_tactic[1]) |
|
|
label_after_tactic = self.get_current_state_number(key) |
|
|
|
|
|
|
|
|
if 'proof is complete' in lean_output[1]: |
|
|
key_status = 'success' |
|
|
print(f'key is {key}, search is success!') |
|
|
print(f'key is {key}, the successful proof \n{self.theorem_object_dict[key].get_current_lean_proof()}') |
|
|
|
|
|
|
|
|
self.parent_node_of_node[key][label_after_tactic] = label_before_tactic |
|
|
self.prompts_tactic_state_list[key].append( |
|
|
f"{entered_tactic[0]}\n{entered_tactic[1]}\n{lean_output[1]}") |
|
|
self.tactic_list_tree[key][self.get_current_state_number(key)] = None |
|
|
del self.tactic_list_tree[key][label_before_tactic][0] |
|
|
|
|
|
self.check_path_length(key) |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del self.tactic_list_tree[key][self.get_current_state_number(key)][0] |
|
|
|
|
|
self.check_path_length(key) |
|
|
if self.check_if_failure_per_key(key): |
|
|
key_status = 'failed' |
|
|
print(f"key is {key}, tactic error then search failed!") |
|
|
|
|
|
else: |
|
|
if self.check_if_failure_per_key(key): |
|
|
key_status = 'failed' |
|
|
print(f"key is {key}, backtrack to zero and no tactic to try, search failed!") |
|
|
|
|
|
|
|
|
print(f'key is {key}, backtrack phase activated') |
|
|
whether_backtrack = True |
|
|
count_lean_single_backtrack += 1 |
|
|
while True: |
|
|
tactic_list_at_intermediate_node = self.tactic_list_tree[key][self.get_current_state_number(key)] |
|
|
|
|
|
|
|
|
if len(tactic_list_at_intermediate_node) != 0: |
|
|
break |
|
|
if self.check_if_failure_per_key(key): |
|
|
key_status = 'failed' |
|
|
print(f"key is {key}, backtrack to zero and no tactic to try, search failed!") |
|
|
|
|
|
break |
|
|
|
|
|
self.check_path_length(key) |
|
|
|
|
|
count_lean_multiple_backtrack += 1 |
|
|
lean_output = self.theorem_object_dict[key].do_back_track(self.back_track_tactic(key)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del self.prompts_tactic_state_list[key][-1] |
|
|
|
|
|
|
|
|
self.check_path_length(key) |
|
|
except Exception as e: |
|
|
print(f"key is {key}, exception happend") |
|
|
print(e) |
|
|
key_status = 'failed with error' |
|
|
count_lean_tactic = 0 |
|
|
count_lean_single_backtrack = 0 |
|
|
count_lean_multiple_backtrack = 0 |
|
|
|
|
|
finally: |
|
|
report_key_file = {} |
|
|
report_key_file['count_lean_multiple_backtrack'] = count_lean_multiple_backtrack + count_lean_tactic |
|
|
report_key_file['count_lean_single_backtrack'] = count_lean_single_backtrack + count_lean_tactic |
|
|
report_key_file['count_lean_tactic_success'] = count_lean_tactic |
|
|
report_key_file['key_status'] = key_status |
|
|
report_key_file['tactic_list_tree'] = self.tactic_list_tree[key] |
|
|
report_key_file['prompts_tactic_state_list'] = self.prompts_tactic_state_list[key] |
|
|
report_key_file['theorem_object_dict'] = self.theorem_object_dict[key] |
|
|
report_key_file['node_relation'] = self.parent_node_of_node[key] |
|
|
report_key_file['whether_backtrack'] = whether_backtrack |
|
|
pickle.dump(report_key_file, open(f'~/leandojo_project/atp_research/DFS/temp/temp_file_{key}_{self.round}_{self.experiment_id}.pkl','wb')) |
|
|
|
|
|
|
|
|
|
|
|
def save_outcome(self): |
|
|
counter_in_process = 0 |
|
|
counter_success = 0 |
|
|
counter_failed = 0 |
|
|
counter_too_long = 0 |
|
|
count_lean_single_backtrack = 0 |
|
|
count_lean_multiple_backtrack = 0 |
|
|
count_lean_tactic_success = 0 |
|
|
counter_failed_with_error = 0 |
|
|
|
|
|
proof_dict = {} |
|
|
for key in tqdm(self.test_theorem_list,total=len(self.test_theorem_list),desc='saving results'): |
|
|
count_lean_single_backtrack += self.count_lean_dict[key]['count_lean_single_backtrack'] |
|
|
count_lean_multiple_backtrack += self.count_lean_dict[key]['count_lean_multiple_backtrack'] |
|
|
count_lean_tactic_success += self.count_lean_dict[key]['count_lean_tactic_success'] |
|
|
proof_dict[key] = self.theorem_object_dict[key].get_current_lean_proof() |
|
|
if self.root[key] == 'open': |
|
|
counter_in_process += 1 |
|
|
if self.root[key] == 'success': |
|
|
counter_success += 1 |
|
|
if self.root[key] == 'failed': |
|
|
counter_failed += 1 |
|
|
if self.root[key] == 'failed, too long': |
|
|
counter_too_long += 1 |
|
|
if self.root[key] == 'failed with error': |
|
|
counter_failed_with_error += 1 |
|
|
self.counter_success = counter_success |
|
|
self.counter_failed = counter_failed |
|
|
self.counter_failed_with_error = counter_failed_with_error |
|
|
self.counter_too_long = counter_too_long |
|
|
if counter_success + counter_failed + counter_too_long + counter_failed_with_error + counter_in_process!= len(test_theorem_list): |
|
|
assert False, 'number of theorm not equal to success, failed or too long, or in process' |
|
|
outcome = {} |
|
|
outcome['stats'] = {} |
|
|
outcome['stats']['total_lean_count_single_backtrack'] = count_lean_single_backtrack |
|
|
outcome['stats']['total_lean_count_multiple_backtrack'] = count_lean_multiple_backtrack |
|
|
outcome['stats']['count_lean_tactic_success'] = count_lean_tactic_success |
|
|
outcome['stats']['num_success'] = self.counter_success |
|
|
outcome['stats']['num_failed'] = self.counter_failed |
|
|
outcome['stats']['num_failed_with_error'] = self.counter_failed_with_error |
|
|
outcome['stats']['num_too_long'] = self.counter_too_long |
|
|
outcome['stats']['num_sampled_tactics'] = self.num_sampled_tactics |
|
|
outcome['stats']['temperature'] = self.temperature |
|
|
outcome['key_final_state'] = self.root |
|
|
outcome['key_lean_count'] = self.count_lean_dict |
|
|
outcome['key_proof'] = proof_dict |
|
|
outcome['tactic_list_tree'] = self.tactic_list_tree |
|
|
outcome['round_count'] = self.round_count |
|
|
|
|
|
pickle.dump(outcome, open(self.saved_file_path, 'wb')) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
'''os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
if os.environ.get("TOKENIZERS_PARALLELISM") == "true": |
|
|
print("TOKENIZERS_PARALLELISM is set to true") |
|
|
else: |
|
|
print("TOKENIZERS_PARALLELISM is not set to true")''' |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Description of your program.') |
|
|
parser.add_argument('checkpoint_path', type=str, help='checkpoint_path') |
|
|
parser.add_argument('number_of_gpu', type=int, help='number_of_gpu') |
|
|
parser.add_argument('test_data_path', type=str, help='test_data_path') |
|
|
parser.add_argument('saved_file_path', type=str, help='test_data_path') |
|
|
parser.add_argument('max_workers', type=str, help='test_data_path') |
|
|
parser.add_argument('num_sampled_tactics', type=str, help='test_data_path') |
|
|
parser.add_argument('temperature', type=str, help='test_data_path') |
|
|
parser.add_argument('num_test_theorem', type=str, help='test_data_path') |
|
|
|
|
|
args = parser.parse_args() |
|
|
checkpoint = args.checkpoint_path |
|
|
number_of_gpu = args.number_of_gpu |
|
|
test_data_path = args.test_data_path |
|
|
saved_file_path = args.saved_file_path |
|
|
max_workers = args.max_workers |
|
|
num_sampled_tactics = args.num_sampled_tactics |
|
|
temperature = args.temperature |
|
|
num_test_theorem = args.num_test_theorem |
|
|
swap_space = 100 |
|
|
|
|
|
print(f'checkpoint is {checkpoint}') |
|
|
print(f'number_of_gpu is {number_of_gpu}') |
|
|
print(f'test_data_path is {test_data_path}') |
|
|
print(f'saved_file_path is {saved_file_path}') |
|
|
print(f'num_test_theorem is {num_test_theorem}') |
|
|
|
|
|
|
|
|
print(f'max_workers is {max_workers}') |
|
|
print(f'number_sampled_tactic is {num_sampled_tactics}') |
|
|
print(f'temperature is {temperature}') |
|
|
|
|
|
print(f'swap_space is {swap_space}') |
|
|
|
|
|
|
|
|
random.seed(42) |
|
|
with open(test_data_path, 'r') as f: |
|
|
test_theorem_list = json.load(f) |
|
|
|
|
|
|
|
|
test_theorem_list = test_theorem_list[:int(num_test_theorem)] |
|
|
|
|
|
|
|
|
|
|
|
experiment_id = uuid.uuid4() |
|
|
print(experiment_id) |
|
|
|
|
|
|
|
|
|
|
|
llm = LLM(model=checkpoint, tensor_parallel_size=number_of_gpu, swap_space=swap_space) |
|
|
evaluate_obj = DFS(num_sampled_tactics=int(num_sampled_tactics), temperature=float(temperature), test_theorem_list=test_theorem_list, max_workers=int(max_workers), saved_file_path=saved_file_path, experiment_id=experiment_id) |
|
|
evaluate_obj.search() |
|
|
print('Now we start saving') |
|
|
evaluate_obj.save_outcome() |
|
|
print('Now we finish saving. exit') |
|
|
|
|
|
|
|
|
|