Spaces:
Runtime error
Runtime error
| import logging | |
| from collections import namedtuple | |
| import tiktoken | |
| from langchain import OpenAI | |
| LLM_NAME = "text-davinci-003" | |
| # Encoding for text-davinci-003 | |
| ENCODING_NAME = "p50k_base" | |
| ENCODING = tiktoken.get_encoding(ENCODING_NAME) | |
| # Max input tokens for text-davinci-003 | |
| LLM_MAX_TOKENS = 4096 | |
| # As specified in huggingGPT paper | |
| TASK_PLANNING_LOGIT_BIAS = 0.1 | |
| MODEL_SELECTION_LOGIT_BIAS = 5 | |
| logger = logging.getLogger(__name__) | |
| LLMs = namedtuple( | |
| "LLMs", | |
| [ | |
| "task_planning_llm", | |
| "model_selection_llm", | |
| "model_inference_llm", | |
| "response_generation_llm", | |
| "output_fixing_llm", | |
| ], | |
| ) | |
| def create_llms() -> LLMs: | |
| """Create various LLM agents according to the huggingGPT paper's specifications.""" | |
| logger.info(f"Creating {LLM_NAME} LLMs") | |
| task_parsing_highlight_ids = get_token_ids_for_task_parsing() | |
| choose_model_highlight_ids = get_token_ids_for_choose_model() | |
| task_planning_llm = OpenAI( | |
| model_name=LLM_NAME, | |
| temperature=0, | |
| logit_bias={ | |
| token_id: TASK_PLANNING_LOGIT_BIAS | |
| for token_id in task_parsing_highlight_ids | |
| }, | |
| ) | |
| model_selection_llm = OpenAI( | |
| model_name=LLM_NAME, | |
| temperature=0, | |
| logit_bias={ | |
| token_id: MODEL_SELECTION_LOGIT_BIAS | |
| for token_id in choose_model_highlight_ids | |
| }, | |
| ) | |
| model_inference_llm = OpenAI(model_name=LLM_NAME, temperature=0) | |
| response_generation_llm = OpenAI(model_name=LLM_NAME, temperature=0) | |
| output_fixing_llm = OpenAI(model_name=LLM_NAME, temperature=0) | |
| return LLMs( | |
| task_planning_llm=task_planning_llm, | |
| model_selection_llm=model_selection_llm, | |
| model_inference_llm=model_inference_llm, | |
| response_generation_llm=response_generation_llm, | |
| output_fixing_llm=output_fixing_llm, | |
| ) | |
| def get_token_ids_for_task_parsing() -> list[int]: | |
| text = """{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "args", "text", "path", "dep", "id", "<GENERATED>-"}""" | |
| res = ENCODING.encode(text) | |
| res = list(set(res)) | |
| return res | |
| def get_token_ids_for_choose_model() -> list[int]: | |
| text = """{"id": "reason"}""" | |
| res = ENCODING.encode(text) | |
| res = list(set(res)) | |
| return res | |
| def count_tokens(text: str) -> int: | |
| return len(ENCODING.encode(text)) | |