Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch.utils.checkpoint import _get_autocast_kwargs | |
| from src.models.gs_render.gs_util import render | |
| class DeferredBP(torch.autograd.Function): | |
| def forward(ctx, | |
| xyz, rgb, scale, rotation, opacity, | |
| height, width, C2W, fxfycxcy, | |
| patch_size, gaussian_model, | |
| znear, zfar, | |
| bg_color, | |
| scaling_modifier, | |
| render_dn, | |
| ): | |
| """ Forward rendering. """ | |
| assert (xyz.dim() == 3) and (rgb.dim() == 3) and (scale.dim() == 3) and (rotation.dim() == 3) | |
| assert height % patch_size == 0 and width % patch_size == 0 | |
| ctx.save_for_backward(xyz, rgb, scale, rotation, opacity) # save tensors for backward | |
| ctx.height = height | |
| ctx.width = width | |
| ctx.C2W = C2W | |
| ctx.fxfycxcy = fxfycxcy | |
| ctx.patch_size = patch_size | |
| ctx.gaussian_model = gaussian_model | |
| ctx.znear = znear | |
| ctx.zfar = zfar | |
| ctx.bg_color = bg_color | |
| ctx.scaling_modifier = scaling_modifier | |
| ctx.render_dn = render_dn | |
| ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() | |
| ctx.manual_seeds = [] | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast("cuda", **ctx.gpu_autocast_kwargs), | |
| torch.autocast("cpu", **ctx.cpu_autocast_kwargs), | |
| ): | |
| device, (B, V) = C2W.device, C2W.shape[:2] | |
| images = torch.zeros(B, V, 3, height, width, device=device) | |
| alphas = torch.zeros(B, V, 1, height, width, device=device) | |
| depths = torch.zeros(B, V, 1, height, width, device=device) | |
| normals = torch.zeros(B, V, 3, height, width, device=device) | |
| for i in range(B): | |
| ctx.manual_seeds.append([]) | |
| pc = ctx.gaussian_model.set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i]) | |
| for j in range(V): | |
| fxfycxcy_ij = fxfycxcy[i, j] | |
| fx, fy, cx, cy = fxfycxcy_ij[0], fxfycxcy_ij[1], fxfycxcy_ij[2], fxfycxcy_ij[3] | |
| for m in range(0, ctx.width//ctx.patch_size): | |
| for n in range(0, ctx.height //ctx.patch_size): | |
| seed = torch.randint(0, 2**32, (1,)).long().item() | |
| ctx.manual_seeds[-1].append(seed) | |
| # Transform intrinsics | |
| center_x = (m*ctx.patch_size + ctx.patch_size//2) / ctx.width | |
| center_y = (n*ctx.patch_size + ctx.patch_size//2) / ctx.height | |
| scale_x = ctx.width // ctx.patch_size | |
| scale_y = ctx.height // ctx.patch_size | |
| trans_x = 0.5 - scale_x * center_x | |
| trans_y = 0.5 - scale_y * center_y | |
| new_fx = scale_x * fx | |
| new_fy = scale_y * fy | |
| new_cx = scale_x * cx + trans_x | |
| new_cy = scale_y * cy + trans_y | |
| new_fxfycxcy = torch.stack([new_fx, new_fy, new_cx, new_cy], dim=0) | |
| render_results = render(pc, patch_size, patch_size, C2W[i, j], new_fxfycxcy, znear, zfar, bg_color, scaling_modifier, render_dn) | |
| images[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["image"] | |
| alphas[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["alpha"] | |
| depths[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["depth"] | |
| normals[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["normal"] | |
| return images, alphas, depths, normals | |
| def backward(ctx, grad_images, grad_alphas, grad_depths, grad_normals): | |
| """ Backward process. """ | |
| xyz, rgb, scale, rotation, opacity = ctx.saved_tensors | |
| xyz_nosync = xyz.detach().clone() | |
| xyz_nosync.requires_grad = True | |
| xyz_nosync.grad = None | |
| rgb_nosync = rgb.detach().clone() | |
| rgb_nosync.requires_grad = True | |
| rgb_nosync.grad = None | |
| scale_nosync = scale.detach().clone() | |
| scale_nosync.requires_grad = True | |
| scale_nosync.grad = None | |
| rotation_nosync = rotation.detach().clone() | |
| rotation_nosync.requires_grad = True | |
| rotation_nosync.grad = None | |
| opacity_nosync = opacity.detach().clone() | |
| opacity_nosync.requires_grad = True | |
| opacity_nosync.grad = None | |
| with ( | |
| torch.enable_grad(), | |
| torch.autocast("cuda", **ctx.gpu_autocast_kwargs), | |
| torch.autocast("cpu", **ctx.cpu_autocast_kwargs) | |
| ): | |
| B, V = ctx.C2W.shape[:2] | |
| for i in range(B): | |
| ctx.manual_seeds.append([]) | |
| pc = ctx.gaussian_model.set_data(xyz_nosync[i], rgb_nosync[i], scale_nosync[i], rotation_nosync[i], opacity_nosync[i]) | |
| for j in range(V): | |
| fxfycxcy_ij = ctx.fxfycxcy[i, j] | |
| fx, fy, cx, cy = fxfycxcy_ij[0], fxfycxcy_ij[1], fxfycxcy_ij[2], fxfycxcy_ij[3] | |
| for m in range(0, ctx.width//ctx.patch_size): | |
| for n in range(0, ctx.height //ctx.patch_size): | |
| grad_images_split = grad_images[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] | |
| grad_alphas_split = grad_alphas[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] | |
| grad_depths_split = grad_depths[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] | |
| grad_normals_split = grad_normals[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] | |
| seed = torch.randint(0, 2**32, (1,)).long().item() | |
| ctx.manual_seeds[-1].append(seed) | |
| # Transform intrinsics | |
| center_x = (m*ctx.patch_size + ctx.patch_size//2) / ctx.width | |
| center_y = (n*ctx.patch_size + ctx.patch_size//2) / ctx.height | |
| scale_x = ctx.width // ctx.patch_size | |
| scale_y = ctx.height // ctx.patch_size | |
| trans_x = 0.5 - scale_x * center_x | |
| trans_y = 0.5 - scale_y * center_y | |
| new_fx = scale_x * fx | |
| new_fy = scale_y * fy | |
| new_cx = scale_x * cx + trans_x | |
| new_cy = scale_y * cy + trans_y | |
| new_fxfycxcy = torch.stack([new_fx, new_fy, new_cx, new_cy], dim=0) | |
| render_results = render(pc, ctx.patch_size, ctx.patch_size, ctx.C2W[i, j], new_fxfycxcy, ctx.znear, ctx.zfar, ctx.bg_color, ctx.scaling_modifier) | |
| color_split = render_results["image"] | |
| alpha_split = render_results["alpha"] | |
| depth_split = render_results["depth"] | |
| normal_split = render_results["normal"] | |
| render_split = torch.cat([color_split, alpha_split, depth_split, normal_split], dim=0) | |
| grad_split = torch.cat([grad_images_split, grad_alphas_split, grad_depths_split, grad_normals_split], dim=0) | |
| render_split.backward(grad_split) | |
| return xyz_nosync.grad, rgb_nosync.grad, scale_nosync.grad, rotation_nosync.grad, opacity_nosync.grad, None, None, None, None, None, None, None, None, None, None, None | |
| def deferred_bp( | |
| xyz, rgb, scale, rotation, opacity, | |
| height, width, C2W, fxfycxcy, | |
| patch_size, gaussian_model, | |
| znear, zfar, | |
| bg_color, | |
| scaling_modifier, | |
| render_dn, | |
| ): | |
| return DeferredBP.apply( | |
| xyz, rgb, scale, rotation, opacity, | |
| height, width, C2W, fxfycxcy, | |
| patch_size, gaussian_model, | |
| znear, zfar, | |
| bg_color, | |
| scaling_modifier, | |
| render_dn, | |
| ) | |