lulavc commited on
Commit
9cd69ee
·
verified ·
1 Parent(s): b9d8080

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +22 -64
app.py CHANGED
@@ -773,6 +773,9 @@ OUTPUT THE PROMPT NOW (nothing else):"""
773
  # ZEROGPU AOTI CONFIGURATION
774
  # =============================================================================
775
 
 
 
 
776
  # Inductor configuration optimized for diffusion transformers
777
  INDUCTOR_CONFIGS = {
778
  "conv_1x1_as_mm": True,
@@ -791,7 +794,9 @@ MIN_SEQ_LEN = 15360 # 1536x640 -> 192x80 -> 15,360
791
  MAX_SEQ_LEN = 65536 # 2048x2048 -> 256x256 -> 65,536
792
 
793
  # Environment variable to enable/disable AoTI compilation
794
- ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true"
 
 
795
 
796
  logger.info("Loading Z-Image-Turbo pipeline...")
797
 
@@ -809,76 +814,26 @@ except Exception as e:
809
  logger.warning(f"FA3 not available, using default SDPA attention: {e}")
810
 
811
  # =============================================================================
812
- # AOTI COMPILATION FUNCTION
813
  # =============================================================================
814
 
815
- @spaces.GPU(duration=1500)
816
- def compile_transformer_aoti():
817
- """
818
- Compile transformer ahead-of-time for 1.3x-1.8x speedup.
819
- Runs once at Space startup, takes ~5-10 minutes.
820
- """
821
- logger.info("Starting AoTI compilation for transformer...")
822
-
823
  try:
824
- # Step 1: Capture example inputs
825
- logger.info("Step 1/4: Capturing example inputs...")
826
- with spaces.aoti_capture(pipe_t2i.transformer) as call:
827
- pipe_t2i(
828
- "example prompt for compilation",
 
829
  height=1024,
830
  width=1024,
831
  num_inference_steps=1,
832
- guidance_scale=0.0,
 
 
833
  )
834
 
835
- # Step 2: Define dynamic shapes for multi-resolution support
836
- logger.info("Step 2/4: Configuring dynamic shapes...")
837
- logger.info(f"Captured kwargs keys: {list(call.kwargs.keys())}")
838
- from torch.export import Dim
839
- from torch.utils._pytree import tree_map
840
-
841
- # Define dynamic dimensions
842
- batch_dim = Dim("batch", min=1, max=4)
843
- seq_len_dim = Dim("seq_len", min=MIN_SEQ_LEN, max=MAX_SEQ_LEN)
844
-
845
- # Create dynamic shapes from captured kwargs
846
- dynamic_shapes = tree_map(lambda v: None, call.kwargs)
847
-
848
- # Apply dynamic dims to variable-size tensors
849
- if "hidden_states" in call.kwargs:
850
- dynamic_shapes["hidden_states"] = {0: batch_dim, 1: seq_len_dim}
851
- if "img_ids" in call.kwargs:
852
- dynamic_shapes["img_ids"] = {0: batch_dim, 1: seq_len_dim}
853
-
854
- # Step 3: Export the model
855
- logger.info("Step 3/4: Exporting model with torch.export...")
856
- exported = torch.export.export(
857
- pipe_t2i.transformer,
858
- args=call.args,
859
- kwargs=call.kwargs,
860
- dynamic_shapes=dynamic_shapes,
861
- )
862
-
863
- # Step 4: Compile with inductor
864
- logger.info("Step 4/4: Compiling with PyTorch Inductor (this takes several minutes)...")
865
- compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
866
-
867
- logger.info("AoTI compilation completed successfully!")
868
- return compiled
869
-
870
- except Exception as e:
871
- logger.error(f"AoTI compilation failed: {type(e).__name__}: {str(e)}")
872
- logger.warning("Falling back to non-compiled transformer")
873
- return None
874
-
875
- # =============================================================================
876
- # APPLY AOTI COMPILATION
877
- # =============================================================================
878
-
879
- if ENABLE_AOTI:
880
- try:
881
- compiled_transformer = compile_transformer_aoti()
882
  if compiled_transformer is not None:
883
  spaces.aoti_apply(compiled_transformer, pipe_t2i.transformer)
884
  logger.info("AoTI transformer applied successfully (1.3x-1.8x speedup expected)")
@@ -909,7 +864,10 @@ pipe_i2i = ZImageImg2ImgPipeline(
909
  scheduler=pipe_t2i.scheduler,
910
  )
911
 
912
- logger.info("Pipelines ready! (TF32 + FA3 + AoTI Transformer + VAE compile)")
 
 
 
913
 
914
  STYLES = ["None", "Photorealistic", "Cinematic", "Anime", "Digital Art",
915
  "Oil Painting", "Watercolor", "3D Render", "Fantasy", "Sci-Fi"]
 
773
  # ZEROGPU AOTI CONFIGURATION
774
  # =============================================================================
775
 
776
+ # Import the corrected AoTI compilation function
777
+ from aoti import compile_transformer_aoti
778
+
779
  # Inductor configuration optimized for diffusion transformers
780
  INDUCTOR_CONFIGS = {
781
  "conv_1x1_as_mm": True,
 
794
  MAX_SEQ_LEN = 65536 # 2048x2048 -> 256x256 -> 65,536
795
 
796
  # Environment variable to enable/disable AoTI compilation
797
+ # Disabled by default - Z-Image-Turbo transformer uses positional args (x, t, cap_feats)
798
+ # which requires special handling in torch.export. Enable with ENABLE_AOTI=true once fixed.
799
+ ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "false").lower() == "true"
800
 
801
  logger.info("Loading Z-Image-Turbo pipeline...")
802
 
 
814
  logger.warning(f"FA3 not available, using default SDPA attention: {e}")
815
 
816
  # =============================================================================
817
+ # APPLY AOTI COMPILATION
818
  # =============================================================================
819
 
820
+ if ENABLE_AOTI:
 
 
 
 
 
 
 
821
  try:
822
+ # Use the corrected compile function that handles positional args properly
823
+ @spaces.GPU(duration=1500)
824
+ def _compile_wrapper():
825
+ return compile_transformer_aoti(
826
+ pipe=pipe_t2i,
827
+ example_prompt="example prompt for compilation",
828
  height=1024,
829
  width=1024,
830
  num_inference_steps=1,
831
+ inductor_configs=INDUCTOR_CONFIGS,
832
+ min_seq_len=MIN_SEQ_LEN,
833
+ max_seq_len=MAX_SEQ_LEN,
834
  )
835
 
836
+ compiled_transformer = _compile_wrapper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
  if compiled_transformer is not None:
838
  spaces.aoti_apply(compiled_transformer, pipe_t2i.transformer)
839
  logger.info("AoTI transformer applied successfully (1.3x-1.8x speedup expected)")
 
864
  scheduler=pipe_t2i.scheduler,
865
  )
866
 
867
+ if ENABLE_AOTI:
868
+ logger.info("Pipelines ready! (TF32 + FA3 + AoTI Transformer + VAE compile)")
869
+ else:
870
+ logger.info("Pipelines ready! (TF32 + FA3 + VAE compile) - AoTI disabled")
871
 
872
  STYLES = ["None", "Photorealistic", "Cinematic", "Anime", "Digital Art",
873
  "Oil Painting", "Watercolor", "3D Render", "Fantasy", "Sci-Fi"]