seungminkwak's picture
reset: clean history (purge leaked token)
08b23ce
# Copyright (c) 2021 Henrique Morimitsu
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache License 2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025.09.04
#
# Original file was released under Apache License 2.0, with the full license text
# available at https://github.com/hmorimitsu/ptlflow/blob/main/LICENSE.
#
# This modified file is released under the same license.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This module processes PNG frame sequences to generate optical flow using PTLFlow,
with support for visualization and video generation.
"""
import argparse
import os
import subprocess
import shutil
import logging
from pathlib import Path
from typing import List, Tuple, Optional, Union
import cv2 as cv
import torch
import numpy as np
from tqdm import tqdm
from third_partys.ptlflow.ptlflow.utils import flow_utils
from third_partys.ptlflow.ptlflow.utils.io_adapter import IOAdapter
import third_partys.ptlflow.ptlflow as ptlflow
class OpticalFlowProcessor:
"""Handles optical flow computation and visualization."""
def __init__(
self,
model_name: str = 'dpflow',
checkpoint: str = 'sintel',
device: Optional[str] = None,
resize_to: Optional[Tuple[int, int]] = None
):
"""
Initialize optical flow processor.
Args:
model_name: Name of the flow model to use
checkpoint: Checkpoint/dataset name for the model
device: Device to run on (auto-detect if None)
resize_to: Optional (width, height) to resize frames
"""
self.model_name = model_name
self.checkpoint = checkpoint
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.resize_to = resize_to
# Initialize model
self.model = ptlflow.get_model(model_name, ckpt_path=checkpoint).to(self.device).eval()
print(f"Loaded {model_name} model on {self.device}")
self.io_adapter = None
def load_frame_sequence(self, frames_dir: Union[str, Path]) -> Tuple[List[np.ndarray], List[Path]]:
"""
Load PNG frame sequence from directory.
"""
frames_dir = Path(frames_dir)
if not frames_dir.exists():
raise FileNotFoundError(f"Frames directory not found: {frames_dir}")
# Find PNG files and sort naturally
png_files = list(frames_dir.glob('*.png'))
if len(png_files) < 2:
raise ValueError(f"Need at least 2 PNG frames, found {len(png_files)} in {frames_dir}")
# Natural sorting for proper frame order
png_files.sort(key=lambda x: self._natural_sort_key(x.name))
frames = []
for png_path in tqdm(png_files, desc="Loading frames"):
# Load image in color
img_bgr = cv.imread(str(png_path), cv.IMREAD_COLOR)
if self.resize_to:
img_bgr = cv.resize(img_bgr, self.resize_to, cv.INTER_LINEAR)
img_rgb = cv.cvtColor(img_bgr, cv.COLOR_BGR2RGB)
frames.append(img_rgb)
return frames, png_files
def _natural_sort_key(self, filename: str) -> List[Union[int, str]]:
"""Natural sorting key for filenames with numbers."""
import re
return [int(text) if text.isdigit() else text.lower()
for text in re.split('([0-9]+)', filename)]
def compute_optical_flow_sequence(
self,
frames: List[np.ndarray],
flow_vis_dir: Union[str, Path],
flow_save_dir: Optional[Union[str, Path]] = None,
save_visualizations: bool = True
) -> List[torch.Tensor]:
"""
Compute optical flow for entire frame sequence.
"""
if len(frames) < 2:
raise ValueError("Need at least 2 frames for optical flow")
flow_vis_dir = Path(flow_vis_dir)
flow_save_dir = Path(flow_save_dir) if flow_save_dir else flow_vis_dir
H, W = frames[0].shape[:2]
# Initialize IO adapter
if self.io_adapter is None:
self.io_adapter = IOAdapter(self.model, (H, W))
flows = []
for i in tqdm(range(len(frames) - 1), desc="Computing optical flow"):
# Prepare frame pair
frame_pair = [frames[i], frames[i + 1]]
raw_inputs = self.io_adapter.prepare_inputs(frame_pair)
imgs = raw_inputs['images'][0] # (2, 3, H, W)
pair_tensor = torch.stack((imgs[0:1], imgs[1:2]), dim=1).squeeze(0) # (1, 2, 3, H, W)
pair_tensor = pair_tensor.to(self.device, non_blocking=True).contiguous()
with torch.no_grad():
flow_result = self.model({'images': pair_tensor.unsqueeze(0)})
flow = flow_result['flows'][0] # (1, 2, H, W)
flows.append(flow)
if save_visualizations:
self._save_flow_outputs(flow, i, flow_vis_dir, flow_save_dir)
return flows
def _save_flow_outputs(
self,
flow_tensor: torch.Tensor,
frame_idx: int,
viz_dir: Path,
flow_dir: Path
) -> None:
"""Save flow outputs in both .flo and visualization formats."""
# Save raw flow (.flo format)
flow_hw2 = flow_tensor[0] # (2, H, W)
flow_np = flow_hw2.permute(1, 2, 0).cpu().numpy() # (H, W, 2)
flow_path = flow_dir / f'flow_{frame_idx:04d}.flo'
flow_utils.flow_write(flow_path, flow_np)
# Save visualization
flow_rgb = flow_utils.flow_to_rgb(flow_tensor)[0] # Remove batch dimension
if flow_rgb.dim() == 4: # (Npred, 3, H, W)
flow_rgb = flow_rgb[0]
flow_rgb_np = (flow_rgb * 255).byte().permute(1, 2, 0).cpu().numpy() # (H, W, 3)
viz_bgr = cv.cvtColor(flow_rgb_np, cv.COLOR_RGB2BGR)
viz_path = viz_dir / f'flow_viz_{frame_idx:04d}.png'
cv.imwrite(str(viz_path), viz_bgr)
def create_flow_video(
image_dir: Union[str, Path],
output_filename: str = 'flow.mp4',
fps: int = 10,
pattern: str = 'flow_viz_*.png',
cleanup_temp: bool = True
) -> bool:
"""
Create MP4 video from flow visualization images.
"""
image_dir = Path(image_dir)
if not image_dir.exists():
print(f"Image directory not found: {image_dir}")
image_files = sorted(image_dir.glob(pattern))
if not image_files:
print(f"No images found matching pattern '{pattern}' in {image_dir}")
temp_dir = image_dir / 'temp_sequence'
temp_dir.mkdir(exist_ok=True)
try:
# Copy files with sequential naming
for i, img_file in enumerate(image_files):
temp_name = temp_dir / f'frame_{i:05d}.png'
shutil.copy2(img_file, temp_name)
# Create video using ffmpeg
output_path = image_dir / output_filename
cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-i', str(temp_dir / 'frame_%05d.png'),
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
str(output_path)
]
subprocess.run(
cmd,
capture_output=True,
text=True,
check=True
)
return True
except Exception as e:
print(f"Video creation failed: {e}")
return False
finally:
if cleanup_temp and temp_dir.exists():
shutil.rmtree(temp_dir)
def main(
frames_dir: Union[str, Path],
flow_vis_dir: Union[str, Path] = 'flow_out',
flow_save_dir: Optional[Union[str, Path]] = None,
resize_to: Optional[Tuple[int, int]] = None,
model_name: str = 'dpflow',
checkpoint: str = 'sintel'
) -> bool:
# Initialize processor
processor = OpticalFlowProcessor(
model_name=model_name,
checkpoint=checkpoint,
resize_to=resize_to
)
# Load frames
frames, png_paths = processor.load_frame_sequence(frames_dir)
# Compute optical flow
flows = processor.compute_optical_flow_sequence(
frames=frames,
flow_vis_dir=flow_vis_dir,
flow_save_dir=flow_save_dir,
save_visualizations=True
)
# Create video
create_flow_video(flow_vis_dir)
def get_parser():
parser = argparse.ArgumentParser(description="Optical flow inference on frame sequences")
parser.add_argument('--input_path', type=str, help="base input path")
parser.add_argument('--seq_name', type=str, help="sequence name")
parser.add_argument('--model_name', type=str, default='dpflow', help="Optical flow model to use")
parser.add_argument('--checkpoint', type=str, default='sintel', help="Model checkpoint/dataset name")
parser.add_argument('--resize_width', type=int, default=None, help="Resize frame width (must specify both width and height)")
parser.add_argument('--resize_height', type=int, default=None, help="Resize frame height (must specify both width and height)")
parser.add_argument('--fps', type=int, default=10, help="Frame rate for output video")
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
# Path
frames_dir = f'{args.input_path}/{args.seq_name}/imgs'
flow_vis_dir = frames_dir.replace("imgs", "flow_vis")
flow_save_dir = frames_dir.replace("imgs", "flow")
os.makedirs(flow_vis_dir, exist_ok=True)
os.makedirs(flow_save_dir, exist_ok=True)
# Prepare resize parameter
resize_to = None
if args.resize_width and args.resize_height:
resize_to = (args.resize_width, args.resize_height)
# Process optical flow
success = main(
frames_dir=frames_dir,
flow_vis_dir=flow_vis_dir,
flow_save_dir=flow_save_dir,
resize_to=resize_to,
model_name=args.model_name,
checkpoint=args.checkpoint
)
print("Optical flow processing completed successfully")