|
|
"""Persistent Audit System for Guardrails MCP""" |
|
|
|
|
|
import sqlite3 |
|
|
import json |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any, List |
|
|
import hashlib |
|
|
|
|
|
DB_PATH = Path(__file__).parent.parent / "audit_logs.db" |
|
|
|
|
|
def init_database(): |
|
|
"""Initialize SQLite database with audit schema""" |
|
|
conn = sqlite3.connect(str(DB_PATH)) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS audit_logs ( |
|
|
id TEXT PRIMARY KEY, |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, |
|
|
tool_name TEXT NOT NULL, |
|
|
agent_id TEXT, |
|
|
input_hash TEXT, |
|
|
input_summary TEXT, |
|
|
result_summary TEXT, |
|
|
risk_level TEXT, |
|
|
decision TEXT, |
|
|
detection_details JSON, |
|
|
session_id TEXT, |
|
|
ip_address TEXT, |
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON audit_logs(timestamp)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_agent_id ON audit_logs(agent_id)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_risk_level ON audit_logs(risk_level)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_tool_name ON audit_logs(tool_name)") |
|
|
|
|
|
|
|
|
cursor.execute("PRAGMA journal_mode=WAL") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
def generate_audit_id(tool_prefix: str) -> str: |
|
|
"""Generate unique audit ID like 'inj_20251126_143022_abc123'""" |
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
|
|
random_suffix = hashlib.md5(str(datetime.utcnow().timestamp()).encode()).hexdigest()[:6] |
|
|
return f"{tool_prefix}_{timestamp}_{random_suffix}" |
|
|
|
|
|
def log_to_db( |
|
|
audit_id: str, |
|
|
tool_name: str, |
|
|
input_data: Dict[str, Any], |
|
|
result: Dict[str, Any], |
|
|
agent_id: Optional[str] = None, |
|
|
session_id: Optional[str] = None, |
|
|
ip_address: Optional[str] = None |
|
|
) -> None: |
|
|
"""Write audit entry to SQLite database""" |
|
|
try: |
|
|
conn = sqlite3.connect(str(DB_PATH)) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
input_str = json.dumps(input_data, sort_keys=True) |
|
|
input_hash = hashlib.sha256(input_str.encode()).hexdigest() |
|
|
|
|
|
|
|
|
input_summary = str(input_data.get('input_text', input_data.get('action', '')))[:200] |
|
|
result_summary = str(result.get('decision', result.get('recommendation', ''))) |
|
|
risk_level = result.get('risk_level', result.get('severity', 'unknown')) |
|
|
decision = result.get('decision', result.get('recommendation', '')) |
|
|
|
|
|
cursor.execute(""" |
|
|
INSERT INTO audit_logs |
|
|
(id, tool_name, agent_id, input_hash, input_summary, result_summary, |
|
|
risk_level, decision, detection_details, session_id, ip_address) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
audit_id, |
|
|
tool_name, |
|
|
agent_id, |
|
|
input_hash, |
|
|
input_summary, |
|
|
result_summary, |
|
|
risk_level, |
|
|
decision, |
|
|
json.dumps(result), |
|
|
session_id, |
|
|
ip_address |
|
|
)) |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
except Exception as e: |
|
|
print(f"Error logging to database: {e}") |
|
|
|
|
|
def query_audit_logs( |
|
|
count: int = 50, |
|
|
tool_name: Optional[str] = None, |
|
|
risk_level: Optional[str] = None, |
|
|
agent_id: Optional[str] = None |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Query recent audit logs with optional filters""" |
|
|
try: |
|
|
conn = sqlite3.connect(str(DB_PATH)) |
|
|
conn.row_factory = sqlite3.Row |
|
|
cursor = conn.cursor() |
|
|
|
|
|
query = "SELECT * FROM audit_logs WHERE 1=1" |
|
|
params = [] |
|
|
|
|
|
if tool_name: |
|
|
query += " AND tool_name = ?" |
|
|
params.append(tool_name) |
|
|
|
|
|
if risk_level: |
|
|
query += " AND risk_level = ?" |
|
|
params.append(risk_level) |
|
|
|
|
|
if agent_id: |
|
|
query += " AND agent_id = ?" |
|
|
params.append(agent_id) |
|
|
|
|
|
query += " ORDER BY timestamp DESC LIMIT ?" |
|
|
params.append(count) |
|
|
|
|
|
cursor.execute(query, params) |
|
|
rows = cursor.fetchall() |
|
|
|
|
|
results = [] |
|
|
for row in rows: |
|
|
results.append({ |
|
|
'id': row['id'], |
|
|
'timestamp': row['timestamp'], |
|
|
'tool_name': row['tool_name'], |
|
|
'agent_id': row['agent_id'], |
|
|
'input_summary': row['input_summary'], |
|
|
'result_summary': row['result_summary'], |
|
|
'risk_level': row['risk_level'], |
|
|
'decision': row['decision'], |
|
|
'detection_details': json.loads(row['detection_details']) if row['detection_details'] else {} |
|
|
}) |
|
|
|
|
|
conn.close() |
|
|
return results |
|
|
except Exception as e: |
|
|
print(f"Error querying audit logs: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def get_recent_audit_logs(limit: int = 100, **kwargs) -> List[Dict[str, Any]]: |
|
|
"""Get recent audit logs (alias for query_audit_logs)""" |
|
|
return query_audit_logs(count=limit, **kwargs) |
|
|
|
|
|
|
|
|
init_database() |
|
|
|