sachin sharma commited on
Commit
5ddae77
Β·
1 Parent(s): ebbcd26

refactored codebase

Browse files
README.md CHANGED
@@ -1,89 +1,94 @@
1
  # ML Inference Service (FastAPI)
2
 
3
- A production-ready **FastAPI** web service that serves **image classification** models.
4
- This repo ships with a working example using **ResNet-18** (downloaded from Hugging Face) under `models/resnet-18/` and exposes a simple **REST** endpoint.
5
 
6
- ---
 
 
 
 
 
 
 
7
 
8
- ## ✨ What you get
9
 
10
- - FastAPI application with clean layering (routes β†’ controller β†’ service)
11
- - Hot-loaded model on startup (single instance reused per request)
12
- - Hugging Face–compatible local model folder (`config.json`, weights, preprocessor, etc.)
13
- - Example endpoint: `POST /predict/resnet` that accepts a base64 image and returns:
14
- - `prediction` (class label)
15
- - `confidence` (softmax probability)
16
- - `predicted_label` (class index)
17
- - `model` (model id)
18
- - `mediaType` (echoed)
19
 
20
- ---
21
 
22
- ## 🧭 Project Layout
23
 
24
  ```
25
  ml-inference-service/
26
- β”œβ”€ main.py
27
  β”œβ”€ app/
28
- β”‚ β”œβ”€ __init__.py
29
  β”‚ β”œβ”€ core/
30
- β”‚ β”‚ β”œβ”€ app.py # App factory & router wiring
31
- β”‚ β”‚ β”œβ”€ config.py # Settings (app name/version/debug)
32
- β”‚ β”‚ β”œβ”€ dependencies.py # DI for model services
33
- β”‚ β”‚ β”œβ”€ lifespan.py # Startup: load model & register service
34
  β”‚ β”‚ └─ logging.py # Logger setup
35
  β”‚ β”œβ”€ api/
36
- β”‚ β”‚ β”œβ”€ models.py # Pydantic request/response
37
  β”‚ β”‚ β”œβ”€ controllers.py # HTTP β†’ service orchestration
38
  β”‚ β”‚ └─ routes/
39
- β”‚ β”‚ β”œβ”€ prediction.py # `POST /predict/resnet`
40
- β”‚ β”‚ └─ resnet_service_manager.py (legacy, unused)
41
  β”‚ └─ services/
42
- β”‚ └─ inference.py # ResNetInferenceService (load/predict)
 
43
  β”œβ”€ models/
44
- β”‚ └─ resnet-18/ # Sample HF-style model folder
 
45
  β”œβ”€ scripts/
46
- β”‚ β”œβ”€ model_download.bash # One-liner to snapshot HF weights locally
47
- β”‚ β”œβ”€ generate_test_datasets.py # Generate PyArrow datasets for testing
48
- β”‚ β”œβ”€ test_datasets.py # Test generated datasets against API
49
- β”‚ └─ test_datasets/ # Generated PyArrow test datasets (100 files)
50
- β”œβ”€ requirements.in / requirements.txt
51
- └─ test_main.http # Example request you can run from IDEs
52
  ```
53
 
54
- ---
55
 
56
- ## πŸš€ Quickstart
57
 
58
- ### 1) Install dependencies (Python 3.9+)
59
  ```bash
60
  python -m venv .venv
61
  source .venv/bin/activate # Windows: .venv\Scripts\activate
62
  pip install -r requirements.txt
63
  ```
64
 
65
- ### 2) Download the sample model (ResNet‑18) locally
66
  ```bash
67
  bash scripts/model_download.bash
68
  ```
69
- This populates `models/resnet-18/` with Hugging Face artifacts (`config.json`, weights, `preprocessor_config.json`, etc.).
70
 
71
- ### 3) Run the server
72
  ```bash
73
  uvicorn main:app --reload
74
  ```
75
- Server listens on `http://127.0.0.1:8000`.
76
 
77
- ### 4) Call the API
78
- - Use `test_main.http` from your IDE (VSCode/IntelliJ) **or** curl:
 
79
 
80
  ```bash
81
- curl -X POST http://127.0.0.1:8000/predict/resnet -H "Content-Type: application/json" -d '{
82
- "image": { "mediaType": "image/jpeg", "data": "<base64-encoded-bytes>" }
 
 
 
 
 
83
  }'
84
  ```
85
 
86
- **Response (example):**
87
  ```json
88
  {
89
  "prediction": "tiger cat",
@@ -94,117 +99,121 @@ curl -X POST http://127.0.0.1:8000/predict/resnet -H "Content-Type: applicatio
94
  }
95
  ```
96
 
97
- ---
98
-
99
- ## 🧩 Bring Your Own Model (BYOM)
100
-
101
- There are **two** ways to integrate your own model.
102
-
103
- ### Option A β€” *Drop-in replacement (zero code changes)*
104
-
105
- If your model is a **Hugging Face image classification** model that works with
106
- `AutoImageProcessor` and `ResNetForImageClassification` **or** a compatible
107
- `*ForImageClassification` class from `transformers`, you can simply place the
108
- model folder alongside `resnet-18` and point the service at it.
109
-
110
- 1. Put your HF-style folder under `models/<your-model-name>/` containing at least:
111
- - `config.json`
112
- - weights (e.g., `pytorch_model.bin` or `model.safetensors`)
113
- - `preprocessor_config.json` / `image_processor` files
114
-
115
- 2. **Choose one** of these approaches:
116
- - **Simplest**: Replace the contents of `models/resnet-18/` with your model files *but keep the folder name*. The existing `/predict/resnet` endpoint will now serve your model.
117
- - **Preferred**: Change the model id used at startup:
118
- - Open `app/core/lifespan.py` and modify the service initialization:
119
- ```python
120
- resnet_service = ResNetInferenceService(
121
- model_name="your-org/your-model", # used for local folder name
122
- use_local_model=True # loads from models/your-model/
123
- )
124
- ```
125
- - Ensure your local folder is `models/your-model/`.
126
-
127
- > How folder naming works: when `use_local_model=True`, the service derives the
128
- > local directory as `models/<last-segment-of-model_name>`. For
129
- > `"microsoft/resnet-18"` that becomes `models/resnet-18`. For
130
- > `"your-org/awesome-vit-base"`, it becomes `models/awesome-vit-base`.
131
-
132
- That’s it. No code changes elsewhere if your model is a standard image classifier.
133
-
134
- ---
135
-
136
- ### Option B β€” *New task/model type (minimal code: new service + route)*
137
-
138
- If you are **not** serving a Hugging Face image classifier (e.g., object detection,
139
- segmentation, text models), implement a small service class and a route mirroring
140
- the `ResNetInferenceService` flow.
141
-
142
- 1. **Create your service** (copy and adapt `ResNetInferenceService`):
143
- - File: `app/services/<your_model>_service.py`
144
- - Responsibilities you must implement:
145
- - `__init__(model_name: str, use_local_model: bool)` β†’ set `self.model_path`
146
- - `load_model()` β†’ load weights & preprocessor
147
- - `predict(image: PIL.Image.Image) -> Dict[str, Any]` β†’ run inference and return a dict with:
148
- ```python
149
- {
150
- "prediction": "<your label or structured result>",
151
- "confidence": <float 0..1>,
152
- "predicted_label": <int or meaningful code>,
153
- "model": "<model id>"
154
- }
155
- ```
156
- *Feel free to extend the payload; just update the API schema accordingly.*
157
-
158
- 2. **Wire the dependency**:
159
- - Register your service at startup in `app/core/lifespan.py` similar to ResNet:
160
- ```python
161
- from app.core.dependencies import set_resnet_service # or create your own set/get
162
- from app.services.your_model_service import YourModelService
163
-
164
- svc = YourModelService(model_name="your-org/your-model", use_local_model=True)
165
- svc.load_model()
166
- set_resnet_service(svc) # or create set_your_model_service(...)
167
- ```
168
- - Optionally create **new getters/setters** in `app/core/dependencies.py` if you serve multiple models in parallel (one getter per model).
169
-
170
- 3. **Add a route**:
171
- - Create `app/api/routes/your_model.py` analogous to `prediction.py`:
172
- ```python
173
- from fastapi import APIRouter, Depends
174
- from app.api.controllers import PredictionController
175
- from app.api.models import ImageRequest, PredictionResponse
176
- from app.core.dependencies import get_resnet_service # or your getter
177
- from app.services.your_model_service import YourModelService
178
-
179
- router = APIRouter()
180
-
181
- @router.post("/predict/your-model", response_model=PredictionResponse)
182
- async def predict_image(request: ImageRequest, service: YourModelService = Depends(get_resnet_service)):
183
- controller = PredictionController(service) # reuse the controller
184
- return await controller.predict(request)
185
- ```
186
- - Register the router in `app/core/app.py`:
187
- ```python
188
- from app.api.routes import your_model as your_model_routes
189
- app.include_router(your_model_routes.router)
190
- ```
191
-
192
- 4. **Adjust schemas if needed**:
193
- - The default `PredictionResponse` in `app/api/models.py` is for single-label classification. For other tasks, either extend it or define a new response model and use it in your route’s `response_model=`.
194
-
195
- > **Tip**: Keep your controller thin and push all model-specific logic into your service class. The server glue (DI + routes) stays identical across models.
196
-
197
- ---
198
-
199
- ## πŸ§ͺ Validating your setup
200
-
201
- - **Startup logs** should include: `Initializing ResNet service with local model: models/<folder>` and `Model and processor loaded successfully`.
202
- - Hitting your endpoint should return a **200** with a JSON body like the example above.
203
- - If you see `Local model directory not found`, check your `models/<name>/` path and filenames.
204
-
205
- ---
206
-
207
- ## πŸ”Œ Request & Response Shapes
 
 
 
 
208
 
209
  ### Request
210
  ```json
@@ -227,58 +236,51 @@ the `ResNetInferenceService` flow.
227
  }
228
  ```
229
 
230
- ---
231
-
232
- ## βš™οΈ Configuration
233
 
234
- Basic settings live in `app/core/config.py`. Out of the box we keep it simple:
235
- - `app_name`, `app_version`, `debug`
 
 
 
 
236
 
237
- If you want to make the **model** configurable without touching code, extend `Settings` with a `model_name` env var and consume it in `lifespan.py` when creating your service instance.
238
 
239
- Example:
240
  ```python
241
- # app/core/config.py
242
- from pydantic_settings import BaseSettings
243
- from pydantic import Field
244
-
245
  class Settings(BaseSettings):
246
- app_name: str = Field("ML Inference Service")
247
- app_version: str = Field("0.1.0")
248
- debug: bool = Field(False)
249
- model_name: str = Field("microsoft/resnet-18", description="HF model id used at startup")
250
-
251
- settings = Settings()
252
 
253
- # app/core/lifespan.py
254
- from app.core.config import settings
255
- svc = ResNetInferenceService(model_name=settings.model_name, use_local_model=True)
256
  ```
257
 
258
- Then set `MODEL_NAME=your-org/your-model` in your environment (Pydantic will map `model_name` from `MODEL_NAME`).
259
-
260
- ---
261
 
262
- ## πŸ“¦ Packaging & Deployment
 
 
 
263
 
264
- - **Dev**: `uvicorn main:app --reload`
265
- - **Prod**: Use a process manager (e.g., `gunicorn -k uvicorn.workers.UvicornWorker`) and add health checks.
266
- - **Containerize**: Copy only `requirements.txt` and source, install wheels, and bake the `models/` folder into the image or mount it as a volume.
267
- - **CPU vs GPU**: This example uses CPU by default. If you have CUDA, install a CUDA-enabled PyTorch build and set device placement in your service.
268
 
269
- ---
270
 
271
- ## πŸ§ͺ PyArrow Test Datasets
272
 
273
  This project includes a comprehensive **PyArrow-based dataset generation system** designed specifically for academic challenges and ML model validation. The system generates **100 standardized test datasets** that allow participants to validate their models against consistent, reproducible test cases.
274
 
275
- ### πŸ—οΈ Why Both? `.parquet` + `_metadata.json`
276
  ```
277
  standard_test_001.parquet # Actual test data (images, requests, responses)
278
  standard_test_001_metadata.json # Human-readable description and stats
279
  ```
280
 
281
- ### πŸ“Š Dataset Categories (25 each = 100 total)
282
 
283
  #### 1. **Standard Test Cases** (`standard_test_*.parquet`)
284
  **Purpose**: Baseline functionality validation
@@ -321,7 +323,7 @@ standard_test_001_metadata.json # Human-readable description and stats
321
  - **Comparative Analysis**: Enables direct performance comparison between models
322
  - **Expected Behavior**: Architecture-specific but structurally consistent responses
323
 
324
- ### πŸ› οΈ Generation Process
325
 
326
  The dataset generation follows a **deterministic, reproducible approach**:
327
 
@@ -378,7 +380,7 @@ table = pa.table({
378
  })
379
  ```
380
 
381
- ### πŸš€ Usage Guide
382
 
383
 
384
  **1. Generate Test Datasets**
@@ -408,12 +410,12 @@ python scripts/test_datasets.py --category edge_case
408
  python scripts/test_datasets.py --category performance
409
  ```
410
 
411
- ### πŸ“ˆ Testing Output and Metrics
412
 
413
  The test runner provides comprehensive validation metrics:
414
 
415
  ```
416
- 🏁 DATASET TESTING SUMMARY
417
  ============================================================
418
  Datasets tested: 100
419
  Successful datasets: 95
 
1
  # ML Inference Service (FastAPI)
2
 
3
+ A FastAPI-based inference server designed to make it easy to serve your ML models. The repo includes a complete working example using ResNet-18 for image classification, but the architecture is built to be model-agnostic. You implement a simple abstract base class, and everything else just works.
 
4
 
5
+ Key features:
6
+ - Abstract InferenceService class that you subclass for your model
7
+ - Example ResNet-18 implementation showing how to do it
8
+ - FastAPI application with clean separation (routes β†’ controller β†’ service)
9
+ - Model loaded once at startup and reused across requests
10
+ - Background threading for inference so the server stays responsive
11
+ - Type-safe request/response handling with Pydantic
12
+ - Single generic endpoint that works with any model
13
 
14
+ ## What you get
15
 
16
+ The service exposes a single endpoint `POST /predict` that accepts a base64-encoded image and returns:
17
+ - `prediction` - the predicted class label
18
+ - `confidence` - softmax probability for the prediction
19
+ - `predicted_label` - numeric class index
20
+ - `model` - identifier for which model produced this prediction
21
+ - `mediaType` - echoed from the request
 
 
 
22
 
23
+ The inference runs in a background thread using asyncio so long-running model predictions don't block the server from handling other requests.
24
 
25
+ ## Project Layout
26
 
27
  ```
28
  ml-inference-service/
29
+ β”œβ”€ main.py # Entry point
30
  β”œβ”€ app/
 
31
  β”‚ β”œβ”€ core/
32
+ β”‚ β”‚ β”œβ”€ app.py # Everything: config, DI, lifespan, app factory
 
 
 
33
  β”‚ β”‚ └─ logging.py # Logger setup
34
  β”‚ β”œβ”€ api/
35
+ β”‚ β”‚ β”œβ”€ models.py # Pydantic request/response schemas
36
  β”‚ β”‚ β”œβ”€ controllers.py # HTTP β†’ service orchestration
37
  β”‚ β”‚ └─ routes/
38
+ β”‚ β”‚ └─ prediction.py # POST /predict endpoint
 
39
  β”‚ └─ services/
40
+ β”‚ β”œβ”€ base.py # Abstract InferenceService class
41
+ β”‚ └─ inference.py # ResNetInferenceService (example implementation)
42
  β”œβ”€ models/
43
+ β”‚ └─ microsoft/
44
+ β”‚ └─ resnet-18/ # Model files (preserves org structure)
45
  β”œβ”€ scripts/
46
+ β”‚ β”œβ”€ generate_test_datasets.py
47
+ β”‚ β”œβ”€ test_datasets.py
48
+ β”‚ └─ test_datasets/
49
+ β”œβ”€ requirements.txt
50
+ └─ test_main.http # Example HTTP request
 
51
  ```
52
 
53
+ The key change from a typical FastAPI app is that `app/core/app.py` consolidates configuration, dependency injection, lifecycle management, and the app factory into one file. This avoids the complexity of managing global variables across multiple modules.
54
 
55
+ ## Quickstart
56
 
57
+ 1) Install dependencies (Python 3.9+)
58
  ```bash
59
  python -m venv .venv
60
  source .venv/bin/activate # Windows: .venv\Scripts\activate
61
  pip install -r requirements.txt
62
  ```
63
 
64
+ 2) Download the example model
65
  ```bash
66
  bash scripts/model_download.bash
67
  ```
68
+ This downloads ResNet-18 from Hugging Face and saves it to `models/microsoft/resnet-18/` (note the org structure is preserved).
69
 
70
+ 3) Run the server
71
  ```bash
72
  uvicorn main:app --reload
73
  ```
74
+ Server starts on `http://127.0.0.1:8000`.
75
 
76
+ 4) Test the API
77
+
78
+ Use `test_main.http` from your IDE or curl:
79
 
80
  ```bash
81
+ curl -X POST http://127.0.0.1:8000/predict \
82
+ -H "Content-Type: application/json" \
83
+ -d '{
84
+ "image": {
85
+ "mediaType": "image/jpeg",
86
+ "data": "<base64-encoded-bytes>"
87
+ }
88
  }'
89
  ```
90
 
91
+ Example response:
92
  ```json
93
  {
94
  "prediction": "tiger cat",
 
99
  }
100
  ```
101
 
102
+ ## Integrating Your Own Model
103
+
104
+ To use your own model, you implement the `InferenceService` abstract base class. The rest of the infrastructure (API routes, controllers, dependency injection) is already generic and works with any implementation.
105
+
106
+ ### Step 1: Implement the InferenceService ABC
107
+
108
+ Create a new file `app/services/your_model_service.py`:
109
+
110
+ ```python
111
+ from app.services.base import InferenceService
112
+ from app.api.models import ImageRequest, PredictionResponse
113
+
114
+ class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
115
+ def __init__(self, model_name: str):
116
+ self.model_name = model_name
117
+ self.model_path = os.path.join("models", model_name)
118
+ self.model = None
119
+ self._is_loaded = False
120
+
121
+ async def load_model(self) -> None:
122
+ # Load your model here
123
+ self.model = load_your_model(self.model_path)
124
+ self._is_loaded = True
125
+
126
+ async def predict(self, request: ImageRequest) -> PredictionResponse:
127
+ # Offload to background thread (important for performance)
128
+ return await asyncio.to_thread(self._predict_sync, request)
129
+
130
+ def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
131
+ # Decode image, run inference, return typed response
132
+ image = decode_base64_image(request.image.data)
133
+ result = self.model(image)
134
+ return PredictionResponse(
135
+ prediction=result.label,
136
+ confidence=result.confidence,
137
+ predicted_label=result.class_id,
138
+ model=self.model_name,
139
+ mediaType=request.image.mediaType
140
+ )
141
+
142
+ @property
143
+ def is_loaded(self) -> bool:
144
+ return self._is_loaded
145
+ ```
146
+
147
+ The key points:
148
+ - Subclass `InferenceService[RequestType, ResponseType]` with your request/response types
149
+ - Implement three methods: `load_model()`, `predict()`, and `is_loaded` property
150
+ - Use `asyncio.to_thread()` to offload CPU-intensive inference to a background thread
151
+ - Return typed Pydantic models, not dicts
152
+
153
+ ### Step 2: Register your service at startup
154
+
155
+ Edit `app/core/app.py` and find the lifespan function (around line 134):
156
+
157
+ ```python
158
+ # Replace this:
159
+ service = ResNetInferenceService(model_name="microsoft/resnet-18")
160
+
161
+ # With this:
162
+ service = YourModelService(model_name="your-org/your-model")
163
+ ```
164
+
165
+ That's it. The same `/predict` endpoint now serves your model.
166
+
167
+ ### Model file structure
168
+
169
+ Your model files should be organized as:
170
+ ```
171
+ models/
172
+ └── your-org/
173
+ └── your-model/
174
+ β”œβ”€β”€ config.json
175
+ β”œβ”€β”€ weights.bin
176
+ └── ... other files
177
+ ```
178
+
179
+ The full org/model structure is preserved - no more dropping the org prefix.
180
+
181
+ ### Example: Swapping ResNet for ViT
182
+
183
+ ```python
184
+ # app/services/vit_service.py
185
+ from transformers import ViTForImageClassification, ViTImageProcessor
186
+
187
+ class ViTService(InferenceService[ImageRequest, PredictionResponse]):
188
+ async def load_model(self) -> None:
189
+ self.processor = ViTImageProcessor.from_pretrained(self.model_path)
190
+ self.model = ViTForImageClassification.from_pretrained(self.model_path)
191
+ self._is_loaded = True
192
+
193
+ # ... implement predict() following the pattern above
194
+ ```
195
+
196
+ Then in `app/core/app.py`:
197
+ ```python
198
+ service = ViTService(model_name="google/vit-base-patch16-224")
199
+ ```
200
+
201
+ No other changes needed - the routes, controller, and dependency injection are all model-agnostic.
202
+
203
+ ## Validating your setup
204
+
205
+ When you start the server, the logs should show:
206
+ ```
207
+ INFO: Starting ML Inference Service...
208
+ INFO: Initializing ResNet service with local model: models/microsoft/resnet-18
209
+ INFO: Loading ResNet model from: models/microsoft/resnet-18
210
+ INFO: ResNet model loaded successfully
211
+ INFO: Startup completed successfully
212
+ ```
213
+
214
+ If you see errors like `Model directory not found`, check that your model files exist at the expected path with the full org/model structure.
215
+
216
+ ## Request & Response Shapes
217
 
218
  ### Request
219
  ```json
 
236
  }
237
  ```
238
 
239
+ ## Configuration
 
 
240
 
241
+ Settings are defined in `app/core/app.py` in the `Settings` class. The defaults are:
242
+ - `app_name` - "ML Inference Service"
243
+ - `app_version` - "0.1.0"
244
+ - `debug` - False
245
+ - `host` - "0.0.0.0"
246
+ - `port` - 8000
247
 
248
+ You can override these via environment variables or a `.env` file. If you want to make the model configurable via environment variable, add it to the Settings class:
249
 
 
250
  ```python
 
 
 
 
251
  class Settings(BaseSettings):
252
+ # ... existing fields ...
253
+ model_name: str = Field("microsoft/resnet-18")
 
 
 
 
254
 
255
+ # Then in the lifespan function:
256
+ service = ResNetInferenceService(model_name=settings.model_name)
 
257
  ```
258
 
259
+ ## Deployment
 
 
260
 
261
+ For development:
262
+ ```bash
263
+ uvicorn main:app --reload
264
+ ```
265
 
266
+ For production, use gunicorn with uvicorn workers:
267
+ ```bash
268
+ gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
269
+ ```
270
 
271
+ The service runs on CPU by default. For GPU inference, install CUDA-enabled PyTorch and modify your service to move tensors to the GPU device.
272
 
273
+ ## PyArrow Test Datasets
274
 
275
  This project includes a comprehensive **PyArrow-based dataset generation system** designed specifically for academic challenges and ML model validation. The system generates **100 standardized test datasets** that allow participants to validate their models against consistent, reproducible test cases.
276
 
277
+ ### File Structure
278
  ```
279
  standard_test_001.parquet # Actual test data (images, requests, responses)
280
  standard_test_001_metadata.json # Human-readable description and stats
281
  ```
282
 
283
+ ### Dataset Categories (25 each = 100 total)
284
 
285
  #### 1. **Standard Test Cases** (`standard_test_*.parquet`)
286
  **Purpose**: Baseline functionality validation
 
323
  - **Comparative Analysis**: Enables direct performance comparison between models
324
  - **Expected Behavior**: Architecture-specific but structurally consistent responses
325
 
326
+ ### Generation Process
327
 
328
  The dataset generation follows a **deterministic, reproducible approach**:
329
 
 
380
  })
381
  ```
382
 
383
+ ### Usage Guide
384
 
385
 
386
  **1. Generate Test Datasets**
 
410
  python scripts/test_datasets.py --category performance
411
  ```
412
 
413
+ ### Testing Output and Metrics
414
 
415
  The test runner provides comprehensive validation metrics:
416
 
417
  ```
418
+ DATASET TESTING SUMMARY
419
  ============================================================
420
  Datasets tested: 100
421
  Successful datasets: 95
app/api/controllers.py CHANGED
@@ -1,75 +1,79 @@
1
  """
2
  Controllers for handling API business logic.
3
- """
4
- import base64
5
- import io
6
 
 
 
 
 
 
7
  from fastapi import HTTPException
8
- from PIL import Image
9
 
10
  from app.core.logging import logger
11
- from app.services.inference import ResNetInferenceService
12
  from app.api.models import ImageRequest, PredictionResponse
13
 
14
 
15
  class PredictionController:
16
- """Controller for ML prediction endpoints."""
 
 
 
 
 
17
 
18
  @staticmethod
19
- async def predict_resnet(
20
  request: ImageRequest,
21
- resnet_service: ResNetInferenceService
22
  ) -> PredictionResponse:
23
  """
24
- Classify an image using ResNet-18 from base64 encoded data.
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  """
26
  try:
27
  # Validate service availability
28
- if not resnet_service:
29
  raise HTTPException(
30
  status_code=503,
31
  detail="Service not initialized"
32
  )
33
 
34
- # Validate media type
35
- if not request.image.mediaType.startswith('image/'):
36
  raise HTTPException(
37
- status_code=400,
38
- detail=f"Invalid media type: {request.image.mediaType}"
39
- )
40
-
41
- # Decode base64 image data
42
- try:
43
- image_data = base64.b64decode(request.image.data)
44
- except Exception as decode_error:
45
- raise HTTPException(
46
- status_code=400,
47
- detail=f"Invalid base64 data: {str(decode_error)}"
48
  )
49
 
50
- # Load and validate image
51
- try:
52
- image = Image.open(io.BytesIO(image_data))
53
- except Exception as img_error:
54
  raise HTTPException(
55
  status_code=400,
56
- detail=f"Invalid image file: {str(img_error)}"
57
  )
58
 
59
- # Perform prediction
60
- result = resnet_service.predict(image)
61
-
62
- # Return structured response
63
- return PredictionResponse(
64
- prediction=result["prediction"],
65
- confidence=result["confidence"],
66
- model=result["model"],
67
- predicted_label=result["predicted_label"],
68
- mediaType=request.image.mediaType
69
- )
70
 
71
  except HTTPException:
72
  raise
 
 
 
 
73
  except Exception as e:
 
74
  logger.error(f"Prediction failed: {e}")
75
- raise HTTPException(status_code=500, detail=str(e))
 
1
  """
2
  Controllers for handling API business logic.
 
 
 
3
 
4
+ This controller layer orchestrates requests between the API routes and the
5
+ inference service layer. It handles validation and error responses.
6
+
7
+ The controller is model-agnostic and works with any InferenceService implementation.
8
+ """
9
  from fastapi import HTTPException
 
10
 
11
  from app.core.logging import logger
12
+ from app.services.base import InferenceService
13
  from app.api.models import ImageRequest, PredictionResponse
14
 
15
 
16
  class PredictionController:
17
+ """
18
+ Controller for ML prediction endpoints.
19
+
20
+ This controller works with any InferenceService implementation,
21
+ making it easy to swap different models without changing the API layer.
22
+ """
23
 
24
  @staticmethod
25
+ async def predict(
26
  request: ImageRequest,
27
+ service: InferenceService
28
  ) -> PredictionResponse:
29
  """
30
+ Run inference using the configured model service.
31
+
32
+ The controller handles request validation and error handling,
33
+ while the service handles the actual inference logic.
34
+
35
+ Args:
36
+ request: ImageRequest with base64-encoded image data
37
+ service: Initialized inference service (can be any model)
38
+
39
+ Returns:
40
+ PredictionResponse with prediction results
41
+
42
+ Raises:
43
+ HTTPException: If service unavailable, invalid input, or inference fails
44
  """
45
  try:
46
  # Validate service availability
47
+ if not service:
48
  raise HTTPException(
49
  status_code=503,
50
  detail="Service not initialized"
51
  )
52
 
53
+ if not service.is_loaded:
 
54
  raise HTTPException(
55
+ status_code=503,
56
+ detail="Model not loaded"
 
 
 
 
 
 
 
 
 
57
  )
58
 
59
+ # Validate media type
60
+ if not request.image.mediaType.startswith('image/'):
 
 
61
  raise HTTPException(
62
  status_code=400,
63
+ detail=f"Invalid media type: {request.image.mediaType}. Must be image/*"
64
  )
65
 
66
+ # Call service - it handles decoding and returns typed response
67
+ response = await service.predict(request)
68
+ return response
 
 
 
 
 
 
 
 
69
 
70
  except HTTPException:
71
  raise
72
+ except ValueError as e:
73
+ # Service raises ValueError for invalid input
74
+ logger.error(f"Invalid input: {e}")
75
+ raise HTTPException(status_code=400, detail=str(e))
76
  except Exception as e:
77
+ # Unexpected errors
78
  logger.error(f"Prediction failed: {e}")
79
+ raise HTTPException(status_code=500, detail="Internal server error")
app/api/routes/prediction.py CHANGED
@@ -1,20 +1,59 @@
1
  """
2
  ML Prediction routes.
 
 
 
3
  """
4
  from fastapi import APIRouter, Depends
5
 
6
  from app.api.controllers import PredictionController
7
  from app.api.models import ImageRequest, PredictionResponse
8
- from app.core.dependencies import get_resnet_service
9
- from app.services.inference import ResNetInferenceService
10
 
11
  router = APIRouter()
12
 
13
 
14
- @router.post("/predict/resnet", response_model=PredictionResponse)
15
- async def predict_image(
16
  request: ImageRequest,
17
- resnet_service: ResNetInferenceService = Depends(get_resnet_service)
18
  ):
19
- """Classify an image using ResNet-18 from base64 encoded data."""
20
- return await PredictionController.predict_resnet(request, resnet_service)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  ML Prediction routes.
3
+
4
+ This module defines the HTTP endpoints for running model inference.
5
+ The routes are model-agnostic and work with any InferenceService implementation.
6
  """
7
  from fastapi import APIRouter, Depends
8
 
9
  from app.api.controllers import PredictionController
10
  from app.api.models import ImageRequest, PredictionResponse
11
+ from app.core.app import get_inference_service
12
+ from app.services.base import InferenceService
13
 
14
  router = APIRouter()
15
 
16
 
17
+ @router.post("/predict", response_model=PredictionResponse)
18
+ async def predict(
19
  request: ImageRequest,
20
+ service: InferenceService = Depends(get_inference_service)
21
  ):
22
+ """
23
+ Run inference on an image using the configured model.
24
+
25
+ This endpoint works with any model that implements the InferenceService interface.
26
+ The actual model used depends on what was configured during app startup.
27
+
28
+ Example Request Body:
29
+ ```json
30
+ {
31
+ "image": {
32
+ "mediaType": "image/jpeg",
33
+ "data": "<base64-encoded-image-data>"
34
+ }
35
+ }
36
+ ```
37
+
38
+ Example Response:
39
+ ```json
40
+ {
41
+ "prediction": "tabby cat",
42
+ "confidence": 0.8542,
43
+ "model": "microsoft/resnet-18",
44
+ "predicted_label": 281,
45
+ "mediaType": "image/jpeg"
46
+ }
47
+ ```
48
+
49
+ Args:
50
+ request: ImageRequest containing base64-encoded image
51
+ service: Injected inference service (configured at startup)
52
+
53
+ Returns:
54
+ PredictionResponse with model predictions
55
+
56
+ Raises:
57
+ HTTPException: 400 for invalid input, 503 if service unavailable, 500 for errors
58
+ """
59
+ return await PredictionController.predict(request, service)
app/api/routes/resnet_service_manager.py DELETED
@@ -1,19 +0,0 @@
1
- # """
2
- # Dependency injection for FastAPI.
3
- # """
4
- # from typing import Optional
5
- # from app.services.inference import ResNetInferenceService
6
- #
7
- # # Global service instance
8
- # _resnet_service: Optional[ResNetInferenceService] = None
9
- #
10
- #
11
- # def get_resnet_service() -> Optional[ResNetInferenceService]:
12
- # """Get the ResNet service instance."""
13
- # return _resnet_service
14
- #
15
- #
16
- # def set_resnet_service(service: ResNetInferenceService) -> None:
17
- # """Set the global ResNet service instance."""
18
- # global _resnet_service
19
- # _resnet_service = service
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/app.py CHANGED
@@ -1,16 +1,150 @@
1
  """
2
- FastAPI application factory.
 
 
 
 
 
 
 
 
 
3
  """
 
 
 
 
4
  from fastapi import FastAPI
 
 
5
 
6
- from app.core.config import settings
7
- from app.core.lifespan import lifespan
 
8
  from app.api.routes import prediction
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def create_app() -> FastAPI:
12
- """Application factory."""
 
 
 
 
 
 
13
 
 
 
 
14
  app = FastAPI(
15
  title=settings.app_name,
16
  description="ML inference service for image classification",
@@ -19,7 +153,6 @@ def create_app() -> FastAPI:
19
  lifespan=lifespan
20
  )
21
 
22
- # Include only prediction router
23
  app.include_router(prediction.router)
24
 
25
  return app
 
1
  """
2
+ FastAPI application factory and core infrastructure.
3
+
4
+ This module consolidates all core application components:
5
+ - Configuration management
6
+ - Global service instance (dependency injection)
7
+ - Application lifecycle (startup/shutdown)
8
+ - FastAPI app creation
9
+
10
+ By keeping everything in one place, we avoid the complexity of managing
11
+ global variables across multiple modules.
12
  """
13
+ import warnings
14
+ from contextlib import asynccontextmanager
15
+ from typing import AsyncGenerator, Optional
16
+
17
  from fastapi import FastAPI
18
+ from pydantic import Field
19
+ from pydantic_settings import BaseSettings
20
 
21
+ from app.core.logging import logger
22
+ from app.services.base import InferenceService
23
+ from app.services.inference import ResNetInferenceService
24
  from app.api.routes import prediction
25
 
26
 
27
+ class Settings(BaseSettings):
28
+ """
29
+ Application settings with environment variable support.
30
+
31
+ Settings can be overridden via environment variables or .env file.
32
+ """
33
+ # Basic app settings
34
+ app_name: str = Field(default="ML Inference Service", description="Application name")
35
+ app_version: str = Field(default="0.1.0", description="Application version")
36
+ debug: bool = Field(default=False, description="Debug mode")
37
+
38
+ # Server settings
39
+ host: str = Field(default="0.0.0.0", description="Server host")
40
+ port: int = Field(default=8000, description="Server port")
41
+
42
+ class Config:
43
+ """Load from .env file if it exists."""
44
+ env_file = ".env"
45
+
46
+
47
+ # Global settings instance
48
+ settings = Settings()
49
+
50
+
51
+ # Global inference service instance (initialized during startup)
52
+ _inference_service: Optional[InferenceService] = None
53
+
54
+
55
+ def get_inference_service() -> Optional[InferenceService]:
56
+ """
57
+ Get the inference service instance for dependency injection.
58
+
59
+ This function is used in FastAPI route handlers via Depends().
60
+ The service is initialized once during app startup and reused
61
+ for all requests.
62
+
63
+ Returns:
64
+ The initialized inference service, or None if not yet initialized.
65
+
66
+ Example:
67
+ ```python
68
+ @router.post("/predict")
69
+ async def predict(
70
+ request: ImageRequest,
71
+ service: InferenceService = Depends(get_inference_service)
72
+ ):
73
+ return await service.predict(request)
74
+ ```
75
+ """
76
+ return _inference_service
77
+
78
+
79
+ def _set_inference_service(service: InferenceService) -> None:
80
+ """
81
+ INTERNAL: Set the global inference service instance.
82
+
83
+ Called during application startup to register the service.
84
+ This is marked as internal (prefixed with _) because it should
85
+ only be called from the lifespan handler below.
86
+
87
+ Args:
88
+ service: The initialized inference service instance.
89
+ """
90
+ global _inference_service
91
+ _inference_service = service
92
+
93
+
94
+ @asynccontextmanager
95
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
96
+ """
97
+ Application lifespan manager.
98
+
99
+ Handles startup and shutdown events for the FastAPI application.
100
+ During startup, it initializes and loads the inference service.
101
+
102
+ CUSTOMIZATION POINT FOR GRAD STUDENTS:
103
+ To use your own model, replace ResNetInferenceService below with
104
+ your implementation that subclasses InferenceService.
105
+
106
+ Example:
107
+ ```python
108
+ service = MyCustomService(model_name="my-org/my-model")
109
+ await service.load_model()
110
+ _set_inference_service(service)
111
+ ```
112
+ """
113
+ logger.info("Starting ML Inference Service...")
114
+
115
+ try:
116
+ with warnings.catch_warnings():
117
+ warnings.filterwarnings("ignore", category=FutureWarning)
118
+
119
+ service = ResNetInferenceService(
120
+ model_name="microsoft/resnet-18"
121
+ )
122
+ await service.load_model()
123
+ _set_inference_service(service)
124
+
125
+ logger.info("Startup completed successfully")
126
+
127
+ except Exception as e:
128
+ logger.error(f"Startup failed: {e}")
129
+ raise
130
+
131
+ yield
132
+
133
+ logger.info("Shutting down...")
134
+
135
+
136
  def create_app() -> FastAPI:
137
+ """
138
+ Create and configure the FastAPI application.
139
+
140
+ This is the main entry point for the application. It:
141
+ 1. Creates a FastAPI instance with metadata from settings
142
+ 2. Attaches the lifespan handler for startup/shutdown
143
+ 3. Registers API routes
144
 
145
+ Returns:
146
+ Configured FastAPI application instance.
147
+ """
148
  app = FastAPI(
149
  title=settings.app_name,
150
  description="ML inference service for image classification",
 
153
  lifespan=lifespan
154
  )
155
 
 
156
  app.include_router(prediction.router)
157
 
158
  return app
app/core/config.py DELETED
@@ -1,29 +0,0 @@
1
- """
2
- Basic configuration management.
3
-
4
- Starting simple - just app settings. We'll expand as needed.
5
- """
6
-
7
- from pydantic import Field
8
- from pydantic_settings import BaseSettings # Changed import
9
-
10
-
11
- class Settings(BaseSettings):
12
- """Application settings with environment variable support."""
13
-
14
- # Basic app settings
15
- app_name: str = Field(default="ML Inference Service", description="Application name")
16
- app_version: str = Field(default="0.1.0", description="Application version")
17
- debug: bool = Field(default=False, description="Debug mode")
18
-
19
- # Server settings
20
- host: str = Field(default="0.0.0.0", description="Server host")
21
- port: int = Field(default=8000, description="Server port")
22
-
23
- class Config:
24
- """Load from .env file if it exists."""
25
- env_file = ".env"
26
-
27
-
28
- # Global settings instance
29
- settings = Settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/dependencies.py DELETED
@@ -1,19 +0,0 @@
1
- """
2
- Dependency injection for FastAPI.
3
- """
4
- from typing import Optional
5
- from app.services.inference import ResNetInferenceService
6
-
7
- # Global service instance
8
- _resnet_service: Optional[ResNetInferenceService] = None
9
-
10
-
11
- def get_resnet_service() -> Optional[ResNetInferenceService]:
12
- """Get the ResNet service instance."""
13
- return _resnet_service
14
-
15
-
16
- def set_resnet_service(service: ResNetInferenceService) -> None:
17
- """Set the global ResNet service instance."""
18
- global _resnet_service
19
- _resnet_service = service
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/lifespan.py DELETED
@@ -1,43 +0,0 @@
1
- """
2
- Application lifespan management.
3
- """
4
- import warnings
5
- from contextlib import asynccontextmanager
6
- from typing import AsyncGenerator
7
-
8
- from fastapi import FastAPI
9
-
10
- from app.core.logging import logger
11
- from app.core.dependencies import set_resnet_service
12
- from app.services.inference import ResNetInferenceService
13
-
14
-
15
- @asynccontextmanager
16
- async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
17
- """Application lifespan manager."""
18
-
19
- # Startup
20
- logger.info("Starting ML Inference Service...")
21
-
22
- try:
23
- with warnings.catch_warnings():
24
- warnings.filterwarnings("ignore", category=FutureWarning)
25
-
26
- # Initialize and load ResNet service
27
- resnet_service = ResNetInferenceService(
28
- model_name="microsoft/resnet-18",
29
- use_local_model=True
30
- )
31
- resnet_service.load_model()
32
- set_resnet_service(resnet_service)
33
-
34
- logger.info("Startup completed successfully")
35
-
36
- except Exception as e:
37
- logger.error(f"Startup failed: {e}")
38
- raise
39
-
40
- yield # App runs here
41
-
42
- # Shutdown
43
- logger.info("Shutting down...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/base.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for ML inference services.
3
+
4
+ This module defines the contract that all inference services must implement.
5
+ Grad students should subclass `InferenceService` and implement the abstract methods
6
+ to integrate their models with the serving infrastructure.
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from typing import Generic, TypeVar
11
+
12
+ from pydantic import BaseModel
13
+
14
+
15
+ # Type variables for request and response models
16
+ TRequest = TypeVar('TRequest', bound=BaseModel)
17
+ TResponse = TypeVar('TResponse', bound=BaseModel)
18
+
19
+
20
+ class InferenceService(ABC, Generic[TRequest, TResponse]):
21
+ """
22
+ Abstract base class for ML inference services.
23
+
24
+ This class defines the interface that all model serving implementations must follow.
25
+ By subclassing this and implementing the abstract methods, you can integrate any
26
+ ML model with the serving infrastructure.
27
+
28
+ Type Parameters:
29
+ TRequest: Pydantic model for input requests (e.g., ImageRequest, TextRequest)
30
+ TResponse: Pydantic model for prediction responses (e.g., PredictionResponse)
31
+
32
+ Example:
33
+ ```python
34
+ class MyModelService(InferenceService[MyRequest, MyResponse]):
35
+
36
+ async def load_model(self) -> None:
37
+ # Load your model here
38
+ self.model = torch.load("my_model.pt")
39
+ self._is_loaded = True
40
+
41
+ async def predict(self, request: MyRequest) -> MyResponse:
42
+ # Run inference
43
+ output = self.model(request.data)
44
+ return MyResponse(result=output)
45
+
46
+ @property
47
+ def is_loaded(self) -> bool:
48
+ return self._is_loaded
49
+ ```
50
+ """
51
+
52
+ @abstractmethod
53
+ async def load_model(self) -> None:
54
+ """
55
+ Load the model weights and any required processors/tokenizers.
56
+
57
+ This method is called once during application startup (in the lifespan handler).
58
+ Use this to:
59
+ - Load model weights from disk
60
+ - Initialize processors, tokenizers, or other preprocessing components
61
+ - Set up any required state
62
+ - Perform model warmup if needed
63
+
64
+ Raises:
65
+ FileNotFoundError: If model files don't exist
66
+ RuntimeError: If model loading fails
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ async def predict(self, request: TRequest) -> TResponse:
72
+ """
73
+ Run inference on the input request and return a typed response.
74
+
75
+ This method is called for each prediction request. It should:
76
+ 1. Extract input data from the request
77
+ 2. Preprocess the input (if needed)
78
+ 3. Run the model inference
79
+ 4. Post-process the output
80
+ 5. Return a Pydantic response model
81
+
82
+ Args:
83
+ request: Input request containing the data to predict on.
84
+ Type is specified by the TRequest type parameter.
85
+
86
+ Returns:
87
+ Typed Pydantic response model containing predictions.
88
+ Type is specified by the TResponse type parameter.
89
+
90
+ Raises:
91
+ ValueError: If input data is invalid
92
+ RuntimeError: If model inference fails
93
+
94
+ Important - Background Threading:
95
+ For CPU-intensive operations (like deep learning inference), you MUST
96
+ offload computation to a background thread to avoid blocking the event loop.
97
+
98
+ Pattern to follow:
99
+ ```python
100
+ import asyncio
101
+
102
+ def _predict_sync(self, request: TRequest) -> TResponse:
103
+ # Heavy CPU work here (PyTorch, TensorFlow, etc.)
104
+ result = self.model(data)
105
+ return TResponse(result=result)
106
+
107
+ async def predict(self, request: TRequest) -> TResponse:
108
+ # Offload to thread pool
109
+ return await asyncio.to_thread(self._predict_sync, request)
110
+ ```
111
+
112
+ Why this matters:
113
+ - Inference can take 1-3+ seconds and will freeze the server
114
+ - asyncio.to_thread() runs the work in a background thread
115
+ - The event loop stays responsive to handle other requests
116
+ """
117
+ pass
118
+
119
+ @property
120
+ @abstractmethod
121
+ def is_loaded(self) -> bool:
122
+ """
123
+ Check if the model is loaded and ready for inference.
124
+
125
+ Returns:
126
+ True if model is loaded and ready, False otherwise.
127
+
128
+ Example:
129
+ ```python
130
+ @property
131
+ def is_loaded(self) -> bool:
132
+ return self.model is not None and self._is_loaded
133
+ ```
134
+ """
135
+ pass
app/services/inference.py CHANGED
@@ -1,68 +1,92 @@
1
  """
2
- Inference service for machine learning models.
3
 
4
- This service handles the business logic for ML inference,
5
- following the Single Responsibility Principle.
 
 
 
 
 
 
 
6
  """
7
  import os
8
- from typing import Dict, Any
 
 
9
  import torch
10
  from PIL import Image
11
  from transformers import AutoImageProcessor, ResNetForImageClassification
12
 
13
  from app.core.logging import logger
 
 
14
 
15
 
16
- class ResNetInferenceService:
17
  """
18
- ResNet inference service.
 
 
 
19
 
20
- Handles loading and inference for ResNet models.
21
- Follows the Singleton pattern - loads model once.
 
 
 
 
 
22
  """
23
 
24
- def __init__(self, model_name: str = "microsoft/resnet-18", use_local_model: bool = True):
25
  """
26
  Initialize the ResNet service.
27
 
28
  Args:
29
- model_name: HuggingFace model identifier
 
 
 
 
 
 
 
 
30
  """
31
  self.model_name = model_name
32
- self.use_local_model = use_local_model
33
  self.model = None
34
  self.processor = None
35
  self._is_loaded = False
36
 
37
- if use_local_model:
38
- self.model_path = os.path.join("models", model_name.split("/")[-1])
39
- logger.info(f"Initializing ResNet service with local model: {self.model_path}")
40
- else:
41
- self.model_path = model_name
42
- logger.info(f"Initializing ResNet service with remote model: {model_name}")
43
 
44
- def load_model(self) -> None:
45
  """
46
  Load the ResNet model and processor.
47
 
48
- This method loads the model once and reuses it for all requests.
 
49
  """
50
  if self._is_loaded:
51
  logger.debug("Model already loaded, skipping...")
52
  return
53
 
54
  try:
55
- if self.use_local_model:
56
- if not os.path.exists(self.model_path):
57
- raise FileNotFoundError(f"Local model directory not found: {self.model_path}")
 
 
58
 
59
- config_path = os.path.join(self.model_path, "config.json")
60
- if not os.path.exists(config_path):
61
- raise FileNotFoundError(f"Model config not found: {config_path}")
62
 
63
- logger.info(f"Loading ResNet model from local directory: {self.model_path}")
64
- else:
65
- logger.info(f"Loading ResNet model from HuggingFace Hub: {self.model_name}")
66
 
67
  # Suppress warnings during model loading
68
  import warnings
@@ -70,17 +94,15 @@ class ResNetInferenceService:
70
  warnings.filterwarnings("ignore", category=FutureWarning)
71
  warnings.filterwarnings("ignore", message="Could not find image processor class")
72
 
73
- # Load processor and model from local directory or remote
74
  self.processor = AutoImageProcessor.from_pretrained(
75
  self.model_path,
76
- local_files_only=self.use_local_model
77
  )
78
  self.model = ResNetForImageClassification.from_pretrained(
79
  self.model_path,
80
- local_files_only=self.use_local_model
81
  )
82
 
83
-
84
  self._is_loaded = True
85
  logger.info("ResNet model loaded successfully")
86
  logger.info(f"Model architecture: {self.model.config.architectures}")
@@ -88,64 +110,87 @@ class ResNetInferenceService:
88
 
89
  except Exception as e:
90
  logger.error(f"Failed to load ResNet model: {e}")
91
- if self.use_local_model:
92
- logger.error("Hint: Make sure the model was downloaded correctly with dwl.bash")
93
  raise
94
 
95
 
96
- def predict(self, image: Image.Image) -> Dict[str, Any]:
97
  """
98
- Perform inference on an image.
 
 
 
 
99
 
100
  Args:
101
- image: PIL Image to classify
102
 
103
  Returns:
104
- Dictionary containing prediction results
105
 
106
  Raises:
107
- RuntimeError: If model is not loaded
108
- ValueError: If image processing fails
109
  """
110
- if not self._is_loaded:
111
- logger.info("Model not loaded, loading now...")
112
- self.load_model()
113
-
114
  try:
115
- logger.debug("Starting ResNet inference")
 
 
 
116
 
117
  if image.mode != 'RGB':
 
118
  image = image.convert('RGB')
119
- logger.debug(f"Converted image from {image.mode} to RGB")
120
 
121
  inputs = self.processor(image, return_tensors="pt")
122
 
123
- # Perform inference
124
  with torch.no_grad():
125
  logits = self.model(**inputs).logits
126
 
127
- # Get prediction
128
  predicted_label = logits.argmax(-1).item()
129
  predicted_class = self.model.config.id2label[predicted_label]
130
 
131
- # Calculate confidence score
132
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
133
  confidence = probabilities[0][predicted_label].item()
134
 
135
- result = {
136
- "prediction": predicted_class,
137
- "confidence": round(confidence, 4),
138
- "model": self.model_name,
139
- "predicted_label": predicted_label
140
- }
141
-
142
  logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
143
- return result
 
 
 
 
 
 
 
144
 
145
  except Exception as e:
146
  logger.error(f"Inference failed: {e}")
147
  raise ValueError(f"Failed to process image: {str(e)}")
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  @property
150
  def is_loaded(self) -> bool:
151
  """Check if model is loaded."""
 
1
  """
2
+ Inference service for ResNet image classification models.
3
 
4
+ This module provides an EXAMPLE implementation of the InferenceService ABC.
5
+ Grad students should use this as a reference when implementing their own model services.
6
+
7
+ This example demonstrates:
8
+ - How to load a HuggingFace transformer model
9
+ - How to preprocess image inputs
10
+ - How to return typed Pydantic responses
11
+ - How to use background threading for CPU-intensive inference
12
+ - Proper error handling and logging
13
  """
14
  import os
15
+ import base64
16
+ import asyncio
17
+ from io import BytesIO
18
  import torch
19
  from PIL import Image
20
  from transformers import AutoImageProcessor, ResNetForImageClassification
21
 
22
  from app.core.logging import logger
23
+ from app.services.base import InferenceService
24
+ from app.api.models import ImageRequest, PredictionResponse
25
 
26
 
27
+ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
28
  """
29
+ EXAMPLE: ResNet inference service implementation.
30
+
31
+ This is a reference implementation showing how to integrate a HuggingFace
32
+ image classification model with the serving infrastructure.
33
 
34
+ To create your own service:
35
+ 1. Subclass InferenceService[YourRequest, YourResponse]
36
+ 2. Implement load_model() to load your model
37
+ 3. Implement predict() to run inference and return typed response
38
+ 4. Implement the is_loaded property
39
+
40
+ This service loads a ResNet-18 model for ImageNet classification.
41
  """
42
 
43
+ def __init__(self, model_name: str = "microsoft/resnet-18"):
44
  """
45
  Initialize the ResNet service.
46
 
47
  Args:
48
+ model_name: Model identifier (e.g., "microsoft/resnet-18").
49
+ Model files must exist in models/{model_name}/ directory.
50
+ The full org/model structure is preserved.
51
+
52
+ Example:
53
+ For model_name="microsoft/resnet-18", expects files at:
54
+ models/microsoft/resnet-18/config.json
55
+ models/microsoft/resnet-18/pytorch_model.bin
56
+ etc.
57
  """
58
  self.model_name = model_name
 
59
  self.model = None
60
  self.processor = None
61
  self._is_loaded = False
62
 
63
+ # Preserve full org/model path structure
64
+ self.model_path = os.path.join("models", model_name)
65
+ logger.info(f"Initializing ResNet service with local model: {self.model_path}")
 
 
 
66
 
67
+ async def load_model(self) -> None:
68
  """
69
  Load the ResNet model and processor.
70
 
71
+ This method loads the model once during startup and reuses it for all requests.
72
+ Called by the application lifespan handler.
73
  """
74
  if self._is_loaded:
75
  logger.debug("Model already loaded, skipping...")
76
  return
77
 
78
  try:
79
+ if not os.path.exists(self.model_path):
80
+ raise FileNotFoundError(
81
+ f"Model directory not found: {self.model_path}\n"
82
+ f"Make sure the model files are downloaded to the correct location."
83
+ )
84
 
85
+ config_path = os.path.join(self.model_path, "config.json")
86
+ if not os.path.exists(config_path):
87
+ raise FileNotFoundError(f"Model config not found: {config_path}")
88
 
89
+ logger.info(f"Loading ResNet model from: {self.model_path}")
 
 
90
 
91
  # Suppress warnings during model loading
92
  import warnings
 
94
  warnings.filterwarnings("ignore", category=FutureWarning)
95
  warnings.filterwarnings("ignore", message="Could not find image processor class")
96
 
 
97
  self.processor = AutoImageProcessor.from_pretrained(
98
  self.model_path,
99
+ local_files_only=True
100
  )
101
  self.model = ResNetForImageClassification.from_pretrained(
102
  self.model_path,
103
+ local_files_only=True
104
  )
105
 
 
106
  self._is_loaded = True
107
  logger.info("ResNet model loaded successfully")
108
  logger.info(f"Model architecture: {self.model.config.architectures}")
 
110
 
111
  except Exception as e:
112
  logger.error(f"Failed to load ResNet model: {e}")
113
+ logger.error(f"Hint: Ensure model files exist at: {self.model_path}")
 
114
  raise
115
 
116
 
117
+ def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
118
  """
119
+ INTERNAL: Synchronous prediction logic that runs in a background thread.
120
+
121
+ This method contains all CPU-intensive operations (image decoding,
122
+ preprocessing, PyTorch inference). It's called from predict() via
123
+ asyncio.to_thread() to avoid blocking the event loop.
124
 
125
  Args:
126
+ request: ImageRequest containing base64-encoded image data
127
 
128
  Returns:
129
+ PredictionResponse with prediction, confidence, and metadata
130
 
131
  Raises:
132
+ ValueError: If image decoding or processing fails
 
133
  """
 
 
 
 
134
  try:
135
+ logger.debug("Starting ResNet inference in background thread")
136
+
137
+ image_data = base64.b64decode(request.image.data)
138
+ image = Image.open(BytesIO(image_data))
139
 
140
  if image.mode != 'RGB':
141
+ logger.debug(f"Converting image from {image.mode} to RGB")
142
  image = image.convert('RGB')
 
143
 
144
  inputs = self.processor(image, return_tensors="pt")
145
 
 
146
  with torch.no_grad():
147
  logits = self.model(**inputs).logits
148
 
 
149
  predicted_label = logits.argmax(-1).item()
150
  predicted_class = self.model.config.id2label[predicted_label]
151
 
 
152
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
153
  confidence = probabilities[0][predicted_label].item()
154
 
 
 
 
 
 
 
 
155
  logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
156
+
157
+ return PredictionResponse(
158
+ prediction=predicted_class,
159
+ confidence=round(confidence, 4),
160
+ model=self.model_name,
161
+ predicted_label=predicted_label,
162
+ mediaType=request.image.mediaType
163
+ )
164
 
165
  except Exception as e:
166
  logger.error(f"Inference failed: {e}")
167
  raise ValueError(f"Failed to process image: {str(e)}")
168
 
169
+ async def predict(self, request: ImageRequest) -> PredictionResponse:
170
+ """
171
+ Perform inference on an image request.
172
+
173
+ This method demonstrates proper async handling for CPU-intensive operations.
174
+ The actual inference work is offloaded to a background thread using
175
+ asyncio.to_thread(), which prevents blocking the event loop.
176
+
177
+ Args:
178
+ request: ImageRequest containing base64-encoded image data
179
+
180
+ Returns:
181
+ PredictionResponse with prediction, confidence, and metadata
182
+
183
+ Raises:
184
+ RuntimeError: If model is not loaded
185
+ ValueError: If image decoding or processing fails
186
+ """
187
+ if not self._is_loaded:
188
+ logger.warning("Model not loaded, loading now...")
189
+ await self.load_model()
190
+
191
+ response = await asyncio.to_thread(self._predict_sync, request)
192
+ return response
193
+
194
  @property
195
  def is_loaded(self) -> bool:
196
  """Check if model is loaded."""
test_main.http CHANGED
@@ -1,6 +1,7 @@
1
- # Test ResNet Prediction Endpoint
 
2
 
3
- POST http://127.0.0.1:8000/predict/resnet
4
  Content-Type: application/json
5
 
6
  {
 
1
+ # Test Prediction Endpoint
2
+ # Works with any model configured at startup (default: ResNet-18)
3
 
4
+ POST http://127.0.0.1:8000/predict
5
  Content-Type: application/json
6
 
7
  {