Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import numpy as np | |
| import torch | |
| import random | |
| from pytorch3d.io import load_objs_as_meshes, load_obj | |
| from pytorch3d.renderer import TexturesAtlas | |
| from pytorch3d.structures import Meshes | |
| from model import RiggingModel | |
| def prepare_depth(depth_path, input_frames, device, depth_model): | |
| os.makedirs(depth_path, exist_ok=True) | |
| depth_path = f"{depth_path}/depth_gt_raw.pt" | |
| if os.path.exists(depth_path): | |
| print("load GT depth...") | |
| depth_gt_raw = torch.load(depth_path, map_location=device) | |
| else: | |
| print("run VideoDepthAnything and save.") | |
| with torch.no_grad(): | |
| depth_gt_raw = depth_model.get_depth_maps(input_frames) | |
| torch.save(depth_gt_raw.cpu(), depth_path) | |
| depth_gt_raw = depth_gt_raw.to(device) | |
| return depth_gt_raw | |
| def normalize_vertices(verts): | |
| """Normalize vertices to a unit cube.""" | |
| vmin, vmax = verts.min(dim=0).values, verts.max(dim=0).values | |
| center = (vmax + vmin) / 2.0 | |
| scale = (vmax - vmin).max() | |
| verts_norm = (verts - center) / scale | |
| return verts_norm, center, scale | |
| def build_atlas_texture(obj_path, atlas_size, device): | |
| """Load OBJ + materials and bake all textures into a single atlas.""" | |
| verts, faces, aux = load_obj( | |
| obj_path, | |
| device=device, | |
| load_textures=True, | |
| create_texture_atlas=True, | |
| texture_atlas_size=atlas_size, | |
| texture_wrap="repeat", | |
| ) | |
| atlas = aux.texture_atlas # (F, R, R, 3) | |
| verts_norm, _, _ = normalize_vertices(verts) | |
| mesh_atlas = Meshes( | |
| verts=[verts_norm], | |
| faces=[faces.verts_idx], | |
| textures=TexturesAtlas(atlas=[atlas]), | |
| ) | |
| return mesh_atlas | |
| def read_rig_file(file_path): | |
| """ | |
| Read rig from txt file, our format is the same as RigNet: | |
| joints joint_name x y z | |
| root root_joint_name | |
| skin vertex_idx joint_name weight joint_name weight ... | |
| hier parent_joint_name child_joint_name | |
| """ | |
| joints = [] | |
| bones = [] | |
| joint_names = [] | |
| joint_mapping = {} | |
| joint_index = 0 | |
| skinning_data = {} # Dictionary to store vertex index -> [(joint_idx, weight), ...] | |
| with open(file_path, 'r') as file: | |
| lines = file.readlines() | |
| for line in lines: | |
| parts = line.split() | |
| if line.startswith('joints'): | |
| name = parts[1] | |
| position = [float(parts[2]), float(parts[3]), float(parts[4])] | |
| joints.append(position) | |
| joint_names.append(name) | |
| joint_mapping[name] = joint_index | |
| joint_index += 1 | |
| elif line.startswith('hier'): | |
| parent_joint = joint_mapping[parts[1]] | |
| child_joint = joint_mapping[parts[2]] | |
| bones.append([parent_joint, child_joint]) | |
| elif line.startswith('root'): | |
| root = joint_mapping[parts[1]] | |
| elif line.startswith('skin'): | |
| vertex_idx = int(parts[1]) | |
| if vertex_idx not in skinning_data: | |
| skinning_data[vertex_idx] = [] | |
| for i in range(2, len(parts), 2): | |
| if i+1 < len(parts): | |
| joint_name = parts[i] | |
| weight = float(parts[i+1]) | |
| if joint_name in joint_mapping: | |
| joint_idx = joint_mapping[joint_name] | |
| skinning_data[vertex_idx].append((joint_idx, weight)) | |
| return np.array(joints), np.array(bones), root, joint_names, skinning_data | |
| def load_model_from_obj_and_rig( | |
| mesh_path: str, | |
| rig_path: str, | |
| device: str | torch.device = "cuda", | |
| use_skin_color: bool = True, | |
| atlas_size: int = 8, | |
| ): | |
| """Load a 3D model from OBJ and rig files.""" | |
| # 1) read raw mesh | |
| raw_mesh = load_objs_as_meshes([mesh_path], device=device) | |
| verts_raw = raw_mesh.verts_packed() # (V,3) | |
| faces_idx = raw_mesh.faces_packed() # (F,3) | |
| # 2) read rig data | |
| joints_np, bones_np, root_idx, joint_names, skinning_data = read_rig_file(rig_path) | |
| J = joints_np.shape[0] | |
| # parent indices, default -1 | |
| parent_idx = [-1] * J | |
| for p, c in bones_np: | |
| parent_idx[c] = p | |
| verts_norm, center, scale = normalize_vertices(verts_raw) | |
| joints_t = torch.as_tensor(joints_np, dtype=torch.float32, device=device) | |
| joints_norm = (joints_t - center) / scale | |
| # skin weights tensor (V,J) | |
| V = verts_raw.shape[0] | |
| skin_weights = torch.zeros(V, J, dtype=torch.float32, device=device) | |
| for v_idx, lst in skinning_data.items(): | |
| for j_idx, w in lst: | |
| skin_weights[v_idx, j_idx] = w | |
| # 3) texture strategy | |
| mesh_norm = build_atlas_texture(mesh_path, atlas_size, device) | |
| tex = mesh_norm.textures | |
| # 4) pack into Model class | |
| model = RiggingModel(device=device) | |
| model.vertices = [mesh_norm.verts_packed()] | |
| model.faces = [faces_idx] | |
| model.textures = [tex] | |
| # rig meta | |
| model.bones = bones_np # (B,2) | |
| model.parent_indices = parent_idx | |
| model.root_index = root_idx | |
| model.skin_weights = [skin_weights] | |
| model.bind_matrices_inv = torch.eye(4, device=device).unsqueeze(0).expand(J, -1, -1).contiguous() | |
| model.joints_rest = joints_norm | |
| return model | |