File size: 11,487 Bytes
e673944
2473931
 
1fcab49
 
663212e
 
1fcab49
 
 
 
663212e
f828cc2
2473931
663212e
 
 
 
ef52cd8
1fcab49
 
 
ef52cd8
 
 
 
2473931
ef52cd8
3d01d22
1fcab49
 
f828cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef52cd8
663212e
 
8c2280a
663212e
 
 
 
 
 
ef52cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663212e
1fcab49
663212e
 
 
 
1fcab49
663212e
1fcab49
ef52cd8
663212e
ef52cd8
 
663212e
ef52cd8
663212e
 
 
 
ef52cd8
663212e
 
 
 
ef52cd8
663212e
 
 
1fcab49
663212e
ef52cd8
663212e
 
 
ef52cd8
663212e
 
 
 
 
 
 
 
1fcab49
 
663212e
 
1fcab49
 
663212e
 
 
1fcab49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfc1d04
2473931
 
1fcab49
 
 
 
 
 
 
 
 
2473931
 
1fcab49
 
 
8c2280a
1fcab49
 
 
 
8c2280a
663212e
 
 
 
2473931
1fcab49
 
 
 
 
 
 
 
1a91398
663212e
 
 
 
 
 
 
 
8c2280a
4ed82d8
663212e
 
 
 
8c2280a
663212e
2473931
 
4ed82d8
 
b01fe58
8c2280a
663212e
1fcab49
 
b01fe58
1fcab49
4ed82d8
1fcab49
 
b01fe58
2473931
 
 
663212e
8c2280a
2473931
 
 
 
663212e
 
2e25444
2473931
1fcab49
2473931
a6f47af
1fcab49
4e2a429
663212e
a6f47af
663212e
 
a6f47af
2473931
a6f47af
663212e
a6f47af
2473931
1fcab49
663212e
 
 
2e25444
 
663212e
 
2e25444
 
663212e
2e25444
 
a6f47af
663212e
 
 
a6f47af
2e25444
a6f47af
663212e
2e25444
a6f47af
663212e
a6f47af
2473931
 
 
581cbb1
2473931
663212e
a6f47af
663212e
a6f47af
8c2280a
a6f47af
 
2473931
1fcab49
663212e
1fcab49
 
 
 
663212e
1fcab49
 
8c2280a
663212e
1fcab49
 
 
663212e
1fcab49
 
663212e
 
 
 
 
 
 
 
 
 
 
 
 
1fcab49
 
663212e
 
 
1fcab49
663212e
 
 
1fcab49
663212e
 
 
 
 
1fcab49
663212e
1fcab49
663212e
 
 
 
 
32fd425
 
a6f47af
663212e
32fd425
 
663212e
 
2473931
 
663212e
 
 
8c2280a
663212e
8c2280a
2473931
 
 
581cbb1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import spaces
import gradio as gr
import torch
import numpy as np
import random
import time
import os
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, Qwen3ForCausalLM
from controlnet_aux.processor import Processor
from PIL import Image
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download

# Import pipeline and model
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.models import ZImageControlTransformer2DModel

# --- Configuration & Paths ---
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280

# Hugging Face Repo IDs
MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
CONTROLNET_REPO = "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union"
CONTROLNET_FILENAME = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors"

print(f"Loading Z-Image Turbo from {MODEL_REPO}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16

# --- FIX: Download Transformer Config & Weights Locally ---
print("Downloading transformer files...")
transformer_path = snapshot_download(
    repo_id=MODEL_REPO,
    allow_patterns=["transformer/*"],
    local_dir="models/transformer",
    local_dir_use_symlinks=False
)
local_transformer_path = os.path.join(transformer_path, "transformer")

if not os.path.exists(os.path.join(local_transformer_path, "config.json")):
    local_transformer_path = transformer_path

print(f"Transformer files located at: {local_transformer_path}")

# --- 1. Load Transformer ---
print("Initializing Transformer...")
transformer = ZImageControlTransformer2DModel.from_pretrained(
    local_transformer_path,
    transformer_additional_kwargs={
        "control_layers_places": [0, 5, 10, 15, 20, 25],
        "control_in_dim": 16
    },
).to(device, weight_dtype)

# --- 2. Download & Load ControlNet Weights ---
if not os.path.exists(CONTROLNET_FILENAME):
    print(f"Downloading ControlNet weights from {CONTROLNET_REPO}...")
    try:
        CONTROLNET_WEIGHTS = hf_hub_download(
            repo_id=CONTROLNET_REPO,
            filename=CONTROLNET_FILENAME
        )
    except Exception as e:
        print(f"Failed to download ControlNet weights: {e}")
        CONTROLNET_WEIGHTS = None
else:
    CONTROLNET_WEIGHTS = CONTROLNET_FILENAME

if CONTROLNET_WEIGHTS:
    print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}")
    try:
        state_dict = load_file(CONTROLNET_WEIGHTS)
        state_dict = state_dict.get("state_dict", state_dict)
        m, u = transformer.load_state_dict(state_dict, strict=False)
        print(f"ControlNet Weights Loaded - Missing keys: {len(m)}, Unexpected keys: {len(u)}")
    except Exception as e:
        print(f"Error loading ControlNet weights: {e}")
else:
    print("Warning: Running without explicit ControlNet weights.")

# --- 3. Load Core Components ---
print("Loading VAE, Tokenizer, and Text Encoder...")
vae = AutoencoderKL.from_pretrained(
    MODEL_REPO,
    subfolder="vae",
).to(device, weight_dtype)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_REPO, 
    subfolder="tokenizer"
)

text_encoder = Qwen3ForCausalLM.from_pretrained(
    MODEL_REPO, 
    subfolder="text_encoder", 
    torch_dtype=weight_dtype,
).to(device)

scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    MODEL_REPO, 
    subfolder="scheduler"
)

# --- 4. Assemble Pipeline ---
pipe = ZImageControlPipeline(
    vae=vae,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    transformer=transformer,
    scheduler=scheduler,
)
pipe.to(device, weight_dtype)
print(f"Model loaded successfully on {device}!")

# --- Helper Functions ---

def rescale_image(image, scale, divisible_by=16):
    """Rescale image and ensure dimensions are divisible by specified value."""
    if image is None:
        return None, 1024, 1024
        
    width, height = image.size
    new_width = int(width * scale)
    new_height = int(height * scale)
    
    new_width = (new_width // divisible_by) * divisible_by
    new_height = (new_height // divisible_by) * divisible_by
    
    if new_width > MAX_IMAGE_SIZE:
        new_width = MAX_IMAGE_SIZE
    if new_height > MAX_IMAGE_SIZE:
        new_height = MAX_IMAGE_SIZE
    
    resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
    return resized, new_width, new_height

@spaces.GPU()
def generate_image(
    prompt,
    negative_prompt="blurry, ugly, bad quality",
    input_image=None,
    control_mode="Canny",
    control_context_scale=0.75,
    image_scale=1.0,
    num_inference_steps=9,
    guidance_scale=1.0,
    seed=42,
    randomize_seed=True,
    progress=gr.Progress(track_tqdm=True)
):
    if not prompt.strip():
        raise gr.Error("Please enter a prompt to generate an image.")
    
    # 1. Set Seed
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device).manual_seed(seed)
    
    # 2. Process Control Image
    if input_image is None:
        raise gr.Error("Please upload a control image.")

    progress(0.2, desc=f"Processing {control_mode}...")
    
    processor_map = {
        'Canny': 'canny',
        'HED': 'softedge_hed',
        'Depth': 'depth_midas',
        'MLSD': 'mlsd',
        'Pose': 'openpose_full'
    }
    processor_id = processor_map.get(control_mode, 'canny')
    
    try:
        processor = Processor(processor_id)
    except Exception as e:
        print(f"Failed to load processor {processor_id}, falling back to Canny. Error: {e}")
        processor = Processor('canny')

    control_image_rescaled, width, height = rescale_image(input_image, image_scale, 16)
    
    # Run Processor
    # We resize to 1024 temporarily for the preprocessor to work best, then resize back to target
    temp_image = control_image_rescaled.resize((1024, 1024))
    processed_image_pil = processor(temp_image, to_pil=True)
    processed_image_pil = processed_image_pil.resize((width, height))
    
    # 3. Generate
    progress(0.5, desc="Generating...")
    
    try:
        # FIX: Pass the processed PIL image directly. 
        # The pipeline handles VAE encoding internally.
        result = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            generator=generator,
            guidance_scale=guidance_scale,
            control_image=processed_image_pil, 
            num_inference_steps=num_inference_steps,
            control_context_scale=control_context_scale,
        )
        
        image = result.images[0]
        progress(1.0, desc="Complete!")
        
        return image, seed, processed_image_pil
        
    except Exception as e:
        raise gr.Error(f"Generation failed: {str(e)}")

# --- UI Configuration (Apple Style) ---

apple_css = """
.gradio-container {
    max-width: 1200px !important;
    margin: 0 auto !important;
    padding: 48px 20px !important;
    font-family: -apple-system, BlinkMacSystemFont, 'Inter', 'Segoe UI', sans-serif !important;
}
.header-container { text-align: center; margin-bottom: 48px; }
.main-title {
    font-size: 56px !important; font-weight: 600 !important;
    letter-spacing: -0.02em !important; color: #1d1d1f !important;
    margin: 0 0 12px 0 !important;
}
.subtitle {
    font-size: 21px !important; color: #6e6e73 !important;
    margin: 0 0 24px 0 !important;
}
.info-badge {
    display: inline-block; background: #0071e3; color: white;
    padding: 6px 16px; border-radius: 20px; font-size: 14px;
    font-weight: 500; margin-bottom: 16px;
}
textarea {
    font-size: 17px !important; border-radius: 12px !important;
    border: 1px solid #d2d2d7 !important; padding: 12px 16px !important;
}
textarea:focus {
    border-color: #0071e3 !important; box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important;
    outline: none !important;
}
button.primary {
    font-size: 17px !important; padding: 12px 32px !important;
    border-radius: 980px !important; background: #0071e3 !important;
    border: none !important; color: #ffffff !important;
    transition: all 0.2s ease !important;
}
button.primary:hover {
    background: #0077ed !important; transform: scale(1.02) !important;
}
.footer-text {
    text-align: center; margin-top: 48px; font-size: 14px !important;
    color: #86868b !important;
}
"""

with gr.Blocks(title="Z-Image Turbo ControlNet") as demo:
    
    gr.HTML("""
        <div class="header-container">
            <div class="info-badge">✓ ControlNet Union</div>
            <h1 class="main-title">Z-Image Turbo</h1>
            <p class="subtitle">Multi-Control Generation</p>
        </div>
    """)
    
    with gr.Row():
        # Left Input Column
        with gr.Column(scale=1):
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="Describe the image you want to create...",
                lines=3
            )
            
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)

            negative_prompt = gr.Textbox(
                label="Negative Prompt",
                value="blurry, ugly, bad quality",
                lines=1
            )
            
            input_image = gr.Image(
                label="Control Image (Required)",
                type="pil",
                sources=['upload', 'clipboard'],
                height=300
            )
            
            control_mode = gr.Radio(
                choices=["Canny", "Depth", "HED", "MLSD", "Pose"],
                value="Canny",
                label="Control Mode",
                info="Select the type of structure to extract"
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=30, step=1, value=9)
                    guidance_scale = gr.Slider(label="Guidance", minimum=0.0, maximum=10.0, step=0.1, value=1.0)
                
                with gr.Row():
                    control_context_scale = gr.Slider(label="Control Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
                    image_scale = gr.Slider(label="Image Scale", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
                
                seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)

            generate_btn = gr.Button("Generate Image", variant="primary", elem_classes="primary")

        # Right Output Column
        with gr.Column(scale=1):
            output_image = gr.Image(label="Generated Image", type="pil")
            
            with gr.Accordion("Details & Debug", open=True):
                with gr.Row():
                    seed_output = gr.Number(label="Seed Used", precision=0)
                control_output = gr.Image(label="Preprocessor Output", type="pil")

    # Footer
    gr.HTML("""
        <div class="footer-text">
            Powered by Z-Image Turbo • VideoX-Fun • Tongyi-MAI
        </div>
    """)

    # Event Wiring
    generate_btn.click(
        fn=generate_image,
        inputs=[
            prompt, negative_prompt, input_image, control_mode,
            control_context_scale, image_scale, num_inference_steps,
            guidance_scale, seed, randomize_seed
        ],
        outputs=[output_image, seed_output, control_output]
    )

if __name__ == "__main__":
    demo.launch(share=False,
               css=apple_css)