sachin sharma commited on
Commit
da2b98d
·
1 Parent(s): 5ddae77

removed verbosity

Browse files
app/api/controllers.py CHANGED
@@ -1,11 +1,5 @@
1
- """
2
- Controllers for handling API business logic.
3
 
4
- This controller layer orchestrates requests between the API routes and the
5
- inference service layer. It handles validation and error responses.
6
-
7
- The controller is model-agnostic and works with any InferenceService implementation.
8
- """
9
  from fastapi import HTTPException
10
 
11
  from app.core.logging import logger
@@ -14,66 +8,28 @@ from app.api.models import ImageRequest, PredictionResponse
14
 
15
 
16
  class PredictionController:
17
- """
18
- Controller for ML prediction endpoints.
19
-
20
- This controller works with any InferenceService implementation,
21
- making it easy to swap different models without changing the API layer.
22
- """
23
 
24
  @staticmethod
25
  async def predict(
26
- request: ImageRequest,
27
- service: InferenceService
28
  ) -> PredictionResponse:
29
- """
30
- Run inference using the configured model service.
31
-
32
- The controller handles request validation and error handling,
33
- while the service handles the actual inference logic.
34
-
35
- Args:
36
- request: ImageRequest with base64-encoded image data
37
- service: Initialized inference service (can be any model)
38
-
39
- Returns:
40
- PredictionResponse with prediction results
41
-
42
- Raises:
43
- HTTPException: If service unavailable, invalid input, or inference fails
44
- """
45
  try:
46
- # Validate service availability
47
- if not service:
48
- raise HTTPException(
49
- status_code=503,
50
- detail="Service not initialized"
51
- )
52
-
53
- if not service.is_loaded:
54
- raise HTTPException(
55
- status_code=503,
56
- detail="Model not loaded"
57
- )
58
 
59
- # Validate media type
60
  if not request.image.mediaType.startswith('image/'):
61
- raise HTTPException(
62
- status_code=400,
63
- detail=f"Invalid media type: {request.image.mediaType}. Must be image/*"
64
- )
65
 
66
- # Call service - it handles decoding and returns typed response
67
- response = await service.predict(request)
68
- return response
69
 
70
  except HTTPException:
71
  raise
72
  except ValueError as e:
73
- # Service raises ValueError for invalid input
74
  logger.error(f"Invalid input: {e}")
75
- raise HTTPException(status_code=400, detail=str(e))
76
  except Exception as e:
77
- # Unexpected errors
78
  logger.error(f"Prediction failed: {e}")
79
- raise HTTPException(status_code=500, detail="Internal server error")
 
1
+ """API controllers for request handling and validation."""
 
2
 
 
 
 
 
 
3
  from fastapi import HTTPException
4
 
5
  from app.core.logging import logger
 
8
 
9
 
10
  class PredictionController:
11
+ """Controller for prediction endpoints."""
 
 
 
 
 
12
 
13
  @staticmethod
14
  async def predict(
15
+ request: ImageRequest,
16
+ service: InferenceService
17
  ) -> PredictionResponse:
18
+ """Run inference using the configured service."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  try:
20
+ if not service or not service.is_loaded:
21
+ raise HTTPException(503, "Service not available")
 
 
 
 
 
 
 
 
 
 
22
 
 
23
  if not request.image.mediaType.startswith('image/'):
24
+ raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
 
 
 
25
 
26
+ return await service.predict(request)
 
 
27
 
28
  except HTTPException:
29
  raise
30
  except ValueError as e:
 
31
  logger.error(f"Invalid input: {e}")
32
+ raise HTTPException(400, str(e))
33
  except Exception as e:
 
34
  logger.error(f"Prediction failed: {e}")
35
+ raise HTTPException(500, "Internal server error")
app/api/routes/prediction.py CHANGED
@@ -1,9 +1,5 @@
1
- """
2
- ML Prediction routes.
3
 
4
- This module defines the HTTP endpoints for running model inference.
5
- The routes are model-agnostic and work with any InferenceService implementation.
6
- """
7
  from fastapi import APIRouter, Depends
8
 
9
  from app.api.controllers import PredictionController
@@ -20,40 +16,8 @@ async def predict(
20
  service: InferenceService = Depends(get_inference_service)
21
  ):
22
  """
23
- Run inference on an image using the configured model.
24
 
25
- This endpoint works with any model that implements the InferenceService interface.
26
- The actual model used depends on what was configured during app startup.
27
-
28
- Example Request Body:
29
- ```json
30
- {
31
- "image": {
32
- "mediaType": "image/jpeg",
33
- "data": "<base64-encoded-image-data>"
34
- }
35
- }
36
- ```
37
-
38
- Example Response:
39
- ```json
40
- {
41
- "prediction": "tabby cat",
42
- "confidence": 0.8542,
43
- "model": "microsoft/resnet-18",
44
- "predicted_label": 281,
45
- "mediaType": "image/jpeg"
46
- }
47
- ```
48
-
49
- Args:
50
- request: ImageRequest containing base64-encoded image
51
- service: Injected inference service (configured at startup)
52
-
53
- Returns:
54
- PredictionResponse with model predictions
55
-
56
- Raises:
57
- HTTPException: 400 for invalid input, 503 if service unavailable, 500 for errors
58
  """
59
- return await PredictionController.predict(request, service)
 
1
+ """Prediction API routes."""
 
2
 
 
 
 
3
  from fastapi import APIRouter, Depends
4
 
5
  from app.api.controllers import PredictionController
 
16
  service: InferenceService = Depends(get_inference_service)
17
  ):
18
  """
19
+ Run inference on base64-encoded image.
20
 
21
+ Returns prediction, confidence, predicted label, model name, and media type.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
+ return await PredictionController.predict(request, service)
app/core/app.py CHANGED
@@ -1,15 +1,5 @@
1
- """
2
- FastAPI application factory and core infrastructure.
3
-
4
- This module consolidates all core application components:
5
- - Configuration management
6
- - Global service instance (dependency injection)
7
- - Application lifecycle (startup/shutdown)
8
- - FastAPI app creation
9
-
10
- By keeping everything in one place, we avoid the complexity of managing
11
- global variables across multiple modules.
12
- """
13
  import warnings
14
  from contextlib import asynccontextmanager
15
  from typing import AsyncGenerator, Optional
@@ -25,100 +15,45 @@ from app.api.routes import prediction
25
 
26
 
27
  class Settings(BaseSettings):
28
- """
29
- Application settings with environment variable support.
30
-
31
- Settings can be overridden via environment variables or .env file.
32
- """
33
- # Basic app settings
34
- app_name: str = Field(default="ML Inference Service", description="Application name")
35
- app_version: str = Field(default="0.1.0", description="Application version")
36
- debug: bool = Field(default=False, description="Debug mode")
37
 
38
- # Server settings
39
- host: str = Field(default="0.0.0.0", description="Server host")
40
- port: int = Field(default=8000, description="Server port")
 
 
41
 
42
  class Config:
43
- """Load from .env file if it exists."""
44
  env_file = ".env"
45
 
46
 
47
- # Global settings instance
48
  settings = Settings()
49
 
50
-
51
- # Global inference service instance (initialized during startup)
52
  _inference_service: Optional[InferenceService] = None
53
 
54
 
55
  def get_inference_service() -> Optional[InferenceService]:
56
- """
57
- Get the inference service instance for dependency injection.
58
-
59
- This function is used in FastAPI route handlers via Depends().
60
- The service is initialized once during app startup and reused
61
- for all requests.
62
-
63
- Returns:
64
- The initialized inference service, or None if not yet initialized.
65
-
66
- Example:
67
- ```python
68
- @router.post("/predict")
69
- async def predict(
70
- request: ImageRequest,
71
- service: InferenceService = Depends(get_inference_service)
72
- ):
73
- return await service.predict(request)
74
- ```
75
- """
76
  return _inference_service
77
 
78
 
79
  def _set_inference_service(service: InferenceService) -> None:
80
- """
81
- INTERNAL: Set the global inference service instance.
82
-
83
- Called during application startup to register the service.
84
- This is marked as internal (prefixed with _) because it should
85
- only be called from the lifespan handler below.
86
-
87
- Args:
88
- service: The initialized inference service instance.
89
- """
90
  global _inference_service
91
  _inference_service = service
92
 
93
 
94
  @asynccontextmanager
95
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
96
- """
97
- Application lifespan manager.
98
-
99
- Handles startup and shutdown events for the FastAPI application.
100
- During startup, it initializes and loads the inference service.
101
-
102
- CUSTOMIZATION POINT FOR GRAD STUDENTS:
103
- To use your own model, replace ResNetInferenceService below with
104
- your implementation that subclasses InferenceService.
105
-
106
- Example:
107
- ```python
108
- service = MyCustomService(model_name="my-org/my-model")
109
- await service.load_model()
110
- _set_inference_service(service)
111
- ```
112
- """
113
  logger.info("Starting ML Inference Service...")
114
 
115
  try:
116
  with warnings.catch_warnings():
117
  warnings.filterwarnings("ignore", category=FutureWarning)
118
 
119
- service = ResNetInferenceService(
120
- model_name="microsoft/resnet-18"
121
- )
122
  await service.load_model()
123
  _set_inference_service(service)
124
 
@@ -134,17 +69,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
134
 
135
 
136
  def create_app() -> FastAPI:
137
- """
138
- Create and configure the FastAPI application.
139
-
140
- This is the main entry point for the application. It:
141
- 1. Creates a FastAPI instance with metadata from settings
142
- 2. Attaches the lifespan handler for startup/shutdown
143
- 3. Registers API routes
144
-
145
- Returns:
146
- Configured FastAPI application instance.
147
- """
148
  app = FastAPI(
149
  title=settings.app_name,
150
  description="ML inference service for image classification",
@@ -155,4 +80,4 @@ def create_app() -> FastAPI:
155
 
156
  app.include_router(prediction.router)
157
 
158
- return app
 
1
+ """FastAPI application factory and core infrastructure."""
2
+
 
 
 
 
 
 
 
 
 
 
3
  import warnings
4
  from contextlib import asynccontextmanager
5
  from typing import AsyncGenerator, Optional
 
15
 
16
 
17
  class Settings(BaseSettings):
18
+ """Application settings. Override via environment variables or .env file."""
 
 
 
 
 
 
 
 
19
 
20
+ app_name: str = Field(default="ML Inference Service")
21
+ app_version: str = Field(default="0.1.0")
22
+ debug: bool = Field(default=False)
23
+ host: str = Field(default="0.0.0.0")
24
+ port: int = Field(default=8000)
25
 
26
  class Config:
 
27
  env_file = ".env"
28
 
29
 
 
30
  settings = Settings()
31
 
 
 
32
  _inference_service: Optional[InferenceService] = None
33
 
34
 
35
  def get_inference_service() -> Optional[InferenceService]:
36
+ """Get inference service for dependency injection."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return _inference_service
38
 
39
 
40
  def _set_inference_service(service: InferenceService) -> None:
41
+ """Set inference service. Called internally during startup."""
 
 
 
 
 
 
 
 
 
42
  global _inference_service
43
  _inference_service = service
44
 
45
 
46
  @asynccontextmanager
47
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
48
+ """Application lifecycle: startup/shutdown."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  logger.info("Starting ML Inference Service...")
50
 
51
  try:
52
  with warnings.catch_warnings():
53
  warnings.filterwarnings("ignore", category=FutureWarning)
54
 
55
+ # Replace ResNetInferenceService with your own implementation
56
+ service = ResNetInferenceService(model_name="microsoft/resnet-18")
 
57
  await service.load_model()
58
  _set_inference_service(service)
59
 
 
69
 
70
 
71
  def create_app() -> FastAPI:
72
+ """Create and configure FastAPI application."""
 
 
 
 
 
 
 
 
 
 
73
  app = FastAPI(
74
  title=settings.app_name,
75
  description="ML inference service for image classification",
 
80
 
81
  app.include_router(prediction.router)
82
 
83
+ return app
app/services/base.py CHANGED
@@ -1,135 +1,33 @@
1
- """
2
- Abstract base class for ML inference services.
3
-
4
- This module defines the contract that all inference services must implement.
5
- Grad students should subclass `InferenceService` and implement the abstract methods
6
- to integrate their models with the serving infrastructure.
7
- """
8
 
9
  from abc import ABC, abstractmethod
10
  from typing import Generic, TypeVar
11
-
12
  from pydantic import BaseModel
13
 
14
-
15
- # Type variables for request and response models
16
  TRequest = TypeVar('TRequest', bound=BaseModel)
17
  TResponse = TypeVar('TResponse', bound=BaseModel)
18
 
19
 
20
  class InferenceService(ABC, Generic[TRequest, TResponse]):
21
  """
22
- Abstract base class for ML inference services.
23
-
24
- This class defines the interface that all model serving implementations must follow.
25
- By subclassing this and implementing the abstract methods, you can integrate any
26
- ML model with the serving infrastructure.
27
-
28
- Type Parameters:
29
- TRequest: Pydantic model for input requests (e.g., ImageRequest, TextRequest)
30
- TResponse: Pydantic model for prediction responses (e.g., PredictionResponse)
31
-
32
- Example:
33
- ```python
34
- class MyModelService(InferenceService[MyRequest, MyResponse]):
35
 
36
- async def load_model(self) -> None:
37
- # Load your model here
38
- self.model = torch.load("my_model.pt")
39
- self._is_loaded = True
40
-
41
- async def predict(self, request: MyRequest) -> MyResponse:
42
- # Run inference
43
- output = self.model(request.data)
44
- return MyResponse(result=output)
45
-
46
- @property
47
- def is_loaded(self) -> bool:
48
- return self._is_loaded
49
- ```
50
  """
51
 
52
  @abstractmethod
53
  async def load_model(self) -> None:
54
- """
55
- Load the model weights and any required processors/tokenizers.
56
-
57
- This method is called once during application startup (in the lifespan handler).
58
- Use this to:
59
- - Load model weights from disk
60
- - Initialize processors, tokenizers, or other preprocessing components
61
- - Set up any required state
62
- - Perform model warmup if needed
63
-
64
- Raises:
65
- FileNotFoundError: If model files don't exist
66
- RuntimeError: If model loading fails
67
- """
68
  pass
69
 
70
  @abstractmethod
71
  async def predict(self, request: TRequest) -> TResponse:
72
- """
73
- Run inference on the input request and return a typed response.
74
-
75
- This method is called for each prediction request. It should:
76
- 1. Extract input data from the request
77
- 2. Preprocess the input (if needed)
78
- 3. Run the model inference
79
- 4. Post-process the output
80
- 5. Return a Pydantic response model
81
-
82
- Args:
83
- request: Input request containing the data to predict on.
84
- Type is specified by the TRequest type parameter.
85
-
86
- Returns:
87
- Typed Pydantic response model containing predictions.
88
- Type is specified by the TResponse type parameter.
89
-
90
- Raises:
91
- ValueError: If input data is invalid
92
- RuntimeError: If model inference fails
93
-
94
- Important - Background Threading:
95
- For CPU-intensive operations (like deep learning inference), you MUST
96
- offload computation to a background thread to avoid blocking the event loop.
97
-
98
- Pattern to follow:
99
- ```python
100
- import asyncio
101
-
102
- def _predict_sync(self, request: TRequest) -> TResponse:
103
- # Heavy CPU work here (PyTorch, TensorFlow, etc.)
104
- result = self.model(data)
105
- return TResponse(result=result)
106
-
107
- async def predict(self, request: TRequest) -> TResponse:
108
- # Offload to thread pool
109
- return await asyncio.to_thread(self._predict_sync, request)
110
- ```
111
-
112
- Why this matters:
113
- - Inference can take 1-3+ seconds and will freeze the server
114
- - asyncio.to_thread() runs the work in a background thread
115
- - The event loop stays responsive to handle other requests
116
- """
117
  pass
118
 
119
  @property
120
  @abstractmethod
121
  def is_loaded(self) -> bool:
122
- """
123
- Check if the model is loaded and ready for inference.
124
-
125
- Returns:
126
- True if model is loaded and ready, False otherwise.
127
-
128
- Example:
129
- ```python
130
- @property
131
- def is_loaded(self) -> bool:
132
- return self.model is not None and self._is_loaded
133
- ```
134
- """
135
  pass
 
1
+ """Abstract base class for ML inference services."""
 
 
 
 
 
 
2
 
3
  from abc import ABC, abstractmethod
4
  from typing import Generic, TypeVar
 
5
  from pydantic import BaseModel
6
 
 
 
7
  TRequest = TypeVar('TRequest', bound=BaseModel)
8
  TResponse = TypeVar('TResponse', bound=BaseModel)
9
 
10
 
11
  class InferenceService(ABC, Generic[TRequest, TResponse]):
12
  """
13
+ Base class for inference services. Subclass this to integrate your model.
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ For CPU-intensive inference, offload work to a background thread using
16
+ asyncio.to_thread() to avoid blocking the event loop.
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
 
19
  @abstractmethod
20
  async def load_model(self) -> None:
21
+ """Load model weights and processors. Called once at startup."""
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  pass
23
 
24
  @abstractmethod
25
  async def predict(self, request: TRequest) -> TResponse:
26
+ """Run inference and return typed response."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  pass
28
 
29
  @property
30
  @abstractmethod
31
  def is_loaded(self) -> bool:
32
+ """Check if model is loaded and ready."""
 
 
 
 
 
 
 
 
 
 
 
 
33
  pass
app/services/inference.py CHANGED
@@ -1,16 +1,5 @@
1
- """
2
- Inference service for ResNet image classification models.
3
-
4
- This module provides an EXAMPLE implementation of the InferenceService ABC.
5
- Grad students should use this as a reference when implementing their own model services.
6
-
7
- This example demonstrates:
8
- - How to load a HuggingFace transformer model
9
- - How to preprocess image inputs
10
- - How to return typed Pydantic responses
11
- - How to use background threading for CPU-intensive inference
12
- - Proper error handling and logging
13
- """
14
  import os
15
  import base64
16
  import asyncio
@@ -25,173 +14,75 @@ from app.api.models import ImageRequest, PredictionResponse
25
 
26
 
27
  class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
28
- """
29
- EXAMPLE: ResNet inference service implementation.
30
-
31
- This is a reference implementation showing how to integrate a HuggingFace
32
- image classification model with the serving infrastructure.
33
-
34
- To create your own service:
35
- 1. Subclass InferenceService[YourRequest, YourResponse]
36
- 2. Implement load_model() to load your model
37
- 3. Implement predict() to run inference and return typed response
38
- 4. Implement the is_loaded property
39
-
40
- This service loads a ResNet-18 model for ImageNet classification.
41
- """
42
 
43
  def __init__(self, model_name: str = "microsoft/resnet-18"):
44
- """
45
- Initialize the ResNet service.
46
-
47
- Args:
48
- model_name: Model identifier (e.g., "microsoft/resnet-18").
49
- Model files must exist in models/{model_name}/ directory.
50
- The full org/model structure is preserved.
51
-
52
- Example:
53
- For model_name="microsoft/resnet-18", expects files at:
54
- models/microsoft/resnet-18/config.json
55
- models/microsoft/resnet-18/pytorch_model.bin
56
- etc.
57
- """
58
  self.model_name = model_name
59
  self.model = None
60
  self.processor = None
61
  self._is_loaded = False
62
-
63
- # Preserve full org/model path structure
64
  self.model_path = os.path.join("models", model_name)
65
- logger.info(f"Initializing ResNet service with local model: {self.model_path}")
66
 
67
  async def load_model(self) -> None:
68
- """
69
- Load the ResNet model and processor.
70
-
71
- This method loads the model once during startup and reuses it for all requests.
72
- Called by the application lifespan handler.
73
- """
74
  if self._is_loaded:
75
- logger.debug("Model already loaded, skipping...")
76
  return
77
 
78
- try:
79
- if not os.path.exists(self.model_path):
80
- raise FileNotFoundError(
81
- f"Model directory not found: {self.model_path}\n"
82
- f"Make sure the model files are downloaded to the correct location."
83
- )
84
-
85
- config_path = os.path.join(self.model_path, "config.json")
86
- if not os.path.exists(config_path):
87
- raise FileNotFoundError(f"Model config not found: {config_path}")
88
-
89
- logger.info(f"Loading ResNet model from: {self.model_path}")
90
-
91
- # Suppress warnings during model loading
92
- import warnings
93
- with warnings.catch_warnings():
94
- warnings.filterwarnings("ignore", category=FutureWarning)
95
- warnings.filterwarnings("ignore", message="Could not find image processor class")
96
-
97
- self.processor = AutoImageProcessor.from_pretrained(
98
- self.model_path,
99
- local_files_only=True
100
- )
101
- self.model = ResNetForImageClassification.from_pretrained(
102
- self.model_path,
103
- local_files_only=True
104
- )
105
-
106
- self._is_loaded = True
107
- logger.info("ResNet model loaded successfully")
108
- logger.info(f"Model architecture: {self.model.config.architectures}")
109
- logger.info(f"Model has {len(self.model.config.id2label)} classes")
110
-
111
- except Exception as e:
112
- logger.error(f"Failed to load ResNet model: {e}")
113
- logger.error(f"Hint: Ensure model files exist at: {self.model_path}")
114
- raise
115
-
116
-
117
- def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
118
- """
119
- INTERNAL: Synchronous prediction logic that runs in a background thread.
120
-
121
- This method contains all CPU-intensive operations (image decoding,
122
- preprocessing, PyTorch inference). It's called from predict() via
123
- asyncio.to_thread() to avoid blocking the event loop.
124
 
125
- Args:
126
- request: ImageRequest containing base64-encoded image data
 
127
 
128
- Returns:
129
- PredictionResponse with prediction, confidence, and metadata
130
 
131
- Raises:
132
- ValueError: If image decoding or processing fails
133
- """
134
- try:
135
- logger.debug("Starting ResNet inference in background thread")
136
-
137
- image_data = base64.b64decode(request.image.data)
138
- image = Image.open(BytesIO(image_data))
139
-
140
- if image.mode != 'RGB':
141
- logger.debug(f"Converting image from {image.mode} to RGB")
142
- image = image.convert('RGB')
143
 
144
- inputs = self.processor(image, return_tensors="pt")
 
145
 
146
- with torch.no_grad():
147
- logits = self.model(**inputs).logits
 
 
148
 
149
- predicted_label = logits.argmax(-1).item()
150
- predicted_class = self.model.config.id2label[predicted_label]
151
 
152
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
153
- confidence = probabilities[0][predicted_label].item()
154
 
155
- logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
 
156
 
157
- return PredictionResponse(
158
- prediction=predicted_class,
159
- confidence=round(confidence, 4),
160
- model=self.model_name,
161
- predicted_label=predicted_label,
162
- mediaType=request.image.mediaType
163
- )
164
 
165
- except Exception as e:
166
- logger.error(f"Inference failed: {e}")
167
- raise ValueError(f"Failed to process image: {str(e)}")
 
 
 
 
168
 
169
  async def predict(self, request: ImageRequest) -> PredictionResponse:
170
- """
171
- Perform inference on an image request.
172
-
173
- This method demonstrates proper async handling for CPU-intensive operations.
174
- The actual inference work is offloaded to a background thread using
175
- asyncio.to_thread(), which prevents blocking the event loop.
176
-
177
- Args:
178
- request: ImageRequest containing base64-encoded image data
179
-
180
- Returns:
181
- PredictionResponse with prediction, confidence, and metadata
182
-
183
- Raises:
184
- RuntimeError: If model is not loaded
185
- ValueError: If image decoding or processing fails
186
- """
187
  if not self._is_loaded:
188
- logger.warning("Model not loaded, loading now...")
189
  await self.load_model()
190
 
191
- response = await asyncio.to_thread(self._predict_sync, request)
192
- return response
193
 
194
  @property
195
  def is_loaded(self) -> bool:
196
- """Check if model is loaded."""
197
  return self._is_loaded
 
1
+ """ResNet inference service implementation."""
2
+
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import base64
5
  import asyncio
 
14
 
15
 
16
  class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
17
+ """ResNet-18 inference service for image classification."""
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def __init__(self, model_name: str = "microsoft/resnet-18"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  self.model_name = model_name
21
  self.model = None
22
  self.processor = None
23
  self._is_loaded = False
 
 
24
  self.model_path = os.path.join("models", model_name)
25
+ logger.info(f"Initializing ResNet service: {self.model_path}")
26
 
27
  async def load_model(self) -> None:
 
 
 
 
 
 
28
  if self._is_loaded:
 
29
  return
30
 
31
+ if not os.path.exists(self.model_path):
32
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ config_path = os.path.join(self.model_path, "config.json")
35
+ if not os.path.exists(config_path):
36
+ raise FileNotFoundError(f"Config not found: {config_path}")
37
 
38
+ logger.info(f"Loading model from {self.model_path}")
 
39
 
40
+ import warnings
41
+ with warnings.catch_warnings():
42
+ warnings.filterwarnings("ignore", category=FutureWarning)
43
+ self.processor = AutoImageProcessor.from_pretrained(
44
+ self.model_path, local_files_only=True
45
+ )
46
+ self.model = ResNetForImageClassification.from_pretrained(
47
+ self.model_path, local_files_only=True
48
+ )
 
 
 
49
 
50
+ self._is_loaded = True
51
+ logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
52
 
53
+ def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
54
+ """Synchronous inference logic. Runs in background thread."""
55
+ image_data = base64.b64decode(request.image.data)
56
+ image = Image.open(BytesIO(image_data))
57
 
58
+ if image.mode != 'RGB':
59
+ image = image.convert('RGB')
60
 
61
+ inputs = self.processor(image, return_tensors="pt")
 
62
 
63
+ with torch.no_grad():
64
+ logits = self.model(**inputs).logits
65
 
66
+ predicted_label = logits.argmax(-1).item()
67
+ predicted_class = self.model.config.id2label[predicted_label]
68
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
69
+ confidence = probabilities[0][predicted_label].item()
 
 
 
70
 
71
+ return PredictionResponse(
72
+ prediction=predicted_class,
73
+ confidence=round(confidence, 4),
74
+ model=self.model_name,
75
+ predicted_label=predicted_label,
76
+ mediaType=request.image.mediaType
77
+ )
78
 
79
  async def predict(self, request: ImageRequest) -> PredictionResponse:
80
+ """Run inference with background threading to avoid blocking event loop."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if not self._is_loaded:
 
82
  await self.load_model()
83
 
84
+ return await asyncio.to_thread(self._predict_sync, request)
 
85
 
86
  @property
87
  def is_loaded(self) -> bool:
 
88
  return self._is_loaded