File size: 4,339 Bytes
31086ae
 
d02622b
 
 
 
 
 
 
 
31086ae
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31086ae
d02622b
 
 
 
 
 
982cb95
 
d02622b
31086ae
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31086ae
d02622b
 
 
 
 
 
 
31086ae
 
 
 
 
 
 
 
 
 
d02622b
 
 
982cb95
d02622b
31086ae
 
d02622b
 
 
 
982cb95
d02622b
 
31086ae
 
 
 
 
 
 
 
 
 
d02622b
31086ae
 
 
d02622b
982cb95
d02622b
982cb95
31086ae
982cb95
 
 
 
 
 
 
 
 
 
 
 
d02622b
 
 
 
31086ae
 
d02622b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import math
import uuid
from typing import Any, List, Optional

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token


class VLLMWrapper(BaseLLMWrapper):
    """
    Async inference backend based on vLLM.
    """

    def __init__(
        self,
        model: str,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        temperature: float = 0.0,
        top_p: float = 1.0,
        topk: int = 5,
        **kwargs: Any,
    ):
        super().__init__(temperature=temperature, top_p=top_p, **kwargs)
        try:
            from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
        except ImportError as exc:
            raise ImportError(
                "VLLMWrapper requires vllm. Install it with: uv pip install vllm"
            ) from exc

        self.SamplingParams = SamplingParams

        engine_args = AsyncEngineArgs(
            model=model,
            tensor_parallel_size=int(tensor_parallel_size),
            gpu_memory_utilization=float(gpu_memory_utilization),
            trust_remote_code=kwargs.get("trust_remote_code", True),
            disable_log_stats=False,
        )
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)
        self.temperature = temperature
        self.top_p = top_p
        self.topk = topk

    @staticmethod
    def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
        msgs = history or []
        lines = []
        for m in msgs:
            if isinstance(m, dict):
                role = m.get("role", "")
                content = m.get("content", "")
                lines.append(f"{role}: {content}")
            else:
                lines.append(str(m))
        lines.append(prompt)
        return "\n".join(lines)

    async def generate_answer(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> str:
        full_prompt = self._build_inputs(text, history)
        request_id = f"graphgen_req_{uuid.uuid4()}"

        sp = self.SamplingParams(
            temperature=self.temperature if self.temperature > 0 else 1.0,
            top_p=self.top_p if self.temperature > 0 else 1.0,
            max_tokens=extra.get("max_new_tokens", 512),
        )

        result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)

        final_output = None
        async for request_output in result_generator:
            final_output = request_output

        if not final_output or not final_output.outputs:
            return ""

        return final_output.outputs[0].text

    async def generate_topk_per_token(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
        ) -> List[Token]:
        full_prompt = self._build_inputs(text, history)
        request_id = f"graphgen_topk_{uuid.uuid4()}"

        sp = self.SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=self.topk,
            prompt_logprobs=1,
        )

        result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)

        final_output = None
        async for request_output in result_generator:
            final_output = request_output

        if (
            not final_output
            or not final_output.outputs
            or not final_output.outputs[0].logprobs
        ):
            return []

        top_logprobs = final_output.outputs[0].logprobs[0]

        candidate_tokens = []
        for _, logprob_obj in top_logprobs.items():
            tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
            prob = float(math.exp(logprob_obj.logprob))
            candidate_tokens.append(Token(tok_str, prob))

        candidate_tokens.sort(key=lambda x: -x.prob)

        if candidate_tokens:
            main_token = Token(
                text=candidate_tokens[0].text,
                prob=candidate_tokens[0].prob,
                top_candidates=candidate_tokens
            )
            return [main_token]
        return []

    async def generate_inputs_prob(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> List[Token]:
        raise NotImplementedError(
            "VLLMWrapper does not support per-token logprobs yet."
        )