Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import shlex | |
| import subprocess | |
| import imageio | |
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| import sys | |
| from loguru import logger | |
| current_path = os.path.dirname(os.path.abspath(__file__)) | |
| # try: | |
| # import diff_gaussian_rasterization # noqa: F401 | |
| # except ImportError: | |
| # @spaces.GPU | |
| # def install_diff_gaussian_rasterization(): | |
| # os.system("pip install ./extensions/RaDe-GS/submodules/diff-gaussian-rasterization") | |
| # install_diff_gaussian_rasterization() | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(current_path, 'out') | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| TAG = { | |
| "SD15": ["gsdiff_gobj83k_sd15__render", "gsdiff_gobj83k_sd15_image__render"], # Best efficiency | |
| "PixArt-Sigma": ["gsdiff_gobj83k_pas_fp16__render","gsdiff_gobj83k_pas_fp16_image__render"], | |
| "SD3": ["gsdiff_gobj83k_sd35m__render", "gsdiff_gobj83k_sd35m_image__render"] # Best performance | |
| } | |
| MODEL_TYPE = "PixArt-Sigma" | |
| # for PixArt-Sigma | |
| subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas")) # for txt condition | |
| subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas --image_cond")) # for img condition | |
| img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml {} \ | |
| --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \ | |
| --image_path {} --elevation {} --prompt {} --seed {}" | |
| txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml {} \ | |
| --save_ply --output_video_type mp4 \ | |
| --prompt {} --seed {}" | |
| # for SD1.5 | |
| # subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15")) # for txt condition | |
| # subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15 --image_cond")) # for img condition | |
| # img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml {} \ | |
| # --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \ | |
| # --image_path {} --elevation {} --prompt {} --seed {}" | |
| # txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml {} \ | |
| # --save_ply --output_video_type mp4 --guidance_scale {} \ | |
| # --elevation {} --prompt {} --seed {}" | |
| # process function | |
| def process(input_image, prompt='a_high_quality_3D_asset', prompt_neg='poor_quality', input_elevation=20, guidance_scale=2., input_seed=0): | |
| # fail to install RaDe-GS | |
| # subprocess.run("cd extensions/RaDe-GS/submodules && pip3 install diff-gaussian-rasterization", shell=True) | |
| # subprocess.run("cd extensions/RaDe-GS/submodules/diff-gaussian-rasterization && python3 setup.py bdist_wheel ", shell=True) | |
| if input_image is not None: | |
| import uuid | |
| image_path = os.path.join(TMP_DIR, f"{str(uuid.uuid4())}.png") | |
| image_name = image_path.split('/')[-1].split('.')[0] + "_rgba" | |
| input_image.save(image_path) | |
| TAG_DEST = TAG[MODEL_TYPE][1] | |
| full_command = img_commands.format(TAG_DEST, guidance_scale, image_path, input_elevation, prompt, input_seed) | |
| else: | |
| TAG_DEST = TAG[MODEL_TYPE][0] | |
| # without guidance_scale and input_elevation | |
| full_command = txt_commands.format(TAG_DEST, prompt, input_seed) | |
| image_name = "" | |
| os.system(full_command) | |
| # save video and ply files | |
| ckpt_dir = os.path.join(TMP_DIR, TAG_DEST, "checkpoints") | |
| infer_from_iter = int(sorted(os.listdir(ckpt_dir))[-1]) | |
| MAX_NAME_LEN = 20 # TODO: make `20` configurable | |
| prompt = prompt.replace("_", " ") | |
| prompt_name = prompt[:MAX_NAME_LEN] + "..." if prompt[:MAX_NAME_LEN] != "" else prompt | |
| name = f"[{image_name}]_[{prompt_name}]_{infer_from_iter:06d}" | |
| output_video_path = os.path.join(TMP_DIR, TAG_DEST, "inference", name + ".mp4") | |
| output_ply_path = os.path.join(TMP_DIR, TAG_DEST, "inference", name + ".ply") | |
| output_img_path = os.path.join(TMP_DIR, TAG_DEST, "inference", name + "_gs.png") | |
| logger.info(full_command, output_video_path, output_ply_path) | |
| output_image = imageio.imread(output_img_path) | |
| return output_image, output_video_path, output_ply_path | |
| # gradio UI | |
| _TITLE = '''DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation''' | |
| _DESCRIPTION = ''' | |
| <strong style="color:red;">This space is currently under maintenance. If you experience slow or failed generations, please try again. We apologize for any inconvenience.</strong> | |
| ### If you find our work helpful, please consider citing our paper π or giving the repo a star π | |
| <div> | |
| <a style="display:inline-block; margin-left: .5em" href="https://chenguolin.github.io/projects/DiffSplat"><img src='https://img.shields.io/badge/Project-Page-brightgreen'/></a> | |
| <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2501.16764"><img src='https://img.shields.io/badge/arXiv-2501.16764-b31b1b.svg?logo=arXiv'/></a> | |
| <a style="display:inline-block; margin-left: .5em" href="https://github.com/chenguolin/DiffSplat"><img src='https://img.shields.io/github/stars/chenguolin/DiffSplat?style=social'/></a> | |
| <a style="display:inline-block; margin-left: .5em" href="https://huggingface.co/chenguolin/DiffSplat"><img src='https://img.shields.io/badge/HF-Model-yellow'/></a> | |
| </div> | |
| * Input can be only text, only image, or both image and text. | |
| * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. | |
| * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. | |
| ''' | |
| block = gr.Blocks(title=_TITLE) | |
| with block: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown('# ' + _TITLE) | |
| gr.Markdown(_DESCRIPTION) | |
| with gr.Row(variant='panel'): | |
| with gr.Column(scale=1): | |
| # input image | |
| input_image = gr.Image(label="image", type='pil') | |
| # input prompt | |
| input_text = gr.Textbox(label="prompt",value="a_high_quality_3D_asset") | |
| # negative prompt | |
| input_neg_text = gr.Textbox(label="negative prompt", value="ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate") | |
| # guidance_scale | |
| guidance_scale = gr.Slider(label="guidance scale", minimum=1., maximum=7.5, step=0.5, value=2.0) | |
| # elevation | |
| input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=10) | |
| # random seed | |
| input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0) | |
| # gen button | |
| button_gen = gr.Button("Generate") | |
| with gr.Column(scale=0.8): | |
| with gr.Tab("Video"): | |
| # final video results | |
| output_video = gr.Video(label="video") | |
| # ply file | |
| output_file = gr.File(label="3D Gaussians (ply format)") | |
| with gr.Tab("Splatter Images"): | |
| output_image = gr.Image(interactive=False, show_label=False) | |
| button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, guidance_scale, input_seed], outputs=[output_image, output_video, output_file]) | |
| gr.Examples( | |
| examples=[ | |
| [f'assets/diffsplat/{image}', "a_high_quality_3D_asset"] | |
| for image in os.listdir("assets/diffsplat") if image.endswith('.png') | |
| ], | |
| inputs=[input_image, input_text], | |
| label='Image-to-3D Examples' | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["a_toy_robot", None], | |
| ["a_cute_panda", None], | |
| ["an_ancient_leather-bound_book", None] | |
| ], | |
| inputs=[input_text, input_image], | |
| label='Text-to-3D Examples' | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| block.launch(share=True) | |