Alovestocode commited on
Commit
1910748
·
verified ·
1 Parent(s): c924012

Refactor: Use APIRouter with include_router, improve Gradio UI with status messages, remove HTML console

Browse files
Files changed (1) hide show
  1. app.py +56 -38
app.py CHANGED
@@ -5,7 +5,7 @@ from functools import lru_cache
5
  from typing import List, Optional, Tuple
6
 
7
  import torch
8
- from fastapi import FastAPI, HTTPException
9
  from pydantic import BaseModel
10
 
11
  try:
@@ -289,14 +289,10 @@ def _generate_with_gpu(
289
  temperature: float = DEFAULT_TEMPERATURE,
290
  top_p: float = DEFAULT_TOP_P,
291
  ) -> str:
292
- """Generate function wrapped with ZeroGPU decorator. Must be defined before FastAPI app for ZeroGPU detection."""
293
  return _generate(prompt, max_new_tokens, temperature, top_p)
294
 
295
 
296
- fastapi_app = FastAPI(title="Router Model API", version="1.0.0")
297
-
298
-
299
- @fastapi_app.get("/health")
300
  def healthcheck() -> dict[str, str]:
301
  return {
302
  "status": "ok",
@@ -305,16 +301,11 @@ def healthcheck() -> dict[str, str]:
305
  }
306
 
307
 
308
- @fastapi_app.on_event("startup")
309
  def warm_start() -> None:
310
  """Warm start is disabled for ZeroGPU - model loads on first request."""
311
- # ZeroGPU functions decorated with @spaces.GPU cannot be called during startup.
312
- # They must be called within request handlers. Skip warm start for ZeroGPU.
313
  print("Warm start skipped for ZeroGPU. Model will load on first request.")
314
- return
315
 
316
 
317
- @fastapi_app.post("/v1/generate", response_model=GenerateResponse)
318
  def generate_endpoint(payload: GeneratePayload) -> GenerateResponse:
319
  try:
320
  text = _generate_with_gpu(
@@ -333,15 +324,41 @@ def generate_endpoint(payload: GeneratePayload) -> GenerateResponse:
333
  # Gradio interface for ZeroGPU detection - ZeroGPU requires Gradio SDK
334
  import gradio as gr
335
 
336
- @spaces.GPU(duration=300)
337
- def gradio_generate(
 
 
 
 
 
 
 
338
  prompt: str,
339
  max_new_tokens: int = MAX_NEW_TOKENS,
340
  temperature: float = DEFAULT_TEMPERATURE,
341
  top_p: float = DEFAULT_TOP_P,
342
- ) -> str:
343
- """Gradio interface function with GPU decorator for ZeroGPU detection."""
344
- return _generate(prompt, max_new_tokens, temperature, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  # Create Gradio Blocks app to mount FastAPI routes properly
347
  with gr.Blocks(
@@ -428,10 +445,11 @@ with gr.Blocks(
428
  gr.Markdown("### 📤 Output")
429
  output = gr.Textbox(
430
  label="Generated Response",
431
- lines=20,
432
  placeholder="Generated response will appear here...",
433
  show_copy_button=True,
434
  )
 
435
 
436
  with gr.Accordion("📚 API Information", open=False):
437
  gr.Markdown("""
@@ -453,39 +471,39 @@ with gr.Blocks(
453
 
454
  # Event handlers
455
  generate_btn.click(
456
- fn=gradio_generate,
457
  inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
458
- outputs=output,
459
  )
460
 
461
  clear_btn.click(
462
- fn=lambda: ("", ""),
463
- outputs=[prompt_input, output],
464
  )
465
 
466
  # Note: API routes will be added after Blocks context to avoid interfering with Gradio's static assets
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  # Enable queued execution so ZeroGPU can schedule GPU work reliably
469
  gradio_app.queue(max_size=8)
470
 
471
- # Mount FastAPI routes onto Gradio's underlying FastAPI app
472
- # This allows API endpoints to work alongside Gradio UI
473
- # We mount FastAPI as a sub-application to avoid conflicts
474
- try:
475
- from starlette.routing import Mount
476
- # Mount FastAPI app at root - Starlette will check routes in order
477
- # Gradio's routes (like /_app/*) will be checked first, then FastAPI routes
478
- gradio_app.app.mount("/", fastapi_app)
479
- print("FastAPI routes mounted onto Gradio app successfully")
480
- except Exception as e:
481
- print(f"Warning: Could not mount FastAPI routes: {e}")
482
- import traceback
483
- traceback.print_exc()
484
-
485
- # Set app to Gradio for Spaces compatibility (sdk: gradio requires Gradio app)
486
- # Spaces will handle running the server automatically
487
  app = gradio_app
488
 
489
  if __name__ == "__main__": # pragma: no cover
490
- # For local testing only - Spaces handles server startup
491
  app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
5
  from typing import List, Optional, Tuple
6
 
7
  import torch
8
+ from fastapi import APIRouter, HTTPException
9
  from pydantic import BaseModel
10
 
11
  try:
 
289
  temperature: float = DEFAULT_TEMPERATURE,
290
  top_p: float = DEFAULT_TOP_P,
291
  ) -> str:
292
+ """Generate function wrapped with ZeroGPU decorator."""
293
  return _generate(prompt, max_new_tokens, temperature, top_p)
294
 
295
 
 
 
 
 
296
  def healthcheck() -> dict[str, str]:
297
  return {
298
  "status": "ok",
 
301
  }
302
 
303
 
 
304
  def warm_start() -> None:
305
  """Warm start is disabled for ZeroGPU - model loads on first request."""
 
 
306
  print("Warm start skipped for ZeroGPU. Model will load on first request.")
 
307
 
308
 
 
309
  def generate_endpoint(payload: GeneratePayload) -> GenerateResponse:
310
  try:
311
  text = _generate_with_gpu(
 
324
  # Gradio interface for ZeroGPU detection - ZeroGPU requires Gradio SDK
325
  import gradio as gr
326
 
327
+ STATUS_IDLE = "Status: awaiting prompt."
328
+
329
+
330
+ def _format_status(message: str, *, success: bool) -> str:
331
+ prefix = "✅" if success else "❌"
332
+ return f"{prefix} {message}"
333
+
334
+
335
+ def gradio_generate_handler(
336
  prompt: str,
337
  max_new_tokens: int = MAX_NEW_TOKENS,
338
  temperature: float = DEFAULT_TEMPERATURE,
339
  top_p: float = DEFAULT_TOP_P,
340
+ ) -> tuple[str, str]:
341
+ """Wrapper used by the Gradio UI with friendly status messages."""
342
+ if not prompt.strip():
343
+ return (
344
+ "ERROR: Prompt must not be empty.",
345
+ _format_status("Prompt required before generating.", success=False),
346
+ )
347
+ try:
348
+ # Reuse the same GPU-decorated generator as the API so behaviour matches.
349
+ text = _generate_with_gpu(
350
+ prompt=prompt,
351
+ max_new_tokens=max_new_tokens,
352
+ temperature=temperature,
353
+ top_p=top_p,
354
+ )
355
+ except Exception as exc: # pragma: no cover - runtime/hardware dependent
356
+ print(f"UI generation failed: {exc}")
357
+ return (
358
+ f"ERROR: {exc}",
359
+ _format_status("Generation failed. Check logs for details.", success=False),
360
+ )
361
+ return text, _format_status("Plan generated successfully.", success=True)
362
 
363
  # Create Gradio Blocks app to mount FastAPI routes properly
364
  with gr.Blocks(
 
445
  gr.Markdown("### 📤 Output")
446
  output = gr.Textbox(
447
  label="Generated Response",
448
+ lines=22,
449
  placeholder="Generated response will appear here...",
450
  show_copy_button=True,
451
  )
452
+ status_display = gr.Markdown(STATUS_IDLE)
453
 
454
  with gr.Accordion("📚 API Information", open=False):
455
  gr.Markdown("""
 
471
 
472
  # Event handlers
473
  generate_btn.click(
474
+ fn=gradio_generate_handler,
475
  inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
476
+ outputs=[output, status_display],
477
  )
478
 
479
  clear_btn.click(
480
+ fn=lambda: ("", "", STATUS_IDLE),
481
+ outputs=[prompt_input, output, status_display],
482
  )
483
 
484
  # Note: API routes will be added after Blocks context to avoid interfering with Gradio's static assets
485
 
486
+ # Attach API routes directly onto Gradio's FastAPI instance
487
+ api_router = APIRouter()
488
+
489
+
490
+ @api_router.get("/health")
491
+ def api_health() -> dict[str, str]:
492
+ return healthcheck()
493
+
494
+
495
+ @api_router.post("/v1/generate", response_model=GenerateResponse)
496
+ def api_generate(payload: GeneratePayload) -> GenerateResponse:
497
+ return generate_endpoint(payload)
498
+
499
+
500
+ gradio_app.app.include_router(api_router)
501
+ warm_start()
502
+
503
  # Enable queued execution so ZeroGPU can schedule GPU work reliably
504
  gradio_app.queue(max_size=8)
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  app = gradio_app
507
 
508
  if __name__ == "__main__": # pragma: no cover
 
509
  app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))