Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from dotenv import load_dotenv | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_ollama import ChatOllama | |
| from langgraph.graph import END, START, Graph, StateGraph | |
| from typing_extensions import TypedDict | |
| load_dotenv() | |
| base_url = os.getenv("OLLAMA_BASE_URL") | |
| class AgentState(TypedDict): | |
| """State for the final answer validation graph.""" | |
| question: str | |
| answer: str | |
| final_answer: str | None | |
| agent_memory: Any | |
| valid_answer: bool | |
| def extract_answer(state: AgentState) -> Dict: | |
| """Extract and format the final answer from the state. | |
| Args: | |
| state: The state of the agent. | |
| Returns: | |
| A dictionary with the formatted final answer. | |
| """ | |
| # Extract the final answer from the state | |
| sep_token = "FINAL ANSWER:" | |
| raw_answer = state["answer"] | |
| # Extract the answer after the separator if it exists | |
| if sep_token in raw_answer: | |
| formatted_answer = raw_answer.split(sep_token)[1].strip() | |
| else: | |
| formatted_answer = raw_answer.strip() | |
| # Remove any brackets from lists | |
| formatted_answer = formatted_answer.replace("[", "").replace("]", "") | |
| # Remove units unless specified | |
| if not any( | |
| unit in formatted_answer.lower() for unit in ["$", "%", "dollars", "percent"] | |
| ): | |
| formatted_answer = formatted_answer.replace("$", "").replace("%", "") | |
| # Remove commas from numbers | |
| parts = formatted_answer.split(",") | |
| formatted_parts = [] | |
| for part in parts: | |
| part = part.strip() | |
| if part.replace(".", "").isdigit(): # Check if it's a number | |
| part = part.replace(",", "") | |
| formatted_parts.append(part) | |
| formatted_answer = ", ".join(formatted_parts) | |
| return {"final_answer": formatted_answer} | |
| def reasoning_check(state: AgentState) -> Dict: | |
| """ | |
| Node that checks the reasoning of the final answer. | |
| Args: | |
| state: The state of the agent. | |
| Returns: | |
| A dictionary with the reasoning check result. | |
| """ | |
| model = ChatOllama( | |
| model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K", | |
| base_url=base_url, | |
| temperature=0.2, | |
| ) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| """You are a strict validator of answers. Your job is to check if the reasoning and results are correct. | |
| You should have >90% confidence that the answer is correct to pass it. | |
| First list reasons why yes/no, then write your final decision: PASS in caps lock if it is satisfactory, FAIL if it is not.""", | |
| ), | |
| ( | |
| "human", | |
| """ | |
| Here is a user-given task and the agent steps: {agent_memory} | |
| Now here is the answer that was given: {final_answer} | |
| Please check that the reasoning process and results are correct: do they correctly answer the given task? | |
| """, | |
| ), | |
| ] | |
| ) | |
| chain = prompt | model | StrOutputParser() | |
| output = chain.invoke( | |
| { | |
| "agent_memory": state["agent_memory"], | |
| "final_answer": state["final_answer"], | |
| } | |
| ) | |
| print("Reasoning Feedback: ", output) | |
| if "FAIL" in output: | |
| return {"valid_answer": False} | |
| torch.cuda.empty_cache() | |
| return {"valid_answer": True} | |
| def formatting_check(state: AgentState) -> Dict: | |
| """ | |
| Node that checks the formatting of the final answer. | |
| Args: | |
| state: The state of the agent. | |
| Returns: | |
| A dictionary with the formatting check result. | |
| """ | |
| model = ChatOllama( | |
| model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K", | |
| base_url=base_url, | |
| temperature=0.2, | |
| ) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| """, | |
| ), | |
| ( | |
| "human", | |
| """ | |
| Here is a user-given task and the agent steps: {agent_memory} | |
| Now here is the FINAL ANSWER that was given: {final_answer} | |
| Ensure the FINAL ANSWER is in the right format as asked for by the task. | |
| """, | |
| ), | |
| ] | |
| ) | |
| chain = prompt | model | StrOutputParser() | |
| output = chain.invoke( | |
| { | |
| "agent_memory": state["agent_memory"], | |
| "final_answer": state["final_answer"], | |
| } | |
| ) | |
| print("Formatting Feedback: ", output) | |
| if "FAIL" in output: | |
| return {"valid_answer": False} | |
| torch.cuda.empty_cache() | |
| return {"valid_answer": True} | |
| def create_final_answer_graph() -> Graph: | |
| """Create a graph that validates the final answer. | |
| Returns: | |
| A graph that validates the final answer. | |
| """ | |
| # Create the graph | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("extract_answer", extract_answer) | |
| workflow.add_node("reasoning_check", reasoning_check) | |
| workflow.add_node("formatting_check", formatting_check) | |
| # Add edges | |
| workflow.add_edge(START, "extract_answer") | |
| workflow.add_edge("extract_answer", "reasoning_check") | |
| workflow.add_edge("reasoning_check", "formatting_check") | |
| workflow.add_edge("formatting_check", END) | |
| # Compile the graph | |
| return workflow.compile() # type: ignore | |
| def validate_answer(graph: StateGraph, answer: str, agent_memory: Any) -> Dict: | |
| """Validate the answer using the LangGraph workflow. | |
| Args: | |
| graph: The validation graph (LangGraph StateGraph). | |
| answer: The answer to validate. | |
| agent_memory: The agent's memory. | |
| Returns: | |
| A dictionary with validation results. | |
| """ | |
| try: | |
| # Initialize state | |
| initial_state = { | |
| "answer": answer, | |
| "final_answer": None, | |
| "agent_memory": agent_memory, | |
| "valid_answer": False, | |
| } | |
| # Run the graph | |
| result = graph.invoke(initial_state) # type:ignore | |
| return { | |
| "valid_answer": result.get("valid_answer", False), | |
| "final_answer": result.get("final_answer", None), | |
| } | |
| except Exception as e: | |
| print(f"Validation failed: {e}") | |
| return {"valid_answer": False, "final_answer": None} | |