|
|
""" |
|
|
ColorizeNet model wrapper for image colorization |
|
|
""" |
|
|
import logging |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, StableDiffusionImg2ImgPipeline |
|
|
from diffusers.utils import load_image |
|
|
from transformers import pipeline |
|
|
from huggingface_hub import hf_hub_download |
|
|
from app.config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ColorizeModel: |
|
|
"""Wrapper for ColorizeNet model""" |
|
|
|
|
|
def __init__(self, model_id: str | None = None): |
|
|
""" |
|
|
Initialize the ColorizeNet model |
|
|
|
|
|
Args: |
|
|
model_id: Hugging Face model ID for ColorizeNet |
|
|
""" |
|
|
if model_id is None: |
|
|
model_id = settings.MODEL_ID |
|
|
self.model_id = model_id |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info("Using device: %s", self.device) |
|
|
self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
self.hf_token = os.getenv("HF_TOKEN") or None |
|
|
|
|
|
|
|
|
hf_cache_dir = os.getenv("HF_HOME", "./hf_cache") |
|
|
os.environ.setdefault("HF_HOME", hf_cache_dir) |
|
|
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", hf_cache_dir) |
|
|
os.environ.setdefault("TRANSFORMERS_CACHE", hf_cache_dir) |
|
|
os.makedirs(hf_cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
|
|
|
try: |
|
|
|
|
|
wants_controlnet = "control" in self.model_id.lower() |
|
|
|
|
|
if wants_controlnet: |
|
|
|
|
|
logger.info("Attempting to load model as ControlNet: %s", self.model_id) |
|
|
try: |
|
|
|
|
|
self.controlnet = ControlNetModel.from_pretrained( |
|
|
self.model_id, |
|
|
torch_dtype=self.dtype, |
|
|
token=self.hf_token, |
|
|
cache_dir=hf_cache_dir |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
|
controlnet=self.controlnet, |
|
|
torch_dtype=self.dtype, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False, |
|
|
token=self.hf_token, |
|
|
cache_dir=hf_cache_dir |
|
|
) |
|
|
logger.info("Loaded with SDXL base model") |
|
|
except Exception: |
|
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
controlnet=self.controlnet, |
|
|
torch_dtype=self.dtype, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False, |
|
|
token=self.hf_token, |
|
|
cache_dir=hf_cache_dir |
|
|
) |
|
|
logger.info("Loaded with SD 1.5 base model") |
|
|
|
|
|
self.pipe.to(self.device) |
|
|
|
|
|
|
|
|
if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): |
|
|
try: |
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
logger.info("XFormers memory efficient attention enabled") |
|
|
except Exception as e: |
|
|
logger.warning("Could not enable XFormers: %s", str(e)) |
|
|
|
|
|
logger.info("ColorizeNet model loaded successfully as ControlNet") |
|
|
self.model_type = "controlnet" |
|
|
except Exception as e: |
|
|
logger.warning("Failed to load as ControlNet: %s", str(e)) |
|
|
wants_controlnet = False |
|
|
|
|
|
if not wants_controlnet: |
|
|
|
|
|
logger.info("Trying to load as image-to-image pipeline...") |
|
|
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
|
self.model_id, |
|
|
torch_dtype=self.dtype, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False, |
|
|
use_safetensors=True, |
|
|
cache_dir=hf_cache_dir, |
|
|
token=self.hf_token |
|
|
).to(self.device) |
|
|
logger.info("ColorizeNet model loaded using image-to-image pipeline") |
|
|
self.model_type = "pipeline" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("Failed to load ColorizeNet model: %s", str(e)) |
|
|
raise RuntimeError(f"Could not load ColorizeNet model: {str(e)}") |
|
|
|
|
|
def preprocess_image(self, image: Image.Image) -> Image.Image: |
|
|
""" |
|
|
Preprocess image for colorization |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
|
|
|
Returns: |
|
|
Preprocessed PIL Image |
|
|
""" |
|
|
|
|
|
if image.mode != "L": |
|
|
|
|
|
image = image.convert("L") |
|
|
|
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
image = image.resize((512, 512), Image.Resampling.LANCZOS) |
|
|
|
|
|
return image |
|
|
|
|
|
def colorize(self, image: Image.Image, num_inference_steps: int = None) -> Image.Image: |
|
|
""" |
|
|
Colorize a grayscale image |
|
|
|
|
|
Args: |
|
|
image: PIL Image (grayscale or color) |
|
|
num_inference_steps: Number of inference steps (auto-adjusted for CPU/GPU) |
|
|
|
|
|
Returns: |
|
|
Colorized PIL Image |
|
|
""" |
|
|
try: |
|
|
|
|
|
if num_inference_steps is None: |
|
|
|
|
|
num_inference_steps = 8 if self.device == "cpu" else 20 |
|
|
|
|
|
|
|
|
control_image = self.preprocess_image(image) |
|
|
original_size = image.size |
|
|
|
|
|
|
|
|
prompt = "colorize this black and white image, high quality, detailed, vibrant colors, natural colors" |
|
|
negative_prompt = "black and white, grayscale, monochrome, low quality, blurry, desaturated" |
|
|
|
|
|
|
|
|
guidance_scale = 5.0 if self.device == "cpu" else 7.5 |
|
|
|
|
|
|
|
|
if self.model_type == "controlnet": |
|
|
|
|
|
result = self.pipe( |
|
|
prompt=prompt, |
|
|
image=control_image, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
controlnet_conditioning_scale=1.0, |
|
|
generator=torch.Generator(device=self.device).manual_seed(42) |
|
|
) |
|
|
|
|
|
if isinstance(result, dict) and "images" in result: |
|
|
colorized = result["images"][0] |
|
|
elif isinstance(result, list) and len(result) > 0: |
|
|
colorized = result[0] |
|
|
else: |
|
|
colorized = result |
|
|
else: |
|
|
|
|
|
result = self.pipe( |
|
|
control_image, |
|
|
prompt=prompt, |
|
|
num_inference_steps=num_inference_steps |
|
|
) |
|
|
|
|
|
if isinstance(result, dict) and "images" in result: |
|
|
colorized = result["images"][0] |
|
|
elif isinstance(result, list) and len(result) > 0: |
|
|
colorized = result[0] |
|
|
else: |
|
|
colorized = result |
|
|
|
|
|
|
|
|
if not isinstance(colorized, Image.Image): |
|
|
if isinstance(colorized, np.ndarray): |
|
|
|
|
|
if colorized.dtype != np.uint8: |
|
|
colorized = (colorized * 255).astype(np.uint8) |
|
|
if len(colorized.shape) == 3 and colorized.shape[2] == 3: |
|
|
colorized = Image.fromarray(colorized, 'RGB') |
|
|
else: |
|
|
colorized = Image.fromarray(colorized) |
|
|
elif torch.is_tensor(colorized): |
|
|
|
|
|
colorized = colorized.cpu().permute(1, 2, 0).numpy() |
|
|
colorized = (colorized * 255).astype(np.uint8) |
|
|
colorized = Image.fromarray(colorized, 'RGB') |
|
|
else: |
|
|
raise ValueError(f"Unexpected output type: {type(colorized)}") |
|
|
|
|
|
|
|
|
if original_size != (512, 512): |
|
|
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
return colorized |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("Error during colorization: %s", str(e)) |
|
|
raise |
|
|
|
|
|
|