File size: 4,898 Bytes
4724018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import argparse
from PIL import Image

import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from moviepy.editor import VideoFileClip
from diffusers.utils import load_image, load_video

from models.pipelines import DiffusionAsShaderPipeline, CameraMotionGenerator, ObjectMotionGenerator

def load_media(media_path, max_frames=49, transform=None):
    """Load video or image frames and convert to tensor
    
    Args:
        media_path (str): Path to video or image file
        max_frames (int): Maximum number of frames to load
        transform (callable): Transform to apply to frames
        
    Returns:
        Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and FPS
    """
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((480, 720)),
            transforms.ToTensor()
        ])
    
    # Determine if input is video or image based on extension
    ext = os.path.splitext(media_path)[1].lower()
    is_video = ext in ['.mp4', '.avi', '.mov']
    
    if is_video:
        frames = load_video(media_path)
        fps = len(frames) / VideoFileClip(media_path).duration
    else:
        # Handle image as single frame
        image = load_image(media_path)
        frames = [image]
        fps = 8  # Default fps for images
    
    # Ensure we have exactly max_frames
    if len(frames) > max_frames:
        frames = frames[:max_frames]
    elif len(frames) < max_frames:
        last_frame = frames[-1]
        while len(frames) < max_frames:
            frames.append(last_frame.copy())
            
    # Convert frames to tensor
    video_tensor = torch.stack([transform(frame) for frame in frames])
    
    return video_tensor, fps, is_video

def inference(das, prompt, checkpoint_path, tracking_path, output_dir, input_path):
    video_tensor, fps, is_video = load_media(input_path)
    tracking_tensor, _, _ = load_media(args.tracking_path)
    das.apply_tracking(
        video_tensor=video_tensor,
        fps=8,
        tracking_tensor=tracking_tensor,
        img_cond_tensor=None,
        prompt=prompt,
        checkpoint_path=checkpoint_path
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image')
    parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt')
    parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
    parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')
    parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image')
    parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied')
    parser.add_argument('--repaint', type=str, default=None, 
                       help='Path to repainted image, or "true" to perform repainting, if not provided use original frame')
    parser.add_argument('--camera_motion', type=str, default=None, help='Camera motion mode')
    parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
    parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
    parser.add_argument('--tracking_method', type=str, default="spatracker", 
                        help='default tracking method for image input: moge/spatracker, if \'moge\' method will extract first frame for video input')
    parser.add_argument('--coarse_video_path', type=str, default=None, help='Path to coarse video for object motion')
    parser.add_argument('--start_noise_t', type=int, default=10, help='Strength of object motion')
    args = parser.parse_args()
    
    # Load input video/image
    video_tensor, fps, is_video = load_media(args.input_path)

    # Initialize pipeline
    das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
    
    # Repaint first frame if requested
    repaint_img_tensor = None

    # Generate tracking if not provided
    tracking_tensor = None
    pred_tracks = None
    cam_motion = CameraMotionGenerator(args.camera_motion)

    if args.tracking_path:
        tracking_tensor, _, _ = load_media(args.tracking_path)
    
    coarse_video = load_media(args.coarse_video_path)[0] if args.coarse_video_path else None
    
    das.apply_tracking(
        video_tensor=video_tensor,
        fps=24,
        tracking_tensor=tracking_tensor,
        img_cond_tensor=repaint_img_tensor,
        prompt=args.prompt,
        checkpoint_path=args.checkpoint_path,
        coarse_video=coarse_video,
        start_noise_t=args.start_noise_t
    )