|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from functools import wraps |
|
|
from typing import Any, Literal |
|
|
|
|
|
from gradio import ChatMessage |
|
|
from gradio.components.chatbot import Message |
|
|
|
|
|
COMMUNITY_POSTFIX_URL = "/discussions" |
|
|
DEBUG_MODE = False |
|
|
DEBUG_MODEL = False |
|
|
|
|
|
models_config = { |
|
|
"Apriel-1.6-15B-Thinker": { |
|
|
"MODEL_DISPLAY_NAME": "Apriel-1.6-15B-Thinker", |
|
|
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-1.6-15b-Thinker", |
|
|
"MODEL_NAME": "PLACEHOLDER", |
|
|
"VLLM_API_URL": "PLACEHOLDER", |
|
|
"VLLM_API_URL_LIST": None, |
|
|
"AUTH_TOKEN": None, |
|
|
"REASONING": True, |
|
|
"MULTIMODAL": True, |
|
|
"TEMPERATURE": 1.0, |
|
|
"OUTPUT_TAG_START": "[BEGIN FINAL RESPONSE]", |
|
|
"OUTPUT_TAG_END": "", |
|
|
"OUTPUT_STOP_TOKEN": "<|end|>" |
|
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
def get_model_config(model_name: str) -> dict: |
|
|
config = models_config.get(model_name) |
|
|
config['MODEL_KEY'] = model_name |
|
|
|
|
|
if not config: |
|
|
raise ValueError(f"Model {model_name} not found in models_config") |
|
|
if not config.get("MODEL_NAME"): |
|
|
raise ValueError(f"Model name not found in config for {model_name}") |
|
|
if not config.get("VLLM_API_URL"): |
|
|
raise ValueError(f"VLLM API URL not found in config for {model_name}") |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
def _log_message(prefix, message, icon=""): |
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) |
|
|
if len(icon) > 0: |
|
|
icon = f"{icon} " |
|
|
print(f"{timestamp}: {prefix} {icon}{message}") |
|
|
|
|
|
|
|
|
def log_debug(message): |
|
|
if DEBUG_MODE is True: |
|
|
_log_message("DEBUG", message) |
|
|
|
|
|
|
|
|
def log_info(message): |
|
|
_log_message("INFO ", message) |
|
|
|
|
|
|
|
|
def log_warning(message): |
|
|
_log_message("WARN ", message, "⚠️") |
|
|
|
|
|
|
|
|
def log_error(message): |
|
|
_log_message("ERROR", message, "‼️") |
|
|
|
|
|
|
|
|
|
|
|
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None: |
|
|
if not DEBUG_MODE: |
|
|
return |
|
|
|
|
|
if type == "messages": |
|
|
all_valid = all( |
|
|
isinstance(message, dict) |
|
|
and "role" in message |
|
|
and "content" in message |
|
|
or isinstance(message, ChatMessage | Message) |
|
|
for message in messages |
|
|
) |
|
|
if not all_valid: |
|
|
|
|
|
for i, message in enumerate(messages): |
|
|
if not (isinstance(message, dict) and |
|
|
"role" in message and |
|
|
"content" in message) and not isinstance(message, ChatMessage | Message): |
|
|
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr) |
|
|
break |
|
|
|
|
|
raise Exception( |
|
|
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object." |
|
|
) |
|
|
|
|
|
|
|
|
elif not all( |
|
|
isinstance(message, (tuple, list)) and len(message) == 2 |
|
|
for message in messages |
|
|
): |
|
|
raise Exception( |
|
|
"Data incompatible with tuples format. Each message should be a list of length 2." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def logged_event_handler(log_msg='', event_handler=None, log_timer=None, clear_timer=False): |
|
|
@wraps(event_handler) |
|
|
def wrapped_event_handler(*args, **kwargs): |
|
|
|
|
|
if log_timer: |
|
|
if clear_timer: |
|
|
log_timer.clear() |
|
|
log_timer.add_step(f"Start: {log_debug}") |
|
|
log_debug(f"::: Before event: {log_msg}") |
|
|
|
|
|
|
|
|
result = event_handler(*args, **kwargs) |
|
|
|
|
|
|
|
|
if log_timer: |
|
|
log_timer.add_step(f"Completed: {log_msg}") |
|
|
log_debug(f"::: After event: {log_msg}") |
|
|
|
|
|
return result |
|
|
|
|
|
return wrapped_event_handler |
|
|
|