Skier8402 commited on
Commit
ff7112c
·
verified ·
1 Parent(s): a66e3ec

Upload obj_detection_DETR.py

Browse files
Files changed (1) hide show
  1. obj_detection_DETR.py +123 -0
obj_detection_DETR.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Usage of DETR with Captum for interpretability.
3
+
4
+ Demonstrates Grad-CAM and Integrated Gradients on object detection.
5
+
6
+ On random COCO image, picks a detection and visualizes attributions.
7
+ Appeals to developers and ML practitioners interested in model interpretability.
8
+
9
+ '''
10
+
11
+ import torch, requests, numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from PIL import Image
14
+ from transformers import DetrImageProcessor, DetrForObjectDetection
15
+ from torchvision.transforms.functional import resize
16
+ from captum.attr import IntegratedGradients
17
+
18
+ # ---------------- 1. Load DETR ----------------
19
+ model_name = "facebook/detr-resnet-50"
20
+ model = DetrForObjectDetection.from_pretrained(model_name)
21
+ feature_extractor = DetrImageProcessor.from_pretrained(model_name)
22
+ model.eval()
23
+
24
+ # ---------------- 2. Load an image ----------------
25
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg" # dog+cat
26
+ img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
27
+
28
+ # ---------------- 3. Preprocess & forward ----------------
29
+ inputs = feature_extractor(images=img, return_tensors="pt")
30
+ pixel_values = inputs["pixel_values"]
31
+ outputs = model(pixel_values)
32
+
33
+ target_sizes = torch.tensor([img.size[::-1]])
34
+ # use the updated post_process_object_detection API
35
+ results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
36
+
37
+ # ---------------- 4. Pick detection ----------------
38
+ keep = results["scores"] > 0.7
39
+ boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
40
+ chosen_idx = 0
41
+ chosen_label = labels[chosen_idx].item()
42
+ chosen_name = model.config.id2label[chosen_label]
43
+ score_val = float(scores[chosen_idx].detach().cpu().item()) if isinstance(scores[chosen_idx], torch.Tensor) else float(scores[chosen_idx])
44
+ print(f"Chosen detection: {chosen_name}, score={score_val:.2f}")
45
+
46
+ # ---------------- 5. Grad-CAM ----------------
47
+ # Find a suitable convolutional layer in the backbone (robust to implementation details)
48
+ backbone = getattr(model.model, "backbone", None)
49
+ conv_layer = None
50
+ if backbone is not None:
51
+ for name, module in reversed(list(backbone.named_modules())):
52
+ if isinstance(module, torch.nn.Conv2d):
53
+ conv_layer = module
54
+ conv_name = name
55
+ break
56
+ # fallback to searching entire model
57
+ if conv_layer is None:
58
+ for name, module in reversed(list(model.named_modules())):
59
+ if isinstance(module, torch.nn.Conv2d):
60
+ conv_layer = module
61
+ conv_name = name
62
+ break
63
+ if conv_layer is None:
64
+ raise RuntimeError("No Conv2d layer found for Grad-CAM")
65
+
66
+ activations, gradients = {}, {}
67
+ def forward_hook(m, i, o): activations["value"] = o.detach()
68
+ # register_full_backward_hook is preferred where available
69
+ if hasattr(conv_layer, "register_full_backward_hook"):
70
+ conv_layer.register_forward_hook(forward_hook)
71
+ conv_layer.register_full_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()}))
72
+ else:
73
+ conv_layer.register_forward_hook(forward_hook)
74
+ conv_layer.register_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()}))
75
+
76
+ # Previously we computed outputs before registering hooks, so hooks didn't capture activations.
77
+ # Re-run a forward pass with inputs that require gradients, then backprop on the chosen detection logit.
78
+ # determine the query index corresponding to the chosen kept detection (from earlier results)
79
+ keep_idxs = torch.nonzero(keep).squeeze()
80
+ if keep_idxs.dim() == 0:
81
+ chosen_query_idx = int(keep_idxs.item())
82
+ else:
83
+ chosen_query_idx = int(keep_idxs[chosen_idx].item())
84
+
85
+ # prepare pixel_values for gradient computation and re-run forward to trigger hooks
86
+ pixel_values_for_grad = pixel_values.clone().detach().requires_grad_(True)
87
+ outputs_for_grad = model(pixel_values_for_grad)
88
+
89
+ # select the logit for that query & class and backpropagate
90
+ score_for_grad = outputs_for_grad.logits[0, chosen_query_idx, chosen_label]
91
+ model.zero_grad()
92
+ score_for_grad.backward()
93
+
94
+ # now activations and gradients should be populated by the hooks
95
+ acts = activations["value"].squeeze(0) # (C,H,W)
96
+ grads = gradients["value"].squeeze(0)
97
+ weights = grads.mean(dim=(1,2))
98
+ cam = torch.relu((weights[:,None,None] * acts).sum(0))
99
+ cam = cam / cam.max()
100
+ cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0,0].numpy()
101
+
102
+ # ---------------- 6. Integrated Gradients ----------------
103
+ # pick the chosen query index (as above) and create a forward function that returns a scalar logit per input
104
+ def forward_func(pixel_values):
105
+ out = model(pixel_values=pixel_values)
106
+ # return the selected query/class logit as a 1-D tensor (batch,)
107
+ return out.logits[:, chosen_query_idx, chosen_label]
108
+
109
+ ig = IntegratedGradients(forward_func)
110
+ # since forward_func already returns a scalar logit per sample, don't pass target
111
+ attributions, _ = ig.attribute(pixel_values, n_steps=25, return_convergence_delta=True)
112
+
113
+ attr = attributions.squeeze().mean(0).cpu().detach().numpy()
114
+ attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)
115
+
116
+ # ---------------- 7. Visualize ----------------
117
+ fig, axs = plt.subplots(1,3, figsize=(16,6))
118
+ axs[0].imshow(img); axs[0].set_title(f"Original: {chosen_name}"); axs[0].axis("off")
119
+ axs[1].imshow(img); axs[1].imshow(cam_resized, cmap="jet", alpha=0.5)
120
+ axs[1].set_title("Grad-CAM heatmap"); axs[1].axis("off")
121
+ axs[2].imshow(img); axs[2].imshow(attr, cmap="hot", alpha=0.5)
122
+ axs[2].set_title("Integrated Gradients"); axs[2].axis("off")
123
+ plt.show()