Spaces:
Running
Running
File size: 4,339 Bytes
31086ae d02622b 31086ae d02622b 31086ae d02622b 982cb95 d02622b 31086ae d02622b 31086ae d02622b 31086ae d02622b 982cb95 d02622b 31086ae d02622b 982cb95 d02622b 31086ae d02622b 31086ae d02622b 982cb95 d02622b 982cb95 31086ae 982cb95 d02622b 31086ae d02622b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import math
import uuid
from typing import Any, List, Optional
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
class VLLMWrapper(BaseLLMWrapper):
"""
Async inference backend based on vLLM.
"""
def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.9,
temperature: float = 0.0,
top_p: float = 1.0,
topk: int = 5,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
try:
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
except ImportError as exc:
raise ImportError(
"VLLMWrapper requires vllm. Install it with: uv pip install vllm"
) from exc
self.SamplingParams = SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tensor_parallel_size=int(tensor_parallel_size),
gpu_memory_utilization=float(gpu_memory_utilization),
trust_remote_code=kwargs.get("trust_remote_code", True),
disable_log_stats=False,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.temperature = temperature
self.top_p = top_p
self.topk = topk
@staticmethod
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
msgs = history or []
lines = []
for m in msgs:
if isinstance(m, dict):
role = m.get("role", "")
content = m.get("content", "")
lines.append(f"{role}: {content}")
else:
lines.append(str(m))
lines.append(prompt)
return "\n".join(lines)
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
full_prompt = self._build_inputs(text, history)
request_id = f"graphgen_req_{uuid.uuid4()}"
sp = self.SamplingParams(
temperature=self.temperature if self.temperature > 0 else 1.0,
top_p=self.top_p if self.temperature > 0 else 1.0,
max_tokens=extra.get("max_new_tokens", 512),
)
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
final_output = None
async for request_output in result_generator:
final_output = request_output
if not final_output or not final_output.outputs:
return ""
return final_output.outputs[0].text
async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
full_prompt = self._build_inputs(text, history)
request_id = f"graphgen_topk_{uuid.uuid4()}"
sp = self.SamplingParams(
temperature=0,
max_tokens=1,
logprobs=self.topk,
prompt_logprobs=1,
)
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
final_output = None
async for request_output in result_generator:
final_output = request_output
if (
not final_output
or not final_output.outputs
or not final_output.outputs[0].logprobs
):
return []
top_logprobs = final_output.outputs[0].logprobs[0]
candidate_tokens = []
for _, logprob_obj in top_logprobs.items():
tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
prob = float(math.exp(logprob_obj.logprob))
candidate_tokens.append(Token(tok_str, prob))
candidate_tokens.sort(key=lambda x: -x.prob)
if candidate_tokens:
main_token = Token(
text=candidate_tokens[0].text,
prob=candidate_tokens[0].prob,
top_candidates=candidate_tokens
)
return [main_token]
return []
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
raise NotImplementedError(
"VLLMWrapper does not support per-token logprobs yet."
)
|