import os from typing import Any, Dict, Optional import torch from dotenv import load_dotenv from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_ollama import ChatOllama from langgraph.graph import END, START, Graph, StateGraph from typing_extensions import TypedDict load_dotenv() base_url = os.getenv("OLLAMA_BASE_URL") class AgentState(TypedDict): """State for the final answer validation graph.""" question: str answer: str final_answer: str | None agent_memory: Any valid_answer: bool def extract_answer(state: AgentState) -> Dict: """Extract and format the final answer from the state. Args: state: The state of the agent. Returns: A dictionary with the formatted final answer. """ # Extract the final answer from the state sep_token = "FINAL ANSWER:" raw_answer = state["answer"] # Extract the answer after the separator if it exists if sep_token in raw_answer: formatted_answer = raw_answer.split(sep_token)[1].strip() else: formatted_answer = raw_answer.strip() # Remove any brackets from lists formatted_answer = formatted_answer.replace("[", "").replace("]", "") # Remove units unless specified if not any( unit in formatted_answer.lower() for unit in ["$", "%", "dollars", "percent"] ): formatted_answer = formatted_answer.replace("$", "").replace("%", "") # Remove commas from numbers parts = formatted_answer.split(",") formatted_parts = [] for part in parts: part = part.strip() if part.replace(".", "").isdigit(): # Check if it's a number part = part.replace(",", "") formatted_parts.append(part) formatted_answer = ", ".join(formatted_parts) return {"final_answer": formatted_answer} def reasoning_check(state: AgentState) -> Dict: """ Node that checks the reasoning of the final answer. Args: state: The state of the agent. Returns: A dictionary with the reasoning check result. """ model = ChatOllama( model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K", base_url=base_url, temperature=0.2, ) prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a strict validator of answers. Your job is to check if the reasoning and results are correct. You should have >90% confidence that the answer is correct to pass it. First list reasons why yes/no, then write your final decision: PASS in caps lock if it is satisfactory, FAIL if it is not.""", ), ( "human", """ Here is a user-given task and the agent steps: {agent_memory} Now here is the answer that was given: {final_answer} Please check that the reasoning process and results are correct: do they correctly answer the given task? """, ), ] ) chain = prompt | model | StrOutputParser() output = chain.invoke( { "agent_memory": state["agent_memory"], "final_answer": state["final_answer"], } ) print("Reasoning Feedback: ", output) if "FAIL" in output: return {"valid_answer": False} torch.cuda.empty_cache() return {"valid_answer": True} def formatting_check(state: AgentState) -> Dict: """ Node that checks the formatting of the final answer. Args: state: The state of the agent. Returns: A dictionary with the formatting check result. """ model = ChatOllama( model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K", base_url=base_url, temperature=0.2, ) prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. """, ), ( "human", """ Here is a user-given task and the agent steps: {agent_memory} Now here is the FINAL ANSWER that was given: {final_answer} Ensure the FINAL ANSWER is in the right format as asked for by the task. """, ), ] ) chain = prompt | model | StrOutputParser() output = chain.invoke( { "agent_memory": state["agent_memory"], "final_answer": state["final_answer"], } ) print("Formatting Feedback: ", output) if "FAIL" in output: return {"valid_answer": False} torch.cuda.empty_cache() return {"valid_answer": True} def create_final_answer_graph() -> Graph: """Create a graph that validates the final answer. Returns: A graph that validates the final answer. """ # Create the graph workflow = StateGraph(AgentState) # Add nodes workflow.add_node("extract_answer", extract_answer) workflow.add_node("reasoning_check", reasoning_check) workflow.add_node("formatting_check", formatting_check) # Add edges workflow.add_edge(START, "extract_answer") workflow.add_edge("extract_answer", "reasoning_check") workflow.add_edge("reasoning_check", "formatting_check") workflow.add_edge("formatting_check", END) # Compile the graph return workflow.compile() # type: ignore def validate_answer(graph: StateGraph, answer: str, agent_memory: Any) -> Dict: """Validate the answer using the LangGraph workflow. Args: graph: The validation graph (LangGraph StateGraph). answer: The answer to validate. agent_memory: The agent's memory. Returns: A dictionary with validation results. """ try: # Initialize state initial_state = { "answer": answer, "final_answer": None, "agent_memory": agent_memory, "valid_answer": False, } # Run the graph result = graph.invoke(initial_state) # type:ignore return { "valid_answer": result.get("valid_answer", False), "final_answer": result.get("final_answer", None), } except Exception as e: print(f"Validation failed: {e}") return {"valid_answer": False, "final_answer": None}