| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from .shelf import RAFT | |
| from .interpolation import interpolate | |
| from dot.utils.io import read_config | |
| from dot.utils.torch import get_grid, get_sobel_kernel | |
| class OpticalFlow(nn.Module): | |
| def __init__(self, height, width, config, load_path): | |
| super().__init__() | |
| model_args = read_config(config) | |
| model_dict = {"raft": RAFT} | |
| self.model = model_dict[model_args.name](model_args) | |
| self.name = model_args.name | |
| if load_path is not None: | |
| device = next(self.model.parameters()).device | |
| self.model.load_state_dict(torch.load(load_path, map_location=device)) | |
| coarse_height, coarse_width = height // model_args.patch_size, width // model_args.patch_size | |
| self.register_buffer("coarse_grid", get_grid(coarse_height, coarse_width)) | |
| def forward(self, data, mode, **kwargs): | |
| if mode == "flow_with_tracks_init": | |
| return self.get_flow_with_tracks_init(data, **kwargs) | |
| elif mode == "motion_boundaries": | |
| return self.get_motion_boundaries(data, **kwargs) | |
| elif mode == "feats": | |
| return self.get_feats(data, **kwargs) | |
| elif mode == "tracks_for_queries": | |
| return self.get_tracks_for_queries(data, **kwargs) | |
| elif mode == "tracks_from_first_to_every_other_frame": | |
| return self.get_tracks_from_first_to_every_other_frame(data, **kwargs) | |
| elif mode == "flow_from_last_to_first_frame": | |
| return self.get_flow_from_last_to_first_frame(data, **kwargs) | |
| else: | |
| raise ValueError(f"Unknown mode {mode}") | |
| def get_motion_boundaries(self, data, boundaries_size=1, boundaries_dilation=4, boundaries_thresh=0.025, **kwargs): | |
| eps = 1e-12 | |
| src_frame, tgt_frame = data["src_frame"], data["tgt_frame"] | |
| K = boundaries_size * 2 + 1 | |
| D = boundaries_dilation | |
| B, _, H, W = src_frame.shape | |
| reflect = torch.nn.ReflectionPad2d(K // 2) | |
| sobel_kernel = get_sobel_kernel(K).to(src_frame.device) | |
| flow, _ = self.model(src_frame, tgt_frame) | |
| norm_flow = torch.stack([flow[..., 0] / (W - 1), flow[..., 1] / (H - 1)], dim=-1) | |
| norm_flow = norm_flow.permute(0, 3, 1, 2).reshape(-1, 1, H, W) | |
| boundaries = F.conv2d(reflect(norm_flow), sobel_kernel) | |
| boundaries = ((boundaries ** 2).sum(dim=1, keepdim=True) + eps).sqrt() | |
| boundaries = boundaries.view(-1, 2, H, W).mean(dim=1, keepdim=True) | |
| if boundaries_dilation > 1: | |
| boundaries = torch.nn.functional.max_pool2d(boundaries, kernel_size=D * 2, stride=1, padding=D) | |
| boundaries = boundaries[:, :, -H:, -W:] | |
| boundaries = boundaries[:, 0] | |
| boundaries = boundaries - boundaries.reshape(B, -1).min(dim=1)[0].reshape(B, 1, 1) | |
| boundaries = boundaries / boundaries.reshape(B, -1).max(dim=1)[0].reshape(B, 1, 1) | |
| boundaries = boundaries > boundaries_thresh | |
| return {"motion_boundaries": boundaries, "flow": flow} | |
| def get_feats(self, data, **kwargs): | |
| video = data["video"] | |
| feats = [] | |
| for step in tqdm(range(video.size(1)), desc="Extract feats for frame", leave=False): | |
| feats.append(self.model.encode(video[:, step])) | |
| feats = torch.stack(feats, dim=1) | |
| return {"feats": feats} | |
| def get_flow_with_tracks_init(self, data, is_train=False, interpolation_version="torch3d", alpha_thresh=0.8, **kwargs): | |
| coarse_flow, coarse_alpha = interpolate(data["src_points"], data["tgt_points"], self.coarse_grid, | |
| version=interpolation_version) | |
| flow, alpha = self.model(src_frame=data["src_frame"] if "src_feats" not in data else None, | |
| tgt_frame=data["tgt_frame"] if "tgt_feats" not in data else None, | |
| src_feats=data["src_feats"] if "src_feats" in data else None, | |
| tgt_feats=data["tgt_feats"] if "tgt_feats" in data else None, | |
| coarse_flow=coarse_flow, | |
| coarse_alpha=coarse_alpha, | |
| is_train=is_train) | |
| if not is_train: | |
| alpha = (alpha > alpha_thresh).float() | |
| return {"flow": flow, "alpha": alpha, "coarse_flow": coarse_flow, "coarse_alpha": coarse_alpha} | |
| def get_tracks_for_queries(self, data, **kwargs): | |
| raise NotImplementedError | |