full_gaussian_avatar / GHA /preprocess /lib /MultiviewDataset.py
pengc02's picture
all
ec9a6bc
import torch
import torchvision as tv
import numpy as np
import glob
import os
import random
from torch.utils.data import Dataset
class MultiviewDataset(Dataset):
def __init__(self, data_folder, image_size, num_view):
super(MultiviewDataset, self).__init__()
self.image_size = image_size
self.num_view = num_view
self.loader = tv.datasets.folder.default_loader
self.transform = tv.transforms.Compose([tv.transforms.Resize(self.image_size), tv.transforms.ToTensor()])
self.folders = sorted(glob.glob(os.path.join(data_folder, '*')))
self.camera_ids = ['220700191', '221501007', '222200036', '222200037', '222200038', '222200039', '222200040', '222200041',
'222200042', '222200043', '222200044', '222200045', '222200046', '222200047', '222200048', '222200049']
def get_item(self, index):
data = self.__getitem__(index)
return data
def __getitem__(self, index):
images = torch.stack([self.transform(self.loader(self.folders[index] + '/image_%s.jpg' % self.camera_ids[v])) for v in range(self.num_view)])
intrinsics = torch.stack([torch.from_numpy(np.load(self.folders[index] + '/camera_%s.npz' % self.camera_ids[v])['intrinsic']) for v in range(self.num_view)]).float()
extrinsics = torch.stack([torch.from_numpy(np.load(self.folders[index] + '/camera_%s.npz' % self.camera_ids[v])['extrinsic']) for v in range(self.num_view)]).float()
return {'images': images,
'intrinsics': intrinsics,
'extrinsics': extrinsics,
'exp_id': int(self.folders[index][-4:])}
def __len__(self):
return len(self.folders)