| import os | |
| import argparse | |
| from PIL import Image | |
| from glob import glob | |
| import numpy as np | |
| import json | |
| import torch | |
| import torchvision | |
| from torch.nn import functional as F | |
| def create_folder(path, verbose=False, exist_ok=True, safe=True): | |
| if os.path.exists(path) and not exist_ok: | |
| if not safe: | |
| raise OSError | |
| return False | |
| try: | |
| os.makedirs(path) | |
| except: | |
| if not safe: | |
| raise OSError | |
| return False | |
| if verbose: | |
| print(f"Created folder: {path}") | |
| return True | |
| def read_video(path, start_step=0, time_steps=None, channels="first", exts=("jpg", "png"), resolution=None): | |
| if path.endswith(".mp4"): | |
| video = read_video_from_file(path, start_step, time_steps, channels, resolution) | |
| else: | |
| video = read_video_from_folder(path, start_step, time_steps, channels, resolution, exts) | |
| return video | |
| def read_video_from_file(path, start_step, time_steps, channels, resolution): | |
| video, _, _ = torchvision.io.read_video(path, output_format="TCHW", pts_unit="sec") | |
| if time_steps is None: | |
| time_steps = len(video) - start_step | |
| video = video[start_step: start_step + time_steps] | |
| if resolution is not None: | |
| video = F.interpolate(video, size=resolution, mode="bilinear") | |
| if channels == "last": | |
| video = video.permute(0, 2, 3, 1) | |
| video = video / 255. | |
| return video | |
| def read_video_from_folder(path, start_step, time_steps, channels, resolution, exts): | |
| paths = [] | |
| for ext in exts: | |
| paths += glob(os.path.join(path, f"*.{ext}")) | |
| paths = sorted(paths) | |
| if time_steps is None: | |
| time_steps = len(paths) - start_step | |
| video = [] | |
| for step in range(start_step, start_step + time_steps): | |
| frame = read_frame(paths[step], resolution, channels) | |
| video.append(frame) | |
| video = torch.stack(video) | |
| return video | |
| def read_frame(path, resolution=None, channels="first"): | |
| frame = Image.open(path).convert('RGB') | |
| frame = np.array(frame) | |
| frame = frame.astype(np.float32) | |
| frame = frame / 255 | |
| frame = torch.from_numpy(frame) | |
| frame = frame.permute(2, 0, 1) | |
| if resolution is not None: | |
| frame = F.interpolate(frame[None], size=resolution, mode="bilinear")[0] | |
| if channels == "last": | |
| frame = frame.permute(1, 2, 0) | |
| return frame | |
| def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"): | |
| if dtype == "numpy": | |
| video = torch.from_numpy(video) | |
| if path.endswith(".mp4"): | |
| write_video_to_file(video, path, channels) | |
| else: | |
| write_video_to_folder(video, path, channels, zero_padded, ext) | |
| def write_video_to_file(video, path, channels): | |
| create_folder(os.path.dirname(path)) | |
| if channels == "first": | |
| video = video.permute(0, 2, 3, 1) | |
| video = (video.cpu() * 255.).to(torch.uint8) | |
| torchvision.io.write_video(path, video, 24, "h264", options={"pix_fmt": "yuv420p", "crf": "23"}) | |
| return video | |
| def write_video_to_folder(video, path, channels, zero_padded, ext): | |
| create_folder(path) | |
| time_steps = video.shape[0] | |
| for step in range(time_steps): | |
| pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else "" | |
| frame_path = os.path.join(path, f"{pad}{step}.{ext}") | |
| write_frame(video[step], frame_path, channels) | |
| def write_frame(frame, path, channels="first"): | |
| create_folder(os.path.dirname(path)) | |
| frame = frame.cpu().numpy() | |
| if channels == "first": | |
| frame = np.transpose(frame, (1, 2, 0)) | |
| frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8) | |
| frame = Image.fromarray(frame) | |
| frame.save(path) | |
| def read_tracks(path): | |
| return np.load(path) | |
| def write_tracks(tracks, path): | |
| np.save(path, tracks) | |
| def read_config(path): | |
| with open(path, 'r') as f: | |
| config = json.load(f) | |
| args = argparse.Namespace(**config) | |
| return args | |