Spaces:
Runtime error
Runtime error
feat: use the cache payload and overwrite the text_generation func
Browse files
app.py
CHANGED
|
@@ -257,6 +257,279 @@ class InferenceClientUS(InferenceClient):
|
|
| 257 |
continue
|
| 258 |
raise
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
client = InferenceClientUS(
|
| 261 |
API_URL,
|
| 262 |
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|
|
|
|
| 257 |
continue
|
| 258 |
raise
|
| 259 |
|
| 260 |
+
def text_generation(
|
| 261 |
+
self,
|
| 262 |
+
prompt: str,
|
| 263 |
+
*,
|
| 264 |
+
details: bool = False,
|
| 265 |
+
stream: bool = False,
|
| 266 |
+
model: Optional[str] = None,
|
| 267 |
+
do_sample: bool = False,
|
| 268 |
+
max_new_tokens: int = 20,
|
| 269 |
+
best_of: Optional[int] = None,
|
| 270 |
+
repetition_penalty: Optional[float] = None,
|
| 271 |
+
return_full_text: bool = False,
|
| 272 |
+
seed: Optional[int] = None,
|
| 273 |
+
stop_sequences: Optional[List[str]] = None,
|
| 274 |
+
temperature: Optional[float] = None,
|
| 275 |
+
top_k: Optional[int] = None,
|
| 276 |
+
top_p: Optional[float] = None,
|
| 277 |
+
truncate: Optional[int] = None,
|
| 278 |
+
typical_p: Optional[float] = None,
|
| 279 |
+
watermark: bool = False,
|
| 280 |
+
decoder_input_details: bool = False,
|
| 281 |
+
) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]:
|
| 282 |
+
"""
|
| 283 |
+
Given a prompt, generate the following text.
|
| 284 |
+
|
| 285 |
+
It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
|
| 286 |
+
early failures.
|
| 287 |
+
|
| 288 |
+
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
| 289 |
+
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
| 290 |
+
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
| 291 |
+
not exactly the same. This method is compatible with both approaches but some parameters are only available for
|
| 292 |
+
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
|
| 293 |
+
continues correctly.
|
| 294 |
+
|
| 295 |
+
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
prompt (`str`):
|
| 299 |
+
Input text.
|
| 300 |
+
details (`bool`, *optional*):
|
| 301 |
+
By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
|
| 302 |
+
probabilities, seed, finish reason, etc.). Only available for models running on with the
|
| 303 |
+
`text-generation-inference` backend.
|
| 304 |
+
stream (`bool`, *optional*):
|
| 305 |
+
By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
|
| 306 |
+
tokens to be returned. Only available for models running on with the `text-generation-inference`
|
| 307 |
+
backend.
|
| 308 |
+
model (`str`, *optional*):
|
| 309 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
| 310 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
| 311 |
+
do_sample (`bool`):
|
| 312 |
+
Activate logits sampling
|
| 313 |
+
max_new_tokens (`int`):
|
| 314 |
+
Maximum number of generated tokens
|
| 315 |
+
best_of (`int`):
|
| 316 |
+
Generate best_of sequences and return the one if the highest token logprobs
|
| 317 |
+
repetition_penalty (`float`):
|
| 318 |
+
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
| 319 |
+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
| 320 |
+
return_full_text (`bool`):
|
| 321 |
+
Whether to prepend the prompt to the generated text
|
| 322 |
+
seed (`int`):
|
| 323 |
+
Random sampling seed
|
| 324 |
+
stop_sequences (`List[str]`):
|
| 325 |
+
Stop generating tokens if a member of `stop_sequences` is generated
|
| 326 |
+
temperature (`float`):
|
| 327 |
+
The value used to module the logits distribution.
|
| 328 |
+
top_k (`int`):
|
| 329 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
| 330 |
+
top_p (`float`):
|
| 331 |
+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
| 332 |
+
higher are kept for generation.
|
| 333 |
+
truncate (`int`):
|
| 334 |
+
Truncate inputs tokens to the given size
|
| 335 |
+
typical_p (`float`):
|
| 336 |
+
Typical Decoding mass
|
| 337 |
+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
| 338 |
+
watermark (`bool`):
|
| 339 |
+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
| 340 |
+
decoder_input_details (`bool`):
|
| 341 |
+
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
| 342 |
+
into account. Defaults to `False`.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
`Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
|
| 346 |
+
Generated text returned from the server:
|
| 347 |
+
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
| 348 |
+
- if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
|
| 349 |
+
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
|
| 350 |
+
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
|
| 351 |
+
|
| 352 |
+
Raises:
|
| 353 |
+
`ValidationError`:
|
| 354 |
+
If input values are not valid. No HTTP call is made to the server.
|
| 355 |
+
[`InferenceTimeoutError`]:
|
| 356 |
+
If the model is unavailable or the request times out.
|
| 357 |
+
`HTTPError`:
|
| 358 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
| 359 |
+
|
| 360 |
+
Example:
|
| 361 |
+
```py
|
| 362 |
+
>>> from huggingface_hub import InferenceClient
|
| 363 |
+
>>> client = InferenceClient()
|
| 364 |
+
|
| 365 |
+
# Case 1: generate text
|
| 366 |
+
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
|
| 367 |
+
'100% open source and built to be easy to use.'
|
| 368 |
+
|
| 369 |
+
# Case 2: iterate over the generated tokens. Useful for large generation.
|
| 370 |
+
>>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
|
| 371 |
+
... print(token)
|
| 372 |
+
100
|
| 373 |
+
%
|
| 374 |
+
open
|
| 375 |
+
source
|
| 376 |
+
and
|
| 377 |
+
built
|
| 378 |
+
to
|
| 379 |
+
be
|
| 380 |
+
easy
|
| 381 |
+
to
|
| 382 |
+
use
|
| 383 |
+
.
|
| 384 |
+
|
| 385 |
+
# Case 3: get more details about the generation process.
|
| 386 |
+
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
|
| 387 |
+
TextGenerationResponse(
|
| 388 |
+
generated_text='100% open source and built to be easy to use.',
|
| 389 |
+
details=Details(
|
| 390 |
+
finish_reason=<FinishReason.Length: 'length'>,
|
| 391 |
+
generated_tokens=12,
|
| 392 |
+
seed=None,
|
| 393 |
+
prefill=[
|
| 394 |
+
InputToken(id=487, text='The', logprob=None),
|
| 395 |
+
InputToken(id=53789, text=' hugging', logprob=-13.171875),
|
| 396 |
+
(...)
|
| 397 |
+
InputToken(id=204, text=' ', logprob=-7.0390625)
|
| 398 |
+
],
|
| 399 |
+
tokens=[
|
| 400 |
+
Token(id=1425, text='100', logprob=-1.0175781, special=False),
|
| 401 |
+
Token(id=16, text='%', logprob=-0.0463562, special=False),
|
| 402 |
+
(...)
|
| 403 |
+
Token(id=25, text='.', logprob=-0.5703125, special=False)
|
| 404 |
+
],
|
| 405 |
+
best_of_sequences=None
|
| 406 |
+
)
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Case 4: iterate over the generated tokens with more details.
|
| 410 |
+
# Last object is more complete, containing the full generated text and the finish reason.
|
| 411 |
+
>>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
|
| 412 |
+
... print(details)
|
| 413 |
+
...
|
| 414 |
+
TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
|
| 415 |
+
TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
|
| 416 |
+
TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
|
| 417 |
+
TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
|
| 418 |
+
TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
|
| 419 |
+
TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
|
| 420 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
|
| 421 |
+
TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
|
| 422 |
+
TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
|
| 423 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
|
| 424 |
+
TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
|
| 425 |
+
TextGenerationStreamResponse(token=Token(
|
| 426 |
+
id=25,
|
| 427 |
+
text='.',
|
| 428 |
+
logprob=-0.5703125,
|
| 429 |
+
special=False),
|
| 430 |
+
generated_text='100% open source and built to be easy to use.',
|
| 431 |
+
details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
|
| 432 |
+
)
|
| 433 |
+
```
|
| 434 |
+
"""
|
| 435 |
+
# NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
|
| 436 |
+
# like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
|
| 437 |
+
|
| 438 |
+
if decoder_input_details and not details:
|
| 439 |
+
warnings.warn(
|
| 440 |
+
"`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
|
| 441 |
+
" the output from the server will be truncated."
|
| 442 |
+
)
|
| 443 |
+
decoder_input_details = False
|
| 444 |
+
|
| 445 |
+
# Validate parameters
|
| 446 |
+
parameters = TextGenerationParameters(
|
| 447 |
+
best_of=best_of,
|
| 448 |
+
details=details,
|
| 449 |
+
do_sample=do_sample,
|
| 450 |
+
max_new_tokens=max_new_tokens,
|
| 451 |
+
repetition_penalty=repetition_penalty,
|
| 452 |
+
return_full_text=return_full_text,
|
| 453 |
+
seed=seed,
|
| 454 |
+
stop=stop_sequences if stop_sequences is not None else [],
|
| 455 |
+
temperature=temperature,
|
| 456 |
+
top_k=top_k,
|
| 457 |
+
top_p=top_p,
|
| 458 |
+
truncate=truncate,
|
| 459 |
+
typical_p=typical_p,
|
| 460 |
+
watermark=watermark,
|
| 461 |
+
decoder_input_details=decoder_input_details,
|
| 462 |
+
)
|
| 463 |
+
request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
|
| 464 |
+
payload = asdict(request)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# add the use_cache option
|
| 468 |
+
print(f"payload:{payload}")
|
| 469 |
+
payload["options"]['use_cache'] = False
|
| 470 |
+
|
| 471 |
+
# Remove some parameters if not a TGI server
|
| 472 |
+
if not _is_tgi_server(model):
|
| 473 |
+
ignored_parameters = []
|
| 474 |
+
for key in "watermark", "stop", "details", "decoder_input_details":
|
| 475 |
+
if payload["parameters"][key] is not None:
|
| 476 |
+
ignored_parameters.append(key)
|
| 477 |
+
del payload["parameters"][key]
|
| 478 |
+
if len(ignored_parameters) > 0:
|
| 479 |
+
warnings.warn(
|
| 480 |
+
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
| 481 |
+
f" {ignored_parameters}.",
|
| 482 |
+
UserWarning,
|
| 483 |
+
)
|
| 484 |
+
if details:
|
| 485 |
+
warnings.warn(
|
| 486 |
+
"API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
|
| 487 |
+
" be ignored meaning only the generated text will be returned.",
|
| 488 |
+
UserWarning,
|
| 489 |
+
)
|
| 490 |
+
details = False
|
| 491 |
+
if stream:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
"API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
|
| 494 |
+
" Please pass `stream=False` as input."
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Handle errors separately for more precise error messages
|
| 498 |
+
try:
|
| 499 |
+
bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
|
| 500 |
+
except HTTPError as e:
|
| 501 |
+
if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e):
|
| 502 |
+
_set_as_non_tgi(model)
|
| 503 |
+
return self.text_generation( # type: ignore
|
| 504 |
+
prompt=prompt,
|
| 505 |
+
details=details,
|
| 506 |
+
stream=stream,
|
| 507 |
+
model=model,
|
| 508 |
+
do_sample=do_sample,
|
| 509 |
+
max_new_tokens=max_new_tokens,
|
| 510 |
+
best_of=best_of,
|
| 511 |
+
repetition_penalty=repetition_penalty,
|
| 512 |
+
return_full_text=return_full_text,
|
| 513 |
+
seed=seed,
|
| 514 |
+
stop_sequences=stop_sequences,
|
| 515 |
+
temperature=temperature,
|
| 516 |
+
top_k=top_k,
|
| 517 |
+
top_p=top_p,
|
| 518 |
+
truncate=truncate,
|
| 519 |
+
typical_p=typical_p,
|
| 520 |
+
watermark=watermark,
|
| 521 |
+
decoder_input_details=decoder_input_details,
|
| 522 |
+
)
|
| 523 |
+
raise_text_generation_error(e)
|
| 524 |
+
|
| 525 |
+
# Parse output
|
| 526 |
+
if stream:
|
| 527 |
+
return _stream_text_generation_response(bytes_output, details) # type: ignore
|
| 528 |
+
|
| 529 |
+
data = _bytes_to_dict(bytes_output)[0]
|
| 530 |
+
return TextGenerationResponse(**data) if details else data["generated_text"]
|
| 531 |
+
|
| 532 |
+
|
| 533 |
client = InferenceClientUS(
|
| 534 |
API_URL,
|
| 535 |
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|