cc-dsri commited on
Commit
789798d
·
1 Parent(s): e9a47ca

Add GPU support for inference service

Browse files
Files changed (2) hide show
  1. app/services/inference.py +4 -1
  2. challenge-cli.py +18 -0
app/services/inference.py CHANGED
@@ -23,7 +23,8 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
23
  self.processor = None
24
  self._is_loaded = False
25
  self.model_path = os.path.join("models", model_name)
26
- logger.info(f"Initializing ResNet service: {self.model_path}")
 
27
 
28
  def load_model(self) -> None:
29
  if self._is_loaded:
@@ -48,6 +49,7 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
48
  self.model_path, local_files_only=True
49
  )
50
  assert self.model is not None
 
51
 
52
  self._is_loaded = True
53
  logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
@@ -65,6 +67,7 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
65
  image = image.convert('RGB')
66
 
67
  inputs = self.processor(image, return_tensors="pt")
 
68
 
69
  with torch.no_grad():
70
  logits = self.model(**inputs).logits.squeeze() # pyright: ignore
 
23
  self.processor = None
24
  self._is_loaded = False
25
  self.model_path = os.path.join("models", model_name)
26
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ logger.info(f"Initializing ResNet service: {self.model_path} on {self.device}")
28
 
29
  def load_model(self) -> None:
30
  if self._is_loaded:
 
49
  self.model_path, local_files_only=True
50
  )
51
  assert self.model is not None
52
+ self.model.to(self.device)
53
 
54
  self._is_loaded = True
55
  logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
 
67
  image = image.convert('RGB')
68
 
69
  inputs = self.processor(image, return_tensors="pt")
70
+ inputs = inputs.to(self.device)
71
 
72
  with torch.no_grad():
73
  logits = self.model(**inputs).logits.squeeze() # pyright: ignore
challenge-cli.py CHANGED
@@ -114,6 +114,12 @@ def cli():
114
  " You can pass the model.id from a previous invocation.",
115
  metavar="ID",
116
  )
 
 
 
 
 
 
117
  def upload_submission(
118
  account: str,
119
  name: str,
@@ -123,6 +129,7 @@ def upload_submission(
123
  volume_mount: Path | None,
124
  artifact_id: str | None,
125
  model_id: str | None,
 
126
  ) -> None:
127
  dyffapi = Client()
128
 
@@ -208,6 +215,16 @@ def upload_submission(
208
  else:
209
  volumeMounts = None
210
 
 
 
 
 
 
 
 
 
 
 
211
  # Don't change this
212
  service_request = InferenceServiceCreateRequest(
213
  account=account,
@@ -218,6 +235,7 @@ def upload_submission(
218
  imageRef=EntityIdentifier.of(artifact),
219
  resources=ModelResources(),
220
  volumeMounts=volumeMounts,
 
221
  ),
222
  interface=InferenceInterface(
223
  endpoint=endpoint,
 
114
  " You can pass the model.id from a previous invocation.",
115
  metavar="ID",
116
  )
117
+ @click.option(
118
+ "--gpu",
119
+ is_flag=True,
120
+ default=False,
121
+ help="Request a GPU (NVIDIA L4) for the inference service.",
122
+ )
123
  def upload_submission(
124
  account: str,
125
  name: str,
 
129
  volume_mount: Path | None,
130
  artifact_id: str | None,
131
  model_id: str | None,
132
+ gpu: bool,
133
  ) -> None:
134
  dyffapi = Client()
135
 
 
215
  else:
216
  volumeMounts = None
217
 
218
+ accelerator: Accelerator | None = None
219
+ if gpu:
220
+ accelerator = Accelerator(
221
+ kind="GPU",
222
+ gpu=AcceleratorGPU(
223
+ hardwareTypes=["nvidia-l4"],
224
+ count=1,
225
+ ),
226
+ )
227
+
228
  # Don't change this
229
  service_request = InferenceServiceCreateRequest(
230
  account=account,
 
235
  imageRef=EntityIdentifier.of(artifact),
236
  resources=ModelResources(),
237
  volumeMounts=volumeMounts,
238
+ accelerator=accelerator,
239
  ),
240
  interface=InferenceInterface(
241
  endpoint=endpoint,