Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from collections import OrderedDict | |
| from utils import dummy_context_mgr | |
| class CLIP_IMG_ENCODER(nn.Module): | |
| """ | |
| CLIP_IMG_ENCODER module for encoding images using CLIP's visual transformer. | |
| """ | |
| def __init__(self, CLIP): | |
| """ | |
| Initialize the CLIP_IMG_ENCODER module. | |
| Args: | |
| CLIP (CLIP): Pre-trained CLIP model. | |
| """ | |
| super(CLIP_IMG_ENCODER, self).__init__() | |
| model = CLIP.visual | |
| self.define_module(model) | |
| # freeze the parameters of the CLIP model | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def define_module(self, model): | |
| """ | |
| Define the individual layers and modules of the CLIP visual transformer model. | |
| Args: | |
| model (nn.Module): CLIP visual transformer model. | |
| """ | |
| # Extract required modules from the CLIP model | |
| self.conv1 = model.conv1 # Convolutional layer | |
| self.class_embedding = model.class_embedding # Class embedding layer | |
| self.positional_embedding = model.positional_embedding # Positional embedding layer | |
| self.ln_pre = model.ln_pre # Linear Normalization layer for pre-normalization | |
| self.transformer = model.transformer # Transformer block | |
| self.ln_post = model.ln_post # Linear Normalization layer for post-normalization | |
| self.proj = model.proj # projection matrix | |
| def dtype(self): | |
| """ | |
| Get the data type of the convolutional layer weights. | |
| """ | |
| return self.conv1.weight.dtype | |
| def transf_to_CLIP_input(self, inputs): | |
| """ | |
| Transform input images to the format expected by CLIP. | |
| Args: | |
| inputs (torch.Tensor): Input images. | |
| Returns: | |
| torch.Tensor: Transformed images. | |
| """ | |
| device = inputs.device | |
| # Check the size of the input image tensor | |
| if len(inputs.size()) != 4: | |
| raise ValueError('Expect the (B, C, X, Y) tensor.') | |
| else: | |
| # Normalize input images | |
| mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) | |
| var = torch.tensor([0.26862954, 0.26130258, 0.27577711]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) | |
| inputs = F.interpolate(inputs * 0.5 + 0.5, size=(224, 224)) | |
| inputs = ((inputs + 1) * 0.5 - mean) / var | |
| return inputs | |
| def forward(self, img: torch.Tensor): | |
| """ | |
| Forward pass of the CLIP_IMG_ENCODER module. | |
| Args: | |
| img (torch.Tensor): Input images. | |
| Returns: | |
| torch.Tensor: Local features extracted from the image. | |
| torch.Tensor: Encoded image embeddings. | |
| """ | |
| # Transform input images to the format expected by CLIP and set its datatype appropriately | |
| x = self.transf_to_CLIP_input(img) | |
| x = x.type(self.dtype) | |
| # Pass the image through Convolutional layer | |
| x = self.conv1(x) # shape = [*, width, grid, grid] | |
| grid = x.size(-1) | |
| # Reshape and permute the tensor for transformer input | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| # Add class and positional embeddings | |
| x = torch.cat( | |
| [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
| x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
| x = x + self.positional_embedding.to(x.dtype) | |
| x = self.ln_pre(x) | |
| # NLD (Batch Size - Length - Dimension) -> LND (Length - Batch Size - Dimension) | |
| x = x.permute(1, 0, 2) | |
| # Extract local features using transformer blocks | |
| selected = [1, 4, 8] | |
| local_features = [] | |
| for i in range(12): | |
| x = self.transformer.resblocks[i](x) | |
| if i in selected: | |
| local_features.append( | |
| x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type( | |
| img.dtype)) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_post(x[:, 0, :]) | |
| if self.proj is not None: | |
| x = x @ self.proj # Perform matrix multiplication with projection matrix and tensor | |
| return torch.stack(local_features, dim=1), x.type(img.dtype) | |
| class CLIP_TXT_ENCODER(nn.Module): | |
| """ | |
| CLIP_TXT_ENCODER module for encoding text inputs using CLIP's transformer. | |
| """ | |
| def __init__(self, CLIP): | |
| """ | |
| Initialize the CLIP_TXT_ENCODER module. | |
| Args: | |
| CLIP (CLIP): Pre-trained CLIP model. | |
| """ | |
| super(CLIP_TXT_ENCODER, self).__init__() | |
| self.define_module(CLIP) | |
| # Freeze the parameters of the CLIP model | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def define_module(self, CLIP): | |
| """ | |
| Define the individual modules of the CLIP transformer model. | |
| Args: | |
| CLIP (CLIP): Pre-trained CLIP model. | |
| """ | |
| self.transformer = CLIP.transformer # Transformer block | |
| self.vocab_size = CLIP.vocab_size # Size of the vocabulary of the transformer | |
| self.token_embedding = CLIP.token_embedding # token embedding block | |
| self.positional_embedding = CLIP.positional_embedding # positional embedding block | |
| self.ln_final = CLIP.ln_final # Linear Normalization layer | |
| self.text_projection = CLIP.text_projection # Projection matrix for text | |
| def dtype(self): | |
| """ | |
| Get the data type of the first layer's weights in the transformer. | |
| """ | |
| return self.transformer.resblocks[0].mlp.c_fc.weight.dtype | |
| def forward(self, text): | |
| """ | |
| Forward pass of the CLIP_TXT_ENCODER module. | |
| Args: | |
| text (torch.Tensor): Input text tokens. | |
| Returns: | |
| torch.Tensor: Encoded sentence embeddings. | |
| torch.Tensor: Transformer output for the input text. | |
| """ | |
| # Embed input text tokens | |
| x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] | |
| # Add positional embeddings | |
| x = x + self.positional_embedding.type(self.dtype) | |
| # Permute dimensions for transformer input | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| # Pass input through the transformer | |
| x = self.transformer(x) | |
| # Permute dimensions back to original shape | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| # Apply layer normalization | |
| x = self.ln_final(x).type(self.dtype) # shape = [batch_size, n_ctx, transformer.width] | |
| # Extract sentence embeddings from the end-of-text (eot_token : is the highest number in each sequence) | |
| sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
| # Return the sentence embedding and transformer ouput | |
| return sent_emb, x | |
| class CLIP_Mapper(nn.Module): | |
| """ | |
| CLIP_Mapper module for mapping images with prompts using CLIP's transformer. | |
| """ | |
| def __init__(self, CLIP): | |
| """ | |
| Initialize the CLIP_Mapper module. | |
| Args: | |
| CLIP (CLIP): Pre-trained CLIP model. | |
| """ | |
| super(CLIP_Mapper, self).__init__() | |
| model = CLIP.visual | |
| self.define_module(model) | |
| # Freeze the parameters of the CLIP visual model | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| def define_module(self, model): | |
| """ | |
| Define the individual modules of the CLIP visual model. | |
| Args: | |
| model: Pre-trained CLIP visual model. | |
| """ | |
| self.conv1 = model.conv1 | |
| self.class_embedding = model.class_embedding | |
| self.positional_embedding = model.positional_embedding | |
| self.ln_pre = model.ln_pre | |
| self.transformer = model.transformer | |
| def dtype(self): | |
| """ | |
| Get the data type of the weights of the first convolutional layer. | |
| """ | |
| return self.conv1.weight.dtype | |
| def forward(self, img: torch.Tensor, prompts: torch.Tensor): | |
| """ | |
| Forward pass of the CLIP_Mapper module. | |
| Args: | |
| img (torch.Tensor): Input image tensor. | |
| prompts (torch.Tensor): Prompt tokens for mapping. | |
| Returns: | |
| torch.Tensor: Mapped features from the CLIP model. | |
| """ | |
| # Convert input image and prompts to the appropriate data type | |
| x = img.type(self.dtype) | |
| prompts = prompts.type(self.dtype) | |
| grid = x.size(-1) | |
| # Reshape the input image tensor | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| # Append the class embeddings to input tensors | |
| x = torch.cat( | |
| [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
| x], | |
| dim=1 | |
| ) # shape = [*, grid ** 2 + 1, width] | |
| # Append the positional embeddings to the input tensor | |
| x = x + self.positional_embedding.to(x.dtype) | |
| # Perform the layer normalization | |
| x = self.ln_pre(x) | |
| # NLD -> LND | |
| x = x.permute(1, 0, 2) | |
| # Local features | |
| selected = [1, 2, 3, 4, 5, 6, 7, 8] | |
| begin, end = 0, 12 | |
| prompt_idx = 0 | |
| for i in range(begin, end): | |
| # Add prompt to the input tensor | |
| if i in selected: | |
| prompt = prompts[:, prompt_idx, :].unsqueeze(0) | |
| prompt_idx = prompt_idx + 1 | |
| x = torch.cat((x, prompt), dim=0) | |
| x = self.transformer.resblocks[i](x) | |
| x = x[:-1, :, :] | |
| else: | |
| x = self.transformer.resblocks[i](x) | |
| # Reshape and return mapped features | |
| return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype) | |
| class CLIP_Adapter(nn.Module): | |
| """ | |
| CLIP_Adapter module for adapting features from a generator to match the CLIP model's input requirements. | |
| """ | |
| def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP): | |
| """ | |
| Initialize the CLIP_Adapter module. | |
| Args: | |
| in_ch (int): Number of input channels. | |
| mid_ch (int): Number of channels in the intermediate layers. | |
| out_ch (int): Number of output channels. | |
| G_ch (int): Number of channels in the generator's output. | |
| CLIP_ch (int): Number of channels in the CLIP model's input. | |
| cond_dim (int): Dimension of the conditioning vector. | |
| k (int): Kernel size for convolutional layers. | |
| s (int): Stride for convolutional layers. | |
| p (int): Padding for convolutional layers. | |
| map_num (int): Number of mapping blocks. | |
| CLIP: Pre-trained CLIP model. | |
| """ | |
| super(CLIP_Adapter, self).__init__() | |
| self.CLIP_ch = CLIP_ch | |
| self.FBlocks = nn.ModuleList([]) | |
| # Define Mapping blocks (M_Block) and them to Feature blocks (FBlock) for given number of mapping blocks. | |
| self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p)) | |
| for i in range(map_num - 1): | |
| self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p)) | |
| # Convolutional layer to fuse adapted features | |
| self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2) | |
| # CLIP Mapper module to map adapted features to CLIP's input space | |
| self.CLIP_ViT = CLIP_Mapper(CLIP) | |
| # Convolutional layer to further process mapped features | |
| self.conv = nn.Conv2d(768, G_ch, 5, 1, 2) | |
| # Fully connected layer for conditioning | |
| self.fc_prompt = nn.Linear(cond_dim, CLIP_ch * 8) | |
| def forward(self, out, c): | |
| """ | |
| Forward pass of the CLIP_Adapter module. Takes output features from the generator and conditioning vector | |
| as input, adapts features using the Feature block having multiple mapping blocks, fuses them, map them to | |
| CLIPs input space and returns the processed features | |
| Args: | |
| out (torch.Tensor): Output features from the generator. | |
| c (torch.Tensor): Conditioning vector. | |
| Returns: | |
| torch.Tensor: Adapted and mapped features for the generator. | |
| """ | |
| # Generate prompts from the conditioning vector | |
| prompts = self.fc_prompt(c).view(c.size(0), -1, self.CLIP_ch) | |
| # Pass features through feature block consisting of multiple mapping blocks | |
| for FBlock in self.FBlocks: | |
| out = FBlock(out, c) | |
| # Fuse adapted features | |
| fuse_feat = self.conv_fuse(out) | |
| # Map fused features to CLIP's input space | |
| map_feat = self.CLIP_ViT(fuse_feat, prompts) | |
| # Further process mapped features and return | |
| return self.conv(fuse_feat + 0.1 * map_feat) | |
| class NetG(nn.Module): | |
| """ | |
| Generator network for synthesizing images conditioned on text and noise | |
| """ | |
| def __init__(self, ngf, nz, cond_dim, imsize, ch_size, mixed_precision, CLIP): | |
| """ | |
| Initializes the Generator network. | |
| Parameters: | |
| ngf (int): Number of generator filters. | |
| nz (int): Dimensionality of the input noise vector. | |
| cond_dim (int): Dimensionality of the conditioning vector. | |
| imsize (int): Size of the generated images. | |
| ch_size (int): Number of output channels for the generated images. | |
| mixed_precision (bool): Whether to use mixed precision training. | |
| CLIP: CLIP model for feature adaptation. | |
| """ | |
| super(NetG, self).__init__() | |
| # Define attributes | |
| self.ngf = ngf | |
| self.mixed_precision = mixed_precision | |
| # Build CLIP Mapper | |
| self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32 | |
| self.CLIP_ch = 768 | |
| # fully connected layer to convert the noise vector into a feature map of dimensions (code_sz * code_sz * code_ch) | |
| self.fc_code = nn.Linear(nz, self.code_sz * self.code_sz * self.code_ch) | |
| self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf * 8, self.CLIP_ch, cond_dim + nz, 3, 1, | |
| 1, 4, CLIP) | |
| # Build GBlocks | |
| self.GBlocks = nn.ModuleList([]) | |
| in_out_pairs = list(get_G_in_out_chs(ngf, imsize)) | |
| imsize = 4 | |
| for idx, (in_ch, out_ch) in enumerate(in_out_pairs): | |
| if idx < (len(in_out_pairs) - 1): | |
| imsize = imsize * 2 | |
| else: | |
| imsize = 224 | |
| self.GBlocks.append(G_Block(cond_dim + nz, in_ch, out_ch, imsize)) | |
| # To RGB image conversion using the sequential layers having leakyReLU activation function | |
| self.to_rgb = nn.Sequential( | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(out_ch, ch_size, 3, 1, 1), | |
| ) | |
| def forward(self, noise, c, eval=False): # x=noise, c=ent_emb | |
| """ | |
| Forward pass of the generator network. | |
| Args: | |
| noise (torch.Tensor): Input noise vector. | |
| c (torch.Tensor): Conditioning information, typically an embedding representing attributes of the output. | |
| eval (bool, optional): Flag indicating whether the network is in evaluation mode. Defaults to False. | |
| Returns: | |
| torch.Tensor: Generated RGB images. | |
| """ | |
| # Context manager for enabling automatic mixed precision training | |
| with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp: | |
| # Concatenate noise and conditioning information | |
| cond = torch.cat((noise, c), dim=1) | |
| # Pass noise through fully connected layer to generate feature map and adapt features using CLIP Mapper | |
| out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond) | |
| # Apply GBlocks to progressively upsample feature representation, fuse text and visual features | |
| for GBlock in self.GBlocks: | |
| out = GBlock(out, cond) | |
| # Convert final feature representation to RGB images | |
| out = self.to_rgb(out) | |
| return out | |
| class NetD(nn.Module): | |
| """ | |
| Discriminator network for evaluating the realism of images. | |
| Attributes: | |
| DBlocks (nn.ModuleList): List of D_Block modules for processing feature maps. | |
| main (D_Block): Main D_Block module for final processing. | |
| """ | |
| def __init__(self, ndf, imsize, ch_size, mixed_precision): | |
| """ | |
| Initializes the Discriminator network | |
| Args: | |
| ndf (int): Number of channels in the initial features. | |
| imsize (int): Size of the input images (assumed square). | |
| ch_size (int): Number of channels in the output feature maps. | |
| mixed_precision (bool): Flag indicating whether to use mixed precision training. | |
| """ | |
| super(NetD, self).__init__() | |
| self.mixed_precision = mixed_precision | |
| # Define the DBlock | |
| self.DBlocks = nn.ModuleList([ | |
| D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True), | |
| D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True), | |
| ]) | |
| # Define the main DBlock for the final processing | |
| self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False) | |
| def forward(self, h): | |
| """ | |
| Forward pass of the discriminator network. | |
| Args: | |
| h (torch.Tensor): Input feature maps. | |
| Returns: | |
| torch.Tensor: Discriminator output. | |
| """ | |
| with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: | |
| # Initial feature map | |
| out = h[:, 0] | |
| # Pass the input feature through each DBlock | |
| for idx in range(len(self.DBlocks)): | |
| out = self.DBlocks[idx](out, h[:, idx + 1]) | |
| # Final processing through the main DBlock | |
| out = self.main(out) | |
| return out | |
| class NetC(nn.Module): | |
| """ | |
| Classifier / Comparator network for classifying the joint features of the generator output and condition text. | |
| Attributes: | |
| cond_dim (int): Dimensionality of the conditioning information. | |
| mixed_precision (bool): Flag indicating whether to use mixed precision training. | |
| joint_conv (nn.Sequential): Sequential module defining the classifier layers. | |
| """ | |
| def __init__(self, ndf, cond_dim, mixed_precision): | |
| """ | |
| """ | |
| super(NetC, self).__init__() | |
| self.cond_dim = cond_dim | |
| self.mixed_precision = mixed_precision | |
| # Define the classifier layers, sequential convolutional 2D layer with LeakyReLU as the activation function | |
| self.joint_conv = nn.Sequential( | |
| nn.Conv2d(512 + 512, 128, 4, 1, 0, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(128, 1, 4, 1, 0, bias=False), | |
| ) | |
| def forward(self, out, cond): | |
| """ | |
| Forward pass of the classifier network. | |
| Args: | |
| out (torch.Tensor): Generator output feature map. | |
| cond (torch.Tensor): Conditioning information vector | |
| """ | |
| with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: | |
| # Reshape and repeat conditioning information vector to match the feature map size | |
| cond = cond.view(-1, self.cond_dim, 1, 1) | |
| cond = cond.repeat(1, 1, 7, 7) | |
| # Concatenate feature map and conditioned information | |
| h_c_code = torch.cat((out, cond), 1) | |
| # Pass through the classifier layers | |
| out = self.joint_conv(h_c_code) | |
| return out | |
| class M_Block(nn.Module): | |
| """ | |
| Multi-scale block consisting of convolutional layers and conditioning. | |
| Attributes: | |
| conv1 (nn.Conv2d): First convolutional layer. | |
| fuse1 (DFBlock): Conditioning block for the first convolutional layer. | |
| conv2 (nn.Conv2d): Second convolutional layer. | |
| fuse2 (DFBlock): Conditioning block for the second convolutional layer. | |
| learnable_sc (bool): Flag indicating whether the shortcut connection is learnable. | |
| c_sc (nn.Conv2d): Convolutional layer for the shortcut connection. | |
| """ | |
| def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p): | |
| """ | |
| Initializes the Multi-scale block. | |
| Args: | |
| in_ch (int): Number of input channels. | |
| mid_ch (int): Number of channels in the intermediate layers. | |
| out_ch (int): Number of output channels. | |
| cond_dim (int): Dimensionality of the conditioning information. | |
| k (int): Kernel size for convolutional layers. | |
| s (int): Stride for convolutional layers. | |
| p (int): Padding for convolutional layers. | |
| """ | |
| super(M_Block, self).__init__() | |
| # Define convolutional layers and conditioning blocks | |
| self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p) | |
| self.fuse1 = DFBLK(cond_dim, mid_ch) | |
| self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p) | |
| self.fuse2 = DFBLK(cond_dim, out_ch) | |
| # Learnable shortcut connection | |
| self.learnable_sc = in_ch != out_ch | |
| if self.learnable_sc: | |
| self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) | |
| def shortcut(self, x): | |
| """ | |
| Defines the shortcut connection. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Shortcut connection output. | |
| """ | |
| if self.learnable_sc: | |
| x = self.c_sc(x) | |
| return x | |
| def residual(self, h, text): | |
| """ | |
| Defines the residual path with conditioning. | |
| Args: | |
| h (torch.Tensor): Input tensor. | |
| text (torch.Tensor): Conditioning information. | |
| Returns: | |
| torch.Tensor: Residual path output. | |
| """ | |
| h = self.conv1(h) | |
| h = self.fuse1(h, text) | |
| h = self.conv2(h) | |
| h = self.fuse2(h, text) | |
| return h | |
| def forward(self, h, c): | |
| """ | |
| Forward pass of the multi-scale block. | |
| Args: | |
| h (torch.Tensor): Input tensor. | |
| c (torch.Tensor): Conditioning information. | |
| Returns: | |
| torch.Tensor: Output tensor. | |
| """ | |
| return self.shortcut(h) + self.residual(h, c) | |
| class G_Block(nn.Module): | |
| """ | |
| Generator block consisting of convolutional layers and conditioning. | |
| Attributes: | |
| imsize (int): Size of the output image. | |
| learnable_sc (bool): Flag indicating whether the shortcut connection is learnable. | |
| c1 (nn.Conv2d): First convolutional layer. | |
| c2 (nn.Conv2d): Second convolutional layer. | |
| fuse1 (DFBLK): Conditioning block for the first convolutional layer. | |
| fuse2 (DFBLK): Conditioning block for the second convolutional layer. | |
| c_sc (nn.Conv2d): Convolutional layer for the shortcut connection. | |
| """ | |
| def __init__(self, cond_dim, in_ch, out_ch, imsize): | |
| """ | |
| Initialize the Generator block. | |
| Args: | |
| cond_dim (int): Dimensionality of the conditioning information. | |
| in_ch (int): Number of input channels. | |
| out_ch (int): Number of output channels. | |
| imsize (int): Size of the output image. | |
| """ | |
| super(G_Block, self).__init__() | |
| # Initialize attributes | |
| self.imsize = imsize | |
| self.learnable_sc = in_ch != out_ch | |
| # Define convolution layers and conditioning blocks | |
| self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1) | |
| self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1) | |
| self.fuse1 = DFBLK(cond_dim, in_ch) | |
| self.fuse2 = DFBLK(cond_dim, out_ch) | |
| # Learnable shortcut connection | |
| if self.learnable_sc: | |
| self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) | |
| def shortcut(self, x): | |
| """ | |
| Defines the shortcut connection. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Shortcut connection output. | |
| """ | |
| if self.learnable_sc: | |
| x = self.c_sc(x) | |
| return x | |
| def residual(self, h, y): | |
| """ | |
| Defines the residual path with conditioning. | |
| Args: | |
| h (torch.Tensor): Input tensor. | |
| y (torch.Tensor): Conditioning information. | |
| Returns: | |
| torch.Tensor: Residual path output. | |
| """ | |
| h = self.fuse1(h, y) | |
| h = self.c1(h) | |
| h = self.fuse2(h, y) | |
| h = self.c2(h) | |
| return h | |
| def forward(self, h, y): | |
| """ | |
| Forward pass of the generator block. | |
| Args: | |
| h (torch.Tensor): Input tensor. | |
| y (torch.Tensor): Conditioning information. | |
| Returns: | |
| torch.Tensor: Output tensor. | |
| """ | |
| h = F.interpolate(h, size=(self.imsize, self.imsize)) | |
| return self.shortcut(h) + self.residual(h, y) | |
| class D_Block(nn.Module): | |
| """ | |
| Discriminator block. | |
| """ | |
| def __init__(self, fin, fout, k, s, p, res, CLIP_feat): | |
| """ | |
| Initializes Discriminator block. | |
| Args: | |
| - fin (int): Number of input channels. | |
| - fout (int): Number of output channels. | |
| - k (int): Kernel size for convolutional layers. | |
| - s (int): Stride for convolutional layers. | |
| - p (int): Padding for convolutional layers. | |
| - res (bool): Whether to use residual connection. | |
| - CLIP_feat (bool): Whether to incorporate CLIP features. | |
| """ | |
| super(D_Block, self).__init__() | |
| self.res, self.CLIP_feat = res, CLIP_feat | |
| self.learned_shortcut = (fin != fout) | |
| # Convolutional layers for residual path | |
| self.conv_r = nn.Sequential( | |
| nn.Conv2d(fin, fout, k, s, p, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(fout, fout, k, s, p, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # Convolutional layers for shortcut connection | |
| self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0) | |
| # Parameters for learned residual and CLIP features | |
| if self.res == True: | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| if self.CLIP_feat == True: | |
| self.beta = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x, CLIP_feat=None): | |
| """ | |
| Forward pass of the discriminator block. | |
| Args: | |
| - x (torch.Tensor): Input tensor. | |
| - CLIP_feat (torch.Tensor): Optional CLIP features tensor. | |
| Returns: | |
| - torch.Tensor: Output tensor. | |
| """ | |
| # Compute the residual features | |
| res = self.conv_r(x) | |
| # Compute the shortcut connection | |
| if self.learned_shortcut: | |
| x = self.conv_s(x) | |
| # Incorporate learned residual and CLIP features if enabled | |
| if (self.res == True) and (self.CLIP_feat == True): | |
| return x + self.gamma * res + self.beta * CLIP_feat | |
| elif (self.res == True) and (self.CLIP_feat != True): | |
| return x + self.gamma * res | |
| elif (self.res != True) and (self.CLIP_feat == True): | |
| return x + self.beta * CLIP_feat | |
| else: | |
| return x | |
| class DFBLK(nn.Module): | |
| """ | |
| Diffusion Block of the Generator network with Conditional feature block | |
| """ | |
| def __init__(self, cond_dim, in_ch): | |
| """ | |
| Initializing the Conditional feature block of the DFBlock. | |
| Args: | |
| - cond_dim (int): Dimensionality of the conditional input. | |
| - in_ch (int): Number of input channels. | |
| """ | |
| super(DFBLK, self).__init__() | |
| # Define conditional affine transformations | |
| self.affine0 = Affine(cond_dim, in_ch) | |
| self.affine1 = Affine(cond_dim, in_ch) | |
| def forward(self, x, y=None): | |
| """ | |
| Forward pass of the conditional feature block. | |
| Args: | |
| - x (torch.Tensor): Input tensor. | |
| - y (torch.Tensor, optional): Conditional input tensor. Default is None. | |
| Returns: | |
| - torch.Tensor: Output tensor. | |
| """ | |
| # Apply the first affine transformation and activation function | |
| h = self.affine0(x, y) | |
| h = nn.LeakyReLU(0.2, inplace=True)(h) | |
| # Apply second affine transformation and activation function | |
| h = self.affine1(h, y) | |
| h = nn.LeakyReLU(0.2, inplace=True)(h) | |
| return h | |
| class QuickGELU(nn.Module): | |
| """ | |
| Efficient and faster version of GELU, | |
| for non-linearity and to learn complex patterns | |
| """ | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| Forward pass of the QuickGELU activation function. | |
| Args: | |
| - x (torch.Tensor): Input tensor. | |
| Returns: | |
| - torch.Tensor: Output tensor. | |
| """ | |
| # Apply QuickGELU activation function | |
| return x * torch.sigmoid(1.702 * x) | |
| # Taken from the RAT-GAN repository | |
| class Affine(nn.Module): | |
| """ | |
| Affine transformation module that applies conditional scaling and shifting to input features, | |
| to incorporate additional control over the generated output based on input conditions. | |
| """ | |
| def __init__(self, cond_dim, num_features): | |
| """ | |
| Initialize the affine transformation module. | |
| Args: | |
| cond_dim (int): Dimensionality of the conditioning information. | |
| num_features (int): Number of input features. | |
| """ | |
| super(Affine, self).__init__() | |
| # Define 2 fully connected networks to compute gamma and beta parameters | |
| # each 2 linear layers with RELU activation in between | |
| self.fc_gamma = nn.Sequential(OrderedDict([ | |
| ('linear1', nn.Linear(cond_dim, num_features)), | |
| ('relu1', nn.ReLU(inplace=True)), | |
| ('linear2', nn.Linear(num_features, num_features)), | |
| ])) | |
| self.fc_beta = nn.Sequential(OrderedDict([ | |
| ('linear1', nn.Linear(cond_dim, num_features)), | |
| ('relu1', nn.ReLU(inplace=True)), | |
| ('linear2', nn.Linear(num_features, num_features)), | |
| ])) | |
| # Initializes the weights and biases of the network | |
| self._initialize() | |
| def _initialize(self): | |
| """ | |
| Initializes the weights and biases of the linear layers responsible for computing gamma and beta | |
| """ | |
| nn.init.zeros_(self.fc_gamma.linear2.weight.data) | |
| nn.init.ones_(self.fc_gamma.linear2.bias.data) | |
| nn.init.zeros_(self.fc_beta.linear2.weight.data) | |
| nn.init.zeros_(self.fc_beta.linear2.bias.data) | |
| def forward(self, x, y=None): | |
| """ | |
| Forward pass of the Affine transformation module. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| y (torch.Tensor, optional): Conditioning information tensor. Default is None. | |
| Returns: | |
| torch.Tensor: Transformed tensor after applying affine transformation. | |
| """ | |
| # Compute gamma and beta parameters | |
| weight = self.fc_gamma(y) | |
| bias = self.fc_beta(y) | |
| # Ensure proper shape for weight and bias tensors | |
| if weight.dim() == 1: | |
| weight = weight.unsqueeze(0) | |
| if bias.dim() == 1: | |
| bias = bias.unsqueeze(0) | |
| # Expand weight and bias tensors to match input tensor shape | |
| size = x.size() | |
| weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) | |
| bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) | |
| # Apply affine transformation | |
| return weight * x + bias | |
| def get_G_in_out_chs(nf, imsize): | |
| """ | |
| Compute input-output channel pairs for generator blocks based on given number of channels and image size. | |
| Args: | |
| nf (int): Number of input channels. | |
| imsize (int): Size of the input image. | |
| Returns: | |
| list: List of tuples containing input-output channel pairs for generator blocks. | |
| """ | |
| # Determine the number of layers based on image size | |
| layer_num = int(np.log2(imsize)) - 1 | |
| # Compute the number of channels for each layer | |
| channel_nums = [nf * min(2 ** idx, 8) for idx in range(layer_num)] | |
| # Reverse the channel numbers to start with the highest channel count | |
| channel_nums = channel_nums[::-1] | |
| # Generate input-output channel pairs for generator blocks | |
| in_out_pairs = zip(channel_nums[:-1], channel_nums[1:]) | |
| return in_out_pairs | |
| def get_D_in_out_chs(nf, imsize): | |
| """ | |
| Compute input-output channel pairs for discriminator blocks based on given number of channels and image size. | |
| Args: | |
| nf (int): Number of input channels. | |
| imsize (int): Size of the input image. | |
| Returns: | |
| list: List of tuples containing input-output channel pairs for discriminator blocks. | |
| """ | |
| # Determine the number of layers based on image size | |
| layer_num = int(np.log2(imsize)) - 1 | |
| # Compute the number of channels for each layer | |
| channel_nums = [nf * min(2 ** idx, 8) for idx in range(layer_num)] | |
| # Generate input-output channel pairs for discriminator blocks | |
| in_out_pairs = zip(channel_nums[:-1], channel_nums[1:]) | |
| return in_out_pairs | |