Add GPU support for inference service
Browse files- app/services/inference.py +4 -1
- challenge-cli.py +18 -0
app/services/inference.py
CHANGED
|
@@ -23,7 +23,8 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
|
|
| 23 |
self.processor = None
|
| 24 |
self._is_loaded = False
|
| 25 |
self.model_path = os.path.join("models", model_name)
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
def load_model(self) -> None:
|
| 29 |
if self._is_loaded:
|
|
@@ -48,6 +49,7 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
|
|
| 48 |
self.model_path, local_files_only=True
|
| 49 |
)
|
| 50 |
assert self.model is not None
|
|
|
|
| 51 |
|
| 52 |
self._is_loaded = True
|
| 53 |
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
|
|
@@ -65,6 +67,7 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
|
|
| 65 |
image = image.convert('RGB')
|
| 66 |
|
| 67 |
inputs = self.processor(image, return_tensors="pt")
|
|
|
|
| 68 |
|
| 69 |
with torch.no_grad():
|
| 70 |
logits = self.model(**inputs).logits.squeeze() # pyright: ignore
|
|
|
|
| 23 |
self.processor = None
|
| 24 |
self._is_loaded = False
|
| 25 |
self.model_path = os.path.join("models", model_name)
|
| 26 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
+
logger.info(f"Initializing ResNet service: {self.model_path} on {self.device}")
|
| 28 |
|
| 29 |
def load_model(self) -> None:
|
| 30 |
if self._is_loaded:
|
|
|
|
| 49 |
self.model_path, local_files_only=True
|
| 50 |
)
|
| 51 |
assert self.model is not None
|
| 52 |
+
self.model.to(self.device)
|
| 53 |
|
| 54 |
self._is_loaded = True
|
| 55 |
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
|
|
|
|
| 67 |
image = image.convert('RGB')
|
| 68 |
|
| 69 |
inputs = self.processor(image, return_tensors="pt")
|
| 70 |
+
inputs = inputs.to(self.device)
|
| 71 |
|
| 72 |
with torch.no_grad():
|
| 73 |
logits = self.model(**inputs).logits.squeeze() # pyright: ignore
|
challenge-cli.py
CHANGED
|
@@ -114,6 +114,12 @@ def cli():
|
|
| 114 |
" You can pass the model.id from a previous invocation.",
|
| 115 |
metavar="ID",
|
| 116 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
def upload_submission(
|
| 118 |
account: str,
|
| 119 |
name: str,
|
|
@@ -123,6 +129,7 @@ def upload_submission(
|
|
| 123 |
volume_mount: Path | None,
|
| 124 |
artifact_id: str | None,
|
| 125 |
model_id: str | None,
|
|
|
|
| 126 |
) -> None:
|
| 127 |
dyffapi = Client()
|
| 128 |
|
|
@@ -208,6 +215,16 @@ def upload_submission(
|
|
| 208 |
else:
|
| 209 |
volumeMounts = None
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
# Don't change this
|
| 212 |
service_request = InferenceServiceCreateRequest(
|
| 213 |
account=account,
|
|
@@ -218,6 +235,7 @@ def upload_submission(
|
|
| 218 |
imageRef=EntityIdentifier.of(artifact),
|
| 219 |
resources=ModelResources(),
|
| 220 |
volumeMounts=volumeMounts,
|
|
|
|
| 221 |
),
|
| 222 |
interface=InferenceInterface(
|
| 223 |
endpoint=endpoint,
|
|
|
|
| 114 |
" You can pass the model.id from a previous invocation.",
|
| 115 |
metavar="ID",
|
| 116 |
)
|
| 117 |
+
@click.option(
|
| 118 |
+
"--gpu",
|
| 119 |
+
is_flag=True,
|
| 120 |
+
default=False,
|
| 121 |
+
help="Request a GPU (NVIDIA L4) for the inference service.",
|
| 122 |
+
)
|
| 123 |
def upload_submission(
|
| 124 |
account: str,
|
| 125 |
name: str,
|
|
|
|
| 129 |
volume_mount: Path | None,
|
| 130 |
artifact_id: str | None,
|
| 131 |
model_id: str | None,
|
| 132 |
+
gpu: bool,
|
| 133 |
) -> None:
|
| 134 |
dyffapi = Client()
|
| 135 |
|
|
|
|
| 215 |
else:
|
| 216 |
volumeMounts = None
|
| 217 |
|
| 218 |
+
accelerator: Accelerator | None = None
|
| 219 |
+
if gpu:
|
| 220 |
+
accelerator = Accelerator(
|
| 221 |
+
kind="GPU",
|
| 222 |
+
gpu=AcceleratorGPU(
|
| 223 |
+
hardwareTypes=["nvidia-l4"],
|
| 224 |
+
count=1,
|
| 225 |
+
),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
# Don't change this
|
| 229 |
service_request = InferenceServiceCreateRequest(
|
| 230 |
account=account,
|
|
|
|
| 235 |
imageRef=EntityIdentifier.of(artifact),
|
| 236 |
resources=ModelResources(),
|
| 237 |
volumeMounts=volumeMounts,
|
| 238 |
+
accelerator=accelerator,
|
| 239 |
),
|
| 240 |
interface=InferenceInterface(
|
| 241 |
endpoint=endpoint,
|