Spaces:
Paused
Paused
carefully move between CPU and GPU
Browse files- app.py +6 -11
- sam2/modeling/sam2_base.py +2 -2
- sam2/sam2_video_predictor.py +20 -12
app.py
CHANGED
|
@@ -165,7 +165,6 @@ def clear_points(
|
|
| 165 |
)
|
| 166 |
|
| 167 |
|
| 168 |
-
@spaces.GPU(duration=10)
|
| 169 |
def preprocess_video_in(
|
| 170 |
video_path,
|
| 171 |
first_frame,
|
|
@@ -227,16 +226,12 @@ def preprocess_video_in(
|
|
| 227 |
input_points = []
|
| 228 |
input_labels = []
|
| 229 |
|
| 230 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
offload_video_to_cpu=True,
|
| 237 |
-
offload_state_to_cpu=True,
|
| 238 |
-
video_path=video_path,
|
| 239 |
-
)
|
| 240 |
|
| 241 |
return [
|
| 242 |
gr.update(open=False), # video_in_drawer
|
|
|
|
| 165 |
)
|
| 166 |
|
| 167 |
|
|
|
|
| 168 |
def preprocess_video_in(
|
| 169 |
video_path,
|
| 170 |
first_frame,
|
|
|
|
| 226 |
input_points = []
|
| 227 |
input_labels = []
|
| 228 |
|
| 229 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 230 |
+
inference_state = predictor.init_state(
|
| 231 |
+
offload_video_to_cpu=True,
|
| 232 |
+
offload_state_to_cpu=True,
|
| 233 |
+
video_path=video_path,
|
| 234 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
return [
|
| 237 |
gr.update(open=False), # video_in_drawer
|
sam2/modeling/sam2_base.py
CHANGED
|
@@ -617,7 +617,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 618 |
else abs(frame_idx - t)
|
| 619 |
),
|
| 620 |
-
out["obj_ptr"],
|
| 621 |
)
|
| 622 |
for t, out in ptr_cond_outputs.items()
|
| 623 |
]
|
|
@@ -630,7 +630,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 630 |
t, unselected_cond_outputs.get(t, None)
|
| 631 |
)
|
| 632 |
if out is not None:
|
| 633 |
-
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 634 |
# If we have at least one object pointer, add them to the across attention
|
| 635 |
if len(pos_and_ptrs) > 0:
|
| 636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
|
|
| 617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 618 |
else abs(frame_idx - t)
|
| 619 |
),
|
| 620 |
+
out["obj_ptr"].to(device),
|
| 621 |
)
|
| 622 |
for t, out in ptr_cond_outputs.items()
|
| 623 |
]
|
|
|
|
| 630 |
t, unselected_cond_outputs.get(t, None)
|
| 631 |
)
|
| 632 |
if out is not None:
|
| 633 |
+
pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device)))
|
| 634 |
# If we have at least one object pointer, add them to the across attention
|
| 635 |
if len(pos_and_ptrs) > 0:
|
| 636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
sam2/sam2_video_predictor.py
CHANGED
|
@@ -107,7 +107,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 107 |
inference_state["tracking_has_started"] = False
|
| 108 |
inference_state["frames_already_tracked"] = {}
|
| 109 |
# Warm up the visual backbone and cache the image feature on frame 0
|
| 110 |
-
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
| 111 |
return inference_state
|
| 112 |
|
| 113 |
@classmethod
|
|
@@ -470,7 +470,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 470 |
size=(batch_size, self.hidden_dim),
|
| 471 |
fill_value=NO_OBJ_SCORE,
|
| 472 |
dtype=torch.float32,
|
| 473 |
-
device=inference_state["
|
| 474 |
),
|
| 475 |
"object_score_logits": torch.full(
|
| 476 |
size=(batch_size, 1),
|
|
@@ -478,7 +478,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
| 479 |
fill_value=10.0,
|
| 480 |
dtype=torch.float32,
|
| 481 |
-
device=inference_state["
|
| 482 |
),
|
| 483 |
}
|
| 484 |
empty_mask_ptr = None
|
|
@@ -545,7 +545,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 545 |
frame_idx=frame_idx,
|
| 546 |
batch_size=batch_size,
|
| 547 |
high_res_masks=high_res_masks,
|
| 548 |
-
object_score_logits=consolidated_out["object_score_logits"]
|
|
|
|
|
|
|
| 549 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 550 |
)
|
| 551 |
consolidated_out["maskmem_features"] = maskmem_features
|
|
@@ -879,9 +881,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 879 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 880 |
"""Compute the image features on a given frame."""
|
| 881 |
# Look up in the cache first
|
| 882 |
-
image, backbone_out = inference_state["cached_features"].get(
|
| 883 |
-
|
| 884 |
-
)
|
|
|
|
| 885 |
if backbone_out is None:
|
| 886 |
# Cache miss -- we will run inference on a single image
|
| 887 |
device = inference_state["device"]
|
|
@@ -889,7 +892,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 889 |
backbone_out = self.forward_image(image)
|
| 890 |
# Cache the most recent frame's feature (for repeated interactions with
|
| 891 |
# a frame; we can use an LRU cache for more frames in the future).
|
| 892 |
-
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
| 893 |
|
| 894 |
# expand the features to have the same dimension as the number of objects
|
| 895 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
@@ -964,9 +967,11 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 964 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 965 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 966 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 967 |
-
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 968 |
-
obj_ptr = current_out["obj_ptr"]
|
| 969 |
-
object_score_logits = current_out["object_score_logits"]
|
|
|
|
|
|
|
| 970 |
# make a compact version of this frame's output to reduce the state size
|
| 971 |
compact_current_out = {
|
| 972 |
"maskmem_features": maskmem_features,
|
|
@@ -1018,6 +1023,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1018 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| 1019 |
a constant in the inference session to reduce session storage size.
|
| 1020 |
"""
|
|
|
|
| 1021 |
model_constants = inference_state["constants"]
|
| 1022 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
| 1023 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
@@ -1026,7 +1032,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1026 |
assert isinstance(out_maskmem_pos_enc, list)
|
| 1027 |
# only take the slice for one object, since it's same across objects
|
| 1028 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| 1029 |
-
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
|
|
| 1030 |
else:
|
| 1031 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 1032 |
# expand the cached maskmem_pos_enc to the actual batch size
|
|
|
|
| 107 |
inference_state["tracking_has_started"] = False
|
| 108 |
inference_state["frames_already_tracked"] = {}
|
| 109 |
# Warm up the visual backbone and cache the image feature on frame 0
|
| 110 |
+
# self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
| 111 |
return inference_state
|
| 112 |
|
| 113 |
@classmethod
|
|
|
|
| 470 |
size=(batch_size, self.hidden_dim),
|
| 471 |
fill_value=NO_OBJ_SCORE,
|
| 472 |
dtype=torch.float32,
|
| 473 |
+
device=inference_state["storage_device"],
|
| 474 |
),
|
| 475 |
"object_score_logits": torch.full(
|
| 476 |
size=(batch_size, 1),
|
|
|
|
| 478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
| 479 |
fill_value=10.0,
|
| 480 |
dtype=torch.float32,
|
| 481 |
+
device=inference_state["storage_device"],
|
| 482 |
),
|
| 483 |
}
|
| 484 |
empty_mask_ptr = None
|
|
|
|
| 545 |
frame_idx=frame_idx,
|
| 546 |
batch_size=batch_size,
|
| 547 |
high_res_masks=high_res_masks,
|
| 548 |
+
object_score_logits=consolidated_out["object_score_logits"].to(
|
| 549 |
+
device, non_blocking=True
|
| 550 |
+
),
|
| 551 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 552 |
)
|
| 553 |
consolidated_out["maskmem_features"] = maskmem_features
|
|
|
|
| 881 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 882 |
"""Compute the image features on a given frame."""
|
| 883 |
# Look up in the cache first
|
| 884 |
+
# image, backbone_out = inference_state["cached_features"].get(
|
| 885 |
+
# frame_idx, (None, None)
|
| 886 |
+
# )
|
| 887 |
+
image, backbone_out = None, None
|
| 888 |
if backbone_out is None:
|
| 889 |
# Cache miss -- we will run inference on a single image
|
| 890 |
device = inference_state["device"]
|
|
|
|
| 892 |
backbone_out = self.forward_image(image)
|
| 893 |
# Cache the most recent frame's feature (for repeated interactions with
|
| 894 |
# a frame; we can use an LRU cache for more frames in the future).
|
| 895 |
+
# inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
| 896 |
|
| 897 |
# expand the features to have the same dimension as the number of objects
|
| 898 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
|
|
| 967 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 968 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 969 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 970 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access (modified for ZeroGPU)
|
| 971 |
+
obj_ptr = current_out["obj_ptr"].to(storage_device, non_blocking=True)
|
| 972 |
+
object_score_logits = current_out["object_score_logits"].to(
|
| 973 |
+
storage_device, non_blocking=True
|
| 974 |
+
)
|
| 975 |
# make a compact version of this frame's output to reduce the state size
|
| 976 |
compact_current_out = {
|
| 977 |
"maskmem_features": maskmem_features,
|
|
|
|
| 1023 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| 1024 |
a constant in the inference session to reduce session storage size.
|
| 1025 |
"""
|
| 1026 |
+
storage_device = inference_state["storage_device"]
|
| 1027 |
model_constants = inference_state["constants"]
|
| 1028 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
| 1029 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
|
| 1032 |
assert isinstance(out_maskmem_pos_enc, list)
|
| 1033 |
# only take the slice for one object, since it's same across objects
|
| 1034 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| 1035 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc.to(
|
| 1036 |
+
storage_device, non_blocking=True
|
| 1037 |
+
)
|
| 1038 |
else:
|
| 1039 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 1040 |
# expand the cached maskmem_pos_enc to the actual batch size
|