Spaces:
Sleeping
Sleeping
| import einops | |
| from diffusers import StableDiffusionXLPipeline, IFPipeline | |
| from typing import List, Dict, Callable, Union | |
| import torch | |
| from .hooked_scheduler import HookedNoiseScheduler | |
| import spaces | |
| def retrieve(io): | |
| if isinstance(io, tuple): | |
| if len(io) == 1: | |
| return io[0] | |
| else: | |
| raise ValueError("A tuple should have length of 1") | |
| elif isinstance(io, torch.Tensor): | |
| return io | |
| else: | |
| raise ValueError("Input/Output must be a tensor, or 1-element tuple") | |
| class HookedDiffusionAbstractPipeline: | |
| parent_cls = None | |
| pipe = None | |
| def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False): | |
| if use_hooked_scheduler: | |
| pipe.scheduler = HookedNoiseScheduler(pipe.scheduler) | |
| self.__dict__['pipe'] = pipe | |
| self.use_hooked_scheduler = use_hooked_scheduler | |
| def from_pretrained(cls, *args, **kwargs): | |
| return cls(cls.parent_cls.from_pretrained(*args, **kwargs)) | |
| def run_with_hooks(self, | |
| *args, | |
| position_hook_dict: Dict[str, Union[Callable, List[Callable]]], | |
| **kwargs | |
| ): | |
| ''' | |
| Run the pipeline with hooks at specified positions. | |
| Returns the final output. | |
| Args: | |
| *args: Arguments to pass to the pipeline. | |
| position_hook_dict: A dictionary mapping positions to hooks. | |
| The keys are positions in the pipeline where the hooks should be registered. | |
| The values are either a single hook or a list of hooks to be registered at the specified position. | |
| Each hook should be a callable that takes three arguments: (module, input, output). | |
| **kwargs: Keyword arguments to pass to the pipeline. | |
| ''' | |
| hooks = [] | |
| for position, hook in position_hook_dict.items(): | |
| if isinstance(hook, list): | |
| for h in hook: | |
| hooks.append(self._register_general_hook(position, h)) | |
| else: | |
| hooks.append(self._register_general_hook(position, hook)) | |
| hooks = [hook for hook in hooks if hook is not None] | |
| try: | |
| output = self.pipe(*args, **kwargs) | |
| finally: | |
| for hook in hooks: | |
| hook.remove() | |
| if self.use_hooked_scheduler: | |
| self.pipe.scheduler.pre_hooks = [] | |
| self.pipe.scheduler.post_hooks = [] | |
| return output | |
| def run_with_cache(self, | |
| *args, | |
| positions_to_cache: List[str], | |
| save_input: bool = False, | |
| save_output: bool = True, | |
| **kwargs | |
| ): | |
| ''' | |
| Run the pipeline with caching at specified positions. | |
| This method allows you to cache the intermediate inputs and/or outputs of the pipeline | |
| at certain positions. The final output of the pipeline and a dictionary of cached values | |
| are returned. | |
| Args: | |
| *args: Arguments to pass to the pipeline. | |
| positions_to_cache (List[str]): A list of positions in the pipeline where intermediate | |
| inputs/outputs should be cached. | |
| save_input (bool, optional): If True, caches the input at each specified position. | |
| Defaults to False. | |
| save_output (bool, optional): If True, caches the output at each specified position. | |
| Defaults to True. | |
| **kwargs: Keyword arguments to pass to the pipeline. | |
| Returns: | |
| final_output: The final output of the pipeline after execution. | |
| cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions | |
| and values are dictionaries containing the cached 'input' and/or 'output' at each position, | |
| depending on the flags `save_input` and `save_output`. | |
| ''' | |
| cache_input, cache_output = dict() if save_input else None, dict() if save_output else None | |
| hooks = [ | |
| self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache | |
| ] | |
| hooks = [hook for hook in hooks if hook is not None] | |
| output = self.pipe(*args, **kwargs) | |
| for hook in hooks: | |
| hook.remove() | |
| if self.use_hooked_scheduler: | |
| self.pipe.scheduler.pre_hooks = [] | |
| self.pipe.scheduler.post_hooks = [] | |
| cache_dict = {} | |
| if save_input: | |
| for position, block in cache_input.items(): | |
| cache_input[position] = torch.stack(block, dim=1) | |
| cache_dict['input'] = cache_input | |
| if save_output: | |
| for position, block in cache_output.items(): | |
| cache_output[position] = torch.stack(block, dim=1) | |
| cache_dict['output'] = cache_output | |
| return output, cache_dict | |
| def run_with_hooks_and_cache(self, | |
| *args, | |
| position_hook_dict: Dict[str, Union[Callable, List[Callable]]], | |
| positions_to_cache: List[str] = [], | |
| save_input: bool = False, | |
| save_output: bool = True, | |
| **kwargs | |
| ): | |
| ''' | |
| Run the pipeline with hooks and caching at specified positions. | |
| This method allows you to register hooks at certain positions in the pipeline and | |
| cache intermediate inputs and/or outputs at specified positions. Hooks can be used | |
| for inspecting or modifying the pipeline's execution, and caching stores intermediate | |
| values for later inspection or use. | |
| Args: | |
| *args: Arguments to pass to the pipeline. | |
| position_hook_dict Dict[str, Union[Callable, List[Callable]]]: | |
| A dictionary where the keys are the positions in the pipeline, and the values | |
| are hooks (either a single hook or a list of hooks) to be registered at those positions. | |
| Each hook should be a callable that accepts three arguments: (module, input, output). | |
| positions_to_cache (List[str], optional): A list of positions in the pipeline where | |
| intermediate inputs/outputs should be cached. Defaults to an empty list. | |
| save_input (bool, optional): If True, caches the input at each specified position. | |
| Defaults to False. | |
| save_output (bool, optional): If True, caches the output at each specified position. | |
| Defaults to True. | |
| **kwargs: Additional keyword arguments to pass to the pipeline. | |
| Returns: | |
| final_output: The final output of the pipeline after execution. | |
| cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions | |
| and values are dictionaries containing the cached 'input' and/or 'output' at each position, | |
| depending on the flags `save_input` and `save_output`. | |
| ''' | |
| cache_input, cache_output = dict() if save_input else None, dict() if save_output else None | |
| hooks = [ | |
| self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache | |
| ] | |
| for position, hook in position_hook_dict.items(): | |
| if isinstance(hook, list): | |
| for h in hook: | |
| hooks.append(self._register_general_hook(position, h)) | |
| else: | |
| hooks.append(self._register_general_hook(position, hook)) | |
| hooks = [hook for hook in hooks if hook is not None] | |
| output = self.pipe(*args, **kwargs) | |
| for hook in hooks: | |
| hook.remove() | |
| if self.use_hooked_scheduler: | |
| self.pipe.scheduler.pre_hooks = [] | |
| self.pipe.scheduler.post_hooks = [] | |
| cache_dict = {} | |
| if save_input: | |
| for position, block in cache_input.items(): | |
| cache_input[position] = torch.stack(block, dim=1) | |
| cache_dict['input'] = cache_input | |
| if save_output: | |
| for position, block in cache_output.items(): | |
| cache_output[position] = torch.stack(block, dim=1) | |
| cache_dict['output'] = cache_output | |
| return output, cache_dict | |
| def _locate_block(self, position: str): | |
| ''' | |
| Locate the block at the specified position in the pipeline. | |
| ''' | |
| block = self.pipe | |
| for step in position.split('.'): | |
| if step.isdigit(): | |
| step = int(step) | |
| block = block[step] | |
| else: | |
| block = getattr(block, step) | |
| return block | |
| def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict): | |
| if position.endswith('$self_attention') or position.endswith('$cross_attention'): | |
| return self._register_cache_attention_hook(position, cache_output) | |
| if position == 'noise': | |
| def hook(model_output, timestep, sample, generator): | |
| if position not in cache_output: | |
| cache_output[position] = [] | |
| cache_output[position].append(sample) | |
| if self.use_hooked_scheduler: | |
| self.pipe.scheduler.post_hooks.append(hook) | |
| else: | |
| raise ValueError('Cannot cache noise without using hooked scheduler') | |
| return | |
| block = self._locate_block(position) | |
| def hook(module, input, kwargs, output): | |
| if cache_input is not None: | |
| if position not in cache_input: | |
| cache_input[position] = [] | |
| cache_input[position].append(retrieve(input)) | |
| if cache_output is not None: | |
| if position not in cache_output: | |
| cache_output[position] = [] | |
| cache_output[position].append(retrieve(output)) | |
| return block.register_forward_hook(hook, with_kwargs=True) | |
| def _register_cache_attention_hook(self, position, cache): | |
| attn_block = self._locate_block(position.split('$')[0]) | |
| if position.endswith('$self_attention'): | |
| attn_block = attn_block.attn1 | |
| elif position.endswith('$cross_attention'): | |
| attn_block = attn_block.attn2 | |
| else: | |
| raise ValueError('Wrong attention type') | |
| def hook(module, args, kwargs, output): | |
| hidden_states = args[0] | |
| encoder_hidden_states = kwargs['encoder_hidden_states'] | |
| attention_mask = kwargs['attention_mask'] | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| query = attn_block.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn_block.norm_cross is not None: | |
| encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states) | |
| key = attn_block.to_k(encoder_hidden_states) | |
| value = attn_block.to_v(encoder_hidden_states) | |
| query = attn_block.head_to_batch_dim(query) | |
| key = attn_block.head_to_batch_dim(key) | |
| value = attn_block.head_to_batch_dim(value) | |
| attention_probs = attn_block.get_attention_scores(query, key, attention_mask) | |
| attention_probs = attention_probs.view( | |
| batch_size, | |
| attention_probs.shape[0] // batch_size, | |
| attention_probs.shape[1], | |
| attention_probs.shape[2] | |
| ) | |
| if position not in cache: | |
| cache[position] = [] | |
| cache[position].append(attention_probs) | |
| return attn_block.register_forward_hook(hook, with_kwargs=True) | |
| def _register_general_hook(self, position, hook): | |
| if position == 'scheduler_pre': | |
| if not self.use_hooked_scheduler: | |
| raise ValueError('Cannot register hooks on scheduler without using hooked scheduler') | |
| self.pipe.scheduler.pre_hooks.append(hook) | |
| return | |
| elif position == 'scheduler_post': | |
| if not self.use_hooked_scheduler: | |
| raise ValueError('Cannot register hooks on scheduler without using hooked scheduler') | |
| self.pipe.scheduler.post_hooks.append(hook) | |
| return | |
| block = self._locate_block(position) | |
| return block.register_forward_hook(hook) | |
| def to(self, *args, **kwargs): | |
| self.pipe = self.pipe.to(*args, **kwargs) | |
| return self | |
| def __getattr__(self, name): | |
| return getattr(self.pipe, name) | |
| def __setattr__(self, name, value): | |
| return setattr(self.pipe, name, value) | |
| def __call__(self, *args, **kwargs): | |
| return self.pipe(*args, **kwargs) | |
| class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline): | |
| parent_cls = StableDiffusionXLPipeline | |
| class HookedIFPipeline(HookedDiffusionAbstractPipeline): | |
| parent_cls = IFPipeline | |