Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from PIL.Image import Image as PILImage | |
| from torch import Tensor | |
| import numpy as np | |
| from skimage.metrics import structural_similarity as calculate_ssim | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| CLIPImageProcessor, CLIPVisionModelWithProjection, | |
| CLIPTokenizer, CLIPTextModelWithProjection, | |
| ) | |
| import ImageReward as RM | |
| from kiui.lpips import LPIPS | |
| class TextConditionMetrics: | |
| def __init__(self, | |
| clip_name: str = "openai/clip-vit-base-patch32", | |
| rm_name: str = "ImageReward-v1.0", | |
| device_idx: int = 0, | |
| ): | |
| self.image_processor = CLIPImageProcessor.from_pretrained(clip_name) | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(clip_name) | |
| self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval() | |
| self.rm_model = RM.load(rm_name) | |
| self.device = f"cuda:{device_idx}" | |
| def evaluate(self, | |
| image: Union[PILImage, List[PILImage]], | |
| text: Union[str, List[str]], | |
| ) -> Tuple[float, float, float]: | |
| if isinstance(image, PILImage): | |
| image = [image] | |
| if isinstance(text, str): | |
| text = [text] | |
| assert len(image) == len(text) | |
| image_inputs = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device) | |
| image_embeds = self.image_encoder(image_inputs).image_embeds.float() # (N, D) | |
| image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
| text_inputs = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| text_embeds = self.text_encoder(text_input_ids).text_embeds.float() # (N, D) | |
| text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
| assert image_embeds.shape == text_embeds.shape | |
| clip_scores = image_embeds @ text_embeds.T # (N, N) | |
| # 1. CLIP similarity | |
| clip_sim = clip_scores.diag().mean().item() | |
| # 2. CLIP R-Precision | |
| clip_rprec = (clip_scores.argmax(dim=1) == torch.arange(len(text)).to(self.device)).float().mean().item() | |
| # 3. ImageReward | |
| rm_scores = [] | |
| for img, txt in zip(image, text): | |
| rm_scores.append(self.rm_model.score(txt, img)) | |
| rm_scores = torch.tensor(rm_scores, device=self.device) | |
| rm_score = rm_scores.mean().item() | |
| return clip_sim, clip_rprec, rm_score | |
| class ImageConditionMetrics: | |
| def __init__(self, | |
| lpips_net: str = "vgg", | |
| lpips_res: int = 256, | |
| device_idx: int = 0, | |
| ): | |
| self.lpips_loss = LPIPS(net=lpips_net).to(f"cuda:{device_idx}").eval() | |
| self.lpips_res = lpips_res | |
| self.device = f"cuda:{device_idx}" | |
| def evaluate(self, | |
| image: Union[Tensor, PILImage, List[PILImage]], | |
| gt: Union[Tensor, PILImage, List[PILImage]], | |
| chunk_size: Optional[int] = None, | |
| input_tensor: bool = False, | |
| ) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]: | |
| if not input_tensor: | |
| if isinstance(image, PILImage): | |
| image = [image] | |
| if isinstance(gt, PILImage): | |
| gt = [gt] | |
| assert len(image) == len(gt) | |
| if chunk_size is None: | |
| chunk_size = len(image) | |
| def image_to_tensor(img: PILImage): | |
| return torch.tensor(np.array(img).transpose(2, 0, 1) / 255., device=self.device).unsqueeze(0).float() # (1, 3, H, W) | |
| image_pt = torch.cat([image_to_tensor(img) for img in image], dim=0) | |
| gt_pt = torch.cat([image_to_tensor(img) for img in gt], dim=0) | |
| else: | |
| image_pt = image.to(device=self.device) | |
| gt_pt = gt.to(device=self.device) | |
| # 1. LPIPS | |
| lpips = [] | |
| for i in range(0, len(image), chunk_size): | |
| _lpips = self.lpips_loss( | |
| F.interpolate( | |
| image_pt[i:min(len(image), i+chunk_size)] * 2. - 1., | |
| (self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False | |
| ), | |
| F.interpolate( | |
| gt_pt[i:min(len(image), i+chunk_size)] * 2. - 1., | |
| (self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False | |
| ) | |
| ) | |
| lpips.append(_lpips) | |
| lpips = torch.cat(lpips) | |
| lpips_mean, lpips_std = lpips.mean().item(), lpips.std().item() | |
| # 2. PSNR | |
| psnr = -10. * torch.log10((gt_pt - image_pt).pow(2).mean(dim=[1, 2, 3])) | |
| psnr_mean, psnr_std = psnr.mean().item(), psnr.std().item() | |
| # 3. SSIM | |
| ssim = [] | |
| for i in range(len(image)): | |
| _ssim = calculate_ssim( | |
| (image_pt[i].cpu().float().numpy() * 255.).astype(np.uint8), | |
| (gt_pt[i].cpu().float().numpy() * 255.).astype(np.uint8), | |
| channel_axis=0, | |
| ) | |
| ssim.append(_ssim) | |
| ssim = np.array(ssim) | |
| ssim_mean, ssim_std = ssim.mean(), ssim.std() | |
| return (psnr_mean, psnr_std), (ssim_mean, ssim_std), (lpips_mean, lpips_std) | |