seungminkwak's picture
reset: clean history (purge leaked token)
08b23ce
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import torch
import random
from pytorch3d.io import load_objs_as_meshes, load_obj
from pytorch3d.renderer import TexturesAtlas
from pytorch3d.structures import Meshes
from model import RiggingModel
def prepare_depth(depth_path, input_frames, device, depth_model):
os.makedirs(depth_path, exist_ok=True)
depth_path = f"{depth_path}/depth_gt_raw.pt"
if os.path.exists(depth_path):
print("load GT depth...")
depth_gt_raw = torch.load(depth_path, map_location=device)
else:
print("run VideoDepthAnything and save.")
with torch.no_grad():
depth_gt_raw = depth_model.get_depth_maps(input_frames)
torch.save(depth_gt_raw.cpu(), depth_path)
depth_gt_raw = depth_gt_raw.to(device)
return depth_gt_raw
def normalize_vertices(verts):
"""Normalize vertices to a unit cube."""
vmin, vmax = verts.min(dim=0).values, verts.max(dim=0).values
center = (vmax + vmin) / 2.0
scale = (vmax - vmin).max()
verts_norm = (verts - center) / scale
return verts_norm, center, scale
def build_atlas_texture(obj_path, atlas_size, device):
"""Load OBJ + materials and bake all textures into a single atlas."""
verts, faces, aux = load_obj(
obj_path,
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=atlas_size,
texture_wrap="repeat",
)
atlas = aux.texture_atlas # (F, R, R, 3)
verts_norm, _, _ = normalize_vertices(verts)
mesh_atlas = Meshes(
verts=[verts_norm],
faces=[faces.verts_idx],
textures=TexturesAtlas(atlas=[atlas]),
)
return mesh_atlas
def read_rig_file(file_path):
"""
Read rig from txt file, our format is the same as RigNet:
joints joint_name x y z
root root_joint_name
skin vertex_idx joint_name weight joint_name weight ...
hier parent_joint_name child_joint_name
"""
joints = []
bones = []
joint_names = []
joint_mapping = {}
joint_index = 0
skinning_data = {} # Dictionary to store vertex index -> [(joint_idx, weight), ...]
with open(file_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.split()
if line.startswith('joints'):
name = parts[1]
position = [float(parts[2]), float(parts[3]), float(parts[4])]
joints.append(position)
joint_names.append(name)
joint_mapping[name] = joint_index
joint_index += 1
elif line.startswith('hier'):
parent_joint = joint_mapping[parts[1]]
child_joint = joint_mapping[parts[2]]
bones.append([parent_joint, child_joint])
elif line.startswith('root'):
root = joint_mapping[parts[1]]
elif line.startswith('skin'):
vertex_idx = int(parts[1])
if vertex_idx not in skinning_data:
skinning_data[vertex_idx] = []
for i in range(2, len(parts), 2):
if i+1 < len(parts):
joint_name = parts[i]
weight = float(parts[i+1])
if joint_name in joint_mapping:
joint_idx = joint_mapping[joint_name]
skinning_data[vertex_idx].append((joint_idx, weight))
return np.array(joints), np.array(bones), root, joint_names, skinning_data
def load_model_from_obj_and_rig(
mesh_path: str,
rig_path: str,
device: str | torch.device = "cuda",
use_skin_color: bool = True,
atlas_size: int = 8,
):
"""Load a 3D model from OBJ and rig files."""
# 1) read raw mesh
raw_mesh = load_objs_as_meshes([mesh_path], device=device)
verts_raw = raw_mesh.verts_packed() # (V,3)
faces_idx = raw_mesh.faces_packed() # (F,3)
# 2) read rig data
joints_np, bones_np, root_idx, joint_names, skinning_data = read_rig_file(rig_path)
J = joints_np.shape[0]
# parent indices, default -1
parent_idx = [-1] * J
for p, c in bones_np:
parent_idx[c] = p
verts_norm, center, scale = normalize_vertices(verts_raw)
joints_t = torch.as_tensor(joints_np, dtype=torch.float32, device=device)
joints_norm = (joints_t - center) / scale
# skin weights tensor (V,J)
V = verts_raw.shape[0]
skin_weights = torch.zeros(V, J, dtype=torch.float32, device=device)
for v_idx, lst in skinning_data.items():
for j_idx, w in lst:
skin_weights[v_idx, j_idx] = w
# 3) texture strategy
mesh_norm = build_atlas_texture(mesh_path, atlas_size, device)
tex = mesh_norm.textures
# 4) pack into Model class
model = RiggingModel(device=device)
model.vertices = [mesh_norm.verts_packed()]
model.faces = [faces_idx]
model.textures = [tex]
# rig meta
model.bones = bones_np # (B,2)
model.parent_indices = parent_idx
model.root_index = root_idx
model.skin_weights = [skin_weights]
model.bind_matrices_inv = torch.eye(4, device=device).unsqueeze(0).expand(J, -1, -1).contiguous()
model.joints_rest = joints_norm
return model