Final_Assignment_Template / src /final_answer.py
AK47-M4A4's picture
v1
ae1d0b9
raw
history blame
7.07 kB
import os
from typing import Any, Dict, Optional
import torch
from dotenv import load_dotenv
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langgraph.graph import END, START, Graph, StateGraph
from typing_extensions import TypedDict
load_dotenv()
base_url = os.getenv("OLLAMA_BASE_URL")
class AgentState(TypedDict):
"""State for the final answer validation graph."""
question: str
answer: str
final_answer: str | None
agent_memory: Any
valid_answer: bool
def extract_answer(state: AgentState) -> Dict:
"""Extract and format the final answer from the state.
Args:
state: The state of the agent.
Returns:
A dictionary with the formatted final answer.
"""
# Extract the final answer from the state
sep_token = "FINAL ANSWER:"
raw_answer = state["answer"]
# Extract the answer after the separator if it exists
if sep_token in raw_answer:
formatted_answer = raw_answer.split(sep_token)[1].strip()
else:
formatted_answer = raw_answer.strip()
# Remove any brackets from lists
formatted_answer = formatted_answer.replace("[", "").replace("]", "")
# Remove units unless specified
if not any(
unit in formatted_answer.lower() for unit in ["$", "%", "dollars", "percent"]
):
formatted_answer = formatted_answer.replace("$", "").replace("%", "")
# Remove commas from numbers
parts = formatted_answer.split(",")
formatted_parts = []
for part in parts:
part = part.strip()
if part.replace(".", "").isdigit(): # Check if it's a number
part = part.replace(",", "")
formatted_parts.append(part)
formatted_answer = ", ".join(formatted_parts)
return {"final_answer": formatted_answer}
def reasoning_check(state: AgentState) -> Dict:
"""
Node that checks the reasoning of the final answer.
Args:
state: The state of the agent.
Returns:
A dictionary with the reasoning check result.
"""
model = ChatOllama(
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
base_url=base_url,
temperature=0.2,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a strict validator of answers. Your job is to check if the reasoning and results are correct.
You should have >90% confidence that the answer is correct to pass it.
First list reasons why yes/no, then write your final decision: PASS in caps lock if it is satisfactory, FAIL if it is not.""",
),
(
"human",
"""
Here is a user-given task and the agent steps: {agent_memory}
Now here is the answer that was given: {final_answer}
Please check that the reasoning process and results are correct: do they correctly answer the given task?
""",
),
]
)
chain = prompt | model | StrOutputParser()
output = chain.invoke(
{
"agent_memory": state["agent_memory"],
"final_answer": state["final_answer"],
}
)
print("Reasoning Feedback: ", output)
if "FAIL" in output:
return {"valid_answer": False}
torch.cuda.empty_cache()
return {"valid_answer": True}
def formatting_check(state: AgentState) -> Dict:
"""
Node that checks the formatting of the final answer.
Args:
state: The state of the agent.
Returns:
A dictionary with the formatting check result.
"""
model = ChatOllama(
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
base_url=base_url,
temperature=0.2,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
""",
),
(
"human",
"""
Here is a user-given task and the agent steps: {agent_memory}
Now here is the FINAL ANSWER that was given: {final_answer}
Ensure the FINAL ANSWER is in the right format as asked for by the task.
""",
),
]
)
chain = prompt | model | StrOutputParser()
output = chain.invoke(
{
"agent_memory": state["agent_memory"],
"final_answer": state["final_answer"],
}
)
print("Formatting Feedback: ", output)
if "FAIL" in output:
return {"valid_answer": False}
torch.cuda.empty_cache()
return {"valid_answer": True}
def create_final_answer_graph() -> Graph:
"""Create a graph that validates the final answer.
Returns:
A graph that validates the final answer.
"""
# Create the graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("extract_answer", extract_answer)
workflow.add_node("reasoning_check", reasoning_check)
workflow.add_node("formatting_check", formatting_check)
# Add edges
workflow.add_edge(START, "extract_answer")
workflow.add_edge("extract_answer", "reasoning_check")
workflow.add_edge("reasoning_check", "formatting_check")
workflow.add_edge("formatting_check", END)
# Compile the graph
return workflow.compile() # type: ignore
def validate_answer(graph: StateGraph, answer: str, agent_memory: Any) -> Dict:
"""Validate the answer using the LangGraph workflow.
Args:
graph: The validation graph (LangGraph StateGraph).
answer: The answer to validate.
agent_memory: The agent's memory.
Returns:
A dictionary with validation results.
"""
try:
# Initialize state
initial_state = {
"answer": answer,
"final_answer": None,
"agent_memory": agent_memory,
"valid_answer": False,
}
# Run the graph
result = graph.invoke(initial_state) # type:ignore
return {
"valid_answer": result.get("valid_answer", False),
"final_answer": result.get("final_answer", None),
}
except Exception as e:
print(f"Validation failed: {e}")
return {"valid_answer": False, "final_answer": None}