|
|
""" |
|
|
Safety & Bias Mitigation Agent |
|
|
Specialized in content moderation and bias detection with non-blocking warnings |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import re |
|
|
from typing import Dict, Any, List, Tuple |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SafetyCheckAgent: |
|
|
def __init__(self, llm_router=None): |
|
|
self.llm_router = llm_router |
|
|
self.agent_id = "SAFETY_BIAS_001" |
|
|
self.specialization = "Content moderation and bias detection with warning-based approach" |
|
|
|
|
|
|
|
|
self.safety_thresholds = { |
|
|
"toxicity": 0.8, |
|
|
"bias": 0.7, |
|
|
"safety": 0.6, |
|
|
"privacy": 0.9 |
|
|
} |
|
|
|
|
|
|
|
|
self.warning_templates = { |
|
|
"toxicity": "β οΈ Note: Content may contain strong language", |
|
|
"bias": "π Note: Potential biases detected in response", |
|
|
"safety": "π Note: Response should be verified for accuracy", |
|
|
"privacy": "π Note: Privacy-sensitive topics discussed", |
|
|
"controversial": "π Note: This topic may have multiple perspectives" |
|
|
} |
|
|
|
|
|
|
|
|
self.sensitive_patterns = { |
|
|
"toxicity": [ |
|
|
r'\b(hate|violence|harm|attack|destroy)\b', |
|
|
r'\b(kill|hurt|harm|danger)\b', |
|
|
r'racial slurs', |
|
|
], |
|
|
"bias": [ |
|
|
r'\b(all|always|never|every)\b', |
|
|
r'\b(should|must|have to)\b', |
|
|
r'stereotypes?', |
|
|
], |
|
|
"privacy": [ |
|
|
r'\b(ssn|social security|password|credit card)\b', |
|
|
r'\b(address|phone|email|personal)\b', |
|
|
r'\b(confidential|secret|private)\b', |
|
|
] |
|
|
} |
|
|
|
|
|
async def execute(self, response, context: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]: |
|
|
""" |
|
|
Execute safety check with non-blocking warnings |
|
|
Returns original response with added warnings |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(response, dict): |
|
|
|
|
|
response_text = response.get('final_response', response.get('response', str(response))) |
|
|
else: |
|
|
response_text = str(response) |
|
|
|
|
|
logger.info(f"{self.agent_id} analyzing response of length {len(response_text)}") |
|
|
|
|
|
|
|
|
safety_analysis = await self._analyze_safety(response_text, context) |
|
|
|
|
|
|
|
|
warnings = self._generate_warnings(safety_analysis) |
|
|
|
|
|
|
|
|
result = { |
|
|
"original_response": response_text, |
|
|
"safety_checked_response": response_text, |
|
|
"warnings": warnings, |
|
|
"safety_analysis": safety_analysis, |
|
|
"blocked": False, |
|
|
"confidence_scores": safety_analysis.get("confidence_scores", {}), |
|
|
"agent_id": self.agent_id |
|
|
} |
|
|
|
|
|
logger.info(f"{self.agent_id} completed with {len(warnings)} warnings") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"{self.agent_id} error: {str(e)}", exc_info=True) |
|
|
|
|
|
response_text = str(response) if not isinstance(response, dict) else response.get('final_response', str(response)) |
|
|
return self._get_fallback_result(response_text) |
|
|
|
|
|
async def _analyze_safety(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Analyze response for safety concerns using multiple methods""" |
|
|
|
|
|
if self.llm_router: |
|
|
return await self._llm_based_safety_analysis(response, context) |
|
|
else: |
|
|
return await self._pattern_based_safety_analysis(response) |
|
|
|
|
|
async def _llm_based_safety_analysis(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Use LLM for sophisticated safety analysis""" |
|
|
|
|
|
try: |
|
|
safety_prompt = self._build_safety_prompt(response, context) |
|
|
|
|
|
logger.info(f"{self.agent_id} calling LLM for safety analysis") |
|
|
llm_response = await self.llm_router.route_inference( |
|
|
task_type="safety_check", |
|
|
prompt=safety_prompt, |
|
|
max_tokens=800, |
|
|
temperature=0.3 |
|
|
) |
|
|
|
|
|
if llm_response and isinstance(llm_response, str) and len(llm_response.strip()) > 0: |
|
|
|
|
|
parsed_analysis = self._parse_llm_safety_response(llm_response) |
|
|
parsed_analysis["processing_time"] = 0.6 |
|
|
parsed_analysis["method"] = "llm_enhanced" |
|
|
return parsed_analysis |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"{self.agent_id} LLM safety analysis failed: {e}") |
|
|
|
|
|
|
|
|
logger.info(f"{self.agent_id} falling back to pattern-based safety analysis") |
|
|
return await self._pattern_based_safety_analysis(response) |
|
|
|
|
|
async def _pattern_based_safety_analysis(self, response: str) -> Dict[str, Any]: |
|
|
"""Pattern-based safety analysis as fallback""" |
|
|
|
|
|
detected_issues = self._pattern_based_detection(response) |
|
|
|
|
|
return { |
|
|
"toxicity_score": self._calculate_toxicity_score(response), |
|
|
"bias_indicators": self._detect_bias_indicators(response), |
|
|
"privacy_concerns": self._check_privacy_issues(response), |
|
|
"overall_safety_score": 0.75, |
|
|
"confidence_scores": { |
|
|
"toxicity": 0.6, |
|
|
"bias": 0.5, |
|
|
"safety": 0.7, |
|
|
"privacy": 0.8 |
|
|
}, |
|
|
"detected_issues": detected_issues, |
|
|
"analysis_method": "pattern_based" |
|
|
} |
|
|
|
|
|
def _build_safety_prompt(self, response: str, context: Dict[str, Any]) -> str: |
|
|
"""Build prompt for safety analysis with context awareness""" |
|
|
|
|
|
|
|
|
context_info = "" |
|
|
if context: |
|
|
|
|
|
session_context = context.get('session_context', {}) |
|
|
session_summary = session_context.get('summary', '') if isinstance(session_context, dict) else "" |
|
|
user_context = context.get('user_context', '') |
|
|
interaction_contexts = context.get('interaction_contexts', []) |
|
|
|
|
|
if session_summary: |
|
|
context_info = f"\n\nSession Context (for safety context): {session_summary[:200]}..." |
|
|
if user_context: |
|
|
context_info += f"\n\nUser Context (for safety context): {user_context[:200]}..." |
|
|
|
|
|
|
|
|
if interaction_contexts: |
|
|
recent_context = interaction_contexts[-1].get('summary', '') if interaction_contexts else '' |
|
|
if recent_context: |
|
|
context_info += f"\nRecent conversation: {recent_context[:200]}..." |
|
|
|
|
|
return f""" |
|
|
Analyze the following text for safety concerns and provide warnings if needed: |
|
|
|
|
|
Text to analyze: "{response[:500]}..." # Truncate for large responses |
|
|
{context_info} |
|
|
|
|
|
Analyze for: |
|
|
1. Toxicity or harmful language |
|
|
2. Potential biases or stereotypes |
|
|
3. Privacy or security concerns |
|
|
4. Controversial or sensitive topics |
|
|
|
|
|
Consider the conversation context when assessing whether content is appropriate in this specific conversational context. |
|
|
|
|
|
Provide analysis in JSON format with: |
|
|
- safety_scores for each category (0-1 scale) |
|
|
- detected_issues list |
|
|
- confidence_level |
|
|
- recommended_warnings (non-blocking, advisory only) |
|
|
|
|
|
IMPORTANT: Never block or modify the content, only provide warnings. |
|
|
""" |
|
|
|
|
|
def _pattern_based_detection(self, response: str) -> List[Dict[str, Any]]: |
|
|
"""Detect safety issues using pattern matching""" |
|
|
issues = [] |
|
|
response_lower = response.lower() |
|
|
|
|
|
|
|
|
for category, patterns in self.sensitive_patterns.items(): |
|
|
for pattern in patterns: |
|
|
if re.search(pattern, response_lower, re.IGNORECASE): |
|
|
issues.append({ |
|
|
"category": category, |
|
|
"pattern": pattern, |
|
|
"severity": "low", |
|
|
"confidence": 0.7 |
|
|
}) |
|
|
break |
|
|
|
|
|
return issues |
|
|
|
|
|
def _calculate_toxicity_score(self, response: str) -> float: |
|
|
"""Calculate toxicity score (simplified version)""" |
|
|
|
|
|
toxic_indicators = [ |
|
|
'hate', 'violence', 'harm', 'attack', 'destroy', 'kill', 'hurt' |
|
|
] |
|
|
|
|
|
score = 0.0 |
|
|
words = response.lower().split() |
|
|
for indicator in toxic_indicators: |
|
|
if indicator in words: |
|
|
score += 0.2 |
|
|
|
|
|
return min(1.0, score) |
|
|
|
|
|
def _detect_bias_indicators(self, response: str) -> List[str]: |
|
|
"""Detect potential bias indicators""" |
|
|
biases = [] |
|
|
|
|
|
|
|
|
if re.search(r'\b(all|always|never|every)\s+\w+s\b', response, re.IGNORECASE): |
|
|
biases.append("overgeneralization") |
|
|
|
|
|
|
|
|
if re.search(r'\b(should|must|have to|ought to)\b', response, re.IGNORECASE): |
|
|
biases.append("prescriptive_language") |
|
|
|
|
|
|
|
|
stereotype_patterns = [ |
|
|
r'\b(all|most)\s+\w+\s+people\b', |
|
|
r'\b(typical|usual|normal)\s+\w+\b', |
|
|
] |
|
|
|
|
|
for pattern in stereotype_patterns: |
|
|
if re.search(pattern, response, re.IGNORECASE): |
|
|
biases.append("potential_stereotype") |
|
|
break |
|
|
|
|
|
return biases |
|
|
|
|
|
def _check_privacy_issues(self, response: str) -> List[str]: |
|
|
"""Check for privacy-sensitive content""" |
|
|
privacy_issues = [] |
|
|
|
|
|
|
|
|
personal_info_patterns = [ |
|
|
r'\b\d{3}-\d{2}-\d{4}\b', |
|
|
r'\b\d{16}\b', |
|
|
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', |
|
|
] |
|
|
|
|
|
for pattern in personal_info_patterns: |
|
|
if re.search(pattern, response): |
|
|
privacy_issues.append("potential_personal_info") |
|
|
break |
|
|
|
|
|
return privacy_issues |
|
|
|
|
|
def _generate_warnings(self, safety_analysis: Dict[str, Any]) -> List[str]: |
|
|
"""Generate non-blocking warnings based on safety analysis""" |
|
|
warnings = [] |
|
|
|
|
|
try: |
|
|
|
|
|
confidence_scores = safety_analysis.get("confidence_scores", {}) |
|
|
if not isinstance(confidence_scores, dict): |
|
|
confidence_scores = {} |
|
|
|
|
|
|
|
|
detected_issues = safety_analysis.get("detected_issues", []) |
|
|
if not isinstance(detected_issues, list): |
|
|
detected_issues = [] |
|
|
|
|
|
|
|
|
if isinstance(confidence_scores.get("toxicity"), (int, float)) and confidence_scores.get("toxicity", 0) > self.safety_thresholds["toxicity"]: |
|
|
warnings.append(self.warning_templates["toxicity"]) |
|
|
|
|
|
|
|
|
has_bias_score = isinstance(confidence_scores.get("bias"), (int, float)) and confidence_scores.get("bias", 0) > self.safety_thresholds["bias"] |
|
|
has_bias_indicators = safety_analysis.get("bias_indicators") |
|
|
if has_bias_score or (has_bias_indicators and has_bias_indicators): |
|
|
warnings.append(self.warning_templates["bias"]) |
|
|
|
|
|
|
|
|
has_privacy_score = isinstance(confidence_scores.get("privacy"), (int, float)) and confidence_scores.get("privacy", 0) > self.safety_thresholds["privacy"] |
|
|
has_privacy_concerns = safety_analysis.get("privacy_concerns") |
|
|
if has_privacy_score or (has_privacy_concerns and has_privacy_concerns): |
|
|
warnings.append(self.warning_templates["privacy"]) |
|
|
|
|
|
|
|
|
overall_score = safety_analysis.get("overall_safety_score", 1.0) |
|
|
if isinstance(overall_score, (int, float)) and overall_score < 0.7: |
|
|
warnings.append(self.warning_templates["safety"]) |
|
|
|
|
|
|
|
|
for issue in detected_issues: |
|
|
try: |
|
|
if isinstance(issue, dict): |
|
|
category = issue.get("category") |
|
|
if category and isinstance(category, str) and category in self.warning_templates: |
|
|
category_warning = self.warning_templates[category] |
|
|
if category_warning not in warnings: |
|
|
warnings.append(category_warning) |
|
|
except Exception as e: |
|
|
logger.debug(f"Error processing issue: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
warnings = [w for w in warnings if isinstance(w, str)] |
|
|
|
|
|
seen = set() |
|
|
unique_warnings = [] |
|
|
for w in warnings: |
|
|
if w not in seen: |
|
|
seen.add(w) |
|
|
unique_warnings.append(w) |
|
|
return unique_warnings |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating warnings: {e}", exc_info=True) |
|
|
|
|
|
return [] |
|
|
|
|
|
def _parse_llm_safety_response(self, response: str) -> Dict[str, Any]: |
|
|
"""Parse LLM response for safety analysis""" |
|
|
try: |
|
|
import json |
|
|
import re |
|
|
|
|
|
|
|
|
json_match = re.search(r'\{.*\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
parsed = json.loads(json_match.group()) |
|
|
return parsed |
|
|
except json.JSONDecodeError: |
|
|
logger.warning(f"{self.agent_id} Failed to parse LLM safety JSON") |
|
|
|
|
|
|
|
|
response_lower = response.lower() |
|
|
|
|
|
|
|
|
toxicity_score = 0.1 |
|
|
bias_score = 0.1 |
|
|
safety_score = 0.9 |
|
|
|
|
|
if any(word in response_lower for word in ['toxic', 'harmful', 'dangerous', 'inappropriate']): |
|
|
toxicity_score = 0.8 |
|
|
safety_score = 0.3 |
|
|
elif any(word in response_lower for word in ['bias', 'discriminatory', 'unfair', 'prejudiced']): |
|
|
bias_score = 0.7 |
|
|
safety_score = 0.5 |
|
|
|
|
|
return { |
|
|
"toxicity_score": toxicity_score, |
|
|
"bias_indicators": [], |
|
|
"privacy_concerns": [], |
|
|
"overall_safety_score": safety_score, |
|
|
"confidence_scores": { |
|
|
"toxicity": 0.7, |
|
|
"bias": 0.6, |
|
|
"safety": safety_score, |
|
|
"privacy": 0.9 |
|
|
}, |
|
|
"detected_issues": [], |
|
|
"analysis_method": "llm_parsed", |
|
|
"llm_response": response[:200] + "..." if len(response) > 200 else response |
|
|
} |
|
|
|
|
|
def _get_fallback_result(self, response: str) -> Dict[str, Any]: |
|
|
"""Fallback result when safety check fails""" |
|
|
return { |
|
|
"original_response": response, |
|
|
"safety_checked_response": response, |
|
|
"warnings": ["π§ Note: Safety analysis temporarily unavailable"], |
|
|
"safety_analysis": { |
|
|
"overall_safety_score": 0.5, |
|
|
"confidence_scores": {"safety": 0.5}, |
|
|
"detected_issues": [], |
|
|
"analysis_method": "fallback" |
|
|
}, |
|
|
"blocked": False, |
|
|
"agent_id": self.agent_id, |
|
|
"error_handled": True |
|
|
} |
|
|
|
|
|
def get_safety_summary(self, analysis_result: Dict[str, Any]) -> str: |
|
|
"""Generate a user-friendly safety summary""" |
|
|
warnings = analysis_result.get("warnings", []) |
|
|
safety_score = analysis_result.get("safety_analysis", {}).get("overall_safety_score", 1.0) |
|
|
|
|
|
if not warnings: |
|
|
return "β
Content appears safe based on automated analysis" |
|
|
|
|
|
warning_count = len(warnings) |
|
|
if safety_score > 0.8: |
|
|
severity = "low" |
|
|
elif safety_score > 0.6: |
|
|
severity = "medium" |
|
|
else: |
|
|
severity = "high" |
|
|
|
|
|
return f"β οΈ {warning_count} advisory note(s) - {severity} severity" |
|
|
|
|
|
async def batch_analyze(self, responses: List[str]) -> List[Dict[str, Any]]: |
|
|
"""Analyze multiple responses efficiently""" |
|
|
results = [] |
|
|
for response in responses: |
|
|
result = await self.execute(response) |
|
|
results.append(result) |
|
|
return results |
|
|
|
|
|
|
|
|
def create_safety_agent(llm_router=None): |
|
|
return SafetyCheckAgent(llm_router) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
agent = SafetyCheckAgent() |
|
|
|
|
|
test_responses = [ |
|
|
"This is a perfectly normal response with no issues.", |
|
|
"Some content that might contain controversial topics.", |
|
|
"Discussion about sensitive personal information." |
|
|
] |
|
|
|
|
|
import asyncio |
|
|
|
|
|
async def test_agent(): |
|
|
for response in test_responses: |
|
|
result = await agent.execute(response) |
|
|
print(f"Response: {response[:50]}...") |
|
|
print(f"Warnings: {result['warnings']}") |
|
|
print(f"Safety Score: {result['safety_analysis']['overall_safety_score']}") |
|
|
print("-" * 50) |
|
|
|
|
|
asyncio.run(test_agent()) |
|
|
|
|
|
|