Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| from pytorch3d.structures import join_meshes_as_scene, join_meshes_as_batch, Meshes | |
| from pytorch3d.renderer import ( | |
| FoVPerspectiveCameras, look_at_view_transform, | |
| RasterizationSettings, MeshRenderer, MeshRasterizer, | |
| SoftPhongShader, PointLights, BlendParams, SoftSilhouetteShader | |
| ) | |
| from utils.loss_utils import compute_visibility_mask_igl | |
| def create_camera_from_blender_params(cam_params, device): | |
| """ | |
| Convert Blender camera parameters to PyTorch3D camera | |
| Args: | |
| cam_params (dict): Camera parameters from Blender JSON | |
| device: Device to create camera on | |
| Returns: | |
| FoVPerspectiveCameras: Converted camera | |
| """ | |
| # Extract matrix world and convert to rotation and translation | |
| matrix_world = torch.tensor(cam_params['matrix_world'], dtype=torch.float32) | |
| # Extract field of view (use x_fov, assuming symmetric FOV) | |
| fov = cam_params['x_fov'] * 180 / np.pi # Convert radians to degrees | |
| rotation_matrix = torch.tensor([ | |
| [1, 0, 0, 0], | |
| [0, 0, 1, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, 0, 1] | |
| ], dtype=torch.float32) | |
| # Apply transformations | |
| adjusted_matrix = rotation_matrix @ matrix_world | |
| world2cam_matrix_tensor = torch.linalg.inv(adjusted_matrix) | |
| aligned_matrix = torch.tensor([ | |
| [-1.0, 0.0, 0.0, 0.0], | |
| [0.0, 1.0, 0.0, 0.0], | |
| [0.0, 0.0, -1.0, 0.0], | |
| [0.0, 0.0, 0.0, 1.0] | |
| ], dtype=torch.float32, device=device) | |
| world2cam_matrix = aligned_matrix @ world2cam_matrix_tensor.to(device) | |
| cam2world_matrix = torch.linalg.inv(world2cam_matrix) | |
| # Extract rotation and translation | |
| R = cam2world_matrix[:3, :3] | |
| T = torch.tensor([ | |
| world2cam_matrix[0, 3], | |
| world2cam_matrix[1, 3], | |
| world2cam_matrix[2, 3] | |
| ], device=device, dtype=torch.float32) | |
| return FoVPerspectiveCameras( | |
| device=device, | |
| fov=fov, | |
| R=R[None], | |
| T=T[None], | |
| znear=0.1, | |
| zfar=100.0 | |
| ) | |
| class MeshRenderer3D: | |
| """ | |
| PyTorch3D mesh renderer with support for various rendering modes. | |
| Features: | |
| - Standard mesh rendering with Phong shading | |
| - Silhouette rendering | |
| - Multi-frame batch rendering | |
| - Point projection with visibility computation | |
| """ | |
| def __init__(self, device, image_size=1024, cam_params=None, light_params=None, raster_params=None): | |
| self.device = device | |
| # Initialize camera | |
| self.camera = self._setup_camera(cam_params) | |
| # Initialize light | |
| self.light = self._setup_light(light_params) | |
| # Initialize rasterization settings | |
| self.raster_settings = self._setup_raster_settings(raster_params, image_size) | |
| self.camera.image_size = self.raster_settings.image_size | |
| # Initialize renderers | |
| self._setup_renderers() | |
| def _setup_camera(self, cam_params): | |
| """Setup camera based on parameters.""" | |
| if cam_params is None: | |
| # Default camera | |
| R, T = look_at_view_transform(3.0, 30, 20, at=[[0.0, 1.0, 0.0]]) | |
| return FoVPerspectiveCameras(device=self.device, R=R, T=T) | |
| # Check if Blender parameters | |
| if "matrix_world" in cam_params and "x_fov" in cam_params: | |
| return create_camera_from_blender_params(cam_params, self.device) | |
| else: | |
| raise ValueError("Need to provide blender parameters.") | |
| def _setup_light(self, light_params): | |
| """Setup light source.""" | |
| if light_params is None: | |
| return PointLights(device=self.device, location=[[0.0, 0.0, 3.0]]) | |
| location = [[ | |
| light_params.get('light_x', 0.0), | |
| light_params.get('light_y', 0.0), | |
| light_params.get('light_z', 3.0) | |
| ]] | |
| return PointLights(device=self.device, location=location) | |
| def _setup_raster_settings(self, raster_params, default_size): | |
| """Setup rasterization settings.""" | |
| if raster_params is None: | |
| raster_params = { | |
| "image_size": [default_size, default_size], | |
| "blur_radius": 0.0, | |
| "faces_per_pixel": 1, | |
| "bin_size": 0, | |
| "cull_backfaces": False | |
| } | |
| return RasterizationSettings(**raster_params) | |
| def _setup_renderers(self) -> None: | |
| """Initialize main and silhouette renderers.""" | |
| rasterizer = MeshRasterizer( | |
| cameras=self.camera, | |
| raster_settings=self.raster_settings | |
| ) | |
| # Main renderer with Phong shading | |
| self.renderer = MeshRenderer( | |
| rasterizer=rasterizer, | |
| shader=SoftPhongShader( | |
| device=self.device, | |
| cameras=self.camera, | |
| lights=self.light | |
| ) | |
| ) | |
| # Silhouette renderer | |
| blend_params = BlendParams( | |
| sigma=1e-4, | |
| gamma=1e-4, | |
| background_color=(0.0, 0.0, 0.0) | |
| ) | |
| self.silhouette_renderer = MeshRenderer( | |
| rasterizer=rasterizer, | |
| shader=SoftSilhouetteShader(blend_params=blend_params) | |
| ) | |
| def render(self, meshes): | |
| """ | |
| Render meshes with Phong shading. | |
| Args: | |
| meshes: Single mesh or list of meshes | |
| Returns: | |
| Rendered images tensor of shape (1, H, W, C) | |
| """ | |
| scene_mesh = self._prepare_scene_mesh(meshes) | |
| return self.renderer(scene_mesh) | |
| def render_batch(self, mesh_list): | |
| """ | |
| Render multiple frames as a batch. | |
| Args: | |
| mesh_list: List of mesh lists (one per frame) | |
| Returns: | |
| Batch of rendered images of shape (B, H, W, C) | |
| """ | |
| assert isinstance(mesh_list, list) | |
| batch_meshes = [] | |
| for frame_meshes in mesh_list: | |
| scene_mesh = self._prepare_scene_mesh(frame_meshes) | |
| batch_meshes.append(scene_mesh) | |
| batch_mesh = join_meshes_as_batch(batch_meshes) | |
| return self.renderer(batch_mesh) | |
| def get_rasterization_fragments(self, mesh_list): | |
| """ | |
| Get rasterization fragments for batch of meshes. | |
| Args: | |
| mesh_list: List of mesh lists (one per frame) | |
| Returns: | |
| Rasterization fragments | |
| """ | |
| assert isinstance(mesh_list, list) | |
| batch_meshes = [] | |
| for frame_meshes in mesh_list: | |
| scene_mesh = self._prepare_scene_mesh(frame_meshes) | |
| batch_meshes.append(scene_mesh) | |
| batch_mesh = join_meshes_as_batch(batch_meshes) | |
| return self.renderer.rasterizer(batch_mesh) | |
| def render_silhouette_batch(self, mesh_list): | |
| """ | |
| Render silhouette masks for multiple frames. | |
| Args: | |
| mesh_list: List of mesh lists (one per frame) | |
| Returns: | |
| Batch of silhouette masks of shape (B, H, W, 1) | |
| """ | |
| assert isinstance(mesh_list, list) | |
| batch_meshes = [] | |
| for frame_meshes in mesh_list: | |
| scene_mesh = self._prepare_scene_mesh(frame_meshes) | |
| batch_meshes.append(scene_mesh) | |
| batch_mesh = join_meshes_as_batch(batch_meshes) | |
| silhouette = self.silhouette_renderer(batch_mesh) | |
| return silhouette[..., 3:] # Return alpha channel | |
| def tensor_to_image(self, tensor): | |
| """ | |
| Convert rendered tensor to numpy image array. | |
| Args: | |
| tensor: Rendered tensor of shape (B, H, W, C) | |
| Returns: | |
| Numpy array of shape (H, W, 3) with values in [0, 255] | |
| """ | |
| return (tensor[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) | |
| def project_points(self, points_3d): | |
| """ | |
| Project 3D joints/vertices to 2D image plane | |
| Args: | |
| points_3d: shape (N, 3) or (B, N, 3) tensor of 3D points | |
| Returns: | |
| points_2d: shape (N, 2) or (B, N, 2) tensor of 2D projected points | |
| """ | |
| if not torch.is_tensor(points_3d): | |
| points_3d = torch.tensor(points_3d, device=self.device, dtype=torch.float32) | |
| if len(points_3d.shape) == 2: | |
| points_3d = points_3d.unsqueeze(0) # (1, N, 3) | |
| # project points | |
| projected = self.camera.transform_points_screen(points_3d, image_size=self.raster_settings.image_size) | |
| if projected.shape[0] == 1: | |
| projected_points = projected.squeeze(0)[:, :2] | |
| else: | |
| projected_points = projected[:, :, :2] | |
| return projected_points | |
| def render_with_points(self, meshes, points_3d, point_radius=3, for_vertices=False): | |
| """ | |
| render the mesh and visualize the joints/vertices on the image | |
| Args: | |
| meshes: mesh or list of meshes to be rendered | |
| points_3d: shape (N, 3) tensor of 3D joints/vertices | |
| point_radius: radius of the drawn points | |
| for_vertices: if True, compute visibility for vertices, else for joints | |
| Returns: | |
| Image with joints/vertices drawn, visibility mask | |
| """ | |
| rendered_image = self.render(meshes) | |
| # project 3D points to 2D | |
| points_2d = self.project_points(points_3d) | |
| image_np = rendered_image[0, ..., :3].cpu().numpy() | |
| image_with_points = image_np.copy() | |
| height, width = image_np.shape[:2] | |
| ray_origins = self.camera.get_camera_center() # (B, 3) | |
| ray_origins = np.tile(ray_origins.detach().cpu().numpy(), (points_3d.shape[0], 1)) | |
| verts = meshes.verts_packed().detach().cpu().numpy() | |
| faces = meshes.faces_packed().detach().cpu().numpy() | |
| ray_dirs = points_3d.detach().cpu().numpy() - ray_origins # calculate ray directions | |
| distances = np.linalg.norm(ray_dirs, axis=1) # distances from camera to points | |
| ray_dirs = (ray_dirs.T / distances).T # normalize to unit vectors | |
| vis_mask = compute_visibility_mask_igl(ray_origins, ray_dirs, distances, verts, faces, distance_tolerance=1e-6, for_vertices=for_vertices) | |
| # draw points | |
| visible_color=(1, 0, 0) # visible points are red | |
| invisible_color=(0, 0, 1) # invisible points are blue | |
| for i, point in enumerate(points_2d): | |
| x, y = int(point[0].item()), int(point[1].item()) | |
| if 0 <= x < width and 0 <= y < height: | |
| point_color = visible_color if vis_mask[i] else invisible_color | |
| cv2.circle(image_with_points, (x, y), point_radius, point_color, -1) | |
| result = torch.from_numpy(image_with_points).to(self.device) | |
| result = result.unsqueeze(0) | |
| if rendered_image.shape[-1] == 4: | |
| alpha = rendered_image[..., 3:] | |
| result = torch.cat([result, alpha], dim=-1) | |
| return result, vis_mask | |
| def _prepare_scene_mesh(self, meshes): | |
| """Convert meshes to a single scene mesh.""" | |
| if isinstance(meshes, Meshes): | |
| return meshes | |
| elif isinstance(meshes, list): | |
| return join_meshes_as_scene(meshes) | |
| else: | |
| raise ValueError("meshes must be Meshes object or list of Meshes") | |