sachin sharma
commited on
Commit
·
da2b98d
1
Parent(s):
5ddae77
removed verbosity
Browse files- app/api/controllers.py +11 -55
- app/api/routes/prediction.py +4 -40
- app/core/app.py +15 -90
- app/services/base.py +7 -109
- app/services/inference.py +43 -152
app/api/controllers.py
CHANGED
|
@@ -1,11 +1,5 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Controllers for handling API business logic.
|
| 3 |
|
| 4 |
-
This controller layer orchestrates requests between the API routes and the
|
| 5 |
-
inference service layer. It handles validation and error responses.
|
| 6 |
-
|
| 7 |
-
The controller is model-agnostic and works with any InferenceService implementation.
|
| 8 |
-
"""
|
| 9 |
from fastapi import HTTPException
|
| 10 |
|
| 11 |
from app.core.logging import logger
|
|
@@ -14,66 +8,28 @@ from app.api.models import ImageRequest, PredictionResponse
|
|
| 14 |
|
| 15 |
|
| 16 |
class PredictionController:
|
| 17 |
-
"""
|
| 18 |
-
Controller for ML prediction endpoints.
|
| 19 |
-
|
| 20 |
-
This controller works with any InferenceService implementation,
|
| 21 |
-
making it easy to swap different models without changing the API layer.
|
| 22 |
-
"""
|
| 23 |
|
| 24 |
@staticmethod
|
| 25 |
async def predict(
|
| 26 |
-
|
| 27 |
-
|
| 28 |
) -> PredictionResponse:
|
| 29 |
-
"""
|
| 30 |
-
Run inference using the configured model service.
|
| 31 |
-
|
| 32 |
-
The controller handles request validation and error handling,
|
| 33 |
-
while the service handles the actual inference logic.
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
request: ImageRequest with base64-encoded image data
|
| 37 |
-
service: Initialized inference service (can be any model)
|
| 38 |
-
|
| 39 |
-
Returns:
|
| 40 |
-
PredictionResponse with prediction results
|
| 41 |
-
|
| 42 |
-
Raises:
|
| 43 |
-
HTTPException: If service unavailable, invalid input, or inference fails
|
| 44 |
-
"""
|
| 45 |
try:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
raise HTTPException(
|
| 49 |
-
status_code=503,
|
| 50 |
-
detail="Service not initialized"
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
if not service.is_loaded:
|
| 54 |
-
raise HTTPException(
|
| 55 |
-
status_code=503,
|
| 56 |
-
detail="Model not loaded"
|
| 57 |
-
)
|
| 58 |
|
| 59 |
-
# Validate media type
|
| 60 |
if not request.image.mediaType.startswith('image/'):
|
| 61 |
-
raise HTTPException(
|
| 62 |
-
status_code=400,
|
| 63 |
-
detail=f"Invalid media type: {request.image.mediaType}. Must be image/*"
|
| 64 |
-
)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
response = await service.predict(request)
|
| 68 |
-
return response
|
| 69 |
|
| 70 |
except HTTPException:
|
| 71 |
raise
|
| 72 |
except ValueError as e:
|
| 73 |
-
# Service raises ValueError for invalid input
|
| 74 |
logger.error(f"Invalid input: {e}")
|
| 75 |
-
raise HTTPException(
|
| 76 |
except Exception as e:
|
| 77 |
-
# Unexpected errors
|
| 78 |
logger.error(f"Prediction failed: {e}")
|
| 79 |
-
raise HTTPException(
|
|
|
|
| 1 |
+
"""API controllers for request handling and validation."""
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from fastapi import HTTPException
|
| 4 |
|
| 5 |
from app.core.logging import logger
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class PredictionController:
|
| 11 |
+
"""Controller for prediction endpoints."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
@staticmethod
|
| 14 |
async def predict(
|
| 15 |
+
request: ImageRequest,
|
| 16 |
+
service: InferenceService
|
| 17 |
) -> PredictionResponse:
|
| 18 |
+
"""Run inference using the configured service."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
try:
|
| 20 |
+
if not service or not service.is_loaded:
|
| 21 |
+
raise HTTPException(503, "Service not available")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
| 23 |
if not request.image.mediaType.startswith('image/'):
|
| 24 |
+
raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
return await service.predict(request)
|
|
|
|
|
|
|
| 27 |
|
| 28 |
except HTTPException:
|
| 29 |
raise
|
| 30 |
except ValueError as e:
|
|
|
|
| 31 |
logger.error(f"Invalid input: {e}")
|
| 32 |
+
raise HTTPException(400, str(e))
|
| 33 |
except Exception as e:
|
|
|
|
| 34 |
logger.error(f"Prediction failed: {e}")
|
| 35 |
+
raise HTTPException(500, "Internal server error")
|
app/api/routes/prediction.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
| 1 |
-
"""
|
| 2 |
-
ML Prediction routes.
|
| 3 |
|
| 4 |
-
This module defines the HTTP endpoints for running model inference.
|
| 5 |
-
The routes are model-agnostic and work with any InferenceService implementation.
|
| 6 |
-
"""
|
| 7 |
from fastapi import APIRouter, Depends
|
| 8 |
|
| 9 |
from app.api.controllers import PredictionController
|
|
@@ -20,40 +16,8 @@ async def predict(
|
|
| 20 |
service: InferenceService = Depends(get_inference_service)
|
| 21 |
):
|
| 22 |
"""
|
| 23 |
-
Run inference on
|
| 24 |
|
| 25 |
-
|
| 26 |
-
The actual model used depends on what was configured during app startup.
|
| 27 |
-
|
| 28 |
-
Example Request Body:
|
| 29 |
-
```json
|
| 30 |
-
{
|
| 31 |
-
"image": {
|
| 32 |
-
"mediaType": "image/jpeg",
|
| 33 |
-
"data": "<base64-encoded-image-data>"
|
| 34 |
-
}
|
| 35 |
-
}
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
Example Response:
|
| 39 |
-
```json
|
| 40 |
-
{
|
| 41 |
-
"prediction": "tabby cat",
|
| 42 |
-
"confidence": 0.8542,
|
| 43 |
-
"model": "microsoft/resnet-18",
|
| 44 |
-
"predicted_label": 281,
|
| 45 |
-
"mediaType": "image/jpeg"
|
| 46 |
-
}
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
request: ImageRequest containing base64-encoded image
|
| 51 |
-
service: Injected inference service (configured at startup)
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
PredictionResponse with model predictions
|
| 55 |
-
|
| 56 |
-
Raises:
|
| 57 |
-
HTTPException: 400 for invalid input, 503 if service unavailable, 500 for errors
|
| 58 |
"""
|
| 59 |
-
return await PredictionController.predict(request, service)
|
|
|
|
| 1 |
+
"""Prediction API routes."""
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from fastapi import APIRouter, Depends
|
| 4 |
|
| 5 |
from app.api.controllers import PredictionController
|
|
|
|
| 16 |
service: InferenceService = Depends(get_inference_service)
|
| 17 |
):
|
| 18 |
"""
|
| 19 |
+
Run inference on base64-encoded image.
|
| 20 |
|
| 21 |
+
Returns prediction, confidence, predicted label, model name, and media type.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
+
return await PredictionController.predict(request, service)
|
app/core/app.py
CHANGED
|
@@ -1,15 +1,5 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module consolidates all core application components:
|
| 5 |
-
- Configuration management
|
| 6 |
-
- Global service instance (dependency injection)
|
| 7 |
-
- Application lifecycle (startup/shutdown)
|
| 8 |
-
- FastAPI app creation
|
| 9 |
-
|
| 10 |
-
By keeping everything in one place, we avoid the complexity of managing
|
| 11 |
-
global variables across multiple modules.
|
| 12 |
-
"""
|
| 13 |
import warnings
|
| 14 |
from contextlib import asynccontextmanager
|
| 15 |
from typing import AsyncGenerator, Optional
|
|
@@ -25,100 +15,45 @@ from app.api.routes import prediction
|
|
| 25 |
|
| 26 |
|
| 27 |
class Settings(BaseSettings):
|
| 28 |
-
"""
|
| 29 |
-
Application settings with environment variable support.
|
| 30 |
-
|
| 31 |
-
Settings can be overridden via environment variables or .env file.
|
| 32 |
-
"""
|
| 33 |
-
# Basic app settings
|
| 34 |
-
app_name: str = Field(default="ML Inference Service", description="Application name")
|
| 35 |
-
app_version: str = Field(default="0.1.0", description="Application version")
|
| 36 |
-
debug: bool = Field(default=False, description="Debug mode")
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
|
| 42 |
class Config:
|
| 43 |
-
"""Load from .env file if it exists."""
|
| 44 |
env_file = ".env"
|
| 45 |
|
| 46 |
|
| 47 |
-
# Global settings instance
|
| 48 |
settings = Settings()
|
| 49 |
|
| 50 |
-
|
| 51 |
-
# Global inference service instance (initialized during startup)
|
| 52 |
_inference_service: Optional[InferenceService] = None
|
| 53 |
|
| 54 |
|
| 55 |
def get_inference_service() -> Optional[InferenceService]:
|
| 56 |
-
"""
|
| 57 |
-
Get the inference service instance for dependency injection.
|
| 58 |
-
|
| 59 |
-
This function is used in FastAPI route handlers via Depends().
|
| 60 |
-
The service is initialized once during app startup and reused
|
| 61 |
-
for all requests.
|
| 62 |
-
|
| 63 |
-
Returns:
|
| 64 |
-
The initialized inference service, or None if not yet initialized.
|
| 65 |
-
|
| 66 |
-
Example:
|
| 67 |
-
```python
|
| 68 |
-
@router.post("/predict")
|
| 69 |
-
async def predict(
|
| 70 |
-
request: ImageRequest,
|
| 71 |
-
service: InferenceService = Depends(get_inference_service)
|
| 72 |
-
):
|
| 73 |
-
return await service.predict(request)
|
| 74 |
-
```
|
| 75 |
-
"""
|
| 76 |
return _inference_service
|
| 77 |
|
| 78 |
|
| 79 |
def _set_inference_service(service: InferenceService) -> None:
|
| 80 |
-
"""
|
| 81 |
-
INTERNAL: Set the global inference service instance.
|
| 82 |
-
|
| 83 |
-
Called during application startup to register the service.
|
| 84 |
-
This is marked as internal (prefixed with _) because it should
|
| 85 |
-
only be called from the lifespan handler below.
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
service: The initialized inference service instance.
|
| 89 |
-
"""
|
| 90 |
global _inference_service
|
| 91 |
_inference_service = service
|
| 92 |
|
| 93 |
|
| 94 |
@asynccontextmanager
|
| 95 |
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 96 |
-
"""
|
| 97 |
-
Application lifespan manager.
|
| 98 |
-
|
| 99 |
-
Handles startup and shutdown events for the FastAPI application.
|
| 100 |
-
During startup, it initializes and loads the inference service.
|
| 101 |
-
|
| 102 |
-
CUSTOMIZATION POINT FOR GRAD STUDENTS:
|
| 103 |
-
To use your own model, replace ResNetInferenceService below with
|
| 104 |
-
your implementation that subclasses InferenceService.
|
| 105 |
-
|
| 106 |
-
Example:
|
| 107 |
-
```python
|
| 108 |
-
service = MyCustomService(model_name="my-org/my-model")
|
| 109 |
-
await service.load_model()
|
| 110 |
-
_set_inference_service(service)
|
| 111 |
-
```
|
| 112 |
-
"""
|
| 113 |
logger.info("Starting ML Inference Service...")
|
| 114 |
|
| 115 |
try:
|
| 116 |
with warnings.catch_warnings():
|
| 117 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
)
|
| 122 |
await service.load_model()
|
| 123 |
_set_inference_service(service)
|
| 124 |
|
|
@@ -134,17 +69,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
| 134 |
|
| 135 |
|
| 136 |
def create_app() -> FastAPI:
|
| 137 |
-
"""
|
| 138 |
-
Create and configure the FastAPI application.
|
| 139 |
-
|
| 140 |
-
This is the main entry point for the application. It:
|
| 141 |
-
1. Creates a FastAPI instance with metadata from settings
|
| 142 |
-
2. Attaches the lifespan handler for startup/shutdown
|
| 143 |
-
3. Registers API routes
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
Configured FastAPI application instance.
|
| 147 |
-
"""
|
| 148 |
app = FastAPI(
|
| 149 |
title=settings.app_name,
|
| 150 |
description="ML inference service for image classification",
|
|
@@ -155,4 +80,4 @@ def create_app() -> FastAPI:
|
|
| 155 |
|
| 156 |
app.include_router(prediction.router)
|
| 157 |
|
| 158 |
-
return app
|
|
|
|
| 1 |
+
"""FastAPI application factory and core infrastructure."""
|
| 2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import warnings
|
| 4 |
from contextlib import asynccontextmanager
|
| 5 |
from typing import AsyncGenerator, Optional
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class Settings(BaseSettings):
|
| 18 |
+
"""Application settings. Override via environment variables or .env file."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
app_name: str = Field(default="ML Inference Service")
|
| 21 |
+
app_version: str = Field(default="0.1.0")
|
| 22 |
+
debug: bool = Field(default=False)
|
| 23 |
+
host: str = Field(default="0.0.0.0")
|
| 24 |
+
port: int = Field(default=8000)
|
| 25 |
|
| 26 |
class Config:
|
|
|
|
| 27 |
env_file = ".env"
|
| 28 |
|
| 29 |
|
|
|
|
| 30 |
settings = Settings()
|
| 31 |
|
|
|
|
|
|
|
| 32 |
_inference_service: Optional[InferenceService] = None
|
| 33 |
|
| 34 |
|
| 35 |
def get_inference_service() -> Optional[InferenceService]:
|
| 36 |
+
"""Get inference service for dependency injection."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
return _inference_service
|
| 38 |
|
| 39 |
|
| 40 |
def _set_inference_service(service: InferenceService) -> None:
|
| 41 |
+
"""Set inference service. Called internally during startup."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
global _inference_service
|
| 43 |
_inference_service = service
|
| 44 |
|
| 45 |
|
| 46 |
@asynccontextmanager
|
| 47 |
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 48 |
+
"""Application lifecycle: startup/shutdown."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
logger.info("Starting ML Inference Service...")
|
| 50 |
|
| 51 |
try:
|
| 52 |
with warnings.catch_warnings():
|
| 53 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 54 |
|
| 55 |
+
# Replace ResNetInferenceService with your own implementation
|
| 56 |
+
service = ResNetInferenceService(model_name="microsoft/resnet-18")
|
|
|
|
| 57 |
await service.load_model()
|
| 58 |
_set_inference_service(service)
|
| 59 |
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def create_app() -> FastAPI:
|
| 72 |
+
"""Create and configure FastAPI application."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
app = FastAPI(
|
| 74 |
title=settings.app_name,
|
| 75 |
description="ML inference service for image classification",
|
|
|
|
| 80 |
|
| 81 |
app.include_router(prediction.router)
|
| 82 |
|
| 83 |
+
return app
|
app/services/base.py
CHANGED
|
@@ -1,135 +1,33 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Abstract base class for ML inference services.
|
| 3 |
-
|
| 4 |
-
This module defines the contract that all inference services must implement.
|
| 5 |
-
Grad students should subclass `InferenceService` and implement the abstract methods
|
| 6 |
-
to integrate their models with the serving infrastructure.
|
| 7 |
-
"""
|
| 8 |
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
from typing import Generic, TypeVar
|
| 11 |
-
|
| 12 |
from pydantic import BaseModel
|
| 13 |
|
| 14 |
-
|
| 15 |
-
# Type variables for request and response models
|
| 16 |
TRequest = TypeVar('TRequest', bound=BaseModel)
|
| 17 |
TResponse = TypeVar('TResponse', bound=BaseModel)
|
| 18 |
|
| 19 |
|
| 20 |
class InferenceService(ABC, Generic[TRequest, TResponse]):
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
This class defines the interface that all model serving implementations must follow.
|
| 25 |
-
By subclassing this and implementing the abstract methods, you can integrate any
|
| 26 |
-
ML model with the serving infrastructure.
|
| 27 |
-
|
| 28 |
-
Type Parameters:
|
| 29 |
-
TRequest: Pydantic model for input requests (e.g., ImageRequest, TextRequest)
|
| 30 |
-
TResponse: Pydantic model for prediction responses (e.g., PredictionResponse)
|
| 31 |
-
|
| 32 |
-
Example:
|
| 33 |
-
```python
|
| 34 |
-
class MyModelService(InferenceService[MyRequest, MyResponse]):
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
self.model = torch.load("my_model.pt")
|
| 39 |
-
self._is_loaded = True
|
| 40 |
-
|
| 41 |
-
async def predict(self, request: MyRequest) -> MyResponse:
|
| 42 |
-
# Run inference
|
| 43 |
-
output = self.model(request.data)
|
| 44 |
-
return MyResponse(result=output)
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def is_loaded(self) -> bool:
|
| 48 |
-
return self._is_loaded
|
| 49 |
-
```
|
| 50 |
"""
|
| 51 |
|
| 52 |
@abstractmethod
|
| 53 |
async def load_model(self) -> None:
|
| 54 |
-
"""
|
| 55 |
-
Load the model weights and any required processors/tokenizers.
|
| 56 |
-
|
| 57 |
-
This method is called once during application startup (in the lifespan handler).
|
| 58 |
-
Use this to:
|
| 59 |
-
- Load model weights from disk
|
| 60 |
-
- Initialize processors, tokenizers, or other preprocessing components
|
| 61 |
-
- Set up any required state
|
| 62 |
-
- Perform model warmup if needed
|
| 63 |
-
|
| 64 |
-
Raises:
|
| 65 |
-
FileNotFoundError: If model files don't exist
|
| 66 |
-
RuntimeError: If model loading fails
|
| 67 |
-
"""
|
| 68 |
pass
|
| 69 |
|
| 70 |
@abstractmethod
|
| 71 |
async def predict(self, request: TRequest) -> TResponse:
|
| 72 |
-
"""
|
| 73 |
-
Run inference on the input request and return a typed response.
|
| 74 |
-
|
| 75 |
-
This method is called for each prediction request. It should:
|
| 76 |
-
1. Extract input data from the request
|
| 77 |
-
2. Preprocess the input (if needed)
|
| 78 |
-
3. Run the model inference
|
| 79 |
-
4. Post-process the output
|
| 80 |
-
5. Return a Pydantic response model
|
| 81 |
-
|
| 82 |
-
Args:
|
| 83 |
-
request: Input request containing the data to predict on.
|
| 84 |
-
Type is specified by the TRequest type parameter.
|
| 85 |
-
|
| 86 |
-
Returns:
|
| 87 |
-
Typed Pydantic response model containing predictions.
|
| 88 |
-
Type is specified by the TResponse type parameter.
|
| 89 |
-
|
| 90 |
-
Raises:
|
| 91 |
-
ValueError: If input data is invalid
|
| 92 |
-
RuntimeError: If model inference fails
|
| 93 |
-
|
| 94 |
-
Important - Background Threading:
|
| 95 |
-
For CPU-intensive operations (like deep learning inference), you MUST
|
| 96 |
-
offload computation to a background thread to avoid blocking the event loop.
|
| 97 |
-
|
| 98 |
-
Pattern to follow:
|
| 99 |
-
```python
|
| 100 |
-
import asyncio
|
| 101 |
-
|
| 102 |
-
def _predict_sync(self, request: TRequest) -> TResponse:
|
| 103 |
-
# Heavy CPU work here (PyTorch, TensorFlow, etc.)
|
| 104 |
-
result = self.model(data)
|
| 105 |
-
return TResponse(result=result)
|
| 106 |
-
|
| 107 |
-
async def predict(self, request: TRequest) -> TResponse:
|
| 108 |
-
# Offload to thread pool
|
| 109 |
-
return await asyncio.to_thread(self._predict_sync, request)
|
| 110 |
-
```
|
| 111 |
-
|
| 112 |
-
Why this matters:
|
| 113 |
-
- Inference can take 1-3+ seconds and will freeze the server
|
| 114 |
-
- asyncio.to_thread() runs the work in a background thread
|
| 115 |
-
- The event loop stays responsive to handle other requests
|
| 116 |
-
"""
|
| 117 |
pass
|
| 118 |
|
| 119 |
@property
|
| 120 |
@abstractmethod
|
| 121 |
def is_loaded(self) -> bool:
|
| 122 |
-
"""
|
| 123 |
-
Check if the model is loaded and ready for inference.
|
| 124 |
-
|
| 125 |
-
Returns:
|
| 126 |
-
True if model is loaded and ready, False otherwise.
|
| 127 |
-
|
| 128 |
-
Example:
|
| 129 |
-
```python
|
| 130 |
-
@property
|
| 131 |
-
def is_loaded(self) -> bool:
|
| 132 |
-
return self.model is not None and self._is_loaded
|
| 133 |
-
```
|
| 134 |
-
"""
|
| 135 |
pass
|
|
|
|
| 1 |
+
"""Abstract base class for ML inference services."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Generic, TypeVar
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
|
|
|
|
|
|
|
| 7 |
TRequest = TypeVar('TRequest', bound=BaseModel)
|
| 8 |
TResponse = TypeVar('TResponse', bound=BaseModel)
|
| 9 |
|
| 10 |
|
| 11 |
class InferenceService(ABC, Generic[TRequest, TResponse]):
|
| 12 |
"""
|
| 13 |
+
Base class for inference services. Subclass this to integrate your model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
For CPU-intensive inference, offload work to a background thread using
|
| 16 |
+
asyncio.to_thread() to avoid blocking the event loop.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
@abstractmethod
|
| 20 |
async def load_model(self) -> None:
|
| 21 |
+
"""Load model weights and processors. Called once at startup."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
pass
|
| 23 |
|
| 24 |
@abstractmethod
|
| 25 |
async def predict(self, request: TRequest) -> TResponse:
|
| 26 |
+
"""Run inference and return typed response."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
pass
|
| 28 |
|
| 29 |
@property
|
| 30 |
@abstractmethod
|
| 31 |
def is_loaded(self) -> bool:
|
| 32 |
+
"""Check if model is loaded and ready."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
pass
|
app/services/inference.py
CHANGED
|
@@ -1,16 +1,5 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module provides an EXAMPLE implementation of the InferenceService ABC.
|
| 5 |
-
Grad students should use this as a reference when implementing their own model services.
|
| 6 |
-
|
| 7 |
-
This example demonstrates:
|
| 8 |
-
- How to load a HuggingFace transformer model
|
| 9 |
-
- How to preprocess image inputs
|
| 10 |
-
- How to return typed Pydantic responses
|
| 11 |
-
- How to use background threading for CPU-intensive inference
|
| 12 |
-
- Proper error handling and logging
|
| 13 |
-
"""
|
| 14 |
import os
|
| 15 |
import base64
|
| 16 |
import asyncio
|
|
@@ -25,173 +14,75 @@ from app.api.models import ImageRequest, PredictionResponse
|
|
| 25 |
|
| 26 |
|
| 27 |
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 28 |
-
"""
|
| 29 |
-
EXAMPLE: ResNet inference service implementation.
|
| 30 |
-
|
| 31 |
-
This is a reference implementation showing how to integrate a HuggingFace
|
| 32 |
-
image classification model with the serving infrastructure.
|
| 33 |
-
|
| 34 |
-
To create your own service:
|
| 35 |
-
1. Subclass InferenceService[YourRequest, YourResponse]
|
| 36 |
-
2. Implement load_model() to load your model
|
| 37 |
-
3. Implement predict() to run inference and return typed response
|
| 38 |
-
4. Implement the is_loaded property
|
| 39 |
-
|
| 40 |
-
This service loads a ResNet-18 model for ImageNet classification.
|
| 41 |
-
"""
|
| 42 |
|
| 43 |
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
| 44 |
-
"""
|
| 45 |
-
Initialize the ResNet service.
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
model_name: Model identifier (e.g., "microsoft/resnet-18").
|
| 49 |
-
Model files must exist in models/{model_name}/ directory.
|
| 50 |
-
The full org/model structure is preserved.
|
| 51 |
-
|
| 52 |
-
Example:
|
| 53 |
-
For model_name="microsoft/resnet-18", expects files at:
|
| 54 |
-
models/microsoft/resnet-18/config.json
|
| 55 |
-
models/microsoft/resnet-18/pytorch_model.bin
|
| 56 |
-
etc.
|
| 57 |
-
"""
|
| 58 |
self.model_name = model_name
|
| 59 |
self.model = None
|
| 60 |
self.processor = None
|
| 61 |
self._is_loaded = False
|
| 62 |
-
|
| 63 |
-
# Preserve full org/model path structure
|
| 64 |
self.model_path = os.path.join("models", model_name)
|
| 65 |
-
logger.info(f"Initializing ResNet service
|
| 66 |
|
| 67 |
async def load_model(self) -> None:
|
| 68 |
-
"""
|
| 69 |
-
Load the ResNet model and processor.
|
| 70 |
-
|
| 71 |
-
This method loads the model once during startup and reuses it for all requests.
|
| 72 |
-
Called by the application lifespan handler.
|
| 73 |
-
"""
|
| 74 |
if self._is_loaded:
|
| 75 |
-
logger.debug("Model already loaded, skipping...")
|
| 76 |
return
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
raise FileNotFoundError(
|
| 81 |
-
f"Model directory not found: {self.model_path}\n"
|
| 82 |
-
f"Make sure the model files are downloaded to the correct location."
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
config_path = os.path.join(self.model_path, "config.json")
|
| 86 |
-
if not os.path.exists(config_path):
|
| 87 |
-
raise FileNotFoundError(f"Model config not found: {config_path}")
|
| 88 |
-
|
| 89 |
-
logger.info(f"Loading ResNet model from: {self.model_path}")
|
| 90 |
-
|
| 91 |
-
# Suppress warnings during model loading
|
| 92 |
-
import warnings
|
| 93 |
-
with warnings.catch_warnings():
|
| 94 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 95 |
-
warnings.filterwarnings("ignore", message="Could not find image processor class")
|
| 96 |
-
|
| 97 |
-
self.processor = AutoImageProcessor.from_pretrained(
|
| 98 |
-
self.model_path,
|
| 99 |
-
local_files_only=True
|
| 100 |
-
)
|
| 101 |
-
self.model = ResNetForImageClassification.from_pretrained(
|
| 102 |
-
self.model_path,
|
| 103 |
-
local_files_only=True
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
self._is_loaded = True
|
| 107 |
-
logger.info("ResNet model loaded successfully")
|
| 108 |
-
logger.info(f"Model architecture: {self.model.config.architectures}")
|
| 109 |
-
logger.info(f"Model has {len(self.model.config.id2label)} classes")
|
| 110 |
-
|
| 111 |
-
except Exception as e:
|
| 112 |
-
logger.error(f"Failed to load ResNet model: {e}")
|
| 113 |
-
logger.error(f"Hint: Ensure model files exist at: {self.model_path}")
|
| 114 |
-
raise
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
|
| 118 |
-
"""
|
| 119 |
-
INTERNAL: Synchronous prediction logic that runs in a background thread.
|
| 120 |
-
|
| 121 |
-
This method contains all CPU-intensive operations (image decoding,
|
| 122 |
-
preprocessing, PyTorch inference). It's called from predict() via
|
| 123 |
-
asyncio.to_thread() to avoid blocking the event loop.
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
-
|
| 129 |
-
PredictionResponse with prediction, confidence, and metadata
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
if image.mode != 'RGB':
|
| 141 |
-
logger.debug(f"Converting image from {image.mode} to RGB")
|
| 142 |
-
image = image.convert('RGB')
|
| 143 |
|
| 144 |
-
|
|
|
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
-
|
| 153 |
-
confidence = probabilities[0][predicted_label].item()
|
| 154 |
|
| 155 |
-
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
predicted_label=predicted_label,
|
| 162 |
-
mediaType=request.image.mediaType
|
| 163 |
-
)
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
async def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 170 |
-
"""
|
| 171 |
-
Perform inference on an image request.
|
| 172 |
-
|
| 173 |
-
This method demonstrates proper async handling for CPU-intensive operations.
|
| 174 |
-
The actual inference work is offloaded to a background thread using
|
| 175 |
-
asyncio.to_thread(), which prevents blocking the event loop.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
request: ImageRequest containing base64-encoded image data
|
| 179 |
-
|
| 180 |
-
Returns:
|
| 181 |
-
PredictionResponse with prediction, confidence, and metadata
|
| 182 |
-
|
| 183 |
-
Raises:
|
| 184 |
-
RuntimeError: If model is not loaded
|
| 185 |
-
ValueError: If image decoding or processing fails
|
| 186 |
-
"""
|
| 187 |
if not self._is_loaded:
|
| 188 |
-
logger.warning("Model not loaded, loading now...")
|
| 189 |
await self.load_model()
|
| 190 |
|
| 191 |
-
|
| 192 |
-
return response
|
| 193 |
|
| 194 |
@property
|
| 195 |
def is_loaded(self) -> bool:
|
| 196 |
-
"""Check if model is loaded."""
|
| 197 |
return self._is_loaded
|
|
|
|
| 1 |
+
"""ResNet inference service implementation."""
|
| 2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import base64
|
| 5 |
import asyncio
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 17 |
+
"""ResNet-18 inference service for image classification."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.model_name = model_name
|
| 21 |
self.model = None
|
| 22 |
self.processor = None
|
| 23 |
self._is_loaded = False
|
|
|
|
|
|
|
| 24 |
self.model_path = os.path.join("models", model_name)
|
| 25 |
+
logger.info(f"Initializing ResNet service: {self.model_path}")
|
| 26 |
|
| 27 |
async def load_model(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if self._is_loaded:
|
|
|
|
| 29 |
return
|
| 30 |
|
| 31 |
+
if not os.path.exists(self.model_path):
|
| 32 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
config_path = os.path.join(self.model_path, "config.json")
|
| 35 |
+
if not os.path.exists(config_path):
|
| 36 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 37 |
|
| 38 |
+
logger.info(f"Loading model from {self.model_path}")
|
|
|
|
| 39 |
|
| 40 |
+
import warnings
|
| 41 |
+
with warnings.catch_warnings():
|
| 42 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 43 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 44 |
+
self.model_path, local_files_only=True
|
| 45 |
+
)
|
| 46 |
+
self.model = ResNetForImageClassification.from_pretrained(
|
| 47 |
+
self.model_path, local_files_only=True
|
| 48 |
+
)
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
self._is_loaded = True
|
| 51 |
+
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
|
| 52 |
|
| 53 |
+
def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
|
| 54 |
+
"""Synchronous inference logic. Runs in background thread."""
|
| 55 |
+
image_data = base64.b64decode(request.image.data)
|
| 56 |
+
image = Image.open(BytesIO(image_data))
|
| 57 |
|
| 58 |
+
if image.mode != 'RGB':
|
| 59 |
+
image = image.convert('RGB')
|
| 60 |
|
| 61 |
+
inputs = self.processor(image, return_tensors="pt")
|
|
|
|
| 62 |
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
logits = self.model(**inputs).logits
|
| 65 |
|
| 66 |
+
predicted_label = logits.argmax(-1).item()
|
| 67 |
+
predicted_class = self.model.config.id2label[predicted_label]
|
| 68 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 69 |
+
confidence = probabilities[0][predicted_label].item()
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
return PredictionResponse(
|
| 72 |
+
prediction=predicted_class,
|
| 73 |
+
confidence=round(confidence, 4),
|
| 74 |
+
model=self.model_name,
|
| 75 |
+
predicted_label=predicted_label,
|
| 76 |
+
mediaType=request.image.mediaType
|
| 77 |
+
)
|
| 78 |
|
| 79 |
async def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 80 |
+
"""Run inference with background threading to avoid blocking event loop."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
if not self._is_loaded:
|
|
|
|
| 82 |
await self.load_model()
|
| 83 |
|
| 84 |
+
return await asyncio.to_thread(self._predict_sync, request)
|
|
|
|
| 85 |
|
| 86 |
@property
|
| 87 |
def is_loaded(self) -> bool:
|
|
|
|
| 88 |
return self._is_loaded
|