seungminkwak's picture
reset: clean history (purge leaked token)
08b23ce
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import List, Optional, Tuple, Union
from collections import deque
from pytorch3d.structures import Meshes, join_meshes_as_scene
from pytorch3d.renderer import TexturesVertex, TexturesUV
from utils.quat_utils import quat_to_transform_matrix, quat_multiply, quat_rotate_vector
class RiggingModel:
"""
A 3D rigged model supporting skeletal animation.
Handles mesh geometry, skeletal hierarchy, skinning weights, and
linear blend skinning (LBS) deformation.
"""
def __init__(self, device = "cuda:0"):
self.device = device
# Mesh data
self.vertices: List[torch.Tensor] = []
self.faces: List[torch.Tensor] = []
self.textures: List[Union[TexturesVertex, TexturesUV]] = []
# Skeletal data
self.bones: Optional[torch.Tensor] = None # (N, 2) [parent, child] pairs
self.parent_indices: Optional[torch.Tensor] = None # (J,) parent index for each joint
self.root_index: Optional[int] = None # Root joint index
self.joints_rest: Optional[torch.Tensor] = None # (J, 3) rest pose positions
self.skin_weights: List[torch.Tensor] = [] # List of (V_i, J) skinning weights
# Fixed local positions
self.rest_local_positions: Optional[torch.Tensor] = None # (J, 3)
# Computed data
self.bind_matrices_inv: Optional[torch.Tensor] = None # (J, 4, 4) inverse bind matrices
self.deformed_vertices: Optional[List[torch.Tensor]] = None # List of (T, V_i, 3)
self.joint_positions: Optional[torch.Tensor] = None # (T, J, 3) current joint positions
# Validation flags
self._bind_matrices_initialized = False
def initialize_bind_matrices(self, rest_local_pos):
"""Initialize bind matrices and store rest local positions."""
self.rest_local_positions = rest_local_pos.to(self.device)
J = rest_local_pos.shape[0]
rest_global_quats, rest_global_pos = self.forward_kinematics(
torch.tensor([[[1.0, 0.0, 0.0, 0.0]] * J], device=self.device), # unit quaternion
self.parent_indices,
self.root_index
)
bind_matrices = quat_to_transform_matrix(rest_global_quats, rest_global_pos) # (1,J,4,4)
self.bind_matrices_inv = torch.inverse(bind_matrices.squeeze(0)) # (J,4,4)
self._bind_matrices_initialized = True
def animate(self, local_quaternions, root_quaternion = None, root_position = None):
"""
Animate the model using local joint transformations.
Args:
local_quaternions: (T, J, 4) local rotations per frame
root_quaternion: (T, 4) global root rotation
root_position: (T, 3) global root translation
"""
if not self._bind_matrices_initialized:
raise RuntimeError("Bind matrices not initialized. Call initialize_bind_matrices() first.")
# Forward kinematics
global_quats, global_pos = self.forward_kinematics(
local_quaternions,
self.parent_indices,
self.root_index
)
self.joint_positions = global_pos
joint_transforms = quat_to_transform_matrix(global_quats, global_pos) # (T, J, 4, 4)
# Apply global root transformation if provided
if root_quaternion is not None and root_position is not None:
root_transform = quat_to_transform_matrix(root_quaternion, root_position)
joint_transforms = root_transform[:, None] @ joint_transforms
self.joint_positions = joint_transforms[..., :3, 3]
# Linear blend skinning
self.deformed_vertices = []
for i, vertices in enumerate(self.vertices):
deformed = self._linear_blend_skinning(
vertices,
joint_transforms,
self.skin_weights[i],
self.bind_matrices_inv
)
self.deformed_vertices.append(deformed)
def get_mesh(self, frame_idx=None):
meshes = []
for i in range(len(self.vertices)):
mesh = Meshes(
verts=[self.vertices[i]] if frame_idx is None or self.deformed_vertices is None else [self.deformed_vertices[i][frame_idx]],
faces=[self.faces[i]],
textures=self.textures[i]
)
meshes.append(mesh)
return join_meshes_as_scene(meshes)
def _linear_blend_skinning(self, vertices, joint_transforms, skin_weights, bind_matrices_inv):
"""
Apply linear blend skinning to vertices.
Args:
vertices: (V, 3) vertex positions
joint_transforms: (T, J, 4, 4) joint transformation matrices
skin_weights: (V, J) per-vertex joint weights
bind_matrices_inv: (J, 4, 4) inverse bind matrices
Returns:
(T, V, 3) deformed vertices
"""
# Compute final transformation matrices
transforms = torch.matmul(joint_transforms, bind_matrices_inv) # (T, J, 4, 4)
# Weight and blend transformations
weighted_transforms = torch.einsum('vj,tjab->tvab', skin_weights, transforms) # (T, V, 4, 4)
# Apply to vertices
vertices_hom = torch.cat([vertices, torch.ones(vertices.shape[0], 1, device=vertices.device)], dim=-1)
deformed = torch.matmul(weighted_transforms, vertices_hom.unsqueeze(-1)).squeeze(-1)
return deformed[..., :3]
def forward_kinematics(self, local_quaternions, parent_indices, root_index = 0):
"""
Compute global joint transformations from local ones.
Args:
local_quaternions: (B, J, 4) local rotations
parent_indices: (J,) parent index for each joint
root_index: Root joint index
Returns:
Tuple of (global_quaternions, global_positions)
"""
B, J = local_quaternions.shape[:2]
local_positions = self.rest_local_positions.unsqueeze(0).expand(B, -1, -1)
# Initialize storage
global_quats = [None] * J
global_positions = [None] * J
# Build children mapping
children = [[] for _ in range(J)]
for child_idx in range(J):
parent_idx = parent_indices[child_idx]
if parent_idx >= 0:
children[parent_idx].append(child_idx)
# Breadth-first traversal from root
queue = deque([root_index])
visited = {root_index}
# Process root
global_quats[root_index] = local_quaternions[:, root_index]
global_positions[root_index] = local_positions[:, root_index]
while queue:
current = queue.popleft()
current_quat = global_quats[current]
current_pos = global_positions[current]
for child in children[current]:
if child not in visited:
visited.add(child)
queue.append(child)
# Transform child to global space
child_quat = quat_multiply(current_quat, local_quaternions[:, child])
child_pos = quat_rotate_vector(current_quat, local_positions[:, child]) + current_pos
global_quats[child] = child_quat
global_positions[child] = child_pos
return torch.stack(global_quats, dim=1), torch.stack(global_positions, dim=1)