Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| import logging | |
| import os | |
| import uuid | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Any, Dict, Generator, List | |
| import numpy as np | |
| import torch | |
| from app_conf import APP_ROOT, MODEL_SIZE | |
| from inference.data_types import ( | |
| AddMaskRequest, | |
| AddPointsRequest, | |
| CancelPorpagateResponse, | |
| CancelPropagateInVideoRequest, | |
| ClearPointsInFrameRequest, | |
| ClearPointsInVideoRequest, | |
| ClearPointsInVideoResponse, | |
| CloseSessionRequest, | |
| CloseSessionResponse, | |
| Mask, | |
| PropagateDataResponse, | |
| PropagateDataValue, | |
| PropagateInVideoRequest, | |
| RemoveObjectRequest, | |
| RemoveObjectResponse, | |
| StartSessionRequest, | |
| StartSessionResponse, | |
| ) | |
| from pycocotools.mask import decode as decode_masks, encode as encode_masks | |
| from sam2.build_sam import build_sam2_video_predictor | |
| logger = logging.getLogger(__name__) | |
| class InferenceAPI: | |
| def __init__(self) -> None: | |
| super(InferenceAPI, self).__init__() | |
| self.session_states: Dict[str, Any] = {} | |
| self.score_thresh = 0 | |
| if MODEL_SIZE == "tiny": | |
| checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt" | |
| model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
| elif MODEL_SIZE == "small": | |
| checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt" | |
| model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" | |
| elif MODEL_SIZE == "large": | |
| checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt" | |
| model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
| else: # base_plus (default) | |
| checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt" | |
| model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" | |
| # select the device for computation | |
| force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1" | |
| if force_cpu_device: | |
| logger.info("forcing CPU device for SAM 2 demo") | |
| if torch.cuda.is_available() and not force_cpu_device: | |
| device = torch.device("cuda") | |
| elif torch.backends.mps.is_available() and not force_cpu_device: | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| logger.info(f"using device: {device}") | |
| if device.type == "cuda": | |
| # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| elif device.type == "mps": | |
| logging.warning( | |
| "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " | |
| "give numerically different outputs and sometimes degraded performance on MPS. " | |
| "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." | |
| ) | |
| self.device = device | |
| self.predictor = build_sam2_video_predictor( | |
| model_cfg, checkpoint, device=device | |
| ) | |
| self.inference_lock = Lock() | |
| def autocast_context(self): | |
| if self.device.type == "cuda": | |
| return torch.autocast("cuda", dtype=torch.bfloat16) | |
| else: | |
| return contextlib.nullcontext() | |
| def start_session(self, request: StartSessionRequest) -> StartSessionResponse: | |
| with self.autocast_context(), self.inference_lock: | |
| session_id = str(uuid.uuid4()) | |
| # for MPS devices, we offload the video frames to CPU by default to avoid | |
| # memory fragmentation in MPS (which sometimes crashes the entire process) | |
| offload_video_to_cpu = self.device.type == "mps" | |
| inference_state = self.predictor.init_state( | |
| request.path, | |
| offload_video_to_cpu=offload_video_to_cpu, | |
| ) | |
| self.session_states[session_id] = { | |
| "canceled": False, | |
| "state": inference_state, | |
| } | |
| return StartSessionResponse(session_id=session_id) | |
| def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse: | |
| is_successful = self.__clear_session_state(request.session_id) | |
| return CloseSessionResponse(success=is_successful) | |
| def add_points( | |
| self, request: AddPointsRequest, test: str = "" | |
| ) -> PropagateDataResponse: | |
| with self.autocast_context(), self.inference_lock: | |
| session = self.__get_session(request.session_id) | |
| inference_state = session["state"] | |
| frame_idx = request.frame_index | |
| obj_id = request.object_id | |
| points = request.points | |
| labels = request.labels | |
| clear_old_points = request.clear_old_points | |
| # add new prompts and instantly get the output on the same frame | |
| frame_idx, object_ids, masks = self.predictor.add_new_points_or_box( | |
| inference_state=inference_state, | |
| frame_idx=frame_idx, | |
| obj_id=obj_id, | |
| points=points, | |
| labels=labels, | |
| clear_old_points=clear_old_points, | |
| normalize_coords=False, | |
| ) | |
| masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy() | |
| rle_mask_list = self.__get_rle_mask_list( | |
| object_ids=object_ids, masks=masks_binary | |
| ) | |
| return PropagateDataResponse( | |
| frame_index=frame_idx, | |
| results=rle_mask_list, | |
| ) | |
| def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse: | |
| """ | |
| Add new points on a specific video frame. | |
| - mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background). | |
| Note: providing an input mask would overwrite any previous input points on this frame. | |
| """ | |
| with self.autocast_context(), self.inference_lock: | |
| session_id = request.session_id | |
| frame_idx = request.frame_index | |
| obj_id = request.object_id | |
| rle_mask = { | |
| "counts": request.mask.counts, | |
| "size": request.mask.size, | |
| } | |
| mask = decode_masks(rle_mask) | |
| logger.info( | |
| f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}" | |
| ) | |
| session = self.__get_session(session_id) | |
| inference_state = session["state"] | |
| frame_idx, obj_ids, video_res_masks = self.model.add_new_mask( | |
| inference_state=inference_state, | |
| frame_idx=frame_idx, | |
| obj_id=obj_id, | |
| mask=torch.tensor(mask > 0), | |
| ) | |
| masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() | |
| rle_mask_list = self.__get_rle_mask_list( | |
| object_ids=obj_ids, masks=masks_binary | |
| ) | |
| return PropagateDataResponse( | |
| frame_index=frame_idx, | |
| results=rle_mask_list, | |
| ) | |
| def clear_points_in_frame( | |
| self, request: ClearPointsInFrameRequest | |
| ) -> PropagateDataResponse: | |
| """ | |
| Remove all input points in a specific frame. | |
| """ | |
| with self.autocast_context(), self.inference_lock: | |
| session_id = request.session_id | |
| frame_idx = request.frame_index | |
| obj_id = request.object_id | |
| logger.info( | |
| f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}" | |
| ) | |
| session = self.__get_session(session_id) | |
| inference_state = session["state"] | |
| frame_idx, obj_ids, video_res_masks = ( | |
| self.predictor.clear_all_prompts_in_frame( | |
| inference_state, frame_idx, obj_id | |
| ) | |
| ) | |
| masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() | |
| rle_mask_list = self.__get_rle_mask_list( | |
| object_ids=obj_ids, masks=masks_binary | |
| ) | |
| return PropagateDataResponse( | |
| frame_index=frame_idx, | |
| results=rle_mask_list, | |
| ) | |
| def clear_points_in_video( | |
| self, request: ClearPointsInVideoRequest | |
| ) -> ClearPointsInVideoResponse: | |
| """ | |
| Remove all input points in all frames throughout the video. | |
| """ | |
| with self.autocast_context(), self.inference_lock: | |
| session_id = request.session_id | |
| logger.info(f"clear all inputs across the video in session {session_id}") | |
| session = self.__get_session(session_id) | |
| inference_state = session["state"] | |
| self.predictor.reset_state(inference_state) | |
| return ClearPointsInVideoResponse(success=True) | |
| def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse: | |
| """ | |
| Remove an object id from the tracking state. | |
| """ | |
| with self.autocast_context(), self.inference_lock: | |
| session_id = request.session_id | |
| obj_id = request.object_id | |
| logger.info(f"remove object in session {session_id}: {obj_id=}") | |
| session = self.__get_session(session_id) | |
| inference_state = session["state"] | |
| new_obj_ids, updated_frames = self.predictor.remove_object( | |
| inference_state, obj_id | |
| ) | |
| results = [] | |
| for frame_index, video_res_masks in updated_frames: | |
| masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() | |
| rle_mask_list = self.__get_rle_mask_list( | |
| object_ids=new_obj_ids, masks=masks | |
| ) | |
| results.append( | |
| PropagateDataResponse( | |
| frame_index=frame_index, | |
| results=rle_mask_list, | |
| ) | |
| ) | |
| return RemoveObjectResponse(results=results) | |
| def propagate_in_video( | |
| self, request: PropagateInVideoRequest | |
| ) -> Generator[PropagateDataResponse, None, None]: | |
| session_id = request.session_id | |
| start_frame_idx = request.start_frame_index | |
| propagation_direction = "both" | |
| max_frame_num_to_track = None | |
| """ | |
| Propagate existing input points in all frames to track the object across video. | |
| """ | |
| # Note that as this method is a generator, we also need to use autocast_context | |
| # in caller to this method to ensure that it's called under the correct context | |
| # (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py). | |
| with self.autocast_context(), self.inference_lock: | |
| logger.info( | |
| f"propagate in video in session {session_id}: " | |
| f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}" | |
| ) | |
| try: | |
| session = self.__get_session(session_id) | |
| session["canceled"] = False | |
| inference_state = session["state"] | |
| if propagation_direction not in ["both", "forward", "backward"]: | |
| raise ValueError( | |
| f"invalid propagation direction: {propagation_direction}" | |
| ) | |
| # First doing the forward propagation | |
| if propagation_direction in ["both", "forward"]: | |
| for outputs in self.predictor.propagate_in_video( | |
| inference_state=inference_state, | |
| start_frame_idx=start_frame_idx, | |
| max_frame_num_to_track=max_frame_num_to_track, | |
| reverse=False, | |
| ): | |
| if session["canceled"]: | |
| return None | |
| frame_idx, obj_ids, video_res_masks = outputs | |
| masks_binary = ( | |
| (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() | |
| ) | |
| rle_mask_list = self.__get_rle_mask_list( | |
| object_ids=obj_ids, masks=masks_binary | |
| ) | |
| yield PropagateDataResponse( | |
| frame_index=frame_idx, | |
| results=rle_mask_list, | |
| ) | |
| # Then doing the backward propagation (reverse in time) | |
| if propagation_direction in ["both", "backward"]: | |
| for outputs in self.predictor.propagate_in_video( | |
| inference_state=inference_state, | |
| start_frame_idx=start_frame_idx, | |
| max_frame_num_to_track=max_frame_num_to_track, | |
| reverse=True, | |
| ): | |
| if session["canceled"]: | |
| return None | |
| frame_idx, obj_ids, video_res_masks = outputs | |
| masks_binary = ( | |
| (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() | |
| ) | |
| rle_mask_list = self.__get_rle_mask_list( | |
| object_ids=obj_ids, masks=masks_binary | |
| ) | |
| yield PropagateDataResponse( | |
| frame_index=frame_idx, | |
| results=rle_mask_list, | |
| ) | |
| finally: | |
| # Log upon completion (so that e.g. we can see if two propagations happen in parallel). | |
| # Using `finally` here to log even when the tracking is aborted with GeneratorExit. | |
| logger.info( | |
| f"propagation ended in session {session_id}; {self.__get_session_stats()}" | |
| ) | |
| def cancel_propagate_in_video( | |
| self, request: CancelPropagateInVideoRequest | |
| ) -> CancelPorpagateResponse: | |
| session = self.__get_session(request.session_id) | |
| session["canceled"] = True | |
| return CancelPorpagateResponse(success=True) | |
| def __get_rle_mask_list( | |
| self, object_ids: List[int], masks: np.ndarray | |
| ) -> List[PropagateDataValue]: | |
| """ | |
| Return a list of data values, i.e. list of object/mask combos. | |
| """ | |
| return [ | |
| self.__get_mask_for_object(object_id=object_id, mask=mask) | |
| for object_id, mask in zip(object_ids, masks) | |
| ] | |
| def __get_mask_for_object( | |
| self, object_id: int, mask: np.ndarray | |
| ) -> PropagateDataValue: | |
| """ | |
| Create a data value for an object/mask combo. | |
| """ | |
| mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F")) | |
| mask_rle["counts"] = mask_rle["counts"].decode() | |
| return PropagateDataValue( | |
| object_id=object_id, | |
| mask=Mask( | |
| size=mask_rle["size"], | |
| counts=mask_rle["counts"], | |
| ), | |
| ) | |
| def __get_session(self, session_id: str): | |
| session = self.session_states.get(session_id, None) | |
| if session is None: | |
| raise RuntimeError( | |
| f"Cannot find session {session_id}; it might have expired" | |
| ) | |
| return session | |
| def __get_session_stats(self): | |
| """Get a statistics string for live sessions and their GPU usage.""" | |
| # print both the session ids and their video frame numbers | |
| live_session_strs = [ | |
| f"'{session_id}' ({session['state']['num_frames']} frames, " | |
| f"{len(session['state']['obj_ids'])} objects)" | |
| for session_id, session in self.session_states.items() | |
| ] | |
| session_stats_str = ( | |
| "Test String Here - -" | |
| f"live sessions: [{', '.join(live_session_strs)}], GPU memory: " | |
| f"{torch.cuda.memory_allocated() // 1024**2} MiB used and " | |
| f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved" | |
| f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used " | |
| f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)" | |
| ) | |
| return session_stats_str | |
| def __clear_session_state(self, session_id: str) -> bool: | |
| session = self.session_states.pop(session_id, None) | |
| if session is None: | |
| logger.warning( | |
| f"cannot close session {session_id} as it does not exist (it might have expired); " | |
| f"{self.__get_session_stats()}" | |
| ) | |
| return False | |
| else: | |
| logger.info(f"removed session {session_id}; {self.__get_session_stats()}") | |
| return True | |