Spaces:
Sleeping
Sleeping
| import spaces | |
| from accelerate import dispatch_model | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from typing import Optional, Dict, Any | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| StableDiffusionXLPipeline, | |
| AutoPipelineForText2Image | |
| ) | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import gc | |
| from io import BytesIO | |
| import base64 | |
| import functools | |
| app = FastAPI() | |
| # Comprehensive model registry | |
| MODELS = { | |
| "SDXL-Base": { | |
| "model_id": "stabilityai/stable-diffusion-xl-base-1.0", | |
| "pipeline": StableDiffusionXLPipeline, | |
| "supports_img2img": True, | |
| "parameters": { | |
| "num_inference_steps": {"min": 1, "max": 100, "default": 50}, | |
| "guidance_scale": {"min": 1, "max": 15, "default": 7.5}, | |
| "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
| "height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
| } | |
| }, | |
| "SDXL-Turbo": { | |
| "model_id": "stabilityai/sdxl-turbo", | |
| "pipeline": AutoPipelineForText2Image, | |
| "supports_img2img": True, | |
| "parameters": { | |
| "num_inference_steps": {"min": 1, "max": 50, "default": 1}, | |
| "guidance_scale": {"min": 0.0, "max": 20.0, "default": 7.5}, | |
| "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
| "height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
| } | |
| }, | |
| "SD-1.5": { | |
| "model_id": "runwayml/stable-diffusion-v1-5", | |
| "pipeline": StableDiffusionPipeline, | |
| "supports_img2img": True, | |
| "parameters": { | |
| "num_inference_steps": {"min": 1, "max": 50, "default": 30}, | |
| "guidance_scale": {"min": 1, "max": 20, "default": 7.5}, | |
| "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
| "height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
| } | |
| }, | |
| "Waifu-Diffusion": { | |
| "model_id": "hakurei/waifu-diffusion", | |
| "pipeline": StableDiffusionPipeline, | |
| "supports_img2img": True, | |
| "parameters": { | |
| "num_inference_steps": {"min": 1, "max": 100, "default": 50}, | |
| "guidance_scale": {"min": 1, "max": 15, "default": 7.5}, | |
| "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
| "height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
| } | |
| }, | |
| "Flux": { | |
| "model_id": "black-forest-labs/flux-1-1-dev", | |
| "pipeline": AutoPipelineForText2Image, | |
| "supports_img2img": True, | |
| "parameters": { | |
| "num_inference_steps": {"min": 1, "max": 50, "default": 25}, | |
| "guidance_scale": {"min": 1, "max": 15, "default": 7.5}, | |
| "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
| "height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
| } | |
| } | |
| } | |
| class ModelManager: | |
| def __init__(self): | |
| self.current_model = None | |
| self.current_pipeline = None | |
| self.model_cache: Dict[str, Any] = {} | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._dtype = torch.float16 if self._device == "cuda" else torch.float32 | |
| def _clear_memory(self): | |
| """Clear CUDA memory and garbage collect""" | |
| if self.current_pipeline is not None: | |
| del self.current_pipeline | |
| self.current_pipeline = None | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| gc.collect() | |
| def get_model_config(self, model_id: str, pipeline_class): | |
| """Load and cache model configuration""" | |
| return pipeline_class.from_pretrained( | |
| model_id, | |
| torch_dtype=self._dtype, | |
| variant="fp16" if self._device == "cuda" else None, | |
| device_map="balanced" | |
| ) | |
| def load_model(self, model_name: str): | |
| """Load model with memory optimization""" | |
| if self.current_model != model_name: | |
| self._clear_memory() | |
| try: | |
| model_info = MODELS[model_name] | |
| self.current_pipeline = self.get_model_config( | |
| model_info["model_id"], | |
| model_info["pipeline"] | |
| ) | |
| if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'): | |
| self.current_pipeline.enable_xformers_memory_efficient_attention() | |
| # if self._device == "cuda": | |
| # self.current_pipeline.enable_model_cpu_offload() | |
| self.current_model = model_name | |
| except Exception as e: | |
| self._clear_memory() | |
| raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") | |
| return self.current_pipeline | |
| def unload_current_model(self): | |
| """Explicitly unload current model""" | |
| self._clear_memory() | |
| self.current_model = None | |
| def get_memory_status(self): | |
| """Get current memory usage status""" | |
| if not torch.cuda.is_available(): | |
| return {"status": "CPU Mode"} | |
| return { | |
| "total": torch.cuda.get_device_properties(0).total_memory / 1e9, | |
| "allocated": torch.cuda.memory_allocated() / 1e9, | |
| "cached": torch.cuda.memory_reserved() / 1e9, | |
| "free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9 | |
| } | |
| class ModelContext: | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| def __enter__(self): | |
| pipeline = model_manager.load_model(self.model_name) | |
| if hasattr(pipeline, 'reset_device_map'): | |
| pipeline.reset_device_map() | |
| # Check if the pipeline supports dispatch_model | |
| if hasattr(pipeline, 'state_dict'): | |
| dispatch_model(pipeline, device_map="auto") | |
| return pipeline | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if exc_type is not None: | |
| model_manager.unload_current_model() | |
| model_manager = ModelManager() | |
| def generate_image( | |
| model_name: str, | |
| prompt: str, | |
| height: int = 512, | |
| width: int = 512, | |
| num_inference_steps: Optional[int] = None, | |
| guidance_scale: Optional[float] = None, | |
| reference_image: Optional[Image.Image] = None | |
| ) -> dict: | |
| try: | |
| with ModelContext(model_name) as pipeline: | |
| pre_mem = model_manager.get_memory_status() | |
| # Process reference image if provided | |
| if reference_image and MODELS[model_name]["supports_img2img"]: | |
| reference_image = reference_image.resize((width, height)) | |
| # Generate image | |
| generation_params = { | |
| "prompt": prompt, | |
| "height": height, | |
| "width": width, | |
| "num_inference_steps": num_inference_steps or MODELS[model_name]["parameters"]["num_inference_steps"]["default"], | |
| "guidance_scale": guidance_scale or MODELS[model_name]["parameters"]["guidance_scale"]["default"] | |
| } | |
| if reference_image: | |
| generation_params["image"] = reference_image | |
| image = pipeline(**generation_params).images[0] | |
| # Convert to base64 | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| post_mem = model_manager.get_memory_status() | |
| return { | |
| "status": "success", | |
| "image_base64": img_str, | |
| "memory": { | |
| "before": pre_mem, | |
| "after": post_mem | |
| } | |
| } | |
| except Exception as e: | |
| model_manager.unload_current_model() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_image_endpoint( | |
| model_name: str, | |
| prompt: str, | |
| height: int = 512, | |
| width: int = 512, | |
| num_inference_steps: Optional[int] = None, | |
| guidance_scale: Optional[float] = None, | |
| reference_image: UploadFile = File(None) | |
| ): | |
| ref_img = None | |
| if reference_image: | |
| content = await reference_image.read() | |
| ref_img = Image.open(BytesIO(content)) | |
| return generate_image( | |
| model_name=model_name, | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| reference_image=ref_img | |
| ) | |
| async def get_memory_status(): | |
| return model_manager.get_memory_status() | |
| async def unload_model(): | |
| model_manager.unload_current_model() | |
| return {"status": "success", "message": "Model unloaded"} | |
| def create_gradio_interface() -> gr.Blocks: | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Text-to-Image Generation Interface") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| label="Select Model" | |
| ) | |
| prompt = gr.Textbox( | |
| lines=3, | |
| label="Prompt", | |
| placeholder="Enter your image description here..." | |
| ) | |
| with gr.Row(): | |
| height = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Height" | |
| ) | |
| width = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Width" | |
| ) | |
| with gr.Row(): | |
| num_steps = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Number of Inference Steps" | |
| ) | |
| guidance = gr.Slider( | |
| minimum=1, | |
| maximum=15, | |
| value=7.5, | |
| step=0.1, | |
| label="Guidance Scale" | |
| ) | |
| reference_image = gr.Image( | |
| type="pil", | |
| label="Reference Image (optional)" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| unload_btn = gr.Button("Unload Model") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Generated Image") | |
| memory_status = gr.JSON( | |
| label="Memory Status", | |
| value=model_manager.get_memory_status() | |
| ) | |
| def update_params(model_name: str) -> list: | |
| model_config = MODELS[model_name]["parameters"] | |
| return [ | |
| gr.update( | |
| minimum=model_config["height"]["min"], | |
| maximum=model_config["height"]["max"], | |
| value=model_config["height"]["default"], | |
| step=model_config["height"]["step"] | |
| ), | |
| gr.update( | |
| minimum=model_config["width"]["min"], | |
| maximum=model_config["width"]["max"], | |
| value=model_config["width"]["default"], | |
| step=model_config["width"]["step"] | |
| ), | |
| gr.update( | |
| minimum=model_config["num_inference_steps"]["min"], | |
| maximum=model_config["num_inference_steps"]["max"], | |
| value=model_config["num_inference_steps"]["default"] | |
| ), | |
| gr.update( | |
| minimum=model_config["guidance_scale"]["min"], | |
| maximum=model_config["guidance_scale"]["max"], | |
| value=model_config["guidance_scale"]["default"] | |
| ) | |
| ] | |
| def generate(model_name: str, prompt_text: str, h: int, w: int, steps: int, guide_scale: float, ref_img: Optional[Image.Image]) -> Image.Image: | |
| response = generate_image( | |
| model_name=model_name, | |
| prompt=prompt_text, | |
| height=h, | |
| width=w, | |
| num_inference_steps=steps, | |
| guidance_scale=guide_scale, | |
| reference_image=ref_img | |
| ) | |
| return Image.open(BytesIO(base64.b64decode(response["image_base64"]))) | |
| model_dropdown.change( | |
| update_params, | |
| inputs=[model_dropdown], | |
| outputs=[height, width, num_steps, guidance] | |
| ) | |
| generate_btn.click( | |
| generate, | |
| inputs=[ | |
| model_dropdown, | |
| prompt, | |
| height, | |
| width, | |
| num_steps, | |
| guidance, | |
| reference_image | |
| ], | |
| outputs=[output_image] | |
| ) | |
| unload_btn.click( | |
| lambda: [model_manager.unload_current_model(), model_manager.get_memory_status()], | |
| outputs=[memory_status] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| import uvicorn | |
| from threading import Thread | |
| # Launch Gradio interface | |
| interface = create_gradio_interface() | |
| gradio_thread = Thread( | |
| target=interface.launch, | |
| kwargs={ | |
| "server_name": "0.0.0.0", | |
| "server_port": 7860, | |
| "share": False | |
| } | |
| ) | |
| gradio_thread.start() | |
| # Launch FastAPI | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |