Spaces:
Runtime error
Runtime error
| import random | |
| from typing import Any, Optional | |
| import numpy as np | |
| import os | |
| import cv2 | |
| from glob import glob | |
| from PIL import Image, ImageDraw | |
| from tqdm import tqdm | |
| import kornia | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import albumentations as albu | |
| import functools | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| import torchvision as tv | |
| import torchvision.models as models | |
| from torchvision import transforms | |
| from torchvision.transforms import functional as F | |
| from losses import TempCombLoss | |
| ######## for loading checkpoint from googledrive | |
| google_drive_paths = { | |
| "BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL", | |
| "BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9", | |
| } | |
| def ensure_checkpoint_exists(model_weights_filename): | |
| if not os.path.isfile(model_weights_filename) and ( | |
| model_weights_filename in google_drive_paths | |
| ): | |
| gdrive_url = google_drive_paths[model_weights_filename] | |
| try: | |
| from gdown import download as drive_download | |
| drive_download(gdrive_url, model_weights_filename, quiet=False) | |
| except ModuleNotFoundError: | |
| print( | |
| "gdown module not found.", | |
| "pip3 install gdown or, manually download the checkpoint file:", | |
| gdrive_url | |
| ) | |
| if not os.path.isfile(model_weights_filename) and ( | |
| model_weights_filename not in google_drive_paths | |
| ): | |
| print( | |
| model_weights_filename, | |
| " not found, you may need to manually download the model weights." | |
| ) | |
| def normalize(image: np.ndarray) -> np.ndarray: | |
| """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. | |
| Args: | |
| image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. | |
| Returns: | |
| Normalized image data. Data range [0, 1]. | |
| """ | |
| return image.astype(np.float64) / 255.0 | |
| def unnormalize(image: np.ndarray) -> np.ndarray: | |
| """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. | |
| Args: | |
| image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. | |
| Returns: | |
| Denormalized image data. Data range [0, 255]. | |
| """ | |
| return image.astype(np.float64) * 255.0 | |
| def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor: | |
| """Convert ``PIL.Image`` to Tensor. | |
| Args: | |
| image (np.ndarray): The image data read by ``PIL.Image`` | |
| range_norm (bool): Scale [0, 1] data to between [-1, 1] | |
| half (bool): Whether to convert torch.float32 similarly to torch.half type. | |
| Returns: | |
| Normalized image data | |
| Examples: | |
| >>> image = Image.open("image.bmp") | |
| >>> tensor_image = image2tensor(image, range_norm=False, half=False) | |
| """ | |
| tensor = F.to_tensor(image) | |
| if range_norm: | |
| tensor = tensor.mul_(2.0).sub_(1.0) | |
| if half: | |
| tensor = tensor.half() | |
| return tensor | |
| def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any: | |
| """Converts ``torch.Tensor`` to ``PIL.Image``. | |
| Args: | |
| tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image`` | |
| range_norm (bool): Scale [-1, 1] data to between [0, 1] | |
| half (bool): Whether to convert torch.float32 similarly to torch.half type. | |
| Returns: | |
| Convert image data to support PIL library | |
| Examples: | |
| >>> tensor = torch.randn([1, 3, 128, 128]) | |
| >>> image = tensor2image(tensor, range_norm=False, half=False) | |
| """ | |
| if range_norm: | |
| tensor = tensor.add_(1.0).div_(2.0) | |
| if half: | |
| tensor = tensor.half() | |
| image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8") | |
| return image | |