Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| from PIL import Image | |
| from typing import List, Tuple, Optional, Union | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| import torch | |
| import torchvision.transforms as transforms | |
| from accelerate.utils import set_seed | |
| from src import ( | |
| FontDiffuserDPMPipeline, | |
| FontDiffuserModelDPM, | |
| build_ddpm_scheduler, | |
| build_unet, | |
| build_content_encoder, | |
| build_style_encoder, | |
| ) | |
| from utils import ( | |
| ttf2im, | |
| load_ttf, | |
| is_char_in_font, | |
| save_args_to_yaml, | |
| save_single_image, | |
| save_image_with_content_style, | |
| ) | |
| class BatchProcessor: | |
| """Handles batch processing logic for FontDiffuser""" | |
| def __init__(self, args): | |
| self.args = args | |
| self.device = args.device | |
| self.max_batch_size = getattr(args, "max_batch_size", 8) | |
| self.num_workers = getattr(args, "num_workers", 4) | |
| def batch_image_process( | |
| self, | |
| content_inputs: List[Union[str, Image.Image]], | |
| style_inputs: List[Union[str, Image.Image]], | |
| content_characters: Optional[List[str]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[Image.Image]]]: | |
| """ | |
| Process multiple images in batch | |
| Args: | |
| content_inputs: List of content image paths or PIL Images | |
| style_inputs: List of style image paths or PIL Images | |
| content_characters: List of characters if using character input mode | |
| Returns: | |
| Tuple of (content_tensors, style_tensors, content_pil_images) | |
| """ | |
| batch_size = len(content_inputs) | |
| assert len(style_inputs) == batch_size, ( | |
| "Content and style inputs must have same length" | |
| ) | |
| if content_characters: | |
| assert len(content_characters) == batch_size, ( | |
| "Content characters must match batch size" | |
| ) | |
| # Transform setup | |
| content_inference_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| self.args.content_image_size, | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| style_inference_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| self.args.style_image_size, | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| # Initialize ordered lists for results | |
| content_tensors = [None] * batch_size | |
| style_tensors = [None] * batch_size | |
| content_pil_images = [None] * batch_size | |
| # Process in parallel using ThreadPoolExecutor for I/O operations | |
| with ThreadPoolExecutor(max_workers=self.num_workers) as executor: | |
| # Submit content processing tasks | |
| content_futures = [] | |
| for i, content_input in enumerate(content_inputs): | |
| if content_characters and i < len(content_characters): | |
| future = executor.submit( | |
| self._process_content_character, | |
| content_characters[i], | |
| content_inference_transforms, | |
| ) | |
| else: | |
| future = executor.submit( | |
| self._process_content_image, | |
| content_input, | |
| content_inference_transforms, | |
| ) | |
| content_futures.append((i, future)) | |
| # Submit style processing tasks | |
| style_futures = [] | |
| for i, style_input in enumerate(style_inputs): | |
| future = executor.submit( | |
| self._process_style_image, style_input, style_inference_transforms | |
| ) | |
| style_futures.append((i, future)) | |
| # Collect results in order | |
| for i, future in content_futures: | |
| try: | |
| content_tensor, content_pil = future.result() | |
| if content_tensor is not None: | |
| content_tensors[i] = content_tensor | |
| content_pil_images[i] = content_pil | |
| except Exception as e: | |
| print(f"Error processing content at index {i}: {e}") | |
| continue | |
| for i, future in style_futures: | |
| try: | |
| style_tensor = future.result() | |
| if style_tensor is not None: | |
| style_tensors[i] = style_tensor | |
| except Exception as e: | |
| print(f"Error processing style at index {i}: {e}") | |
| continue | |
| # Filter out None values and stack tensors | |
| content_tensors = [t for t in content_tensors if t is not None] | |
| style_tensors = [t for t in style_tensors if t is not None] | |
| content_pil_images = [img for img in content_pil_images if img is not None] | |
| if content_tensors and style_tensors: | |
| content_batch = torch.stack(content_tensors) | |
| style_batch = torch.stack(style_tensors) | |
| return content_batch, style_batch, content_pil_images | |
| else: | |
| return None, None, [] | |
| def _process_content_character( | |
| self, character: str, transform | |
| ) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]: | |
| """Process content character into tensor""" | |
| if not is_char_in_font(font_path=self.args.ttf_path, char=character): | |
| print(f"Character '{character}' not found in font") | |
| return None, None | |
| font = load_ttf(ttf_path=self.args.ttf_path) | |
| content_image = ttf2im(font=font, char=character) | |
| content_image_pil = content_image.copy() | |
| content_tensor = transform(content_image) | |
| return content_tensor, content_image_pil | |
| def _process_content_image( | |
| self, image_input: Union[str, Image.Image], transform | |
| ) -> Tuple[Optional[torch.Tensor], None]: | |
| """Process content image into tensor""" | |
| try: | |
| if isinstance(image_input, str): | |
| content_image = Image.open(image_input).convert("RGB") | |
| else: | |
| content_image = image_input.convert("RGB") | |
| content_tensor = transform(content_image) | |
| return content_tensor, None | |
| except Exception as e: | |
| print(f"Error processing content image: {e}") | |
| return None, None | |
| def _process_style_image( | |
| self, image_input: Union[str, Image.Image], transform | |
| ) -> Optional[torch.Tensor]: | |
| """Process style image into tensor""" | |
| try: | |
| if isinstance(image_input, str): | |
| style_image = Image.open(image_input).convert("RGB") | |
| else: | |
| style_image = image_input.convert("RGB") | |
| style_tensor = transform(style_image) | |
| return style_tensor | |
| except Exception as e: | |
| print(f"Error processing style image: {e}") | |
| return None | |
| def arg_parse(): | |
| from configs.fontdiffuser import get_parser | |
| parser = get_parser() | |
| parser.add_argument("--ckpt_dir", type=str, default=None) | |
| parser.add_argument("--demo", action="store_true") | |
| parser.add_argument( | |
| "--controlnet", | |
| type=bool, | |
| default=False, | |
| help="If in demo mode, the controlnet can be added.", | |
| ) | |
| parser.add_argument("--character_input", action="store_true") | |
| parser.add_argument("--content_character", type=str, default=None) | |
| parser.add_argument("--content_image_path", type=str, default=None) | |
| parser.add_argument("--style_image_path", type=str, default=None) | |
| parser.add_argument("--save_image", action="store_true") | |
| parser.add_argument( | |
| "--save_image_dir", type=str, default=None, help="The saving directory." | |
| ) | |
| parser.add_argument("--device", type=str, default="cuda:0") | |
| parser.add_argument("--ttf_path", type=str, default="ttf/KaiXinSongA.ttf") | |
| # Batch processing arguments | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=4, | |
| help="Batch size for processing multiple images", | |
| ) | |
| parser.add_argument( | |
| "--max_batch_size", | |
| type=int, | |
| default=8, | |
| help="Maximum batch size based on GPU memory", | |
| ) | |
| parser.add_argument( | |
| "--num_workers", | |
| type=int, | |
| default=4, | |
| help="Number of workers for parallel image loading", | |
| ) | |
| parser.add_argument( | |
| "--batch_content_paths", | |
| type=str, | |
| nargs="+", | |
| default=None, | |
| help="List of content image paths for batch processing", | |
| ) | |
| parser.add_argument( | |
| "--batch_style_paths", | |
| type=str, | |
| nargs="+", | |
| default=None, | |
| help="List of style image paths for batch processing", | |
| ) | |
| parser.add_argument( | |
| "--batch_characters", | |
| type=str, | |
| nargs="+", | |
| default=None, | |
| help="List of characters for batch processing", | |
| ) | |
| parser.add_argument( | |
| "--adaptive_batch_size", | |
| action="store_true", | |
| help="Automatically adjust batch size based on GPU memory", | |
| ) | |
| args = parser.parse_args() | |
| style_image_size = args.style_image_size | |
| content_image_size = args.content_image_size | |
| args.style_image_size = (style_image_size, style_image_size) | |
| args.content_image_size = (content_image_size, content_image_size) | |
| return args | |
| def get_optimal_batch_size(args) -> int: | |
| """Determine optimal batch size based on GPU memory""" | |
| if not torch.cuda.is_available(): | |
| return 1 | |
| # Get GPU memory info | |
| gpu_memory = torch.cuda.get_device_properties(args.device).total_memory / ( | |
| 1024**3 | |
| ) # GB | |
| # Estimate batch size based on GPU memory (rough heuristic) | |
| if gpu_memory >= 24: # RTX 4090, A100, etc. | |
| optimal_batch = min(16, args.max_batch_size) | |
| elif gpu_memory >= 12: # RTX 3080 Ti, RTX 4070 Ti, etc. | |
| optimal_batch = min(8, args.max_batch_size) | |
| elif gpu_memory >= 8: # RTX 3070, RTX 4060 Ti, etc. | |
| optimal_batch = min(4, args.max_batch_size) | |
| else: # Lower end GPUs | |
| optimal_batch = min(2, args.max_batch_size) | |
| return optimal_batch | |
| def load_fontdiffuer_pipeline(args): | |
| """Load FontDiffuser pipeline (unchanged from original)""" | |
| # Load the model state_dict | |
| unet = build_unet(args=args) | |
| unet.load_state_dict(torch.load(f"{args.ckpt_dir}/unet.pth")) | |
| style_encoder = build_style_encoder(args=args) | |
| style_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/style_encoder.pth")) | |
| content_encoder = build_content_encoder(args=args) | |
| content_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/content_encoder.pth")) | |
| model = FontDiffuserModelDPM( | |
| unet=unet, style_encoder=style_encoder, content_encoder=content_encoder | |
| ) | |
| model.to(args.device) | |
| print("Loaded the model state_dict successfully!") | |
| # Load the training ddpm_scheduler. | |
| train_scheduler = build_ddpm_scheduler(args=args) | |
| print("Loaded training DDPM scheduler sucessfully!") | |
| # Load the DPM_Solver to generate the sample. | |
| pipe = FontDiffuserDPMPipeline( | |
| model=model, | |
| ddpm_train_scheduler=train_scheduler, | |
| model_type=args.model_type, | |
| guidance_type=args.guidance_type, | |
| guidance_scale=args.guidance_scale, | |
| ) | |
| print("Loaded dpm_solver pipeline sucessfully!") | |
| return pipe | |
| def batch_sampling( | |
| args, | |
| pipe, | |
| content_inputs: List[Union[str, Image.Image]], | |
| style_inputs: List[Union[str, Image.Image]], | |
| content_characters: Optional[List[str]] = None, | |
| ) -> List[Image.Image]: | |
| """ | |
| Perform batch sampling with FontDiffuser | |
| Args: | |
| args: Arguments | |
| pipe: FontDiffuser pipeline | |
| content_inputs: List of content images/paths | |
| style_inputs: List of style images/paths | |
| content_characters: List of characters (if using character input) | |
| Returns: | |
| List of generated images | |
| """ | |
| if not args.demo: | |
| os.makedirs(args.save_image_dir, exist_ok=True) | |
| save_args_to_yaml( | |
| args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml" | |
| ) | |
| if args.seed: | |
| set_seed(seed=args.seed) | |
| # Determine optimal batch size | |
| if args.adaptive_batch_size: | |
| optimal_batch_size = get_optimal_batch_size(args) | |
| print(f"Using adaptive batch size: {optimal_batch_size}") | |
| else: | |
| optimal_batch_size = args.batch_size | |
| batch_processor = BatchProcessor(args) | |
| total_samples = len(content_inputs) | |
| all_generated_images = [] | |
| print(f"Processing {total_samples} samples in batches of {optimal_batch_size}") | |
| # Process in batches | |
| for batch_start in range(0, total_samples, optimal_batch_size): | |
| batch_end = min(batch_start + optimal_batch_size, total_samples) | |
| batch_content = content_inputs[batch_start:batch_end] | |
| batch_style = style_inputs[batch_start:batch_end] | |
| batch_chars = ( | |
| content_characters[batch_start:batch_end] if content_characters else None | |
| ) | |
| print( | |
| f"Processing batch {batch_start // optimal_batch_size + 1}/{(total_samples + optimal_batch_size - 1) // optimal_batch_size}" | |
| ) | |
| # Process batch | |
| content_batch, style_batch, content_pil_images = ( | |
| batch_processor.batch_image_process(batch_content, batch_style, batch_chars) | |
| ) | |
| if content_batch is None or style_batch is None: | |
| print("Skipping batch due to processing errors") | |
| continue | |
| current_batch_size = content_batch.shape[0] | |
| with torch.no_grad(): | |
| content_batch = content_batch.to(args.device) | |
| style_batch = style_batch.to(args.device) | |
| print(f"Generating {current_batch_size} images with DPM-Solver++...") | |
| start_time = time.time() | |
| try: | |
| # Generate batch | |
| images = pipe.generate( | |
| content_images=content_batch, | |
| style_images=style_batch, | |
| batch_size=current_batch_size, | |
| order=args.order, | |
| num_inference_step=args.num_inference_steps, | |
| content_encoder_downsample_size=args.content_encoder_downsample_size, | |
| t_start=args.t_start, | |
| t_end=args.t_end, | |
| dm_size=args.content_image_size, | |
| algorithm_type=args.algorithm_type, | |
| skip_type=args.skip_type, | |
| method=args.method, | |
| correcting_x0_fn=args.correcting_x0_fn, | |
| ) | |
| end_time = time.time() | |
| print(f"Batch generation completed in {end_time - start_time:.2f}s") | |
| # Save images if requested | |
| if args.save_image: | |
| save_batch_images( | |
| args, | |
| images, | |
| content_pil_images, | |
| batch_content, | |
| batch_style, | |
| batch_start, | |
| ) | |
| all_generated_images.extend(images) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| print( | |
| f"GPU out of memory with batch size {current_batch_size}, trying smaller batch..." | |
| ) | |
| torch.cuda.empty_cache() | |
| # Retry with smaller batch | |
| smaller_batch_size = max(1, current_batch_size // 2) | |
| for sub_batch_start in range( | |
| 0, current_batch_size, smaller_batch_size | |
| ): | |
| sub_batch_end = min( | |
| sub_batch_start + smaller_batch_size, current_batch_size | |
| ) | |
| sub_content = content_batch[sub_batch_start:sub_batch_end] | |
| sub_style = style_batch[sub_batch_start:sub_batch_end] | |
| sub_images = pipe.generate( | |
| content_images=sub_content, | |
| style_images=sub_style, | |
| batch_size=sub_batch_end - sub_batch_start, | |
| order=args.order, | |
| num_inference_step=args.num_inference_steps, | |
| content_encoder_downsample_size=args.content_encoder_downsample_size, | |
| t_start=args.t_start, | |
| t_end=args.t_end, | |
| dm_size=args.content_image_size, | |
| algorithm_type=args.algorithm_type, | |
| skip_type=args.skip_type, | |
| method=args.method, | |
| correcting_x0_fn=args.correcting_x0_fn, | |
| ) | |
| all_generated_images.extend(sub_images) | |
| else: | |
| print(f"Error during generation: {e}") | |
| continue | |
| # Clear GPU cache between batches | |
| torch.cuda.empty_cache() | |
| print(f"Batch processing completed! Generated {len(all_generated_images)} images.") | |
| return all_generated_images | |
| def save_batch_images( | |
| args, images, content_pil_images, batch_content, batch_style, batch_offset | |
| ): | |
| """Save batch of generated images""" | |
| for i, image in enumerate(images): | |
| # Create unique filename for each image | |
| image_idx = batch_offset + i | |
| save_single_image( | |
| save_dir=args.save_image_dir, image=image, suffix=f"_{image_idx:04d}" | |
| ) | |
| # Save with content and style context if available | |
| if args.character_input and i < len(content_pil_images): | |
| save_image_with_content_style( | |
| save_dir=args.save_image_dir, | |
| image=image, | |
| content_image_pil=content_pil_images[i], | |
| content_image_path=None, | |
| style_image_path=batch_style[i] | |
| if isinstance(batch_style[i], str) | |
| else None, | |
| resolution=args.resolution, | |
| suffix=f"_{image_idx:04d}", | |
| ) | |
| elif not args.character_input: | |
| save_image_with_content_style( | |
| save_dir=args.save_image_dir, | |
| image=image, | |
| content_image_pil=None, | |
| content_image_path=batch_content[i] | |
| if isinstance(batch_content[i], str) | |
| else None, | |
| style_image_path=batch_style[i] | |
| if isinstance(batch_style[i], str) | |
| else None, | |
| resolution=args.resolution, | |
| suffix=f"_{image_idx:04d}", | |
| ) | |
| def sampling(args, pipe, content_image=None, style_image=None): | |
| """Original single image sampling function (for backward compatibility)""" | |
| if not args.demo: | |
| os.makedirs(args.save_image_dir, exist_ok=True) | |
| save_args_to_yaml( | |
| args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml" | |
| ) | |
| if args.seed: | |
| set_seed(seed=args.seed) | |
| # Use single image processing | |
| if args.character_input: | |
| content_inputs = ( | |
| [args.content_character] if hasattr(args, "content_character") else ["A"] | |
| ) | |
| style_inputs = [style_image or args.style_image_path] | |
| result = batch_sampling(args, pipe, [], style_inputs, content_inputs) | |
| else: | |
| content_inputs = [content_image or args.content_image_path] | |
| style_inputs = [style_image or args.style_image_path] | |
| result = batch_sampling(args, pipe, content_inputs, style_inputs) | |
| return result[0] if result else None | |
| # Additional utility functions for batch processing | |
| def load_images_from_directory( | |
| directory_path: str, extensions: List[str] = [".jpg", ".jpeg", ".png", ".bmp"] | |
| ) -> List[str]: | |
| """Load all image paths from a directory""" | |
| directory = Path(directory_path) | |
| image_paths = [] | |
| for ext in extensions: | |
| image_paths.extend(directory.glob(f"*{ext}")) | |
| image_paths.extend(directory.glob(f"*{ext.upper()}")) | |
| return [str(path) for path in sorted(image_paths)] | |
| def create_batch_from_config( | |
| config_file: str, | |
| ) -> Tuple[List[str], List[str], List[str]]: | |
| """Create batch inputs from configuration file""" | |
| import json | |
| with open(config_file, "r") as f: | |
| config = json.load(f) | |
| content_inputs = config.get("content_images", []) | |
| style_inputs = config.get("style_images", []) | |
| characters = config.get("characters", []) | |
| return content_inputs, style_inputs, characters | |
| if __name__ == "__main__": | |
| args = arg_parse() | |
| # Load fontdiffuser pipeline | |
| pipe = load_fontdiffuer_pipeline(args=args) | |
| # Check if batch processing is requested | |
| if args.batch_content_paths or args.batch_style_paths or args.batch_characters: | |
| # Batch processing mode | |
| content_inputs = args.batch_content_paths or [] | |
| style_inputs = args.batch_style_paths or [] | |
| characters = args.batch_characters or None | |
| if characters and args.character_input: | |
| # Character-based batch processing | |
| style_inputs = style_inputs or [args.style_image_path] * len(characters) | |
| generated_images = batch_sampling(args, pipe, [], style_inputs, characters) | |
| else: | |
| # Image-based batch processing | |
| if len(content_inputs) != len(style_inputs): | |
| print("Error: Number of content and style images must match") | |
| exit(1) | |
| generated_images = batch_sampling(args, pipe, content_inputs, style_inputs) | |
| print(f"Batch processing completed! Generated {len(generated_images)} images.") | |
| else: | |
| # Single image processing (original behavior) | |
| out_image = sampling(args=args, pipe=pipe) | |