Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from einops import rearrange | |
| import tqdm | |
| from pytorch3d.ops.knn import knn_gather, knn_points | |
| from pytorch3d.transforms import so3_exponential_map | |
| from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix, matrix_to_quaternion | |
| from simple_knn._C import distCUDA2 | |
| from GHA.lib.network.MLP import MLP | |
| from GHA.lib.network.PositionalEmbedding import get_embedder | |
| from GHA.lib.utils.general_utils import inverse_sigmoid | |
| class GaussianHeadModule(nn.Module): | |
| def __init__(self, cfg, xyz, feature, landmarks_3d_neutral, add_mouth_points=False): | |
| super(GaussianHeadModule, self).__init__() | |
| if add_mouth_points and cfg.num_add_mouth_points > 0: | |
| mouth_keypoints = landmarks_3d_neutral[48:66] | |
| mouth_center = torch.mean(mouth_keypoints, dim=0, keepdim=True) | |
| mouth_center[:, 2] = mouth_keypoints[:, 2].min() | |
| max_dist = (mouth_keypoints - mouth_center).abs().max(0)[0] | |
| points_add = (torch.rand([cfg.num_add_mouth_points, 3]) - 0.5) * 1.6 * max_dist + mouth_center | |
| xyz = torch.cat([xyz, points_add]) | |
| feature = torch.cat([feature, torch.zeros([cfg.num_add_mouth_points, feature.shape[1]])]) | |
| self.xyz = nn.Parameter(xyz) | |
| self.feature = nn.Parameter(feature) | |
| self.register_buffer('landmarks_3d_neutral', landmarks_3d_neutral) | |
| dist2 = torch.clamp_min(distCUDA2(self.xyz.cuda()), 0.0000001).cpu() | |
| scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) | |
| self.scales = nn.Parameter(scales) | |
| rots = torch.zeros((xyz.shape[0], 4), device=xyz.device) | |
| rots[:, 0] = 1 | |
| self.rotation = nn.Parameter(rots) | |
| self.opacity = nn.Parameter(inverse_sigmoid(0.3 * torch.ones((xyz.shape[0], 1)))) | |
| self.exp_color_mlp = MLP(cfg.exp_color_mlp, last_op=None) | |
| self.pose_color_mlp = MLP(cfg.pose_color_mlp, last_op=None) | |
| self.exp_attributes_mlp = MLP(cfg.exp_attributes_mlp, last_op=None) | |
| self.pose_attributes_mlp = MLP(cfg.pose_attributes_mlp, last_op=None) | |
| self.exp_deform_mlp = MLP(cfg.exp_deform_mlp, last_op=nn.Tanh()) | |
| self.pose_deform_mlp = MLP(cfg.pose_deform_mlp, last_op=nn.Tanh()) | |
| self.pos_embedding, _ = get_embedder(cfg.pos_freq) | |
| self.exp_coeffs_dim = cfg.exp_coeffs_dim | |
| self.dist_threshold_near = cfg.dist_threshold_near | |
| self.dist_threshold_far = cfg.dist_threshold_far | |
| self.deform_scale = cfg.deform_scale | |
| self.attributes_scale = cfg.attributes_scale | |
| def generate(self, data): | |
| B = data['exp_coeff'].shape[0] | |
| xyz = self.xyz.unsqueeze(0).repeat(B, 1, 1) | |
| feature = torch.tanh(self.feature).unsqueeze(0).repeat(B, 1, 1) | |
| dists, _, _ = knn_points(xyz, self.landmarks_3d_neutral.unsqueeze(0).repeat(B, 1, 1)) | |
| exp_weights = torch.clamp((self.dist_threshold_far - dists) / (self.dist_threshold_far - self.dist_threshold_near), 0.0, 1.0) | |
| pose_weights = 1 - exp_weights | |
| exp_controlled = (dists < self.dist_threshold_far).squeeze(-1) | |
| pose_controlled = (dists > self.dist_threshold_near).squeeze(-1) | |
| color = torch.zeros([B, xyz.shape[1], self.exp_color_mlp.dims[-1]], device=xyz.device) | |
| delta_xyz = torch.zeros_like(xyz, device=xyz.device) | |
| delta_attributes = torch.zeros([B, xyz.shape[1], self.scales.shape[1] + self.rotation.shape[1] + self.opacity.shape[1]], device=xyz.device) | |
| for b in range(B): | |
| # print(B) | |
| feature_exp_controlled = feature[b, exp_controlled[b], :] | |
| exp_color_input = torch.cat([feature_exp_controlled.t(), | |
| data['exp_coeff'][b].unsqueeze(-1).repeat(1, feature_exp_controlled.shape[0])], 0)[None] | |
| exp_color = self.exp_color_mlp(exp_color_input)[0].t() | |
| color[b, exp_controlled[b], :] += exp_color * exp_weights[b, exp_controlled[b], :] | |
| feature_pose_controlled = feature[b, pose_controlled[b], :] | |
| pose_color_input = torch.cat([feature_pose_controlled.t(), | |
| self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, feature_pose_controlled.shape[0])], 0)[None] | |
| pose_color = self.pose_color_mlp(pose_color_input)[0].t() | |
| color[b, pose_controlled[b], :] += pose_color * pose_weights[b, pose_controlled[b], :] | |
| exp_attributes_input = exp_color_input | |
| exp_delta_attributes = self.exp_attributes_mlp(exp_attributes_input)[0].t() | |
| delta_attributes[b, exp_controlled[b], :] += exp_delta_attributes * exp_weights[b, exp_controlled[b], :] | |
| pose_attributes_input = pose_color_input | |
| pose_attributes = self.pose_attributes_mlp(pose_attributes_input)[0].t() | |
| delta_attributes[b, pose_controlled[b], :] += pose_attributes * pose_weights[b, pose_controlled[b], :] | |
| xyz_exp_controlled = xyz[b, exp_controlled[b], :] | |
| exp_deform_input = torch.cat([self.pos_embedding(xyz_exp_controlled).t(), | |
| data['exp_coeff'][b].unsqueeze(-1).repeat(1, xyz_exp_controlled.shape[0])], 0)[None] | |
| exp_deform = self.exp_deform_mlp(exp_deform_input)[0].t() | |
| delta_xyz[b, exp_controlled[b], :] += exp_deform * exp_weights[b, exp_controlled[b], :] | |
| xyz_pose_controlled = xyz[b, pose_controlled[b], :] | |
| pose_deform_input = torch.cat([self.pos_embedding(xyz_pose_controlled).t(), | |
| self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, xyz_pose_controlled.shape[0])], 0)[None] | |
| pose_deform = self.pose_deform_mlp(pose_deform_input)[0].t() | |
| delta_xyz[b, pose_controlled[b], :] += pose_deform * pose_weights[b, pose_controlled[b], :] | |
| xyz = xyz + delta_xyz * self.deform_scale | |
| delta_scales = delta_attributes[:, :, 0:3] | |
| scales = self.scales.unsqueeze(0).repeat(B, 1, 1) + delta_scales * self.attributes_scale | |
| scales = torch.exp(scales) | |
| delta_rotation = delta_attributes[:, :, 3:7] | |
| rotation = self.rotation.unsqueeze(0).repeat(B, 1, 1) + delta_rotation * self.attributes_scale | |
| rotation = torch.nn.functional.normalize(rotation, dim=2) | |
| delta_opacity = delta_attributes[:, :, 7:8] | |
| opacity = self.opacity.unsqueeze(0).repeat(B, 1, 1) + delta_opacity * self.attributes_scale | |
| opacity = torch.sigmoid(opacity) | |
| if 'pose' in data: | |
| R = so3_exponential_map(data['pose'][:, :3]) | |
| T = data['pose'][:, None, 3:] | |
| S = data['scale'][:, :, None] | |
| xyz = torch.bmm(xyz * S, R.permute(0, 2, 1)) + T | |
| rotation_matrix = quaternion_to_matrix(rotation) | |
| rotation_matrix = rearrange(rotation_matrix, 'b n x y -> (b n) x y') | |
| R = rearrange(R.unsqueeze(1).repeat(1, rotation.shape[1], 1, 1), 'b n x y -> (b n) x y') | |
| rotation_matrix = rearrange(torch.bmm(R, rotation_matrix), '(b n) x y -> b n x y', b=B) | |
| rotation = matrix_to_quaternion(rotation_matrix) | |
| scales = scales * S | |
| data['exp_deform'] = exp_deform | |
| data['xyz'] = xyz | |
| data['color'] = color | |
| data['scales'] = scales | |
| data['rotation'] = rotation | |
| data['opacity'] = opacity | |
| return data | |