Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| from huggingface_hub import snapshot_download | |
| def download_ckpt(): | |
| parser = argparse.ArgumentParser(description="Download checkpoints from HuggingFace Hub") | |
| parser.add_argument( | |
| "--local_dir", | |
| type=str, | |
| default="./out", | |
| help="Local directory to save the checkpoints" | |
| ) | |
| parser.add_argument( | |
| "--model_type", | |
| type=str, | |
| default="sd15", | |
| choices=["sd15", "pas", "sd35m", "depth", "normal", "canny", "elevest"], | |
| help="Model type to download" | |
| ) | |
| parser.add_argument( | |
| "--image_cond", | |
| action="store_true", | |
| help="Whether to download image-conditioned models" | |
| ) | |
| args = parser.parse_args() | |
| repo_id, local_dir = "chenguolin/DiffSplat", args.local_dir | |
| os.makedirs(local_dir, exist_ok=True) | |
| model_type, image_cond = args.model_type, args.image_cond | |
| suffix = "_image" if image_cond else "" | |
| # DiffSplat (SD1.5) | |
| if model_type == "sd15": | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| allow_patterns=[ | |
| "gsrecon_gobj265k_cnp_even4/*", # `GSRecon` | |
| "gsvae_gobj265k_sd/*", # `GSVAE (SD)` | |
| f"gsdiff_gobj83k_sd15{suffix}__render/*", # `DiffSplat (SD)` | |
| ] | |
| ) | |
| # DiffSplat (PixArt-Sigma) | |
| elif model_type == "pas": | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| allow_patterns=[ | |
| "gsrecon_gobj265k_cnp_even4/*", # `GSRecon` | |
| "gsvae_gobj265k_sdxl_fp16/*", # `GSVAE (SDXL)` | |
| f"gsdiff_gobj83k_pas_fp16{suffix}__render/*", # `DiffSplat (PixArt-Sigma)` | |
| ] | |
| ) | |
| # DiffSplat (SD3.5m) | |
| elif model_type == "sd35m": | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| allow_patterns=[ | |
| "gsrecon_gobj265k_cnp_even4/*", # `GSRecon` | |
| "gsvae_gobj265k_sd3/*", # `GSVAE (SD3)` | |
| f"gsdiff_gobj83k_sd35m{suffix}__render/*", # `DiffSplat (SD3.5m)` | |
| ] | |
| ) | |
| # DiffSplat ControlNet (SD1.5) | |
| elif model_type in ["depth", "normal", "canny"]: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| allow_patterns=[ | |
| f"gsdiff_gobj83k_sd15__render__{model_type}/*", # `DiffSplat ControlNet (SD1.5)` | |
| ] | |
| ) | |
| # Elevation Estimation | |
| elif model_type == "elevest": | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| allow_patterns=[ | |
| "elevest_gobj265k_b_C25/*", | |
| ] | |
| ) | |
| else: | |
| raise ValueError(f"Choose from ['sd15', 'pas', 'sd35m', 'depth', 'normal', 'canny', 'elevest'], but got [{model_type}]") | |
| if __name__ == "__main__": | |
| download_ckpt() | |