|
|
""" |
|
|
Multi-Agent RAG-Enhanced LLM System |
|
|
κ°λ
μ(Supervisor) -> μ°½μμ± μμ±μ(Creative) -> λΉνμ(Critic) -> κ°λ
μ(Final) |
|
|
4λ¨κ³ νμ΄νλΌμΈμ ν΅ν κ³ νμ§ λ΅λ³ μμ± μμ€ν
|
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
import time |
|
|
from typing import Optional, List, Dict, Any, Tuple |
|
|
from contextlib import asynccontextmanager |
|
|
from datetime import datetime |
|
|
from enum import Enum |
|
|
|
|
|
import requests |
|
|
import uvicorn |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentRole(Enum): |
|
|
"""μμ΄μ νΈ μν μ μ""" |
|
|
SUPERVISOR = "supervisor" |
|
|
CREATIVE = "creative" |
|
|
CRITIC = "critic" |
|
|
FINALIZER = "finalizer" |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[Message] |
|
|
model: str = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" |
|
|
max_tokens: int = Field(default=4096, ge=1, le=8192) |
|
|
temperature: float = Field(default=0.6, ge=0, le=2) |
|
|
top_p: float = Field(default=1.0, ge=0, le=1) |
|
|
top_k: int = Field(default=40, ge=1, le=100) |
|
|
use_search: bool = Field(default=True) |
|
|
|
|
|
|
|
|
class AgentResponse(BaseModel): |
|
|
role: AgentRole |
|
|
content: str |
|
|
metadata: Optional[Dict] = None |
|
|
|
|
|
|
|
|
class FinalResponse(BaseModel): |
|
|
final_answer: str |
|
|
agent_responses: List[AgentResponse] |
|
|
search_results: Optional[List[Dict]] = None |
|
|
processing_time: float |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BraveSearchClient: |
|
|
def __init__(self, api_key: Optional[str] = None): |
|
|
self.api_key = api_key or os.getenv("BRAVE_SEARCH_API_KEY") |
|
|
if not self.api_key: |
|
|
print("β οΈ Warning: Brave Search API key not found. Search disabled.") |
|
|
|
|
|
self.base_url = "https://api.search.brave.com/res/v1/web/search" |
|
|
self.headers = { |
|
|
"Accept": "application/json", |
|
|
"X-Subscription-Token": self.api_key |
|
|
} if self.api_key else {} |
|
|
|
|
|
def search(self, query: str, count: int = 5) -> List[Dict]: |
|
|
"""μΉ κ²μ μν""" |
|
|
if not self.api_key: |
|
|
return [] |
|
|
|
|
|
params = { |
|
|
"q": query, |
|
|
"count": count, |
|
|
"text_decorations": False, |
|
|
"search_lang": "ko", |
|
|
"country": "KR" |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.get( |
|
|
self.base_url, |
|
|
headers=self.headers, |
|
|
params=params, |
|
|
timeout=10 |
|
|
) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
|
|
|
results = [] |
|
|
if "web" in data and "results" in data["web"]: |
|
|
for item in data["web"]["results"][:count]: |
|
|
results.append({ |
|
|
"title": item.get("title", ""), |
|
|
"url": item.get("url", ""), |
|
|
"description": item.get("description", ""), |
|
|
"age": item.get("age", "") |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Search error: {str(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FireworksClient: |
|
|
def __init__(self, api_key: Optional[str] = None): |
|
|
self.api_key = api_key or os.getenv("FIREWORKS_API_KEY") |
|
|
if not self.api_key: |
|
|
raise ValueError("FIREWORKS_API_KEY is required!") |
|
|
|
|
|
self.base_url = "https://api.fireworks.ai/inference/v1/chat/completions" |
|
|
self.headers = { |
|
|
"Accept": "application/json", |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {self.api_key}" |
|
|
} |
|
|
|
|
|
def chat(self, messages: List[Dict], **kwargs) -> str: |
|
|
"""LLMκ³Ό λν""" |
|
|
payload = { |
|
|
"model": kwargs.get("model", "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507"), |
|
|
"messages": messages, |
|
|
"max_tokens": kwargs.get("max_tokens", 4096), |
|
|
"temperature": kwargs.get("temperature", 0.7), |
|
|
"top_p": kwargs.get("top_p", 1.0), |
|
|
"top_k": kwargs.get("top_k", 40) |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post( |
|
|
self.base_url, |
|
|
headers=self.headers, |
|
|
data=json.dumps(payload), |
|
|
timeout=60 |
|
|
) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
|
|
|
if "choices" in data and len(data["choices"]) > 0: |
|
|
return data["choices"][0]["message"]["content"] |
|
|
return "μλ΅μ μμ±ν μ μμ΅λλ€." |
|
|
|
|
|
except Exception as e: |
|
|
return f"μ€λ₯ λ°μ: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiAgentSystem: |
|
|
"""4λ¨κ³ λ©ν° μμ΄μ νΈ μ²λ¦¬ μμ€ν
""" |
|
|
|
|
|
def __init__(self, llm_client: FireworksClient, search_client: BraveSearchClient): |
|
|
self.llm = llm_client |
|
|
self.search = search_client |
|
|
self.agent_configs = self._initialize_agent_configs() |
|
|
|
|
|
def _initialize_agent_configs(self) -> Dict: |
|
|
"""κ° μμ΄μ νΈλ³ μ€μ μ΄κΈ°ν""" |
|
|
return { |
|
|
AgentRole.SUPERVISOR: { |
|
|
"temperature": 0.3, |
|
|
"system_prompt": """λΉμ μ κ°λ
μ μμ΄μ νΈμ
λλ€. |
|
|
μ¬μ©μμ μ§λ¬Έκ³Ό κ²μ κ²°κ³Όλ₯Ό λΆμνμ¬ λ΅λ³μ μ 체μ μΈ λ°©ν₯μ±κ³Ό ꡬ쑰λ₯Ό μ μν΄μΌ ν©λλ€. |
|
|
|
|
|
μν : |
|
|
1. μ§λ¬Έμ ν΅μ¬ μλ νμ
|
|
|
2. κ²μ κ²°κ³Όμμ ν΅μ¬ μ 보 μΆμΆ |
|
|
3. λ΅λ³μ΄ ν¬ν¨ν΄μΌ ν μ£Όμ μμλ€ μ μ |
|
|
4. λ
Όλ¦¬μ νλ¦κ³Ό ꡬ쑰 μ μ |
|
|
|
|
|
μΆλ ₯ νμ: |
|
|
- μ§λ¬Έ λΆμ: [ν΅μ¬ μλ] |
|
|
- μ£Όμ ν¬ν¨ μ¬ν: [νλͺ©λ€] |
|
|
- λ΅λ³ ꡬ쑰: [λ
Όλ¦¬μ νλ¦] |
|
|
- κ²μ κ²°κ³Ό νμ© λ°©μ: [μ΄λ€ μ 보λ₯Ό μ΄λ»κ² νμ©ν μ§]""" |
|
|
}, |
|
|
|
|
|
AgentRole.CREATIVE: { |
|
|
"temperature": 0.9, |
|
|
"system_prompt": """λΉμ μ μ°½μμ± μμ±μ μμ΄μ νΈμ
λλ€. |
|
|
κ°λ
μμ μ§μΉ¨μ λ°νμΌλ‘ μ°½μμ μ΄κ³ ν₯λ―Έλ‘μ΄ λ΅λ³μ μμ±ν΄μΌ ν©λλ€. |
|
|
|
|
|
μν : |
|
|
1. κ°λ
μμ ꡬ쑰λ₯Ό λ°λ₯΄λ μ°½μμ μΌλ‘ νμ₯ |
|
|
2. μμ, λΉμ , μ€ν 리ν
λ§ νμ© |
|
|
3. μ¬μ©μ κ΄μ μμ μ΄ν΄νκΈ° μ¬μ΄ μ€λͺ
μΆκ° |
|
|
4. μ€μ©μ μ΄κ³ ꡬ체μ μΈ μ‘°μΈ ν¬ν¨ |
|
|
5. λ
μ°½μ μΈ κ΄μ κ³Ό ν΅μ°° μ 곡 |
|
|
|
|
|
μ£Όμμ¬ν: |
|
|
- μ νμ±μ ν΄μΉμ§ μλ μ μμ μ°½μμ± λ°ν |
|
|
- κ²μ κ²°κ³Όλ₯Ό μ°½μμ μΌλ‘ μ¬κ΅¬μ± |
|
|
- μ¬μ©μ μ°Έμ¬λ₯Ό μ λνλ λ΄μ© ν¬ν¨""" |
|
|
}, |
|
|
|
|
|
AgentRole.CRITIC: { |
|
|
"temperature": 0.2, |
|
|
"system_prompt": """λΉμ μ λΉνμ μμ΄μ νΈμ
λλ€. |
|
|
μ°½μμ± μμ±μμ λ΅λ³μ κ²ν νκ³ κ°μ μ μ μ μν΄μΌ ν©λλ€. |
|
|
|
|
|
μν : |
|
|
1. μ¬μ€ κ΄κ³ κ²μ¦ |
|
|
2. λ
Όλ¦¬μ μΌκ΄μ± νμΈ |
|
|
3. μ€ν΄μ μμ§κ° μλ νν μ§μ |
|
|
4. λλ½λ μ€μ μ 보 νμΈ |
|
|
5. κ°μ λ°©ν₯ ꡬ체μ μ μ |
|
|
|
|
|
νκ° κΈ°μ€: |
|
|
- μ νμ±: μ¬μ€κ³Ό λ°μ΄ν°μ μ νμ± |
|
|
- μμ μ±: μ§λ¬Έμ λν μΆ©λΆν λ΅λ³ μ¬λΆ |
|
|
- λͺ
νμ±: μ΄ν΄νκΈ° μ¬μ΄ μ€λͺ
μΈμ§ |
|
|
- μ μ©μ±: μ€μ λ‘ λμμ΄ λλ μ 보μΈμ§ |
|
|
- μ λ’°μ±: κ²μ¦ κ°λ₯ν μΆμ² ν¬ν¨ μ¬λΆ |
|
|
|
|
|
μΆλ ₯ νμ: |
|
|
β
κΈμ μ μΈ‘λ©΄: [μλ μ λ€] |
|
|
β οΈ κ°μ νμ: [λ¬Έμ μ κ³Ό κ°μ λ°©μ] |
|
|
π‘ μΆκ° μ μ: [보μν λ΄μ©]""" |
|
|
}, |
|
|
|
|
|
AgentRole.FINALIZER: { |
|
|
"temperature": 0.5, |
|
|
"system_prompt": """λΉμ μ μ΅μ’
κ°λ
μμ
λλ€. |
|
|
λͺ¨λ μμ΄μ νΈμ μ견μ μ’
ν©νμ¬ μ΅μ’
λ΅λ³μ μμ±ν΄μΌ ν©λλ€. |
|
|
|
|
|
μν : |
|
|
1. μ°½μμ± μμ±μμ λ΅λ³μ κΈ°λ°μΌλ‘ |
|
|
2. λΉνμμ νΌλλ°±μ λ°μνμ¬ |
|
|
3. κ°λ
μμ μ΄κΈ° ꡬ쑰λ₯Ό μ μ§νλ©° |
|
|
4. λ
Όλ¦¬μ μ΄κ³ μ΄ν΄νκΈ° μ¬μ΄ μ΅μ’
λ΅λ³ μμ± |
|
|
|
|
|
μ΅μ’
λ΅λ³ κΈ°μ€: |
|
|
- μ νμ±κ³Ό μ°½μμ±μ κ· ν |
|
|
- λͺ
νν ꡬ쑰μ λ
Όλ¦¬μ νλ¦ |
|
|
- μ€μ©μ μ΄κ³ μ μ©ν μ 보 |
|
|
- μ¬μ©μ μΉνμ μΈ ν€ |
|
|
- κ²μ κ²°κ³Ό μΆμ² λͺ
μ |
|
|
|
|
|
λ°λμ ν¬ν¨ν μμ: |
|
|
1. ν΅μ¬ λ΅λ³ (μ§μ μ μΈ μλ΅) |
|
|
2. μμΈ μ€λͺ
(λ°°κ²½κ³Ό λ§₯λ½) |
|
|
3. μ€μ©μ μ‘°μΈ (ν΄λΉ μ) |
|
|
4. μ°Έκ³ μλ£ (κ²μ κ²°κ³Ό κΈ°λ°)""" |
|
|
} |
|
|
} |
|
|
|
|
|
def _format_search_results(self, results: List[Dict]) -> str: |
|
|
"""κ²μ κ²°κ³Ό ν¬λ§·ν
""" |
|
|
if not results: |
|
|
return "κ²μ κ²°κ³Ό μμ" |
|
|
|
|
|
formatted = [] |
|
|
for i, result in enumerate(results, 1): |
|
|
formatted.append(f""" |
|
|
[κ²μκ²°κ³Ό {i}] |
|
|
μ λͺ©: {result.get('title', 'N/A')} |
|
|
URL: {result.get('url', 'N/A')} |
|
|
λ΄μ©: {result.get('description', 'N/A')} |
|
|
κ²μ: {result.get('age', 'N/A')}""") |
|
|
|
|
|
return "\n".join(formatted) |
|
|
|
|
|
async def process_with_agents( |
|
|
self, |
|
|
query: str, |
|
|
search_results: List[Dict], |
|
|
config: Dict |
|
|
) -> FinalResponse: |
|
|
"""λ©ν° μμ΄μ νΈ νμ΄νλΌμΈ μ€ν""" |
|
|
|
|
|
start_time = time.time() |
|
|
agent_responses = [] |
|
|
search_context = self._format_search_results(search_results) |
|
|
|
|
|
|
|
|
supervisor_prompt = f""" |
|
|
μ¬μ©μ μ§λ¬Έ: {query} |
|
|
|
|
|
κ²μ κ²°κ³Ό: |
|
|
{search_context} |
|
|
|
|
|
μ μ 보λ₯Ό λ°νμΌλ‘ λ΅λ³μ λ°©ν₯μ±κ³Ό ꡬ쑰λ₯Ό μ μνμΈμ.""" |
|
|
|
|
|
supervisor_response = self.llm.chat( |
|
|
messages=[ |
|
|
{"role": "system", "content": self.agent_configs[AgentRole.SUPERVISOR]["system_prompt"]}, |
|
|
{"role": "user", "content": supervisor_prompt} |
|
|
], |
|
|
temperature=self.agent_configs[AgentRole.SUPERVISOR]["temperature"], |
|
|
max_tokens=config.get("max_tokens", 1000) |
|
|
) |
|
|
|
|
|
agent_responses.append(AgentResponse( |
|
|
role=AgentRole.SUPERVISOR, |
|
|
content=supervisor_response |
|
|
)) |
|
|
|
|
|
|
|
|
creative_prompt = f""" |
|
|
μ¬μ©μ μ§λ¬Έ: {query} |
|
|
|
|
|
κ°λ
μ μ§μΉ¨: |
|
|
{supervisor_response} |
|
|
|
|
|
κ²μ κ²°κ³Ό: |
|
|
{search_context} |
|
|
|
|
|
μ μ§μΉ¨κ³Ό μ 보λ₯Ό λ°νμΌλ‘ μ°½μμ μ΄κ³ μ μ©ν λ΅λ³μ μμ±νμΈμ.""" |
|
|
|
|
|
creative_response = self.llm.chat( |
|
|
messages=[ |
|
|
{"role": "system", "content": self.agent_configs[AgentRole.CREATIVE]["system_prompt"]}, |
|
|
{"role": "user", "content": creative_prompt} |
|
|
], |
|
|
temperature=self.agent_configs[AgentRole.CREATIVE]["temperature"], |
|
|
max_tokens=config.get("max_tokens", 2000) |
|
|
) |
|
|
|
|
|
agent_responses.append(AgentResponse( |
|
|
role=AgentRole.CREATIVE, |
|
|
content=creative_response |
|
|
)) |
|
|
|
|
|
|
|
|
critic_prompt = f""" |
|
|
μλ³Έ μ§λ¬Έ: {query} |
|
|
|
|
|
μ°½μμ± μμ±μμ λ΅λ³: |
|
|
{creative_response} |
|
|
|
|
|
κ²μ κ²°κ³Ό: |
|
|
{search_context} |
|
|
|
|
|
μ λ΅λ³μ κ²ν νκ³ κ°μ μ μ μ μνμΈμ.""" |
|
|
|
|
|
critic_response = self.llm.chat( |
|
|
messages=[ |
|
|
{"role": "system", "content": self.agent_configs[AgentRole.CRITIC]["system_prompt"]}, |
|
|
{"role": "user", "content": critic_prompt} |
|
|
], |
|
|
temperature=self.agent_configs[AgentRole.CRITIC]["temperature"], |
|
|
max_tokens=config.get("max_tokens", 1000) |
|
|
) |
|
|
|
|
|
agent_responses.append(AgentResponse( |
|
|
role=AgentRole.CRITIC, |
|
|
content=critic_response |
|
|
)) |
|
|
|
|
|
|
|
|
final_prompt = f""" |
|
|
μ¬μ©μ μ§λ¬Έ: {query} |
|
|
|
|
|
μ°½μμ± μμ±μμ λ΅λ³: |
|
|
{creative_response} |
|
|
|
|
|
λΉνμμ νΌλλ°±: |
|
|
{critic_response} |
|
|
|
|
|
μ΄κΈ° κ°λ
μ μ§μΉ¨: |
|
|
{supervisor_response} |
|
|
|
|
|
κ²μ κ²°κ³Ό: |
|
|
{search_context} |
|
|
|
|
|
λͺ¨λ μ견μ μ’
ν©νμ¬ μ΅μ’
λ΅λ³μ μμ±νμΈμ. |
|
|
λΉνμμ νΌλλ°±μ λ°μνμ¬ κ°μ λ λ²μ μ λ§λ€μ΄μ£ΌμΈμ.""" |
|
|
|
|
|
final_response = self.llm.chat( |
|
|
messages=[ |
|
|
{"role": "system", "content": self.agent_configs[AgentRole.FINALIZER]["system_prompt"]}, |
|
|
{"role": "user", "content": final_prompt} |
|
|
], |
|
|
temperature=self.agent_configs[AgentRole.FINALIZER]["temperature"], |
|
|
max_tokens=config.get("max_tokens", 3000) |
|
|
) |
|
|
|
|
|
agent_responses.append(AgentResponse( |
|
|
role=AgentRole.FINALIZER, |
|
|
content=final_response |
|
|
)) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
return FinalResponse( |
|
|
final_answer=final_response, |
|
|
agent_responses=agent_responses, |
|
|
search_results=search_results, |
|
|
processing_time=processing_time |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_gradio_interface(multi_agent_system: MultiAgentSystem, search_client: BraveSearchClient): |
|
|
"""Gradio μΈν°νμ΄μ€ μμ±""" |
|
|
|
|
|
async def process_query( |
|
|
message: str, |
|
|
history: List[List[str]], |
|
|
use_search: bool, |
|
|
show_agent_thoughts: bool, |
|
|
search_count: int, |
|
|
temperature: float, |
|
|
max_tokens: int |
|
|
): |
|
|
"""쿼리 μ²λ¦¬ ν¨μ""" |
|
|
|
|
|
if not message: |
|
|
return "", history, "", "" |
|
|
|
|
|
try: |
|
|
|
|
|
search_results = [] |
|
|
if use_search and search_client.api_key: |
|
|
search_results = search_client.search(message, count=search_count) |
|
|
|
|
|
|
|
|
config = { |
|
|
"temperature": temperature, |
|
|
"max_tokens": max_tokens |
|
|
} |
|
|
|
|
|
|
|
|
response = await multi_agent_system.process_with_agents( |
|
|
query=message, |
|
|
search_results=search_results, |
|
|
config=config |
|
|
) |
|
|
|
|
|
|
|
|
agent_thoughts = "" |
|
|
if show_agent_thoughts: |
|
|
agent_thoughts = "## π€ μμ΄μ νΈ μ¬κ³ κ³Όμ \n\n" |
|
|
|
|
|
for agent_resp in response.agent_responses: |
|
|
role_emoji = { |
|
|
AgentRole.SUPERVISOR: "π", |
|
|
AgentRole.CREATIVE: "π¨", |
|
|
AgentRole.CRITIC: "π", |
|
|
AgentRole.FINALIZER: "β
" |
|
|
} |
|
|
|
|
|
role_name = { |
|
|
AgentRole.SUPERVISOR: "κ°λ
μ (μ΄κΈ° ꡬ쑰ν)", |
|
|
AgentRole.CREATIVE: "μ°½μμ± μμ±μ", |
|
|
AgentRole.CRITIC: "λΉνμ", |
|
|
AgentRole.FINALIZER: "μ΅μ’
κ°λ
μ" |
|
|
} |
|
|
|
|
|
agent_thoughts += f"### {role_emoji[agent_resp.role]} {role_name[agent_resp.role]}\n" |
|
|
agent_thoughts += f"{agent_resp.content[:500]}...\n\n" |
|
|
|
|
|
|
|
|
search_display = "" |
|
|
if search_results: |
|
|
search_display = "## π μ°Έκ³ μλ£\n\n" |
|
|
for i, result in enumerate(search_results, 1): |
|
|
search_display += f"**{i}. [{result['title']}]({result['url']})**\n" |
|
|
search_display += f" {result['description'][:100]}...\n\n" |
|
|
|
|
|
|
|
|
final_answer = response.final_answer |
|
|
final_answer += f"\n\n---\nβ±οΈ *μ²λ¦¬ μκ°: {response.processing_time:.2f}μ΄*" |
|
|
|
|
|
|
|
|
history.append([message, final_answer]) |
|
|
|
|
|
return "", history, agent_thoughts, search_display |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"β μ€λ₯ λ°μ: {str(e)}" |
|
|
history.append([message, error_msg]) |
|
|
return "", history, "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Multi-Agent RAG System", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-width: 1400px !important; |
|
|
margin: auto !important; |
|
|
} |
|
|
#chatbot { |
|
|
height: 600px !important; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
gr.Markdown(""" |
|
|
# π§ Multi-Agent RAG System |
|
|
### 4λ¨κ³ μμ΄μ νΈ νμ
μ ν΅ν κ³ νμ§ λ΅λ³ μμ± |
|
|
|
|
|
**μ²λ¦¬ κ³Όμ :** κ°λ
μ(ꡬ쑰ν) β μ°½μμ± μμ±μ(μ°½μμ λ΅λ³) β λΉνμ(κ²μ¦) β μ΅μ’
κ°λ
μ(μ’
ν©) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
height=500, |
|
|
label="π¬ λν", |
|
|
elem_id="chatbot" |
|
|
) |
|
|
|
|
|
msg = gr.Textbox( |
|
|
label="μ§λ¬Έ μ
λ ₯", |
|
|
placeholder="μ§λ¬Έμ μ
λ ₯νμΈμ... (λ©ν° μμ΄μ νΈκ° νμ
νμ¬ λ΅λ³ν©λλ€)", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
submit = gr.Button("π μ μ‘", variant="primary") |
|
|
clear = gr.Button("π μ΄κΈ°ν") |
|
|
|
|
|
|
|
|
with gr.Accordion("π€ μμ΄μ νΈ μ¬κ³ κ³Όμ ", open=False): |
|
|
agent_thoughts = gr.Markdown() |
|
|
|
|
|
|
|
|
with gr.Accordion("π κ²μ μμ€", open=False): |
|
|
search_sources = gr.Markdown() |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### βοΈ μ€μ ") |
|
|
|
|
|
with gr.Group(): |
|
|
use_search = gr.Checkbox( |
|
|
label="π μΉ κ²μ μ¬μ©", |
|
|
value=True |
|
|
) |
|
|
|
|
|
show_agent_thoughts = gr.Checkbox( |
|
|
label="π§ μμ΄μ νΈ μ¬κ³ κ³Όμ νμ", |
|
|
value=True |
|
|
) |
|
|
|
|
|
search_count = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=1, |
|
|
label="κ²μ κ²°κ³Ό μ" |
|
|
) |
|
|
|
|
|
temperature = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
value=0.6, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
) |
|
|
|
|
|
max_tokens = gr.Slider( |
|
|
minimum=500, |
|
|
maximum=4000, |
|
|
value=2000, |
|
|
step=100, |
|
|
label="Max Tokens" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π μμ€ν
μ 보 |
|
|
|
|
|
**μμ΄μ νΈ μν :** |
|
|
- π **κ°λ
μ**: ꡬ쑰 μ€κ³ |
|
|
- π¨ **μ°½μμ±**: μ°½μμ μμ± |
|
|
- π **λΉνμ**: κ²μ¦/κ°μ |
|
|
- β
**μ΅μ’
**: μ’
ν©/μμ± |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"μμ μ»΄ν¨ν°μ μ리λ₯Ό μ΄λ±νμλ μ΄ν΄ν μ μκ² μ€λͺ
ν΄μ€", |
|
|
"2024λ
AI κΈ°μ νΈλ λμ λ―Έλ μ λ§μ?", |
|
|
"ν¨κ³Όμ μΈ νλ‘κ·Έλλ° νμ΅ λ°©λ²μ λ¨κ³λ³λ‘ μλ €μ€", |
|
|
"κΈ°ν λ³νκ° νκ΅ κ²½μ μ λ―ΈμΉλ μν₯ λΆμν΄μ€", |
|
|
"μ€ννΈμ
μ°½μ
μ κ³ λ €ν΄μΌ ν ν΅μ¬ μμλ€μ?" |
|
|
], |
|
|
inputs=msg |
|
|
) |
|
|
|
|
|
|
|
|
submit.click( |
|
|
process_query, |
|
|
inputs=[msg, chatbot, use_search, show_agent_thoughts, |
|
|
search_count, temperature, max_tokens], |
|
|
outputs=[msg, chatbot, agent_thoughts, search_sources] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
process_query, |
|
|
inputs=[msg, chatbot, use_search, show_agent_thoughts, |
|
|
search_count, temperature, max_tokens], |
|
|
outputs=[msg, chatbot, agent_thoughts, search_sources] |
|
|
) |
|
|
|
|
|
clear.click( |
|
|
lambda: (None, None, None), |
|
|
None, |
|
|
[chatbot, agent_thoughts, search_sources] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""μ± μλͺ
μ£ΌκΈ° κ΄λ¦¬""" |
|
|
print("\n" + "="*60) |
|
|
print("π Multi-Agent RAG System Starting...") |
|
|
print("="*60) |
|
|
yield |
|
|
print("\nπ Shutting down...") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Multi-Agent RAG System API", |
|
|
description="4-Stage Agent Collaboration System with RAG", |
|
|
version="3.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"] |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
llm_client = FireworksClient() |
|
|
search_client = BraveSearchClient() |
|
|
multi_agent_system = MultiAgentSystem(llm_client, search_client) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Initialization error: {e}") |
|
|
llm_client = None |
|
|
search_client = None |
|
|
multi_agent_system = None |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""λ£¨νΈ μλν¬μΈνΈ""" |
|
|
return { |
|
|
"name": "Multi-Agent RAG System", |
|
|
"version": "3.0.0", |
|
|
"status": "running", |
|
|
"ui": "http://localhost:8000/ui", |
|
|
"docs": "http://localhost:8000/docs" |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/api/chat") |
|
|
async def chat_endpoint(request: ChatRequest): |
|
|
"""λ©ν° μμ΄μ νΈ μ±ν
API""" |
|
|
if not multi_agent_system: |
|
|
raise HTTPException(status_code=500, detail="System not initialized") |
|
|
|
|
|
try: |
|
|
|
|
|
search_results = [] |
|
|
if request.use_search and search_client.api_key: |
|
|
last_message = request.messages[-1].content if request.messages else "" |
|
|
search_results = search_client.search(last_message, count=5) |
|
|
|
|
|
|
|
|
response = await multi_agent_system.process_with_agents( |
|
|
query=request.messages[-1].content, |
|
|
search_results=search_results, |
|
|
config={ |
|
|
"temperature": request.temperature, |
|
|
"max_tokens": request.max_tokens |
|
|
} |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""ν¬μ€ 체ν¬""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"services": { |
|
|
"llm": "ready" if llm_client else "not configured", |
|
|
"search": "ready" if search_client and search_client.api_key else "not configured", |
|
|
"multi_agent": "ready" if multi_agent_system else "not configured" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if multi_agent_system: |
|
|
gradio_app = create_gradio_interface(multi_agent_system, search_client) |
|
|
app = gr.mount_gradio_app(app, gradio_app, path="/ui") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(""" |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β π§ Multi-Agent RAG-Enhanced LLM System π§ β |
|
|
β β |
|
|
β κ°λ
μ β μ°½μμ± μμ±μ β λΉνμ β μ΅μ’
κ°λ
μ β |
|
|
β 4λ¨κ³ νμ
μ ν΅ν κ³ νμ§ λ΅λ³ μμ± β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
""") |
|
|
|
|
|
|
|
|
if not os.getenv("FIREWORKS_API_KEY"): |
|
|
print("\nβ οΈ FIREWORKS_API_KEYκ° μ€μ λμ§ μμμ΅λλ€.") |
|
|
key = input("Fireworks API Key μ
λ ₯: ").strip() |
|
|
if key: |
|
|
os.environ["FIREWORKS_API_KEY"] = key |
|
|
llm_client = FireworksClient(key) |
|
|
|
|
|
if not os.getenv("BRAVE_SEARCH_API_KEY"): |
|
|
print("\nβ οΈ BRAVE_SEARCH_API_KEYκ° μ€μ λμ§ μμμ΅λλ€.") |
|
|
print(" (μ νμ¬ν: κ²μ κΈ°λ₯μ μ¬μ©νλ €λ©΄ μ
λ ₯)") |
|
|
key = input("Brave Search API Key μ
λ ₯ (Enter=건λλ°κΈ°): ").strip() |
|
|
if key: |
|
|
os.environ["BRAVE_SEARCH_API_KEY"] = key |
|
|
search_client = BraveSearchClient(key) |
|
|
|
|
|
|
|
|
if llm_client: |
|
|
multi_agent_system = MultiAgentSystem(llm_client, search_client) |
|
|
gradio_app = create_gradio_interface(multi_agent_system, search_client) |
|
|
app = gr.mount_gradio_app(app, gradio_app, path="/ui") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("β
μμ€ν
μ€λΉ μλ£!") |
|
|
print("="*60) |
|
|
print("\nπ μ μ μ£Όμ:") |
|
|
print(" π¨ Gradio UI: http://localhost:8000/ui") |
|
|
print(" π API Docs: http://localhost:8000/docs") |
|
|
print(" π§ Chat API: POST http://localhost:8000/api/chat") |
|
|
print("\nπ‘ Ctrl+Cλ₯Ό λλ¬ μ’
λ£") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=8000, |
|
|
reload=False, |
|
|
log_level="info" |
|
|
) |