# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import List, Tuple, Optional EPS = 1e-8 def normalize_quaternion(quat: torch.Tensor, eps: float = EPS) -> torch.Tensor: """ Normalize quaternions to unit length. Args: quat: Quaternion tensor of shape (..., 4) with (w, x, y, z) format eps: Small value for numerical stability Returns: Normalized quaternions of same shape """ norm = torch.norm(quat, dim=-1, keepdim=True) return quat / torch.clamp(norm, min=eps) def quat_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """ Multiply two quaternions using Hamilton product. """ w1, x1, y1, z1 = torch.unbind(q1, dim=-1) w2, x2, y2, z2 = torch.unbind(q2, dim=-1) w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 return torch.stack((w, x, y, z), dim=-1) def quat_conjugate(quat: torch.Tensor) -> torch.Tensor: """ Compute quaternion conjugate. """ w, xyz = quat[..., :1], quat[..., 1:] return torch.cat([w, -xyz], dim=-1) def quat_inverse(quat: torch.Tensor, eps: float = EPS) -> torch.Tensor: """ Compute quaternion inverse. """ conjugate = quat_conjugate(quat) norm_squared = torch.sum(quat * quat, dim=-1, keepdim=True) return conjugate / torch.clamp(norm_squared, min=eps) def quat_log(quat: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: """ Compute quaternion logarithm, mapping to rotation vectors (axis-angle). """ # quat_norm = normalize_quaternion(quat, eps) q_norm = torch.sqrt(torch.sum(quat * quat, dim=-1, keepdim=True)) quat_norm = quat / torch.clamp(q_norm, min=eps) w = quat_norm[..., 0:1] # Scalar part xyz = quat_norm[..., 1:] # Vector part xyz_norm = torch.norm(xyz, dim=-1, keepdim=True) w_clamped = torch.clamp(w, min=-1.0 + eps, max=1.0 - eps) # half-angle half_angle = torch.acos(torch.abs(w_clamped)) safe_xyz_norm = torch.clamp(xyz_norm, min=eps) # Scale factor scale = torch.where( xyz_norm < eps, torch.ones_like(xyz_norm), half_angle / safe_xyz_norm ) # Handle quaternion sign ambiguity (q and -q represent same rotation) sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)) rotation_vector = sign * scale * xyz return rotation_vector def quat_rotate_vector(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """ Rotate a 3D vector by a quaternion. """ q_vec = quat[..., 1:] # vector part q_w = quat[..., 0:1] # scalar part cross1 = torch.cross(q_vec, vec, dim=-1) cross2 = torch.cross(q_vec, cross1, dim=-1) # Apply the rotation formula rotated_vec = vec + 2.0 * q_w * cross1 + 2.0 * cross2 return rotated_vec def quat_to_rotation_matrix(quat: torch.Tensor, eps: float = EPS) -> torch.Tensor: """ Convert quaternions to rotation matrices. """ quat_norm = normalize_quaternion(quat, eps) w, x, y, z = torch.unbind(quat_norm, dim=-1) xx, yy, zz = x * x, y * y, z * z xy, xz, yz = x * y, x * z, y * z wx, wy, wz = w * x, w * y, w * z r00 = 1.0 - 2.0 * (yy + zz) r01 = 2.0 * (xy - wz) r02 = 2.0 * (xz + wy) r10 = 2.0 * (xy + wz) r11 = 1.0 - 2.0 * (xx + zz) r12 = 2.0 * (yz - wx) r20 = 2.0 * (xz - wy) r21 = 2.0 * (yz + wx) r22 = 1.0 - 2.0 * (xx + yy) rotation_matrix = torch.stack([ r00, r01, r02, r10, r11, r12, r20, r21, r22 ], dim=-1) return rotation_matrix.reshape(quat.shape[:-1] + (3, 3)) def quat_to_transform_matrix(quat: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: """ Convert quaternion and position to 4x4 transformation matrix. """ # rotation part rotation = quat_to_rotation_matrix(quat) batch_shape = rotation.shape[:-2] # homogeneous transformation matrix transform = torch.zeros(batch_shape + (4, 4), dtype=rotation.dtype, device=rotation.device) transform[..., :3, :3] = rotation transform[..., :3, 3] = pos transform[..., 3, 3] = 1.0 return transform def compute_rest_local_positions( joint_positions: torch.Tensor, parent_indices: List[int] ) -> torch.Tensor: """ Compute local positions relative to parent joints from global joint positions. """ num_joints = joint_positions.shape[0] local_positions = torch.zeros_like(joint_positions) for j in range(num_joints): parent_idx = parent_indices[j] if parent_idx >= 0 and parent_idx != j and parent_idx < num_joints: # Child joint: local offset = global_pos - parent_global_pos local_positions[j] = joint_positions[j] - joint_positions[parent_idx] else: # Root joint: use global position as local position local_positions[j] = joint_positions[j] return local_positions