Spaces:
Sleeping
Sleeping
Commit
·
ba377fa
1
Parent(s):
5daa6d0
improve visualization code
Browse files
utils/visualize_att_maps.py
CHANGED
|
@@ -26,7 +26,6 @@ class VisualizeAttentionMaps:
|
|
| 26 |
|
| 27 |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
| 28 |
self.num_parts = num_parts
|
| 29 |
-
self.req_colors = colors[:num_parts]
|
| 30 |
self.figs_size = (10, 10)
|
| 31 |
|
| 32 |
@torch.no_grad()
|
|
@@ -45,7 +44,11 @@ class VisualizeAttentionMaps:
|
|
| 45 |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
| 46 |
mode='bilinear',
|
| 47 |
align_corners=True).argmax(dim=1).cpu().numpy()
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
bg_label=self.bg_label, alpha=self.alpha)
|
| 51 |
return curr_map
|
|
|
|
| 26 |
|
| 27 |
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
| 28 |
self.num_parts = num_parts
|
|
|
|
| 29 |
self.figs_size = (10, 10)
|
| 30 |
|
| 31 |
@torch.no_grad()
|
|
|
|
| 44 |
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
| 45 |
mode='bilinear',
|
| 46 |
align_corners=True).argmax(dim=1).cpu().numpy()
|
| 47 |
+
# Select colors for parts which are present
|
| 48 |
+
parts_present = np.unique(map_argmax).tolist()
|
| 49 |
+
if self.bg_label in parts_present:
|
| 50 |
+
parts_present.remove(self.bg_label)
|
| 51 |
+
colors_present = [colors[i] for i in parts_present]
|
| 52 |
+
curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=colors_present,
|
| 53 |
bg_label=self.bg_label, alpha=self.alpha)
|
| 54 |
return curr_map
|