JatsTheAIGen's picture
workflow errors debugging v7
55f436b
raw
history blame
15.9 kB
"""
Safety & Bias Mitigation Agent
Specialized in content moderation and bias detection with non-blocking warnings
"""
import logging
import re
from typing import Dict, Any, List, Tuple
logger = logging.getLogger(__name__)
class SafetyCheckAgent:
def __init__(self, llm_router=None):
self.llm_router = llm_router
self.agent_id = "SAFETY_BIAS_001"
self.specialization = "Content moderation and bias detection with warning-based approach"
# Safety thresholds (non-blocking, warning-only)
self.safety_thresholds = {
"toxicity": 0.8, # High threshold for warnings
"bias": 0.7, # Moderate threshold for bias detection
"safety": 0.6, # Lower threshold for general safety
"privacy": 0.9 # Very high threshold for privacy concerns
}
# Warning templates (non-blocking)
self.warning_templates = {
"toxicity": "⚠️ Note: Content may contain strong language",
"bias": "πŸ” Note: Potential biases detected in response",
"safety": "πŸ“ Note: Response should be verified for accuracy",
"privacy": "πŸ”’ Note: Privacy-sensitive topics discussed",
"controversial": "πŸ’­ Note: This topic may have multiple perspectives"
}
# Pattern-based detection for quick analysis
self.sensitive_patterns = {
"toxicity": [
r'\b(hate|violence|harm|attack|destroy)\b',
r'\b(kill|hurt|harm|danger)\b',
r'racial slurs', # Placeholder for actual sensitive terms
],
"bias": [
r'\b(all|always|never|every)\b', # Overgeneralizations
r'\b(should|must|have to)\b', # Prescriptive language
r'stereotypes?', # Stereotype indicators
],
"privacy": [
r'\b(ssn|social security|password|credit card)\b',
r'\b(address|phone|email|personal)\b',
r'\b(confidential|secret|private)\b',
]
}
async def execute(self, response, context: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
"""
Execute safety check with non-blocking warnings
Returns original response with added warnings
"""
try:
# Handle both string and dict inputs
if isinstance(response, dict):
# Extract the actual response string from the dict
response_text = response.get('final_response', response.get('response', str(response)))
else:
response_text = str(response)
logger.info(f"{self.agent_id} analyzing response of length {len(response_text)}")
# Perform safety analysis
safety_analysis = await self._analyze_safety(response_text, context)
# Generate warnings without modifying response
warnings = self._generate_warnings(safety_analysis)
# Add safety metadata to response
result = {
"original_response": response_text,
"safety_checked_response": response_text, # Response never modified
"warnings": warnings,
"safety_analysis": safety_analysis,
"blocked": False, # Never blocks content
"confidence_scores": safety_analysis.get("confidence_scores", {}),
"agent_id": self.agent_id
}
logger.info(f"{self.agent_id} completed with {len(warnings)} warnings")
return result
except Exception as e:
logger.error(f"{self.agent_id} error: {str(e)}", exc_info=True)
# Fail-safe: return original response with error note
response_text = str(response) if not isinstance(response, dict) else response.get('final_response', str(response))
return self._get_fallback_result(response_text)
async def _analyze_safety(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze response for safety concerns using multiple methods"""
if self.llm_router:
return await self._llm_based_safety_analysis(response, context)
else:
return await self._pattern_based_safety_analysis(response)
async def _llm_based_safety_analysis(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]:
"""Use LLM for sophisticated safety analysis"""
safety_prompt = self._build_safety_prompt(response, context)
# Simulate LLM analysis (replace with actual LLM call)
simulated_analysis = {
"toxicity_score": self._calculate_toxicity_score(response),
"bias_indicators": self._detect_bias_indicators(response),
"privacy_concerns": self._check_privacy_issues(response),
"overall_safety_score": 0.85, # Simulated score
"confidence_scores": {
"toxicity": 0.7,
"bias": 0.6,
"safety": 0.8,
"privacy": 0.9
},
"detected_issues": self._pattern_based_detection(response),
"analysis_method": "llm_enhanced"
}
return simulated_analysis
async def _pattern_based_safety_analysis(self, response: str) -> Dict[str, Any]:
"""Pattern-based safety analysis as fallback"""
detected_issues = self._pattern_based_detection(response)
return {
"toxicity_score": self._calculate_toxicity_score(response),
"bias_indicators": self._detect_bias_indicators(response),
"privacy_concerns": self._check_privacy_issues(response),
"overall_safety_score": 0.75, # Conservative estimate
"confidence_scores": {
"toxicity": 0.6,
"bias": 0.5,
"safety": 0.7,
"privacy": 0.8
},
"detected_issues": detected_issues,
"analysis_method": "pattern_based"
}
def _build_safety_prompt(self, response: str, context: Dict[str, Any]) -> str:
"""Build prompt for safety analysis"""
return f"""
Analyze the following text for safety concerns and provide warnings if needed:
Text to analyze: "{response[:500]}..." # Truncate for large responses
Analyze for:
1. Toxicity or harmful language
2. Potential biases or stereotypes
3. Privacy or security concerns
4. Controversial or sensitive topics
Provide analysis in JSON format with:
- safety_scores for each category (0-1 scale)
- detected_issues list
- confidence_level
- recommended_warnings (non-blocking, advisory only)
IMPORTANT: Never block or modify the content, only provide warnings.
"""
def _pattern_based_detection(self, response: str) -> List[Dict[str, Any]]:
"""Detect safety issues using pattern matching"""
issues = []
response_lower = response.lower()
# Check each category
for category, patterns in self.sensitive_patterns.items():
for pattern in patterns:
if re.search(pattern, response_lower, re.IGNORECASE):
issues.append({
"category": category,
"pattern": pattern,
"severity": "low", # Always low for warning-only approach
"confidence": 0.7
})
break # Only report one pattern match per category
return issues
def _calculate_toxicity_score(self, response: str) -> float:
"""Calculate toxicity score (simplified version)"""
# Simple heuristic-based toxicity detection
toxic_indicators = [
'hate', 'violence', 'harm', 'attack', 'destroy', 'kill', 'hurt'
]
score = 0.0
words = response.lower().split()
for indicator in toxic_indicators:
if indicator in words:
score += 0.2
return min(1.0, score)
def _detect_bias_indicators(self, response: str) -> List[str]:
"""Detect potential bias indicators"""
biases = []
# Overgeneralization detection
if re.search(r'\b(all|always|never|every)\s+\w+s\b', response, re.IGNORECASE):
biases.append("overgeneralization")
# Prescriptive language
if re.search(r'\b(should|must|have to|ought to)\b', response, re.IGNORECASE):
biases.append("prescriptive_language")
# Stereotype indicators
stereotype_patterns = [
r'\b(all|most)\s+\w+\s+people\b',
r'\b(typical|usual|normal)\s+\w+\b',
]
for pattern in stereotype_patterns:
if re.search(pattern, response, re.IGNORECASE):
biases.append("potential_stereotype")
break
return biases
def _check_privacy_issues(self, response: str) -> List[str]:
"""Check for privacy-sensitive content"""
privacy_issues = []
# Personal information patterns
personal_info_patterns = [
r'\b\d{3}-\d{2}-\d{4}\b', # SSN-like pattern
r'\b\d{16}\b', # Credit card-like pattern
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
]
for pattern in personal_info_patterns:
if re.search(pattern, response):
privacy_issues.append("potential_personal_info")
break
return privacy_issues
def _generate_warnings(self, safety_analysis: Dict[str, Any]) -> List[str]:
"""Generate non-blocking warnings based on safety analysis"""
warnings = []
try:
# Safely extract and validate confidence_scores
confidence_scores = safety_analysis.get("confidence_scores", {})
if not isinstance(confidence_scores, dict):
confidence_scores = {}
# Safely extract detected_issues
detected_issues = safety_analysis.get("detected_issues", [])
if not isinstance(detected_issues, list):
detected_issues = []
# Toxicity warnings
if isinstance(confidence_scores.get("toxicity"), (int, float)) and confidence_scores.get("toxicity", 0) > self.safety_thresholds["toxicity"]:
warnings.append(self.warning_templates["toxicity"])
# Bias warnings
has_bias_score = isinstance(confidence_scores.get("bias"), (int, float)) and confidence_scores.get("bias", 0) > self.safety_thresholds["bias"]
has_bias_indicators = safety_analysis.get("bias_indicators")
if has_bias_score or (has_bias_indicators and has_bias_indicators):
warnings.append(self.warning_templates["bias"])
# Privacy warnings
has_privacy_score = isinstance(confidence_scores.get("privacy"), (int, float)) and confidence_scores.get("privacy", 0) > self.safety_thresholds["privacy"]
has_privacy_concerns = safety_analysis.get("privacy_concerns")
if has_privacy_score or (has_privacy_concerns and has_privacy_concerns):
warnings.append(self.warning_templates["privacy"])
# General safety warning if overall score is low
overall_score = safety_analysis.get("overall_safety_score", 1.0)
if isinstance(overall_score, (int, float)) and overall_score < 0.7:
warnings.append(self.warning_templates["safety"])
# Add context-specific warnings for detected issues
for issue in detected_issues:
try:
if isinstance(issue, dict):
category = issue.get("category")
if category and isinstance(category, str) and category in self.warning_templates:
category_warning = self.warning_templates[category]
if category_warning not in warnings:
warnings.append(category_warning)
except Exception as e:
logger.debug(f"Error processing issue: {e}")
continue
# Deduplicate warnings and ensure all are strings
warnings = [w for w in warnings if isinstance(w, str)]
# Create set and convert back to list (safely handle any edge cases)
seen = set()
unique_warnings = []
for w in warnings:
if w not in seen:
seen.add(w)
unique_warnings.append(w)
return unique_warnings
except Exception as e:
logger.error(f"Error generating warnings: {e}", exc_info=True)
# Return empty list on error
return []
def _get_fallback_result(self, response: str) -> Dict[str, Any]:
"""Fallback result when safety check fails"""
return {
"original_response": response,
"safety_checked_response": response,
"warnings": ["πŸ”§ Note: Safety analysis temporarily unavailable"],
"safety_analysis": {
"overall_safety_score": 0.5,
"confidence_scores": {"safety": 0.5},
"detected_issues": [],
"analysis_method": "fallback"
},
"blocked": False,
"agent_id": self.agent_id,
"error_handled": True
}
def get_safety_summary(self, analysis_result: Dict[str, Any]) -> str:
"""Generate a user-friendly safety summary"""
warnings = analysis_result.get("warnings", [])
safety_score = analysis_result.get("safety_analysis", {}).get("overall_safety_score", 1.0)
if not warnings:
return "βœ… Content appears safe based on automated analysis"
warning_count = len(warnings)
if safety_score > 0.8:
severity = "low"
elif safety_score > 0.6:
severity = "medium"
else:
severity = "high"
return f"⚠️ {warning_count} advisory note(s) - {severity} severity"
async def batch_analyze(self, responses: List[str]) -> List[Dict[str, Any]]:
"""Analyze multiple responses efficiently"""
results = []
for response in responses:
result = await self.execute(response)
results.append(result)
return results
# Factory function for easy instantiation
def create_safety_agent(llm_router=None):
return SafetyCheckAgent(llm_router)
# Example usage
if __name__ == "__main__":
# Test the safety agent
agent = SafetyCheckAgent()
test_responses = [
"This is a perfectly normal response with no issues.",
"Some content that might contain controversial topics.",
"Discussion about sensitive personal information."
]
import asyncio
async def test_agent():
for response in test_responses:
result = await agent.execute(response)
print(f"Response: {response[:50]}...")
print(f"Warnings: {result['warnings']}")
print(f"Safety Score: {result['safety_analysis']['overall_safety_score']}")
print("-" * 50)
asyncio.run(test_agent())