# Copyright (c) 2021 Henrique Morimitsu # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache License 2.0 # # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025.09.04 # # Original file was released under Apache License 2.0, with the full license text # available at https://github.com/hmorimitsu/ptlflow/blob/main/LICENSE. # # This modified file is released under the same license. #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ This module processes PNG frame sequences to generate optical flow using PTLFlow, with support for visualization and video generation. """ import argparse import os import subprocess import shutil import logging from pathlib import Path from typing import List, Tuple, Optional, Union import cv2 as cv import torch import numpy as np from tqdm import tqdm from third_partys.ptlflow.ptlflow.utils import flow_utils from third_partys.ptlflow.ptlflow.utils.io_adapter import IOAdapter import third_partys.ptlflow.ptlflow as ptlflow class OpticalFlowProcessor: """Handles optical flow computation and visualization.""" def __init__( self, model_name: str = 'dpflow', checkpoint: str = 'sintel', device: Optional[str] = None, resize_to: Optional[Tuple[int, int]] = None ): """ Initialize optical flow processor. Args: model_name: Name of the flow model to use checkpoint: Checkpoint/dataset name for the model device: Device to run on (auto-detect if None) resize_to: Optional (width, height) to resize frames """ self.model_name = model_name self.checkpoint = checkpoint self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.resize_to = resize_to # Initialize model self.model = ptlflow.get_model(model_name, ckpt_path=checkpoint).to(self.device).eval() print(f"Loaded {model_name} model on {self.device}") self.io_adapter = None def load_frame_sequence(self, frames_dir: Union[str, Path]) -> Tuple[List[np.ndarray], List[Path]]: """ Load PNG frame sequence from directory. """ frames_dir = Path(frames_dir) if not frames_dir.exists(): raise FileNotFoundError(f"Frames directory not found: {frames_dir}") # Find PNG files and sort naturally png_files = list(frames_dir.glob('*.png')) if len(png_files) < 2: raise ValueError(f"Need at least 2 PNG frames, found {len(png_files)} in {frames_dir}") # Natural sorting for proper frame order png_files.sort(key=lambda x: self._natural_sort_key(x.name)) frames = [] for png_path in tqdm(png_files, desc="Loading frames"): # Load image in color img_bgr = cv.imread(str(png_path), cv.IMREAD_COLOR) if self.resize_to: img_bgr = cv.resize(img_bgr, self.resize_to, cv.INTER_LINEAR) img_rgb = cv.cvtColor(img_bgr, cv.COLOR_BGR2RGB) frames.append(img_rgb) return frames, png_files def _natural_sort_key(self, filename: str) -> List[Union[int, str]]: """Natural sorting key for filenames with numbers.""" import re return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', filename)] def compute_optical_flow_sequence( self, frames: List[np.ndarray], flow_vis_dir: Union[str, Path], flow_save_dir: Optional[Union[str, Path]] = None, save_visualizations: bool = True ) -> List[torch.Tensor]: """ Compute optical flow for entire frame sequence. """ if len(frames) < 2: raise ValueError("Need at least 2 frames for optical flow") flow_vis_dir = Path(flow_vis_dir) flow_save_dir = Path(flow_save_dir) if flow_save_dir else flow_vis_dir H, W = frames[0].shape[:2] # Initialize IO adapter if self.io_adapter is None: self.io_adapter = IOAdapter(self.model, (H, W)) flows = [] for i in tqdm(range(len(frames) - 1), desc="Computing optical flow"): # Prepare frame pair frame_pair = [frames[i], frames[i + 1]] raw_inputs = self.io_adapter.prepare_inputs(frame_pair) imgs = raw_inputs['images'][0] # (2, 3, H, W) pair_tensor = torch.stack((imgs[0:1], imgs[1:2]), dim=1).squeeze(0) # (1, 2, 3, H, W) pair_tensor = pair_tensor.to(self.device, non_blocking=True).contiguous() with torch.no_grad(): flow_result = self.model({'images': pair_tensor.unsqueeze(0)}) flow = flow_result['flows'][0] # (1, 2, H, W) flows.append(flow) if save_visualizations: self._save_flow_outputs(flow, i, flow_vis_dir, flow_save_dir) return flows def _save_flow_outputs( self, flow_tensor: torch.Tensor, frame_idx: int, viz_dir: Path, flow_dir: Path ) -> None: """Save flow outputs in both .flo and visualization formats.""" # Save raw flow (.flo format) flow_hw2 = flow_tensor[0] # (2, H, W) flow_np = flow_hw2.permute(1, 2, 0).cpu().numpy() # (H, W, 2) flow_path = flow_dir / f'flow_{frame_idx:04d}.flo' flow_utils.flow_write(flow_path, flow_np) # Save visualization flow_rgb = flow_utils.flow_to_rgb(flow_tensor)[0] # Remove batch dimension if flow_rgb.dim() == 4: # (Npred, 3, H, W) flow_rgb = flow_rgb[0] flow_rgb_np = (flow_rgb * 255).byte().permute(1, 2, 0).cpu().numpy() # (H, W, 3) viz_bgr = cv.cvtColor(flow_rgb_np, cv.COLOR_RGB2BGR) viz_path = viz_dir / f'flow_viz_{frame_idx:04d}.png' cv.imwrite(str(viz_path), viz_bgr) def create_flow_video( image_dir: Union[str, Path], output_filename: str = 'flow.mp4', fps: int = 10, pattern: str = 'flow_viz_*.png', cleanup_temp: bool = True ) -> bool: """ Create MP4 video from flow visualization images. """ image_dir = Path(image_dir) if not image_dir.exists(): print(f"Image directory not found: {image_dir}") image_files = sorted(image_dir.glob(pattern)) if not image_files: print(f"No images found matching pattern '{pattern}' in {image_dir}") temp_dir = image_dir / 'temp_sequence' temp_dir.mkdir(exist_ok=True) try: # Copy files with sequential naming for i, img_file in enumerate(image_files): temp_name = temp_dir / f'frame_{i:05d}.png' shutil.copy2(img_file, temp_name) # Create video using ffmpeg output_path = image_dir / output_filename cmd = [ 'ffmpeg', '-y', '-framerate', str(fps), '-i', str(temp_dir / 'frame_%05d.png'), '-c:v', 'libx264', '-pix_fmt', 'yuv420p', str(output_path) ] subprocess.run( cmd, capture_output=True, text=True, check=True ) return True except Exception as e: print(f"Video creation failed: {e}") return False finally: if cleanup_temp and temp_dir.exists(): shutil.rmtree(temp_dir) def main( frames_dir: Union[str, Path], flow_vis_dir: Union[str, Path] = 'flow_out', flow_save_dir: Optional[Union[str, Path]] = None, resize_to: Optional[Tuple[int, int]] = None, model_name: str = 'dpflow', checkpoint: str = 'sintel' ) -> bool: # Initialize processor processor = OpticalFlowProcessor( model_name=model_name, checkpoint=checkpoint, resize_to=resize_to ) # Load frames frames, png_paths = processor.load_frame_sequence(frames_dir) # Compute optical flow flows = processor.compute_optical_flow_sequence( frames=frames, flow_vis_dir=flow_vis_dir, flow_save_dir=flow_save_dir, save_visualizations=True ) # Create video create_flow_video(flow_vis_dir) def get_parser(): parser = argparse.ArgumentParser(description="Optical flow inference on frame sequences") parser.add_argument('--input_path', type=str, help="base input path") parser.add_argument('--seq_name', type=str, help="sequence name") parser.add_argument('--model_name', type=str, default='dpflow', help="Optical flow model to use") parser.add_argument('--checkpoint', type=str, default='sintel', help="Model checkpoint/dataset name") parser.add_argument('--resize_width', type=int, default=None, help="Resize frame width (must specify both width and height)") parser.add_argument('--resize_height', type=int, default=None, help="Resize frame height (must specify both width and height)") parser.add_argument('--fps', type=int, default=10, help="Frame rate for output video") return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() # Path frames_dir = f'{args.input_path}/{args.seq_name}/imgs' flow_vis_dir = frames_dir.replace("imgs", "flow_vis") flow_save_dir = frames_dir.replace("imgs", "flow") os.makedirs(flow_vis_dir, exist_ok=True) os.makedirs(flow_save_dir, exist_ok=True) # Prepare resize parameter resize_to = None if args.resize_width and args.resize_height: resize_to = (args.resize_width, args.resize_height) # Process optical flow success = main( frames_dir=frames_dir, flow_vis_dir=flow_vis_dir, flow_save_dir=flow_save_dir, resize_to=resize_to, model_name=args.model_name, checkpoint=args.checkpoint ) print("Optical flow processing completed successfully")