Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import argparse | |
| from config.config import config_reenactment | |
| from lib.dataset.Dataset import ReenactmentDataset | |
| from lib.dataset.DataLoaderX import DataLoaderX | |
| from lib.module.GaussianHeadModule import GaussianHeadModule | |
| from lib.module.SuperResolutionModule import SuperResolutionModule | |
| from lib.module.CameraModule import CameraModule | |
| from lib.recorder.Recorder import ReenactmentRecorder | |
| from lib.apps.Reenactment import Reenactment | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default='config/reenactment_N031.yaml') | |
| parser.add_argument('--offline_rendering_param_fpath', type=str, default=None) | |
| arg = parser.parse_args() | |
| cfg = config_reenactment() | |
| cfg.load(arg.config) | |
| cfg = cfg.get_cfg() | |
| dataset = ReenactmentDataset(cfg.dataset) | |
| dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) | |
| device = torch.device('cuda:%d' % cfg.gpu_id) | |
| gaussianhead_state_dict = torch.load(cfg.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) | |
| gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule, | |
| xyz=gaussianhead_state_dict['xyz'], | |
| feature=gaussianhead_state_dict['feature'], | |
| landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device) | |
| gaussianhead.load_state_dict(gaussianhead_state_dict) | |
| supres = SuperResolutionModule(cfg.supresmodule).to(device) | |
| supres.load_state_dict(torch.load(cfg.load_supres_checkpoint, map_location=lambda storage, loc: storage)) | |
| camera = CameraModule() | |
| recorder = ReenactmentRecorder(cfg.recorder) | |
| app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, cfg.gpu_id, dataset.freeview) | |
| if arg.offline_rendering_param_fpath is None: | |
| app.run(stop_fid=800) | |
| else: | |
| app.run_for_offline_stitching(arg.offline_rendering_param_fpath) | |