# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from third_partys.Video_Depth_Anything.video_depth_anything.video_depth import VideoDepthAnything import torch import torch.nn as nn import numpy as np import igl import cv2 import time import torch.nn.functional as F from utils.quat_utils import quat_inverse, quat_log, quat_multiply, normalize_quaternion from pytorch3d.structures import join_meshes_as_scene, join_meshes_as_batch import os from pathlib import Path class DepthModule: def __init__(self, encoder='vitl', device='cuda', input_size=518, fp32=False): """ Initialize the depth loss module with Video Depth Anything Args: encoder: 'vitl' or 'vits' device: device to run the model on input_size: input size for the model fp32: whether to use float32 for inference """ self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.input_size = input_size self.fp32 = fp32 # Initialize model configuration model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, } # Load Video Depth Anything model self.video_depth_model = VideoDepthAnything(**model_configs[encoder]) self.video_depth_model.load_state_dict( torch.load(f'./third_partys/Video_Depth_Anything/ckpt/video_depth_anything_{encoder}.pth', map_location='cpu'), strict=True ) self.video_depth_model = self.video_depth_model.to(self.device).eval() for param in self.video_depth_model.parameters(): param.requires_grad = False def get_depth_maps(self, frames, target_fps=30): """ Get depth maps for video frames """ depths, _ = self.video_depth_model.infer_video_depth( frames, target_fps, input_size=self.input_size, device=self.device, fp32=self.fp32 ) return depths def save_depth_as_images(depth_np, output_dir='./depth_images'): """save depth images""" os.makedirs(output_dir, exist_ok=True) for i, depth_map in enumerate(depth_np): depth_map = depth_map.detach().cpu().numpy() valid_mask = (depth_map > 0) if not valid_mask.any(): continue valid_min = depth_map[valid_mask].min() valid_max = depth_map[valid_mask].max() normalized = np.zeros_like(depth_map) normalized[valid_mask] = 255.0 * (depth_map[valid_mask] - valid_min) / (valid_max - valid_min) depth_img = normalized.astype(np.uint8) cv2.imwrite(os.path.join(output_dir, f'depth_{i:04d}.png'), depth_img) print(f"Save {len(depth_np)} depth images to {output_dir}") def compute_visibility_mask_igl(ray_origins, ray_dirs, distances, verts, faces, distance_tolerance=1e-6, for_vertices=False): """ Compute visibility mask using IGL ray-mesh intersection. """ num_rays = ray_origins.shape[0] visibility_mask = np.ones(num_rays, dtype=bool) for i in range(num_rays): ray_origin = ray_origins[i].reshape(1, 3) ray_dir = ray_dirs[i].reshape(1, 3) intersections = igl.ray_mesh_intersect(ray_origin, ray_dir, verts, faces) if intersections: # Count intersections that occur before the target point count = sum(1 for h in intersections if h[4] < distances[i] - distance_tolerance) # count=0 → ray completely missed the mesh; count=1 → ray stops exactly at the face containing the joint # count>1 → ray was blocked by other faces along the way if for_vertices: if count != 1: visibility_mask[i] = False else: # for joints if count > 2: visibility_mask[i] = False return visibility_mask def compute_reprojection_loss(renderer, vis_mask, predicted_joints, tracked_joints_2d, image_size): """ Compute reprojection loss between predicted 3D points and tracked 2D points. """ if predicted_joints.dim() != 3: raise ValueError(f"predicted_joints must be 3D tensor, got shape {predicted_joints.shape}") B, J, _ = predicted_joints.shape device = predicted_joints.device # Project 3D joints to 2D screen coordinates projected = renderer.camera.transform_points_screen( predicted_joints, image_size=[image_size, image_size] ) projected_2d = projected[..., :2] # (B, J, 2) # Convert and validate tracked joints if not isinstance(tracked_joints_2d, torch.Tensor): tracked_joints_2d = torch.from_numpy(tracked_joints_2d).float() tracked_joints_2d = tracked_joints_2d.to(device) if tracked_joints_2d.dim() == 2: tracked_joints_2d = tracked_joints_2d.unsqueeze(0).expand(B, -1, -1) vis_mask = vis_mask.to(device).float() num_visible = vis_mask.sum() if num_visible == 0: # No visible joints - return zero loss return torch.tensor(0.0, device=device, requires_grad=True) squared_diff = (projected_2d - tracked_joints_2d).pow(2).sum(dim=-1) # (B, J) vis_mask_expanded = vis_mask.unsqueeze(0) # (1, J) masked_loss = squared_diff * vis_mask_expanded # (B, J) per_frame_loss = masked_loss.sum(dim=1) / num_visible # (B,) final_loss = per_frame_loss.mean() # scalar return final_loss def geodesic_loss(q1, q2, eps=1e-6): """ Compute geodesic distance loss between batches of quaternions for rot smooth loss. """ q1_norm = normalize_quaternion(q1, eps=eps) q2_norm = normalize_quaternion(q2, eps=eps) dot_product = (q1_norm * q2_norm).sum(dim=-1, keepdim=True) q2_corrected = torch.where(dot_product < 0, -q2_norm, q2_norm) inner_product = (q1_norm * q2_corrected).sum(dim=-1) # Clamp to valid range for arccos to avoid numerical issues inner_product_clamped = torch.clamp(inner_product, min=-1.0 + eps, max=1.0 - eps) theta = 2.0 * torch.arccos(torch.abs(inner_product_clamped)) return theta def root_motion_reg(root_quats, root_pos): return ((root_pos[1:] - root_pos[:-1])**2).mean(), (geodesic_loss(root_quats[1:], root_quats[:-1])**2).mean() def joint_motion_coherence(quats_normed, parent_idx): """ Compute joint motion coherence loss to enforce smooth relative motion between parent-child joints. """ coherence_loss = 0 for j, parent in enumerate(parent_idx): if parent != -1: # Skip root joint parent_rot = quats_normed[:, parent] # (T, 4) child_rot = quats_normed[:, j] # (T, 4) # Compute relative rotation of child w.r.t. parent's local frame # local_rot = parent_rot^(-1) * child_rot local_rot = quat_multiply(quat_inverse(parent_rot), child_rot) local_rot_velocity = local_rot[1:] - local_rot[:-1] # (T-1, 4) coherence_loss += local_rot_velocity.pow(2).mean() return coherence_loss def read_flo_file(file_path): """ Read optical flow from .flo format file. """ with open(file_path, 'rb') as f: magic = np.fromfile(f, np.float32, count=1) if len(magic) == 0 or magic[0] != 202021.25: raise ValueError(f'Invalid .flo file format: magic number {magic}') w = np.fromfile(f, np.int32, count=1)[0] h = np.fromfile(f, np.int32, count=1)[0] data = np.fromfile(f, np.float32, count=2*w*h) flow = data.reshape(h, w, 2) return flow def load_optical_flows(flow_dir, num_frames): """ Load sequence of optical flow files. """ flow_dir = Path(flow_dir) flows = [] for i in range(num_frames - 1): flow_path = flow_dir / f'flow_{i:04d}.flo' if flow_path.exists(): flow = read_flo_file(flow_path) flows.append(flow) else: raise ValueError("No flow files found") return np.stack(flows, axis=0) def rasterize_vertex_flow(flow_vertices, meshes, faces, image_size, renderer, eps = 1e-8): """ Rasterize per-vertex flow to dense flow field using barycentric interpolation. """ B, V, _ = flow_vertices.shape device = flow_vertices.device if isinstance(image_size, int): H = W = image_size else: H, W = image_size batch_meshes = join_meshes_as_batch([join_meshes_as_scene(m) for m in meshes]).to(device) fragments = renderer.renderer.rasterizer(batch_meshes) pix_to_face = fragments.pix_to_face # (B, H, W, K) bary_coords = fragments.bary_coords # (B, H, W, K, 3) flow_scene_list = [] for mesh_idx in range(B): mesh = meshes[mesh_idx] V_mesh = mesh.verts_packed().shape[0] if V_mesh > flow_vertices.shape[1]: raise ValueError(f"Mesh {mesh_idx} has {V_mesh} vertices but flow has {flow_vertices.shape[1]}") flow_scene_list.append(flow_vertices[mesh_idx, :V_mesh]) flow_vertices_scene = torch.cat(flow_scene_list, dim=0).to(device) faces_scene = batch_meshes.faces_packed() flow_pred = torch.zeros(B, H, W, 2, device=device) valid = pix_to_face[..., 0] >= 0 for b in range(B): b_valid = valid[b] # (H,W) if torch.count_nonzero(b_valid) == 0: print(f"No valid pixels found for batch {b}") continue valid_indices = torch.nonzero(b_valid, as_tuple=True) h_indices, w_indices = valid_indices face_idxs = pix_to_face[b, h_indices, w_indices, 0] # (N,) bary = bary_coords[b, h_indices, w_indices, 0] # (N,3) max_face_idx = faces_scene.shape[0] - 1 if face_idxs.max() > max_face_idx: raise RuntimeError(f"Face index {face_idxs.max()} exceeds max {max_face_idx}") face_verts = faces_scene[face_idxs] # (N, 3) f0, f1, f2 = face_verts.unbind(-1) # Each (N,) max_vert_idx = flow_vertices_scene.shape[0] - 1 if max(f0.max(), f1.max(), f2.max()) > max_vert_idx: raise RuntimeError(f"Vertex index exceeds flow_vertices_scene size {max_vert_idx}") v0_flow = flow_vertices_scene[f0] # (N, 2) v1_flow = flow_vertices_scene[f1] # (N, 2) v2_flow = flow_vertices_scene[f2] # (N, 2) # Interpolate using barycentric coordinates b0, b1, b2 = bary.unbind(-1) # Each (N,) # Ensure barycentric coordinates sum to 1 (numerical stability) bary_sum = b0 + b1 + b2 b0 = b0 / (bary_sum + eps) b1 = b1 / (bary_sum + eps) b2 = b2 / (bary_sum + eps) flow_interpolated = ( b0.unsqueeze(-1) * v0_flow + b1.unsqueeze(-1) * v1_flow + b2.unsqueeze(-1) * v2_flow ) # (N, 2) # Update flow prediction flow_pred[b, h_indices, w_indices] = flow_interpolated return flow_pred def calculate_flow_loss(flow_dir, device, mask, renderer, model): """ Calculate optical flow loss with improved error handling and flexibility. """ if device is None: device = mask.device T = mask.shape[0] H, W = mask.shape[1:3] if mask.shape[0] == T: flow_mask = mask[1:] # Use frames 1 to T-1 else: flow_mask = mask flows_np = load_optical_flows(flow_dir, T) flow_gt = torch.from_numpy(flows_np).float().to(device) # [T-1, H, W, 2] vertices = model.deformed_vertices[0] # (T,V,3) # Project vertices to get 2D flow proj_t = renderer.project_points(vertices[:-1]) # (T-1,V,2) in pixels proj_tp = renderer.project_points(vertices[1:]) vertex_flow = proj_tp - proj_t # (T-1,V,2) Δx,Δy meshes = [model.get_mesh(t) for t in range(T)] flow_pred = rasterize_vertex_flow(vertex_flow, meshes, model.faces[0], (H,W), renderer) # (B,H,W,2) eps = 1e-3 diff = (flow_pred - flow_gt) * flow_mask.unsqueeze(-1) # (T-1, H, W, 2) loss = torch.sqrt(diff.pow(2).sum(dim=-1) + eps**2) # Charbonnier loss loss = loss.sum() / (flow_mask.sum() + 1e-6) return loss def normalize_depth_from_reference(depth_maps, reference_idx=0, invalid_value=-1.0, invert=False, eps = 1e-8): """ Normalize depth maps based on a reference frame with improved robustness. """ if depth_maps.dim() != 3: raise ValueError(f"Expected depth_maps with 3 dimensions, got {depth_maps.dim()}") T, H, W = depth_maps.shape device = depth_maps.device reference_depth = depth_maps[reference_idx] valid_mask = ( (reference_depth != invalid_value) & (reference_depth > 1e-8) & # Avoid very small positive values torch.isfinite(reference_depth) # Exclude inf/nan ) valid_values = reference_depth[valid_mask] min_depth = torch.quantile(valid_values, 0.01) # 1st percentile max_depth = torch.quantile(valid_values, 0.99) # 99th percentile depth_range = max_depth - min_depth if depth_range < eps: logger.warning(f"Very small depth range ({depth_range:.6f}), using fallback normalization") min_depth = valid_values.min() max_depth = valid_values.max() depth_range = max(max_depth - min_depth, eps) scale = 1.0 / (max_depth - min_depth) offset = -min_depth * scale all_valid_mask = ( (depth_maps != invalid_value) & (depth_maps > eps) & torch.isfinite(depth_maps) ) normalized_depths = torch.full_like(depth_maps, invalid_value) if all_valid_mask.any(): normalized_values = depth_maps[all_valid_mask] * scale + offset if invert: normalized_values = 1.0 - normalized_values normalized_depths[all_valid_mask] = normalized_values return normalized_depths, scale.item(), offset.item() def compute_depth_loss_normalized(mono_depths, zbuf_depths, mask): """ Compute normalized depth loss. """ device = zbuf_depths.device # Normalize both depth types zbuf_norm, z_scale, z_offset = normalize_depth_from_reference(zbuf_depths) mono_norm, m_scale, m_offset = normalize_depth_from_reference(mono_depths, invert=True) valid_zbuf = (zbuf_norm >= 0) & (zbuf_norm <= 1) valid_mono = (mono_norm >= 0) & (mono_norm <= 1) if mask.dtype != torch.bool: mask = mask > 0.5 combined_mask = mask & valid_zbuf & valid_mono num_valid = combined_mask.sum().item() if num_valid == 0: print("No valid pixels for depth loss computation") return torch.tensor(0.0, device=device, requires_grad=True) depth_diff = (zbuf_norm - mono_norm) * combined_mask.float() loss = (depth_diff**2).sum() / num_valid return loss