Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from torch import Tensor | |
| import torch | |
| import torch.nn.functional as tF | |
| def normalize_normals(normals: Tensor, C2W: Tensor, i: int = 0) -> Tensor: | |
| """Normalize a batch of multi-view `normals` by the `i`-th view. | |
| Inputs: | |
| - `normals`: (B, V, 3, H, W) | |
| - `C2W`: (B, V, 4, 4) | |
| - `i`: the index of the view to normalize by | |
| Outputs: | |
| - `normalized_normals`: (B, V, 3, H, W) | |
| """ | |
| _, _, R, C = C2W.shape # (B, V, 4, 4) | |
| assert R == C == 4 | |
| _, _, CC, _, _ = normals.shape # (B, V, 3, H, W) | |
| assert CC == 3 | |
| dtype = normals.dtype | |
| normals = normals.clone().float() | |
| transform = torch.inverse(C2W[:, i, :3, :3]) # (B, 3, 3) | |
| return torch.einsum("brc,bvchw->bvrhw", transform, normals).to(dtype) # (B, V, 3, H, W) | |
| def normalize_C2W(C2W: Tensor, i: int = 0, norm_radius: float = 0.) -> Tensor: | |
| """Normalize a batch of multi-view `C2W` by the `i`-th view. | |
| Inputs: | |
| - `C2W`: (B, V, 4, 4) | |
| - `i`: the index of the view to normalize by | |
| - `norm_radius`: the normalization radius | |
| Outputs: | |
| - `normalized_C2W`: (B, V, 4, 4) | |
| """ | |
| _, _, R, C = C2W.shape # (B, V, 4, 4) | |
| assert R == C == 4 | |
| device, dtype = C2W.device, C2W.dtype | |
| C2W = C2W.clone().float() | |
| if abs(norm_radius) > 0.: | |
| radius = torch.norm(C2W[:, i, :3, 3], dim=1) # (B,) | |
| C2W[:, :, :3, 3] *= (norm_radius / radius.unsqueeze(1).unsqueeze(2)) | |
| # The `i`-th view is normalized to a canonical matrix as the reference view | |
| transform = torch.tensor([ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, 0], | |
| [0, 0, 1, norm_radius], | |
| [0, 0, 0, 1] # canonical c2w in OpenGL world convention | |
| ], dtype=torch.float32, device=device) @ torch.inverse(C2W[:, i, ...]) # (B, 4, 4) | |
| return (transform.unsqueeze(1) @ C2W).to(dtype) # (B, V, 4, 4) | |
| def unproject_depth(depth_map: Tensor, C2W: Tensor, fxfycxcy: Tensor) -> Tensor: | |
| """Unproject depth map to 3D world coordinate. | |
| Inputs: | |
| - `depth_map`: (B, V, H, W) | |
| - `C2W`: (B, V, 4, 4) | |
| - `fxfycxcy`: (B, V, 4) | |
| Outputs: | |
| - `xyz_world`: (B, V, 3, H, W) | |
| """ | |
| device, dtype = depth_map.device, depth_map.dtype | |
| B, V, H, W = depth_map.shape | |
| depth_map = depth_map.reshape(B*V, H, W).float() | |
| C2W = C2W.reshape(B*V, 4, 4).float() | |
| fxfycxcy = fxfycxcy.reshape(B*V, 4).float() | |
| K = torch.zeros(B*V, 3, 3, dtype=torch.float32, device=device) | |
| K[:, 0, 0] = fxfycxcy[:, 0] | |
| K[:, 1, 1] = fxfycxcy[:, 1] | |
| K[:, 0, 2] = fxfycxcy[:, 2] | |
| K[:, 1, 2] = fxfycxcy[:, 3] | |
| K[:, 2, 2] = 1 | |
| y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") # OpenCV/COLMAP camera convention | |
| y = y.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (H-1) | |
| x = x.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (W-1) | |
| # NOTE: To align with `plucker_ray(bug=False)`, should be: | |
| # y = (y.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / H | |
| # x = (x.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / W | |
| xyz_map = torch.stack([x, y, torch.ones_like(x)], axis=-1) * depth_map[..., None] | |
| xyz = xyz_map.view(B*V, -1, 3) | |
| # Get point positions in camera coordinate | |
| xyz = torch.matmul(xyz, torch.transpose(torch.inverse(K), 1, 2)) | |
| xyz_map = xyz.view(B*V, H, W, 3) | |
| # Transform pts from camera to world coordinate | |
| xyz_homo = torch.ones((B*V, H, W, 4), device=device) | |
| xyz_homo[..., :3] = xyz_map | |
| xyz_world = torch.bmm(C2W, xyz_homo.reshape(B*V, -1, 4).permute(0, 2, 1))[:, :3, ...].to(dtype) # (B*V, 3, H*W) | |
| xyz_world = xyz_world.reshape(B, V, 3, H, W) | |
| return xyz_world | |
| def plucker_ray(h: int, w: int, C2W: Tensor, fxfycxcy: Tensor, bug: bool = True) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: | |
| """Get Plucker ray embeddings. | |
| Inputs: | |
| - `h`: image height | |
| - `w`: image width | |
| - `C2W`: (B, V, 4, 4) | |
| - `fxfycxcy`: (B, V, 4) | |
| Outputs: | |
| - `plucker`: (B, V, 6, `h`, `w`) | |
| - `ray_o`: (B, V, 3, `h`, `w`) | |
| - `ray_d`: (B, V, 3, `h`, `w`) | |
| """ | |
| device, dtype = C2W.device, C2W.dtype | |
| B, V = C2W.shape[:2] | |
| C2W = C2W.reshape(B*V, 4, 4).float() | |
| fxfycxcy = fxfycxcy.reshape(B*V, 4).float() | |
| y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") # OpenCV/COLMAP camera convention | |
| y, x = y.to(device), x.to(device) | |
| if bug: # BUG !!! same here: https://github.com/camenduru/GRM/blob/master/model/visual_encoder/vit_gs.py#L85 | |
| y = y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (h - 1) | |
| x = x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (w - 1) | |
| x = (x + 0.5 - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1] | |
| y = (y + 0.5 - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2] | |
| else: | |
| y = (y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / h | |
| x = (x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / w | |
| x = (x - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1] | |
| y = (y - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2] | |
| z = torch.ones_like(x) | |
| ray_d = torch.stack([x, y, z], dim=2) # (B*V, h*w, 3) | |
| ray_d = torch.bmm(ray_d, C2W[:, :3, :3].transpose(1, 2)) # (B*V, h*w, 3) | |
| ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True) # (B*V, h*w, 3) | |
| ray_o = C2W[:, :3, 3][:, None, :].expand_as(ray_d) # (B*V, h*w, 3) | |
| ray_o = ray_o.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3) | |
| ray_d = ray_d.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3) | |
| plucker = torch.cat([torch.cross(ray_o, ray_d, dim=2).to(dtype), ray_d.to(dtype)], dim=2) | |
| return plucker, (ray_o, ray_d) | |
| def orbit_camera( | |
| elevs: Tensor, azims: Tensor, radius: Optional[Tensor] = None, | |
| is_degree: bool = True, | |
| target: Optional[Tensor] = None, | |
| opengl: bool=True, | |
| ) -> Tensor: | |
| """Construct a camera pose matrix orbiting a target with elevation & azimuth angle. | |
| Inputs: | |
| - `elevs`: (B,); elevation in (-90, 90), from +y to -y is (-90, 90) | |
| - `azims`: (B,); azimuth in (-180, 180), from +z to +x is (0, 90) | |
| - `radius`: (B,); camera radius; if None, default to 1. | |
| - `is_degree`: bool; whether the input angles are in degree | |
| - `target`: (B, 3); look-at target position | |
| - `opengl`: bool; whether to use OpenGL convention | |
| Outputs: | |
| - `C2W`: (B, 4, 4); camera pose matrix | |
| """ | |
| device, dtype = elevs.device, elevs.dtype | |
| if radius is None: | |
| radius = torch.ones_like(elevs) | |
| assert elevs.shape == azims.shape == radius.shape | |
| if target is None: | |
| target = torch.zeros(elevs.shape[0], 3, device=device, dtype=dtype) | |
| if is_degree: | |
| elevs = torch.deg2rad(elevs) | |
| azims = torch.deg2rad(azims) | |
| x = radius * torch.cos(elevs) * torch.sin(azims) | |
| y = - radius * torch.sin(elevs) | |
| z = radius * torch.cos(elevs) * torch.cos(azims) | |
| camposes = torch.stack([x, y, z], dim=1) + target # (B, 3) | |
| R = look_at(camposes, target, opengl=opengl) # (B, 3, 3) | |
| C2W = torch.cat([R, camposes[:, :, None]], dim=2) # (B, 3, 4) | |
| C2W = torch.cat([C2W, torch.zeros_like(C2W[:, :1, :])], dim=1) # (B, 4, 4) | |
| C2W[:, 3, 3] = 1. | |
| return C2W | |
| def look_at(camposes: Tensor, targets: Tensor, opengl: bool = True) -> Tensor: | |
| """Construct batched pose rotation matrices by look-at. | |
| Inputs: | |
| - `camposes`: (B, 3); camera positions | |
| - `targets`: (B, 3); look-at targets | |
| - `opengl`: whether to use OpenGL convention | |
| Outputs: | |
| - `R`: (B, 3, 3); normalized camera pose rotation matrices | |
| """ | |
| device, dtype = camposes.device, camposes.dtype | |
| if not opengl: # OpenCV convention | |
| # forward is camera -> target | |
| forward_vectors = tF.normalize(targets - camposes, dim=-1) | |
| up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors) | |
| right_vectors = tF.normalize(torch.cross(forward_vectors, up_vectors), dim=-1) | |
| up_vectors = tF.normalize(torch.cross(right_vectors, forward_vectors), dim=-1) | |
| else: | |
| # forward is target -> camera | |
| forward_vectors = tF.normalize(camposes - targets, dim=-1) | |
| up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors) | |
| right_vectors = tF.normalize(torch.cross(up_vectors, forward_vectors), dim=-1) | |
| up_vectors = tF.normalize(torch.cross(forward_vectors, right_vectors), dim=-1) | |
| R = torch.stack([right_vectors, up_vectors, forward_vectors], dim=-1) | |
| return R | |