Z-Image-Turbo / aoti.py
lulavc's picture
Add custom AoTI helper for pre-compiled blocks
d0bf03c verified
raw
history blame
1.27 kB
"""
AoTI (Ahead of Time Inductor) helper for ZeroGPU
Loads pre-compiled transformer blocks from Hugging Face Hub
"""
from typing import cast
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
"""
Load pre-compiled AoTI blocks from Hub repository.
Args:
module: The transformer module containing layers to replace
repo_id: HuggingFace repo with pre-compiled blocks (e.g., 'zerogpu-aoti/Z-Image')
variant: Optional variant like 'fa3' for FlashAttention-3 compiled blocks
"""
repeated_blocks = cast(list[str], module._repeated_blocks)
aoti_files = {name: hf_hub_download(
repo_id=repo_id,
filename='package.pt2',
subfolder=name if variant is None else f'{name}.{variant}',
) for name in repeated_blocks}
for block_name, aoti_file in aoti_files.items():
for block in module.modules():
if block.__class__.__name__ == block_name:
weights = ZeroGPUWeights(block.state_dict())
block.forward = ZeroGPUCompiledModel(aoti_file, weights)