Spaces:
Paused
Paused
remove predictor from state
Browse files
app.py
CHANGED
|
@@ -87,21 +87,45 @@ def get_video_fps(video_path):
|
|
| 87 |
return fps
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def reset(
|
| 91 |
first_frame,
|
| 92 |
all_frames,
|
| 93 |
input_points,
|
| 94 |
input_labels,
|
| 95 |
inference_state,
|
| 96 |
-
predictor,
|
| 97 |
):
|
| 98 |
first_frame = None
|
| 99 |
all_frames = None
|
| 100 |
input_points = []
|
| 101 |
input_labels = []
|
| 102 |
|
| 103 |
-
if inference_state and predictor:
|
| 104 |
-
predictor.reset_state(inference_state)
|
| 105 |
inference_state = None
|
| 106 |
return (
|
| 107 |
None,
|
|
@@ -114,7 +138,6 @@ def reset(
|
|
| 114 |
input_points,
|
| 115 |
input_labels,
|
| 116 |
inference_state,
|
| 117 |
-
predictor,
|
| 118 |
)
|
| 119 |
|
| 120 |
|
|
@@ -124,12 +147,11 @@ def clear_points(
|
|
| 124 |
input_points,
|
| 125 |
input_labels,
|
| 126 |
inference_state,
|
| 127 |
-
predictor,
|
| 128 |
):
|
| 129 |
input_points = []
|
| 130 |
input_labels = []
|
| 131 |
-
if inference_state and
|
| 132 |
-
|
| 133 |
return (
|
| 134 |
first_frame,
|
| 135 |
None,
|
|
@@ -139,7 +161,6 @@ def clear_points(
|
|
| 139 |
input_points,
|
| 140 |
input_labels,
|
| 141 |
inference_state,
|
| 142 |
-
predictor,
|
| 143 |
)
|
| 144 |
|
| 145 |
|
|
@@ -150,7 +171,6 @@ def preprocess_video_in(
|
|
| 150 |
input_points,
|
| 151 |
input_labels,
|
| 152 |
inference_state,
|
| 153 |
-
predictor,
|
| 154 |
):
|
| 155 |
if video_path is None:
|
| 156 |
return (
|
|
@@ -163,7 +183,6 @@ def preprocess_video_in(
|
|
| 163 |
input_points,
|
| 164 |
input_labels,
|
| 165 |
inference_state,
|
| 166 |
-
predictor,
|
| 167 |
)
|
| 168 |
|
| 169 |
# Read the first frame
|
|
@@ -180,12 +199,8 @@ def preprocess_video_in(
|
|
| 180 |
input_points,
|
| 181 |
input_labels,
|
| 182 |
inference_state,
|
| 183 |
-
predictor,
|
| 184 |
)
|
| 185 |
|
| 186 |
-
if predictor is None:
|
| 187 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 188 |
-
|
| 189 |
frame_number = 0
|
| 190 |
_first_frame = None
|
| 191 |
all_frames = []
|
|
@@ -207,10 +222,19 @@ def preprocess_video_in(
|
|
| 207 |
|
| 208 |
cap.release()
|
| 209 |
first_frame = copy.deepcopy(_first_frame)
|
| 210 |
-
inference_state = predictor.init_state(video_path=video_path)
|
| 211 |
input_points = []
|
| 212 |
input_labels = []
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
return [
|
| 215 |
gr.update(open=False), # video_in_drawer
|
| 216 |
first_frame, # points_map
|
|
@@ -221,7 +245,6 @@ def preprocess_video_in(
|
|
| 221 |
input_points,
|
| 222 |
input_labels,
|
| 223 |
inference_state,
|
| 224 |
-
predictor,
|
| 225 |
]
|
| 226 |
|
| 227 |
|
|
@@ -232,9 +255,9 @@ def segment_with_points(
|
|
| 232 |
input_points,
|
| 233 |
input_labels,
|
| 234 |
inference_state,
|
| 235 |
-
predictor,
|
| 236 |
evt: gr.SelectData,
|
| 237 |
):
|
|
|
|
| 238 |
if torch.cuda.is_available():
|
| 239 |
predictor.to("cuda")
|
| 240 |
inference_state["device"] = "cuda"
|
|
@@ -299,7 +322,6 @@ def segment_with_points(
|
|
| 299 |
input_points,
|
| 300 |
input_labels,
|
| 301 |
inference_state,
|
| 302 |
-
predictor,
|
| 303 |
)
|
| 304 |
|
| 305 |
|
|
@@ -325,8 +347,8 @@ def propagate_to_all(
|
|
| 325 |
input_points,
|
| 326 |
input_labels,
|
| 327 |
inference_state,
|
| 328 |
-
predictor,
|
| 329 |
):
|
|
|
|
| 330 |
if torch.cuda.is_available():
|
| 331 |
predictor.to("cuda")
|
| 332 |
inference_state["device"] = "cuda"
|
|
@@ -383,15 +405,15 @@ def propagate_to_all(
|
|
| 383 |
input_points,
|
| 384 |
input_labels,
|
| 385 |
inference_state,
|
| 386 |
-
predictor,
|
| 387 |
)
|
| 388 |
|
| 389 |
|
| 390 |
try:
|
| 391 |
from spaces import GPU
|
| 392 |
|
| 393 |
-
|
| 394 |
-
|
|
|
|
| 395 |
except:
|
| 396 |
print("spaces unavailable")
|
| 397 |
|
|
@@ -406,7 +428,6 @@ with gr.Blocks() as demo:
|
|
| 406 |
input_points = gr.State([])
|
| 407 |
input_labels = gr.State([])
|
| 408 |
inference_state = gr.State()
|
| 409 |
-
predictor = gr.State()
|
| 410 |
|
| 411 |
with gr.Column():
|
| 412 |
# Title
|
|
@@ -461,7 +482,6 @@ with gr.Blocks() as demo:
|
|
| 461 |
input_points,
|
| 462 |
input_labels,
|
| 463 |
inference_state,
|
| 464 |
-
predictor,
|
| 465 |
],
|
| 466 |
outputs=[
|
| 467 |
video_in_drawer, # Accordion to hide uploaded video player
|
|
@@ -473,7 +493,6 @@ with gr.Blocks() as demo:
|
|
| 473 |
input_points,
|
| 474 |
input_labels,
|
| 475 |
inference_state,
|
| 476 |
-
predictor,
|
| 477 |
],
|
| 478 |
queue=False,
|
| 479 |
)
|
|
@@ -487,7 +506,6 @@ with gr.Blocks() as demo:
|
|
| 487 |
input_points,
|
| 488 |
input_labels,
|
| 489 |
inference_state,
|
| 490 |
-
predictor,
|
| 491 |
],
|
| 492 |
outputs=[
|
| 493 |
video_in_drawer, # Accordion to hide uploaded video player
|
|
@@ -499,7 +517,6 @@ with gr.Blocks() as demo:
|
|
| 499 |
input_points,
|
| 500 |
input_labels,
|
| 501 |
inference_state,
|
| 502 |
-
predictor,
|
| 503 |
],
|
| 504 |
queue=False,
|
| 505 |
)
|
|
@@ -514,7 +531,6 @@ with gr.Blocks() as demo:
|
|
| 514 |
input_points,
|
| 515 |
input_labels,
|
| 516 |
inference_state,
|
| 517 |
-
predictor,
|
| 518 |
],
|
| 519 |
outputs=[
|
| 520 |
points_map, # updated image with points
|
|
@@ -524,7 +540,6 @@ with gr.Blocks() as demo:
|
|
| 524 |
input_points,
|
| 525 |
input_labels,
|
| 526 |
inference_state,
|
| 527 |
-
predictor,
|
| 528 |
],
|
| 529 |
queue=False,
|
| 530 |
)
|
|
@@ -538,7 +553,6 @@ with gr.Blocks() as demo:
|
|
| 538 |
input_points,
|
| 539 |
input_labels,
|
| 540 |
inference_state,
|
| 541 |
-
predictor,
|
| 542 |
],
|
| 543 |
outputs=[
|
| 544 |
points_map,
|
|
@@ -549,7 +563,6 @@ with gr.Blocks() as demo:
|
|
| 549 |
input_points,
|
| 550 |
input_labels,
|
| 551 |
inference_state,
|
| 552 |
-
predictor,
|
| 553 |
],
|
| 554 |
queue=False,
|
| 555 |
)
|
|
@@ -562,7 +575,6 @@ with gr.Blocks() as demo:
|
|
| 562 |
input_points,
|
| 563 |
input_labels,
|
| 564 |
inference_state,
|
| 565 |
-
predictor,
|
| 566 |
],
|
| 567 |
outputs=[
|
| 568 |
video_in,
|
|
@@ -575,7 +587,6 @@ with gr.Blocks() as demo:
|
|
| 575 |
input_points,
|
| 576 |
input_labels,
|
| 577 |
inference_state,
|
| 578 |
-
predictor,
|
| 579 |
],
|
| 580 |
queue=False,
|
| 581 |
)
|
|
@@ -594,7 +605,6 @@ with gr.Blocks() as demo:
|
|
| 594 |
input_points,
|
| 595 |
input_labels,
|
| 596 |
inference_state,
|
| 597 |
-
predictor,
|
| 598 |
],
|
| 599 |
outputs=[
|
| 600 |
output_video,
|
|
@@ -603,7 +613,6 @@ with gr.Blocks() as demo:
|
|
| 603 |
input_points,
|
| 604 |
input_labels,
|
| 605 |
inference_state,
|
| 606 |
-
predictor,
|
| 607 |
],
|
| 608 |
concurrency_limit=10,
|
| 609 |
queue=False,
|
|
|
|
| 87 |
return fps
|
| 88 |
|
| 89 |
|
| 90 |
+
def reset_state(inference_state):
|
| 91 |
+
for v in inference_state["point_inputs_per_obj"].values():
|
| 92 |
+
v.clear()
|
| 93 |
+
for v in inference_state["mask_inputs_per_obj"].values():
|
| 94 |
+
v.clear()
|
| 95 |
+
for v in inference_state["output_dict_per_obj"].values():
|
| 96 |
+
v["cond_frame_outputs"].clear()
|
| 97 |
+
v["non_cond_frame_outputs"].clear()
|
| 98 |
+
for v in inference_state["temp_output_dict_per_obj"].values():
|
| 99 |
+
v["cond_frame_outputs"].clear()
|
| 100 |
+
v["non_cond_frame_outputs"].clear()
|
| 101 |
+
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
| 102 |
+
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
| 103 |
+
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
| 104 |
+
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
| 105 |
+
inference_state["tracking_has_started"] = False
|
| 106 |
+
inference_state["frames_already_tracked"].clear()
|
| 107 |
+
inference_state["obj_id_to_idx"].clear()
|
| 108 |
+
inference_state["obj_idx_to_id"].clear()
|
| 109 |
+
inference_state["obj_ids"].clear()
|
| 110 |
+
inference_state["point_inputs_per_obj"].clear()
|
| 111 |
+
inference_state["mask_inputs_per_obj"].clear()
|
| 112 |
+
inference_state["output_dict_per_obj"].clear()
|
| 113 |
+
inference_state["temp_output_dict_per_obj"].clear()
|
| 114 |
+
return inference_state
|
| 115 |
+
|
| 116 |
+
|
| 117 |
def reset(
|
| 118 |
first_frame,
|
| 119 |
all_frames,
|
| 120 |
input_points,
|
| 121 |
input_labels,
|
| 122 |
inference_state,
|
|
|
|
| 123 |
):
|
| 124 |
first_frame = None
|
| 125 |
all_frames = None
|
| 126 |
input_points = []
|
| 127 |
input_labels = []
|
| 128 |
|
|
|
|
|
|
|
| 129 |
inference_state = None
|
| 130 |
return (
|
| 131 |
None,
|
|
|
|
| 138 |
input_points,
|
| 139 |
input_labels,
|
| 140 |
inference_state,
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
|
|
|
|
| 147 |
input_points,
|
| 148 |
input_labels,
|
| 149 |
inference_state,
|
|
|
|
| 150 |
):
|
| 151 |
input_points = []
|
| 152 |
input_labels = []
|
| 153 |
+
if inference_state and inference_state["tracking_has_started"]:
|
| 154 |
+
inference_state = reset_state(inference_state)
|
| 155 |
return (
|
| 156 |
first_frame,
|
| 157 |
None,
|
|
|
|
| 161 |
input_points,
|
| 162 |
input_labels,
|
| 163 |
inference_state,
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
|
|
|
|
| 171 |
input_points,
|
| 172 |
input_labels,
|
| 173 |
inference_state,
|
|
|
|
| 174 |
):
|
| 175 |
if video_path is None:
|
| 176 |
return (
|
|
|
|
| 183 |
input_points,
|
| 184 |
input_labels,
|
| 185 |
inference_state,
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
# Read the first frame
|
|
|
|
| 199 |
input_points,
|
| 200 |
input_labels,
|
| 201 |
inference_state,
|
|
|
|
| 202 |
)
|
| 203 |
|
|
|
|
|
|
|
|
|
|
| 204 |
frame_number = 0
|
| 205 |
_first_frame = None
|
| 206 |
all_frames = []
|
|
|
|
| 222 |
|
| 223 |
cap.release()
|
| 224 |
first_frame = copy.deepcopy(_first_frame)
|
|
|
|
| 225 |
input_points = []
|
| 226 |
input_labels = []
|
| 227 |
|
| 228 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 229 |
+
if torch.cuda.is_available():
|
| 230 |
+
predictor.to("cuda")
|
| 231 |
+
inference_state["device"] = "cuda"
|
| 232 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 233 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 234 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 235 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 236 |
+
inference_state = predictor.init_state(video_path=video_path)
|
| 237 |
+
|
| 238 |
return [
|
| 239 |
gr.update(open=False), # video_in_drawer
|
| 240 |
first_frame, # points_map
|
|
|
|
| 245 |
input_points,
|
| 246 |
input_labels,
|
| 247 |
inference_state,
|
|
|
|
| 248 |
]
|
| 249 |
|
| 250 |
|
|
|
|
| 255 |
input_points,
|
| 256 |
input_labels,
|
| 257 |
inference_state,
|
|
|
|
| 258 |
evt: gr.SelectData,
|
| 259 |
):
|
| 260 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 261 |
if torch.cuda.is_available():
|
| 262 |
predictor.to("cuda")
|
| 263 |
inference_state["device"] = "cuda"
|
|
|
|
| 322 |
input_points,
|
| 323 |
input_labels,
|
| 324 |
inference_state,
|
|
|
|
| 325 |
)
|
| 326 |
|
| 327 |
|
|
|
|
| 347 |
input_points,
|
| 348 |
input_labels,
|
| 349 |
inference_state,
|
|
|
|
| 350 |
):
|
| 351 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 352 |
if torch.cuda.is_available():
|
| 353 |
predictor.to("cuda")
|
| 354 |
inference_state["device"] = "cuda"
|
|
|
|
| 405 |
input_points,
|
| 406 |
input_labels,
|
| 407 |
inference_state,
|
|
|
|
| 408 |
)
|
| 409 |
|
| 410 |
|
| 411 |
try:
|
| 412 |
from spaces import GPU
|
| 413 |
|
| 414 |
+
preprocess_video_in = GPU(preprocess_video_in, duration=10)
|
| 415 |
+
segment_with_points = GPU(segment_with_points, duration=5)
|
| 416 |
+
propagate_to_all = GPU(propagate_to_all, duration=30)
|
| 417 |
except:
|
| 418 |
print("spaces unavailable")
|
| 419 |
|
|
|
|
| 428 |
input_points = gr.State([])
|
| 429 |
input_labels = gr.State([])
|
| 430 |
inference_state = gr.State()
|
|
|
|
| 431 |
|
| 432 |
with gr.Column():
|
| 433 |
# Title
|
|
|
|
| 482 |
input_points,
|
| 483 |
input_labels,
|
| 484 |
inference_state,
|
|
|
|
| 485 |
],
|
| 486 |
outputs=[
|
| 487 |
video_in_drawer, # Accordion to hide uploaded video player
|
|
|
|
| 493 |
input_points,
|
| 494 |
input_labels,
|
| 495 |
inference_state,
|
|
|
|
| 496 |
],
|
| 497 |
queue=False,
|
| 498 |
)
|
|
|
|
| 506 |
input_points,
|
| 507 |
input_labels,
|
| 508 |
inference_state,
|
|
|
|
| 509 |
],
|
| 510 |
outputs=[
|
| 511 |
video_in_drawer, # Accordion to hide uploaded video player
|
|
|
|
| 517 |
input_points,
|
| 518 |
input_labels,
|
| 519 |
inference_state,
|
|
|
|
| 520 |
],
|
| 521 |
queue=False,
|
| 522 |
)
|
|
|
|
| 531 |
input_points,
|
| 532 |
input_labels,
|
| 533 |
inference_state,
|
|
|
|
| 534 |
],
|
| 535 |
outputs=[
|
| 536 |
points_map, # updated image with points
|
|
|
|
| 540 |
input_points,
|
| 541 |
input_labels,
|
| 542 |
inference_state,
|
|
|
|
| 543 |
],
|
| 544 |
queue=False,
|
| 545 |
)
|
|
|
|
| 553 |
input_points,
|
| 554 |
input_labels,
|
| 555 |
inference_state,
|
|
|
|
| 556 |
],
|
| 557 |
outputs=[
|
| 558 |
points_map,
|
|
|
|
| 563 |
input_points,
|
| 564 |
input_labels,
|
| 565 |
inference_state,
|
|
|
|
| 566 |
],
|
| 567 |
queue=False,
|
| 568 |
)
|
|
|
|
| 575 |
input_points,
|
| 576 |
input_labels,
|
| 577 |
inference_state,
|
|
|
|
| 578 |
],
|
| 579 |
outputs=[
|
| 580 |
video_in,
|
|
|
|
| 587 |
input_points,
|
| 588 |
input_labels,
|
| 589 |
inference_state,
|
|
|
|
| 590 |
],
|
| 591 |
queue=False,
|
| 592 |
)
|
|
|
|
| 605 |
input_points,
|
| 606 |
input_labels,
|
| 607 |
inference_state,
|
|
|
|
| 608 |
],
|
| 609 |
outputs=[
|
| 610 |
output_video,
|
|
|
|
| 613 |
input_points,
|
| 614 |
input_labels,
|
| 615 |
inference_state,
|
|
|
|
| 616 |
],
|
| 617 |
concurrency_limit=10,
|
| 618 |
queue=False,
|