Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import cv2 | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import supervision as sv | |
| from PIL import Image | |
| import gc | |
| import sys | |
| from eval.grounded_sam.florence2.modeling_florence2 import Florence2ForConditionalGeneration | |
| from eval.grounded_sam.florence2.processing_florence2 import Florence2Processor | |
| from eval.grounded_sam.sam2.build_sam import build_sam2, build_sam2_hf | |
| from eval.grounded_sam.sam2.sam2_image_predictor import SAM2ImagePredictor | |
| class FlorenceSAM: | |
| # official usage: https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb | |
| TASK_PROMPT = { | |
| "original": "<GIVEN>", | |
| "caption": "<CAPTION>", | |
| "detailed_caption": "<DETAILED_CAPTION>", | |
| "more_detailed_caption": "<MORE_DETAILED_CAPTION>", | |
| "object_detection": "<OD>", | |
| "dense_region_caption": "<DENSE_REGION_CAPTION>", | |
| "region_proposal": "<REGION_PROPOSAL>", | |
| "phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>", | |
| "referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>", | |
| "region_to_segmentation": "<REGION_TO_SEGMENTATION>", | |
| "open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>", | |
| "region_to_category": "<REGION_TO_CATEGORY>", | |
| "region_to_description": "<REGION_TO_DESCRIPTION>", | |
| "ocr": "<OCR>", | |
| "ocr_with_region": "<OCR_WITH_REGION>", | |
| } | |
| def __init__(self, device): | |
| """ | |
| Init Florence-2 and SAM 2 Model | |
| """ | |
| print(f"[{self}] init on device {device}") | |
| self.device = torch.device(device) | |
| # with torch.autocast(device_type="cuda", dtype=torch.float32).__enter__() | |
| # self.torch_dtype = torch.float32 | |
| # self.torch_dtype = torch.float16 | |
| self.torch_dtype = torch.bfloat16 | |
| try: | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # self.torch_dtype = torch.bfloat16 | |
| # else: | |
| # self.torch_dtype = torch.float16 | |
| except: | |
| self.torch_dtype = torch.bfloat16 | |
| FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large") | |
| SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH', "facebook/sam2-hiera-large") | |
| self.florence2_model = Florence2ForConditionalGeneration.from_pretrained( | |
| FLORENCE2_MODEL_ID, | |
| torch_dtype=self.torch_dtype, | |
| ).eval().to(self.device) | |
| self.florence2_processor = Florence2Processor.from_pretrained( | |
| FLORENCE2_MODEL_ID, | |
| ) | |
| sam2_model = build_sam2_hf(SAM2_CHECKPOINT, device=self.device) | |
| self.sam2_predictor = SAM2ImagePredictor(sam2_model) | |
| def __str__(self): | |
| return "FlorenceSAM" | |
| def run_florence2(self, task_prompt, text_input, image): | |
| model = self.florence2_model | |
| processor = self.florence2_processor | |
| device = self.device | |
| assert model is not None, "You should pass the init florence-2 model here" | |
| assert processor is not None, "You should set florence-2 processor here" | |
| with torch.autocast(device_type="cuda", dtype=torch.float32): | |
| if text_input is None: | |
| prompt = task_prompt | |
| else: | |
| prompt = task_prompt + text_input | |
| inputs = processor( | |
| text=prompt, images=image, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(device, self.torch_dtype) | |
| # inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, self.torch_dtype) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"].to(device), | |
| pixel_values=inputs["pixel_values"].to(device), | |
| # max_new_tokens=1024, | |
| max_new_tokens=768, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, | |
| task=task_prompt, | |
| image_size=(image.width, image.height) | |
| ) | |
| return parsed_answer | |
| def caption(self, image, caption_task_prompt='<CAPTION>'): | |
| assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>"] | |
| caption_results = self.run_florence2(caption_task_prompt, None, image) | |
| text_input = caption_results[caption_task_prompt] | |
| caption = text_input | |
| return caption | |
| def segmentation(self, image, input_boxes, seg_model="sam"): | |
| if seg_model == "sam": | |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| sam2_predictor = self.sam2_predictor | |
| sam2_predictor.set_image(np.array(image)) | |
| masks, scores, logits = sam2_predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_boxes, | |
| multimask_output=False, | |
| ) | |
| if masks.ndim == 4: | |
| masks = masks.squeeze(1) | |
| if scores.ndim == 2: | |
| scores = scores.squeeze(1) | |
| else: | |
| raise NotImplementedError() | |
| return masks, scores | |
| def post_process_results(self, image, caption, labels, detections, output_dir=None): | |
| result_dict = { | |
| "caption": caption, | |
| "instance_images": [], | |
| "instance_labels": [], | |
| "instance_bboxes": [], | |
| "instance_mask_scores": [], | |
| } | |
| if detections is None: | |
| return detections, result_dict | |
| if output_dir is not None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| box_annotator = sv.BoxAnnotator() | |
| annotated_frame = box_annotator.annotate(scene=cv_image.copy(), detections=detections) | |
| label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) | |
| annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) | |
| if output_dir is not None: | |
| cv2.imwrite(os.path.join(output_dir, "detections.jpg"), annotated_frame) | |
| mask_annotator = sv.MaskAnnotator() | |
| annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) | |
| if output_dir is not None: | |
| cv2.imwrite(os.path.join(output_dir, "masks.jpg"), annotated_frame) | |
| for detection in detections: | |
| xyxy, mask, confidence, class_id, tracker_id, data = detection | |
| label = labels[class_id] | |
| cropped_img = sv.crop_image(image=cv_image, xyxy=xyxy) | |
| if output_dir is not None: | |
| cv2.imwrite(os.path.join(output_dir, f"cropped_image_{label}.jpg"), cropped_img) | |
| if mask is None: | |
| result_dict["instance_mask_scores"].append(0) | |
| result_dict["instance_images"].append(cropped_img) | |
| else: | |
| mask = np.repeat(mask[..., np.newaxis], 3, axis=-1) | |
| masked_img = np.where(mask, cv_image, 255) | |
| cropped_masked_img = sv.crop_image(image=masked_img, xyxy=xyxy) | |
| result_dict["instance_mask_scores"].append(confidence.item()) | |
| result_dict["instance_images"].append(cropped_masked_img) | |
| result_dict["instance_labels"].append(label) | |
| result_dict["instance_bboxes"].append(xyxy) | |
| if output_dir is not None: | |
| cv2.imwrite(os.path.join(output_dir, f"masked_image_{label}.jpg"), cropped_masked_img) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return detections, result_dict | |
| def caption_phrase_grounding_and_segmentation( | |
| self, | |
| image, | |
| seg_model="sam", | |
| caption_task_prompt='<CAPTION>', | |
| original_caption=None, | |
| output_dir=None | |
| ): | |
| assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", "<GIVEN>", "<OPEN_VOCABULARY_DETECTION>"] | |
| assert seg_model in ["sam", "florence2"] | |
| # image caption | |
| if caption_task_prompt in ["<GIVEN>", "<OPEN_VOCABULARY_DETECTION>"]: | |
| assert original_caption is not None | |
| caption = original_caption | |
| else: | |
| caption_results = self.run_florence2(caption_task_prompt, None, image) | |
| text_input = caption_results[caption_task_prompt] | |
| caption = text_input | |
| # phrase grounding | |
| grounding_results = self.run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', caption, image)['<CAPTION_TO_PHRASE_GROUNDING>'] | |
| input_boxes = np.array(grounding_results["bboxes"]) | |
| class_names = grounding_results["labels"] | |
| class_ids = np.array(list(range(len(class_names)))) | |
| # segmentation | |
| masks, scores = self.segmentation(image, input_boxes, seg_model) | |
| labels = [f"{class_name}" for class_name in class_names] | |
| detections = sv.Detections( | |
| xyxy=input_boxes, | |
| mask=masks.astype(bool), | |
| class_id=class_ids, | |
| confidence=scores, | |
| ) | |
| return self.post_process_results(image, caption, labels, detections, output_dir) | |
| def od_grounding_and_segmentation( | |
| self, | |
| image, | |
| text_input, | |
| seg_model="sam", | |
| output_dir=None | |
| ): | |
| assert seg_model in ["sam", "florence2"] | |
| # od grounding | |
| grounding_results = self.run_florence2('<OPEN_VOCABULARY_DETECTION>', text_input, image)['<OPEN_VOCABULARY_DETECTION>'] | |
| if len(grounding_results["bboxes"]) == 0: | |
| detections = None | |
| labels = [] | |
| else: | |
| input_boxes = np.array(grounding_results["bboxes"]) | |
| class_names = grounding_results["bboxes_labels"] | |
| class_ids = np.array(list(range(len(class_names)))) | |
| # segmentation | |
| masks, scores = self.segmentation(image, input_boxes, seg_model) | |
| labels = [f"{class_name}" for class_name in class_names] | |
| detections = sv.Detections( | |
| xyxy=input_boxes, | |
| mask=masks.astype(bool), | |
| class_id=class_ids, | |
| confidence=scores, | |
| ) | |
| return self.post_process_results(image, text_input, labels, detections, output_dir) | |
| def od_grounding( | |
| self, | |
| image, | |
| text_input, | |
| output_dir=None | |
| ): | |
| # od grounding | |
| grounding_results = self.run_florence2('<OPEN_VOCABULARY_DETECTION>', text_input, image)['<OPEN_VOCABULARY_DETECTION>'] | |
| if len(grounding_results["bboxes"]) == 0: | |
| detections = None | |
| labels = [] | |
| else: | |
| input_boxes = np.array(grounding_results["bboxes"]) | |
| class_names = grounding_results["bboxes_labels"] | |
| class_ids = np.array(list(range(len(class_names)))) | |
| labels = [f"{class_name}" for class_name in class_names] | |
| detections = sv.Detections( | |
| xyxy=input_boxes, | |
| class_id=class_ids, | |
| ) | |
| return self.post_process_results(image, text_input, labels, detections, output_dir) | |
| def phrase_grounding_and_segmentation( | |
| self, | |
| image, | |
| text_input, | |
| seg_model="sam", | |
| output_dir=None | |
| ): | |
| assert seg_model in ["sam", "florence2"] | |
| # phrase grounding | |
| grounding_results = self.run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', text_input, image)['<CAPTION_TO_PHRASE_GROUNDING>'] | |
| input_boxes = np.array(grounding_results["bboxes"]) | |
| class_names = grounding_results["labels"] | |
| # print(f"[phrase_grounding_and_segmentation] input_label={text_input}, output_label={class_names}") | |
| class_ids = np.array(list(range(len(class_names)))) | |
| # segmentation | |
| masks, scores = self.segmentation(image, input_boxes, seg_model) | |
| labels = [f"{class_name}" for class_name in class_names] | |
| detections = sv.Detections( | |
| xyxy=input_boxes, | |
| mask=masks.astype(bool), | |
| class_id=class_ids, | |
| confidence=scores, | |
| ) | |
| return self.post_process_results(image, text_input, labels, detections, output_dir) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True) | |
| parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file") | |
| parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption") | |
| args = parser.parse_args() | |
| # IMAGE_PATH = args.image_path | |
| PIPELINE = "caption_to_phrase_grounding" | |
| CAPTION_TYPE = args.caption_type | |
| assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption", "original"] | |
| print(f"Running pipeline: {PIPELINE} now.") | |
| pipeline = FlorenceSAM("cuda:0") | |
| from glob import glob | |
| from tqdm import tqdm | |
| for image_path in tqdm(glob("/mnt/bn/lq-prompt-alignment/personal/chenbowen/code/IPVerse/prompt_alignment/Grounded-SAM-2/notebooks/images/*") * 3): | |
| # for image_path in tqdm(glob("/mnt/bn/lq-prompt-alignment/personal/chenbowen/code/IPVerse/prompt_alignment/Grounded-SAM-2/outputs/gcg_pipeline/00001.tar_debug/*.png")): | |
| print(pipeline.TASK_PROMPT, CAPTION_TYPE) | |
| image = Image.open(image_path).convert("RGB") | |
| pipeline.caption_phrase_grounding_and_segmentation( | |
| image=image, | |
| seg_model="sam", | |
| caption_task_prompt=pipeline.TASK_PROMPT[CAPTION_TYPE], | |
| output_dir=f"./outputs/{os.path.basename(image_path)}" | |
| ) |