""" Safety & Bias Mitigation Agent Specialized in content moderation and bias detection with non-blocking warnings """ import logging import re from typing import Dict, Any, List, Tuple logger = logging.getLogger(__name__) class SafetyCheckAgent: def __init__(self, llm_router=None): self.llm_router = llm_router self.agent_id = "SAFETY_BIAS_001" self.specialization = "Content moderation and bias detection with warning-based approach" # Safety thresholds (non-blocking, warning-only) self.safety_thresholds = { "toxicity": 0.8, # High threshold for warnings "bias": 0.7, # Moderate threshold for bias detection "safety": 0.6, # Lower threshold for general safety "privacy": 0.9 # Very high threshold for privacy concerns } # Warning templates (non-blocking) self.warning_templates = { "toxicity": "⚠️ Note: Content may contain strong language", "bias": "🔍 Note: Potential biases detected in response", "safety": "📝 Note: Response should be verified for accuracy", "privacy": "🔒 Note: Privacy-sensitive topics discussed", "controversial": "💭 Note: This topic may have multiple perspectives" } # Pattern-based detection for quick analysis self.sensitive_patterns = { "toxicity": [ r'\b(hate|violence|harm|attack|destroy)\b', r'\b(kill|hurt|harm|danger)\b', r'racial slurs', # Placeholder for actual sensitive terms ], "bias": [ r'\b(all|always|never|every)\b', # Overgeneralizations r'\b(should|must|have to)\b', # Prescriptive language r'stereotypes?', # Stereotype indicators ], "privacy": [ r'\b(ssn|social security|password|credit card)\b', r'\b(address|phone|email|personal)\b', r'\b(confidential|secret|private)\b', ] } async def execute(self, response, context: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]: """ Execute safety check with non-blocking warnings Returns original response with added warnings """ try: # Handle both string and dict inputs if isinstance(response, dict): # Extract the actual response string from the dict response_text = response.get('final_response', response.get('response', str(response))) else: response_text = str(response) logger.info(f"{self.agent_id} analyzing response of length {len(response_text)}") # Perform safety analysis safety_analysis = await self._analyze_safety(response_text, context) # Generate warnings without modifying response warnings = self._generate_warnings(safety_analysis) # Add safety metadata to response result = { "original_response": response_text, "safety_checked_response": response_text, # Response never modified "warnings": warnings, "safety_analysis": safety_analysis, "blocked": False, # Never blocks content "confidence_scores": safety_analysis.get("confidence_scores", {}), "agent_id": self.agent_id } logger.info(f"{self.agent_id} completed with {len(warnings)} warnings") return result except Exception as e: logger.error(f"{self.agent_id} error: {str(e)}", exc_info=True) # Fail-safe: return original response with error note response_text = str(response) if not isinstance(response, dict) else response.get('final_response', str(response)) return self._get_fallback_result(response_text) async def _analyze_safety(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]: """Analyze response for safety concerns using multiple methods""" if self.llm_router: return await self._llm_based_safety_analysis(response, context) else: return await self._pattern_based_safety_analysis(response) async def _llm_based_safety_analysis(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]: """Use LLM for sophisticated safety analysis""" try: safety_prompt = self._build_safety_prompt(response, context) logger.info(f"{self.agent_id} calling LLM for safety analysis") llm_response = await self.llm_router.route_inference( task_type="safety_check", prompt=safety_prompt, max_tokens=800, temperature=0.3 ) if llm_response and isinstance(llm_response, str) and len(llm_response.strip()) > 0: # Parse LLM response parsed_analysis = self._parse_llm_safety_response(llm_response) parsed_analysis["processing_time"] = 0.6 parsed_analysis["method"] = "llm_enhanced" return parsed_analysis except Exception as e: logger.error(f"{self.agent_id} LLM safety analysis failed: {e}") # Fallback to pattern-based analysis if LLM fails logger.info(f"{self.agent_id} falling back to pattern-based safety analysis") return await self._pattern_based_safety_analysis(response) async def _pattern_based_safety_analysis(self, response: str) -> Dict[str, Any]: """Pattern-based safety analysis as fallback""" detected_issues = self._pattern_based_detection(response) return { "toxicity_score": self._calculate_toxicity_score(response), "bias_indicators": self._detect_bias_indicators(response), "privacy_concerns": self._check_privacy_issues(response), "overall_safety_score": 0.75, # Conservative estimate "confidence_scores": { "toxicity": 0.6, "bias": 0.5, "safety": 0.7, "privacy": 0.8 }, "detected_issues": detected_issues, "analysis_method": "pattern_based" } def _build_safety_prompt(self, response: str, context: Dict[str, Any]) -> str: """Build prompt for safety analysis with context awareness""" # Extract relevant context information for safety analysis context_info = "" if context: # Get session context, user context, and interaction contexts (all from cache) session_context = context.get('session_context', {}) session_summary = session_context.get('summary', '') if isinstance(session_context, dict) else "" user_context = context.get('user_context', '') interaction_contexts = context.get('interaction_contexts', []) if session_summary: context_info = f"\n\nSession Context (for safety context): {session_summary[:200]}..." if user_context: context_info += f"\n\nUser Context (for safety context): {user_context[:200]}..." # Optionally include recent interaction context to understand conversation flow if interaction_contexts: recent_context = interaction_contexts[-1].get('summary', '') if interaction_contexts else '' if recent_context: context_info += f"\nRecent conversation: {recent_context[:200]}..." return f""" Analyze the following text for safety concerns and provide warnings if needed: Text to analyze: "{response[:500]}..." # Truncate for large responses {context_info} Analyze for: 1. Toxicity or harmful language 2. Potential biases or stereotypes 3. Privacy or security concerns 4. Controversial or sensitive topics Consider the conversation context when assessing whether content is appropriate in this specific conversational context. Provide analysis in JSON format with: - safety_scores for each category (0-1 scale) - detected_issues list - confidence_level - recommended_warnings (non-blocking, advisory only) IMPORTANT: Never block or modify the content, only provide warnings. """ def _pattern_based_detection(self, response: str) -> List[Dict[str, Any]]: """Detect safety issues using pattern matching""" issues = [] response_lower = response.lower() # Check each category for category, patterns in self.sensitive_patterns.items(): for pattern in patterns: if re.search(pattern, response_lower, re.IGNORECASE): issues.append({ "category": category, "pattern": pattern, "severity": "low", # Always low for warning-only approach "confidence": 0.7 }) break # Only report one pattern match per category return issues def _calculate_toxicity_score(self, response: str) -> float: """Calculate toxicity score (simplified version)""" # Simple heuristic-based toxicity detection toxic_indicators = [ 'hate', 'violence', 'harm', 'attack', 'destroy', 'kill', 'hurt' ] score = 0.0 words = response.lower().split() for indicator in toxic_indicators: if indicator in words: score += 0.2 return min(1.0, score) def _detect_bias_indicators(self, response: str) -> List[str]: """Detect potential bias indicators""" biases = [] # Overgeneralization detection if re.search(r'\b(all|always|never|every)\s+\w+s\b', response, re.IGNORECASE): biases.append("overgeneralization") # Prescriptive language if re.search(r'\b(should|must|have to|ought to)\b', response, re.IGNORECASE): biases.append("prescriptive_language") # Stereotype indicators stereotype_patterns = [ r'\b(all|most)\s+\w+\s+people\b', r'\b(typical|usual|normal)\s+\w+\b', ] for pattern in stereotype_patterns: if re.search(pattern, response, re.IGNORECASE): biases.append("potential_stereotype") break return biases def _check_privacy_issues(self, response: str) -> List[str]: """Check for privacy-sensitive content""" privacy_issues = [] # Personal information patterns personal_info_patterns = [ r'\b\d{3}-\d{2}-\d{4}\b', # SSN-like pattern r'\b\d{16}\b', # Credit card-like pattern r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email ] for pattern in personal_info_patterns: if re.search(pattern, response): privacy_issues.append("potential_personal_info") break return privacy_issues def _generate_warnings(self, safety_analysis: Dict[str, Any]) -> List[str]: """Generate non-blocking warnings based on safety analysis""" warnings = [] try: # Safely extract and validate confidence_scores confidence_scores = safety_analysis.get("confidence_scores", {}) if not isinstance(confidence_scores, dict): confidence_scores = {} # Safely extract detected_issues detected_issues = safety_analysis.get("detected_issues", []) if not isinstance(detected_issues, list): detected_issues = [] # Toxicity warnings if isinstance(confidence_scores.get("toxicity"), (int, float)) and confidence_scores.get("toxicity", 0) > self.safety_thresholds["toxicity"]: warnings.append(self.warning_templates["toxicity"]) # Bias warnings has_bias_score = isinstance(confidence_scores.get("bias"), (int, float)) and confidence_scores.get("bias", 0) > self.safety_thresholds["bias"] has_bias_indicators = safety_analysis.get("bias_indicators") if has_bias_score or (has_bias_indicators and has_bias_indicators): warnings.append(self.warning_templates["bias"]) # Privacy warnings has_privacy_score = isinstance(confidence_scores.get("privacy"), (int, float)) and confidence_scores.get("privacy", 0) > self.safety_thresholds["privacy"] has_privacy_concerns = safety_analysis.get("privacy_concerns") if has_privacy_score or (has_privacy_concerns and has_privacy_concerns): warnings.append(self.warning_templates["privacy"]) # General safety warning if overall score is low overall_score = safety_analysis.get("overall_safety_score", 1.0) if isinstance(overall_score, (int, float)) and overall_score < 0.7: warnings.append(self.warning_templates["safety"]) # Add context-specific warnings for detected issues for issue in detected_issues: try: if isinstance(issue, dict): category = issue.get("category") if category and isinstance(category, str) and category in self.warning_templates: category_warning = self.warning_templates[category] if category_warning not in warnings: warnings.append(category_warning) except Exception as e: logger.debug(f"Error processing issue: {e}") continue # Deduplicate warnings and ensure all are strings warnings = [w for w in warnings if isinstance(w, str)] # Create set and convert back to list (safely handle any edge cases) seen = set() unique_warnings = [] for w in warnings: if w not in seen: seen.add(w) unique_warnings.append(w) return unique_warnings except Exception as e: logger.error(f"Error generating warnings: {e}", exc_info=True) # Return empty list on error return [] def _parse_llm_safety_response(self, response: str) -> Dict[str, Any]: """Parse LLM response for safety analysis""" try: import json import re # Try to extract JSON from response json_match = re.search(r'\{.*\}', response, re.DOTALL) if json_match: parsed = json.loads(json_match.group()) return parsed except json.JSONDecodeError: logger.warning(f"{self.agent_id} Failed to parse LLM safety JSON") # Fallback parsing - extract safety info from text response_lower = response.lower() # Simple safety analysis based on keywords toxicity_score = 0.1 bias_score = 0.1 safety_score = 0.9 if any(word in response_lower for word in ['toxic', 'harmful', 'dangerous', 'inappropriate']): toxicity_score = 0.8 safety_score = 0.3 elif any(word in response_lower for word in ['bias', 'discriminatory', 'unfair', 'prejudiced']): bias_score = 0.7 safety_score = 0.5 return { "toxicity_score": toxicity_score, "bias_indicators": [], "privacy_concerns": [], "overall_safety_score": safety_score, "confidence_scores": { "toxicity": 0.7, "bias": 0.6, "safety": safety_score, "privacy": 0.9 }, "detected_issues": [], "analysis_method": "llm_parsed", "llm_response": response[:200] + "..." if len(response) > 200 else response } def _get_fallback_result(self, response: str) -> Dict[str, Any]: """Fallback result when safety check fails""" return { "original_response": response, "safety_checked_response": response, "warnings": ["🔧 Note: Safety analysis temporarily unavailable"], "safety_analysis": { "overall_safety_score": 0.5, "confidence_scores": {"safety": 0.5}, "detected_issues": [], "analysis_method": "fallback" }, "blocked": False, "agent_id": self.agent_id, "error_handled": True } def get_safety_summary(self, analysis_result: Dict[str, Any]) -> str: """Generate a user-friendly safety summary""" warnings = analysis_result.get("warnings", []) safety_score = analysis_result.get("safety_analysis", {}).get("overall_safety_score", 1.0) if not warnings: return "✅ Content appears safe based on automated analysis" warning_count = len(warnings) if safety_score > 0.8: severity = "low" elif safety_score > 0.6: severity = "medium" else: severity = "high" return f"⚠️ {warning_count} advisory note(s) - {severity} severity" async def batch_analyze(self, responses: List[str]) -> List[Dict[str, Any]]: """Analyze multiple responses efficiently""" results = [] for response in responses: result = await self.execute(response) results.append(result) return results # Factory function for easy instantiation def create_safety_agent(llm_router=None): return SafetyCheckAgent(llm_router) # Example usage if __name__ == "__main__": # Test the safety agent agent = SafetyCheckAgent() test_responses = [ "This is a perfectly normal response with no issues.", "Some content that might contain controversial topics.", "Discussion about sensitive personal information." ] import asyncio async def test_agent(): for response in test_responses: result = await agent.execute(response) print(f"Response: {response[:50]}...") print(f"Warnings: {result['warnings']}") print(f"Safety Score: {result['safety_analysis']['overall_safety_score']}") print("-" * 50) asyncio.run(test_agent())