seungminkwak's picture
reset: clean history (purge leaked token)
08b23ce
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import json
import numpy as np
import logging
import glob
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from renderer import MeshRenderer3D
from model import RiggingModel
from utils.quat_utils import (
compute_rest_local_positions, quat_inverse, quat_log, quat_multiply
)
from utils.loss_utils import (
DepthModule, compute_reprojection_loss, geodesic_loss, root_motion_reg,
calculate_flow_loss, compute_depth_loss_normalized, joint_motion_coherence
)
from utils.data_loader import load_model_from_obj_and_rig, prepare_depth
from utils.save_utils import (
save_args, visualize_joints_on_mesh, save_final_video,
save_and_smooth_results, visualize_points_on_mesh, save_track_points
)
from utils.misc import warmup_then_decay
from third_partys.co_tracker.save_track import save_track
class AnimationOptimizer:
"""Main class for animation optimization with video guidance."""
def __init__(self, args, device = 'cuda:0'):
self.args = args
self.device = device
self.logger = self._setup_logger()
# Training parameters
self.reinit_patience_threshold = 20
self.loss_divergence_factor = 2.0
self.gradient_clip_norm = 1.0
# Loss weights
self.target_ratios = {
'rgb': args.rgb_wt,
'flow': args.flow_wt,
'proj_joint': args.proj_joint_wt,
'proj_vert': args.proj_vert_wt,
'depth': args.depth_wt,
'mask': args.mask_wt
}
self.loss_weights = {
'rgb': 1.0,
'flow': 1.0,
'proj_joint': 1.0,
'proj_vert': 1.0,
'depth': 1.0,
'mask': 1.0
}
def _setup_logger(self):
"""Set up logging configuration."""
logger = logging.getLogger("animation_optimizer")
logger.setLevel(logging.INFO)
if not logger.handlers:
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
def _add_file_handler(self, log_path):
"""Add file handler to logger."""
file_handler = logging.FileHandler(log_path)
formatter = logging.Formatter("%(asctime)s %(message)s")
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
def _initialize_parameters(self, batch_size, num_joints):
"""Initialize optimization parameters."""
# Fixed first frame quaternions (identity)
fixed_quat_0 = torch.zeros((1, num_joints, 4), device=self.device)
fixed_quat_0[..., 0] = 1.0
# Initialize learnable quaternions for frames 1 to B-1
learn_quats_init = torch.zeros((batch_size - 1, num_joints, 4), device=self.device)
learn_quats_init[..., 0] = 1.0
quats_to_optimize = learn_quats_init.clone().detach().requires_grad_(True)
# Initialize global transformations
fixed_global_quat_0 = torch.zeros((1, 4), device=self.device)
fixed_global_quat_0[:, 0] = 1.0
fixed_global_trans_0 = torch.zeros((1, 3), device=self.device)
# Initialize learnable global transformations
global_quats_init = torch.zeros((batch_size - 1, 4), device=self.device)
global_quats_init[:, 0] = 1.0
global_trans_init = torch.zeros((batch_size - 1, 3), device=self.device)
global_quats = global_quats_init.clone().detach().requires_grad_(True)
global_trans = global_trans_init.clone().detach().requires_grad_(True)
return quats_to_optimize, global_quats, global_trans, fixed_quat_0, fixed_global_quat_0, fixed_global_trans_0
def _setup_optimizer_and_scheduler(self, quats_to_optimize, global_quats, global_trans, n_iters):
"""Set up optimizer and learning rate scheduler."""
base_lr = self.args.warm_lr
max_lr = self.args.lr
warmup_steps = 20
min_lr = self.args.min_lr
quat_lr = base_lr # *2
optimizer = torch.optim.AdamW([
{'params': quats_to_optimize, 'lr': quat_lr},
{'params': global_quats, 'lr': quat_lr},
{'params': global_trans, 'lr': base_lr}
])
scheduler = warmup_then_decay(
optimizer=optimizer,
total_steps=n_iters,
warmup_steps=warmup_steps,
max_lr=max_lr,
min_lr=min_lr,
base_lr=base_lr
)
return optimizer, scheduler
def _compute_smoothness_losses(self, quats_normed, all_global_quats_normed, all_global_trans, model):
"""Compute various smoothness losses."""
# Rotation smoothness loss using geodesic distance
theta = geodesic_loss(quats_normed[1:], quats_normed[:-1])
rot_smoothness_loss = (theta ** 2).mean()
# Second-order rotation smoothness (acceleration)
omega = quat_log(quat_multiply(quat_inverse(quats_normed[:-1]), quats_normed[1:]))
rot_acc = omega[1:] - omega[:-1]
rot_acc_smoothness_loss = rot_acc.pow(2).mean()
# Joint motion coherence loss (parent-child relative motion smoothness)
joint_coherence_loss = joint_motion_coherence(quats_normed, model.parent_indices)
# Root motion regularization
root_pos_smooth_loss, root_quat_smooth_loss = root_motion_reg(
all_global_quats_normed, all_global_trans
)
return rot_smoothness_loss, rot_acc_smoothness_loss, joint_coherence_loss, root_pos_smooth_loss + root_quat_smooth_loss
def pre_calibrate_loss_weights(self, loss_components, target_ratios=None):
""" calibrate loss weights """
loss_for_ratio = {name: loss.detach().clone() for name, loss in loss_components.items()}
rgb_loss = loss_for_ratio['rgb'].item()
for name, loss_val in loss_for_ratio.items():
if name == 'rgb':
continue
if loss_val > 1e-8:
scale_factor = rgb_loss / loss_val.item()
target_ratio = target_ratios.get(name, 1.0)
new_weight = self.loss_weights.get(name, 1.0) * scale_factor * target_ratio
self.loss_weights[name] = new_weight
def _compute_losses(
self,
model,
renderer,
images_batch,
tracked_joints_2d,
joint_vis_mask,
track_verts_2d,
vert_vis_mask,
sampled_vertex_indices,
track_indices,
flow_dirs,
depth_gt_raw,
mask,
out_dir,
iteration
):
"""Compute all losses for the optimization."""
batch_size = images_batch.shape[0]
meshes = [model.get_mesh(t) for t in range(batch_size)]
pred_images_all = renderer.render_batch(meshes)
# 2D projection losses
pred_joints_3d = model.joint_positions
proj_joint_loss = compute_reprojection_loss(
renderer, joint_vis_mask, pred_joints_3d,
tracked_joints_2d, self.args.img_size
)
pred_points_3d = model.deformed_vertices[0]
proj_vert_loss = compute_reprojection_loss(
renderer, vert_vis_mask,
pred_points_3d[:, sampled_vertex_indices],
track_verts_2d[:, track_indices],
self.args.img_size
)
# RGB loss
pred_rgb = pred_images_all[..., :3]
real_rgb = images_batch[..., :3]
diff_rgb_masked = (pred_rgb - real_rgb) * mask.unsqueeze(-1)
mse_rgb_num = (diff_rgb_masked ** 2).sum()
mse_rgb_den = mask.sum() * 3
rgb_loss = mse_rgb_num / mse_rgb_den.clamp_min(1e-8)
# Mask loss
silhouette_soft = renderer.render_silhouette_batch(meshes).squeeze()
mask_loss = F.binary_cross_entropy(silhouette_soft, mask)
# Depth losses
fragments = renderer.get_rasterization_fragments(meshes)
zbuf_depths = fragments.zbuf[..., 0]
depth_loss = compute_depth_loss_normalized(depth_gt_raw, zbuf_depths, mask)
# Flow losses
flow_loss = calculate_flow_loss(flow_dirs, self.device, mask, renderer, model)
loss_components = {
'rgb': rgb_loss,
'proj_joint': proj_joint_loss,
'proj_vert': proj_vert_loss,
'depth': depth_loss,
'flow': flow_loss,
'mask': mask_loss
}
return loss_components
def optimization(
self,
images_batch,
model,
renderer,
tracked_joints_2d,
joint_vis_mask,
track_verts_2d,
vert_vis_mask,
sampled_vertex_indices,
track_indices,
flow_dirs,
n_iters,
out_dir):
"""
Optimize animation parameters with fixed first frame.
"""
torch.autograd.set_detect_anomaly(True)
batch_size, _, _, _ = images_batch.shape
num_joints = model.joints_rest.shape[0]
# Setup output directory and logging
os.makedirs(out_dir, exist_ok=True)
log_path = os.path.join(out_dir, "optimization.log")
self._add_file_handler(log_path)
# Initialize parameters
(quats_to_optimize, global_quats, global_trans,
fixed_quat_0, fixed_global_quat_0, fixed_global_trans_0) = self._initialize_parameters(batch_size, num_joints)
# Setup rest positions and bind matrices
rest_local_pos = compute_rest_local_positions(model.joints_rest, model.parent_indices)
model.initialize_bind_matrices(rest_local_pos)
# Setup optimizer and scheduler
optimizer, scheduler = self._setup_optimizer_and_scheduler(
quats_to_optimize, global_quats, global_trans, n_iters
)
# Initialize depth module and flow weights
depth_module = DepthModule(
encoder='vitl',
device=self.device,
input_size=images_batch.shape[1],
fp32=False
)
# Prepare masks
real_rgb = images_batch[..., :3]
threshold = 0.95
with torch.no_grad():
background_mask = (real_rgb > threshold).all(dim=-1)
mask = (~background_mask).float()
depth_gt_raw = prepare_depth(
flow_dirs.replace('flow', 'depth'), real_rgb, self.device, depth_module
)
# Optimization tracking
best_loss = float('inf')
patience = 0
best_params = None
pbar = tqdm(total=n_iters, desc="Optimizing animation")
for iteration in range(n_iters):
# Combine fixed and learnable parameters
quats_all = torch.cat([fixed_quat_0, quats_to_optimize], dim=0)
# Normalize quaternions
reshaped = quats_all.reshape(-1, 4)
norm = torch.norm(reshaped, dim=1, keepdim=True).clamp_min(1e-8)
quats_normed = (reshaped / norm).reshape(batch_size, num_joints, 4)
# Global transformations
all_global_quats = torch.cat([fixed_global_quat_0, global_quats], dim=0)
all_global_trans = torch.cat([fixed_global_trans_0, global_trans], dim=0)
all_global_quats_normed = all_global_quats / torch.norm(
all_global_quats, dim=-1, keepdim=True
).clamp_min(1e-8)
# Compute smoothness losses
(rot_smoothness_loss, rot_acc_smoothness_loss, joint_coherence_loss,
root_smooth_loss) = self._compute_smoothness_losses(
quats_normed, all_global_quats_normed, all_global_trans, model
)
# animate model
model.animate(quats_normed, all_global_quats_normed, all_global_trans)
# Verify first frame hasn't changed
verts0 = model.vertices[0]
de0 = model.deformed_vertices[0][0]
assert torch.allclose(de0, verts0, atol=1e-2), "First frame vertices have changed!"
# Compute all losses
loss_components = self._compute_losses(
model, renderer, images_batch, tracked_joints_2d, joint_vis_mask,
track_verts_2d, vert_vis_mask, sampled_vertex_indices, track_indices,
flow_dirs, depth_gt_raw, mask, out_dir, iteration
)
total_smoothness_loss = rot_smoothness_loss + rot_acc_smoothness_loss * 10
if iteration == 0:
self.pre_calibrate_loss_weights(loss_components, self.target_ratios)
total_loss = (
loss_components['rgb'] +
self.loss_weights['mask'] * loss_components['mask'] +
self.loss_weights['flow'] * loss_components['flow'] +
self.loss_weights['proj_joint'] * loss_components['proj_joint'] +
self.loss_weights['proj_vert'] * loss_components['proj_vert'] +
self.loss_weights['depth'] * loss_components['depth'] +
self.args.smooth_weight * total_smoothness_loss +
self.args.coherence_weight * joint_coherence_loss +
self.args.root_smooth_weight * root_smooth_loss
)
# Optimization step
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
[quats_to_optimize, global_quats, global_trans],
max_norm=self.gradient_clip_norm
)
optimizer.step()
scheduler.step()
# Update progress bar and logging
loss_desc = (
f"Loss: {total_loss.item():.4f}, "
f"RGB: {loss_components['rgb'].item():.4f}, "
f"Mask: {self.loss_weights['mask'] * loss_components['mask'].item():.4f}, "
f"Flow: {self.loss_weights['flow'] * loss_components['flow'].item():.4f}, "
f"Proj_joint: {self.loss_weights['proj_joint'] * loss_components['proj_joint'].item():.4f}, "
f"Proj_vert: {self.loss_weights['proj_vert'] * loss_components['proj_vert'].item():.4f}, "
f"Depth: {self.loss_weights['depth'] * loss_components['depth'].item():.4f}, "
f"Smooth: {self.args.smooth_weight * total_smoothness_loss.item():.4f}, "
f"Joint smooth: {self.args.coherence_weight * joint_coherence_loss.item():.4f}, "
f"Root smooth: {self.args.root_smooth_weight * root_smooth_loss.item():.4f}"
)
pbar.set_description(loss_desc)
if iteration % 5 == 0:
self.logger.info(f"Iter {iteration}: {loss_desc}")
# Adaptive reinitialization
current_loss = total_loss.item()
if current_loss < best_loss:
best_loss = current_loss
best_params = {
'quats': quats_to_optimize.clone().detach(),
'global_quats': global_quats.clone().detach(),
'global_trans': global_trans.clone().detach()
}
patience = 0
elif (current_loss > best_loss * self.loss_divergence_factor or
patience > self.reinit_patience_threshold * 2):
# Reinitialize with best parameters
quats_to_optimize = best_params['quats'].clone().requires_grad_(True)
global_quats = best_params['global_quats'].clone().requires_grad_(True)
global_trans = best_params['global_trans'].clone().requires_grad_(True)
optimizer, scheduler = self._setup_optimizer_and_scheduler(
quats_to_optimize, global_quats, global_trans, n_iters
)
patience = 0
self.logger.info(f'Adaptive reset at iteration {iteration} with best loss: {best_loss:.6f}')
else:
patience += 1
pbar.update(1)
pbar.close()
# Prepare final results
quats_final = torch.cat([fixed_quat_0, best_params['quats']], dim=0)
# Final normalization
reshaped = quats_final.reshape(-1, 4)
norm = torch.norm(reshaped, dim=1, keepdim=True).clamp_min(1e-8)
quats_final = (reshaped / norm).reshape(batch_size, num_joints, 4)
global_quats_final = torch.cat([fixed_global_quat_0, best_params['global_quats']], dim=0)
global_trans_final = torch.cat([fixed_global_trans_0, best_params['global_trans']], dim=0)
global_quats_final = global_quats_final / torch.norm(
global_quats_final, dim=-1, keepdim=True
).clamp_min(1e-8)
return quats_final, global_quats_final, global_trans_final
def load_and_prepare_data(args):
"""Load and prepare all necessary data for optimization."""
# Define paths
base_path = f'{args.input_path}/{args.seq_name}'
mesh_path = f'{base_path}/objs/mesh.obj'
rig_path = f'{base_path}/objs/rig.txt'
img_path = f'{base_path}/imgs'
flow_dirs = f'{base_path}/flow'
# Load model
model = load_model_from_obj_and_rig(mesh_path, rig_path, device=args.device)
# Load images
img_files = sorted(glob.glob(os.path.join(img_path, "*.png")))
images = []
for f in img_files:
img = Image.open(f).convert("RGBA")
arr = np.array(img, dtype=np.float32) / 255.0
t = torch.from_numpy(arr).to(args.device)
images.append(t)
images_batch = torch.stack(images, dim=0)
return model, images_batch, flow_dirs, img_path
def setup_renderers(args):
"""Setup multiple renderers for different camera views."""
available_views = [
"front", "back", "left", "right",
"front_left", "front_right", "back_left", "back_right"
]
if args.main_renderer not in available_views:
raise ValueError(f"Main renderer '{args.main_renderer}' not found in available cameras: {available_views}")
main_cam_config = json.load(open(f"utils/cameras/{args.main_renderer}.json"))
main_renderer = MeshRenderer3D(args.device, image_size=args.img_size, cam_params=main_cam_config)
additional_views = [view.strip() for view in args.additional_renderers.split(',') if view.strip()]
if len(additional_views) > 3:
print(f"Warning: Only first 3 additional renderers will be used. Got: {additional_views}")
additional_views = additional_views[:3]
additional_renderers = {}
for view_name in additional_views:
if view_name in available_views and view_name != args.main_renderer:
cam_config = json.load(open(f"utils/cameras/{view_name}.json"))
renderer = MeshRenderer3D(args.device, image_size=args.img_size, cam_params=cam_config)
additional_renderers[f"{view_name}_renderer"] = renderer
elif view_name == args.main_renderer:
print(f"Warning: '{view_name}' is already the main renderer, skipping...")
elif view_name not in available_views:
print(f"Warning: Camera view '{view_name}' not found, skipping...")
return main_renderer, additional_renderers
def get_parser():
"""Create argument parser with all configuration options."""
parser = argparse.ArgumentParser(description="3D Rigging Optimization")
# Training parameters
training_group = parser.add_argument_group('Training')
training_group.add_argument("--iter", type=int, default=500, help="Number of training iterations")
training_group.add_argument("--img_size", type=int, default=512, help="Image resolution")
training_group.add_argument("--device", type=str, default="cuda:0", help="Device to use")
training_group.add_argument("--img_fps", type=int, default=15, help="Image frame rate")
training_group.add_argument('--main_renderer', type=str, default='front', help='Main renderer camera view (default: front)')
training_group.add_argument('--additional_renderers', type=str, default="back, right, left", help='Additional renderer views (max 3), comma-separated (e.g., "back,left,right"). ')
# Learning rates
lr_group = parser.add_argument_group('Learning Rates')
lr_group.add_argument("--lr", type=float, default=2e-3, help="Base learning rate")
lr_group.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate")
lr_group.add_argument("--warm_lr", type=float, default=1e-5, help="Warmup learning rate")
# Loss weights
loss_group = parser.add_argument_group('Loss Weights')
loss_group.add_argument("--smooth_weight", type=float, default=0.2)
loss_group.add_argument("--root_smooth_weight", type=float, default=1.0)
loss_group.add_argument("--coherence_weight", type=float, default=10)
loss_group.add_argument("--rgb_wt", type=float, default=1.0, help="RGB loss target ratio (relative importance)")
loss_group.add_argument("--mask_wt", type=float, default=1.0, help="Mask loss target ratio")
loss_group.add_argument("--proj_joint_wt", type=float, default=1.5, help="Joint projection loss target ratio")
loss_group.add_argument("--proj_vert_wt", type=float, default=3.0, help="Point projection loss target ratio")
loss_group.add_argument("--depth_wt", type=float, default=0.8, help="Depth loss target ratio")
loss_group.add_argument("--flow_wt", type=float, default=0.8, help="Flow loss target ratio")
# Data and output
data_group = parser.add_argument_group('Data and Output')
data_group.add_argument("--input_path", type=str, default="inputs")
data_group.add_argument("--save_path", type=str, default="results")
data_group.add_argument("--save_name", type=str, default="results")
data_group.add_argument("--seq_name", type=str, default=None)
# Flags
flag_group = parser.add_argument_group('Flags')
flag_group.add_argument('--gauss_filter', action='store_true', default=False)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# Setup output directory
out_dir = f'{args.save_path}/{args.seq_name}/{args.save_name}'
save_args(args, out_dir)
# Initialize optimizer
ani_optimizer = AnimationOptimizer(args, device=args.device)
# Setup renderers
renderer, additional_renderers = setup_renderers(args)
# Load and prepare data
model, images_batch, flow_dirs, img_path = load_and_prepare_data(args)
# Setup tracking
joint_vis_mask = visualize_joints_on_mesh(model, renderer, args.seq_name, out_dir=out_dir)
joint_vis_mask = torch.from_numpy(joint_vis_mask).float().to(args.device)
joint_project_2d = renderer.project_points(model.joints_rest)
# Setup track paths
track_2d_path = img_path.replace('imgs', 'track_2d_joints')
os.makedirs(track_2d_path, exist_ok=True)
# Load or generate tracks
if not os.listdir(track_2d_path):
print("Generating joint tracks")
tracked_joints_2d = save_track(args.seq_name, joint_project_2d, img_path, track_2d_path, out_dir)
else:
print("Loading existing joint tracks")
tracked_joints_2d = np.load(f'{track_2d_path}/pred_tracks.npy')
# Setup point tracking
vert_vis_mask = visualize_points_on_mesh(model, renderer, args.seq_name, out_dir=out_dir)
vert_vis_mask = torch.from_numpy(vert_vis_mask).float().to(args.device)
track_verts_2d, track_indices, sampled_vertex_indices = save_track_points(
vert_vis_mask, renderer, model, img_path, out_dir, args
)
vert_vis_mask = vert_vis_mask[sampled_vertex_indices]
# Run optimization
print(f"Starting optimization")
final_quats, root_quats, root_pos = ani_optimizer.optimization(
images_batch=images_batch,
model=model,
renderer=renderer,
tracked_joints_2d=tracked_joints_2d,
joint_vis_mask=joint_vis_mask,
track_verts_2d=track_verts_2d,
vert_vis_mask=vert_vis_mask,
sampled_vertex_indices=sampled_vertex_indices,
track_indices=track_indices,
flow_dirs=flow_dirs,
n_iters=args.iter,
out_dir=out_dir
)
# Save results
save_and_smooth_results(
args, model, renderer, final_quats, root_quats, root_pos,
out_dir, additional_renderers, fps=10
)
print("Optimization completed successfully")
save_final_video(args)
if __name__ == "__main__":
main()