Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from PIL.Image import Image as PILImage | |
| from numpy import ndarray | |
| from torch import Tensor | |
| from wandb import Image as WandbImage | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| import wandb | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def wandb_mvimage_log(outputs: Dict[str, Tensor], max_num: int = 4, max_view: int = 8) -> List[WandbImage]: | |
| """Organize multi-view images in Dict `outputs` for wandb logging. | |
| Only process values in Dict `outputs` that have keys containing the word "images", | |
| which should be in the shape of (B, V, 3, H, W). | |
| """ | |
| formatted_images = [] | |
| for k in outputs.keys(): | |
| if "images" in k and outputs[k] is not None: # (B, V, 3, H, W) | |
| assert outputs[k].ndim == 5 | |
| num, view = outputs[k].shape[:2] | |
| num, view = min(num, max_num), min(view, max_view) | |
| mvimages = rearrange(outputs[k][:num, :view], "b v c h w -> c (b h) (v w)") | |
| formatted_images.append( | |
| wandb.Image( | |
| tensor_to_image(mvimages.detach()), | |
| caption=k | |
| ) | |
| ) | |
| return formatted_images | |
| def tensor_to_image(tensor: Tensor, return_pil: bool = False) -> Union[ndarray, PILImage]: | |
| if tensor.ndim == 4: # (B, C, H, W) | |
| tensor = rearrange(tensor, "b c h w -> c h (b w)") | |
| assert tensor.ndim == 3 # (C, H, W) | |
| assert tensor.shape[0] in [1, 3] # grayscale, RGB (not consider RGBA here) | |
| if tensor.shape[0] == 1: | |
| tensor = tensor.repeat(3, 1, 1) | |
| image = (tensor.permute(1, 2, 0).cpu().float().numpy() * 255).astype(np.uint8) # (H, W, C) | |
| if return_pil: | |
| image = Image.fromarray(image) | |
| return image | |
| def load_image(image_path: str, rgba: bool = False, imagenet_norm: bool = False) -> Tensor: | |
| image = Image.open(image_path) | |
| tensor_image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255. # (C, H, W) in [0, 1] | |
| if not rgba and tensor_image.shape[0] == 4: | |
| mask = tensor_image[3:4] | |
| tensor_image = tensor_image[:3] * mask + (1. - mask) # white background | |
| if imagenet_norm: | |
| mean = torch.tensor(IMAGENET_MEAN, dtype=tensor_image.dtype, device=tensor_image.device).view(3, 1, 1) | |
| std = torch.tensor(IMAGENET_STD, dtype=tensor_image.dtype, device=tensor_image.device).view(3, 1, 1) | |
| tensor_image = (tensor_image - mean) / std | |
| return tensor_image # (C, H, W) | |