File size: 7,729 Bytes
abd74e7
 
 
 
 
 
 
 
 
 
00dcc9e
abd74e7
 
 
 
8fa2de8
 
 
 
 
 
 
abd74e7
 
 
8fa2de8
 
abd74e7
 
 
 
 
 
 
 
8fa2de8
 
abd74e7
 
 
 
 
 
8fa2de8
 
 
 
 
 
 
 
 
 
abd74e7
 
 
8fa2de8
abd74e7
 
 
b7c2e61
a562ca7
 
 
 
 
 
 
 
b212cd0
 
 
 
 
 
 
abd74e7
 
 
 
 
 
 
 
 
 
 
 
b7c2e61
 
 
abd74e7
 
 
b7c2e61
abd74e7
 
 
 
b7c2e61
abd74e7
 
b7c2e61
abd74e7
 
 
 
 
 
 
 
 
00dcc9e
 
b7c2e61
abd74e7
 
 
 
 
 
 
 
 
 
 
00dcc9e
abd74e7
 
 
 
 
 
 
 
 
 
00dcc9e
 
 
 
 
 
b7c2e61
 
9d9eb70
 
 
b7c2e61
 
9d9eb70
 
b7c2e61
00dcc9e
 
 
 
 
 
 
 
 
 
 
 
 
abd74e7
00dcc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7c2e61
00dcc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd74e7
 
b7c2e61
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import torch
from PIL import Image
import io
import os
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizerFast, CLIPImageProcessor
import numpy as np
from diffusers import DiffusionPipeline
import warnings
import gradio as gr

warnings.filterwarnings("ignore")

# Global evaluator instance (lazy loaded)
evaluator = None

# Check if running on Hugging Face Spaces with ZeroGPU
try:
    import spaces
    ZERO_GPU_AVAILABLE = True
except ImportError:
    ZERO_GPU_AVAILABLE = False


class TextToImageEvaluator:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32
        clip_model_name = "openai/clip-vit-large-patch14-336"
        tokenizer = CLIPTokenizerFast.from_pretrained(clip_model_name)
        image_processor = CLIPImageProcessor.from_pretrained(clip_model_name)
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.clip_processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)

        print("Loading image generation model...")
        self.generator = DiffusionPipeline.from_pretrained(
            "Lykon/dreamshaper-xl-v2-turbo",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        self.clip_model.to(self.device)
        self.generator.to(self.device)

        if self.device == "cuda":
            self.generator.enable_attention_slicing()
            self.generator.enable_vae_slicing()
            # Try to enable xformers if available
            try:
                self.generator.enable_xformers_memory_efficient_attention()
                print("xformers enabled for memory efficient attention")
            except Exception:
                pass
        else:
            # CPU optimizations
            self.generator.enable_attention_slicing(1)

        print(f"Models loaded successfully on {self.device}")

    def generate_image(self, text, num_inference_steps=6, guidance_scale=2):
        """Generate image from text using Stable Diffusion"""
        self.generator.to(self.device)
        generator = torch.Generator(device=self.generator.device).manual_seed(42)
        with torch.inference_mode():
            if self.device == "cuda":
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    image = self.generator(
                        text,
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        generator=generator
                    ).images[0]
            else:
                image = self.generator(
                    text,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    generator=generator
                ).images[0]
        if self.device == "cuda":
            torch.cuda.empty_cache()

        return image

    def calculate_clip_score(self, image, text):
        """Calculate CLIPScore between image and text"""
        self.clip_model.to(self.device)
        inputs = self.clip_processor(
            text=[text],
            images=[image],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=77
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.inference_mode():
            outputs = self.clip_model(**inputs)

        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds
        # Normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
        # Calculate cosine similarity
        similarity = (image_embeds * text_embeds).sum(dim=-1)
        score = similarity.cpu().item()
        
        if self.device == "cuda":
            torch.cuda.empty_cache()
        return score

    def process_prompt(self, text):
        """Process a single text prompt and return image with scores"""
        if not text or text.strip() == "":
            raise gr.Error("Please enter a prompt")
        text = text.strip()
        print(f"Processing prompt: {text}")
        
        # Generate image
        print("Generating image...")
        generated_image = self.generate_image(text)
        
        # Calculate CLIP score
        print("Calculating similarity scores...")
        clip_score = self.calculate_clip_score(generated_image, text)
        geneval_score = clip_score * 2.5
        
        return generated_image, round(clip_score, 4), round(geneval_score, 4)


def get_evaluator():
    """Lazy load the evaluator"""
    global evaluator
    if evaluator is None:
        evaluator = TextToImageEvaluator()
    return evaluator


def generate_and_evaluate(prompt):
    """Main function for Gradio interface"""
    eval_instance = get_evaluator()
    image, clip_score, geneval_score = eval_instance.process_prompt(prompt)
    return image, f"{clip_score}", f"{geneval_score}"

# Use ZeroGPU decorator if available on HF Spaces
if ZERO_GPU_AVAILABLE:
    # Store reference to original function before reassignment
    _generate_and_evaluate_impl = generate_and_evaluate
    
    @spaces.GPU
    def generate_and_evaluate_gpu(prompt):
        return _generate_and_evaluate_impl(prompt)
    
    generate_and_evaluate = generate_and_evaluate_gpu

# Create Gradio interface
with gr.Blocks(title="Text-to-Image Generator & Evaluator") as demo:
    gr.Markdown(
        """
        # 🎨 Text-to-Image Generator & Evaluator
        
        Generate images from text prompts and evaluate them using CLIP scores.
        
        - **CLIP Score**: Measures how well the generated image matches the text prompt (0-1 scale)
        - **GenEval Score**: Scaled evaluation score (CLIP Score × 2.5)
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(
                label="Enter your prompt",
                placeholder="A beautiful sunset over mountains...",
                lines=3
            )
            generate_btn = gr.Button("🚀 Generate Image", variant="primary")
            
            gr.Markdown("### Evaluation Scores")
            with gr.Row():
                clip_score_output = gr.Textbox(label="CLIP Score", interactive=False)
                geneval_score_output = gr.Textbox(label="GenEval Score", interactive=False)
        
        with gr.Column(scale=1):
            image_output = gr.Image(label="Generated Image", type="pil")
    
    # Example prompts
    gr.Examples(
        examples=[
            ["A futuristic city with flying cars at night"],
            ["A cute cat wearing a wizard hat"],
            ["An astronaut riding a horse on Mars"],
            ["A cozy coffee shop interior with warm lighting"]
        ],
        inputs=prompt_input
    )
    
    # Connect the button to the function
    generate_btn.click(
        fn=generate_and_evaluate,
        inputs=prompt_input,
        outputs=[image_output, clip_score_output, geneval_score_output]
    )
    
    # Also allow Enter key to submit
    prompt_input.submit(
        fn=generate_and_evaluate,
        inputs=prompt_input,
        outputs=[image_output, clip_score_output, geneval_score_output]
    )

if __name__ == "__main__":
    print("TEXT-TO-IMAGE GENERATOR - GRADIO APP")
    print("=" * 60)
    print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    demo.queue(max_size=10).launch()