Spaces:
Running
on
Zero
Running
on
Zero
| #adopted from https://github.com/autonomousvision/gaussian-opacity-fields/blob/main/extract_mesh.py | |
| import torch | |
| from scene import Scene | |
| import os | |
| from os import makedirs | |
| from gaussian_renderer import render, integrate | |
| import random | |
| from tqdm import tqdm | |
| from argparse import ArgumentParser | |
| from arguments import ModelParams, PipelineParams, get_combined_args | |
| from gaussian_renderer import GaussianModel | |
| import numpy as np | |
| import trimesh | |
| from tetranerf.utils.extension import cpp | |
| from utils.tetmesh import marching_tetrahedra | |
| def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size): | |
| final_alpha = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda") | |
| with torch.no_grad(): | |
| for _, view in enumerate(tqdm(views, desc="Rendering progress")): | |
| ret = integrate(points, view, gaussians, pipeline, background, kernel_size=kernel_size) | |
| alpha_integrated = ret["alpha_integrated"] | |
| final_alpha = torch.min(final_alpha, alpha_integrated) | |
| alpha = 1 - final_alpha | |
| return alpha | |
| def evaluage_cull_alpha(points, views, masks, gaussians, pipeline, background, kernel_size): | |
| # final_sdf = torch.zeros((points.shape[0]), dtype=torch.float32, device="cuda") | |
| final_sdf = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda") | |
| weight = torch.zeros((points.shape[0]), dtype=torch.int32, device="cuda") | |
| with torch.no_grad(): | |
| for cam_id, view in enumerate(tqdm(views, desc="Rendering progress")): | |
| torch.cuda.empty_cache() | |
| ret = integrate(points, view, gaussians, pipeline, background, kernel_size) | |
| alpha_integrated = ret["alpha_integrated"] | |
| point_coordinate = ret["point_coordinate"] | |
| point_coordinate[:,0] = (point_coordinate[:,0]*2+1)/(views[cam_id].image_width-1) - 1 | |
| point_coordinate[:,1] = (point_coordinate[:,1]*2+1)/(views[cam_id].image_height-1) - 1 | |
| rendered_mask = ret["render"][7] | |
| mask = rendered_mask[None] | |
| if not view.gt_mask is None: | |
| mask = mask * view.gt_mask | |
| if not masks is None: | |
| mask = mask * masks[cam_id] | |
| valid_point_prob = torch.nn.functional.grid_sample(mask.type(torch.float32)[None],point_coordinate[None,None],padding_mode='zeros',align_corners=False) | |
| valid_point_prob = valid_point_prob[0,0,0] | |
| valid_point = valid_point_prob>0.5 | |
| final_sdf = torch.where(valid_point, torch.min(alpha_integrated,final_sdf), final_sdf) | |
| weight = torch.where(valid_point, weight+1, weight) | |
| final_sdf = torch.where(weight>0,0.5-final_sdf,-100) | |
| return final_sdf | |
| def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians: GaussianModel, pipeline, background, kernel_size): | |
| # generate tetra points here | |
| points, points_scale = gaussians.get_tetra_points() | |
| cells = cpp.triangulate(points) | |
| mask = None | |
| sdf = evaluage_cull_alpha(points, views, mask, gaussians, pipeline, background, kernel_size) | |
| torch.cuda.empty_cache() | |
| # the function marching_tetrahedra costs much memory, so we move it to cpu. | |
| verts_list, scale_list, faces_list, _ = marching_tetrahedra(points.cpu()[None], cells.cpu().long(), sdf[None].cpu(), points_scale[None].cpu()) | |
| del points | |
| del points_scale | |
| del cells | |
| end_points, end_sdf = verts_list[0] | |
| end_scales = scale_list[0] | |
| end_points, end_sdf, end_scales = end_points.cuda(), end_sdf.cuda(), end_scales.cuda() | |
| faces=faces_list[0].cpu().numpy() | |
| points = (end_points[:, 0, :] + end_points[:, 1, :]) / 2. | |
| left_points = end_points[:, 0, :] | |
| right_points = end_points[:, 1, :] | |
| left_sdf = end_sdf[:, 0, :] | |
| right_sdf = end_sdf[:, 1, :] | |
| left_scale = end_scales[:, 0, 0] | |
| right_scale = end_scales[:, 1, 0] | |
| distance = torch.norm(left_points - right_points, dim=-1) | |
| scale = left_scale + right_scale | |
| n_binary_steps = 8 | |
| for step in range(n_binary_steps): | |
| print("binary search in step {}".format(step)) | |
| mid_points = (left_points + right_points) / 2 | |
| mid_sdf = evaluage_cull_alpha(mid_points, views, mask, gaussians, pipeline, background, kernel_size) | |
| mid_sdf = mid_sdf.unsqueeze(-1) | |
| ind_low = ((mid_sdf < 0) & (left_sdf < 0)) | ((mid_sdf > 0) & (left_sdf > 0)) | |
| left_sdf[ind_low] = mid_sdf[ind_low] | |
| right_sdf[~ind_low] = mid_sdf[~ind_low] | |
| left_points[ind_low.flatten()] = mid_points[ind_low.flatten()] | |
| right_points[~ind_low.flatten()] = mid_points[~ind_low.flatten()] | |
| points = (left_points + right_points) / 2 | |
| mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, process=False) | |
| # filter | |
| vertice_mask = (distance <= scale).cpu().numpy() | |
| face_mask = vertice_mask[faces].all(axis=1) | |
| mesh.update_vertices(vertice_mask) | |
| mesh.update_faces(face_mask) | |
| mesh.export(os.path.join(model_path,"recon.ply")) | |
| def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams): | |
| with torch.no_grad(): | |
| gaussians = GaussianModel(dataset.sh_degree) | |
| scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) | |
| gaussians.load_ply(os.path.join(dataset.model_path, "point_cloud", f"iteration_{iteration}", "point_cloud.ply")) | |
| bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] | |
| background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
| kernel_size = dataset.kernel_size | |
| cams = scene.getTrainCameras() | |
| marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size) | |
| if __name__ == "__main__": | |
| # Set up command line argument parser | |
| parser = ArgumentParser(description="Testing script parameters") | |
| model = ModelParams(parser, sentinel=True) | |
| pipeline = PipelineParams(parser) | |
| parser.add_argument("--iteration", default=30000, type=int) | |
| parser.add_argument("--quiet", action="store_true") | |
| args = get_combined_args(parser) | |
| print("Rendering " + args.model_path) | |
| random.seed(0) | |
| np.random.seed(0) | |
| torch.manual_seed(0) | |
| torch.cuda.set_device(torch.device("cuda:0")) | |
| extract_mesh(model.extract(args), args.iteration, pipeline.extract(args)) |