Spaces:
Sleeping
Sleeping
| import colorcet as cc | |
| import numpy as np | |
| import skimage | |
| import torch | |
| from utils.transform_utils import inverse_normalize_w_resize | |
| # Define the colors to use for the attention maps | |
| colors = cc.glasbey_category10 | |
| class VisualizeAttentionMaps: | |
| def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, bg_label=0, num_parts=15): | |
| """ | |
| Plot attention maps and optionally landmark centroids on images. | |
| :param snapshot_dir: Directory to save the visualization results | |
| :param save_resolution: Size of the images to save | |
| :param alpha: The transparency of the attention maps | |
| :param bg_label: The background label index in the attention maps | |
| :param num_parts: The number of parts in the attention maps | |
| """ | |
| self.save_resolution = save_resolution | |
| self.alpha = alpha | |
| self.bg_label = bg_label | |
| self.snapshot_dir = snapshot_dir | |
| self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution) | |
| self.num_parts = num_parts | |
| self.figs_size = (10, 10) | |
| def show_maps(self, ims, maps): | |
| """ | |
| Plot images, attention maps and landmark centroids. | |
| Parameters | |
| ---------- | |
| ims: Tensor, [batch_size, 3, width_im, height_im] | |
| Input images on which to show the attention maps | |
| maps: Tensor, [batch_size, number of parts + 1, width_map, height_map] | |
| The attention maps to display | |
| """ | |
| ims = self.resize_unnorm(ims) | |
| ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) | |
| map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution, | |
| mode='bilinear', | |
| align_corners=True).argmax(dim=1).cpu().numpy() | |
| # Select colors for parts which are present | |
| parts_present = np.unique(map_argmax).tolist() | |
| if self.bg_label in parts_present: | |
| parts_present.remove(self.bg_label) | |
| colors_present = [colors[i] for i in parts_present] | |
| curr_map = skimage.color.label2rgb(label=map_argmax[0], image=ims[0], colors=colors_present, | |
| bg_label=self.bg_label, alpha=self.alpha) | |
| return curr_map | |