File size: 2,782 Bytes
b3d3593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import logging
from collections import namedtuple

import tiktoken
from langchain import OpenAI

LLM_NAME = "text-davinci-003"
# Encoding for text-davinci-003
ENCODING_NAME = "p50k_base"
ENCODING = tiktoken.get_encoding(ENCODING_NAME)
# Max input tokens for text-davinci-003
LLM_MAX_TOKENS = 4096

# As specified in huggingGPT paper
TASK_PLANNING_LOGIT_BIAS = 0.1
MODEL_SELECTION_LOGIT_BIAS = 5

logger = logging.getLogger(__name__)

LLMs = namedtuple(
    "LLMs",
    [
        "task_planning_llm",
        "model_selection_llm",
        "model_inference_llm",
        "response_generation_llm",
        "output_fixing_llm",
    ],
)


def create_llms() -> LLMs:
    """Create various LLM agents according to the huggingGPT paper's specifications."""
    logger.info(f"Creating {LLM_NAME} LLMs")

    task_parsing_highlight_ids = get_token_ids_for_task_parsing()
    choose_model_highlight_ids = get_token_ids_for_choose_model()

    task_planning_llm = OpenAI(
        model_name=LLM_NAME,
        temperature=0,
        logit_bias={
            token_id: TASK_PLANNING_LOGIT_BIAS
            for token_id in task_parsing_highlight_ids
        },
    )
    model_selection_llm = OpenAI(
        model_name=LLM_NAME,
        temperature=0,
        logit_bias={
            token_id: MODEL_SELECTION_LOGIT_BIAS
            for token_id in choose_model_highlight_ids
        },
    )
    model_inference_llm = OpenAI(model_name=LLM_NAME, temperature=0)
    response_generation_llm = OpenAI(model_name=LLM_NAME, temperature=0)
    output_fixing_llm = OpenAI(model_name=LLM_NAME, temperature=0)
    return LLMs(
        task_planning_llm=task_planning_llm,
        model_selection_llm=model_selection_llm,
        model_inference_llm=model_inference_llm,
        response_generation_llm=response_generation_llm,
        output_fixing_llm=output_fixing_llm,
    )


def get_token_ids_for_task_parsing() -> list[int]:
    text = """{"task": "text-classification",  "token-classification", "text2text-generation", "summarization", "translation",  "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "args", "text", "path", "dep", "id", "<GENERATED>-"}"""
    res = ENCODING.encode(text)
    res = list(set(res))
    return res


def get_token_ids_for_choose_model() -> list[int]:
    text = """{"id": "reason"}"""
    res = ENCODING.encode(text)
    res = list(set(res))
    return res


def count_tokens(text: str) -> int:
    return len(ENCODING.encode(text))