Spaces:
Runtime error
Runtime error
File size: 6,147 Bytes
6f3aa0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# 1. Import dependencies
import gradio as gr
import torch
# import spaces # for GPU usage
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection
# 2. Setup preprocessing and model functions
model_save_path = "Sairii/rt_detrv2_finetuned_trashify_box_detector_v1"
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
image_processor.size = {"height":640,
"width":640}
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
# Setup the target device (use GPU if it's accesible)
# Note if you want to use a GPU in your Space, you can use a ZeroGPU: https://huggingface.co/docs/hub/spaces-zerogpu
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Setup a color dictionary
color_dict = {
"bin": "green",
"trash": "blue",
"hand": "purple",
"trash_arm": "yellow",
"not_trash": "red",
"not_bin": "red",
"not_hand": "red"
}
# 3. Create a function to predict on image
def predict_on_image(image, conf_threshold):
model.eval()
# Make a prediction on target image
with torch.no_grad():
inputs = image_processor(images = [image], return_tensors = "pt")
model_outputs = model(**inputs.to(device))
# Get original size of image
# PIL.Image.size = width, height
# But post_process_object_detection requires height, width
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
# Post process the raw outputs from the model
results = image_processor.post_process_object_detection(
model_outputs = model_outputs,
threshold = conf_threshold,
target_sizes = target_sizes
)[0]
# Return all data items/objects to the CPU if they aren't already there
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalars as .items() so add try/except block
except:
results[key] = value.cpu()
### 4. Draw the predictions on the target iamge ###
draw = ImageDraw.Draw(image)
# Get a font to write on our image
font = ImageFont.load_default()
# Get a list of the detect class names
detected_class_names_text_labels = []
# Iterate throught the predictions of the model and draw them on the target image
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
detected_class_names_text_labels.append(label_name)
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# Remove the draw each time
del draw
### 5. Create logic for outputting information message
# Set up set of target items to discover
target_items = {"bin", "trash", "hand"}
detected_items = set(detected_class_names_text_labels)
# If no items detected or bin, trash, hand not in detected items, return notifications
if not detected_items & target_items:
return_string = {
f"No trash, bin or hand detected at confidence threshold of {conf_threshold}"
f"Try another image or lowerin the confidence threshold"
}
print(return_string)
return image, return_string
# If there are items missing, output what's missing for +1 point
missing_items = target_items - detected_items
if missing_items:
return_string (
f" Detected the following items: {sorted(detected_class_names_text_labels)}"
f" Missing the following items: {sorted(missing_items)}"
"In order to get +1 points, all target items must be detected"
)
print(return_string)
return image, return_string
# Final case, all items are detected
return_string = f"+1: Found the following items: {sorted(detected_items)}, thank you for cleaning up your local area"
return image, return_string
### 6. Setup the demo application to take in image, make a prediction with our model, return the image with drawn predicitons ###
# Write description for our demo application
description = """
Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
This version is v1 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more.
"""
# Create the Gradio interface
demo = gr.Interface(
fn = predict_on_image,
inputs = [
gr.Image(type="pil", label="Target Input Image"),
gr.Slider(0, 1, value=0.3, label="Confidence Threshold set higher for more confident boxes")
],
outputs = [
gr.Image(type="pil", label = "Target Image Output"),
gr.Text(label="Text output")
],
description = description,
title = "Trashify Demo V1",
examples=[
["trashify_examples/trashify_example_1.jpeg", 0.3],
["trashify_examples/trashify_example_2.jpeg", 0.3],
["trashify_examples/trashify_example_3.jpeg", 0.3],
],
cache_examples = True
)
# Laun demo
demo.launch(debug=True)
|