File size: 5,764 Bytes
4724018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import argparse
import rembg
import numpy as np
import math
import torch
import json

from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from diffusers.utils import export_to_gif
from diffusers_sv3d import SV3DUNetSpatioTemporalConditionModel, StableVideo3DDiffusionPipeline
from kiui.cam import orbit_camera

SV3D_DIFFUSERS = "chenguolin/sv3d-diffusers"

# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "~/.cache/huggingface"

def construct_camera(azimuths_rad, elevation_rad, output_dir, res=576, radius=2, fov=33.8):
    
    transforms = {}
    transforms["camera_angle_x"] = math.radians(fov)
    transforms["frames"] = []
    
    for i in range(21):
        frame = {}
        frame['file_path'] = f"data/{i:03d}"
        frame['transform_matrix'] = orbit_camera(elevation_rad[i], azimuths_rad[i], radius, is_degree=False).tolist()
        transforms['frames'].append(frame)
    
    with open(f"{output_dir}/../transforms_train.json", "w") as f:
        json.dump(transforms, f, indent=4)
    with open(f"{output_dir}/../transforms_val.json", "w") as f:
        json.dump(transforms, f, indent=4)
    with open(f"{output_dir}/../transforms_test.json", "w") as f:
        json.dump(transforms, f, indent=4)

def recenter(image, h_begin=100, w_begin=220, res=256):
    image = np.array(image)
    h_image, w_image = image.shape[:2]
    new_image = np.zeros((res, res, 4), dtype=np.uint8) 
    h_begin_new = -min(0, h_begin)
    w_begin_new = -min(0, w_begin)
    if h_begin > 0 and w_begin > 0:
        new_image = image[h_begin:h_begin+res, w_begin:w_begin+res]
    else:
        new_image[h_begin_new:h_begin_new+h_image, w_begin_new:w_image] = image
    return Image.fromarray(new_image)

def main():
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-dir", default="../../data", type=str, help="Base dir")
    parser.add_argument("--output-dir", default="../../data", type=str, help="Output filepath")
    parser.add_argument("--data-name", default="chair", type=str, help="Data Name")
    parser.add_argument("--elevation", default=0, type=float, help="Camera elevation of the input image")
    parser.add_argument("--half-precision", action="store_true", help="Use fp16 half precision")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = parser.parse_args()

    image_path = f'{args.base_dir}/{args.data_name}/{args.data_name}.png'
    output_dir = f'{args.output_dir}/{args.data_name}/data'
    os.makedirs(output_dir, exist_ok=True)

    num_frames, sv3d_res = 20, 576
    elevations_deg = [args.elevation] * num_frames
    elevations_rad = [np.deg2rad(e) for e in elevations_deg]
    polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
    azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
    azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
    azimuths_rad[:-1].sort()
    
    # print(f"Elevation: {elevations_rad}")
    print(f"Azimuth: {np.rad2deg(azimuths_rad)}")
    # construct_camera(azimuths_rad, elevations_rad, output_dir=output_dir)
    
    bg_remover = rembg.new_session()
    unet = SV3DUNetSpatioTemporalConditionModel.from_pretrained(SV3D_DIFFUSERS, subfolder="unet")
    vae = AutoencoderKL.from_pretrained(SV3D_DIFFUSERS, subfolder="vae")
    scheduler = EulerDiscreteScheduler.from_pretrained(SV3D_DIFFUSERS, subfolder="scheduler")
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(SV3D_DIFFUSERS, subfolder="image_encoder")
    feature_extractor = CLIPImageProcessor.from_pretrained(SV3D_DIFFUSERS, subfolder="feature_extractor")

    pipeline = StableVideo3DDiffusionPipeline(
        image_encoder=image_encoder, feature_extractor=feature_extractor, 
        unet=unet, vae=vae,
        scheduler=scheduler,
    )
    pipeline = pipeline.to("cuda")
    with torch.no_grad():
        with torch.autocast("cuda", dtype=torch.float16 if args.half_precision else torch.float32, enabled=True):
                
            h_begin, w_begin, res = 180, 190, 280
            image = Image.open(image_path)
            image = recenter(image, h_begin, w_begin, res)
            image = rembg.remove(image, session=bg_remover) # [H, W, 4]
            image.save(f"{output_dir}/../{args.data_name}_alpha.png")
            if len(image.split()) == 4:  # RGBA
                input_image = Image.new("RGB", image.size, (255, 255, 255))  # pure white bg
                input_image.paste(image, mask=image.split()[3])  # 3rd is the alpha channel
            else:
                input_image = image
            
            video_frames = pipeline(
                input_image.resize((sv3d_res, sv3d_res)),
                height=sv3d_res,
                width=sv3d_res,
                num_frames=num_frames,
                decode_chunk_size=8,  # smaller to save memory
                polars_rad=polars_rad,
                azimuths_rad=azimuths_rad,
                generator=torch.manual_seed(args.seed) if args.seed >= 0 else None,
            ).frames[0]

    os.makedirs(output_dir, exist_ok=True)
    export_to_gif(video_frames, f"{output_dir}/animation.gif", fps=7)
    for i, frame in enumerate(video_frames):
        # frame = frame.resize((res, res))
        frame.save(f"{output_dir}/{i:03d}.png")
    video_frames[19].save(f"../LGM/workspace_test/{args.data_name}_0.png")
    video_frames[4].save(f"../LGM/workspace_test/{args.data_name}_1.png")
    video_frames[9].save(f"../LGM/workspace_test/{args.data_name}_2.png")
    video_frames[14].save(f"../LGM/workspace_test/{args.data_name}_3.png")
    
    
if __name__ == "__main__":
    main()