""" ColorizeNet model wrapper for image colorization """ import logging import torch import numpy as np from PIL import Image from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline from diffusers.utils import load_image from transformers import pipeline from huggingface_hub import hf_hub_download from app.config import settings logger = logging.getLogger(__name__) class ColorizeModel: """Wrapper for ColorizeNet model""" def __init__(self, model_id: str | None = None): """ Initialize the ColorizeNet model Args: model_id: Hugging Face model ID for ColorizeNet """ if model_id is None: model_id = settings.MODEL_ID self.model_id = model_id self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Using device: %s", self.device) self.dtype = torch.float16 if self.device == "cuda" else torch.float32 try: # Try loading as ControlNet with Stable Diffusion logger.info("Attempting to load model as ControlNet: %s", self.model_id) try: # Load ControlNet model self.controlnet = ControlNetModel.from_pretrained( self.model_id, torch_dtype=self.dtype ) # Try SDXL first, fallback to SD 1.5 try: self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=self.controlnet, torch_dtype=self.dtype, safety_checker=None, requires_safety_checker=False ) logger.info("Loaded with SDXL base model") except: self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, torch_dtype=self.dtype, safety_checker=None, requires_safety_checker=False ) logger.info("Loaded with SD 1.5 base model") self.pipe.to(self.device) # Enable memory efficient attention if available if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): try: self.pipe.enable_xformers_memory_efficient_attention() logger.info("XFormers memory efficient attention enabled") except Exception as e: logger.warning("Could not enable XFormers: %s", str(e)) logger.info("ColorizeNet model loaded successfully as ControlNet") self.model_type = "controlnet" except Exception as e: logger.warning("Failed to load as ControlNet: %s", str(e)) # Fallback: try as image-to-image pipeline logger.info("Trying to load as image-to-image pipeline...") self.pipe = pipeline( "image-to-image", model=self.model_id, device=0 if self.device == "cuda" else -1, torch_dtype=self.dtype ) logger.info("ColorizeNet model loaded using image-to-image pipeline") self.model_type = "pipeline" except Exception as e: logger.error("Failed to load ColorizeNet model: %s", str(e)) raise RuntimeError(f"Could not load ColorizeNet model: {str(e)}") def preprocess_image(self, image: Image.Image) -> Image.Image: """ Preprocess image for colorization Args: image: PIL Image Returns: Preprocessed PIL Image """ # Convert to grayscale if needed if image.mode != "L": # Convert to grayscale image = image.convert("L") # Convert back to RGB (grayscale image with 3 channels) image = image.convert("RGB") # Resize to standard size (512x512 for SD models) image = image.resize((512, 512), Image.Resampling.LANCZOS) return image def colorize(self, image: Image.Image, num_inference_steps: int = None) -> Image.Image: """ Colorize a grayscale image Args: image: PIL Image (grayscale or color) num_inference_steps: Number of inference steps (auto-adjusted for CPU/GPU) Returns: Colorized PIL Image """ try: # Optimize inference steps based on device if num_inference_steps is None: # Use fewer steps on CPU for faster processing num_inference_steps = 8 if self.device == "cpu" else 20 # Preprocess image control_image = self.preprocess_image(image) original_size = image.size # Prepare prompt for colorization prompt = "colorize this black and white image, high quality, detailed, vibrant colors, natural colors" negative_prompt = "black and white, grayscale, monochrome, low quality, blurry, desaturated" # Adjust guidance scale for CPU (lower = faster) guidance_scale = 5.0 if self.device == "cpu" else 7.5 # Generate colorized image based on model type if self.model_type == "controlnet": # Use ControlNet pipeline result = self.pipe( prompt=prompt, image=control_image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=1.0, generator=torch.Generator(device=self.device).manual_seed(42) ) if isinstance(result, dict) and "images" in result: colorized = result["images"][0] elif isinstance(result, list) and len(result) > 0: colorized = result[0] else: colorized = result else: # Use pipeline directly result = self.pipe( control_image, prompt=prompt, num_inference_steps=num_inference_steps ) if isinstance(result, dict) and "images" in result: colorized = result["images"][0] elif isinstance(result, list) and len(result) > 0: colorized = result[0] else: colorized = result # Ensure we have a PIL Image if not isinstance(colorized, Image.Image): if isinstance(colorized, np.ndarray): # Handle numpy array if colorized.dtype != np.uint8: colorized = (colorized * 255).astype(np.uint8) if len(colorized.shape) == 3 and colorized.shape[2] == 3: colorized = Image.fromarray(colorized, 'RGB') else: colorized = Image.fromarray(colorized) elif torch.is_tensor(colorized): # Handle torch tensor colorized = colorized.cpu().permute(1, 2, 0).numpy() colorized = (colorized * 255).astype(np.uint8) colorized = Image.fromarray(colorized, 'RGB') else: raise ValueError(f"Unexpected output type: {type(colorized)}") # Resize back to original size if original_size != (512, 512): colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) return colorized except Exception as e: logger.error("Error during colorization: %s", str(e)) raise