Spaces:
Runtime error
Runtime error
Commit
Β·
0cc37e5
1
Parent(s):
f47bc1e
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,9 +8,6 @@ import gradio as gr
|
|
| 8 |
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
-
# os.system("pip install diffuser==0.6.0")
|
| 12 |
-
# os.system("pip install transformers==4.29.1")
|
| 13 |
-
|
| 14 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 15 |
|
| 16 |
if os.environ.get('IS_MY_DEBUG') is None:
|
|
@@ -69,7 +66,10 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
|
| 69 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
| 70 |
sam_checkpoint = './sam_vit_h_4b8939.pth'
|
| 71 |
output_dir = "outputs"
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
os.makedirs(output_dir, exist_ok=True)
|
| 75 |
groundingdino_model = None
|
|
@@ -77,8 +77,9 @@ sam_device = None
|
|
| 77 |
sam_model = None
|
| 78 |
sam_predictor = None
|
| 79 |
sam_mask_generator = None
|
| 80 |
-
|
| 81 |
lama_cleaner_model= None
|
|
|
|
| 82 |
ram_model = None
|
| 83 |
|
| 84 |
def get_sam_vit_h_4b8939():
|
|
@@ -165,16 +166,6 @@ def load_image(image_path):
|
|
| 165 |
image, _ = transform(image_pil, None) # 3, h, w
|
| 166 |
return image_pil, image
|
| 167 |
|
| 168 |
-
def load_model(model_config_path, model_checkpoint_path, device):
|
| 169 |
-
args = SLConfig.fromfile(model_config_path)
|
| 170 |
-
args.device = device
|
| 171 |
-
model = build_model(args)
|
| 172 |
-
checkpoint = torch.load(model_checkpoint_path, map_location=device) #"cpu")
|
| 173 |
-
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
| 174 |
-
print(load_res)
|
| 175 |
-
_ = model.eval()
|
| 176 |
-
return model
|
| 177 |
-
|
| 178 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
| 179 |
caption = caption.lower()
|
| 180 |
caption = caption.strip()
|
|
@@ -258,18 +249,21 @@ def mix_masks(imgs):
|
|
| 258 |
return Image.fromarray(np.uint8(255*re_img))
|
| 259 |
|
| 260 |
def set_device():
|
| 261 |
-
device
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
def load_groundingdino_model():
|
| 265 |
# initialize groundingdino model
|
| 266 |
global groundingdino_model
|
| 267 |
logger.info(f"initialize groundingdino model...")
|
| 268 |
-
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
| 269 |
|
| 270 |
def load_sam_model():
|
| 271 |
# initialize SAM
|
| 272 |
-
global sam_model, sam_predictor, sam_mask_generator, sam_device
|
| 273 |
logger.info(f"initialize SAM model...")
|
| 274 |
sam_device = device
|
| 275 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
|
@@ -278,26 +272,26 @@ def load_sam_model():
|
|
| 278 |
|
| 279 |
def load_sd_model():
|
| 280 |
# initialize stable-diffusion-inpainting
|
| 281 |
-
global
|
| 282 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
| 283 |
-
|
| 284 |
if os.environ.get('IS_MY_DEBUG') is None:
|
| 285 |
-
|
| 286 |
"runwayml/stable-diffusion-inpainting",
|
| 287 |
revision="fp16",
|
| 288 |
# "stabilityai/stable-diffusion-2-inpainting",
|
| 289 |
torch_dtype=torch.float16,
|
| 290 |
)
|
| 291 |
-
|
| 292 |
|
| 293 |
def load_lama_cleaner_model():
|
| 294 |
# initialize lama_cleaner
|
| 295 |
-
global lama_cleaner_model
|
| 296 |
logger.info(f"initialize lama_cleaner...")
|
| 297 |
|
| 298 |
lama_cleaner_model = ModelManager(
|
| 299 |
name='lama',
|
| 300 |
-
device=
|
| 301 |
)
|
| 302 |
|
| 303 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
|
@@ -517,6 +511,7 @@ mask_source_segment = "type what to detect below"
|
|
| 517 |
|
| 518 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 519 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
|
|
|
|
| 520 |
if (task_type == 'relate anything'):
|
| 521 |
output_images = relate_anything(input_image['image'], num_relation)
|
| 522 |
return output_images, gr.Gallery.update(label='relate images')
|
|
@@ -566,7 +561,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 566 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
| 567 |
)
|
| 568 |
if boxes_filt.size(0) == 0:
|
| 569 |
-
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]
|
| 570 |
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ')
|
| 571 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 572 |
|
|
@@ -640,7 +635,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 640 |
# inpainting pipeline
|
| 641 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
| 642 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
| 643 |
-
image_inpainting =
|
| 644 |
else:
|
| 645 |
# remove from mask
|
| 646 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
|
|
@@ -707,6 +702,8 @@ def change_radio_display(task_type, mask_source_radio):
|
|
| 707 |
|
| 708 |
def get_model_device(module):
|
| 709 |
try:
|
|
|
|
|
|
|
| 710 |
if isinstance(module, torch.nn.DataParallel):
|
| 711 |
module = module.module
|
| 712 |
for submodule in module.children():
|
|
@@ -714,8 +711,9 @@ def get_model_device(module):
|
|
| 714 |
parameters = submodule._parameters
|
| 715 |
if "weight" in parameters:
|
| 716 |
return parameters["weight"].device
|
|
|
|
| 717 |
except Exception as e:
|
| 718 |
-
return '
|
| 719 |
|
| 720 |
if __name__ == "__main__":
|
| 721 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
@@ -732,10 +730,12 @@ if __name__ == "__main__":
|
|
| 732 |
load_lama_cleaner_model()
|
| 733 |
load_ram_model()
|
| 734 |
|
| 735 |
-
os.
|
|
|
|
|
|
|
| 736 |
print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
|
| 737 |
print(f'sam_model__{get_model_device(sam_model)}')
|
| 738 |
-
print(f'sd_model__{get_model_device(
|
| 739 |
print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
|
| 740 |
print(f'ram_model__{get_model_device(ram_model)}')
|
| 741 |
|
|
@@ -790,3 +790,4 @@ if __name__ == "__main__":
|
|
| 790 |
|
| 791 |
computer_info()
|
| 792 |
block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
|
|
|
|
|
|
| 8 |
|
| 9 |
from loguru import logger
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 12 |
|
| 13 |
if os.environ.get('IS_MY_DEBUG') is None:
|
|
|
|
| 66 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
| 67 |
sam_checkpoint = './sam_vit_h_4b8939.pth'
|
| 68 |
output_dir = "outputs"
|
| 69 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
| 70 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 71 |
+
else:
|
| 72 |
+
device = 'cpu'
|
| 73 |
|
| 74 |
os.makedirs(output_dir, exist_ok=True)
|
| 75 |
groundingdino_model = None
|
|
|
|
| 77 |
sam_model = None
|
| 78 |
sam_predictor = None
|
| 79 |
sam_mask_generator = None
|
| 80 |
+
sd_model = None
|
| 81 |
lama_cleaner_model= None
|
| 82 |
+
lama_cleaner_model_device = device
|
| 83 |
ram_model = None
|
| 84 |
|
| 85 |
def get_sam_vit_h_4b8939():
|
|
|
|
| 166 |
image, _ = transform(image_pil, None) # 3, h, w
|
| 167 |
return image_pil, image
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
| 170 |
caption = caption.lower()
|
| 171 |
caption = caption.strip()
|
|
|
|
| 249 |
return Image.fromarray(np.uint8(255*re_img))
|
| 250 |
|
| 251 |
def set_device():
|
| 252 |
+
global device
|
| 253 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
| 254 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 255 |
+
else:
|
| 256 |
+
device = 'cpu'
|
| 257 |
|
| 258 |
def load_groundingdino_model():
|
| 259 |
# initialize groundingdino model
|
| 260 |
global groundingdino_model
|
| 261 |
logger.info(f"initialize groundingdino model...")
|
| 262 |
+
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device='cpu')
|
| 263 |
|
| 264 |
def load_sam_model():
|
| 265 |
# initialize SAM
|
| 266 |
+
global sam_model, sam_predictor, sam_mask_generator, sam_device, device
|
| 267 |
logger.info(f"initialize SAM model...")
|
| 268 |
sam_device = device
|
| 269 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
|
|
|
| 272 |
|
| 273 |
def load_sd_model():
|
| 274 |
# initialize stable-diffusion-inpainting
|
| 275 |
+
global sd_model, device
|
| 276 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
| 277 |
+
sd_model = None
|
| 278 |
if os.environ.get('IS_MY_DEBUG') is None:
|
| 279 |
+
sd_model = StableDiffusionInpaintPipeline.from_pretrained(
|
| 280 |
"runwayml/stable-diffusion-inpainting",
|
| 281 |
revision="fp16",
|
| 282 |
# "stabilityai/stable-diffusion-2-inpainting",
|
| 283 |
torch_dtype=torch.float16,
|
| 284 |
)
|
| 285 |
+
sd_model = sd_model.to(device)
|
| 286 |
|
| 287 |
def load_lama_cleaner_model():
|
| 288 |
# initialize lama_cleaner
|
| 289 |
+
global lama_cleaner_model, device
|
| 290 |
logger.info(f"initialize lama_cleaner...")
|
| 291 |
|
| 292 |
lama_cleaner_model = ModelManager(
|
| 293 |
name='lama',
|
| 294 |
+
device=lama_cleaner_model_device,
|
| 295 |
)
|
| 296 |
|
| 297 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
|
|
|
| 511 |
|
| 512 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 513 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
|
| 514 |
+
|
| 515 |
if (task_type == 'relate anything'):
|
| 516 |
output_images = relate_anything(input_image['image'], num_relation)
|
| 517 |
return output_images, gr.Gallery.update(label='relate images')
|
|
|
|
| 561 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
| 562 |
)
|
| 563 |
if boxes_filt.size(0) == 0:
|
| 564 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
| 565 |
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ')
|
| 566 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 567 |
|
|
|
|
| 635 |
# inpainting pipeline
|
| 636 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
| 637 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
| 638 |
+
image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 639 |
else:
|
| 640 |
# remove from mask
|
| 641 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
|
|
|
|
| 702 |
|
| 703 |
def get_model_device(module):
|
| 704 |
try:
|
| 705 |
+
if module is None:
|
| 706 |
+
return 'None'
|
| 707 |
if isinstance(module, torch.nn.DataParallel):
|
| 708 |
module = module.module
|
| 709 |
for submodule in module.children():
|
|
|
|
| 711 |
parameters = submodule._parameters
|
| 712 |
if "weight" in parameters:
|
| 713 |
return parameters["weight"].device
|
| 714 |
+
return 'UnKnown'
|
| 715 |
except Exception as e:
|
| 716 |
+
return 'Error'
|
| 717 |
|
| 718 |
if __name__ == "__main__":
|
| 719 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
|
|
| 730 |
load_lama_cleaner_model()
|
| 731 |
load_ram_model()
|
| 732 |
|
| 733 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
| 734 |
+
os.system("pip list")
|
| 735 |
+
|
| 736 |
print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
|
| 737 |
print(f'sam_model__{get_model_device(sam_model)}')
|
| 738 |
+
print(f'sd_model__{get_model_device(sd_model)}')
|
| 739 |
print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
|
| 740 |
print(f'ram_model__{get_model_device(ram_model)}')
|
| 741 |
|
|
|
|
| 790 |
|
| 791 |
computer_info()
|
| 792 |
block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
|
| 793 |
+
|