sachin sharma
commited on
Commit
Β·
5ddae77
1
Parent(s):
ebbcd26
refactored codebase
Browse files- README.md +193 -191
- app/api/controllers.py +44 -40
- app/api/routes/prediction.py +46 -7
- app/api/routes/resnet_service_manager.py +0 -19
- app/core/app.py +138 -5
- app/core/config.py +0 -29
- app/core/dependencies.py +0 -19
- app/core/lifespan.py +0 -43
- app/services/base.py +135 -0
- app/services/inference.py +102 -57
- test_main.http +3 -2
README.md
CHANGED
|
@@ -1,89 +1,94 @@
|
|
| 1 |
# ML Inference Service (FastAPI)
|
| 2 |
|
| 3 |
-
A
|
| 4 |
-
This repo ships with a working example using **ResNet-18** (downloaded from Hugging Face) under `models/resnet-18/` and exposes a simple **REST** endpoint.
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
##
|
| 9 |
|
| 10 |
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
- `predicted_label` (class index)
|
| 17 |
-
- `model` (model id)
|
| 18 |
-
- `mediaType` (echoed)
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
##
|
| 23 |
|
| 24 |
```
|
| 25 |
ml-inference-service/
|
| 26 |
-
ββ main.py
|
| 27 |
ββ app/
|
| 28 |
-
β ββ __init__.py
|
| 29 |
β ββ core/
|
| 30 |
-
β β ββ app.py #
|
| 31 |
-
β β ββ config.py # Settings (app name/version/debug)
|
| 32 |
-
β β ββ dependencies.py # DI for model services
|
| 33 |
-
β β ββ lifespan.py # Startup: load model & register service
|
| 34 |
β β ββ logging.py # Logger setup
|
| 35 |
β ββ api/
|
| 36 |
-
β β ββ models.py # Pydantic request/response
|
| 37 |
β β ββ controllers.py # HTTP β service orchestration
|
| 38 |
β β ββ routes/
|
| 39 |
-
β β
|
| 40 |
-
β β ββ resnet_service_manager.py (legacy, unused)
|
| 41 |
β ββ services/
|
| 42 |
-
β
|
|
|
|
| 43 |
ββ models/
|
| 44 |
-
β ββ
|
|
|
|
| 45 |
ββ scripts/
|
| 46 |
-
β ββ
|
| 47 |
-
β ββ
|
| 48 |
-
β
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
ββ test_main.http # Example request you can run from IDEs
|
| 52 |
```
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
-
##
|
| 57 |
|
| 58 |
-
|
| 59 |
```bash
|
| 60 |
python -m venv .venv
|
| 61 |
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
| 62 |
pip install -r requirements.txt
|
| 63 |
```
|
| 64 |
|
| 65 |
-
|
| 66 |
```bash
|
| 67 |
bash scripts/model_download.bash
|
| 68 |
```
|
| 69 |
-
This
|
| 70 |
|
| 71 |
-
|
| 72 |
```bash
|
| 73 |
uvicorn main:app --reload
|
| 74 |
```
|
| 75 |
-
Server
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
```bash
|
| 81 |
-
curl -X POST http://127.0.0.1:8000/predict
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
}'
|
| 84 |
```
|
| 85 |
|
| 86 |
-
|
| 87 |
```json
|
| 88 |
{
|
| 89 |
"prediction": "tiger cat",
|
|
@@ -94,117 +99,121 @@ curl -X POST http://127.0.0.1:8000/predict/resnet -H "Content-Type: applicatio
|
|
| 94 |
}
|
| 95 |
```
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
### Request
|
| 210 |
```json
|
|
@@ -227,58 +236,51 @@ the `ResNetInferenceService` flow.
|
|
| 227 |
}
|
| 228 |
```
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
## βοΈ Configuration
|
| 233 |
|
| 234 |
-
|
| 235 |
-
- `app_name
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
-
|
| 238 |
|
| 239 |
-
Example:
|
| 240 |
```python
|
| 241 |
-
# app/core/config.py
|
| 242 |
-
from pydantic_settings import BaseSettings
|
| 243 |
-
from pydantic import Field
|
| 244 |
-
|
| 245 |
class Settings(BaseSettings):
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
debug: bool = Field(False)
|
| 249 |
-
model_name: str = Field("microsoft/resnet-18", description="HF model id used at startup")
|
| 250 |
-
|
| 251 |
-
settings = Settings()
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
svc = ResNetInferenceService(model_name=settings.model_name, use_local_model=True)
|
| 256 |
```
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
---
|
| 261 |
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
-
|
| 267 |
-
|
| 268 |
|
| 269 |
-
|
| 270 |
|
| 271 |
-
##
|
| 272 |
|
| 273 |
This project includes a comprehensive **PyArrow-based dataset generation system** designed specifically for academic challenges and ML model validation. The system generates **100 standardized test datasets** that allow participants to validate their models against consistent, reproducible test cases.
|
| 274 |
|
| 275 |
-
###
|
| 276 |
```
|
| 277 |
standard_test_001.parquet # Actual test data (images, requests, responses)
|
| 278 |
standard_test_001_metadata.json # Human-readable description and stats
|
| 279 |
```
|
| 280 |
|
| 281 |
-
###
|
| 282 |
|
| 283 |
#### 1. **Standard Test Cases** (`standard_test_*.parquet`)
|
| 284 |
**Purpose**: Baseline functionality validation
|
|
@@ -321,7 +323,7 @@ standard_test_001_metadata.json # Human-readable description and stats
|
|
| 321 |
- **Comparative Analysis**: Enables direct performance comparison between models
|
| 322 |
- **Expected Behavior**: Architecture-specific but structurally consistent responses
|
| 323 |
|
| 324 |
-
###
|
| 325 |
|
| 326 |
The dataset generation follows a **deterministic, reproducible approach**:
|
| 327 |
|
|
@@ -378,7 +380,7 @@ table = pa.table({
|
|
| 378 |
})
|
| 379 |
```
|
| 380 |
|
| 381 |
-
###
|
| 382 |
|
| 383 |
|
| 384 |
**1. Generate Test Datasets**
|
|
@@ -408,12 +410,12 @@ python scripts/test_datasets.py --category edge_case
|
|
| 408 |
python scripts/test_datasets.py --category performance
|
| 409 |
```
|
| 410 |
|
| 411 |
-
###
|
| 412 |
|
| 413 |
The test runner provides comprehensive validation metrics:
|
| 414 |
|
| 415 |
```
|
| 416 |
-
|
| 417 |
============================================================
|
| 418 |
Datasets tested: 100
|
| 419 |
Successful datasets: 95
|
|
|
|
| 1 |
# ML Inference Service (FastAPI)
|
| 2 |
|
| 3 |
+
A FastAPI-based inference server designed to make it easy to serve your ML models. The repo includes a complete working example using ResNet-18 for image classification, but the architecture is built to be model-agnostic. You implement a simple abstract base class, and everything else just works.
|
|
|
|
| 4 |
|
| 5 |
+
Key features:
|
| 6 |
+
- Abstract InferenceService class that you subclass for your model
|
| 7 |
+
- Example ResNet-18 implementation showing how to do it
|
| 8 |
+
- FastAPI application with clean separation (routes β controller β service)
|
| 9 |
+
- Model loaded once at startup and reused across requests
|
| 10 |
+
- Background threading for inference so the server stays responsive
|
| 11 |
+
- Type-safe request/response handling with Pydantic
|
| 12 |
+
- Single generic endpoint that works with any model
|
| 13 |
|
| 14 |
+
## What you get
|
| 15 |
|
| 16 |
+
The service exposes a single endpoint `POST /predict` that accepts a base64-encoded image and returns:
|
| 17 |
+
- `prediction` - the predicted class label
|
| 18 |
+
- `confidence` - softmax probability for the prediction
|
| 19 |
+
- `predicted_label` - numeric class index
|
| 20 |
+
- `model` - identifier for which model produced this prediction
|
| 21 |
+
- `mediaType` - echoed from the request
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
The inference runs in a background thread using asyncio so long-running model predictions don't block the server from handling other requests.
|
| 24 |
|
| 25 |
+
## Project Layout
|
| 26 |
|
| 27 |
```
|
| 28 |
ml-inference-service/
|
| 29 |
+
ββ main.py # Entry point
|
| 30 |
ββ app/
|
|
|
|
| 31 |
β ββ core/
|
| 32 |
+
β β ββ app.py # Everything: config, DI, lifespan, app factory
|
|
|
|
|
|
|
|
|
|
| 33 |
β β ββ logging.py # Logger setup
|
| 34 |
β ββ api/
|
| 35 |
+
β β ββ models.py # Pydantic request/response schemas
|
| 36 |
β β ββ controllers.py # HTTP β service orchestration
|
| 37 |
β β ββ routes/
|
| 38 |
+
β β ββ prediction.py # POST /predict endpoint
|
|
|
|
| 39 |
β ββ services/
|
| 40 |
+
β ββ base.py # Abstract InferenceService class
|
| 41 |
+
β ββ inference.py # ResNetInferenceService (example implementation)
|
| 42 |
ββ models/
|
| 43 |
+
β ββ microsoft/
|
| 44 |
+
β ββ resnet-18/ # Model files (preserves org structure)
|
| 45 |
ββ scripts/
|
| 46 |
+
β ββ generate_test_datasets.py
|
| 47 |
+
β ββ test_datasets.py
|
| 48 |
+
β ββ test_datasets/
|
| 49 |
+
ββ requirements.txt
|
| 50 |
+
ββ test_main.http # Example HTTP request
|
|
|
|
| 51 |
```
|
| 52 |
|
| 53 |
+
The key change from a typical FastAPI app is that `app/core/app.py` consolidates configuration, dependency injection, lifecycle management, and the app factory into one file. This avoids the complexity of managing global variables across multiple modules.
|
| 54 |
|
| 55 |
+
## Quickstart
|
| 56 |
|
| 57 |
+
1) Install dependencies (Python 3.9+)
|
| 58 |
```bash
|
| 59 |
python -m venv .venv
|
| 60 |
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
| 61 |
pip install -r requirements.txt
|
| 62 |
```
|
| 63 |
|
| 64 |
+
2) Download the example model
|
| 65 |
```bash
|
| 66 |
bash scripts/model_download.bash
|
| 67 |
```
|
| 68 |
+
This downloads ResNet-18 from Hugging Face and saves it to `models/microsoft/resnet-18/` (note the org structure is preserved).
|
| 69 |
|
| 70 |
+
3) Run the server
|
| 71 |
```bash
|
| 72 |
uvicorn main:app --reload
|
| 73 |
```
|
| 74 |
+
Server starts on `http://127.0.0.1:8000`.
|
| 75 |
|
| 76 |
+
4) Test the API
|
| 77 |
+
|
| 78 |
+
Use `test_main.http` from your IDE or curl:
|
| 79 |
|
| 80 |
```bash
|
| 81 |
+
curl -X POST http://127.0.0.1:8000/predict \
|
| 82 |
+
-H "Content-Type: application/json" \
|
| 83 |
+
-d '{
|
| 84 |
+
"image": {
|
| 85 |
+
"mediaType": "image/jpeg",
|
| 86 |
+
"data": "<base64-encoded-bytes>"
|
| 87 |
+
}
|
| 88 |
}'
|
| 89 |
```
|
| 90 |
|
| 91 |
+
Example response:
|
| 92 |
```json
|
| 93 |
{
|
| 94 |
"prediction": "tiger cat",
|
|
|
|
| 99 |
}
|
| 100 |
```
|
| 101 |
|
| 102 |
+
## Integrating Your Own Model
|
| 103 |
+
|
| 104 |
+
To use your own model, you implement the `InferenceService` abstract base class. The rest of the infrastructure (API routes, controllers, dependency injection) is already generic and works with any implementation.
|
| 105 |
+
|
| 106 |
+
### Step 1: Implement the InferenceService ABC
|
| 107 |
+
|
| 108 |
+
Create a new file `app/services/your_model_service.py`:
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
from app.services.base import InferenceService
|
| 112 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 113 |
+
|
| 114 |
+
class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
|
| 115 |
+
def __init__(self, model_name: str):
|
| 116 |
+
self.model_name = model_name
|
| 117 |
+
self.model_path = os.path.join("models", model_name)
|
| 118 |
+
self.model = None
|
| 119 |
+
self._is_loaded = False
|
| 120 |
+
|
| 121 |
+
async def load_model(self) -> None:
|
| 122 |
+
# Load your model here
|
| 123 |
+
self.model = load_your_model(self.model_path)
|
| 124 |
+
self._is_loaded = True
|
| 125 |
+
|
| 126 |
+
async def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 127 |
+
# Offload to background thread (important for performance)
|
| 128 |
+
return await asyncio.to_thread(self._predict_sync, request)
|
| 129 |
+
|
| 130 |
+
def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
|
| 131 |
+
# Decode image, run inference, return typed response
|
| 132 |
+
image = decode_base64_image(request.image.data)
|
| 133 |
+
result = self.model(image)
|
| 134 |
+
return PredictionResponse(
|
| 135 |
+
prediction=result.label,
|
| 136 |
+
confidence=result.confidence,
|
| 137 |
+
predicted_label=result.class_id,
|
| 138 |
+
model=self.model_name,
|
| 139 |
+
mediaType=request.image.mediaType
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def is_loaded(self) -> bool:
|
| 144 |
+
return self._is_loaded
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
The key points:
|
| 148 |
+
- Subclass `InferenceService[RequestType, ResponseType]` with your request/response types
|
| 149 |
+
- Implement three methods: `load_model()`, `predict()`, and `is_loaded` property
|
| 150 |
+
- Use `asyncio.to_thread()` to offload CPU-intensive inference to a background thread
|
| 151 |
+
- Return typed Pydantic models, not dicts
|
| 152 |
+
|
| 153 |
+
### Step 2: Register your service at startup
|
| 154 |
+
|
| 155 |
+
Edit `app/core/app.py` and find the lifespan function (around line 134):
|
| 156 |
+
|
| 157 |
+
```python
|
| 158 |
+
# Replace this:
|
| 159 |
+
service = ResNetInferenceService(model_name="microsoft/resnet-18")
|
| 160 |
+
|
| 161 |
+
# With this:
|
| 162 |
+
service = YourModelService(model_name="your-org/your-model")
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
That's it. The same `/predict` endpoint now serves your model.
|
| 166 |
+
|
| 167 |
+
### Model file structure
|
| 168 |
+
|
| 169 |
+
Your model files should be organized as:
|
| 170 |
+
```
|
| 171 |
+
models/
|
| 172 |
+
βββ your-org/
|
| 173 |
+
βββ your-model/
|
| 174 |
+
βββ config.json
|
| 175 |
+
βββ weights.bin
|
| 176 |
+
βββ ... other files
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
The full org/model structure is preserved - no more dropping the org prefix.
|
| 180 |
+
|
| 181 |
+
### Example: Swapping ResNet for ViT
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
# app/services/vit_service.py
|
| 185 |
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
| 186 |
+
|
| 187 |
+
class ViTService(InferenceService[ImageRequest, PredictionResponse]):
|
| 188 |
+
async def load_model(self) -> None:
|
| 189 |
+
self.processor = ViTImageProcessor.from_pretrained(self.model_path)
|
| 190 |
+
self.model = ViTForImageClassification.from_pretrained(self.model_path)
|
| 191 |
+
self._is_loaded = True
|
| 192 |
+
|
| 193 |
+
# ... implement predict() following the pattern above
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
Then in `app/core/app.py`:
|
| 197 |
+
```python
|
| 198 |
+
service = ViTService(model_name="google/vit-base-patch16-224")
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
No other changes needed - the routes, controller, and dependency injection are all model-agnostic.
|
| 202 |
+
|
| 203 |
+
## Validating your setup
|
| 204 |
+
|
| 205 |
+
When you start the server, the logs should show:
|
| 206 |
+
```
|
| 207 |
+
INFO: Starting ML Inference Service...
|
| 208 |
+
INFO: Initializing ResNet service with local model: models/microsoft/resnet-18
|
| 209 |
+
INFO: Loading ResNet model from: models/microsoft/resnet-18
|
| 210 |
+
INFO: ResNet model loaded successfully
|
| 211 |
+
INFO: Startup completed successfully
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
If you see errors like `Model directory not found`, check that your model files exist at the expected path with the full org/model structure.
|
| 215 |
+
|
| 216 |
+
## Request & Response Shapes
|
| 217 |
|
| 218 |
### Request
|
| 219 |
```json
|
|
|
|
| 236 |
}
|
| 237 |
```
|
| 238 |
|
| 239 |
+
## Configuration
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
Settings are defined in `app/core/app.py` in the `Settings` class. The defaults are:
|
| 242 |
+
- `app_name` - "ML Inference Service"
|
| 243 |
+
- `app_version` - "0.1.0"
|
| 244 |
+
- `debug` - False
|
| 245 |
+
- `host` - "0.0.0.0"
|
| 246 |
+
- `port` - 8000
|
| 247 |
|
| 248 |
+
You can override these via environment variables or a `.env` file. If you want to make the model configurable via environment variable, add it to the Settings class:
|
| 249 |
|
|
|
|
| 250 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
class Settings(BaseSettings):
|
| 252 |
+
# ... existing fields ...
|
| 253 |
+
model_name: str = Field("microsoft/resnet-18")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
# Then in the lifespan function:
|
| 256 |
+
service = ResNetInferenceService(model_name=settings.model_name)
|
|
|
|
| 257 |
```
|
| 258 |
|
| 259 |
+
## Deployment
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
For development:
|
| 262 |
+
```bash
|
| 263 |
+
uvicorn main:app --reload
|
| 264 |
+
```
|
| 265 |
|
| 266 |
+
For production, use gunicorn with uvicorn workers:
|
| 267 |
+
```bash
|
| 268 |
+
gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
|
| 269 |
+
```
|
| 270 |
|
| 271 |
+
The service runs on CPU by default. For GPU inference, install CUDA-enabled PyTorch and modify your service to move tensors to the GPU device.
|
| 272 |
|
| 273 |
+
## PyArrow Test Datasets
|
| 274 |
|
| 275 |
This project includes a comprehensive **PyArrow-based dataset generation system** designed specifically for academic challenges and ML model validation. The system generates **100 standardized test datasets** that allow participants to validate their models against consistent, reproducible test cases.
|
| 276 |
|
| 277 |
+
### File Structure
|
| 278 |
```
|
| 279 |
standard_test_001.parquet # Actual test data (images, requests, responses)
|
| 280 |
standard_test_001_metadata.json # Human-readable description and stats
|
| 281 |
```
|
| 282 |
|
| 283 |
+
### Dataset Categories (25 each = 100 total)
|
| 284 |
|
| 285 |
#### 1. **Standard Test Cases** (`standard_test_*.parquet`)
|
| 286 |
**Purpose**: Baseline functionality validation
|
|
|
|
| 323 |
- **Comparative Analysis**: Enables direct performance comparison between models
|
| 324 |
- **Expected Behavior**: Architecture-specific but structurally consistent responses
|
| 325 |
|
| 326 |
+
### Generation Process
|
| 327 |
|
| 328 |
The dataset generation follows a **deterministic, reproducible approach**:
|
| 329 |
|
|
|
|
| 380 |
})
|
| 381 |
```
|
| 382 |
|
| 383 |
+
### Usage Guide
|
| 384 |
|
| 385 |
|
| 386 |
**1. Generate Test Datasets**
|
|
|
|
| 410 |
python scripts/test_datasets.py --category performance
|
| 411 |
```
|
| 412 |
|
| 413 |
+
### Testing Output and Metrics
|
| 414 |
|
| 415 |
The test runner provides comprehensive validation metrics:
|
| 416 |
|
| 417 |
```
|
| 418 |
+
DATASET TESTING SUMMARY
|
| 419 |
============================================================
|
| 420 |
Datasets tested: 100
|
| 421 |
Successful datasets: 95
|
app/api/controllers.py
CHANGED
|
@@ -1,75 +1,79 @@
|
|
| 1 |
"""
|
| 2 |
Controllers for handling API business logic.
|
| 3 |
-
"""
|
| 4 |
-
import base64
|
| 5 |
-
import io
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from fastapi import HTTPException
|
| 8 |
-
from PIL import Image
|
| 9 |
|
| 10 |
from app.core.logging import logger
|
| 11 |
-
from app.services.
|
| 12 |
from app.api.models import ImageRequest, PredictionResponse
|
| 13 |
|
| 14 |
|
| 15 |
class PredictionController:
|
| 16 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
@staticmethod
|
| 19 |
-
async def
|
| 20 |
request: ImageRequest,
|
| 21 |
-
|
| 22 |
) -> PredictionResponse:
|
| 23 |
"""
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
try:
|
| 27 |
# Validate service availability
|
| 28 |
-
if not
|
| 29 |
raise HTTPException(
|
| 30 |
status_code=503,
|
| 31 |
detail="Service not initialized"
|
| 32 |
)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
if not request.image.mediaType.startswith('image/'):
|
| 36 |
raise HTTPException(
|
| 37 |
-
status_code=
|
| 38 |
-
detail=
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
# Decode base64 image data
|
| 42 |
-
try:
|
| 43 |
-
image_data = base64.b64decode(request.image.data)
|
| 44 |
-
except Exception as decode_error:
|
| 45 |
-
raise HTTPException(
|
| 46 |
-
status_code=400,
|
| 47 |
-
detail=f"Invalid base64 data: {str(decode_error)}"
|
| 48 |
)
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
image = Image.open(io.BytesIO(image_data))
|
| 53 |
-
except Exception as img_error:
|
| 54 |
raise HTTPException(
|
| 55 |
status_code=400,
|
| 56 |
-
detail=f"Invalid
|
| 57 |
)
|
| 58 |
|
| 59 |
-
#
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# Return structured response
|
| 63 |
-
return PredictionResponse(
|
| 64 |
-
prediction=result["prediction"],
|
| 65 |
-
confidence=result["confidence"],
|
| 66 |
-
model=result["model"],
|
| 67 |
-
predicted_label=result["predicted_label"],
|
| 68 |
-
mediaType=request.image.mediaType
|
| 69 |
-
)
|
| 70 |
|
| 71 |
except HTTPException:
|
| 72 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
except Exception as e:
|
|
|
|
| 74 |
logger.error(f"Prediction failed: {e}")
|
| 75 |
-
raise HTTPException(status_code=500, detail=
|
|
|
|
| 1 |
"""
|
| 2 |
Controllers for handling API business logic.
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
This controller layer orchestrates requests between the API routes and the
|
| 5 |
+
inference service layer. It handles validation and error responses.
|
| 6 |
+
|
| 7 |
+
The controller is model-agnostic and works with any InferenceService implementation.
|
| 8 |
+
"""
|
| 9 |
from fastapi import HTTPException
|
|
|
|
| 10 |
|
| 11 |
from app.core.logging import logger
|
| 12 |
+
from app.services.base import InferenceService
|
| 13 |
from app.api.models import ImageRequest, PredictionResponse
|
| 14 |
|
| 15 |
|
| 16 |
class PredictionController:
|
| 17 |
+
"""
|
| 18 |
+
Controller for ML prediction endpoints.
|
| 19 |
+
|
| 20 |
+
This controller works with any InferenceService implementation,
|
| 21 |
+
making it easy to swap different models without changing the API layer.
|
| 22 |
+
"""
|
| 23 |
|
| 24 |
@staticmethod
|
| 25 |
+
async def predict(
|
| 26 |
request: ImageRequest,
|
| 27 |
+
service: InferenceService
|
| 28 |
) -> PredictionResponse:
|
| 29 |
"""
|
| 30 |
+
Run inference using the configured model service.
|
| 31 |
+
|
| 32 |
+
The controller handles request validation and error handling,
|
| 33 |
+
while the service handles the actual inference logic.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
request: ImageRequest with base64-encoded image data
|
| 37 |
+
service: Initialized inference service (can be any model)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
PredictionResponse with prediction results
|
| 41 |
+
|
| 42 |
+
Raises:
|
| 43 |
+
HTTPException: If service unavailable, invalid input, or inference fails
|
| 44 |
"""
|
| 45 |
try:
|
| 46 |
# Validate service availability
|
| 47 |
+
if not service:
|
| 48 |
raise HTTPException(
|
| 49 |
status_code=503,
|
| 50 |
detail="Service not initialized"
|
| 51 |
)
|
| 52 |
|
| 53 |
+
if not service.is_loaded:
|
|
|
|
| 54 |
raise HTTPException(
|
| 55 |
+
status_code=503,
|
| 56 |
+
detail="Model not loaded"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
+
# Validate media type
|
| 60 |
+
if not request.image.mediaType.startswith('image/'):
|
|
|
|
|
|
|
| 61 |
raise HTTPException(
|
| 62 |
status_code=400,
|
| 63 |
+
detail=f"Invalid media type: {request.image.mediaType}. Must be image/*"
|
| 64 |
)
|
| 65 |
|
| 66 |
+
# Call service - it handles decoding and returns typed response
|
| 67 |
+
response = await service.predict(request)
|
| 68 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
except HTTPException:
|
| 71 |
raise
|
| 72 |
+
except ValueError as e:
|
| 73 |
+
# Service raises ValueError for invalid input
|
| 74 |
+
logger.error(f"Invalid input: {e}")
|
| 75 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 76 |
except Exception as e:
|
| 77 |
+
# Unexpected errors
|
| 78 |
logger.error(f"Prediction failed: {e}")
|
| 79 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
app/api/routes/prediction.py
CHANGED
|
@@ -1,20 +1,59 @@
|
|
| 1 |
"""
|
| 2 |
ML Prediction routes.
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
from fastapi import APIRouter, Depends
|
| 5 |
|
| 6 |
from app.api.controllers import PredictionController
|
| 7 |
from app.api.models import ImageRequest, PredictionResponse
|
| 8 |
-
from app.core.
|
| 9 |
-
from app.services.
|
| 10 |
|
| 11 |
router = APIRouter()
|
| 12 |
|
| 13 |
|
| 14 |
-
@router.post("/predict
|
| 15 |
-
async def
|
| 16 |
request: ImageRequest,
|
| 17 |
-
|
| 18 |
):
|
| 19 |
-
"""
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
ML Prediction routes.
|
| 3 |
+
|
| 4 |
+
This module defines the HTTP endpoints for running model inference.
|
| 5 |
+
The routes are model-agnostic and work with any InferenceService implementation.
|
| 6 |
"""
|
| 7 |
from fastapi import APIRouter, Depends
|
| 8 |
|
| 9 |
from app.api.controllers import PredictionController
|
| 10 |
from app.api.models import ImageRequest, PredictionResponse
|
| 11 |
+
from app.core.app import get_inference_service
|
| 12 |
+
from app.services.base import InferenceService
|
| 13 |
|
| 14 |
router = APIRouter()
|
| 15 |
|
| 16 |
|
| 17 |
+
@router.post("/predict", response_model=PredictionResponse)
|
| 18 |
+
async def predict(
|
| 19 |
request: ImageRequest,
|
| 20 |
+
service: InferenceService = Depends(get_inference_service)
|
| 21 |
):
|
| 22 |
+
"""
|
| 23 |
+
Run inference on an image using the configured model.
|
| 24 |
+
|
| 25 |
+
This endpoint works with any model that implements the InferenceService interface.
|
| 26 |
+
The actual model used depends on what was configured during app startup.
|
| 27 |
+
|
| 28 |
+
Example Request Body:
|
| 29 |
+
```json
|
| 30 |
+
{
|
| 31 |
+
"image": {
|
| 32 |
+
"mediaType": "image/jpeg",
|
| 33 |
+
"data": "<base64-encoded-image-data>"
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Example Response:
|
| 39 |
+
```json
|
| 40 |
+
{
|
| 41 |
+
"prediction": "tabby cat",
|
| 42 |
+
"confidence": 0.8542,
|
| 43 |
+
"model": "microsoft/resnet-18",
|
| 44 |
+
"predicted_label": 281,
|
| 45 |
+
"mediaType": "image/jpeg"
|
| 46 |
+
}
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
request: ImageRequest containing base64-encoded image
|
| 51 |
+
service: Injected inference service (configured at startup)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
PredictionResponse with model predictions
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
HTTPException: 400 for invalid input, 503 if service unavailable, 500 for errors
|
| 58 |
+
"""
|
| 59 |
+
return await PredictionController.predict(request, service)
|
app/api/routes/resnet_service_manager.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
# """
|
| 2 |
-
# Dependency injection for FastAPI.
|
| 3 |
-
# """
|
| 4 |
-
# from typing import Optional
|
| 5 |
-
# from app.services.inference import ResNetInferenceService
|
| 6 |
-
#
|
| 7 |
-
# # Global service instance
|
| 8 |
-
# _resnet_service: Optional[ResNetInferenceService] = None
|
| 9 |
-
#
|
| 10 |
-
#
|
| 11 |
-
# def get_resnet_service() -> Optional[ResNetInferenceService]:
|
| 12 |
-
# """Get the ResNet service instance."""
|
| 13 |
-
# return _resnet_service
|
| 14 |
-
#
|
| 15 |
-
#
|
| 16 |
-
# def set_resnet_service(service: ResNetInferenceService) -> None:
|
| 17 |
-
# """Set the global ResNet service instance."""
|
| 18 |
-
# global _resnet_service
|
| 19 |
-
# _resnet_service = service
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/app.py
CHANGED
|
@@ -1,16 +1,150 @@
|
|
| 1 |
"""
|
| 2 |
-
FastAPI application factory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from fastapi import FastAPI
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
from app.core.
|
| 7 |
-
from app.
|
|
|
|
| 8 |
from app.api.routes import prediction
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def create_app() -> FastAPI:
|
| 12 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
app = FastAPI(
|
| 15 |
title=settings.app_name,
|
| 16 |
description="ML inference service for image classification",
|
|
@@ -19,7 +153,6 @@ def create_app() -> FastAPI:
|
|
| 19 |
lifespan=lifespan
|
| 20 |
)
|
| 21 |
|
| 22 |
-
# Include only prediction router
|
| 23 |
app.include_router(prediction.router)
|
| 24 |
|
| 25 |
return app
|
|
|
|
| 1 |
"""
|
| 2 |
+
FastAPI application factory and core infrastructure.
|
| 3 |
+
|
| 4 |
+
This module consolidates all core application components:
|
| 5 |
+
- Configuration management
|
| 6 |
+
- Global service instance (dependency injection)
|
| 7 |
+
- Application lifecycle (startup/shutdown)
|
| 8 |
+
- FastAPI app creation
|
| 9 |
+
|
| 10 |
+
By keeping everything in one place, we avoid the complexity of managing
|
| 11 |
+
global variables across multiple modules.
|
| 12 |
"""
|
| 13 |
+
import warnings
|
| 14 |
+
from contextlib import asynccontextmanager
|
| 15 |
+
from typing import AsyncGenerator, Optional
|
| 16 |
+
|
| 17 |
from fastapi import FastAPI
|
| 18 |
+
from pydantic import Field
|
| 19 |
+
from pydantic_settings import BaseSettings
|
| 20 |
|
| 21 |
+
from app.core.logging import logger
|
| 22 |
+
from app.services.base import InferenceService
|
| 23 |
+
from app.services.inference import ResNetInferenceService
|
| 24 |
from app.api.routes import prediction
|
| 25 |
|
| 26 |
|
| 27 |
+
class Settings(BaseSettings):
|
| 28 |
+
"""
|
| 29 |
+
Application settings with environment variable support.
|
| 30 |
+
|
| 31 |
+
Settings can be overridden via environment variables or .env file.
|
| 32 |
+
"""
|
| 33 |
+
# Basic app settings
|
| 34 |
+
app_name: str = Field(default="ML Inference Service", description="Application name")
|
| 35 |
+
app_version: str = Field(default="0.1.0", description="Application version")
|
| 36 |
+
debug: bool = Field(default=False, description="Debug mode")
|
| 37 |
+
|
| 38 |
+
# Server settings
|
| 39 |
+
host: str = Field(default="0.0.0.0", description="Server host")
|
| 40 |
+
port: int = Field(default=8000, description="Server port")
|
| 41 |
+
|
| 42 |
+
class Config:
|
| 43 |
+
"""Load from .env file if it exists."""
|
| 44 |
+
env_file = ".env"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Global settings instance
|
| 48 |
+
settings = Settings()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Global inference service instance (initialized during startup)
|
| 52 |
+
_inference_service: Optional[InferenceService] = None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_inference_service() -> Optional[InferenceService]:
|
| 56 |
+
"""
|
| 57 |
+
Get the inference service instance for dependency injection.
|
| 58 |
+
|
| 59 |
+
This function is used in FastAPI route handlers via Depends().
|
| 60 |
+
The service is initialized once during app startup and reused
|
| 61 |
+
for all requests.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
The initialized inference service, or None if not yet initialized.
|
| 65 |
+
|
| 66 |
+
Example:
|
| 67 |
+
```python
|
| 68 |
+
@router.post("/predict")
|
| 69 |
+
async def predict(
|
| 70 |
+
request: ImageRequest,
|
| 71 |
+
service: InferenceService = Depends(get_inference_service)
|
| 72 |
+
):
|
| 73 |
+
return await service.predict(request)
|
| 74 |
+
```
|
| 75 |
+
"""
|
| 76 |
+
return _inference_service
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _set_inference_service(service: InferenceService) -> None:
|
| 80 |
+
"""
|
| 81 |
+
INTERNAL: Set the global inference service instance.
|
| 82 |
+
|
| 83 |
+
Called during application startup to register the service.
|
| 84 |
+
This is marked as internal (prefixed with _) because it should
|
| 85 |
+
only be called from the lifespan handler below.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
service: The initialized inference service instance.
|
| 89 |
+
"""
|
| 90 |
+
global _inference_service
|
| 91 |
+
_inference_service = service
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@asynccontextmanager
|
| 95 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 96 |
+
"""
|
| 97 |
+
Application lifespan manager.
|
| 98 |
+
|
| 99 |
+
Handles startup and shutdown events for the FastAPI application.
|
| 100 |
+
During startup, it initializes and loads the inference service.
|
| 101 |
+
|
| 102 |
+
CUSTOMIZATION POINT FOR GRAD STUDENTS:
|
| 103 |
+
To use your own model, replace ResNetInferenceService below with
|
| 104 |
+
your implementation that subclasses InferenceService.
|
| 105 |
+
|
| 106 |
+
Example:
|
| 107 |
+
```python
|
| 108 |
+
service = MyCustomService(model_name="my-org/my-model")
|
| 109 |
+
await service.load_model()
|
| 110 |
+
_set_inference_service(service)
|
| 111 |
+
```
|
| 112 |
+
"""
|
| 113 |
+
logger.info("Starting ML Inference Service...")
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
with warnings.catch_warnings():
|
| 117 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 118 |
+
|
| 119 |
+
service = ResNetInferenceService(
|
| 120 |
+
model_name="microsoft/resnet-18"
|
| 121 |
+
)
|
| 122 |
+
await service.load_model()
|
| 123 |
+
_set_inference_service(service)
|
| 124 |
+
|
| 125 |
+
logger.info("Startup completed successfully")
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Startup failed: {e}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
yield
|
| 132 |
+
|
| 133 |
+
logger.info("Shutting down...")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
def create_app() -> FastAPI:
|
| 137 |
+
"""
|
| 138 |
+
Create and configure the FastAPI application.
|
| 139 |
+
|
| 140 |
+
This is the main entry point for the application. It:
|
| 141 |
+
1. Creates a FastAPI instance with metadata from settings
|
| 142 |
+
2. Attaches the lifespan handler for startup/shutdown
|
| 143 |
+
3. Registers API routes
|
| 144 |
|
| 145 |
+
Returns:
|
| 146 |
+
Configured FastAPI application instance.
|
| 147 |
+
"""
|
| 148 |
app = FastAPI(
|
| 149 |
title=settings.app_name,
|
| 150 |
description="ML inference service for image classification",
|
|
|
|
| 153 |
lifespan=lifespan
|
| 154 |
)
|
| 155 |
|
|
|
|
| 156 |
app.include_router(prediction.router)
|
| 157 |
|
| 158 |
return app
|
app/core/config.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Basic configuration management.
|
| 3 |
-
|
| 4 |
-
Starting simple - just app settings. We'll expand as needed.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from pydantic import Field
|
| 8 |
-
from pydantic_settings import BaseSettings # Changed import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Settings(BaseSettings):
|
| 12 |
-
"""Application settings with environment variable support."""
|
| 13 |
-
|
| 14 |
-
# Basic app settings
|
| 15 |
-
app_name: str = Field(default="ML Inference Service", description="Application name")
|
| 16 |
-
app_version: str = Field(default="0.1.0", description="Application version")
|
| 17 |
-
debug: bool = Field(default=False, description="Debug mode")
|
| 18 |
-
|
| 19 |
-
# Server settings
|
| 20 |
-
host: str = Field(default="0.0.0.0", description="Server host")
|
| 21 |
-
port: int = Field(default=8000, description="Server port")
|
| 22 |
-
|
| 23 |
-
class Config:
|
| 24 |
-
"""Load from .env file if it exists."""
|
| 25 |
-
env_file = ".env"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# Global settings instance
|
| 29 |
-
settings = Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/dependencies.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Dependency injection for FastAPI.
|
| 3 |
-
"""
|
| 4 |
-
from typing import Optional
|
| 5 |
-
from app.services.inference import ResNetInferenceService
|
| 6 |
-
|
| 7 |
-
# Global service instance
|
| 8 |
-
_resnet_service: Optional[ResNetInferenceService] = None
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def get_resnet_service() -> Optional[ResNetInferenceService]:
|
| 12 |
-
"""Get the ResNet service instance."""
|
| 13 |
-
return _resnet_service
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def set_resnet_service(service: ResNetInferenceService) -> None:
|
| 17 |
-
"""Set the global ResNet service instance."""
|
| 18 |
-
global _resnet_service
|
| 19 |
-
_resnet_service = service
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/lifespan.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Application lifespan management.
|
| 3 |
-
"""
|
| 4 |
-
import warnings
|
| 5 |
-
from contextlib import asynccontextmanager
|
| 6 |
-
from typing import AsyncGenerator
|
| 7 |
-
|
| 8 |
-
from fastapi import FastAPI
|
| 9 |
-
|
| 10 |
-
from app.core.logging import logger
|
| 11 |
-
from app.core.dependencies import set_resnet_service
|
| 12 |
-
from app.services.inference import ResNetInferenceService
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@asynccontextmanager
|
| 16 |
-
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 17 |
-
"""Application lifespan manager."""
|
| 18 |
-
|
| 19 |
-
# Startup
|
| 20 |
-
logger.info("Starting ML Inference Service...")
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
with warnings.catch_warnings():
|
| 24 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 25 |
-
|
| 26 |
-
# Initialize and load ResNet service
|
| 27 |
-
resnet_service = ResNetInferenceService(
|
| 28 |
-
model_name="microsoft/resnet-18",
|
| 29 |
-
use_local_model=True
|
| 30 |
-
)
|
| 31 |
-
resnet_service.load_model()
|
| 32 |
-
set_resnet_service(resnet_service)
|
| 33 |
-
|
| 34 |
-
logger.info("Startup completed successfully")
|
| 35 |
-
|
| 36 |
-
except Exception as e:
|
| 37 |
-
logger.error(f"Startup failed: {e}")
|
| 38 |
-
raise
|
| 39 |
-
|
| 40 |
-
yield # App runs here
|
| 41 |
-
|
| 42 |
-
# Shutdown
|
| 43 |
-
logger.info("Shutting down...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/services/base.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract base class for ML inference services.
|
| 3 |
+
|
| 4 |
+
This module defines the contract that all inference services must implement.
|
| 5 |
+
Grad students should subclass `InferenceService` and implement the abstract methods
|
| 6 |
+
to integrate their models with the serving infrastructure.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from typing import Generic, TypeVar
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Type variables for request and response models
|
| 16 |
+
TRequest = TypeVar('TRequest', bound=BaseModel)
|
| 17 |
+
TResponse = TypeVar('TResponse', bound=BaseModel)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class InferenceService(ABC, Generic[TRequest, TResponse]):
|
| 21 |
+
"""
|
| 22 |
+
Abstract base class for ML inference services.
|
| 23 |
+
|
| 24 |
+
This class defines the interface that all model serving implementations must follow.
|
| 25 |
+
By subclassing this and implementing the abstract methods, you can integrate any
|
| 26 |
+
ML model with the serving infrastructure.
|
| 27 |
+
|
| 28 |
+
Type Parameters:
|
| 29 |
+
TRequest: Pydantic model for input requests (e.g., ImageRequest, TextRequest)
|
| 30 |
+
TResponse: Pydantic model for prediction responses (e.g., PredictionResponse)
|
| 31 |
+
|
| 32 |
+
Example:
|
| 33 |
+
```python
|
| 34 |
+
class MyModelService(InferenceService[MyRequest, MyResponse]):
|
| 35 |
+
|
| 36 |
+
async def load_model(self) -> None:
|
| 37 |
+
# Load your model here
|
| 38 |
+
self.model = torch.load("my_model.pt")
|
| 39 |
+
self._is_loaded = True
|
| 40 |
+
|
| 41 |
+
async def predict(self, request: MyRequest) -> MyResponse:
|
| 42 |
+
# Run inference
|
| 43 |
+
output = self.model(request.data)
|
| 44 |
+
return MyResponse(result=output)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def is_loaded(self) -> bool:
|
| 48 |
+
return self._is_loaded
|
| 49 |
+
```
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
async def load_model(self) -> None:
|
| 54 |
+
"""
|
| 55 |
+
Load the model weights and any required processors/tokenizers.
|
| 56 |
+
|
| 57 |
+
This method is called once during application startup (in the lifespan handler).
|
| 58 |
+
Use this to:
|
| 59 |
+
- Load model weights from disk
|
| 60 |
+
- Initialize processors, tokenizers, or other preprocessing components
|
| 61 |
+
- Set up any required state
|
| 62 |
+
- Perform model warmup if needed
|
| 63 |
+
|
| 64 |
+
Raises:
|
| 65 |
+
FileNotFoundError: If model files don't exist
|
| 66 |
+
RuntimeError: If model loading fails
|
| 67 |
+
"""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
async def predict(self, request: TRequest) -> TResponse:
|
| 72 |
+
"""
|
| 73 |
+
Run inference on the input request and return a typed response.
|
| 74 |
+
|
| 75 |
+
This method is called for each prediction request. It should:
|
| 76 |
+
1. Extract input data from the request
|
| 77 |
+
2. Preprocess the input (if needed)
|
| 78 |
+
3. Run the model inference
|
| 79 |
+
4. Post-process the output
|
| 80 |
+
5. Return a Pydantic response model
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
request: Input request containing the data to predict on.
|
| 84 |
+
Type is specified by the TRequest type parameter.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Typed Pydantic response model containing predictions.
|
| 88 |
+
Type is specified by the TResponse type parameter.
|
| 89 |
+
|
| 90 |
+
Raises:
|
| 91 |
+
ValueError: If input data is invalid
|
| 92 |
+
RuntimeError: If model inference fails
|
| 93 |
+
|
| 94 |
+
Important - Background Threading:
|
| 95 |
+
For CPU-intensive operations (like deep learning inference), you MUST
|
| 96 |
+
offload computation to a background thread to avoid blocking the event loop.
|
| 97 |
+
|
| 98 |
+
Pattern to follow:
|
| 99 |
+
```python
|
| 100 |
+
import asyncio
|
| 101 |
+
|
| 102 |
+
def _predict_sync(self, request: TRequest) -> TResponse:
|
| 103 |
+
# Heavy CPU work here (PyTorch, TensorFlow, etc.)
|
| 104 |
+
result = self.model(data)
|
| 105 |
+
return TResponse(result=result)
|
| 106 |
+
|
| 107 |
+
async def predict(self, request: TRequest) -> TResponse:
|
| 108 |
+
# Offload to thread pool
|
| 109 |
+
return await asyncio.to_thread(self._predict_sync, request)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
Why this matters:
|
| 113 |
+
- Inference can take 1-3+ seconds and will freeze the server
|
| 114 |
+
- asyncio.to_thread() runs the work in a background thread
|
| 115 |
+
- The event loop stays responsive to handle other requests
|
| 116 |
+
"""
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
@abstractmethod
|
| 121 |
+
def is_loaded(self) -> bool:
|
| 122 |
+
"""
|
| 123 |
+
Check if the model is loaded and ready for inference.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
True if model is loaded and ready, False otherwise.
|
| 127 |
+
|
| 128 |
+
Example:
|
| 129 |
+
```python
|
| 130 |
+
@property
|
| 131 |
+
def is_loaded(self) -> bool:
|
| 132 |
+
return self.model is not None and self._is_loaded
|
| 133 |
+
```
|
| 134 |
+
"""
|
| 135 |
+
pass
|
app/services/inference.py
CHANGED
|
@@ -1,68 +1,92 @@
|
|
| 1 |
"""
|
| 2 |
-
Inference service for
|
| 3 |
|
| 4 |
-
This
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
import torch
|
| 10 |
from PIL import Image
|
| 11 |
from transformers import AutoImageProcessor, ResNetForImageClassification
|
| 12 |
|
| 13 |
from app.core.logging import logger
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
-
class ResNetInferenceService:
|
| 17 |
"""
|
| 18 |
-
ResNet inference service.
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(self, model_name: str = "microsoft/resnet-18"
|
| 25 |
"""
|
| 26 |
Initialize the ResNet service.
|
| 27 |
|
| 28 |
Args:
|
| 29 |
-
model_name:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
self.model_name = model_name
|
| 32 |
-
self.use_local_model = use_local_model
|
| 33 |
self.model = None
|
| 34 |
self.processor = None
|
| 35 |
self._is_loaded = False
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
else:
|
| 41 |
-
self.model_path = model_name
|
| 42 |
-
logger.info(f"Initializing ResNet service with remote model: {model_name}")
|
| 43 |
|
| 44 |
-
def load_model(self) -> None:
|
| 45 |
"""
|
| 46 |
Load the ResNet model and processor.
|
| 47 |
|
| 48 |
-
This method loads the model once and reuses it for all requests.
|
|
|
|
| 49 |
"""
|
| 50 |
if self._is_loaded:
|
| 51 |
logger.debug("Model already loaded, skipping...")
|
| 52 |
return
|
| 53 |
|
| 54 |
try:
|
| 55 |
-
if self.
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
-
|
| 64 |
-
else:
|
| 65 |
-
logger.info(f"Loading ResNet model from HuggingFace Hub: {self.model_name}")
|
| 66 |
|
| 67 |
# Suppress warnings during model loading
|
| 68 |
import warnings
|
|
@@ -70,17 +94,15 @@ class ResNetInferenceService:
|
|
| 70 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 71 |
warnings.filterwarnings("ignore", message="Could not find image processor class")
|
| 72 |
|
| 73 |
-
# Load processor and model from local directory or remote
|
| 74 |
self.processor = AutoImageProcessor.from_pretrained(
|
| 75 |
self.model_path,
|
| 76 |
-
local_files_only=
|
| 77 |
)
|
| 78 |
self.model = ResNetForImageClassification.from_pretrained(
|
| 79 |
self.model_path,
|
| 80 |
-
local_files_only=
|
| 81 |
)
|
| 82 |
|
| 83 |
-
|
| 84 |
self._is_loaded = True
|
| 85 |
logger.info("ResNet model loaded successfully")
|
| 86 |
logger.info(f"Model architecture: {self.model.config.architectures}")
|
|
@@ -88,64 +110,87 @@ class ResNetInferenceService:
|
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
logger.error(f"Failed to load ResNet model: {e}")
|
| 91 |
-
|
| 92 |
-
logger.error("Hint: Make sure the model was downloaded correctly with dwl.bash")
|
| 93 |
raise
|
| 94 |
|
| 95 |
|
| 96 |
-
def
|
| 97 |
"""
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
Args:
|
| 101 |
-
|
| 102 |
|
| 103 |
Returns:
|
| 104 |
-
|
| 105 |
|
| 106 |
Raises:
|
| 107 |
-
|
| 108 |
-
ValueError: If image processing fails
|
| 109 |
"""
|
| 110 |
-
if not self._is_loaded:
|
| 111 |
-
logger.info("Model not loaded, loading now...")
|
| 112 |
-
self.load_model()
|
| 113 |
-
|
| 114 |
try:
|
| 115 |
-
logger.debug("Starting ResNet inference")
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
if image.mode != 'RGB':
|
|
|
|
| 118 |
image = image.convert('RGB')
|
| 119 |
-
logger.debug(f"Converted image from {image.mode} to RGB")
|
| 120 |
|
| 121 |
inputs = self.processor(image, return_tensors="pt")
|
| 122 |
|
| 123 |
-
# Perform inference
|
| 124 |
with torch.no_grad():
|
| 125 |
logits = self.model(**inputs).logits
|
| 126 |
|
| 127 |
-
# Get prediction
|
| 128 |
predicted_label = logits.argmax(-1).item()
|
| 129 |
predicted_class = self.model.config.id2label[predicted_label]
|
| 130 |
|
| 131 |
-
# Calculate confidence score
|
| 132 |
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 133 |
confidence = probabilities[0][predicted_label].item()
|
| 134 |
|
| 135 |
-
result = {
|
| 136 |
-
"prediction": predicted_class,
|
| 137 |
-
"confidence": round(confidence, 4),
|
| 138 |
-
"model": self.model_name,
|
| 139 |
-
"predicted_label": predicted_label
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
except Exception as e:
|
| 146 |
logger.error(f"Inference failed: {e}")
|
| 147 |
raise ValueError(f"Failed to process image: {str(e)}")
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
@property
|
| 150 |
def is_loaded(self) -> bool:
|
| 151 |
"""Check if model is loaded."""
|
|
|
|
| 1 |
"""
|
| 2 |
+
Inference service for ResNet image classification models.
|
| 3 |
|
| 4 |
+
This module provides an EXAMPLE implementation of the InferenceService ABC.
|
| 5 |
+
Grad students should use this as a reference when implementing their own model services.
|
| 6 |
+
|
| 7 |
+
This example demonstrates:
|
| 8 |
+
- How to load a HuggingFace transformer model
|
| 9 |
+
- How to preprocess image inputs
|
| 10 |
+
- How to return typed Pydantic responses
|
| 11 |
+
- How to use background threading for CPU-intensive inference
|
| 12 |
+
- Proper error handling and logging
|
| 13 |
"""
|
| 14 |
import os
|
| 15 |
+
import base64
|
| 16 |
+
import asyncio
|
| 17 |
+
from io import BytesIO
|
| 18 |
import torch
|
| 19 |
from PIL import Image
|
| 20 |
from transformers import AutoImageProcessor, ResNetForImageClassification
|
| 21 |
|
| 22 |
from app.core.logging import logger
|
| 23 |
+
from app.services.base import InferenceService
|
| 24 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 25 |
|
| 26 |
|
| 27 |
+
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 28 |
"""
|
| 29 |
+
EXAMPLE: ResNet inference service implementation.
|
| 30 |
+
|
| 31 |
+
This is a reference implementation showing how to integrate a HuggingFace
|
| 32 |
+
image classification model with the serving infrastructure.
|
| 33 |
|
| 34 |
+
To create your own service:
|
| 35 |
+
1. Subclass InferenceService[YourRequest, YourResponse]
|
| 36 |
+
2. Implement load_model() to load your model
|
| 37 |
+
3. Implement predict() to run inference and return typed response
|
| 38 |
+
4. Implement the is_loaded property
|
| 39 |
+
|
| 40 |
+
This service loads a ResNet-18 model for ImageNet classification.
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
| 44 |
"""
|
| 45 |
Initialize the ResNet service.
|
| 46 |
|
| 47 |
Args:
|
| 48 |
+
model_name: Model identifier (e.g., "microsoft/resnet-18").
|
| 49 |
+
Model files must exist in models/{model_name}/ directory.
|
| 50 |
+
The full org/model structure is preserved.
|
| 51 |
+
|
| 52 |
+
Example:
|
| 53 |
+
For model_name="microsoft/resnet-18", expects files at:
|
| 54 |
+
models/microsoft/resnet-18/config.json
|
| 55 |
+
models/microsoft/resnet-18/pytorch_model.bin
|
| 56 |
+
etc.
|
| 57 |
"""
|
| 58 |
self.model_name = model_name
|
|
|
|
| 59 |
self.model = None
|
| 60 |
self.processor = None
|
| 61 |
self._is_loaded = False
|
| 62 |
|
| 63 |
+
# Preserve full org/model path structure
|
| 64 |
+
self.model_path = os.path.join("models", model_name)
|
| 65 |
+
logger.info(f"Initializing ResNet service with local model: {self.model_path}")
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
async def load_model(self) -> None:
|
| 68 |
"""
|
| 69 |
Load the ResNet model and processor.
|
| 70 |
|
| 71 |
+
This method loads the model once during startup and reuses it for all requests.
|
| 72 |
+
Called by the application lifespan handler.
|
| 73 |
"""
|
| 74 |
if self._is_loaded:
|
| 75 |
logger.debug("Model already loaded, skipping...")
|
| 76 |
return
|
| 77 |
|
| 78 |
try:
|
| 79 |
+
if not os.path.exists(self.model_path):
|
| 80 |
+
raise FileNotFoundError(
|
| 81 |
+
f"Model directory not found: {self.model_path}\n"
|
| 82 |
+
f"Make sure the model files are downloaded to the correct location."
|
| 83 |
+
)
|
| 84 |
|
| 85 |
+
config_path = os.path.join(self.model_path, "config.json")
|
| 86 |
+
if not os.path.exists(config_path):
|
| 87 |
+
raise FileNotFoundError(f"Model config not found: {config_path}")
|
| 88 |
|
| 89 |
+
logger.info(f"Loading ResNet model from: {self.model_path}")
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Suppress warnings during model loading
|
| 92 |
import warnings
|
|
|
|
| 94 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 95 |
warnings.filterwarnings("ignore", message="Could not find image processor class")
|
| 96 |
|
|
|
|
| 97 |
self.processor = AutoImageProcessor.from_pretrained(
|
| 98 |
self.model_path,
|
| 99 |
+
local_files_only=True
|
| 100 |
)
|
| 101 |
self.model = ResNetForImageClassification.from_pretrained(
|
| 102 |
self.model_path,
|
| 103 |
+
local_files_only=True
|
| 104 |
)
|
| 105 |
|
|
|
|
| 106 |
self._is_loaded = True
|
| 107 |
logger.info("ResNet model loaded successfully")
|
| 108 |
logger.info(f"Model architecture: {self.model.config.architectures}")
|
|
|
|
| 110 |
|
| 111 |
except Exception as e:
|
| 112 |
logger.error(f"Failed to load ResNet model: {e}")
|
| 113 |
+
logger.error(f"Hint: Ensure model files exist at: {self.model_path}")
|
|
|
|
| 114 |
raise
|
| 115 |
|
| 116 |
|
| 117 |
+
def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
|
| 118 |
"""
|
| 119 |
+
INTERNAL: Synchronous prediction logic that runs in a background thread.
|
| 120 |
+
|
| 121 |
+
This method contains all CPU-intensive operations (image decoding,
|
| 122 |
+
preprocessing, PyTorch inference). It's called from predict() via
|
| 123 |
+
asyncio.to_thread() to avoid blocking the event loop.
|
| 124 |
|
| 125 |
Args:
|
| 126 |
+
request: ImageRequest containing base64-encoded image data
|
| 127 |
|
| 128 |
Returns:
|
| 129 |
+
PredictionResponse with prediction, confidence, and metadata
|
| 130 |
|
| 131 |
Raises:
|
| 132 |
+
ValueError: If image decoding or processing fails
|
|
|
|
| 133 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
try:
|
| 135 |
+
logger.debug("Starting ResNet inference in background thread")
|
| 136 |
+
|
| 137 |
+
image_data = base64.b64decode(request.image.data)
|
| 138 |
+
image = Image.open(BytesIO(image_data))
|
| 139 |
|
| 140 |
if image.mode != 'RGB':
|
| 141 |
+
logger.debug(f"Converting image from {image.mode} to RGB")
|
| 142 |
image = image.convert('RGB')
|
|
|
|
| 143 |
|
| 144 |
inputs = self.processor(image, return_tensors="pt")
|
| 145 |
|
|
|
|
| 146 |
with torch.no_grad():
|
| 147 |
logits = self.model(**inputs).logits
|
| 148 |
|
|
|
|
| 149 |
predicted_label = logits.argmax(-1).item()
|
| 150 |
predicted_class = self.model.config.id2label[predicted_label]
|
| 151 |
|
|
|
|
| 152 |
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 153 |
confidence = probabilities[0][predicted_label].item()
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
|
| 156 |
+
|
| 157 |
+
return PredictionResponse(
|
| 158 |
+
prediction=predicted_class,
|
| 159 |
+
confidence=round(confidence, 4),
|
| 160 |
+
model=self.model_name,
|
| 161 |
+
predicted_label=predicted_label,
|
| 162 |
+
mediaType=request.image.mediaType
|
| 163 |
+
)
|
| 164 |
|
| 165 |
except Exception as e:
|
| 166 |
logger.error(f"Inference failed: {e}")
|
| 167 |
raise ValueError(f"Failed to process image: {str(e)}")
|
| 168 |
|
| 169 |
+
async def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 170 |
+
"""
|
| 171 |
+
Perform inference on an image request.
|
| 172 |
+
|
| 173 |
+
This method demonstrates proper async handling for CPU-intensive operations.
|
| 174 |
+
The actual inference work is offloaded to a background thread using
|
| 175 |
+
asyncio.to_thread(), which prevents blocking the event loop.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
request: ImageRequest containing base64-encoded image data
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
PredictionResponse with prediction, confidence, and metadata
|
| 182 |
+
|
| 183 |
+
Raises:
|
| 184 |
+
RuntimeError: If model is not loaded
|
| 185 |
+
ValueError: If image decoding or processing fails
|
| 186 |
+
"""
|
| 187 |
+
if not self._is_loaded:
|
| 188 |
+
logger.warning("Model not loaded, loading now...")
|
| 189 |
+
await self.load_model()
|
| 190 |
+
|
| 191 |
+
response = await asyncio.to_thread(self._predict_sync, request)
|
| 192 |
+
return response
|
| 193 |
+
|
| 194 |
@property
|
| 195 |
def is_loaded(self) -> bool:
|
| 196 |
"""Check if model is loaded."""
|
test_main.http
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
# Test
|
|
|
|
| 2 |
|
| 3 |
-
POST http://127.0.0.1:8000/predict
|
| 4 |
Content-Type: application/json
|
| 5 |
|
| 6 |
{
|
|
|
|
| 1 |
+
# Test Prediction Endpoint
|
| 2 |
+
# Works with any model configured at startup (default: ResNet-18)
|
| 3 |
|
| 4 |
+
POST http://127.0.0.1:8000/predict
|
| 5 |
Content-Type: application/json
|
| 6 |
|
| 7 |
{
|