Spaces:
Paused
Paused
doublecheck inference_state["device"]
Browse files
app.py
CHANGED
|
@@ -257,7 +257,8 @@ def segment_with_points(
|
|
| 257 |
evt: gr.SelectData,
|
| 258 |
):
|
| 259 |
predictor.to("cpu")
|
| 260 |
-
inference_state
|
|
|
|
| 261 |
input_points.append(evt.index)
|
| 262 |
print(f"TRACKING INPUT POINT: {input_points}")
|
| 263 |
|
|
@@ -344,7 +345,8 @@ def propagate_to_all(
|
|
| 344 |
torch.backends.cudnn.allow_tf32 = True
|
| 345 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 346 |
predictor.to("cuda")
|
| 347 |
-
inference_state
|
|
|
|
| 348 |
|
| 349 |
if len(input_points) == 0 or video_in is None or inference_state is None:
|
| 350 |
return None
|
|
|
|
| 257 |
evt: gr.SelectData,
|
| 258 |
):
|
| 259 |
predictor.to("cpu")
|
| 260 |
+
if inference_state:
|
| 261 |
+
inference_state["device"] = predictor.device
|
| 262 |
input_points.append(evt.index)
|
| 263 |
print(f"TRACKING INPUT POINT: {input_points}")
|
| 264 |
|
|
|
|
| 345 |
torch.backends.cudnn.allow_tf32 = True
|
| 346 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 347 |
predictor.to("cuda")
|
| 348 |
+
if inference_state:
|
| 349 |
+
inference_state["device"] = predictor.device
|
| 350 |
|
| 351 |
if len(input_points) == 0 or video_in is None or inference_state is None:
|
| 352 |
return None
|