Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ZeroGPU AoTI (Ahead-of-Time Inductor) compilation module for Z-Image-Turbo. | |
| This module provides the compile_transformer_aoti() function that handles | |
| the Z-Image-Turbo transformer's specific forward signature: | |
| forward(x, t, cap_feats, return_dict=True) | |
| Where: | |
| - x: hidden states / latent sequence (dynamic shape) | |
| - t: timestep | |
| - cap_feats: caption/text features | |
| - return_dict: whether to return dict or tuple | |
| """ | |
| import logging | |
| import inspect | |
| import torch | |
| import spaces | |
| logger = logging.getLogger(__name__) | |
| def compile_transformer_aoti( | |
| pipe, | |
| example_prompt: str = "example prompt for compilation", | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 1, | |
| inductor_configs: dict = None, | |
| min_seq_len: int = 15360, | |
| max_seq_len: int = 65536, | |
| ): | |
| """ | |
| Compile transformer ahead-of-time for 1.3x-1.8x speedup. | |
| This function correctly handles the Z-Image-Turbo transformer's forward | |
| signature which uses positional args (x, t, cap_feats) rather than kwargs. | |
| Args: | |
| pipe: The DiffusionPipeline instance | |
| example_prompt: Prompt to use for capturing example inputs | |
| height: Example image height | |
| width: Example image width | |
| num_inference_steps: Steps for example inference | |
| inductor_configs: PyTorch Inductor configuration dict | |
| min_seq_len: Minimum sequence length for dynamic shapes | |
| max_seq_len: Maximum sequence length for dynamic shapes | |
| Returns: | |
| Compiled model path or None if compilation fails | |
| """ | |
| logger.info("Starting AoTI compilation for transformer...") | |
| if inductor_configs is None: | |
| inductor_configs = {} | |
| try: | |
| # Step 1: Capture example inputs | |
| logger.info("Step 1/4: Capturing example inputs...") | |
| with spaces.aoti_capture(pipe.transformer) as call: | |
| pipe( | |
| example_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=0.0, | |
| ) | |
| # Step 2: Map positional args to parameter names | |
| logger.info("Step 2/4: Configuring dynamic shapes...") | |
| # Get the transformer's forward signature to map positional args | |
| sig = inspect.signature(pipe.transformer.forward) | |
| param_names = [p.name for p in sig.parameters.values() if p.name != 'self'] | |
| logger.info(f"Forward signature params: {param_names}") | |
| logger.info(f"Captured positional args: {len(call.args)}") | |
| logger.info(f"Captured kwargs keys: {list(call.kwargs.keys())}") | |
| # Convert positional args to named kwargs | |
| args_as_kwargs = {} | |
| for i, arg in enumerate(call.args): | |
| if i < len(param_names): | |
| args_as_kwargs[param_names[i]] = arg | |
| # Combine with actual kwargs | |
| combined_kwargs = {**args_as_kwargs, **call.kwargs} | |
| logger.info(f"Combined kwargs keys: {list(combined_kwargs.keys())}") | |
| # Step 3: Define dynamic shapes for the sequence dimension | |
| from torch.export import Dim | |
| from torch.utils._pytree import tree_map | |
| # Create base dynamic shapes (all None) | |
| dynamic_shapes = tree_map(lambda v: None, combined_kwargs) | |
| # Define dynamic dimension for sequence length | |
| batch_dim = Dim("batch", min=1, max=4) | |
| seq_len_dim = Dim("seq_len", min=min_seq_len, max=max_seq_len) | |
| # Apply dynamic shapes to the latent input (x) | |
| # x shape is typically (batch, seq_len, hidden_dim) | |
| if 'x' in combined_kwargs: | |
| x_tensor = combined_kwargs['x'] | |
| if hasattr(x_tensor, 'shape') and len(x_tensor.shape) >= 2: | |
| dynamic_shapes['x'] = {0: batch_dim, 1: seq_len_dim} | |
| logger.info(f"Set dynamic shapes for 'x': batch={batch_dim}, seq_len={seq_len_dim}") | |
| # Step 4: Export the model | |
| logger.info("Step 3/4: Exporting model with torch.export...") | |
| # Export with all inputs as kwargs (no positional args) | |
| exported = torch.export.export( | |
| pipe.transformer, | |
| args=(), # Empty - all inputs via kwargs | |
| kwargs=combined_kwargs, | |
| dynamic_shapes=dynamic_shapes, | |
| ) | |
| # Step 5: Compile with inductor | |
| logger.info("Step 4/4: Compiling with PyTorch Inductor (this takes several minutes)...") | |
| compiled = spaces.aoti_compile(exported, inductor_configs) | |
| logger.info("AoTI compilation completed successfully!") | |
| return compiled | |
| except Exception as e: | |
| logger.error(f"AoTI compilation failed: {type(e).__name__}: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| logger.warning("Falling back to non-compiled transformer") | |
| return None | |