# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F # ----------------------------------------------------------------------------- # Activation functions # ----------------------------------------------------------------------------- def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None): """ Process network output to extract GS params and density values. Density could be view-dependent as SH coefficient Args: out: Network output tensor (B, C, H, W) activation: Activation type for 3D points conf_activation: Activation type for confidence values Returns: Tuple of (3D points tensor, confidence tensor) """ # Move channels from last dim to the 4th dimension => (B, H, W, C) fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected # Split into xyz (first C-1 channels) and confidence (last channel) conf_dim = 1 if conf_dim is None else conf_dim xyz = fmap[:, :, :, :-conf_dim] conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:] if activation == "norm_exp": d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) xyz_normed = xyz / d pts3d = xyz_normed * torch.expm1(d) elif activation == "norm": pts3d = xyz / xyz.norm(dim=-1, keepdim=True) elif activation == "exp": pts3d = torch.exp(xyz) elif activation == "relu": pts3d = F.relu(xyz) elif activation == "sigmoid": pts3d = torch.sigmoid(xyz) elif activation == "linear": pts3d = xyz else: raise ValueError(f"Unknown activation: {activation}") if conf_activation == "expp1": conf_out = 1 + conf.exp() elif conf_activation == "expp0": conf_out = conf.exp() elif conf_activation == "sigmoid": conf_out = torch.sigmoid(conf) elif conf_activation == "linear": conf_out = conf else: raise ValueError(f"Unknown conf_activation: {conf_activation}") return pts3d, conf_out # ----------------------------------------------------------------------------- # Other utilities # ----------------------------------------------------------------------------- class Permute(nn.Module): """nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage.""" dims: Tuple[int, ...] def __init__(self, dims: Tuple[int, ...]) -> None: super().__init__() self.dims = dims def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] return x.permute(*self.dims) def position_grid_to_embed( pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 ) -> torch.Tensor: """ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) Args: pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates embed_dim: Output channel dimension for embeddings Returns: Tensor of shape (H, W, embed_dim) with positional embeddings """ H, W, grid_dim = pos_grid.shape assert grid_dim == 2 pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) # Process x and y coordinates separately emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] # Combine and reshape emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] return emb.view(H, W, embed_dim) # [H, W, D] def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device) omega /= embed_dim / 2.0 omega = 1.0 / omega_0**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb.float() # Inspired by https://github.com/microsoft/moge def create_uv_grid( width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None, ) -> torch.Tensor: """ Create a normalized UV grid of shape (width, height, 2). The grid spans horizontally and vertically according to an aspect ratio, ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right corner is at (x_span, y_span), normalized by the diagonal of the plane. Args: width (int): Number of points horizontally. height (int): Number of points vertically. aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. dtype (torch.dtype, optional): Data type of the resulting tensor. device (torch.device, optional): Device on which the tensor is created. Returns: torch.Tensor: A (width, height, 2) tensor of UV coordinates. """ # Derive aspect ratio if not explicitly provided if aspect_ratio is None: aspect_ratio = float(width) / float(height) # Compute normalized spans for X and Y diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 span_x = aspect_ratio / diag_factor span_y = 1.0 / diag_factor # Establish the linspace boundaries left_x = -span_x * (width - 1) / width right_x = span_x * (width - 1) / width top_y = -span_y * (height - 1) / height bottom_y = span_y * (height - 1) / height # Generate 1D coordinates x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) # Create 2D meshgrid (width x height) and stack into UV uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") uv_grid = torch.stack((uu, vv), dim=-1) return uv_grid # ----------------------------------------------------------------------------- # Interpolation (safe interpolation, avoid INT_MAX overflow) # ----------------------------------------------------------------------------- def custom_interpolate( x: torch.Tensor, size: Union[Tuple[int, int], None] = None, scale_factor: Union[float, None] = None, mode: str = "bilinear", align_corners: bool = True, ) -> torch.Tensor: """ Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate. """ if size is None: assert scale_factor is not None, "Either size or scale_factor must be provided." size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) INT_MAX = 1610612736 total = size[0] * size[1] * x.shape[0] * x.shape[1] if total > INT_MAX: chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) outs = [ nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks ] return torch.cat(outs, dim=0).contiguous() return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)