import torch import torch import torch.nn.functional as F def get_rigid_transform(A, B): """ Estimate the rigid body transformation between two sets of 3D points. A and B are Nx3 matrices where each row is a 3D point. Returns a rotation matrix R and translation vector t. Args: A, B: [batch, N, 3] matrix of 3D points Outputs: R, t: [batch, 3, 3/1] target = R @ source (source shape [3, 1]) + t """ assert A.shape == B.shape, "Input matrices must have the same shape" assert A.shape[-1] == 3, "Input matrices must have 3 columns (x, y, z coordinates)" # Compute centroids. [..., 1, 3] centroid_A = torch.mean(A, dim=-2, keepdim=True) centroid_B = torch.mean(B, dim=-2, keepdim=True) # Center the point sets A_centered = A - centroid_A B_centered = B - centroid_B # Compute the cross-covariance matrix. [..., 3, 3] H = A_centered.transpose(-2, -1) @ B_centered # Compute the Singular Value Decomposition. Along last two dimensions U, S, Vt = torch.linalg.svd(H) # Compute the rotation matrix R = Vt.transpose(-2, -1) @ U.transpose(-2, -1) # Ensure a right-handed coordinate system flip_mask = (torch.det(R) < 0) * -2.0 + 1.0 # Vt[:, 2, :] *= flip_mask[..., None] # [N] => [N, 3] pad_flip_mask = torch.stack( [torch.ones_like(flip_mask), torch.ones_like(flip_mask), flip_mask], dim=-1 ) Vt = Vt * pad_flip_mask[..., None] # Compute the rotation matrix R = Vt.transpose(-2, -1) @ U.transpose(-2, -1) # print(R.shape, centroid_A.shape, centroid_B.shape, flip_mask.shape) # Compute the translation t = centroid_B - (R @ centroid_A.transpose(-2, -1)).transpose(-2, -1) t = t.transpose(-2, -1) return R, t def _test_rigid_transform(): # Example usage: A = torch.tensor([[1, 2, 3], [4, 5, 6], [9, 8, 10], [10, -5, 1]]) * 1.0 R_synthesized = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) * 1.0 # init a random rotation matrix: B = (R_synthesized @ A.T).T + 2.0 # Just an example offset R, t = get_rigid_transform(A[None, ...], B[None, ...]) print("Rotation matrix R:") print(R) print("\nTranslation vector t:") print(t) def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """ from pytorch3d. Based on trace_method like: https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L205 Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : ].reshape(batch_dim + (4,)) def quternion_to_matrix(r): norm = torch.sqrt( r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] ) q = r / norm[:, None] R = torch.zeros((q.size(0), 3, 3), device="cuda") r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] R[:, 0, 0] = 1 - 2 * (y * y + z * z) R[:, 0, 1] = 2 * (x * y - r * z) R[:, 0, 2] = 2 * (x * z + r * y) R[:, 1, 0] = 2 * (x * y + r * z) R[:, 1, 1] = 1 - 2 * (x * x + z * z) R[:, 1, 2] = 2 * (y * z - r * x) R[:, 2, 0] = 2 * (x * z - r * y) R[:, 2, 1] = 2 * (y * z + r * x) R[:, 2, 2] = 1 - 2 * (x * x + y * y) return R def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ from Pytorch3d Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part first, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ From pytorch3d Multiply two quaternions. Usual torch rules for broadcasting apply. Args: a: Quaternions as tensor of shape (..., 4), real part first. b: Quaternions as tensor of shape (..., 4), real part first. Returns: The product of a and b, a tensor of quaternions shape (..., 4). """ aw, ax, ay, az = torch.unbind(a, -1) bw, bx, by, bz = torch.unbind(b, -1) ow = aw * bw - ax * bx - ay * by - az * bz ox = aw * bx + ax * bw + ay * bz - az * by oy = aw * by - ax * bz + ay * bw + az * bx oz = aw * bz + ax * by - ay * bx + az * bw ret = torch.stack((ow, ox, oy, oz), -1) ret = standardize_quaternion(ret) return ret def _test_matrix_to_quaternion(): # init a random batch of quaternion r = torch.randn((10, 4)).cuda() norm = torch.sqrt( r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] ) q = r / norm[:, None] q = standardize_quaternion(q) R = quternion_to_matrix(q) I_rec = R @ R.transpose(-2, -1) I_rec_error = torch.abs(I_rec - torch.eye(3, device="cuda")[None, ...]).max() q_recovered = matrix_to_quaternion(R) norm_ = torch.linalg.norm(q_recovered, dim=-1) q_recovered = q_recovered / norm_[..., None] q_recovered = standardize_quaternion(q_recovered) print(q_recovered.shape, q.shape, R.shape) rec = (q - q_recovered).abs().max() print("rotation to I error:", I_rec_error, "quant rec error: ", rec) def _test_matrix_to_quaternion_2(): R = ( torch.tensor( [[[1, 0, 0], [0, -1, 0], [0, 0, -1]], [[1, 0, 0], [0, 0, 1], [0, -1, 0]]] ) * 1.0 ) q_rec = matrix_to_quaternion(R.transpose(-2, -1)) R_rec = quternion_to_matrix(q_rec) print(R_rec) def interpolate_points_w_R( query_points, query_rotation, drive_origin_pts, drive_displacement, top_k_index ): """ Args: query_points: [n, 3] drive_origin_pts: [m, 3] drive_displacement: [m, 3] top_k_index: [n, top_k] < m Or directly call: apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6): Args: origin_points: (N_r, 3) offsets: (N_r, 3) in rendering """ # [n, topk, 3] top_k_disp = drive_displacement[top_k_index] source_points = drive_origin_pts[top_k_index] R, t = get_rigid_transform(source_points, source_points + top_k_disp) avg_offsets = top_k_disp.mean(dim=1) ret_points = query_points + avg_offsets new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation) return ret_points, new_rotation def interpolate_points( query_points, query_rotation, drive_origin_pts, drive_current_points, top_k_index ): source_points = drive_origin_pts[top_k_index] # [n, topk, 3] target_points = drive_current_points[top_k_index] # [n, topk, 3] disp = target_points - source_points avg_offsets = disp.mean(dim=1) # [n, 3] ret_points = query_points + avg_offsets # [n, 3] # ret_points = target_points.mean(dim=1) # [n, 3] R, t = get_rigid_transform(source_points, target_points) new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation) return ret_points, new_rotation