sachin sharma commited on
Commit
d481329
·
1 Parent(s): eff5773

fix for async methods as per PR

Browse files
app/api/controllers.py CHANGED
@@ -1,5 +1,6 @@
1
  """API controllers for request handling and validation."""
2
 
 
3
  from fastapi import HTTPException
4
 
5
  from app.core.logging import logger
@@ -23,7 +24,7 @@ class PredictionController:
23
  if not request.image.mediaType.startswith('image/'):
24
  raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
25
 
26
- return await service.predict(request)
27
 
28
  except HTTPException:
29
  raise
 
1
  """API controllers for request handling and validation."""
2
 
3
+ import asyncio
4
  from fastapi import HTTPException
5
 
6
  from app.core.logging import logger
 
24
  if not request.image.mediaType.startswith('image/'):
25
  raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
26
 
27
+ return await asyncio.to_thread(service.predict, request)
28
 
29
  except HTTPException:
30
  raise
app/core/app.py CHANGED
@@ -1,5 +1,6 @@
1
  """FastAPI application factory and core infrastructure."""
2
 
 
3
  import warnings
4
  from contextlib import asynccontextmanager
5
  from typing import AsyncGenerator, Optional
@@ -41,7 +42,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
41
 
42
  # Replace ResNetInferenceService with your own implementation
43
  service = ResNetInferenceService(model_name="microsoft/resnet-18")
44
- await service.load_model()
45
  set_inference_service(service)
46
 
47
  logger.info("Startup completed successfully")
 
1
  """FastAPI application factory and core infrastructure."""
2
 
3
+ import asyncio
4
  import warnings
5
  from contextlib import asynccontextmanager
6
  from typing import AsyncGenerator, Optional
 
42
 
43
  # Replace ResNetInferenceService with your own implementation
44
  service = ResNetInferenceService(model_name="microsoft/resnet-18")
45
+ await asyncio.to_thread(service.load_model)
46
  set_inference_service(service)
47
 
48
  logger.info("Startup completed successfully")
app/services/base.py CHANGED
@@ -11,18 +11,15 @@ TResponse = TypeVar('TResponse', bound=BaseModel)
11
  class InferenceService(ABC, Generic[TRequest, TResponse]):
12
  """
13
  Base class for inference services. Subclass this to integrate your model.
14
-
15
- For CPU-intensive inference, offload work to a background thread using
16
- asyncio.to_thread() to avoid blocking the event loop.
17
  """
18
 
19
  @abstractmethod
20
- async def load_model(self) -> None:
21
  """Load model weights and processors. Called once at startup."""
22
  pass
23
 
24
  @abstractmethod
25
- async def predict(self, request: TRequest) -> TResponse:
26
  """Run inference and return typed response."""
27
  pass
28
 
 
11
  class InferenceService(ABC, Generic[TRequest, TResponse]):
12
  """
13
  Base class for inference services. Subclass this to integrate your model.
 
 
 
14
  """
15
 
16
  @abstractmethod
17
+ def load_model(self) -> None:
18
  """Load model weights and processors. Called once at startup."""
19
  pass
20
 
21
  @abstractmethod
22
+ def predict(self, request: TRequest) -> TResponse:
23
  """Run inference and return typed response."""
24
  pass
25
 
app/services/inference.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  import os
4
  import base64
5
- import asyncio
6
  from io import BytesIO
7
  import torch
8
  from PIL import Image
@@ -24,7 +23,7 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
24
  self.model_path = os.path.join("models", model_name)
25
  logger.info(f"Initializing ResNet service: {self.model_path}")
26
 
27
- async def load_model(self) -> None:
28
  if self._is_loaded:
29
  return
30
 
@@ -50,8 +49,7 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
50
  self._is_loaded = True
51
  logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
52
 
53
- def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
54
- """Synchronous inference logic. Runs in background thread."""
55
  image_data = base64.b64decode(request.image.data)
56
  image = Image.open(BytesIO(image_data))
57
 
@@ -76,13 +74,6 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
76
  mediaType=request.image.mediaType
77
  )
78
 
79
- async def predict(self, request: ImageRequest) -> PredictionResponse:
80
- """Run inference with background threading to avoid blocking event loop."""
81
- if not self._is_loaded:
82
- await self.load_model()
83
-
84
- return await asyncio.to_thread(self._predict_sync, request)
85
-
86
  @property
87
  def is_loaded(self) -> bool:
88
  return self._is_loaded
 
2
 
3
  import os
4
  import base64
 
5
  from io import BytesIO
6
  import torch
7
  from PIL import Image
 
23
  self.model_path = os.path.join("models", model_name)
24
  logger.info(f"Initializing ResNet service: {self.model_path}")
25
 
26
+ def load_model(self) -> None:
27
  if self._is_loaded:
28
  return
29
 
 
49
  self._is_loaded = True
50
  logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
51
 
52
+ def predict(self, request: ImageRequest) -> PredictionResponse:
 
53
  image_data = base64.b64decode(request.image.data)
54
  image = Image.open(BytesIO(image_data))
55
 
 
74
  mediaType=request.image.mediaType
75
  )
76
 
 
 
 
 
 
 
 
77
  @property
78
  def is_loaded(self) -> bool:
79
  return self._is_loaded