Spaces:
Running
Running
| """Utility functions for videos, plotting and computing performance metrics.""" | |
| import os | |
| import typing | |
| import cv2 # pytype: disable=attribute-error | |
| import matplotlib | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| from . import video | |
| from . import segmentation | |
| def loadvideo(filename: str) -> np.ndarray: | |
| """Loads a video from a file. | |
| Args: | |
| filename (str): filename of video | |
| Returns: | |
| A np.ndarray with dimensions (channels=3, frames, height, width). The | |
| values will be uint8's ranging from 0 to 255. | |
| Raises: | |
| FileNotFoundError: Could not find `filename` | |
| ValueError: An error occurred while reading the video | |
| """ | |
| if not os.path.exists(filename): | |
| raise FileNotFoundError(filename) | |
| capture = cv2.VideoCapture(filename) | |
| frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8) | |
| for count in range(frame_count): | |
| ret, frame = capture.read() | |
| if not ret: | |
| raise ValueError("Failed to load frame #{} of {}.".format(count, filename)) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| v[count, :, :] = frame | |
| v = v.transpose((3, 0, 1, 2)) | |
| return v | |
| def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1): | |
| """Saves a video to a file. | |
| Args: | |
| filename (str): filename of video | |
| array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width) | |
| fps (float or int): frames per second | |
| Returns: | |
| None | |
| """ | |
| c, _, height, width = array.shape | |
| if c != 3: | |
| raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape)))) | |
| fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') | |
| out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) | |
| for frame in array.transpose((1, 2, 3, 0)): | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| out.write(frame) | |
| def get_mean_and_std(dataset: torch.utils.data.Dataset, | |
| samples: int = 128, | |
| batch_size: int = 8, | |
| num_workers: int = 4): | |
| """Computes mean and std from samples from a Pytorch dataset. | |
| Args: | |
| dataset (torch.utils.data.Dataset): A Pytorch dataset. | |
| ``dataset[i][0]'' is expected to be the i-th video in the dataset, which | |
| should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) | |
| samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and | |
| standard deviation are computed over all elements. | |
| Defaults to 128. | |
| batch_size (int, optional): how many samples per batch to load | |
| Defaults to 8. | |
| num_workers (int, optional): how many subprocesses to use for data | |
| loading. If 0, the data will be loaded in the main process. | |
| Defaults to 4. | |
| Returns: | |
| A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,). | |
| """ | |
| if samples is not None and len(dataset) > samples: | |
| indices = np.random.choice(len(dataset), samples, replace=False) | |
| dataset = torch.utils.data.Subset(dataset, indices) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) | |
| n = 0 # number of elements taken (should be equal to samples by end of for loop) | |
| s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,)) | |
| s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,)) | |
| for (x, *_) in tqdm.tqdm(dataloader): | |
| x = x.transpose(0, 1).contiguous().view(3, -1) | |
| n += x.shape[1] | |
| s1 += torch.sum(x, dim=1).numpy() | |
| s2 += torch.sum(x ** 2, dim=1).numpy() | |
| mean = s1 / n # type: np.ndarray | |
| std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray | |
| mean = mean.astype(np.float32) | |
| std = std.astype(np.float32) | |
| return mean, std | |
| def bootstrap(a, b, func, samples=10000): | |
| """Computes a bootstrapped confidence intervals for ``func(a, b)''. | |
| Args: | |
| a (array_like): first argument to `func`. | |
| b (array_like): second argument to `func`. | |
| func (callable): Function to compute confidence intervals for. | |
| ``dataset[i][0]'' is expected to be the i-th video in the dataset, which | |
| should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) | |
| samples (int, optional): Number of samples to compute. | |
| Defaults to 10000. | |
| Returns: | |
| A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile). | |
| """ | |
| a = np.array(a) | |
| b = np.array(b) | |
| bootstraps = [] | |
| for _ in range(samples): | |
| ind = np.random.choice(len(a), len(a)) | |
| bootstraps.append(func(a[ind], b[ind])) | |
| bootstraps = sorted(bootstraps) | |
| return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))] | |
| def latexify(): | |
| """Sets matplotlib params to appear more like LaTeX. | |
| Based on https://nipunbatra.github.io/blog/2014/latexify.html | |
| """ | |
| params = {'backend': 'pdf', | |
| 'axes.titlesize': 8, | |
| 'axes.labelsize': 8, | |
| 'font.size': 8, | |
| 'legend.fontsize': 8, | |
| 'xtick.labelsize': 8, | |
| 'ytick.labelsize': 8, | |
| 'font.family': 'DejaVu Serif', | |
| 'font.serif': 'Computer Modern', | |
| } | |
| matplotlib.rcParams.update(params) | |
| def dice_similarity_coefficient(inter, union): | |
| """Computes the dice similarity coefficient. | |
| Args: | |
| inter (iterable): iterable of the intersections | |
| union (iterable): iterable of the unions | |
| """ | |
| return 2 * sum(inter) / (sum(union) + sum(inter)) | |
| __all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"] | |