Spaces:
Paused
Paused
| import os | |
| import re | |
| import sys | |
| import shutil | |
| import subprocess | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import importlib.util | |
| import requests | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, HttpUrl, Field | |
| # ----------------------------------------------------------------------------- # | |
| # ๐ง ํ๊ฒฝ ๊ณ ์ : libgomp ๊ฒฝ๊ณ /์๋ฌ ํํผ (invalid OMP_NUM_THREADS) | |
| # ----------------------------------------------------------------------------- # | |
| # ์ผ๋ถ ์ปจํ ์ด๋์์ OMP_NUM_THREADS๊ฐ ๋น์ด์๊ฑฐ๋ ์๋ชป ๋ค์ด๊ฐ๋ฉด libgomp๊ฐ ์๋ฌ๋ฅผ ๋ ๋๋ค. | |
| # ์์ ํ๊ฒ ์ ์๊ฐ์ผ๋ก ๊ฐ์ ์ธํ ํฉ๋๋ค. | |
| os.environ["OMP_NUM_THREADS"] = os.environ.get("OMP_NUM_THREADS", "4") | |
| if not os.environ["OMP_NUM_THREADS"].isdigit(): | |
| os.environ["OMP_NUM_THREADS"] = "4" | |
| # ----------------------------------------------------------------------------- # | |
| # ๐ง ๋ฐํ์ ์์กด์ฑ ์๋ ์ค์น (tqdm, einops, scipy, trimesh ๋ฑ) | |
| # - requirements/Dockerfile์ ๋น ์ง ๊ฒฝ์ฐ๋ฅผ ๋๋นํด, ์๋ฒ ๊ธฐ๋ ์ ํ ๋ฒ ์ฒดํฌํด์ ์ค์น | |
| # ----------------------------------------------------------------------------- # | |
| RUNTIME_DEPS = [ | |
| "tqdm", | |
| "einops", | |
| "scipy", | |
| "trimesh", | |
| "accelerate", # ์ถ๊ฐ | |
| "timm", # ์ถ๊ฐ | |
| # ์๋๋ ์ฌ์ ํจํค์ง (์๋ฌ ๋๋ฉด ์๋ ๋ณด๊ฐ) | |
| "networkx", | |
| "scikit-image", | |
| ] | |
| def _need_install(mod_name: str) -> bool: | |
| return importlib.util.find_spec(mod_name) is None | |
| def _pip_install(pkgs: List[str]) -> None: | |
| if not pkgs: | |
| return | |
| try: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", *pkgs]) | |
| except Exception as e: | |
| print(f"[deps] pip install failed for {pkgs}: {e}") | |
| def _ensure_runtime_deps() -> None: | |
| # numpy 2.x๋ฉด scipy ๋ฑ๊ณผ ์ถฉ๋ ๊ฐ๋ฅ โ numpy<2๋ก ๋ด๋ฆฌ๋ ์๋ | |
| try: | |
| import numpy as _np | |
| if _np.__version__.startswith("2"): | |
| print(f"[deps] numpy=={_np.__version__} detected; attempting to pin <2.0") | |
| _pip_install(["numpy<2"]) | |
| except Exception as e: | |
| print(f"[deps] numpy check failed: {e}") | |
| # ํ์ ๋ชจ๋ ์ฑ์ฐ๊ธฐ | |
| missing = [m for m in RUNTIME_DEPS if _need_install(m)] | |
| if missing: | |
| print(f"[deps] installing missing modules: {missing}") | |
| _pip_install(missing) | |
| # ์ต์ข ํ์ธ ๋ก๊ทธ | |
| for m in RUNTIME_DEPS: | |
| print(f"[deps] {m} -> {'OK' if not _need_install(m) else 'MISSING'}") | |
| _ensure_runtime_deps() | |
| # ----------------------------------------------------------------------------- # | |
| # FastAPI ์ด๊ธฐํ | |
| # ----------------------------------------------------------------------------- # | |
| app = FastAPI(title="Puppeteer API", version="1.0.0") | |
| # ----------------------------------------------------------------------------- # | |
| # Settings | |
| # ----------------------------------------------------------------------------- # | |
| PUPPETEER_SRC = Path(os.environ.get("PUPPETEER_DIR", "/app/Puppeteer")) # ์ฝ๊ธฐ ์ ์ฉ ์๋ณธ | |
| PUPPETEER_RUN = Path(os.environ.get("PUPPETEER_RUN", "/tmp/puppeteer_run")) # ์คํ์ฉ ๋ณต์ฌ๋ณธ(์ฐ๊ธฐ ๊ฐ๋ฅ) | |
| RESULT_DIR = Path(os.environ.get("RESULT_DIR", str(PUPPETEER_RUN / "results"))) # rig ๊ฒฐ๊ณผ ๊ธฐ๋ณธ ๊ฒฝ๋ก | |
| TMP_IN_DIR = Path(os.environ.get("TMP_IN_DIR", "/tmp/in")) # ์ ๋ ฅ ์ ์ฅ ๊ฒฝ๋ก | |
| DOWNLOAD_TIMEOUT = int(os.environ.get("DOWNLOAD_TIMEOUT", "180")) | |
| MAX_DOWNLOAD_MB = int(os.environ.get("MAX_DOWNLOAD_MB", "512")) | |
| SAFE_NAME = re.compile(r"[^A-Za-z0-9._-]+") | |
| # ์ ๋๋ฉ์ด์ /๋ฆฌ๊น ๊ฒฐ๊ณผ๋ฅผ ํญ๋๊ฒ ์ฐพ๊ธฐ ์ํ ํ๋ณด ๊ฒฝ๋ก | |
| RESULT_BASES = [ | |
| Path("/app/Puppeteer/results"), | |
| RESULT_DIR, | |
| Path("/data/results"), | |
| Path("/tmp/puppeteer_run/results"), | |
| ] | |
| # ----------------------------------------------------------------------------- # | |
| # Auto-download checkpoints (๋ฐํ์ ์ ์๋ ๋ค์ด๋ก๋) | |
| # ----------------------------------------------------------------------------- # | |
| ckpt_path = Path("/app/Puppeteer/checkpoints") | |
| if not ckpt_path.exists() or not any(ckpt_path.iterdir()): | |
| try: | |
| print("[init] checkpoints missing โ trying runtime download via script...") | |
| subprocess.run( | |
| ["bash", "-lc", "cd /app/Puppeteer && ./scripts/download_ckpt.sh"], | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| ) | |
| print("[init] Puppeteer checkpoints downloaded successfully via script.") | |
| except Exception as e: | |
| print("[init] WARNING: checkpoint script failed:", e) | |
| try: | |
| ckpt_path.mkdir(parents=True, exist_ok=True) | |
| print("[init] trying manual download from GitHub release...") | |
| subprocess.run( | |
| [ | |
| "wget", | |
| "-O", | |
| "/app/Puppeteer/checkpoints/rig.ckpt", | |
| "https://github.com/ByteDance-Seed/Puppeteer/releases/download/v1.0.0/rig.ckpt", | |
| ], | |
| check=True, | |
| ) | |
| print("[init] rig.ckpt downloaded manually.") | |
| except Exception as e2: | |
| print("[init] WARNING: manual checkpoint download failed:", e2) | |
| # ----------------------------------------------------------------------------- # | |
| # Schemas | |
| # ----------------------------------------------------------------------------- # | |
| class RigIn(BaseModel): | |
| mesh_url: HttpUrl = Field(..., description="Input mesh URL (obj/glb/fbx/โฆ)") | |
| workdir: Optional[str] = Field(default=None, description="Optional work directory name") | |
| class RigOut(BaseModel): | |
| status: str | |
| result_dir: Optional[str] = None | |
| files_preview: Optional[List[str]] = None | |
| detail: Optional[str] = None | |
| gpu: Optional[bool] = None | |
| gpu_name: Optional[str] = None | |
| class AnimateIn(BaseModel): | |
| video_url: HttpUrl = Field(..., description="Input video URL (mp4, mov, etc.)") | |
| mesh_path: Optional[str] = Field( | |
| default="/app/Puppeteer/results/rigged.glb", | |
| description="Path to rigged mesh" | |
| ) | |
| # ----------------------------------------------------------------------------- # | |
| # Utils | |
| # ----------------------------------------------------------------------------- # | |
| def ensure_dirs() -> None: | |
| TMP_IN_DIR.mkdir(parents=True, exist_ok=True) | |
| PUPPETEER_RUN.mkdir(parents=True, exist_ok=True) | |
| RESULT_DIR.mkdir(parents=True, exist_ok=True) | |
| def prepare_run_tree() -> None: | |
| if not PUPPETEER_SRC.exists(): | |
| raise HTTPException(status_code=500, detail=f"Puppeteer not found: {PUPPETEER_SRC}") | |
| shutil.copytree(PUPPETEER_SRC, PUPPETEER_RUN, dirs_exist_ok=True) | |
| script = PUPPETEER_RUN / "demo_rigging.sh" | |
| if script.exists(): | |
| script.chmod(0o755) | |
| def safe_basename(url: str) -> str: | |
| name = os.path.basename(url.split("?")[0]) | |
| return SAFE_NAME.sub("_", name) or "input_mesh" | |
| def download_with_limit(url: str, dst: Path) -> None: | |
| with requests.get(url, stream=True, timeout=DOWNLOAD_TIMEOUT) as r: | |
| r.raise_for_status() | |
| total = 0 | |
| with open(dst, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if not chunk: | |
| continue | |
| total += len(chunk) | |
| if total > MAX_DOWNLOAD_MB * 1024 * 1024: | |
| raise HTTPException(status_code=413, detail="File too large") | |
| f.write(chunk) | |
| def torch_info() -> tuple[bool, Optional[str]]: | |
| try: | |
| import torch | |
| ok = torch.cuda.is_available() | |
| name = torch.cuda.get_device_name(0) if ok else None | |
| return ok, name | |
| except Exception: | |
| return False, None | |
| def scan_results(limit: int = 200) -> List[str]: | |
| files: List[str] = [] | |
| exts = ("*.glb", "*.mp4", "*.fbx", "*.obj", "*.gltf", "*.png", "*.jpg", "*.json", "*.txt") | |
| for base in RESULT_BASES: | |
| if base.exists(): | |
| for ext in exts: | |
| for p in base.rglob(ext): | |
| if p.is_file(): | |
| files.append(str(p)) | |
| if len(files) >= limit: | |
| return files | |
| return files | |
| # ----------------------------------------------------------------------------- # | |
| # Routes | |
| # ----------------------------------------------------------------------------- # | |
| def root(): | |
| return {"status": "ready", "service": "puppeteer-api"} | |
| def health(): | |
| gpu, name = torch_info() | |
| return {"status": "ok", "cuda": gpu, "gpu": name} | |
| def rig(inp: RigIn): | |
| ensure_dirs() | |
| prepare_run_tree() | |
| basename = safe_basename(str(inp.mesh_url)) | |
| mesh_path = TMP_IN_DIR / basename | |
| _ = SAFE_NAME.sub("_", inp.workdir or "job") # reserved, ํ์ฌ๋ ๋ฏธ์ฌ์ฉ | |
| try: | |
| download_with_limit(str(inp.mesh_url), mesh_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Download error: {e}") | |
| script = PUPPETEER_RUN / "demo_rigging.sh" | |
| cmd = ["bash", str(script), str(mesh_path)] | |
| try: | |
| proc = subprocess.run( | |
| cmd, | |
| cwd=str(PUPPETEER_RUN), | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| ) | |
| run_log = proc.stdout[-4000:] | |
| except subprocess.CalledProcessError as e: | |
| snippet = (e.stdout or "")[-2000:] | |
| raise HTTPException(status_code=500, detail=f"Puppeteer failed: {snippet}") | |
| except FileNotFoundError: | |
| raise HTTPException(status_code=500, detail="demo_rigging.sh not found") | |
| preview = scan_results(limit=20) | |
| gpu, gpu_name = torch_info() | |
| return RigOut( | |
| status="ok", | |
| result_dir=str(RESULT_DIR), | |
| files_preview=preview[:10], | |
| detail=run_log if preview else "no result files found", | |
| gpu=gpu, | |
| gpu_name=gpu_name, | |
| ) | |
| def animate(inp: AnimateIn): | |
| """ | |
| Puppeteer์ demo_animation.sh ์คํ (์์ ๊ธฐ๋ฐ ์ ๋๋ฉ์ด์ ) | |
| ์ ๋ ฅ: video_url (mp4), mesh_path (rigged.glb ๊ธฐ๋ณธ๊ฐ) | |
| """ | |
| pdir = Path("/app/Puppeteer") | |
| script = pdir / "demo_animation.sh" | |
| video_path = Path("/tmp/video.mp4") | |
| if not script.exists(): | |
| raise HTTPException(status_code=404, detail="demo_animation.sh not found") | |
| # -------- requests ๊ธฐ๋ฐ ์์ ๋ค์ด๋ก๋ -------- # | |
| try: | |
| print(f"[animate] downloading video from {inp.video_url}") | |
| with requests.get(str(inp.video_url), stream=True, timeout=60) as r: | |
| r.raise_for_status() | |
| with open(video_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| print(f"[animate] Video saved to {video_path}") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Video download failed via requests: {e}") | |
| # -------- Puppeteer ์ ๋๋ฉ์ด์ ์คํ -------- # | |
| cmd = [ | |
| "bash", str(script), | |
| "--mesh", str(inp.mesh_path), | |
| "--video", str(video_path), | |
| ] | |
| try: | |
| proc = subprocess.run( | |
| cmd, | |
| cwd=str(pdir), | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| ) | |
| output = proc.stdout[-4000:] | |
| except subprocess.CalledProcessError as e: | |
| raise HTTPException(status_code=500, detail=f"Animation failed: {e.stdout[-2000:]}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") | |
| anim_results = scan_results(limit=20) | |
| return { | |
| "status": "ok", | |
| "video_used": str(inp.video_url), | |
| "detail": output, | |
| "files_preview": anim_results[:10], | |
| } | |
| # -------- ๊ฒฐ๊ณผ ํ์ผ ํ์ธ/๋ค์ด๋ก๋ ์ ํธ -------- # | |
| def list_results(): | |
| files = scan_results(limit=500) | |
| return {"count": len(files), "files": files} | |
| def download(path: str): | |
| p = Path(path).resolve() | |
| # ์์ ํ ๊ฒฝ๋ก๋ง ํ์ฉ | |
| if not any(str(p).startswith(str(b.resolve())) for b in RESULT_BASES): | |
| raise HTTPException(status_code=400, detail="invalid path") | |
| if not p.exists() or not p.is_file(): | |
| raise HTTPException(status_code=404, detail="file not found") | |
| return FileResponse(str(p), filename=p.name) | |
| def debug(): | |
| pdir = Path("/app/Puppeteer") | |
| script = pdir / "demo_rigging.sh" | |
| ckpt_dir = pdir / "checkpoints" | |
| req_file = pdir / "requirements.txt" | |
| return { | |
| "script_exists": script.exists(), | |
| "ckpt_dir_exists": ckpt_dir.exists(), | |
| "req_exists": req_file.exists(), | |
| "ckpt_samples": [str(p) for p in ckpt_dir.glob("**/*")][:15], | |
| "tmp_in": os.environ.get("TMP_IN_DIR", "/data/in"), | |
| "result_dir": os.environ.get("RESULT_DIR", "/data/results"), | |
| "omp_num_threads": os.environ.get("OMP_NUM_THREADS"), | |
| } | |