Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import argparse | |
| import json | |
| import numpy as np | |
| import logging | |
| import glob | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from renderer import MeshRenderer3D | |
| from model import RiggingModel | |
| from utils.quat_utils import ( | |
| compute_rest_local_positions, quat_inverse, quat_log, quat_multiply | |
| ) | |
| from utils.loss_utils import ( | |
| DepthModule, compute_reprojection_loss, geodesic_loss, root_motion_reg, | |
| calculate_flow_loss, compute_depth_loss_normalized, joint_motion_coherence | |
| ) | |
| from utils.data_loader import load_model_from_obj_and_rig, prepare_depth | |
| from utils.save_utils import ( | |
| save_args, visualize_joints_on_mesh, save_final_video, | |
| save_and_smooth_results, visualize_points_on_mesh, save_track_points | |
| ) | |
| from utils.misc import warmup_then_decay | |
| from third_partys.co_tracker.save_track import save_track | |
| class AnimationOptimizer: | |
| """Main class for animation optimization with video guidance.""" | |
| def __init__(self, args, device = 'cuda:0'): | |
| self.args = args | |
| self.device = device | |
| self.logger = self._setup_logger() | |
| # Training parameters | |
| self.reinit_patience_threshold = 20 | |
| self.loss_divergence_factor = 2.0 | |
| self.gradient_clip_norm = 1.0 | |
| # Loss weights | |
| self.target_ratios = { | |
| 'rgb': args.rgb_wt, | |
| 'flow': args.flow_wt, | |
| 'proj_joint': args.proj_joint_wt, | |
| 'proj_vert': args.proj_vert_wt, | |
| 'depth': args.depth_wt, | |
| 'mask': args.mask_wt | |
| } | |
| self.loss_weights = { | |
| 'rgb': 1.0, | |
| 'flow': 1.0, | |
| 'proj_joint': 1.0, | |
| 'proj_vert': 1.0, | |
| 'depth': 1.0, | |
| 'mask': 1.0 | |
| } | |
| def _setup_logger(self): | |
| """Set up logging configuration.""" | |
| logger = logging.getLogger("animation_optimizer") | |
| logger.setLevel(logging.INFO) | |
| if not logger.handlers: | |
| formatter = logging.Formatter( | |
| "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| return logger | |
| def _add_file_handler(self, log_path): | |
| """Add file handler to logger.""" | |
| file_handler = logging.FileHandler(log_path) | |
| formatter = logging.Formatter("%(asctime)s %(message)s") | |
| file_handler.setFormatter(formatter) | |
| self.logger.addHandler(file_handler) | |
| def _initialize_parameters(self, batch_size, num_joints): | |
| """Initialize optimization parameters.""" | |
| # Fixed first frame quaternions (identity) | |
| fixed_quat_0 = torch.zeros((1, num_joints, 4), device=self.device) | |
| fixed_quat_0[..., 0] = 1.0 | |
| # Initialize learnable quaternions for frames 1 to B-1 | |
| learn_quats_init = torch.zeros((batch_size - 1, num_joints, 4), device=self.device) | |
| learn_quats_init[..., 0] = 1.0 | |
| quats_to_optimize = learn_quats_init.clone().detach().requires_grad_(True) | |
| # Initialize global transformations | |
| fixed_global_quat_0 = torch.zeros((1, 4), device=self.device) | |
| fixed_global_quat_0[:, 0] = 1.0 | |
| fixed_global_trans_0 = torch.zeros((1, 3), device=self.device) | |
| # Initialize learnable global transformations | |
| global_quats_init = torch.zeros((batch_size - 1, 4), device=self.device) | |
| global_quats_init[:, 0] = 1.0 | |
| global_trans_init = torch.zeros((batch_size - 1, 3), device=self.device) | |
| global_quats = global_quats_init.clone().detach().requires_grad_(True) | |
| global_trans = global_trans_init.clone().detach().requires_grad_(True) | |
| return quats_to_optimize, global_quats, global_trans, fixed_quat_0, fixed_global_quat_0, fixed_global_trans_0 | |
| def _setup_optimizer_and_scheduler(self, quats_to_optimize, global_quats, global_trans, n_iters): | |
| """Set up optimizer and learning rate scheduler.""" | |
| base_lr = self.args.warm_lr | |
| max_lr = self.args.lr | |
| warmup_steps = 20 | |
| min_lr = self.args.min_lr | |
| quat_lr = base_lr # *2 | |
| optimizer = torch.optim.AdamW([ | |
| {'params': quats_to_optimize, 'lr': quat_lr}, | |
| {'params': global_quats, 'lr': quat_lr}, | |
| {'params': global_trans, 'lr': base_lr} | |
| ]) | |
| scheduler = warmup_then_decay( | |
| optimizer=optimizer, | |
| total_steps=n_iters, | |
| warmup_steps=warmup_steps, | |
| max_lr=max_lr, | |
| min_lr=min_lr, | |
| base_lr=base_lr | |
| ) | |
| return optimizer, scheduler | |
| def _compute_smoothness_losses(self, quats_normed, all_global_quats_normed, all_global_trans, model): | |
| """Compute various smoothness losses.""" | |
| # Rotation smoothness loss using geodesic distance | |
| theta = geodesic_loss(quats_normed[1:], quats_normed[:-1]) | |
| rot_smoothness_loss = (theta ** 2).mean() | |
| # Second-order rotation smoothness (acceleration) | |
| omega = quat_log(quat_multiply(quat_inverse(quats_normed[:-1]), quats_normed[1:])) | |
| rot_acc = omega[1:] - omega[:-1] | |
| rot_acc_smoothness_loss = rot_acc.pow(2).mean() | |
| # Joint motion coherence loss (parent-child relative motion smoothness) | |
| joint_coherence_loss = joint_motion_coherence(quats_normed, model.parent_indices) | |
| # Root motion regularization | |
| root_pos_smooth_loss, root_quat_smooth_loss = root_motion_reg( | |
| all_global_quats_normed, all_global_trans | |
| ) | |
| return rot_smoothness_loss, rot_acc_smoothness_loss, joint_coherence_loss, root_pos_smooth_loss + root_quat_smooth_loss | |
| def pre_calibrate_loss_weights(self, loss_components, target_ratios=None): | |
| """ calibrate loss weights """ | |
| loss_for_ratio = {name: loss.detach().clone() for name, loss in loss_components.items()} | |
| rgb_loss = loss_for_ratio['rgb'].item() | |
| for name, loss_val in loss_for_ratio.items(): | |
| if name == 'rgb': | |
| continue | |
| if loss_val > 1e-8: | |
| scale_factor = rgb_loss / loss_val.item() | |
| target_ratio = target_ratios.get(name, 1.0) | |
| new_weight = self.loss_weights.get(name, 1.0) * scale_factor * target_ratio | |
| self.loss_weights[name] = new_weight | |
| def _compute_losses( | |
| self, | |
| model, | |
| renderer, | |
| images_batch, | |
| tracked_joints_2d, | |
| joint_vis_mask, | |
| track_verts_2d, | |
| vert_vis_mask, | |
| sampled_vertex_indices, | |
| track_indices, | |
| flow_dirs, | |
| depth_gt_raw, | |
| mask, | |
| out_dir, | |
| iteration | |
| ): | |
| """Compute all losses for the optimization.""" | |
| batch_size = images_batch.shape[0] | |
| meshes = [model.get_mesh(t) for t in range(batch_size)] | |
| pred_images_all = renderer.render_batch(meshes) | |
| # 2D projection losses | |
| pred_joints_3d = model.joint_positions | |
| proj_joint_loss = compute_reprojection_loss( | |
| renderer, joint_vis_mask, pred_joints_3d, | |
| tracked_joints_2d, self.args.img_size | |
| ) | |
| pred_points_3d = model.deformed_vertices[0] | |
| proj_vert_loss = compute_reprojection_loss( | |
| renderer, vert_vis_mask, | |
| pred_points_3d[:, sampled_vertex_indices], | |
| track_verts_2d[:, track_indices], | |
| self.args.img_size | |
| ) | |
| # RGB loss | |
| pred_rgb = pred_images_all[..., :3] | |
| real_rgb = images_batch[..., :3] | |
| diff_rgb_masked = (pred_rgb - real_rgb) * mask.unsqueeze(-1) | |
| mse_rgb_num = (diff_rgb_masked ** 2).sum() | |
| mse_rgb_den = mask.sum() * 3 | |
| rgb_loss = mse_rgb_num / mse_rgb_den.clamp_min(1e-8) | |
| # Mask loss | |
| silhouette_soft = renderer.render_silhouette_batch(meshes).squeeze() | |
| mask_loss = F.binary_cross_entropy(silhouette_soft, mask) | |
| # Depth losses | |
| fragments = renderer.get_rasterization_fragments(meshes) | |
| zbuf_depths = fragments.zbuf[..., 0] | |
| depth_loss = compute_depth_loss_normalized(depth_gt_raw, zbuf_depths, mask) | |
| # Flow losses | |
| flow_loss = calculate_flow_loss(flow_dirs, self.device, mask, renderer, model) | |
| loss_components = { | |
| 'rgb': rgb_loss, | |
| 'proj_joint': proj_joint_loss, | |
| 'proj_vert': proj_vert_loss, | |
| 'depth': depth_loss, | |
| 'flow': flow_loss, | |
| 'mask': mask_loss | |
| } | |
| return loss_components | |
| def optimization( | |
| self, | |
| images_batch, | |
| model, | |
| renderer, | |
| tracked_joints_2d, | |
| joint_vis_mask, | |
| track_verts_2d, | |
| vert_vis_mask, | |
| sampled_vertex_indices, | |
| track_indices, | |
| flow_dirs, | |
| n_iters, | |
| out_dir): | |
| """ | |
| Optimize animation parameters with fixed first frame. | |
| """ | |
| torch.autograd.set_detect_anomaly(True) | |
| batch_size, _, _, _ = images_batch.shape | |
| num_joints = model.joints_rest.shape[0] | |
| # Setup output directory and logging | |
| os.makedirs(out_dir, exist_ok=True) | |
| log_path = os.path.join(out_dir, "optimization.log") | |
| self._add_file_handler(log_path) | |
| # Initialize parameters | |
| (quats_to_optimize, global_quats, global_trans, | |
| fixed_quat_0, fixed_global_quat_0, fixed_global_trans_0) = self._initialize_parameters(batch_size, num_joints) | |
| # Setup rest positions and bind matrices | |
| rest_local_pos = compute_rest_local_positions(model.joints_rest, model.parent_indices) | |
| model.initialize_bind_matrices(rest_local_pos) | |
| # Setup optimizer and scheduler | |
| optimizer, scheduler = self._setup_optimizer_and_scheduler( | |
| quats_to_optimize, global_quats, global_trans, n_iters | |
| ) | |
| # Initialize depth module and flow weights | |
| depth_module = DepthModule( | |
| encoder='vitl', | |
| device=self.device, | |
| input_size=images_batch.shape[1], | |
| fp32=False | |
| ) | |
| # Prepare masks | |
| real_rgb = images_batch[..., :3] | |
| threshold = 0.95 | |
| with torch.no_grad(): | |
| background_mask = (real_rgb > threshold).all(dim=-1) | |
| mask = (~background_mask).float() | |
| depth_gt_raw = prepare_depth( | |
| flow_dirs.replace('flow', 'depth'), real_rgb, self.device, depth_module | |
| ) | |
| # Optimization tracking | |
| best_loss = float('inf') | |
| patience = 0 | |
| best_params = None | |
| pbar = tqdm(total=n_iters, desc="Optimizing animation") | |
| for iteration in range(n_iters): | |
| # Combine fixed and learnable parameters | |
| quats_all = torch.cat([fixed_quat_0, quats_to_optimize], dim=0) | |
| # Normalize quaternions | |
| reshaped = quats_all.reshape(-1, 4) | |
| norm = torch.norm(reshaped, dim=1, keepdim=True).clamp_min(1e-8) | |
| quats_normed = (reshaped / norm).reshape(batch_size, num_joints, 4) | |
| # Global transformations | |
| all_global_quats = torch.cat([fixed_global_quat_0, global_quats], dim=0) | |
| all_global_trans = torch.cat([fixed_global_trans_0, global_trans], dim=0) | |
| all_global_quats_normed = all_global_quats / torch.norm( | |
| all_global_quats, dim=-1, keepdim=True | |
| ).clamp_min(1e-8) | |
| # Compute smoothness losses | |
| (rot_smoothness_loss, rot_acc_smoothness_loss, joint_coherence_loss, | |
| root_smooth_loss) = self._compute_smoothness_losses( | |
| quats_normed, all_global_quats_normed, all_global_trans, model | |
| ) | |
| # animate model | |
| model.animate(quats_normed, all_global_quats_normed, all_global_trans) | |
| # Verify first frame hasn't changed | |
| verts0 = model.vertices[0] | |
| de0 = model.deformed_vertices[0][0] | |
| assert torch.allclose(de0, verts0, atol=1e-2), "First frame vertices have changed!" | |
| # Compute all losses | |
| loss_components = self._compute_losses( | |
| model, renderer, images_batch, tracked_joints_2d, joint_vis_mask, | |
| track_verts_2d, vert_vis_mask, sampled_vertex_indices, track_indices, | |
| flow_dirs, depth_gt_raw, mask, out_dir, iteration | |
| ) | |
| total_smoothness_loss = rot_smoothness_loss + rot_acc_smoothness_loss * 10 | |
| if iteration == 0: | |
| self.pre_calibrate_loss_weights(loss_components, self.target_ratios) | |
| total_loss = ( | |
| loss_components['rgb'] + | |
| self.loss_weights['mask'] * loss_components['mask'] + | |
| self.loss_weights['flow'] * loss_components['flow'] + | |
| self.loss_weights['proj_joint'] * loss_components['proj_joint'] + | |
| self.loss_weights['proj_vert'] * loss_components['proj_vert'] + | |
| self.loss_weights['depth'] * loss_components['depth'] + | |
| self.args.smooth_weight * total_smoothness_loss + | |
| self.args.coherence_weight * joint_coherence_loss + | |
| self.args.root_smooth_weight * root_smooth_loss | |
| ) | |
| # Optimization step | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| torch.nn.utils.clip_grad_norm_( | |
| [quats_to_optimize, global_quats, global_trans], | |
| max_norm=self.gradient_clip_norm | |
| ) | |
| optimizer.step() | |
| scheduler.step() | |
| # Update progress bar and logging | |
| loss_desc = ( | |
| f"Loss: {total_loss.item():.4f}, " | |
| f"RGB: {loss_components['rgb'].item():.4f}, " | |
| f"Mask: {self.loss_weights['mask'] * loss_components['mask'].item():.4f}, " | |
| f"Flow: {self.loss_weights['flow'] * loss_components['flow'].item():.4f}, " | |
| f"Proj_joint: {self.loss_weights['proj_joint'] * loss_components['proj_joint'].item():.4f}, " | |
| f"Proj_vert: {self.loss_weights['proj_vert'] * loss_components['proj_vert'].item():.4f}, " | |
| f"Depth: {self.loss_weights['depth'] * loss_components['depth'].item():.4f}, " | |
| f"Smooth: {self.args.smooth_weight * total_smoothness_loss.item():.4f}, " | |
| f"Joint smooth: {self.args.coherence_weight * joint_coherence_loss.item():.4f}, " | |
| f"Root smooth: {self.args.root_smooth_weight * root_smooth_loss.item():.4f}" | |
| ) | |
| pbar.set_description(loss_desc) | |
| if iteration % 5 == 0: | |
| self.logger.info(f"Iter {iteration}: {loss_desc}") | |
| # Adaptive reinitialization | |
| current_loss = total_loss.item() | |
| if current_loss < best_loss: | |
| best_loss = current_loss | |
| best_params = { | |
| 'quats': quats_to_optimize.clone().detach(), | |
| 'global_quats': global_quats.clone().detach(), | |
| 'global_trans': global_trans.clone().detach() | |
| } | |
| patience = 0 | |
| elif (current_loss > best_loss * self.loss_divergence_factor or | |
| patience > self.reinit_patience_threshold * 2): | |
| # Reinitialize with best parameters | |
| quats_to_optimize = best_params['quats'].clone().requires_grad_(True) | |
| global_quats = best_params['global_quats'].clone().requires_grad_(True) | |
| global_trans = best_params['global_trans'].clone().requires_grad_(True) | |
| optimizer, scheduler = self._setup_optimizer_and_scheduler( | |
| quats_to_optimize, global_quats, global_trans, n_iters | |
| ) | |
| patience = 0 | |
| self.logger.info(f'Adaptive reset at iteration {iteration} with best loss: {best_loss:.6f}') | |
| else: | |
| patience += 1 | |
| pbar.update(1) | |
| pbar.close() | |
| # Prepare final results | |
| quats_final = torch.cat([fixed_quat_0, best_params['quats']], dim=0) | |
| # Final normalization | |
| reshaped = quats_final.reshape(-1, 4) | |
| norm = torch.norm(reshaped, dim=1, keepdim=True).clamp_min(1e-8) | |
| quats_final = (reshaped / norm).reshape(batch_size, num_joints, 4) | |
| global_quats_final = torch.cat([fixed_global_quat_0, best_params['global_quats']], dim=0) | |
| global_trans_final = torch.cat([fixed_global_trans_0, best_params['global_trans']], dim=0) | |
| global_quats_final = global_quats_final / torch.norm( | |
| global_quats_final, dim=-1, keepdim=True | |
| ).clamp_min(1e-8) | |
| return quats_final, global_quats_final, global_trans_final | |
| def load_and_prepare_data(args): | |
| """Load and prepare all necessary data for optimization.""" | |
| # Define paths | |
| base_path = f'{args.input_path}/{args.seq_name}' | |
| mesh_path = f'{base_path}/objs/mesh.obj' | |
| rig_path = f'{base_path}/objs/rig.txt' | |
| img_path = f'{base_path}/imgs' | |
| flow_dirs = f'{base_path}/flow' | |
| # Load model | |
| model = load_model_from_obj_and_rig(mesh_path, rig_path, device=args.device) | |
| # Load images | |
| img_files = sorted(glob.glob(os.path.join(img_path, "*.png"))) | |
| images = [] | |
| for f in img_files: | |
| img = Image.open(f).convert("RGBA") | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| t = torch.from_numpy(arr).to(args.device) | |
| images.append(t) | |
| images_batch = torch.stack(images, dim=0) | |
| return model, images_batch, flow_dirs, img_path | |
| def setup_renderers(args): | |
| """Setup multiple renderers for different camera views.""" | |
| available_views = [ | |
| "front", "back", "left", "right", | |
| "front_left", "front_right", "back_left", "back_right" | |
| ] | |
| if args.main_renderer not in available_views: | |
| raise ValueError(f"Main renderer '{args.main_renderer}' not found in available cameras: {available_views}") | |
| main_cam_config = json.load(open(f"utils/cameras/{args.main_renderer}.json")) | |
| main_renderer = MeshRenderer3D(args.device, image_size=args.img_size, cam_params=main_cam_config) | |
| additional_views = [view.strip() for view in args.additional_renderers.split(',') if view.strip()] | |
| if len(additional_views) > 3: | |
| print(f"Warning: Only first 3 additional renderers will be used. Got: {additional_views}") | |
| additional_views = additional_views[:3] | |
| additional_renderers = {} | |
| for view_name in additional_views: | |
| if view_name in available_views and view_name != args.main_renderer: | |
| cam_config = json.load(open(f"utils/cameras/{view_name}.json")) | |
| renderer = MeshRenderer3D(args.device, image_size=args.img_size, cam_params=cam_config) | |
| additional_renderers[f"{view_name}_renderer"] = renderer | |
| elif view_name == args.main_renderer: | |
| print(f"Warning: '{view_name}' is already the main renderer, skipping...") | |
| elif view_name not in available_views: | |
| print(f"Warning: Camera view '{view_name}' not found, skipping...") | |
| return main_renderer, additional_renderers | |
| def get_parser(): | |
| """Create argument parser with all configuration options.""" | |
| parser = argparse.ArgumentParser(description="3D Rigging Optimization") | |
| # Training parameters | |
| training_group = parser.add_argument_group('Training') | |
| training_group.add_argument("--iter", type=int, default=500, help="Number of training iterations") | |
| training_group.add_argument("--img_size", type=int, default=512, help="Image resolution") | |
| training_group.add_argument("--device", type=str, default="cuda:0", help="Device to use") | |
| training_group.add_argument("--img_fps", type=int, default=15, help="Image frame rate") | |
| training_group.add_argument('--main_renderer', type=str, default='front', help='Main renderer camera view (default: front)') | |
| training_group.add_argument('--additional_renderers', type=str, default="back, right, left", help='Additional renderer views (max 3), comma-separated (e.g., "back,left,right"). ') | |
| # Learning rates | |
| lr_group = parser.add_argument_group('Learning Rates') | |
| lr_group.add_argument("--lr", type=float, default=2e-3, help="Base learning rate") | |
| lr_group.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") | |
| lr_group.add_argument("--warm_lr", type=float, default=1e-5, help="Warmup learning rate") | |
| # Loss weights | |
| loss_group = parser.add_argument_group('Loss Weights') | |
| loss_group.add_argument("--smooth_weight", type=float, default=0.2) | |
| loss_group.add_argument("--root_smooth_weight", type=float, default=1.0) | |
| loss_group.add_argument("--coherence_weight", type=float, default=10) | |
| loss_group.add_argument("--rgb_wt", type=float, default=1.0, help="RGB loss target ratio (relative importance)") | |
| loss_group.add_argument("--mask_wt", type=float, default=1.0, help="Mask loss target ratio") | |
| loss_group.add_argument("--proj_joint_wt", type=float, default=1.5, help="Joint projection loss target ratio") | |
| loss_group.add_argument("--proj_vert_wt", type=float, default=3.0, help="Point projection loss target ratio") | |
| loss_group.add_argument("--depth_wt", type=float, default=0.8, help="Depth loss target ratio") | |
| loss_group.add_argument("--flow_wt", type=float, default=0.8, help="Flow loss target ratio") | |
| # Data and output | |
| data_group = parser.add_argument_group('Data and Output') | |
| data_group.add_argument("--input_path", type=str, default="inputs") | |
| data_group.add_argument("--save_path", type=str, default="results") | |
| data_group.add_argument("--save_name", type=str, default="results") | |
| data_group.add_argument("--seq_name", type=str, default=None) | |
| # Flags | |
| flag_group = parser.add_argument_group('Flags') | |
| flag_group.add_argument('--gauss_filter', action='store_true', default=False) | |
| return parser | |
| def main(): | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| # Setup output directory | |
| out_dir = f'{args.save_path}/{args.seq_name}/{args.save_name}' | |
| save_args(args, out_dir) | |
| # Initialize optimizer | |
| ani_optimizer = AnimationOptimizer(args, device=args.device) | |
| # Setup renderers | |
| renderer, additional_renderers = setup_renderers(args) | |
| # Load and prepare data | |
| model, images_batch, flow_dirs, img_path = load_and_prepare_data(args) | |
| # Setup tracking | |
| joint_vis_mask = visualize_joints_on_mesh(model, renderer, args.seq_name, out_dir=out_dir) | |
| joint_vis_mask = torch.from_numpy(joint_vis_mask).float().to(args.device) | |
| joint_project_2d = renderer.project_points(model.joints_rest) | |
| # Setup track paths | |
| track_2d_path = img_path.replace('imgs', 'track_2d_joints') | |
| os.makedirs(track_2d_path, exist_ok=True) | |
| # Load or generate tracks | |
| if not os.listdir(track_2d_path): | |
| print("Generating joint tracks") | |
| tracked_joints_2d = save_track(args.seq_name, joint_project_2d, img_path, track_2d_path, out_dir) | |
| else: | |
| print("Loading existing joint tracks") | |
| tracked_joints_2d = np.load(f'{track_2d_path}/pred_tracks.npy') | |
| # Setup point tracking | |
| vert_vis_mask = visualize_points_on_mesh(model, renderer, args.seq_name, out_dir=out_dir) | |
| vert_vis_mask = torch.from_numpy(vert_vis_mask).float().to(args.device) | |
| track_verts_2d, track_indices, sampled_vertex_indices = save_track_points( | |
| vert_vis_mask, renderer, model, img_path, out_dir, args | |
| ) | |
| vert_vis_mask = vert_vis_mask[sampled_vertex_indices] | |
| # Run optimization | |
| print(f"Starting optimization") | |
| final_quats, root_quats, root_pos = ani_optimizer.optimization( | |
| images_batch=images_batch, | |
| model=model, | |
| renderer=renderer, | |
| tracked_joints_2d=tracked_joints_2d, | |
| joint_vis_mask=joint_vis_mask, | |
| track_verts_2d=track_verts_2d, | |
| vert_vis_mask=vert_vis_mask, | |
| sampled_vertex_indices=sampled_vertex_indices, | |
| track_indices=track_indices, | |
| flow_dirs=flow_dirs, | |
| n_iters=args.iter, | |
| out_dir=out_dir | |
| ) | |
| # Save results | |
| save_and_smooth_results( | |
| args, model, renderer, final_quats, root_quats, root_pos, | |
| out_dir, additional_renderers, fps=10 | |
| ) | |
| print("Optimization completed successfully") | |
| save_final_video(args) | |
| if __name__ == "__main__": | |
| main() |