Spaces:
Runtime error
Runtime error
| # 1. Import dependencies | |
| import gradio as gr | |
| import torch | |
| # import spaces # for GPU usage | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
| # 2. Setup preprocessing and model functions | |
| model_save_path = "Sairii/rt_detrv2_finetuned_trashify_box_detector_v1" | |
| image_processor = AutoImageProcessor.from_pretrained(model_save_path) | |
| image_processor.size = {"height":640, | |
| "width":640} | |
| model = AutoModelForObjectDetection.from_pretrained(model_save_path) | |
| # Setup the target device (use GPU if it's accesible) | |
| # Note if you want to use a GPU in your Space, you can use a ZeroGPU: https://huggingface.co/docs/hub/spaces-zerogpu | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # Get the id2label dictionary from the model | |
| id2label = model.config.id2label | |
| # Setup a color dictionary | |
| color_dict = { | |
| "bin": "green", | |
| "trash": "blue", | |
| "hand": "purple", | |
| "trash_arm": "yellow", | |
| "not_trash": "red", | |
| "not_bin": "red", | |
| "not_hand": "red" | |
| } | |
| # 3. Create a function to predict on image | |
| def predict_on_image(image, conf_threshold): | |
| model.eval() | |
| # Make a prediction on target image | |
| with torch.no_grad(): | |
| inputs = image_processor(images = [image], return_tensors = "pt") | |
| model_outputs = model(**inputs.to(device)) | |
| # Get original size of image | |
| # PIL.Image.size = width, height | |
| # But post_process_object_detection requires height, width | |
| target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width] | |
| # Post process the raw outputs from the model | |
| results = image_processor.post_process_object_detection( | |
| model_outputs = model_outputs, | |
| threshold = conf_threshold, | |
| target_sizes = target_sizes | |
| )[0] | |
| # Return all data items/objects to the CPU if they aren't already there | |
| for key, value in results.items(): | |
| try: | |
| results[key] = value.item().cpu() # can't get scalars as .items() so add try/except block | |
| except: | |
| results[key] = value.cpu() | |
| ### 4. Draw the predictions on the target iamge ### | |
| draw = ImageDraw.Draw(image) | |
| # Get a font to write on our image | |
| font = ImageFont.load_default() | |
| # Get a list of the detect class names | |
| detected_class_names_text_labels = [] | |
| # Iterate throught the predictions of the model and draw them on the target image | |
| for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
| # Create coordinates | |
| x, y, x2, y2 = tuple(box.tolist()) | |
| # Get label_name | |
| label_name = id2label[label.item()] | |
| targ_color = color_dict[label_name] | |
| detected_class_names_text_labels.append(label_name) | |
| # Draw the rectangle | |
| draw.rectangle(xy=(x, y, x2, y2), | |
| outline=targ_color, | |
| width=3) | |
| # Create a text string to display | |
| text_string_to_show = f"{label_name} ({round(score.item(), 3)})" | |
| # Draw the text on the image | |
| draw.text(xy=(x, y), | |
| text=text_string_to_show, | |
| fill="white", | |
| font=font) | |
| # Remove the draw each time | |
| del draw | |
| ### 5. Create logic for outputting information message | |
| # Set up set of target items to discover | |
| target_items = {"bin", "trash", "hand"} | |
| detected_items = set(detected_class_names_text_labels) | |
| # If no items detected or bin, trash, hand not in detected items, return notifications | |
| if not detected_items & target_items: | |
| return_string = { | |
| f"No trash, bin or hand detected at confidence threshold of {conf_threshold}" | |
| f"Try another image or lowerin the confidence threshold" | |
| } | |
| print(return_string) | |
| return image, return_string | |
| # If there are items missing, output what's missing for +1 point | |
| missing_items = target_items - detected_items | |
| if missing_items: | |
| return_string ( | |
| f" Detected the following items: {sorted(detected_class_names_text_labels)}" | |
| f" Missing the following items: {sorted(missing_items)}" | |
| "In order to get +1 points, all target items must be detected" | |
| ) | |
| print(return_string) | |
| return image, return_string | |
| # Final case, all items are detected | |
| return_string = f"+1: Found the following items: {sorted(detected_items)}, thank you for cleaning up your local area" | |
| return image, return_string | |
| ### 6. Setup the demo application to take in image, make a prediction with our model, return the image with drawn predicitons ### | |
| # Write description for our demo application | |
| description = """ | |
| Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand. | |
| Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). | |
| See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial). | |
| This version is v1 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more. | |
| """ | |
| # Create the Gradio interface | |
| demo = gr.Interface( | |
| fn = predict_on_image, | |
| inputs = [ | |
| gr.Image(type="pil", label="Target Input Image"), | |
| gr.Slider(0, 1, value=0.3, label="Confidence Threshold set higher for more confident boxes") | |
| ], | |
| outputs = [ | |
| gr.Image(type="pil", label = "Target Image Output"), | |
| gr.Text(label="Text output") | |
| ], | |
| description = description, | |
| title = "Trashify Demo V1", | |
| examples=[ | |
| ["trashify_examples/trashify_example_1.jpeg", 0.3], | |
| ["trashify_examples/trashify_example_2.jpeg", 0.3], | |
| ["trashify_examples/trashify_example_3.jpeg", 0.3], | |
| ], | |
| cache_examples = True | |
| ) | |
| # Laun demo | |
| demo.launch(debug=True) | |