Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 HuggingFace Team | |
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache License, Version 2.0 (the "License") | |
| # | |
| # This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 | |
| # | |
| # Original file was released under Apache License, Version 2.0 (the "License"), with the full license text | |
| # available at http://www.apache.org/licenses/LICENSE-2.0. | |
| # | |
| # This modified file is released under the same license. | |
| from contextlib import nullcontext | |
| from typing import Literal, Optional, Tuple, Union | |
| import diffusers | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models.attention_processor import Attention, SpatialNorm | |
| from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution | |
| from diffusers.models.downsampling import Downsample2D | |
| from diffusers.models.lora import LoRACompatibleConv | |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
| from diffusers.models.resnet import ResnetBlock2D | |
| from diffusers.models.unets.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D | |
| from diffusers.models.upsampling import Upsample2D | |
| from diffusers.utils import is_torch_version | |
| from diffusers.utils.accelerate_utils import apply_forward_hook | |
| from einops import rearrange | |
| from common.distributed.advanced import get_sequence_parallel_world_size | |
| from common.logger import get_logger | |
| from models.video_vae_v3.modules.causal_inflation_lib import ( | |
| InflatedCausalConv3d, | |
| causal_norm_wrapper, | |
| init_causal_conv3d, | |
| remove_head, | |
| ) | |
| from models.video_vae_v3.modules.context_parallel_lib import ( | |
| causal_conv_gather_outputs, | |
| causal_conv_slice_inputs, | |
| ) | |
| from models.video_vae_v3.modules.global_config import set_norm_limit | |
| from models.video_vae_v3.modules.types import ( | |
| CausalAutoencoderOutput, | |
| CausalDecoderOutput, | |
| CausalEncoderOutput, | |
| MemoryState, | |
| _inflation_mode_t, | |
| _memory_device_t, | |
| _receptive_field_t, | |
| ) | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | |
| class Upsample3D(Upsample2D): | |
| """A 3D upsampling layer with an optional convolution.""" | |
| def __init__( | |
| self, | |
| *args, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| temporal_up: bool = False, | |
| spatial_up: bool = True, | |
| slicing: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| conv = self.conv if self.name == "conv" else self.Conv2d_0 | |
| assert type(conv) is not nn.ConvTranspose2d | |
| # Note: lora_layer is not passed into constructor in the original implementation. | |
| # So we make a simplification. | |
| conv = init_causal_conv3d( | |
| self.channels, | |
| self.out_channels, | |
| 3, | |
| padding=1, | |
| inflation_mode=inflation_mode, | |
| ) | |
| self.temporal_up = temporal_up | |
| self.spatial_up = spatial_up | |
| self.temporal_ratio = 2 if temporal_up else 1 | |
| self.spatial_ratio = 2 if spatial_up else 1 | |
| self.slicing = slicing | |
| assert not self.interpolate | |
| # [Override] MAGViT v2 implementation | |
| if not self.interpolate: | |
| upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio | |
| self.upscale_conv = nn.Conv3d( | |
| self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 | |
| ) | |
| identity = ( | |
| torch.eye(self.channels) | |
| .repeat(upscale_ratio, 1) | |
| .reshape_as(self.upscale_conv.weight) | |
| ) | |
| self.upscale_conv.weight.data.copy_(identity) | |
| nn.init.zeros_(self.upscale_conv.bias) | |
| if self.name == "conv": | |
| self.conv = conv | |
| else: | |
| self.Conv2d_0 = conv | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| output_size: Optional[int] = None, | |
| memory_state: MemoryState = MemoryState.DISABLED, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| assert hidden_states.shape[1] == self.channels | |
| if hasattr(self, "norm") and self.norm is not None: | |
| # [Overridden] change to causal norm. | |
| hidden_states = causal_norm_wrapper(self.norm, hidden_states) | |
| if self.use_conv_transpose: | |
| return self.conv(hidden_states) | |
| if self.slicing: | |
| split_size = hidden_states.size(2) // 2 | |
| hidden_states = list( | |
| hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) | |
| ) | |
| else: | |
| hidden_states = [hidden_states] | |
| for i in range(len(hidden_states)): | |
| hidden_states[i] = self.upscale_conv(hidden_states[i]) | |
| hidden_states[i] = rearrange( | |
| hidden_states[i], | |
| "b (x y z c) f h w -> b c (f z) (h x) (w y)", | |
| x=self.spatial_ratio, | |
| y=self.spatial_ratio, | |
| z=self.temporal_ratio, | |
| ) | |
| # [Overridden] For causal temporal conv | |
| if self.temporal_up and memory_state != MemoryState.ACTIVE: | |
| hidden_states[0] = remove_head(hidden_states[0]) | |
| if not self.slicing: | |
| hidden_states = hidden_states[0] | |
| if self.use_conv: | |
| if self.name == "conv": | |
| hidden_states = self.conv(hidden_states, memory_state=memory_state) | |
| else: | |
| hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) | |
| if not self.slicing: | |
| return hidden_states | |
| else: | |
| return torch.cat(hidden_states, dim=2) | |
| class Downsample3D(Downsample2D): | |
| """A 3D downsampling layer with an optional convolution.""" | |
| def __init__( | |
| self, | |
| *args, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| spatial_down: bool = False, | |
| temporal_down: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| conv = self.conv | |
| self.temporal_down = temporal_down | |
| self.spatial_down = spatial_down | |
| self.temporal_ratio = 2 if temporal_down else 1 | |
| self.spatial_ratio = 2 if spatial_down else 1 | |
| self.temporal_kernel = 3 if temporal_down else 1 | |
| self.spatial_kernel = 3 if spatial_down else 1 | |
| if type(conv) in [nn.Conv2d, LoRACompatibleConv]: | |
| # Note: lora_layer is not passed into constructor in the original implementation. | |
| # So we make a simplification. | |
| conv = init_causal_conv3d( | |
| self.channels, | |
| self.out_channels, | |
| kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), | |
| stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), | |
| padding=( | |
| 1 if self.temporal_down else 0, | |
| self.padding if self.spatial_down else 0, | |
| self.padding if self.spatial_down else 0, | |
| ), | |
| inflation_mode=inflation_mode, | |
| ) | |
| elif type(conv) is nn.AvgPool2d: | |
| assert self.channels == self.out_channels | |
| conv = nn.AvgPool3d( | |
| kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), | |
| stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if self.name == "conv": | |
| self.Conv2d_0 = conv | |
| self.conv = conv | |
| else: | |
| self.conv = conv | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| memory_state: MemoryState = MemoryState.DISABLED, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| assert hidden_states.shape[1] == self.channels | |
| if hasattr(self, "norm") and self.norm is not None: | |
| # [Overridden] change to causal norm. | |
| hidden_states = causal_norm_wrapper(self.norm, hidden_states) | |
| if self.use_conv and self.padding == 0 and self.spatial_down: | |
| pad = (0, 1, 0, 1) | |
| hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) | |
| assert hidden_states.shape[1] == self.channels | |
| hidden_states = self.conv(hidden_states, memory_state=memory_state) | |
| return hidden_states | |
| class ResnetBlock3D(ResnetBlock2D): | |
| def __init__( | |
| self, | |
| *args, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "half", | |
| slicing: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.conv1 = init_causal_conv3d( | |
| self.in_channels, | |
| self.out_channels, | |
| kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), | |
| stride=1, | |
| padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), | |
| inflation_mode=inflation_mode, | |
| ) | |
| self.conv2 = init_causal_conv3d( | |
| self.out_channels, | |
| self.conv2.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| inflation_mode=inflation_mode, | |
| ) | |
| if self.up: | |
| assert type(self.upsample) is Upsample2D | |
| self.upsample = Upsample3D( | |
| self.in_channels, | |
| use_conv=False, | |
| inflation_mode=inflation_mode, | |
| slicing=slicing, | |
| ) | |
| elif self.down: | |
| assert type(self.downsample) is Downsample2D | |
| self.downsample = Downsample3D( | |
| self.in_channels, | |
| use_conv=False, | |
| padding=1, | |
| name="op", | |
| inflation_mode=inflation_mode, | |
| ) | |
| if self.use_in_shortcut: | |
| self.conv_shortcut = init_causal_conv3d( | |
| self.in_channels, | |
| self.conv_shortcut.out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=(self.conv_shortcut.bias is not None), | |
| inflation_mode=inflation_mode, | |
| ) | |
| def forward( | |
| self, input_tensor, temb, memory_state: MemoryState = MemoryState.DISABLED, **kwargs | |
| ): | |
| hidden_states = input_tensor | |
| hidden_states = causal_norm_wrapper(self.norm1, hidden_states) | |
| hidden_states = self.nonlinearity(hidden_states) | |
| if self.upsample is not None: | |
| # upsample_nearest_nhwc fails with large batch sizes. | |
| # see https://github.com/huggingface/diffusers/issues/984 | |
| if hidden_states.shape[0] >= 64: | |
| input_tensor = input_tensor.contiguous() | |
| hidden_states = hidden_states.contiguous() | |
| input_tensor = self.upsample(input_tensor, memory_state=memory_state) | |
| hidden_states = self.upsample(hidden_states, memory_state=memory_state) | |
| elif self.downsample is not None: | |
| input_tensor = self.downsample(input_tensor, memory_state=memory_state) | |
| hidden_states = self.downsample(hidden_states, memory_state=memory_state) | |
| hidden_states = self.conv1(hidden_states, memory_state=memory_state) | |
| if self.time_emb_proj is not None: | |
| if not self.skip_time_act: | |
| temb = self.nonlinearity(temb) | |
| temb = self.time_emb_proj(temb)[:, :, None, None] | |
| if temb is not None and self.time_embedding_norm == "default": | |
| hidden_states = hidden_states + temb | |
| hidden_states = causal_norm_wrapper(self.norm2, hidden_states) | |
| if temb is not None and self.time_embedding_norm == "scale_shift": | |
| scale, shift = torch.chunk(temb, 2, dim=1) | |
| hidden_states = hidden_states * (1 + scale) + shift | |
| hidden_states = self.nonlinearity(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.conv2(hidden_states, memory_state=memory_state) | |
| if self.conv_shortcut is not None: | |
| input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) | |
| output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | |
| return output_tensor | |
| class DownEncoderBlock3D(DownEncoderBlock2D): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| output_scale_factor: float = 1.0, | |
| add_downsample: bool = True, | |
| downsample_padding: int = 1, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "half", | |
| temporal_down: bool = True, | |
| spatial_down: bool = True, | |
| ): | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| dropout=dropout, | |
| num_layers=num_layers, | |
| resnet_eps=resnet_eps, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| resnet_pre_norm=resnet_pre_norm, | |
| output_scale_factor=output_scale_factor, | |
| add_downsample=add_downsample, | |
| downsample_padding=downsample_padding, | |
| ) | |
| resnets = [] | |
| temporal_modules = [] | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| # [Override] Replace module. | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=None, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| ) | |
| temporal_modules.append(nn.Identity()) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.temporal_modules = nn.ModuleList(temporal_modules) | |
| if add_downsample: | |
| self.downsamplers = nn.ModuleList( | |
| [ | |
| # [Override] Replace module. | |
| Downsample3D( | |
| out_channels, | |
| use_conv=True, | |
| out_channels=out_channels, | |
| padding=downsample_padding, | |
| name="op", | |
| temporal_down=temporal_down, | |
| spatial_down=spatial_down, | |
| inflation_mode=inflation_mode, | |
| ) | |
| ] | |
| ) | |
| else: | |
| self.downsamplers = None | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| memory_state: MemoryState = MemoryState.DISABLED, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| for resnet, temporal in zip(self.resnets, self.temporal_modules): | |
| hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) | |
| hidden_states = temporal(hidden_states) | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states, memory_state=memory_state) | |
| return hidden_states | |
| class UpDecoderBlock3D(UpDecoderBlock2D): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", # default, spatial | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| output_scale_factor: float = 1.0, | |
| add_upsample: bool = True, | |
| temb_channels: Optional[int] = None, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "half", | |
| temporal_up: bool = True, | |
| spatial_up: bool = True, | |
| slicing: bool = False, | |
| ): | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| dropout=dropout, | |
| num_layers=num_layers, | |
| resnet_eps=resnet_eps, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| resnet_pre_norm=resnet_pre_norm, | |
| output_scale_factor=output_scale_factor, | |
| add_upsample=add_upsample, | |
| temb_channels=temb_channels, | |
| ) | |
| resnets = [] | |
| temporal_modules = [] | |
| for i in range(num_layers): | |
| input_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| # [Override] Replace module. | |
| ResnetBlock3D( | |
| in_channels=input_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| slicing=slicing, | |
| ) | |
| ) | |
| temporal_modules.append(nn.Identity()) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.temporal_modules = nn.ModuleList(temporal_modules) | |
| if add_upsample: | |
| # [Override] Replace module & use learnable upsample | |
| self.upsamplers = nn.ModuleList( | |
| [ | |
| Upsample3D( | |
| out_channels, | |
| use_conv=True, | |
| out_channels=out_channels, | |
| temporal_up=temporal_up, | |
| spatial_up=spatial_up, | |
| interpolate=False, | |
| inflation_mode=inflation_mode, | |
| slicing=slicing, | |
| ) | |
| ] | |
| ) | |
| else: | |
| self.upsamplers = None | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| temb: Optional[torch.FloatTensor] = None, | |
| memory_state: MemoryState = MemoryState.DISABLED, | |
| ) -> torch.FloatTensor: | |
| for resnet, temporal in zip(self.resnets, self.temporal_modules): | |
| hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) | |
| hidden_states = temporal(hidden_states) | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, memory_state=memory_state) | |
| return hidden_states | |
| class UNetMidBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", # default, spatial | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| add_attention: bool = True, | |
| attention_head_dim: int = 1, | |
| output_scale_factor: float = 1.0, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "half", | |
| ): | |
| super().__init__() | |
| resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
| self.add_attention = add_attention | |
| # there is always at least one resnet | |
| resnets = [ | |
| # [Override] Replace module. | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| ] | |
| attentions = [] | |
| if attention_head_dim is None: | |
| logger.warn( | |
| f"It is not recommend to pass `attention_head_dim=None`. " | |
| f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." | |
| ) | |
| attention_head_dim = in_channels | |
| for _ in range(num_layers): | |
| if self.add_attention: | |
| attentions.append( | |
| Attention( | |
| in_channels, | |
| heads=in_channels // attention_head_dim, | |
| dim_head=attention_head_dim, | |
| rescale_output_factor=output_scale_factor, | |
| eps=resnet_eps, | |
| norm_num_groups=( | |
| resnet_groups if resnet_time_scale_shift == "default" else None | |
| ), | |
| spatial_norm_dim=( | |
| temb_channels if resnet_time_scale_shift == "spatial" else None | |
| ), | |
| residual_connection=True, | |
| bias=True, | |
| upcast_softmax=True, | |
| _from_deprecated_attn_block=True, | |
| ) | |
| ) | |
| else: | |
| attentions.append(None) | |
| resnets.append( | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| def forward(self, hidden_states, temb=None, memory_state: MemoryState = MemoryState.DISABLED): | |
| video_length, frame_height, frame_width = hidden_states.size()[-3:] | |
| hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) | |
| for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
| if attn is not None: | |
| hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") | |
| hidden_states = attn(hidden_states, temb=temb) | |
| hidden_states = rearrange( | |
| hidden_states, "(b f) c h w -> b c f h w", f=video_length | |
| ) | |
| hidden_states = resnet(hidden_states, temb, memory_state=memory_state) | |
| return hidden_states | |
| class Encoder3D(nn.Module): | |
| r""" | |
| [Override] override most logics to support extra condition input and causal conv | |
| The `Encoder` layer of a variational autoencoder that encodes | |
| its input into a latent representation. | |
| Args: | |
| in_channels (`int`, *optional*, defaults to 3): | |
| The number of input channels. | |
| out_channels (`int`, *optional*, defaults to 3): | |
| The number of output channels. | |
| down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): | |
| The types of down blocks to use. | |
| See `~diffusers.models.unet_2d_blocks.get_down_block` | |
| for available options. | |
| block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
| The number of output channels for each block. | |
| layers_per_block (`int`, *optional*, defaults to 2): | |
| The number of layers per block. | |
| norm_num_groups (`int`, *optional*, defaults to 32): | |
| The number of groups for normalization. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): | |
| The activation function to use. | |
| See `~diffusers.models.activations.get_activation` for available options. | |
| double_z (`bool`, *optional*, defaults to `True`): | |
| Whether to double the number of output channels for the last block. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), | |
| block_out_channels: Tuple[int, ...] = (64,), | |
| layers_per_block: int = 2, | |
| norm_num_groups: int = 32, | |
| act_fn: str = "silu", | |
| double_z: bool = True, | |
| mid_block_add_attention=True, | |
| # [Override] add extra_cond_dim, temporal down num | |
| temporal_down_num: int = 2, | |
| extra_cond_dim: int = None, | |
| gradient_checkpoint: bool = False, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "half", | |
| ): | |
| super().__init__() | |
| self.layers_per_block = layers_per_block | |
| self.temporal_down_num = temporal_down_num | |
| self.conv_in = init_causal_conv3d( | |
| in_channels, | |
| block_out_channels[0], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| inflation_mode=inflation_mode, | |
| ) | |
| self.mid_block = None | |
| self.down_blocks = nn.ModuleList([]) | |
| self.extra_cond_dim = extra_cond_dim | |
| self.conv_extra_cond = nn.ModuleList([]) | |
| # down | |
| output_channel = block_out_channels[0] | |
| for i, down_block_type in enumerate(down_block_types): | |
| input_channel = output_channel | |
| output_channel = block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| # [Override] to support temporal down block design | |
| is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 | |
| # Note: take the last ones | |
| assert down_block_type == "DownEncoderBlock3D" | |
| down_block = DownEncoderBlock3D( | |
| num_layers=self.layers_per_block, | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| add_downsample=not is_final_block, | |
| resnet_eps=1e-6, | |
| downsample_padding=0, | |
| # Note: Don't know why set it as 0 | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| temporal_down=is_temporal_down_block, | |
| spatial_down=True, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| self.down_blocks.append(down_block) | |
| def zero_module(module): | |
| # Zero out the parameters of a module and return it. | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| self.conv_extra_cond.append( | |
| zero_module( | |
| nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) | |
| ) | |
| if self.extra_cond_dim is not None and self.extra_cond_dim > 0 | |
| else None | |
| ) | |
| # mid | |
| self.mid_block = UNetMidBlock3D( | |
| in_channels=block_out_channels[-1], | |
| resnet_eps=1e-6, | |
| resnet_act_fn=act_fn, | |
| output_scale_factor=1, | |
| resnet_time_scale_shift="default", | |
| attention_head_dim=block_out_channels[-1], | |
| resnet_groups=norm_num_groups, | |
| temb_channels=None, | |
| add_attention=mid_block_add_attention, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| # out | |
| self.conv_norm_out = nn.GroupNorm( | |
| num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 | |
| ) | |
| self.conv_act = nn.SiLU() | |
| conv_out_channels = 2 * out_channels if double_z else out_channels | |
| self.conv_out = init_causal_conv3d( | |
| block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode | |
| ) | |
| self.gradient_checkpointing = gradient_checkpoint | |
| def forward( | |
| self, | |
| sample: torch.FloatTensor, | |
| extra_cond=None, | |
| memory_state: MemoryState = MemoryState.DISABLED, | |
| ) -> torch.FloatTensor: | |
| r"""The forward method of the `Encoder` class.""" | |
| sample = self.conv_in(sample, memory_state=memory_state) | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| # down | |
| # [Override] add extra block and extra cond | |
| for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(down_block), sample, memory_state, use_reentrant=False | |
| ) | |
| if extra_block is not None: | |
| sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) | |
| # middle | |
| sample = self.mid_block(sample, memory_state=memory_state) | |
| # sample = torch.utils.checkpoint.checkpoint( | |
| # create_custom_forward(self.mid_block), sample, use_reentrant=False | |
| # ) | |
| else: | |
| # down | |
| # [Override] add extra block and extra cond | |
| for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): | |
| sample = down_block(sample, memory_state=memory_state) | |
| if extra_block is not None: | |
| sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) | |
| # middle | |
| sample = self.mid_block(sample, memory_state=memory_state) | |
| # post-process | |
| sample = causal_norm_wrapper(self.conv_norm_out, sample) | |
| sample = self.conv_act(sample) | |
| sample = self.conv_out(sample, memory_state=memory_state) | |
| return sample | |
| class Decoder3D(nn.Module): | |
| r""" | |
| The `Decoder` layer of a variational autoencoder that | |
| decodes its latent representation into an output sample. | |
| Args: | |
| in_channels (`int`, *optional*, defaults to 3): | |
| The number of input channels. | |
| out_channels (`int`, *optional*, defaults to 3): | |
| The number of output channels. | |
| up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): | |
| The types of up blocks to use. | |
| See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. | |
| block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
| The number of output channels for each block. | |
| layers_per_block (`int`, *optional*, defaults to 2): | |
| The number of layers per block. | |
| norm_num_groups (`int`, *optional*, defaults to 32): | |
| The number of groups for normalization. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): | |
| The activation function to use. | |
| See `~diffusers.models.activations.get_activation` for available options. | |
| norm_type (`str`, *optional*, defaults to `"group"`): | |
| The normalization type to use. Can be either `"group"` or `"spatial"`. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), | |
| block_out_channels: Tuple[int, ...] = (64,), | |
| layers_per_block: int = 2, | |
| norm_num_groups: int = 32, | |
| act_fn: str = "silu", | |
| norm_type: str = "group", # group, spatial | |
| mid_block_add_attention=True, | |
| # [Override] add temporal up block | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "half", | |
| temporal_up_num: int = 2, | |
| slicing_up_num: int = 0, | |
| gradient_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.layers_per_block = layers_per_block | |
| self.temporal_up_num = temporal_up_num | |
| self.conv_in = init_causal_conv3d( | |
| in_channels, | |
| block_out_channels[-1], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| inflation_mode=inflation_mode, | |
| ) | |
| self.mid_block = None | |
| self.up_blocks = nn.ModuleList([]) | |
| temb_channels = in_channels if norm_type == "spatial" else None | |
| # mid | |
| self.mid_block = UNetMidBlock3D( | |
| in_channels=block_out_channels[-1], | |
| resnet_eps=1e-6, | |
| resnet_act_fn=act_fn, | |
| output_scale_factor=1, | |
| resnet_time_scale_shift="default" if norm_type == "group" else norm_type, | |
| attention_head_dim=block_out_channels[-1], | |
| resnet_groups=norm_num_groups, | |
| temb_channels=temb_channels, | |
| add_attention=mid_block_add_attention, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| # up | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| output_channel = reversed_block_out_channels[0] | |
| print(f"slicing_up_num: {slicing_up_num}") | |
| for i, up_block_type in enumerate(up_block_types): | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| is_temporal_up_block = i < self.temporal_up_num | |
| is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num | |
| # Note: Keep symmetric | |
| assert up_block_type == "UpDecoderBlock3D" | |
| up_block = UpDecoderBlock3D( | |
| num_layers=self.layers_per_block + 1, | |
| in_channels=prev_output_channel, | |
| out_channels=output_channel, | |
| add_upsample=not is_final_block, | |
| resnet_eps=1e-6, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| resnet_time_scale_shift=norm_type, | |
| temb_channels=temb_channels, | |
| temporal_up=is_temporal_up_block, | |
| slicing=is_slicing_up_block, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| self.up_blocks.append(up_block) | |
| prev_output_channel = output_channel | |
| # out | |
| if norm_type == "spatial": | |
| self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) | |
| else: | |
| self.conv_norm_out = nn.GroupNorm( | |
| num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 | |
| ) | |
| self.conv_act = nn.SiLU() | |
| self.conv_out = init_causal_conv3d( | |
| block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode | |
| ) | |
| self.gradient_checkpointing = gradient_checkpoint | |
| # Note: Just copy from Decoder. | |
| def forward( | |
| self, | |
| sample: torch.FloatTensor, | |
| latent_embeds: Optional[torch.FloatTensor] = None, | |
| memory_state: MemoryState = MemoryState.DISABLED, | |
| ) -> torch.FloatTensor: | |
| r"""The forward method of the `Decoder` class.""" | |
| sample = self.conv_in(sample, memory_state=memory_state) | |
| upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| if is_torch_version(">=", "1.11.0"): | |
| sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) | |
| sample = sample.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(up_block), | |
| sample, | |
| latent_embeds, | |
| memory_state, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| # middle | |
| sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) | |
| sample = sample.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(up_block), sample, latent_embeds, memory_state | |
| ) | |
| else: | |
| # middle | |
| sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) | |
| sample = sample.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| sample = up_block(sample, latent_embeds, memory_state=memory_state) | |
| # post-process | |
| sample = causal_norm_wrapper(self.conv_norm_out, sample) | |
| sample = self.conv_act(sample) | |
| sample = self.conv_out(sample, memory_state=memory_state) | |
| return sample | |
| class AutoencoderKL(diffusers.AutoencoderKL): | |
| """ | |
| We simply inherit the model code from diffusers | |
| """ | |
| def __init__(self, attention: bool = True, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # A hacky way to remove attention. | |
| if not attention: | |
| self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) | |
| self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) | |
| def load_state_dict(self, state_dict, strict=True): | |
| # Newer version of diffusers changed the model keys, | |
| # causing incompatibility with old checkpoints. | |
| # They provided a method for conversion. We call conversion before loading state_dict. | |
| convert_deprecated_attention_blocks = getattr( | |
| self, "_convert_deprecated_attention_blocks", None | |
| ) | |
| if callable(convert_deprecated_attention_blocks): | |
| convert_deprecated_attention_blocks(state_dict) | |
| return super().load_state_dict(state_dict, strict) | |
| class VideoAutoencoderKL(diffusers.AutoencoderKL): | |
| """ | |
| We simply inherit the model code from diffusers | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| down_block_types: Tuple[str] = ("DownEncoderBlock3D",), | |
| up_block_types: Tuple[str] = ("UpDecoderBlock3D",), | |
| block_out_channels: Tuple[int] = (64,), | |
| layers_per_block: int = 1, | |
| act_fn: str = "silu", | |
| latent_channels: int = 4, | |
| norm_num_groups: int = 32, | |
| sample_size: int = 32, | |
| scaling_factor: float = 0.18215, | |
| force_upcast: float = True, | |
| attention: bool = True, | |
| temporal_scale_num: int = 2, | |
| slicing_up_num: int = 0, | |
| gradient_checkpoint: bool = False, | |
| inflation_mode: _inflation_mode_t = "tail", | |
| time_receptive_field: _receptive_field_t = "full", | |
| slicing_sample_min_size: int = 32, | |
| use_quant_conv: bool = True, | |
| use_post_quant_conv: bool = True, | |
| *args, | |
| **kwargs, | |
| ): | |
| extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None | |
| self.slicing_sample_min_size = slicing_sample_min_size | |
| self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| # [Override] make sure it can be normally initialized | |
| down_block_types=tuple( | |
| [down_block_type.replace("3D", "2D") for down_block_type in down_block_types] | |
| ), | |
| up_block_types=tuple( | |
| [up_block_type.replace("3D", "2D") for up_block_type in up_block_types] | |
| ), | |
| block_out_channels=block_out_channels, | |
| layers_per_block=layers_per_block, | |
| act_fn=act_fn, | |
| latent_channels=latent_channels, | |
| norm_num_groups=norm_num_groups, | |
| sample_size=sample_size, | |
| scaling_factor=scaling_factor, | |
| force_upcast=force_upcast, | |
| *args, | |
| **kwargs, | |
| ) | |
| # pass init params to Encoder | |
| self.encoder = Encoder3D( | |
| in_channels=in_channels, | |
| out_channels=latent_channels, | |
| down_block_types=down_block_types, | |
| block_out_channels=block_out_channels, | |
| layers_per_block=layers_per_block, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| double_z=True, | |
| extra_cond_dim=extra_cond_dim, | |
| # [Override] add temporal_down_num parameter | |
| temporal_down_num=temporal_scale_num, | |
| gradient_checkpoint=gradient_checkpoint, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| # pass init params to Decoder | |
| self.decoder = Decoder3D( | |
| in_channels=latent_channels, | |
| out_channels=out_channels, | |
| up_block_types=up_block_types, | |
| block_out_channels=block_out_channels, | |
| layers_per_block=layers_per_block, | |
| norm_num_groups=norm_num_groups, | |
| act_fn=act_fn, | |
| # [Override] add temporal_up_num parameter | |
| temporal_up_num=temporal_scale_num, | |
| slicing_up_num=slicing_up_num, | |
| gradient_checkpoint=gradient_checkpoint, | |
| inflation_mode=inflation_mode, | |
| time_receptive_field=time_receptive_field, | |
| ) | |
| self.quant_conv = ( | |
| init_causal_conv3d( | |
| in_channels=2 * latent_channels, | |
| out_channels=2 * latent_channels, | |
| kernel_size=1, | |
| inflation_mode=inflation_mode, | |
| ) | |
| if use_quant_conv | |
| else None | |
| ) | |
| self.post_quant_conv = ( | |
| init_causal_conv3d( | |
| in_channels=latent_channels, | |
| out_channels=latent_channels, | |
| kernel_size=1, | |
| inflation_mode=inflation_mode, | |
| ) | |
| if use_post_quant_conv | |
| else None | |
| ) | |
| # A hacky way to remove attention. | |
| if not attention: | |
| self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) | |
| self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) | |
| def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: | |
| h = self.slicing_encode(x) | |
| posterior = DiagonalGaussianDistribution(h) | |
| if not return_dict: | |
| return (posterior,) | |
| return AutoencoderKLOutput(latent_dist=posterior) | |
| def decode( | |
| self, z: torch.Tensor, return_dict: bool = True | |
| ) -> Union[DecoderOutput, torch.Tensor]: | |
| decoded = self.slicing_decode(z) | |
| if not return_dict: | |
| return (decoded,) | |
| return DecoderOutput(sample=decoded) | |
| def _encode( | |
| self, x: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED | |
| ) -> torch.Tensor: | |
| _x = x.to(self.device) | |
| _x = causal_conv_slice_inputs(_x, self.slicing_sample_min_size, memory_state=memory_state) | |
| h = self.encoder(_x, memory_state=memory_state) | |
| if self.quant_conv is not None: | |
| output = self.quant_conv(h, memory_state=memory_state) | |
| else: | |
| output = h | |
| output = causal_conv_gather_outputs(output) | |
| return output.to(x.device) | |
| def _decode( | |
| self, z: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED | |
| ) -> torch.Tensor: | |
| _z = z.to(self.device) | |
| _z = causal_conv_slice_inputs(_z, self.slicing_latent_min_size, memory_state=memory_state) | |
| if self.post_quant_conv is not None: | |
| _z = self.post_quant_conv(_z, memory_state=memory_state) | |
| output = self.decoder(_z, memory_state=memory_state) | |
| output = causal_conv_gather_outputs(output) | |
| return output.to(z.device) | |
| def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: | |
| sp_size = get_sequence_parallel_world_size() | |
| if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: | |
| x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) | |
| encoded_slices = [ | |
| self._encode( | |
| torch.cat((x[:, :, :1], x_slices[0]), dim=2), | |
| memory_state=MemoryState.INITIALIZING, | |
| ) | |
| ] | |
| for x_idx in range(1, len(x_slices)): | |
| encoded_slices.append( | |
| self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) | |
| ) | |
| return torch.cat(encoded_slices, dim=2) | |
| else: | |
| return self._encode(x) | |
| def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: | |
| sp_size = get_sequence_parallel_world_size() | |
| if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: | |
| z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) | |
| decoded_slices = [ | |
| self._decode( | |
| torch.cat((z[:, :, :1], z_slices[0]), dim=2), | |
| memory_state=MemoryState.INITIALIZING, | |
| ) | |
| ] | |
| for z_idx in range(1, len(z_slices)): | |
| decoded_slices.append( | |
| self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) | |
| ) | |
| return torch.cat(decoded_slices, dim=2) | |
| else: | |
| return self._decode(z) | |
| def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: | |
| raise NotImplementedError | |
| def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: | |
| raise NotImplementedError | |
| def forward( | |
| self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs | |
| ): | |
| # x: [b c t h w] | |
| if mode == "encode": | |
| h = self.encode(x) | |
| return h.latent_dist | |
| elif mode == "decode": | |
| h = self.decode(x) | |
| return h.sample | |
| else: | |
| h = self.encode(x) | |
| h = self.decode(h.latent_dist.mode()) | |
| return h.sample | |
| def load_state_dict(self, state_dict, strict=False): | |
| # Newer version of diffusers changed the model keys, | |
| # causing incompatibility with old checkpoints. | |
| # They provided a method for conversion. | |
| # We call conversion before loading state_dict. | |
| convert_deprecated_attention_blocks = getattr( | |
| self, "_convert_deprecated_attention_blocks", None | |
| ) | |
| if callable(convert_deprecated_attention_blocks): | |
| convert_deprecated_attention_blocks(state_dict) | |
| return super().load_state_dict(state_dict, strict) | |
| class VideoAutoencoderKLWrapper(VideoAutoencoderKL): | |
| def __init__( | |
| self, | |
| *args, | |
| spatial_downsample_factor: int, | |
| temporal_downsample_factor: int, | |
| freeze_encoder: bool, | |
| **kwargs, | |
| ): | |
| self.spatial_downsample_factor = spatial_downsample_factor | |
| self.temporal_downsample_factor = temporal_downsample_factor | |
| self.freeze_encoder = freeze_encoder | |
| super().__init__(*args, **kwargs) | |
| def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: | |
| with torch.no_grad() if self.freeze_encoder else nullcontext(): | |
| z, p = self.encode(x) | |
| x = self.decode(z).sample | |
| return CausalAutoencoderOutput(x, z, p) | |
| def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: | |
| if x.ndim == 4: | |
| x = x.unsqueeze(2) | |
| p = super().encode(x).latent_dist | |
| z = p.sample().squeeze(2) | |
| return CausalEncoderOutput(z, p) | |
| def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: | |
| if z.ndim == 4: | |
| z = z.unsqueeze(2) | |
| x = super().decode(z).sample.squeeze(2) | |
| return CausalDecoderOutput(x) | |
| def preprocess(self, x: torch.Tensor): | |
| # x should in [B, C, T, H, W], [B, C, H, W] | |
| assert x.ndim == 4 or x.size(2) % 4 == 1 | |
| return x | |
| def postprocess(self, x: torch.Tensor): | |
| # x should in [B, C, T, H, W], [B, C, H, W] | |
| return x | |
| def set_causal_slicing( | |
| self, | |
| *, | |
| split_size: Optional[int], | |
| memory_device: _memory_device_t, | |
| ): | |
| assert ( | |
| split_size is None or memory_device is not None | |
| ), "if split_size is set, memory_device must not be None." | |
| if split_size is not None: | |
| self.enable_slicing() | |
| self.slicing_sample_min_size = split_size | |
| self.slicing_latent_min_size = split_size // self.temporal_downsample_factor | |
| else: | |
| self.disable_slicing() | |
| for module in self.modules(): | |
| if isinstance(module, InflatedCausalConv3d): | |
| module.set_memory_device(memory_device) | |
| def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): | |
| set_norm_limit(norm_max_mem) | |
| for m in self.modules(): | |
| if isinstance(m, InflatedCausalConv3d): | |
| m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) | |