Spaces:
Runtime error
Runtime error
| import math | |
| import random | |
| from typing import Any | |
| import torch | |
| import numpy as np | |
| import collections | |
| from itertools import repeat | |
| from torch import conv2d, conv_transpose2d | |
| def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): | |
| if schedule == "linear": | |
| betas = ( | |
| torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 | |
| ) | |
| elif schedule == "cosine": | |
| timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device) | |
| alphas = timesteps / (1 + cosine_s) * np.pi / 2 | |
| alphas = torch.cos(alphas).pow(2).to(device) | |
| alphas = alphas / alphas[0] | |
| betas = 1 - alphas[1:] / alphas[:-1] | |
| betas = np.clip(betas, a_min=0, a_max=0.999) | |
| elif schedule == "sqrt_linear": | |
| betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) | |
| elif schedule == "sqrt": | |
| betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 | |
| else: | |
| raise ValueError(f"schedule '{schedule}' unknown.") | |
| return betas.numpy() | |
| def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): | |
| # select alphas for computing the variance schedule | |
| alphas = alphacums[ddim_timesteps] | |
| alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) | |
| # according the the formula provided in https://arxiv.org/abs/2010.02502 | |
| sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) | |
| if verbose: | |
| print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') | |
| print(f'For the chosen value of eta, which is {eta}, ' | |
| f'this results in the following sigma_t schedule for ddim sampler {sigmas}') | |
| return sigmas, alphas, alphas_prev | |
| def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): | |
| if ddim_discr_method == 'uniform': | |
| c = num_ddpm_timesteps // num_ddim_timesteps | |
| ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) | |
| elif ddim_discr_method == 'quad': | |
| ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) | |
| else: | |
| raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') | |
| # assert ddim_timesteps.shape[0] == num_ddim_timesteps | |
| # add one to get the final alpha values right (the ones from first scale to data during sampling) | |
| steps_out = ddim_timesteps + 1 | |
| if verbose: | |
| print(f'Selected timesteps for ddim sampler: {steps_out}') | |
| return steps_out | |
| def noise_like(shape, device, repeat=False): | |
| repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) | |
| noise = lambda: torch.randn(shape, device=device) | |
| return repeat_noise() if repeat else noise() | |
| def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=device) | |
| args = timesteps[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| ###### MAT and FcF ####### | |
| def normalize_2nd_moment(x, dim=1, eps=1e-8): | |
| return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() | |
| class EasyDict(dict): | |
| """Convenience class that behaves like a dict but allows access with the attribute syntax.""" | |
| def __getattr__(self, name: str) -> Any: | |
| try: | |
| return self[name] | |
| except KeyError: | |
| raise AttributeError(name) | |
| def __setattr__(self, name: str, value: Any) -> None: | |
| self[name] = value | |
| def __delattr__(self, name: str) -> None: | |
| del self[name] | |
| def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): | |
| """Slow reference implementation of `bias_act()` using standard TensorFlow ops. | |
| """ | |
| assert isinstance(x, torch.Tensor) | |
| assert clamp is None or clamp >= 0 | |
| spec = activation_funcs[act] | |
| alpha = float(alpha if alpha is not None else spec.def_alpha) | |
| gain = float(gain if gain is not None else spec.def_gain) | |
| clamp = float(clamp if clamp is not None else -1) | |
| # Add bias. | |
| if b is not None: | |
| assert isinstance(b, torch.Tensor) and b.ndim == 1 | |
| assert 0 <= dim < x.ndim | |
| assert b.shape[0] == x.shape[dim] | |
| x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) | |
| # Evaluate activation function. | |
| alpha = float(alpha) | |
| x = spec.func(x, alpha=alpha) | |
| # Scale by gain. | |
| gain = float(gain) | |
| if gain != 1: | |
| x = x * gain | |
| # Clamp. | |
| if clamp >= 0: | |
| x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type | |
| return x | |
| def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'): | |
| r"""Fused bias and activation function. | |
| Adds bias `b` to activation tensor `x`, evaluates activation function `act`, | |
| and scales the result by `gain`. Each of the steps is optional. In most cases, | |
| the fused op is considerably more efficient than performing the same calculation | |
| using standard PyTorch ops. It supports first and second order gradients, | |
| but not third order gradients. | |
| Args: | |
| x: Input activation tensor. Can be of any shape. | |
| b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type | |
| as `x`. The shape must be known, and it must match the dimension of `x` | |
| corresponding to `dim`. | |
| dim: The dimension in `x` corresponding to the elements of `b`. | |
| The value of `dim` is ignored if `b` is not specified. | |
| act: Name of the activation function to evaluate, or `"linear"` to disable. | |
| Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. | |
| See `activation_funcs` for a full list. `None` is not allowed. | |
| alpha: Shape parameter for the activation function, or `None` to use the default. | |
| gain: Scaling factor for the output tensor, or `None` to use default. | |
| See `activation_funcs` for the default scaling of each activation function. | |
| If unsure, consider specifying 1. | |
| clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable | |
| the clamping (default). | |
| impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). | |
| Returns: | |
| Tensor of the same shape and datatype as `x`. | |
| """ | |
| assert isinstance(x, torch.Tensor) | |
| assert impl in ['ref', 'cuda'] | |
| return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) | |
| def _get_filter_size(f): | |
| if f is None: | |
| return 1, 1 | |
| assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | |
| fw = f.shape[-1] | |
| fh = f.shape[0] | |
| fw = int(fw) | |
| fh = int(fh) | |
| assert fw >= 1 and fh >= 1 | |
| return fw, fh | |
| def _get_weight_shape(w): | |
| shape = [int(sz) for sz in w.shape] | |
| return shape | |
| def _parse_scaling(scaling): | |
| if isinstance(scaling, int): | |
| scaling = [scaling, scaling] | |
| assert isinstance(scaling, (list, tuple)) | |
| assert all(isinstance(x, int) for x in scaling) | |
| sx, sy = scaling | |
| assert sx >= 1 and sy >= 1 | |
| return sx, sy | |
| def _parse_padding(padding): | |
| if isinstance(padding, int): | |
| padding = [padding, padding] | |
| assert isinstance(padding, (list, tuple)) | |
| assert all(isinstance(x, int) for x in padding) | |
| if len(padding) == 2: | |
| padx, pady = padding | |
| padding = [padx, padx, pady, pady] | |
| padx0, padx1, pady0, pady1 = padding | |
| return padx0, padx1, pady0, pady1 | |
| def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): | |
| r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. | |
| Args: | |
| f: Torch tensor, numpy array, or python list of the shape | |
| `[filter_height, filter_width]` (non-separable), | |
| `[filter_taps]` (separable), | |
| `[]` (impulse), or | |
| `None` (identity). | |
| device: Result device (default: cpu). | |
| normalize: Normalize the filter so that it retains the magnitude | |
| for constant input signal (DC)? (default: True). | |
| flip_filter: Flip the filter? (default: False). | |
| gain: Overall scaling factor for signal magnitude (default: 1). | |
| separable: Return a separable filter? (default: select automatically). | |
| Returns: | |
| Float32 tensor of the shape | |
| `[filter_height, filter_width]` (non-separable) or | |
| `[filter_taps]` (separable). | |
| """ | |
| # Validate. | |
| if f is None: | |
| f = 1 | |
| f = torch.as_tensor(f, dtype=torch.float32) | |
| assert f.ndim in [0, 1, 2] | |
| assert f.numel() > 0 | |
| if f.ndim == 0: | |
| f = f[np.newaxis] | |
| # Separable? | |
| if separable is None: | |
| separable = (f.ndim == 1 and f.numel() >= 8) | |
| if f.ndim == 1 and not separable: | |
| f = f.ger(f) | |
| assert f.ndim == (1 if separable else 2) | |
| # Apply normalize, flip, gain, and device. | |
| if normalize: | |
| f /= f.sum() | |
| if flip_filter: | |
| f = f.flip(list(range(f.ndim))) | |
| f = f * (gain ** (f.ndim / 2)) | |
| f = f.to(device=device) | |
| return f | |
| def _ntuple(n): | |
| def parse(x): | |
| if isinstance(x, collections.abc.Iterable): | |
| return x | |
| return tuple(repeat(x, n)) | |
| return parse | |
| to_2tuple = _ntuple(2) | |
| activation_funcs = { | |
| 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), | |
| 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, | |
| ref='y', has_2nd_grad=False), | |
| 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, | |
| def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), | |
| 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', | |
| has_2nd_grad=True), | |
| 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', | |
| has_2nd_grad=True), | |
| 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', | |
| has_2nd_grad=True), | |
| 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', | |
| has_2nd_grad=True), | |
| 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, | |
| ref='y', has_2nd_grad=True), | |
| 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', | |
| has_2nd_grad=True), | |
| } | |
| def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): | |
| r"""Pad, upsample, filter, and downsample a batch of 2D images. | |
| Performs the following sequence of operations for each channel: | |
| 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). | |
| 2. Pad the image with the specified number of zeros on each side (`padding`). | |
| Negative padding corresponds to cropping the image. | |
| 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it | |
| so that the footprint of all output pixels lies within the input image. | |
| 4. Downsample the image by keeping every Nth pixel (`down`). | |
| This sequence of operations bears close resemblance to scipy.signal.upfirdn(). | |
| The fused op is considerably more efficient than performing the same calculation | |
| using standard PyTorch ops. It supports gradients of arbitrary order. | |
| Args: | |
| x: Float32/float64/float16 input tensor of the shape | |
| `[batch_size, num_channels, in_height, in_width]`. | |
| f: Float32 FIR filter of the shape | |
| `[filter_height, filter_width]` (non-separable), | |
| `[filter_taps]` (separable), or | |
| `None` (identity). | |
| up: Integer upsampling factor. Can be a single int or a list/tuple | |
| `[x, y]` (default: 1). | |
| down: Integer downsampling factor. Can be a single int or a list/tuple | |
| `[x, y]` (default: 1). | |
| padding: Padding with respect to the upsampled image. Can be a single number | |
| or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | |
| (default: 0). | |
| flip_filter: False = convolution, True = correlation (default: False). | |
| gain: Overall scaling factor for signal magnitude (default: 1). | |
| impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | |
| Returns: | |
| Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | |
| """ | |
| # assert isinstance(x, torch.Tensor) | |
| # assert impl in ['ref', 'cuda'] | |
| return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) | |
| def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
| """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. | |
| """ | |
| # Validate arguments. | |
| assert isinstance(x, torch.Tensor) and x.ndim == 4 | |
| if f is None: | |
| f = torch.ones([1, 1], dtype=torch.float32, device=x.device) | |
| assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | |
| assert f.dtype == torch.float32 and not f.requires_grad | |
| batch_size, num_channels, in_height, in_width = x.shape | |
| # upx, upy = _parse_scaling(up) | |
| # downx, downy = _parse_scaling(down) | |
| upx, upy = up, up | |
| downx, downy = down, down | |
| # padx0, padx1, pady0, pady1 = _parse_padding(padding) | |
| padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] | |
| # Upsample by inserting zeros. | |
| x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) | |
| x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) | |
| x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) | |
| # Pad or crop. | |
| x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) | |
| x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] | |
| # Setup filter. | |
| f = f * (gain ** (f.ndim / 2)) | |
| f = f.to(x.dtype) | |
| if not flip_filter: | |
| f = f.flip(list(range(f.ndim))) | |
| # Convolve with the filter. | |
| f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) | |
| if f.ndim == 4: | |
| x = conv2d(input=x, weight=f, groups=num_channels) | |
| else: | |
| x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) | |
| x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) | |
| # Downsample by throwing away pixels. | |
| x = x[:, :, ::downy, ::downx] | |
| return x | |
| def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): | |
| r"""Downsample a batch of 2D images using the given 2D FIR filter. | |
| By default, the result is padded so that its shape is a fraction of the input. | |
| User-specified padding is applied on top of that, with negative values | |
| indicating cropping. Pixels outside the image are assumed to be zero. | |
| Args: | |
| x: Float32/float64/float16 input tensor of the shape | |
| `[batch_size, num_channels, in_height, in_width]`. | |
| f: Float32 FIR filter of the shape | |
| `[filter_height, filter_width]` (non-separable), | |
| `[filter_taps]` (separable), or | |
| `None` (identity). | |
| down: Integer downsampling factor. Can be a single int or a list/tuple | |
| `[x, y]` (default: 1). | |
| padding: Padding with respect to the input. Can be a single number or a | |
| list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | |
| (default: 0). | |
| flip_filter: False = convolution, True = correlation (default: False). | |
| gain: Overall scaling factor for signal magnitude (default: 1). | |
| impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | |
| Returns: | |
| Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | |
| """ | |
| downx, downy = _parse_scaling(down) | |
| # padx0, padx1, pady0, pady1 = _parse_padding(padding) | |
| padx0, padx1, pady0, pady1 = padding, padding, padding, padding | |
| fw, fh = _get_filter_size(f) | |
| p = [ | |
| padx0 + (fw - downx + 1) // 2, | |
| padx1 + (fw - downx) // 2, | |
| pady0 + (fh - downy + 1) // 2, | |
| pady1 + (fh - downy) // 2, | |
| ] | |
| return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) | |
| def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): | |
| r"""Upsample a batch of 2D images using the given 2D FIR filter. | |
| By default, the result is padded so that its shape is a multiple of the input. | |
| User-specified padding is applied on top of that, with negative values | |
| indicating cropping. Pixels outside the image are assumed to be zero. | |
| Args: | |
| x: Float32/float64/float16 input tensor of the shape | |
| `[batch_size, num_channels, in_height, in_width]`. | |
| f: Float32 FIR filter of the shape | |
| `[filter_height, filter_width]` (non-separable), | |
| `[filter_taps]` (separable), or | |
| `None` (identity). | |
| up: Integer upsampling factor. Can be a single int or a list/tuple | |
| `[x, y]` (default: 1). | |
| padding: Padding with respect to the output. Can be a single number or a | |
| list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | |
| (default: 0). | |
| flip_filter: False = convolution, True = correlation (default: False). | |
| gain: Overall scaling factor for signal magnitude (default: 1). | |
| impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | |
| Returns: | |
| Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | |
| """ | |
| upx, upy = _parse_scaling(up) | |
| # upx, upy = up, up | |
| padx0, padx1, pady0, pady1 = _parse_padding(padding) | |
| # padx0, padx1, pady0, pady1 = padding, padding, padding, padding | |
| fw, fh = _get_filter_size(f) | |
| p = [ | |
| padx0 + (fw + upx - 1) // 2, | |
| padx1 + (fw - upx) // 2, | |
| pady0 + (fh + upy - 1) // 2, | |
| pady1 + (fh - upy) // 2, | |
| ] | |
| return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) | |
| class MinibatchStdLayer(torch.nn.Module): | |
| def __init__(self, group_size, num_channels=1): | |
| super().__init__() | |
| self.group_size = group_size | |
| self.num_channels = num_channels | |
| def forward(self, x): | |
| N, C, H, W = x.shape | |
| G = torch.min(torch.as_tensor(self.group_size), | |
| torch.as_tensor(N)) if self.group_size is not None else N | |
| F = self.num_channels | |
| c = C // F | |
| y = x.reshape(G, -1, F, c, H, | |
| W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. | |
| y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. | |
| y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. | |
| y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. | |
| y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. | |
| y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. | |
| y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. | |
| x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. | |
| return x | |
| class FullyConnectedLayer(torch.nn.Module): | |
| def __init__(self, | |
| in_features, # Number of input features. | |
| out_features, # Number of output features. | |
| bias=True, # Apply additive bias before the activation function? | |
| activation='linear', # Activation function: 'relu', 'lrelu', etc. | |
| lr_multiplier=1, # Learning rate multiplier. | |
| bias_init=0, # Initial value for the additive bias. | |
| ): | |
| super().__init__() | |
| self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) | |
| self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None | |
| self.activation = activation | |
| self.weight_gain = lr_multiplier / np.sqrt(in_features) | |
| self.bias_gain = lr_multiplier | |
| def forward(self, x): | |
| w = self.weight * self.weight_gain | |
| b = self.bias | |
| if b is not None and self.bias_gain != 1: | |
| b = b * self.bias_gain | |
| if self.activation == 'linear' and b is not None: | |
| # out = torch.addmm(b.unsqueeze(0), x, w.t()) | |
| x = x.matmul(w.t()) | |
| out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) | |
| else: | |
| x = x.matmul(w.t()) | |
| out = bias_act(x, b, act=self.activation, dim=x.ndim - 1) | |
| return out | |
| def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): | |
| """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. | |
| """ | |
| out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) | |
| # Flip weight if requested. | |
| if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). | |
| w = w.flip([2, 3]) | |
| # Workaround performance pitfall in cuDNN 8.0.5, triggered when using | |
| # 1x1 kernel + memory_format=channels_last + less than 64 channels. | |
| if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: | |
| if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: | |
| if out_channels <= 4 and groups == 1: | |
| in_shape = x.shape | |
| x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) | |
| x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) | |
| else: | |
| x = x.to(memory_format=torch.contiguous_format) | |
| w = w.to(memory_format=torch.contiguous_format) | |
| x = conv2d(x, w, groups=groups) | |
| return x.to(memory_format=torch.channels_last) | |
| # Otherwise => execute using conv2d_gradfix. | |
| op = conv_transpose2d if transpose else conv2d | |
| return op(x, w, stride=stride, padding=padding, groups=groups) | |
| def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): | |
| r"""2D convolution with optional up/downsampling. | |
| Padding is performed only once at the beginning, not between the operations. | |
| Args: | |
| x: Input tensor of shape | |
| `[batch_size, in_channels, in_height, in_width]`. | |
| w: Weight tensor of shape | |
| `[out_channels, in_channels//groups, kernel_height, kernel_width]`. | |
| f: Low-pass filter for up/downsampling. Must be prepared beforehand by | |
| calling setup_filter(). None = identity (default). | |
| up: Integer upsampling factor (default: 1). | |
| down: Integer downsampling factor (default: 1). | |
| padding: Padding with respect to the upsampled image. Can be a single number | |
| or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | |
| (default: 0). | |
| groups: Split input channels into N groups (default: 1). | |
| flip_weight: False = convolution, True = correlation (default: True). | |
| flip_filter: False = convolution, True = correlation (default: False). | |
| Returns: | |
| Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | |
| """ | |
| # Validate arguments. | |
| assert isinstance(x, torch.Tensor) and (x.ndim == 4) | |
| assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) | |
| assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) | |
| assert isinstance(up, int) and (up >= 1) | |
| assert isinstance(down, int) and (down >= 1) | |
| # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" | |
| out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) | |
| fw, fh = _get_filter_size(f) | |
| # px0, px1, py0, py1 = _parse_padding(padding) | |
| px0, px1, py0, py1 = padding, padding, padding, padding | |
| # Adjust padding to account for up/downsampling. | |
| if up > 1: | |
| px0 += (fw + up - 1) // 2 | |
| px1 += (fw - up) // 2 | |
| py0 += (fh + up - 1) // 2 | |
| py1 += (fh - up) // 2 | |
| if down > 1: | |
| px0 += (fw - down + 1) // 2 | |
| px1 += (fw - down) // 2 | |
| py0 += (fh - down + 1) // 2 | |
| py1 += (fh - down) // 2 | |
| # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. | |
| if kw == 1 and kh == 1 and (down > 1 and up == 1): | |
| x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) | |
| x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | |
| return x | |
| # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. | |
| if kw == 1 and kh == 1 and (up > 1 and down == 1): | |
| x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | |
| x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) | |
| return x | |
| # Fast path: downsampling only => use strided convolution. | |
| if down > 1 and up == 1: | |
| x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) | |
| x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) | |
| return x | |
| # Fast path: upsampling with optional downsampling => use transpose strided convolution. | |
| if up > 1: | |
| if groups == 1: | |
| w = w.transpose(0, 1) | |
| else: | |
| w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) | |
| w = w.transpose(1, 2) | |
| w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) | |
| px0 -= kw - 1 | |
| px1 -= kw - up | |
| py0 -= kh - 1 | |
| py1 -= kh - up | |
| pxt = max(min(-px0, -px1), 0) | |
| pyt = max(min(-py0, -py1), 0) | |
| x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, | |
| flip_weight=(not flip_weight)) | |
| x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, | |
| flip_filter=flip_filter) | |
| if down > 1: | |
| x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) | |
| return x | |
| # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. | |
| if up == 1 and down == 1: | |
| if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: | |
| return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) | |
| # Fallback: Generic reference implementation. | |
| x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, | |
| flip_filter=flip_filter) | |
| x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | |
| if down > 1: | |
| x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) | |
| return x | |
| class Conv2dLayer(torch.nn.Module): | |
| def __init__(self, | |
| in_channels, # Number of input channels. | |
| out_channels, # Number of output channels. | |
| kernel_size, # Width and height of the convolution kernel. | |
| bias=True, # Apply additive bias before the activation function? | |
| activation='linear', # Activation function: 'relu', 'lrelu', etc. | |
| up=1, # Integer upsampling factor. | |
| down=1, # Integer downsampling factor. | |
| resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. | |
| conv_clamp=None, # Clamp the output to +-X, None = disable clamping. | |
| channels_last=False, # Expect the input to have memory_format=channels_last? | |
| trainable=True, # Update the weights of this layer during training? | |
| ): | |
| super().__init__() | |
| self.activation = activation | |
| self.up = up | |
| self.down = down | |
| self.register_buffer('resample_filter', setup_filter(resample_filter)) | |
| self.conv_clamp = conv_clamp | |
| self.padding = kernel_size // 2 | |
| self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) | |
| self.act_gain = activation_funcs[activation].def_gain | |
| memory_format = torch.channels_last if channels_last else torch.contiguous_format | |
| weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) | |
| bias = torch.zeros([out_channels]) if bias else None | |
| if trainable: | |
| self.weight = torch.nn.Parameter(weight) | |
| self.bias = torch.nn.Parameter(bias) if bias is not None else None | |
| else: | |
| self.register_buffer('weight', weight) | |
| if bias is not None: | |
| self.register_buffer('bias', bias) | |
| else: | |
| self.bias = None | |
| def forward(self, x, gain=1): | |
| w = self.weight * self.weight_gain | |
| x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down, | |
| padding=self.padding) | |
| act_gain = self.act_gain * gain | |
| act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None | |
| out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) | |
| return out | |
| def torch_gc(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| def set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |