Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from torch import Tensor | |
| # __all__ = [ | |
| # "ResidualConvBlock", | |
| # "Discriminator", "Generator", | |
| # ] | |
| class ResidualConvBlock(nn.Module): | |
| """Implements residual conv function. | |
| Args: | |
| channels (int): Number of channels in the input image. | |
| """ | |
| def __init__(self, channels: int) -> None: | |
| super(ResidualConvBlock, self).__init__() | |
| self.rcb = nn.Sequential( | |
| nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(channels), | |
| nn.PReLU(), | |
| nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(channels), | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| identity = x | |
| out = self.rcb(x) | |
| out = torch.add(out, identity) | |
| return out | |
| class Discriminator(nn.Module): | |
| def __init__(self) -> None: | |
| super(Discriminator, self).__init__() | |
| self.features = nn.Sequential( | |
| # input size. (3) x 96 x 96 | |
| nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.LeakyReLU(0.2, True), | |
| # state size. (64) x 48 x 48 | |
| nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.LeakyReLU(0.2, True), | |
| # state size. (128) x 24 x 24 | |
| nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.LeakyReLU(0.2, True), | |
| # state size. (256) x 12 x 12 | |
| nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(512), | |
| nn.LeakyReLU(0.2, True), | |
| # state size. (512) x 6 x 6 | |
| nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False), | |
| nn.BatchNorm2d(512), | |
| nn.LeakyReLU(0.2, True), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512 * 6 * 6, 1024), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Linear(1024, 1), | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| out = self.features(x) | |
| out = torch.flatten(out, 1) | |
| out = self.classifier(out) | |
| return out | |
| class Generator(nn.Module): | |
| def __init__(self) -> None: | |
| super(Generator, self).__init__() | |
| # First conv layer. | |
| self.conv_block1 = nn.Sequential( | |
| nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)), | |
| nn.PReLU(), | |
| ) | |
| # Features trunk blocks. | |
| trunk = [] | |
| for _ in range(16): | |
| trunk.append(ResidualConvBlock(64)) | |
| self.trunk = nn.Sequential(*trunk) | |
| # Second conv layer. | |
| self.conv_block2 = nn.Sequential( | |
| nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False), | |
| nn.BatchNorm2d(64), | |
| ) | |
| # Upscale conv block. | |
| self.upsampling = nn.Sequential( | |
| nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)), | |
| nn.PixelShuffle(2), | |
| nn.PReLU(), | |
| nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)), | |
| nn.PixelShuffle(2), | |
| nn.PReLU(), | |
| ) | |
| # Output layer. | |
| self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4)) | |
| # Initialize neural network weights. | |
| self._initialize_weights() | |
| def forward(self, x: Tensor, dop=None) -> Tensor: | |
| if not dop: | |
| return self._forward_impl(x) | |
| else: | |
| return self._forward_w_dop_impl(x, dop) | |
| # Support torch.script function. | |
| def _forward_impl(self, x: Tensor) -> Tensor: | |
| out1 = self.conv_block1(x) | |
| out = self.trunk(out1) | |
| out2 = self.conv_block2(out) | |
| out = torch.add(out1, out2) | |
| out = self.upsampling(out) | |
| out = self.conv_block3(out) | |
| return out | |
| def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor: | |
| out1 = self.conv_block1(x) | |
| out = self.trunk(out1) | |
| out2 = F.dropout2d(self.conv_block2(out), p=dop) | |
| out = torch.add(out1, out2) | |
| out = self.upsampling(out) | |
| out = self.conv_block3(out) | |
| return out | |
| def _initialize_weights(self) -> None: | |
| for module in self.modules(): | |
| if isinstance(module, nn.Conv2d): | |
| nn.init.kaiming_normal_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| elif isinstance(module, nn.BatchNorm2d): | |
| nn.init.constant_(module.weight, 1) | |
| #### BayesCap | |
| class BayesCap(nn.Module): | |
| def __init__(self, in_channels=3, out_channels=3) -> None: | |
| super(BayesCap, self).__init__() | |
| # First conv layer. | |
| self.conv_block1 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| ) | |
| # Features trunk blocks. | |
| trunk = [] | |
| for _ in range(16): | |
| trunk.append(ResidualConvBlock(64)) | |
| self.trunk = nn.Sequential(*trunk) | |
| # Second conv layer. | |
| self.conv_block2 = nn.Sequential( | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=3, stride=1, padding=1, bias=False | |
| ), | |
| nn.BatchNorm2d(64), | |
| ) | |
| # Output layer. | |
| self.conv_block3_mu = nn.Conv2d( | |
| 64, out_channels=out_channels, | |
| kernel_size=9, stride=1, padding=4 | |
| ) | |
| self.conv_block3_alpha = nn.Sequential( | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 1, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.ReLU(), | |
| ) | |
| self.conv_block3_beta = nn.Sequential( | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 1, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.ReLU(), | |
| ) | |
| # Initialize neural network weights. | |
| self._initialize_weights() | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self._forward_impl(x) | |
| # Support torch.script function. | |
| def _forward_impl(self, x: Tensor) -> Tensor: | |
| out1 = self.conv_block1(x) | |
| out = self.trunk(out1) | |
| out2 = self.conv_block2(out) | |
| out = out1 + out2 | |
| out_mu = self.conv_block3_mu(out) | |
| out_alpha = self.conv_block3_alpha(out) | |
| out_beta = self.conv_block3_beta(out) | |
| return out_mu, out_alpha, out_beta | |
| def _initialize_weights(self) -> None: | |
| for module in self.modules(): | |
| if isinstance(module, nn.Conv2d): | |
| nn.init.kaiming_normal_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| elif isinstance(module, nn.BatchNorm2d): | |
| nn.init.constant_(module.weight, 1) | |
| class BayesCap_noID(nn.Module): | |
| def __init__(self, in_channels=3, out_channels=3) -> None: | |
| super(BayesCap_noID, self).__init__() | |
| # First conv layer. | |
| self.conv_block1 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| ) | |
| # Features trunk blocks. | |
| trunk = [] | |
| for _ in range(16): | |
| trunk.append(ResidualConvBlock(64)) | |
| self.trunk = nn.Sequential(*trunk) | |
| # Second conv layer. | |
| self.conv_block2 = nn.Sequential( | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=3, stride=1, padding=1, bias=False | |
| ), | |
| nn.BatchNorm2d(64), | |
| ) | |
| # Output layer. | |
| # self.conv_block3_mu = nn.Conv2d( | |
| # 64, out_channels=out_channels, | |
| # kernel_size=9, stride=1, padding=4 | |
| # ) | |
| self.conv_block3_alpha = nn.Sequential( | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 1, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.ReLU(), | |
| ) | |
| self.conv_block3_beta = nn.Sequential( | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 64, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.PReLU(), | |
| nn.Conv2d( | |
| 64, 1, | |
| kernel_size=9, stride=1, padding=4 | |
| ), | |
| nn.ReLU(), | |
| ) | |
| # Initialize neural network weights. | |
| self._initialize_weights() | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self._forward_impl(x) | |
| # Support torch.script function. | |
| def _forward_impl(self, x: Tensor) -> Tensor: | |
| out1 = self.conv_block1(x) | |
| out = self.trunk(out1) | |
| out2 = self.conv_block2(out) | |
| out = out1 + out2 | |
| # out_mu = self.conv_block3_mu(out) | |
| out_alpha = self.conv_block3_alpha(out) | |
| out_beta = self.conv_block3_beta(out) | |
| return out_alpha, out_beta | |
| def _initialize_weights(self) -> None: | |
| for module in self.modules(): | |
| if isinstance(module, nn.Conv2d): | |
| nn.init.kaiming_normal_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| elif isinstance(module, nn.BatchNorm2d): | |
| nn.init.constant_(module.weight, 1) |