Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| from addict import Dict | |
| from omegaconf import DictConfig, OmegaConf | |
| from depth_anything_3.cfg import create_object | |
| from depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri | |
| from depth_anything_3.utils.alignment import ( | |
| apply_metric_scaling, | |
| compute_alignment_mask, | |
| compute_sky_mask, | |
| least_squares_scale_scalar, | |
| sample_tensor_for_quantile, | |
| set_sky_regions_to_max_depth, | |
| ) | |
| from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity | |
| def _wrap_cfg(cfg_obj): | |
| return OmegaConf.create(cfg_obj) | |
| class DepthAnything3Net(nn.Module): | |
| """ | |
| Depth Anything 3 network for depth estimation and camera pose estimation. | |
| This network consists of: | |
| - Backbone: DinoV2 feature extractor | |
| - Head: DPT or DualDPT for depth prediction | |
| - Optional camera decoders for pose estimation | |
| - Optional GSDPT for 3DGS prediction | |
| Args: | |
| preset: Configuration preset containing network dimensions and settings | |
| Returns: | |
| Dictionary containing: | |
| - depth: Predicted depth map (B, H, W) | |
| - depth_conf: Depth confidence map (B, H, W) | |
| - extrinsics: Camera extrinsics (B, N, 4, 4) | |
| - intrinsics: Camera intrinsics (B, N, 3, 3) | |
| - gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians | |
| - aux: Auxiliary features for specified layers | |
| """ | |
| # Patch size for feature extraction | |
| PATCH_SIZE = 14 | |
| def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None): | |
| """ | |
| Initialize DepthAnything3Net with given yaml-initialized configuration. | |
| """ | |
| super().__init__() | |
| self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net)) | |
| self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head)) | |
| self.cam_dec, self.cam_enc = None, None | |
| if cam_dec is not None: | |
| self.cam_dec = ( | |
| cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec)) | |
| ) | |
| self.cam_enc = ( | |
| cam_dec if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc)) | |
| ) | |
| self.gs_adapter, self.gs_head = None, None | |
| if gs_head is not None and gs_adapter is not None: | |
| self.gs_adapter = ( | |
| gs_adapter | |
| if isinstance(gs_adapter, nn.Module) | |
| else create_object(_wrap_cfg(gs_adapter)) | |
| ) | |
| gs_out_dim = self.gs_adapter.d_in + 1 | |
| if isinstance(gs_head, nn.Module): | |
| assert ( | |
| gs_head.out_dim == gs_out_dim | |
| ), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}" | |
| self.gs_head = gs_head | |
| else: | |
| assert ( | |
| gs_head["output_dim"] == gs_out_dim | |
| ), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}" | |
| self.gs_head = create_object(_wrap_cfg(gs_head)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| extrinsics: torch.Tensor | None = None, | |
| intrinsics: torch.Tensor | None = None, | |
| export_feat_layers: list[int] | None = [], | |
| infer_gs: bool = False, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass through the network. | |
| Args: | |
| x: Input images (B, N, 3, H, W) | |
| extrinsics: Camera extrinsics (B, N, 4, 4) - unused | |
| intrinsics: Camera intrinsics (B, N, 3, 3) - unused | |
| feat_layers: List of layer indices to extract features from | |
| Returns: | |
| Dictionary containing predictions and auxiliary features | |
| """ | |
| # Extract features using backbone | |
| if extrinsics is not None: | |
| with torch.autocast(device_type=x.device.type, enabled=False): | |
| cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:]) | |
| else: | |
| cam_token = None | |
| feats, aux_feats = self.backbone( | |
| x, cam_token=cam_token, export_feat_layers=export_feat_layers | |
| ) | |
| # feats = [[item for item in feat] for feat in feats] | |
| H, W = x.shape[-2], x.shape[-1] | |
| # Process features through depth head | |
| with torch.autocast(device_type=x.device.type, enabled=False): | |
| output = self._process_depth_head(feats, H, W) | |
| output = self._process_camera_estimation(feats, H, W, output) | |
| if infer_gs: | |
| output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics) | |
| # Extract auxiliary features if requested | |
| output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W) | |
| return output | |
| def _process_depth_head( | |
| self, feats: list[torch.Tensor], H: int, W: int | |
| ) -> Dict[str, torch.Tensor]: | |
| """Process features through the depth prediction head.""" | |
| return self.head(feats, H, W, patch_start_idx=0) | |
| def _process_camera_estimation( | |
| self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor] | |
| ) -> Dict[str, torch.Tensor]: | |
| """Process camera pose estimation if camera decoder is available.""" | |
| if self.cam_dec is not None: | |
| pose_enc = self.cam_dec(feats[-1][1]) | |
| # Remove ray information as it's not needed for pose estimation | |
| if "ray" in output: | |
| del output.ray | |
| if "ray_conf" in output: | |
| del output.ray_conf | |
| # Convert pose encoding to extrinsics and intrinsics | |
| c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W)) | |
| output.extrinsics = affine_inverse(c2w) | |
| output.intrinsics = ixt | |
| return output | |
| def _process_gs_head( | |
| self, | |
| feats: list[torch.Tensor], | |
| H: int, | |
| W: int, | |
| output: Dict[str, torch.Tensor], | |
| in_images: torch.Tensor, | |
| extrinsics: torch.Tensor | None = None, | |
| intrinsics: torch.Tensor | None = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Process 3DGS parameters estimation if 3DGS head is available.""" | |
| if self.gs_head is None or self.gs_adapter is None: | |
| return output | |
| assert output.get("depth", None) is not None, "must provide MV depth for the GS head." | |
| # if GT camera poses are provided, use them | |
| if extrinsics is not None and intrinsics is not None: | |
| ctx_extr = extrinsics | |
| ctx_intr = intrinsics | |
| else: | |
| ctx_extr = output.get("extrinsics", None) | |
| ctx_intr = output.get("intrinsics", None) | |
| assert ( | |
| ctx_extr is not None and ctx_intr is not None | |
| ), "must process camera info first if GT is not available" | |
| gt_extr = extrinsics | |
| # homo the extr if needed | |
| ctx_extr = as_homogeneous(ctx_extr) | |
| if gt_extr is not None: | |
| gt_extr = as_homogeneous(gt_extr) | |
| # forward through the gs_dpt head to get 'camera space' parameters | |
| gs_outs = self.gs_head( | |
| feats=feats, | |
| H=H, | |
| W=W, | |
| patch_start_idx=0, | |
| images=in_images, | |
| ) | |
| raw_gaussians = gs_outs.raw_gs | |
| densities = gs_outs.raw_gs_conf | |
| # convert to 'world space' 3DGS parameters; ready to export and render | |
| # gt_extr could be None, and will be used to align the pose scale if available | |
| gs_world = self.gs_adapter( | |
| extrinsics=ctx_extr, | |
| intrinsics=ctx_intr, | |
| depths=output.depth, | |
| opacities=map_pdf_to_opacity(densities), | |
| raw_gaussians=raw_gaussians, | |
| image_shape=(H, W), | |
| gt_extrinsics=gt_extr, | |
| ) | |
| output.gaussians = gs_world | |
| return output | |
| def _extract_auxiliary_features( | |
| self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int | |
| ) -> Dict[str, torch.Tensor]: | |
| """Extract auxiliary features from specified layers.""" | |
| aux_features = Dict() | |
| assert len(feats) == len(feat_layers) | |
| for feat, feat_layer in zip(feats, feat_layers): | |
| # Reshape features to spatial dimensions | |
| feat_reshaped = feat.reshape( | |
| [ | |
| feat.shape[0], | |
| feat.shape[1], | |
| H // self.PATCH_SIZE, | |
| W // self.PATCH_SIZE, | |
| feat.shape[-1], | |
| ] | |
| ) | |
| aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped | |
| return aux_features | |
| class NestedDepthAnything3Net(nn.Module): | |
| """ | |
| Nested Depth Anything 3 network with metric scaling capabilities. | |
| This network combines two DepthAnything3Net branches: | |
| - Main branch: Standard depth estimation | |
| - Metric branch: Metric depth estimation for scaling alignment | |
| The network performs depth alignment using least squares scaling | |
| and handles sky region masking for improved depth estimation. | |
| Args: | |
| preset: Configuration for the main depth estimation branch | |
| second_preset: Configuration for the metric depth branch | |
| """ | |
| def __init__(self, anyview: DictConfig, metric: DictConfig): | |
| """ | |
| Initialize NestedDepthAnything3Net with two branches. | |
| Args: | |
| preset: Configuration for main depth estimation branch | |
| second_preset: Configuration for metric depth branch | |
| """ | |
| super().__init__() | |
| self.da3 = create_object(anyview) | |
| self.da3_metric = create_object(metric) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| extrinsics: torch.Tensor | None = None, | |
| intrinsics: torch.Tensor | None = None, | |
| export_feat_layers: list[int] | None = [], | |
| infer_gs: bool = False, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass through both branches with metric scaling alignment. | |
| Args: | |
| x: Input images (B, N, 3, H, W) | |
| extrinsics: Camera extrinsics (B, N, 4, 4) - unused | |
| intrinsics: Camera intrinsics (B, N, 3, 3) - unused | |
| feat_layers: List of layer indices to extract features from | |
| metric_feat: Whether to use metric features (unused) | |
| Returns: | |
| Dictionary containing aligned depth predictions and camera parameters | |
| """ | |
| # Get predictions from both branches | |
| output = self.da3( | |
| x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs | |
| ) | |
| metric_output = self.da3_metric(x, infer_gs=infer_gs) | |
| # Apply metric scaling and alignment | |
| output = self._apply_metric_scaling(output, metric_output) | |
| output = self._apply_depth_alignment(output, metric_output) | |
| output = self._handle_sky_regions(output, metric_output) | |
| return output | |
| def _apply_metric_scaling( | |
| self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] | |
| ) -> Dict[str, torch.Tensor]: | |
| """Apply metric scaling to the metric depth output.""" | |
| # Scale metric depth based on camera intrinsics | |
| metric_output.depth = apply_metric_scaling( | |
| metric_output.depth, | |
| output.intrinsics, | |
| ) | |
| return output | |
| def _apply_depth_alignment( | |
| self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] | |
| ) -> Dict[str, torch.Tensor]: | |
| """Apply depth alignment using least squares scaling.""" | |
| # Compute non-sky mask | |
| non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) | |
| # Ensure we have enough non-sky pixels | |
| assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment" | |
| # Sample depth confidence for quantile computation | |
| depth_conf_ns = output.depth_conf[non_sky_mask] | |
| depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000) | |
| median_conf = torch.quantile(depth_conf_sampled, 0.5) | |
| # Compute alignment mask | |
| align_mask = compute_alignment_mask( | |
| output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf | |
| ) | |
| # Compute scale factor using least squares | |
| valid_depth = output.depth[align_mask] | |
| valid_metric_depth = metric_output.depth[align_mask] | |
| scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth) | |
| # Apply scaling to depth and extrinsics | |
| output.depth *= scale_factor | |
| output.extrinsics[:, :, :3, 3] *= scale_factor | |
| output.is_metric = 1 | |
| output.scale_factor = scale_factor.item() | |
| return output | |
| def _handle_sky_regions( | |
| self, | |
| output: Dict[str, torch.Tensor], | |
| metric_output: Dict[str, torch.Tensor], | |
| sky_depth_def: float = 200.0, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Handle sky regions by setting them to maximum depth.""" | |
| non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) | |
| # Compute maximum depth for non-sky regions | |
| # Use sampling to safely compute quantile on large tensors | |
| non_sky_depth = output.depth[non_sky_mask] | |
| if non_sky_depth.numel() > 100000: | |
| idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device) | |
| sampled_depth = non_sky_depth[idx] | |
| else: | |
| sampled_depth = non_sky_depth | |
| non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def) | |
| # Set sky regions to maximum depth and high confidence | |
| output.depth, output.depth_conf = set_sky_regions_to_max_depth( | |
| output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max | |
| ) | |
| return output | |