Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| import numpy as np | |
| import open3d as o3d | |
| import os | |
| import argparse | |
| import torch | |
| import trimesh | |
| import pyrender | |
| import copy | |
| from copy import deepcopy | |
| import torch.nn.functional as F | |
| from help_func import auto_orient_and_center_poses | |
| import cv2 | |
| def extract_depth_from_mesh(mesh, | |
| c2w_list, | |
| H, W, fx, fy, cx, cy, | |
| far=20.0,): | |
| """Adapted from Go-Surf: https://github.com/JingwenWang95/go-surf""" | |
| os.environ['PYOPENGL_PLATFORM'] = 'egl' # allows for GPU-accelerated rendering | |
| scene = pyrender.Scene() | |
| #mesh = trimesh.load("/home/yuzh/mnt/A100_data/sdfstudio/meshes_tnt/bakedangelo/Courthouse_fullres_1024.ply") | |
| #mesh = trimesh.load("/home/yuzh/mnt/A100_data/sdfstudio/meshes_tnt/bakedangelo/Caterpillar_fullres_1024.ply") | |
| #mesh = trimesh.load("/home/yuzh/mnt/A100_data/sdfstudio/meshes_tnt/bakedangelo/Truck_fullres_1024.ply") | |
| #mesh = trimesh.load("/home/yuzh/mnt/A3_data/sdfstudio/meshes_tnt/bakedangelo/Meetingroom_fullres_1024_scaleback.ply") | |
| # mesh = trimesh.load("/home/yuzh/mnt/A3_data/sdfstudio/meshes_tnt/bakedangelo/Barn_fullres_1024.ply") | |
| mesh = pyrender.Mesh.from_trimesh(mesh) | |
| scene.add(mesh) | |
| """ | |
| import glob | |
| for f in glob.glob("/home/yuzh/mnt/A100/Projects/sdfstudio/tmp_meshes/*.ply"): | |
| mesh = trimesh.load(f) | |
| mesh = pyrender.Mesh.from_trimesh(mesh) | |
| scene.add(mesh) | |
| print(f) | |
| """ | |
| camera = pyrender.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy, znear=0.01, zfar=far) | |
| camera_node = pyrender.Node(camera=camera, matrix=np.eye(4)) | |
| scene.add_node(camera_node) | |
| renderer = pyrender.OffscreenRenderer(W, H) | |
| flags = pyrender.RenderFlags.OFFSCREEN | pyrender.RenderFlags.DEPTH_ONLY | pyrender.RenderFlags.SKIP_CULL_FACES | |
| depths = [] | |
| for c2w in c2w_list: | |
| c2w = c2w.detach().numpy() | |
| # Convert camera coordinate system from OpenCV to OpenGL | |
| # Details refer to: https://pyrender.readthedocs.io/en/latest/examples/cameras.html | |
| c2w_gl = deepcopy(c2w) | |
| # nerfstudio's .json file is already OpenGL coordinate | |
| #c2w_gl[:3, 1] *= -1 | |
| #c2w_gl[:3, 2] *= -1 | |
| scene.set_pose(camera_node, c2w_gl) | |
| depth = renderer.render(scene, flags) | |
| #print(depth, depth.min(), depth.max(), depth.shape) | |
| #exit(-1) | |
| #cv2.imshow("s", depth) | |
| #cv2.waitKey(0) | |
| depth = torch.from_numpy(depth) | |
| depths.append(depth) | |
| renderer.delete() | |
| return depths | |
| class Mesher(object): | |
| def __init__(self, H, W, fx, fy, cx, cy, far, points_batch_size=5e5): | |
| """ | |
| Mesher class, given a scene representation, the mesher extracts the mesh from it. | |
| Args: | |
| cfg: (dict), parsed config dict | |
| args: (class 'argparse.Namespace'), argparse arguments | |
| slam: (class NICE-SLAM), NICE-SLAM main class | |
| points_batch_size: (int), maximum points size for query in one batch | |
| Used to alleviate GPU memory usage. Defaults to 5e5 | |
| ray_batch_size: (int), maximum ray size for query in one batch | |
| Used to alleviate GPU memory usage. Defaults to 1e5 | |
| """ | |
| self.points_batch_size = int(points_batch_size) | |
| self.scale = 1.0 | |
| self.device = 'cuda:0' | |
| self.forecast_radius = 0 | |
| self.H, self.W, self.fx, self.fy, self.cx, self.cy = H, W, fx, fy, cx, cy | |
| self.resolution = 256 | |
| self.level_set = 0.0 | |
| self.remove_small_geometry_threshold = 0.2 | |
| self.get_largest_components = True | |
| self.verbose = True | |
| def point_masks(self, | |
| input_points, | |
| depth_list, | |
| estimate_c2w_list): | |
| """ | |
| Split the input points into seen, unseen, and forecast, | |
| according to the estimated camera pose and depth image. | |
| Args: | |
| input_points: (Tensor), input points | |
| keyframe_dict: (list), list of keyframe info dictionary | |
| estimate_c2w_list: (list), estimated camera pose. | |
| idx: (int), current frame index | |
| device: (str), device name to compute on. | |
| get_mask_use_all_frames: | |
| Returns: | |
| seen_mask: (Tensor), the mask for seen area. | |
| forecast_mask: (Tensor), the mask for forecast area. | |
| unseen_mask: (Tensor), the mask for unseen area. | |
| """ | |
| H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy | |
| device =self.device | |
| if not isinstance(input_points, torch.Tensor): | |
| input_points = torch.from_numpy(input_points) | |
| input_points = input_points.clone().detach().float() | |
| mask = [] | |
| forecast_mask = [] | |
| # this eps should be tuned for the scene | |
| eps = 0.005 | |
| for _, pnts in enumerate(torch.split(input_points, self.points_batch_size, dim=0)): | |
| n_pts, _ = pnts.shape | |
| valid = torch.zeros(n_pts).to(device).bool() | |
| valid_num = torch.zeros(n_pts).to(device).int() | |
| valid_forecast = torch.zeros(n_pts).to(device).bool() | |
| r = self.forecast_radius | |
| for i in range(len(estimate_c2w_list)): | |
| points = pnts.to(device).float() | |
| c2w = estimate_c2w_list[i].to(device).float() | |
| # transform to opencv coordinate as nerfstudio dataparser's .json file is in opengl coordinate | |
| # c2w[:3, 1:3] *= -1 | |
| depth = depth_list[i].to(device) | |
| w2c = torch.inverse(c2w).to(device).float() | |
| ones = torch.ones_like(points[:, 0]).reshape(-1, 1).to(device) | |
| homo_points = torch.cat([points, ones], dim=1).reshape(-1, 4, 1).to(device).float() | |
| cam_cord_homo = w2c @ homo_points | |
| cam_cord = cam_cord_homo[:, :3, :] # [N, 3, 1] | |
| K = np.eye(3) | |
| K[0, 0], K[0, 2], K[1, 1], K[1, 2] = fx, cx, fy, cy | |
| K = torch.from_numpy(K).to(device) | |
| uv = K.float() @ cam_cord.float() | |
| z = uv[:, -1:] + 1e-8 | |
| uv = uv[:, :2] / z # [N, 2, 1] | |
| u, v = uv[:, 0, 0].float(), uv[:, 1, 0].float() | |
| z = z[:, 0, 0].float() | |
| in_frustum = (u >= 0) & (u <= W-1) & (v >= 0) & (v <= H-1) & (z > 0) | |
| forecast_frustum = (u >= -r) & (u <= W-1+r) & (v >= -r) & (v <= H-1+r) & (z > 0) | |
| depth = depth.reshape(1, 1, H, W) | |
| vgrid = uv.reshape(1, 1, -1, 2) | |
| # normalized to [-1, 1] | |
| vgrid[..., 0] = (vgrid[..., 0] / (W - 1) * 2.0 - 1.0) | |
| vgrid[..., 1] = (vgrid[..., 1] / (H - 1) * 2.0 - 1.0) | |
| depth_sample = F.grid_sample(depth, vgrid, padding_mode='border', align_corners=True) | |
| depth_sample = depth_sample.reshape(-1) | |
| is_front_face = torch.where((depth_sample > 0.0), (z < (depth_sample + eps)), torch.ones_like(z).bool()) | |
| is_forecast_face = torch.where((depth_sample > 0.0), (z < (depth_sample + eps)), torch.ones_like(z).bool()) | |
| in_frustum = in_frustum & is_front_face | |
| valid = valid | in_frustum.bool() | |
| valid_num = valid_num + in_frustum.int() | |
| forecast_frustum = forecast_frustum & is_forecast_face | |
| forecast_frustum = in_frustum | forecast_frustum | |
| valid_forecast = valid_forecast | forecast_frustum.bool() | |
| valid = valid_num >= 20 | |
| # valid = valid_num >= 80 | |
| mask.append(valid.cpu().numpy()) | |
| forecast_mask.append(valid_forecast.cpu().numpy()) | |
| mask = np.concatenate(mask, axis=0) | |
| forecast_mask = np.concatenate(forecast_mask, axis=0) | |
| return mask, forecast_mask | |
| def get_connected_mesh(self, mesh, get_largest_components=False): | |
| print("split") | |
| components = mesh.split(only_watertight=False) | |
| print("split completed") | |
| if get_largest_components: | |
| areas = np.array([c.area for c in components], dtype=np.float) | |
| mesh = components[areas.argmax()] | |
| else: | |
| new_components = [] | |
| global_area = mesh.area | |
| for comp in components: | |
| if comp.area > self.remove_small_geometry_threshold * global_area: | |
| new_components.append(comp) | |
| mesh = trimesh.util.concatenate(new_components) | |
| return mesh | |
| def cull_mesh(self, | |
| mesh, | |
| estimate_c2w_list): | |
| """ | |
| Extract mesh from scene representation and save mesh to file. | |
| Args: | |
| mesh_out_file: (str), output mesh filename | |
| estimate_c2w_list: (Tensor), estimated camera pose, camera coordinate system is same with OpenCV | |
| [N, 4, 4] | |
| """ | |
| step = 1 | |
| print('Start Mesh Culling', end='') | |
| # cull with 3d projection | |
| print(f' --->> {step}(Projection)', end='') | |
| forward_depths = extract_depth_from_mesh( | |
| mesh, estimate_c2w_list, H=self.H, W=self.W, fx=self.fx, fy=self.fy, cx=self.cx, cy=self.cy, far=20.0 | |
| ) | |
| print("after forward depth") | |
| """ | |
| backward_mesh = deepcopy(mesh) | |
| backward_mesh.faces[:, [1, 2]] = backward_mesh.faces[:, [2, 1]] # make the mesh faces from, e.g., facing inside to outside | |
| backward_depths = extract_depth_from_mesh( | |
| backward_mesh, estimate_c2w_list, H=self.H, W=self.W, fx=self.fx, fy=self.fy, cx=self.cx, cy=self.cy, far=20.0 | |
| ) | |
| depth_list = [] | |
| for i in range(len(forward_depths)): | |
| depth = torch.where(forward_depths[i] > 0, forward_depths[i], backward_depths[i]) | |
| depth = torch.where((backward_depths[i] > 0) & (backward_depths[i] < depth), backward_depths[i], depth) | |
| depth_list.append(depth) | |
| """ | |
| depth_list = forward_depths | |
| print("in point masks") | |
| vertices = mesh.vertices[:, :3] | |
| mask, forecast_mask = self.point_masks( | |
| vertices, depth_list, estimate_c2w_list | |
| ) | |
| print(mask.shape, forecast_mask.shape, mask.mean()) | |
| face_mask = mask[mesh.faces].all(axis=1) | |
| mesh_with_hole = deepcopy(mesh) | |
| mesh_with_hole.update_faces(face_mask) | |
| mesh_with_hole.remove_unreferenced_vertices() | |
| #mesh_with_hole.process(validate=True) | |
| step += 1 | |
| print("compute componet") | |
| # cull by computing connected components | |
| print(f' --->> {step}(Component)', end='') | |
| #cull_mesh = self.get_connected_mesh(mesh_with_hole, self.get_largest_components) | |
| cull_mesh = mesh_with_hole | |
| print("after compute componet") | |
| step += 1 | |
| if abs(self.forecast_radius) > 0: | |
| # for forecasting | |
| print(f' --->> {step}(Forecast:{self.forecast_radius})', end='') | |
| forecast_face_mask = forecast_mask[mesh.faces].all(axis=1) | |
| forecast_mesh = deepcopy(mesh) | |
| forecast_mesh.update_faces(forecast_face_mask) | |
| forecast_mesh.remove_unreferenced_vertices() | |
| cull_pc = o3d.geometry.PointCloud( | |
| o3d.utility.Vector3dVector(np.array(cull_mesh.vertices)) | |
| ) | |
| aabb = cull_pc.get_oriented_bounding_box() | |
| indices = aabb.get_point_indices_within_bounding_box( | |
| o3d.utility.Vector3dVector(np.array(forecast_mesh.vertices)) | |
| ) | |
| bound_mask = np.zeros(len(forecast_mesh.vertices)) | |
| bound_mask[indices] = 1.0 | |
| bound_mask = bound_mask.astype(np.bool_) | |
| forecast_face_mask = bound_mask[forecast_mesh.faces].all(axis=1) | |
| forecast_mesh.update_faces(forecast_face_mask) | |
| forecast_mesh.remove_unreferenced_vertices() | |
| forecast_mesh = self.get_connected_mesh(forecast_mesh, self.get_largest_components) | |
| step += 1 | |
| else: | |
| forecast_mesh = deepcopy(cull_mesh) | |
| print(' --->> Done!') | |
| return cull_mesh, forecast_mesh | |
| def __call__(self, mesh_path, estimate_c2w_list): | |
| print(f'Loading mesh from {mesh_path}...') | |
| mesh = trimesh.load(mesh_path, process=True) | |
| mesh.merge_vertices() | |
| """ | |
| print(f'Mesh loaded from {mesh_path}!') | |
| mask = np.linalg.norm(mesh.vertices, axis=-1) < 1.0 | |
| print(mask.shape, mask.mean()) | |
| face_mask = mask[mesh.faces].all(axis=1) | |
| mesh_with_hole = deepcopy(mesh) | |
| mesh_with_hole.update_faces(face_mask) | |
| mesh_with_hole.remove_unreferenced_vertices() | |
| mesh = mesh_with_hole | |
| print(f'Mesh clear from {mesh_path}!') | |
| """ | |
| mesh_out_file = mesh_path.replace('.ply', '_cull.ply') | |
| cull_mesh, forecast_mesh = self.cull_mesh( | |
| mesh=mesh, | |
| estimate_c2w_list=estimate_c2w_list, | |
| ) | |
| cull_mesh.export(mesh_out_file) | |
| if self.verbose: | |
| print("\nINFO: Save mesh at {}!\n".format(mesh_out_file)) | |
| torch.cuda.empty_cache() | |
| def read_trajectory(filename): | |
| traj = [] | |
| with open(filename, "r") as f: | |
| metastr = f.readline() | |
| while metastr: | |
| metadata = map(int, metastr.split()) | |
| mat = np.zeros(shape=(4, 4)) | |
| for i in range(4): | |
| matstr = f.readline() | |
| mat[i, :] = np.fromstring(matstr, dtype=float, sep=" \t") | |
| traj.append(mat) | |
| metastr = f.readline() | |
| return traj | |
| def get_traj(traj_path): | |
| print(f'Load trajectory from {traj_path}.') | |
| traj_to_register = [] | |
| if traj_path.endswith('.npy'): | |
| ld = np.load(traj_path) | |
| for i in range(len(ld)): | |
| # traj_to_register.append(CameraPose(meta=None, mat=ld[i])) | |
| traj_to_register.append(ld[i]) | |
| elif traj_path.endswith('.json'): # instant-npg or sdfstudio format | |
| import json | |
| with open(traj_path, encoding='UTF-8') as f: | |
| meta = json.load(f) | |
| poses_dict = {} | |
| for i, frame in enumerate(meta['frames']): | |
| filepath = frame['file_path'] | |
| new_i = int(filepath[13:18]) - 1 | |
| poses_dict[new_i] = np.array(frame['transform_matrix']) | |
| poses = [] | |
| for i in range(len(poses_dict)): | |
| poses.append(poses_dict[i]) | |
| poses = torch.from_numpy(np.array(poses).astype(np.float32)) | |
| poses, _ = auto_orient_and_center_poses(poses, method='up', center_poses=True) | |
| scale_factor = 1.0 / float(torch.max(torch.abs(poses[:, :3, 3]))) | |
| poses[:, :3, 3] *= scale_factor | |
| poses = poses.numpy() | |
| for i in range(len(poses)): | |
| traj_to_register.append(poses[i]) | |
| else: | |
| traj_to_register = read_trajectory(traj_path) | |
| # with open("test.xyz","w") as file_object: | |
| # for m in traj_to_register: | |
| # # p = - m[:3,:3].T @ m[:3,3:] | |
| # # p = p[:,0] | |
| # p = m[:3,-1] | |
| # print("%f %f %f"%(p[0],p[1],p[2]),file=file_object) | |
| for i in range(len(traj_to_register)): | |
| c2w = torch.from_numpy(traj_to_register[i]).float() | |
| if c2w.shape == (3, 4): | |
| c2w = torch.cat([ | |
| c2w, | |
| torch.tensor([[0, 0, 0, 1]]).float() | |
| ], dim=0) | |
| traj_to_register[i] = c2w # [4, 4] | |
| print(f'Trajectory loaded from {traj_path}, including {len(traj_to_register)} camera views.') | |
| return traj_to_register | |
| if __name__ == "__main__": | |
| print('Start culling...') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--traj-path", | |
| type=str, | |
| required=True, | |
| help= | |
| "path to trajectory file. See `convert_to_logfile.py` to create this file.", | |
| ) | |
| parser.add_argument( | |
| "--ply-path", | |
| type=str, | |
| required=True, | |
| help="path to reconstruction ply file", | |
| ) | |
| args = parser.parse_args() | |
| estimate_c2w_list = get_traj(args.traj_path) | |
| # for TanksandTemples dataset | |
| H, W = 1080, 1920 | |
| fx = 1163.8678928442187 | |
| fy = 1172.793101201448 | |
| cx = 962.3120628412543 | |
| cy = 542.0667209577691 | |
| far = 20.0 | |
| mesher = Mesher(H, W, fx, fy, cx, cy, far, points_batch_size=5e5) | |
| # mesher = Mesher(H*2, W*2, fx*2, fy*2, cx*2, cy*2, far, points_batch_size=5e5) | |
| mesher(args.ply_path, estimate_c2w_list) | |
| print('Done!') | |