Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from typing import List, Optional, Tuple, Union | |
| from collections import deque | |
| from pytorch3d.structures import Meshes, join_meshes_as_scene | |
| from pytorch3d.renderer import TexturesVertex, TexturesUV | |
| from utils.quat_utils import quat_to_transform_matrix, quat_multiply, quat_rotate_vector | |
| class RiggingModel: | |
| """ | |
| A 3D rigged model supporting skeletal animation. | |
| Handles mesh geometry, skeletal hierarchy, skinning weights, and | |
| linear blend skinning (LBS) deformation. | |
| """ | |
| def __init__(self, device = "cuda:0"): | |
| self.device = device | |
| # Mesh data | |
| self.vertices: List[torch.Tensor] = [] | |
| self.faces: List[torch.Tensor] = [] | |
| self.textures: List[Union[TexturesVertex, TexturesUV]] = [] | |
| # Skeletal data | |
| self.bones: Optional[torch.Tensor] = None # (N, 2) [parent, child] pairs | |
| self.parent_indices: Optional[torch.Tensor] = None # (J,) parent index for each joint | |
| self.root_index: Optional[int] = None # Root joint index | |
| self.joints_rest: Optional[torch.Tensor] = None # (J, 3) rest pose positions | |
| self.skin_weights: List[torch.Tensor] = [] # List of (V_i, J) skinning weights | |
| # Fixed local positions | |
| self.rest_local_positions: Optional[torch.Tensor] = None # (J, 3) | |
| # Computed data | |
| self.bind_matrices_inv: Optional[torch.Tensor] = None # (J, 4, 4) inverse bind matrices | |
| self.deformed_vertices: Optional[List[torch.Tensor]] = None # List of (T, V_i, 3) | |
| self.joint_positions: Optional[torch.Tensor] = None # (T, J, 3) current joint positions | |
| # Validation flags | |
| self._bind_matrices_initialized = False | |
| def initialize_bind_matrices(self, rest_local_pos): | |
| """Initialize bind matrices and store rest local positions.""" | |
| self.rest_local_positions = rest_local_pos.to(self.device) | |
| J = rest_local_pos.shape[0] | |
| rest_global_quats, rest_global_pos = self.forward_kinematics( | |
| torch.tensor([[[1.0, 0.0, 0.0, 0.0]] * J], device=self.device), # unit quaternion | |
| self.parent_indices, | |
| self.root_index | |
| ) | |
| bind_matrices = quat_to_transform_matrix(rest_global_quats, rest_global_pos) # (1,J,4,4) | |
| self.bind_matrices_inv = torch.inverse(bind_matrices.squeeze(0)) # (J,4,4) | |
| self._bind_matrices_initialized = True | |
| def animate(self, local_quaternions, root_quaternion = None, root_position = None): | |
| """ | |
| Animate the model using local joint transformations. | |
| Args: | |
| local_quaternions: (T, J, 4) local rotations per frame | |
| root_quaternion: (T, 4) global root rotation | |
| root_position: (T, 3) global root translation | |
| """ | |
| if not self._bind_matrices_initialized: | |
| raise RuntimeError("Bind matrices not initialized. Call initialize_bind_matrices() first.") | |
| # Forward kinematics | |
| global_quats, global_pos = self.forward_kinematics( | |
| local_quaternions, | |
| self.parent_indices, | |
| self.root_index | |
| ) | |
| self.joint_positions = global_pos | |
| joint_transforms = quat_to_transform_matrix(global_quats, global_pos) # (T, J, 4, 4) | |
| # Apply global root transformation if provided | |
| if root_quaternion is not None and root_position is not None: | |
| root_transform = quat_to_transform_matrix(root_quaternion, root_position) | |
| joint_transforms = root_transform[:, None] @ joint_transforms | |
| self.joint_positions = joint_transforms[..., :3, 3] | |
| # Linear blend skinning | |
| self.deformed_vertices = [] | |
| for i, vertices in enumerate(self.vertices): | |
| deformed = self._linear_blend_skinning( | |
| vertices, | |
| joint_transforms, | |
| self.skin_weights[i], | |
| self.bind_matrices_inv | |
| ) | |
| self.deformed_vertices.append(deformed) | |
| def get_mesh(self, frame_idx=None): | |
| meshes = [] | |
| for i in range(len(self.vertices)): | |
| mesh = Meshes( | |
| verts=[self.vertices[i]] if frame_idx is None or self.deformed_vertices is None else [self.deformed_vertices[i][frame_idx]], | |
| faces=[self.faces[i]], | |
| textures=self.textures[i] | |
| ) | |
| meshes.append(mesh) | |
| return join_meshes_as_scene(meshes) | |
| def _linear_blend_skinning(self, vertices, joint_transforms, skin_weights, bind_matrices_inv): | |
| """ | |
| Apply linear blend skinning to vertices. | |
| Args: | |
| vertices: (V, 3) vertex positions | |
| joint_transforms: (T, J, 4, 4) joint transformation matrices | |
| skin_weights: (V, J) per-vertex joint weights | |
| bind_matrices_inv: (J, 4, 4) inverse bind matrices | |
| Returns: | |
| (T, V, 3) deformed vertices | |
| """ | |
| # Compute final transformation matrices | |
| transforms = torch.matmul(joint_transforms, bind_matrices_inv) # (T, J, 4, 4) | |
| # Weight and blend transformations | |
| weighted_transforms = torch.einsum('vj,tjab->tvab', skin_weights, transforms) # (T, V, 4, 4) | |
| # Apply to vertices | |
| vertices_hom = torch.cat([vertices, torch.ones(vertices.shape[0], 1, device=vertices.device)], dim=-1) | |
| deformed = torch.matmul(weighted_transforms, vertices_hom.unsqueeze(-1)).squeeze(-1) | |
| return deformed[..., :3] | |
| def forward_kinematics(self, local_quaternions, parent_indices, root_index = 0): | |
| """ | |
| Compute global joint transformations from local ones. | |
| Args: | |
| local_quaternions: (B, J, 4) local rotations | |
| parent_indices: (J,) parent index for each joint | |
| root_index: Root joint index | |
| Returns: | |
| Tuple of (global_quaternions, global_positions) | |
| """ | |
| B, J = local_quaternions.shape[:2] | |
| local_positions = self.rest_local_positions.unsqueeze(0).expand(B, -1, -1) | |
| # Initialize storage | |
| global_quats = [None] * J | |
| global_positions = [None] * J | |
| # Build children mapping | |
| children = [[] for _ in range(J)] | |
| for child_idx in range(J): | |
| parent_idx = parent_indices[child_idx] | |
| if parent_idx >= 0: | |
| children[parent_idx].append(child_idx) | |
| # Breadth-first traversal from root | |
| queue = deque([root_index]) | |
| visited = {root_index} | |
| # Process root | |
| global_quats[root_index] = local_quaternions[:, root_index] | |
| global_positions[root_index] = local_positions[:, root_index] | |
| while queue: | |
| current = queue.popleft() | |
| current_quat = global_quats[current] | |
| current_pos = global_positions[current] | |
| for child in children[current]: | |
| if child not in visited: | |
| visited.add(child) | |
| queue.append(child) | |
| # Transform child to global space | |
| child_quat = quat_multiply(current_quat, local_quaternions[:, child]) | |
| child_pos = quat_rotate_vector(current_quat, local_positions[:, child]) + current_pos | |
| global_quats[child] = child_quat | |
| global_positions[child] = child_pos | |
| return torch.stack(global_quats, dim=1), torch.stack(global_positions, dim=1) | |