Spaces:
Runtime error
Runtime error
| from calendar import c | |
| import os | |
| # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
| # os.environ['TORCH_USE_CUDA_DSA'] = '1' | |
| os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
| import yaml | |
| import shutil | |
| import collections | |
| import torch | |
| import torch.utils.data | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 as cv | |
| import glob | |
| import datetime | |
| import trimesh | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| import importlib | |
| # import config | |
| from omegaconf import OmegaConf | |
| import json | |
| # AnimatableGaussians part | |
| from AnimatableGaussians.network.lpips import LPIPS | |
| from AnimatableGaussians.dataset.dataset_pose import PoseDataset | |
| import AnimatableGaussians.utils.net_util as net_util | |
| import AnimatableGaussians.utils.visualize_util as visualize_util | |
| from AnimatableGaussians.utils.renderer import Renderer | |
| from AnimatableGaussians.utils.net_util import to_cuda | |
| from AnimatableGaussians.utils.obj_io import save_mesh_as_ply | |
| from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply | |
| import AnimatableGaussians.config as ag_config | |
| # Gaussian-Head-Avatar part | |
| from GHA.config.config import config_reenactment | |
| from GHA.lib.dataset.Dataset import ReenactmentDataset | |
| from GHA.lib.dataset.DataLoaderX import DataLoaderX | |
| from GHA.lib.module.GaussianHeadModule import GaussianHeadModule | |
| from GHA.lib.module.SuperResolutionModule import SuperResolutionModule | |
| from GHA.lib.module.CameraModule import CameraModule | |
| from GHA.lib.recorder.Recorder import ReenactmentRecorder | |
| from GHA.lib.apps.Reenactment import Reenactment | |
| # cat utils | |
| from calc_offline_rendering_param import calc_offline_rendering_param | |
| import ipdb | |
| class Avatar: | |
| def __init__(self, config): | |
| self.config = config | |
| self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| # animateble gaussians part init | |
| self.body = config.animatablegaussians | |
| self.body.mode = 'test' | |
| ag_config.set_opt(self.body) | |
| avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar') | |
| print('Import AvatarNet from %s' % avatar_module) | |
| AvatarNet = importlib.import_module(avatar_module).AvatarNet | |
| self.avatar_net = AvatarNet(self.body.model).to(self.device) | |
| self.random_bg_color = self.body['train'].get('random_bg_color', True) | |
| self.bg_color = (1., 1., 1.) | |
| self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device) | |
| self.loss_weight = self.body['train']['loss_weight'] | |
| self.finetune_color = self.body['train']['finetune_color'] | |
| print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) | |
| # gaussian head avatar part init | |
| self.head = config.gha | |
| self.head_config = config_reenactment() | |
| self.head_config.load(self.head.config_path) | |
| self.head_config = self.head_config.get_cfg() | |
| # cat utils part init | |
| self.cat = config.cat | |
| def test_body(self): | |
| # run the animatable gaussian test | |
| self.avatar_net.eval() | |
| dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX') | |
| MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
| training_dataset = MvRgbDataset(**self.body['train']['data'], training = False) | |
| if self.body['test'].get('n_pca', -1) >= 1: | |
| training_dataset.compute_pca(n_components = self.body['test']['n_pca']) | |
| if 'pose_data' in self.body.test: | |
| testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0]) | |
| dataset_name = testing_dataset.dataset_name | |
| seq_name = testing_dataset.seq_name | |
| else: | |
| # throw an error | |
| raise ValueError('No pose data in test config') | |
| self.dataset = testing_dataset | |
| # iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1] | |
| output_dir = self.body['test'].get('output_dir', None) | |
| if output_dir is None: | |
| raise ValueError('No output_dir in test config') | |
| use_pca = self.body['test'].get('n_pca', -1) >= 1 | |
| if use_pca: | |
| output_dir += '/pca_%d_sigma_%.2f' % (self.body['test'].get('n_pca', -1), float(self.body['test'].get('sigma_pca', 1.))) | |
| else: | |
| output_dir += '/vanilla' | |
| print('# Output dir: \033[1;31m%s\033[0m' % output_dir) | |
| os.makedirs(output_dir + '/live_skeleton', exist_ok = True) | |
| os.makedirs(output_dir + '/rgb_map', exist_ok = True) | |
| os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True) | |
| os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
| os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
| os.makedirs(output_dir + '/posed_gaussians', exist_ok = True) | |
| os.makedirs(output_dir + '/posed_params', exist_ok = True) | |
| os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
| geo_renderer = None | |
| item_0 = self.dataset.getitem(0, training = False) | |
| object_center = item_0['live_bounds'].mean(0) | |
| global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] | |
| # set x and z to 0 | |
| global_orient[0] = 0 | |
| global_orient[2] = 0 | |
| global_orient = cv.Rodrigues(global_orient)[0] | |
| time_start = torch.cuda.Event(enable_timing = True) | |
| time_start_all = torch.cuda.Event(enable_timing = True) | |
| time_end = torch.cuda.Event(enable_timing = True) | |
| data_num = len(self.dataset) | |
| if self.body['test'].get('fix_hand', False): | |
| self.avatar_net.generate_mean_hands() | |
| log_time = False | |
| extr_list = [] | |
| intr_list = [] | |
| img_h_list = [] | |
| img_w_list = [] | |
| for idx in tqdm(range(data_num), desc = 'Rendering avatars...'): | |
| if log_time: | |
| time_start.record() | |
| time_start_all.record() | |
| img_scale = self.body['test'].get('img_scale', 1.0) | |
| view_setting = self.body['test'].get('view_setting', 'free') | |
| if view_setting == 'camera': | |
| # training view setting | |
| cam_id = self.body['test']['render_view_idx'] | |
| intr = self.dataset.intr_mats[cam_id].copy() | |
| intr[:2] *= img_scale | |
| extr = self.dataset.extr_mats[cam_id].copy() | |
| img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale) | |
| elif view_setting.startswith('free'): | |
| # free view setting | |
| # frame_num_per_circle = 360 | |
| # print(self.opt['test'].get('global_orient', False)) | |
| frame_num_per_circle = 360 | |
| rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = rot_Y, | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| elif view_setting.startswith('degree120'): | |
| print('we render 120 degree') | |
| # +- 60 degree | |
| frame_per_cycle = 480 | |
| max_degree = 60 | |
| frame_half_cycle = frame_per_cycle // 2 | |
| if idx%frame_per_cycle < frame_per_cycle/2: | |
| rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
| else: | |
| rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # to radian | |
| rot_Y = rot_Y * np.pi / 180 | |
| if rot_Y<0: | |
| rot_Y = rot_Y + 2 * np.pi | |
| # print('rot_Y: ', rot_Y) | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = rot_Y, | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| elif view_setting.startswith('degree90'): | |
| print('we render 90 degree') | |
| # +- 60 degree | |
| frame_per_cycle = 360 | |
| max_degree = 45 | |
| frame_half_cycle = frame_per_cycle // 2 | |
| if idx%frame_per_cycle < frame_per_cycle/2: | |
| rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
| else: | |
| rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # to radian | |
| rot_Y = rot_Y * np.pi / 180 | |
| if rot_Y<0: | |
| rot_Y = rot_Y + 2 * np.pi | |
| # print('rot_Y: ', rot_Y) | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = rot_Y, | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| elif view_setting.startswith('front'): | |
| # front view setting | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = 0., | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| # print('extr: ', extr) | |
| # print('intr: ', intr) | |
| # print('img_h: ', img_h) | |
| # print('img_w: ', img_w) | |
| # exit() | |
| elif view_setting.startswith('back'): | |
| # back view setting | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = np.pi, | |
| rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| elif view_setting.startswith('moving'): | |
| # moving camera setting | |
| extr = visualize_util.calc_free_mv(object_center, | |
| # tar_pos = np.array([0, 0, 3.0]), | |
| # rot_Y = -0.3, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = 0., | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| elif view_setting.startswith('cano'): | |
| cano_center = self.dataset.cano_bounds.mean(0) | |
| extr = np.identity(4, np.float32) | |
| extr[:3, 3] = -cano_center | |
| rot_x = np.identity(4, np.float32) | |
| rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0] | |
| extr = rot_x @ extr | |
| f_len = 5000 | |
| extr[2, 3] += f_len / 512 | |
| intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32) | |
| # item = self.dataset.getitem(idx, | |
| # training = False, | |
| # extr = extr, | |
| # intr = intr, | |
| # img_w = 1024, | |
| # img_h = 1024) | |
| img_w, img_h = 1024, 1024 | |
| # item['live_smpl_v'] = item['cano_smpl_v'] | |
| # item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1) | |
| # item['live_bounds'] = item['cano_bounds'] | |
| else: | |
| raise ValueError('Invalid view setting for animation!') | |
| self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list) | |
| # also save the extr and intr and img_h and img_w to json | |
| camera_info = [] | |
| for i in range(len(extr_list)): | |
| camera = {} | |
| camera['extr'] = extr_list[i].tolist() | |
| camera['intr'] = intr_list[i].tolist() | |
| camera['img_h'] = img_h_list[i] | |
| camera['img_w'] = img_w_list[i] | |
| camera_info.append(camera) | |
| with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp: | |
| json.dump(camera_info, fp) | |
| getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem | |
| item = getitem_func( | |
| idx, | |
| training = False, | |
| extr = extr, | |
| intr = intr, | |
| img_w = img_w, | |
| img_h = img_h | |
| ) | |
| items = to_cuda(item, add_batch = False) | |
| if view_setting.startswith('moving') or view_setting == 'free_moving': | |
| current_center = items['live_bounds'].cpu().numpy().mean(0) | |
| delta = current_center - object_center | |
| object_center[0] += delta[0] | |
| # object_center[1] += delta[1] | |
| # object_center[2] += delta[2] | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| if self.body['test'].get('render_skeleton', False): | |
| from AnimatableGaussians.utils.visualize_skeletons import construct_skeletons | |
| skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy()) | |
| skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False) | |
| if geo_renderer is None: | |
| geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1)) | |
| extr, intr = item['extr'], item['intr'] | |
| geo_renderer.set_camera(extr, intr) | |
| geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)]) | |
| skel_img = geo_renderer.render()[:, :, :3] | |
| skel_img = (skel_img * 255).astype(np.uint8) | |
| cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| if 'smpl_pos_map' not in items: | |
| self.avatar_net.get_pose_map(items) | |
| # pca | |
| if use_pca: | |
| mask = training_dataset.pos_map_mask | |
| live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() | |
| front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) | |
| pose_conds = front_live_pos_map[mask] | |
| new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.))) | |
| front_live_pos_map[mask] = new_pose_conds | |
| live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) | |
| items.update({ | |
| 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1) | |
| }) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) | |
| output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) | |
| mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| if 'rgb_map' in output_wo_hand: | |
| rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
| if 'full_body_rgb_map' in mask_output: | |
| os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
| full_body_mask = mask_output['full_body_rgb_map'] | |
| full_body_mask.clip_(0., 1.) | |
| full_body_mask = (full_body_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy()) | |
| if 'hand_only_rgb_map' in mask_output: | |
| os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
| hand_only_mask = mask_output['hand_only_rgb_map'] | |
| hand_only_mask.clip_(0., 1.) | |
| hand_only_mask = (hand_only_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy()) | |
| if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
| # mask only covers hand | |
| body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) | |
| body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save | |
| hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) | |
| hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01 | |
| if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 | |
| if_mask_r_hand = if_mask_r_hand.cpu().numpy() | |
| body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) | |
| body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save | |
| hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) | |
| hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 | |
| if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 | |
| if_mask_l_hand = if_mask_l_hand.cpu().numpy() | |
| # 保存左右手被遮挡部分的mask | |
| red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) | |
| blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) | |
| all_mask = red_mask | blue_mask | |
| # now save 3 mask to 3 folders | |
| os.makedirs(output_dir + '/hand_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/r_hand_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/l_hand_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/hand_visual', exist_ok = True) | |
| all_mask = (all_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy()) | |
| r_hand_mask = (body_red_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy()) | |
| l_hand_mask = (body_blue_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy()) | |
| hand_visual = [if_mask_r_hand, if_mask_l_hand] | |
| # save to npy | |
| with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f: | |
| np.save(f, hand_visual) | |
| # now build sleeve_mask | |
| if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output: | |
| os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True) | |
| mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) | |
| mask = mask.cpu().numpy().astype(np.uint8) | |
| # 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
| kernel = np.ones((5, 5), np.uint8) | |
| # 应用膨胀操作 | |
| mask = cv.dilate(mask, kernel, iterations=3) | |
| mask = torch.tensor(mask).to(self.device) | |
| left_hand_mask = mask_output['left_hand_rgb_map'] | |
| left_hand_mask.clip_(0., 1.) | |
| # non white part is mask | |
| left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) | |
| left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 | |
| # dele two hand mask | |
| left_hand_mask = left_hand_mask & ~mask | |
| right_hand_mask = mask_output['right_hand_rgb_map'] | |
| right_hand_mask.clip_(0., 1.) | |
| right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) | |
| right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 | |
| right_hand_mask = right_hand_mask & ~mask | |
| # save | |
| left_hand_mask = (left_hand_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy()) | |
| right_hand_mask = (right_hand_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy()) | |
| rgb_map = output['rgb_map'] | |
| rgb_map.clip_(0., 1.) | |
| rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() | |
| cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map) | |
| # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map | |
| if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
| rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
| rgb_map_wo_hand.clip_(0., 1.) | |
| rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() | |
| r_mask = (r_hand_mask>128).cpu().numpy() | |
| l_mask = (l_hand_mask>128).cpu().numpy() | |
| mask = r_mask | l_mask | |
| mask = mask.astype(np.uint8) | |
| # 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
| kernel = np.ones((5, 5), np.uint8) | |
| # 应用膨胀操作 | |
| mask = cv.dilate(mask, kernel, iterations=3) | |
| mask = mask.astype(np.bool_) | |
| mask = np.expand_dims(mask, axis=2) | |
| # print('mask shape: ', mask.shape) | |
| import ipdb | |
| # ipdb.set_trace() | |
| mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask | |
| cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.png' % item['data_idx'], mix) | |
| if 'torso_map' in output: | |
| os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
| torso_map = output['torso_map'][:, :, 0] | |
| torso_map.clip_(0., 1.) | |
| torso_map = (torso_map * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy()) | |
| if 'mask_map' in output: | |
| os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
| mask_map = output['mask_map'][:, :, 0] | |
| mask_map.clip_(0., 1.) | |
| mask_map = (mask_map * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy()) | |
| if self.body['test'].get('save_tex_map', False): | |
| os.makedirs(output_dir + '/cano_tex_map', exist_ok = True) | |
| cano_tex_map = output['cano_tex_map'] | |
| cano_tex_map.clip_(0., 1.) | |
| cano_tex_map = (cano_tex_map * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/cano_tex_map/%08d.png' % item['data_idx'], cano_tex_map.cpu().numpy()) | |
| if self.body['test'].get('save_ply', False): | |
| if item['data_idx'] == 0: | |
| save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians']) | |
| for k in output['posed_gaussians'].keys(): | |
| if isinstance(output['posed_gaussians'][k], torch.Tensor): | |
| output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy() | |
| np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians']) | |
| np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']), | |
| betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), | |
| global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(), | |
| transl=item['transl'].reshape([-1]).detach().cpu().numpy(), | |
| body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy()) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.)) | |
| torch.cuda.empty_cache() | |
| def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths): | |
| with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp: | |
| outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \ | |
| "white_background=False, data_device='cuda', eval=False)" % ( | |
| 3, self.body['train']['data']['data_dir'], dump_dir) | |
| fp.write(outstr) | |
| with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp: | |
| cam_jsons = [] | |
| for ci in range(len(extrs)): | |
| extr, intr = extrs[ci], intrs[ci] | |
| img_h, img_w = img_heights[ci], img_widths[ci] | |
| w2c = extr | |
| c2w = np.linalg.inv(w2c) | |
| pos = c2w[:3, 3] | |
| rot = c2w[:3, :3] | |
| serializable_array_2d = [x.tolist() for x in rot] | |
| camera_entry = { | |
| 'id': ci, | |
| 'img_name': '%08d' % ci, | |
| 'width': int(img_w), | |
| 'height': int(img_h), | |
| 'position': pos.tolist(), | |
| 'rotation': serializable_array_2d, | |
| 'fy': float(intr[1, 1]), | |
| 'fx': float(intr[0, 0]), | |
| } | |
| cam_jsons.append(camera_entry) | |
| json.dump(cam_jsons, fp) | |
| return | |
| def test_head(self): | |
| dataset = ReenactmentDataset(self.head_config.dataset) | |
| dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) | |
| device = torch.device('cuda:%d' % self.head_config.gpu_id) | |
| gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) | |
| gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule, | |
| xyz=gaussianhead_state_dict['xyz'], | |
| feature=gaussianhead_state_dict['feature'], | |
| landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device) | |
| gaussianhead.load_state_dict(gaussianhead_state_dict) | |
| supres = SuperResolutionModule(self.head_config.supresmodule).to(device) | |
| supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage)) | |
| camera = CameraModule() | |
| recorder = ReenactmentRecorder(self.head_config.recorder) | |
| app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, self.head_config.gpu_id, dataset.freeview) | |
| if self.head.offline_rendering_param_fpath is None: | |
| app.run(stop_fid=800) | |
| else: | |
| app.run_for_offline_stitching(self.head.offline_rendering_param_fpath) | |
| def cal_cat_param(self): | |
| calc_offline_rendering_param( | |
| self.cat.body_gaussian_root_dir, | |
| self.cat.ref_head_gaussian_path, | |
| self.cat.ref_head_param_path, | |
| self.cat.render_cam_fpath, | |
| self.cat.body_head_blending_param_path | |
| ) | |
| if __name__ == '__main__': | |
| conf = OmegaConf.load('configs/example.yaml') | |
| avatar = Avatar(conf) | |
| avatar.test_body() | |
| # avatar.test_head() |