HonestAI / flask_api_standalone.py
JatsTheAIGen's picture
Add input validation for chat endpoint - length limits and type checking
96e6d20
raw
history blame
8.73 kB
#!/usr/bin/env python3
"""
Pure Flask API for Hugging Face Spaces
No Gradio - Just Flask REST API
Uses local GPU models for inference
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
import logging
import sys
import os
import asyncio
from pathlib import Path
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# Create Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all origins
# Global orchestrator
orchestrator = None
orchestrator_available = False
def initialize_orchestrator():
"""Initialize the AI orchestrator with local GPU models"""
global orchestrator, orchestrator_available
try:
logger.info("=" * 60)
logger.info("INITIALIZING AI ORCHESTRATOR (Local GPU Models)")
logger.info("=" * 60)
from src.agents.intent_agent import create_intent_agent
from src.agents.synthesis_agent import create_synthesis_agent
from src.agents.safety_agent import create_safety_agent
from src.agents.skills_identification_agent import create_skills_identification_agent
from src.llm_router import LLMRouter
from src.orchestrator_engine import MVPOrchestrator
from src.context_manager import EfficientContextManager
logger.info("✓ Imports successful")
hf_token = os.getenv('HF_TOKEN', '')
if not hf_token:
logger.warning("HF_TOKEN not set - API fallback will be used if local models fail")
# Initialize LLM Router with local model loading enabled
logger.info("Initializing LLM Router with local GPU model loading...")
llm_router = LLMRouter(hf_token, use_local_models=True)
logger.info("Initializing Agents...")
agents = {
'intent_recognition': create_intent_agent(llm_router),
'response_synthesis': create_synthesis_agent(llm_router),
'safety_check': create_safety_agent(llm_router),
'skills_identification': create_skills_identification_agent(llm_router)
}
logger.info("Initializing Context Manager...")
context_manager = EfficientContextManager(llm_router=llm_router)
logger.info("Initializing Orchestrator...")
orchestrator = MVPOrchestrator(llm_router, context_manager, agents)
orchestrator_available = True
logger.info("=" * 60)
logger.info("✓ AI ORCHESTRATOR READY")
logger.info(" - Local GPU models enabled")
logger.info(" - MAX_WORKERS: 4")
logger.info("=" * 60)
return True
except Exception as e:
logger.error(f"Failed to initialize: {e}", exc_info=True)
orchestrator_available = False
return False
# Root endpoint
@app.route('/', methods=['GET'])
def root():
"""API information"""
return jsonify({
'name': 'AI Assistant Flask API',
'version': '1.0',
'status': 'running',
'orchestrator_ready': orchestrator_available,
'features': {
'local_gpu_models': True,
'max_workers': 4,
'hardware': 'NVIDIA T4 Medium'
},
'endpoints': {
'health': 'GET /api/health',
'chat': 'POST /api/chat',
'initialize': 'POST /api/initialize'
}
})
# Health check
@app.route('/api/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
'status': 'healthy' if orchestrator_available else 'initializing',
'orchestrator_ready': orchestrator_available
})
# Chat endpoint
@app.route('/api/chat', methods=['POST'])
def chat():
"""
Process chat message
POST /api/chat
{
"message": "user message",
"history": [[user, assistant], ...],
"session_id": "session-123",
"user_id": "user-456"
}
Returns:
{
"success": true,
"message": "AI response",
"history": [...],
"reasoning": {...},
"performance": {...}
}
"""
try:
data = request.get_json()
if not data or 'message' not in data:
return jsonify({
'success': False,
'error': 'Message is required'
}), 400
message = data['message']
# Input validation
if not isinstance(message, str):
return jsonify({
'success': False,
'error': 'Message must be a string'
}), 400
# Strip whitespace and validate
message = message.strip()
if not message:
return jsonify({
'success': False,
'error': 'Message cannot be empty'
}), 400
# Length limit (prevent abuse)
MAX_MESSAGE_LENGTH = 10000 # 10KB limit
if len(message) > MAX_MESSAGE_LENGTH:
return jsonify({
'success': False,
'error': f'Message too long. Maximum length is {MAX_MESSAGE_LENGTH} characters'
}), 400
history = data.get('history', [])
session_id = data.get('session_id')
user_id = data.get('user_id', 'anonymous')
logger.info(f"Chat request - User: {user_id}, Session: {session_id}")
logger.info(f"Message length: {len(message)} chars, preview: {message[:100]}...")
if not orchestrator_available or orchestrator is None:
return jsonify({
'success': False,
'error': 'Orchestrator not ready',
'message': 'AI system is initializing. Please try again in a moment.'
}), 503
# Process with orchestrator (async method)
# Set user_id for session tracking
if session_id:
orchestrator.set_user_id(session_id, user_id)
# Run async process_request in event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
orchestrator.process_request(
session_id=session_id or f"session-{user_id}",
user_input=message
)
)
finally:
loop.close()
# Extract response
if isinstance(result, dict):
response_text = result.get('response', '')
reasoning = result.get('reasoning', {})
performance = result.get('performance', {})
else:
response_text = str(result)
reasoning = {}
performance = {}
updated_history = history + [[message, response_text]]
logger.info(f"✓ Response generated (length: {len(response_text)})")
return jsonify({
'success': True,
'message': response_text,
'history': updated_history,
'reasoning': reasoning,
'performance': performance
})
except Exception as e:
logger.error(f"Chat error: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e),
'message': 'Error processing your request. Please try again.'
}), 500
# Manual initialization endpoint
@app.route('/api/initialize', methods=['POST'])
def initialize():
"""Manually trigger initialization"""
success = initialize_orchestrator()
if success:
return jsonify({
'success': True,
'message': 'Orchestrator initialized successfully'
})
else:
return jsonify({
'success': False,
'message': 'Initialization failed. Check logs for details.'
}), 500
# Initialize on startup
if __name__ == '__main__':
logger.info("=" * 60)
logger.info("STARTING PURE FLASK API")
logger.info("=" * 60)
# Initialize orchestrator
initialize_orchestrator()
port = int(os.getenv('PORT', 7860))
logger.info(f"Starting Flask on port {port}")
logger.info("Endpoints available:")
logger.info(" GET /")
logger.info(" GET /api/health")
logger.info(" POST /api/chat")
logger.info(" POST /api/initialize")
logger.info("=" * 60)
app.run(
host='0.0.0.0',
port=port,
debug=False,
threaded=True # Enable threading for concurrent requests
)