Alovestocode commited on
Commit
beb83c9
·
verified ·
1 Parent(s): f4c85cb

Fix app.load() - move inside Blocks context

Browse files
Files changed (1) hide show
  1. app.py +44 -45
app.py CHANGED
@@ -411,54 +411,53 @@ with gr.Blocks(title="Router Model API - ZeroGPU") as gradio_app:
411
  inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
412
  outputs=output,
413
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  # Set app to Gradio Blocks for Spaces - ZeroGPU requires Gradio SDK
416
  app = gradio_app
417
 
418
- # Add FastAPI routes directly to Gradio's app after Blocks context exits
419
- # Use load event to ensure app is ready
420
- def add_fastapi_routes():
421
- """Add FastAPI routes to Gradio's underlying FastAPI app."""
422
- try:
423
- from fastapi.responses import JSONResponse
424
-
425
- @app.app.post("/v1/generate")
426
- async def fastapi_generate_endpoint(request):
427
- """FastAPI endpoint mounted in Gradio app."""
428
- from fastapi import Request
429
- try:
430
- data = await request.json()
431
- payload = GeneratePayload(**data)
432
- text = _generate_with_gpu(
433
- prompt=payload.prompt,
434
- max_new_tokens=payload.max_new_tokens or MAX_NEW_TOKENS,
435
- temperature=payload.temperature or DEFAULT_TEMPERATURE,
436
- top_p=payload.top_p or DEFAULT_TOP_P,
437
- )
438
- return JSONResponse(content={"text": text})
439
- except Exception as exc:
440
- from fastapi import HTTPException
441
- raise HTTPException(status_code=500, detail=str(exc))
442
-
443
- @app.app.get("/gradio")
444
- async def fastapi_gradio_ui():
445
- """FastAPI /gradio endpoint."""
446
- return HTMLResponse(interactive_ui())
447
-
448
- @app.app.get("/")
449
- async def fastapi_healthcheck():
450
- """FastAPI healthcheck endpoint."""
451
- return {
452
- "status": "ok",
453
- "model": MODEL_ID,
454
- "strategy": ACTIVE_STRATEGY or "pending",
455
- }
456
- print("FastAPI routes added successfully")
457
- except Exception as e:
458
- print(f"Warning: Could not add FastAPI routes: {e}")
459
-
460
- # Use load event to add routes when app is ready
461
- app.load(add_fastapi_routes)
462
-
463
  if __name__ == "__main__": # pragma: no cover
464
  app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
411
  inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
412
  outputs=output,
413
  )
414
+
415
+ # Add FastAPI routes using load event (must be inside Blocks context)
416
+ def add_fastapi_routes():
417
+ """Add FastAPI routes to Gradio's underlying FastAPI app."""
418
+ try:
419
+ from fastapi.responses import JSONResponse
420
+
421
+ @gradio_app.app.post("/v1/generate")
422
+ async def fastapi_generate_endpoint(request):
423
+ """FastAPI endpoint mounted in Gradio app."""
424
+ from fastapi import Request
425
+ try:
426
+ data = await request.json()
427
+ payload = GeneratePayload(**data)
428
+ text = _generate_with_gpu(
429
+ prompt=payload.prompt,
430
+ max_new_tokens=payload.max_new_tokens or MAX_NEW_TOKENS,
431
+ temperature=payload.temperature or DEFAULT_TEMPERATURE,
432
+ top_p=payload.top_p or DEFAULT_TOP_P,
433
+ )
434
+ return JSONResponse(content={"text": text})
435
+ except Exception as exc:
436
+ from fastapi import HTTPException
437
+ raise HTTPException(status_code=500, detail=str(exc))
438
+
439
+ @gradio_app.app.get("/gradio")
440
+ async def fastapi_gradio_ui():
441
+ """FastAPI /gradio endpoint."""
442
+ return HTMLResponse(interactive_ui())
443
+
444
+ @gradio_app.app.get("/")
445
+ async def fastapi_healthcheck():
446
+ """FastAPI healthcheck endpoint."""
447
+ return {
448
+ "status": "ok",
449
+ "model": MODEL_ID,
450
+ "strategy": ACTIVE_STRATEGY or "pending",
451
+ }
452
+ print("FastAPI routes added successfully")
453
+ except Exception as e:
454
+ print(f"Warning: Could not add FastAPI routes: {e}")
455
+
456
+ # Use load event to add routes when app is ready (must be inside Blocks context)
457
+ gradio_app.load(add_fastapi_routes)
458
 
459
  # Set app to Gradio Blocks for Spaces - ZeroGPU requires Gradio SDK
460
  app = gradio_app
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  if __name__ == "__main__": # pragma: no cover
463
  app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))