seungminkwak's picture
reset: clean history (purge leaked token)
08b23ce
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import List, Tuple, Optional
EPS = 1e-8
def normalize_quaternion(quat: torch.Tensor, eps: float = EPS) -> torch.Tensor:
"""
Normalize quaternions to unit length.
Args:
quat: Quaternion tensor of shape (..., 4) with (w, x, y, z) format
eps: Small value for numerical stability
Returns:
Normalized quaternions of same shape
"""
norm = torch.norm(quat, dim=-1, keepdim=True)
return quat / torch.clamp(norm, min=eps)
def quat_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions using Hamilton product.
"""
w1, x1, y1, z1 = torch.unbind(q1, dim=-1)
w2, x2, y2, z2 = torch.unbind(q2, dim=-1)
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return torch.stack((w, x, y, z), dim=-1)
def quat_conjugate(quat: torch.Tensor) -> torch.Tensor:
"""
Compute quaternion conjugate.
"""
w, xyz = quat[..., :1], quat[..., 1:]
return torch.cat([w, -xyz], dim=-1)
def quat_inverse(quat: torch.Tensor, eps: float = EPS) -> torch.Tensor:
"""
Compute quaternion inverse.
"""
conjugate = quat_conjugate(quat)
norm_squared = torch.sum(quat * quat, dim=-1, keepdim=True)
return conjugate / torch.clamp(norm_squared, min=eps)
def quat_log(quat: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
"""
Compute quaternion logarithm, mapping to rotation vectors (axis-angle).
"""
# quat_norm = normalize_quaternion(quat, eps)
q_norm = torch.sqrt(torch.sum(quat * quat, dim=-1, keepdim=True))
quat_norm = quat / torch.clamp(q_norm, min=eps)
w = quat_norm[..., 0:1] # Scalar part
xyz = quat_norm[..., 1:] # Vector part
xyz_norm = torch.norm(xyz, dim=-1, keepdim=True)
w_clamped = torch.clamp(w, min=-1.0 + eps, max=1.0 - eps)
# half-angle
half_angle = torch.acos(torch.abs(w_clamped))
safe_xyz_norm = torch.clamp(xyz_norm, min=eps)
# Scale factor
scale = torch.where(
xyz_norm < eps,
torch.ones_like(xyz_norm),
half_angle / safe_xyz_norm
)
# Handle quaternion sign ambiguity (q and -q represent same rotation)
sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
rotation_vector = sign * scale * xyz
return rotation_vector
def quat_rotate_vector(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""
Rotate a 3D vector by a quaternion.
"""
q_vec = quat[..., 1:] # vector part
q_w = quat[..., 0:1] # scalar part
cross1 = torch.cross(q_vec, vec, dim=-1)
cross2 = torch.cross(q_vec, cross1, dim=-1)
# Apply the rotation formula
rotated_vec = vec + 2.0 * q_w * cross1 + 2.0 * cross2
return rotated_vec
def quat_to_rotation_matrix(quat: torch.Tensor, eps: float = EPS) -> torch.Tensor:
"""
Convert quaternions to rotation matrices.
"""
quat_norm = normalize_quaternion(quat, eps)
w, x, y, z = torch.unbind(quat_norm, dim=-1)
xx, yy, zz = x * x, y * y, z * z
xy, xz, yz = x * y, x * z, y * z
wx, wy, wz = w * x, w * y, w * z
r00 = 1.0 - 2.0 * (yy + zz)
r01 = 2.0 * (xy - wz)
r02 = 2.0 * (xz + wy)
r10 = 2.0 * (xy + wz)
r11 = 1.0 - 2.0 * (xx + zz)
r12 = 2.0 * (yz - wx)
r20 = 2.0 * (xz - wy)
r21 = 2.0 * (yz + wx)
r22 = 1.0 - 2.0 * (xx + yy)
rotation_matrix = torch.stack([
r00, r01, r02,
r10, r11, r12,
r20, r21, r22
], dim=-1)
return rotation_matrix.reshape(quat.shape[:-1] + (3, 3))
def quat_to_transform_matrix(quat: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
"""
Convert quaternion and position to 4x4 transformation matrix.
"""
# rotation part
rotation = quat_to_rotation_matrix(quat)
batch_shape = rotation.shape[:-2]
# homogeneous transformation matrix
transform = torch.zeros(batch_shape + (4, 4), dtype=rotation.dtype, device=rotation.device)
transform[..., :3, :3] = rotation
transform[..., :3, 3] = pos
transform[..., 3, 3] = 1.0
return transform
def compute_rest_local_positions(
joint_positions: torch.Tensor,
parent_indices: List[int]
) -> torch.Tensor:
"""
Compute local positions relative to parent joints from global joint positions.
"""
num_joints = joint_positions.shape[0]
local_positions = torch.zeros_like(joint_positions)
for j in range(num_joints):
parent_idx = parent_indices[j]
if parent_idx >= 0 and parent_idx != j and parent_idx < num_joints:
# Child joint: local offset = global_pos - parent_global_pos
local_positions[j] = joint_positions[j] - joint_positions[parent_idx]
else:
# Root joint: use global position as local position
local_positions[j] = joint_positions[j]
return local_positions