Spaces:
Runtime error
Runtime error
| import logging | |
| from langchain import LLMChain | |
| from langchain.llms.base import BaseLLM | |
| from langchain.prompts import load_prompt | |
| from hugginggpt.exceptions import TaskPlanningException, wrap_exceptions | |
| from hugginggpt.history import ConversationHistory | |
| from hugginggpt.llm_factory import LLM_MAX_TOKENS, count_tokens | |
| from hugginggpt.resources import get_prompt_resource | |
| from hugginggpt.task_parsing import Task, parse_tasks | |
| logger = logging.getLogger(__name__) | |
| MAIN_PROMPT_TOKENS = 800 | |
| MAX_HISTORY_TOKENS = LLM_MAX_TOKENS - MAIN_PROMPT_TOKENS | |
| def plan_tasks( | |
| user_input: str, history: ConversationHistory, llm: BaseLLM | |
| ) -> list[Task]: | |
| """Use LLM agent to plan tasks in order solve user request.""" | |
| logger.info("Starting task planning") | |
| task_planning_prompt_template = load_prompt( | |
| get_prompt_resource("task-planning-few-shot-prompt.json") | |
| ) | |
| llm_chain = LLMChain(prompt=task_planning_prompt_template, llm=llm) | |
| history_truncated = truncate_history(history) | |
| output = llm_chain.predict( | |
| user_input=user_input, history=history_truncated, stop=["<im_end>"] | |
| ) | |
| logger.info(f"Task planning raw output: {output}") | |
| tasks = parse_tasks(output) | |
| return tasks | |
| def truncate_history(history: ConversationHistory) -> ConversationHistory: | |
| """Truncate history to fit within the max token limit for the task planning LLM""" | |
| example_prompt_template = load_prompt( | |
| get_prompt_resource("task-planning-example-prompt.json") | |
| ) | |
| token_counter = 0 | |
| n_messages = 0 | |
| # Iterate through history backwards in pairs, to ensure most recent messages are kept | |
| for i in range(0, len(history), 2): | |
| user_message = history[-(i + 2)] | |
| assistant_message = history[-(i + 1)] | |
| # Turn messages into LLM prompt string | |
| history_text = example_prompt_template.format( | |
| example_input=user_message["content"], | |
| example_output=assistant_message["content"], | |
| ) | |
| n_tokens = count_tokens(history_text) | |
| if token_counter + n_tokens <= MAX_HISTORY_TOKENS: | |
| n_messages += 2 | |
| token_counter += n_tokens | |
| else: | |
| break | |
| start = len(history) - n_messages | |
| return history[start:] | |