Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import lightning.pytorch as pl | |
| from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset | |
| from torch.utils.data import DataLoader | |
| from partfield.model.UNet.model import ResidualUNet3D | |
| from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane | |
| from partfield.model.model_utils import VanillaMLP | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import os | |
| import trimesh | |
| import skimage | |
| import numpy as np | |
| import h5py | |
| import torch.distributed as dist | |
| from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat | |
| import json | |
| import gc | |
| import time | |
| from plyfile import PlyData, PlyElement | |
| class Model(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.cfg = cfg | |
| self.automatic_optimization = False | |
| self.triplane_resolution = cfg.triplane_resolution | |
| self.triplane_channels_low = cfg.triplane_channels_low | |
| self.triplane_transformer = TriplaneTransformer( | |
| input_dim=cfg.triplane_channels_low * 2, | |
| transformer_dim=1024, | |
| transformer_layers=6, | |
| transformer_heads=8, | |
| triplane_low_res=32, | |
| triplane_high_res=128, | |
| triplane_dim=cfg.triplane_channels_high, | |
| ) | |
| self.sdf_decoder = VanillaMLP(input_dim=64, | |
| output_dim=1, | |
| out_activation="tanh", | |
| n_neurons=64, #64 | |
| n_hidden_layers=6) #6 | |
| self.use_pvcnn = cfg.use_pvcnnonly | |
| self.use_2d_feat = cfg.use_2d_feat | |
| if self.use_pvcnn: | |
| self.pvcnn = TriPlanePC2Encoder( | |
| cfg.pvcnn, | |
| device="cuda", | |
| shape_min=-1, | |
| shape_length=2, | |
| use_2d_feat=self.use_2d_feat) #.cuda() | |
| self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True)) | |
| self.grid_coord = get_grid_coord(256) | |
| self.mse_loss = torch.nn.MSELoss() | |
| self.l1_loss = torch.nn.L1Loss(reduction='none') | |
| if cfg.regress_2d_feat: | |
| self.feat_decoder = VanillaMLP(input_dim=64, | |
| output_dim=192, | |
| out_activation="GELU", | |
| n_neurons=64, #64 | |
| n_hidden_layers=6) #6 | |
| def predict_dataloader(self): | |
| if self.cfg.remesh_demo: | |
| dataset = Demo_Remesh_Dataset(self.cfg) | |
| elif self.cfg.correspondence_demo: | |
| dataset = Correspondence_Demo_Dataset(self.cfg) | |
| else: | |
| dataset = Demo_Dataset(self.cfg) | |
| dataloader = DataLoader(dataset, | |
| num_workers=self.cfg.dataset.val_num_workers, | |
| batch_size=self.cfg.dataset.val_batch_size, | |
| shuffle=False, | |
| pin_memory=True, | |
| drop_last=False) | |
| return dataloader | |
| def predict_step(self, batch, batch_idx): | |
| save_dir = f"{self.cfg.result_name}" | |
| os.makedirs(save_dir, exist_ok=True) | |
| uid = batch['uid'][0] | |
| view_id = 0 | |
| starttime = time.time() | |
| if uid == "car" or uid == "complex_car": | |
| # if uid == "complex_car": | |
| print("Skipping this for now.") | |
| print(uid) | |
| return | |
| ### Skip if model already processed | |
| if os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}.npy') or os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy'): | |
| print("Already processed "+uid) | |
| return | |
| N = batch['pc'].shape[0] | |
| assert N == 1 | |
| if self.use_2d_feat: | |
| print("ERROR. Dataloader not implemented with input 2d feat.") | |
| exit() | |
| else: | |
| pc_feat = self.pvcnn(batch['pc'], batch['pc']) | |
| planes = pc_feat | |
| planes = self.triplane_transformer(planes) | |
| sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) | |
| if self.cfg.is_pc: | |
| tensor_vertices = batch['pc'].reshape(1, -1, 3).cuda().to(torch.float16) | |
| point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C | |
| point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448) | |
| np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat) | |
| print(f"Exported part_feat_{uid}_{view_id}.npy") | |
| ########### | |
| from sklearn.decomposition import PCA | |
| data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) | |
| pca = PCA(n_components=3) | |
| data_reduced = pca.fit_transform(data_scaled) | |
| data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) | |
| colors_255 = (data_reduced * 255).astype(np.uint8) | |
| points = batch['pc'].squeeze().detach().cpu().numpy() | |
| if colors_255 is None: | |
| colors_255 = np.full_like(points, 255) # Default to white color (255,255,255) | |
| else: | |
| assert colors_255.shape == points.shape, "Colors must have the same shape as points" | |
| # Convert to structured array for PLY format | |
| vertex_data = np.array( | |
| [(*point, *color) for point, color in zip(points, colors_255)], | |
| dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")] | |
| ) | |
| # Create PLY element | |
| el = PlyElement.describe(vertex_data, "vertex") | |
| # Write to file | |
| filename = f'{save_dir}/feat_pca_{uid}_{view_id}.ply' | |
| PlyData([el], text=True).write(filename) | |
| print(f"Saved PLY file: {filename}") | |
| ############ | |
| else: | |
| use_cuda_version = True | |
| if use_cuda_version: | |
| def sample_points(vertices, faces, n_point_per_face): | |
| # Generate random barycentric coordinates | |
| # borrowed from Kaolin https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/mesh/trianglemesh.py#L43 | |
| n_f = faces.shape[0] | |
| u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1), | |
| device=vertices.device, | |
| dtype=vertices.dtype)) | |
| v = torch.rand((n_f, n_point_per_face, 1), | |
| device=vertices.device, | |
| dtype=vertices.dtype) | |
| w0 = 1 - u | |
| w1 = u * (1 - v) | |
| w2 = u * v | |
| face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1)) | |
| face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1)) | |
| face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1)) | |
| points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1) | |
| return points | |
| def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face): | |
| n_sample_each = self.cfg.n_sample_each # we iterate over this to avoid OOM | |
| n_v = tensor_vertices.shape[1] | |
| n_sample = n_v // n_sample_each + 1 | |
| all_sample = [] | |
| for i_sample in range(n_sample): | |
| sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,]) | |
| assert sampled_feature.shape[1] % n_point_per_face == 0 | |
| sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1]) | |
| sampled_feature = torch.mean(sampled_feature, axis=-2) | |
| all_sample.append(sampled_feature) | |
| return torch.cat(all_sample, dim=1) | |
| if self.cfg.vertex_feature: | |
| tensor_vertices = batch['vertices'][0].reshape(1, -1, 3).to(torch.float32) | |
| point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, 1) | |
| else: | |
| n_point_per_face = self.cfg.n_point_per_face | |
| tensor_vertices = sample_points(batch['vertices'][0], batch['faces'][0], n_point_per_face) | |
| tensor_vertices = tensor_vertices.reshape(1, -1, 3).to(torch.float32) | |
| point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face) # N, M, C | |
| #### Take mean feature in the triangle | |
| print("Time elapsed for feature prediction: " + str(time.time() - starttime)) | |
| point_feat = point_feat.reshape(-1, 448).cpu().numpy() | |
| np.save(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy', point_feat) | |
| print(f"Exported part_feat_{uid}_{view_id}.npy") | |
| ########### | |
| from sklearn.decomposition import PCA | |
| data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) | |
| pca = PCA(n_components=3) | |
| data_reduced = pca.fit_transform(data_scaled) | |
| data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) | |
| colors_255 = (data_reduced * 255).astype(np.uint8) | |
| V = batch['vertices'][0].cpu().numpy() | |
| F = batch['faces'][0].cpu().numpy() | |
| if self.cfg.vertex_feature: | |
| colored_mesh = trimesh.Trimesh(vertices=V, faces=F, vertex_colors=colors_255, process=False) | |
| else: | |
| colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False) | |
| colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply') | |
| ############ | |
| torch.cuda.empty_cache() | |
| else: | |
| ### Mesh input (obj file) | |
| V = batch['vertices'][0].cpu().numpy() | |
| F = batch['faces'][0].cpu().numpy() | |
| ##### Loop through faces ##### | |
| num_samples_per_face = self.cfg.n_point_per_face | |
| all_point_feats = [] | |
| for face in F: | |
| # Get the vertices of the current face | |
| v0, v1, v2 = V[face] | |
| # Generate random barycentric coordinates | |
| u = np.random.rand(num_samples_per_face, 1) | |
| v = np.random.rand(num_samples_per_face, 1) | |
| is_prob = (u+v) >1 | |
| u[is_prob] = 1 - u[is_prob] | |
| v[is_prob] = 1 - v[is_prob] | |
| w = 1 - u - v | |
| # Calculate points in Cartesian coordinates | |
| points = u * v0 + v * v1 + w * v2 | |
| tensor_vertices = torch.from_numpy(points.copy()).reshape(1, -1, 3).cuda().to(torch.float32) | |
| point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C | |
| #### Take mean feature in the triangle | |
| point_feat = torch.mean(point_feat, axis=1).cpu().detach().numpy() | |
| all_point_feats.append(point_feat) | |
| ############################## | |
| all_point_feats = np.array(all_point_feats).reshape(-1, 448) | |
| point_feat = all_point_feats | |
| np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat) | |
| print(f"Exported part_feat_{uid}_{view_id}.npy") | |
| ########### | |
| from sklearn.decomposition import PCA | |
| data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) | |
| pca = PCA(n_components=3) | |
| data_reduced = pca.fit_transform(data_scaled) | |
| data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) | |
| colors_255 = (data_reduced * 255).astype(np.uint8) | |
| colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False) | |
| colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply') | |
| ############ | |
| print("Time elapsed: " + str(time.time()-starttime)) | |
| return |