Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| def get_nonorm_transform(resolution): | |
| nonorm_transform = transforms.Compose( | |
| [transforms.Resize((resolution, resolution), | |
| interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.ToTensor()]) | |
| return nonorm_transform | |
| class FontDataset(Dataset): | |
| """The dataset of font generation | |
| """ | |
| def __init__(self, args, phase, transforms=None): | |
| super().__init__() | |
| self.root = args.data_root | |
| self.phase = phase | |
| # Get Data path | |
| self.get_path() | |
| self.transforms = transforms | |
| self.nonorm_transforms = get_nonorm_transform(args.resolution) | |
| def get_path(self): | |
| self.target_images = [] | |
| # images with related style | |
| self.style_to_images = {} | |
| target_image_dir = f"{self.root}/{self.phase}/TargetImage" | |
| for style in os.listdir(target_image_dir): | |
| images_related_style = [] | |
| for img in os.listdir(f"{target_image_dir}/{style}"): | |
| img_path = f"{target_image_dir}/{style}/{img}" | |
| self.target_images.append(img_path) | |
| images_related_style.append(img_path) | |
| self.style_to_images[style] = images_related_style | |
| def __getitem__(self, index): | |
| target_image_path = self.target_images[index] | |
| target_image_name = target_image_path.split('/')[-1] | |
| style, content = target_image_name.split('.')[0].split('+') | |
| # Read content image | |
| content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg" | |
| content_image = Image.open(content_image_path).convert('RGB') | |
| # Random sample used for style image | |
| images_related_style = self.style_to_images[style].copy() | |
| images_related_style.remove(target_image_path) | |
| style_image_path = random.choice(images_related_style) | |
| style_image = Image.open(style_image_path).convert("RGB") | |
| # Read target image | |
| target_image = Image.open(target_image_path).convert("RGB") | |
| nonorm_target_image = self.nonorm_transforms(target_image) | |
| if self.transforms is not None: | |
| content_image = self.transforms[0](content_image) | |
| style_image = self.transforms[1](style_image) | |
| target_image = self.transforms[2](target_image) | |
| return content_image, style_image, target_image, nonorm_target_image, target_image_path | |
| def __len__(self): | |
| return len(self.target_images) | |