import torch from torch import nn import torch.nn.functional as F from tqdm import tqdm from .shelf import RAFT from .interpolation import interpolate from dot.utils.io import read_config from dot.utils.torch import get_grid, get_sobel_kernel class OpticalFlow(nn.Module): def __init__(self, height, width, config, load_path): super().__init__() model_args = read_config(config) model_dict = {"raft": RAFT} self.model = model_dict[model_args.name](model_args) self.name = model_args.name if load_path is not None: device = next(self.model.parameters()).device self.model.load_state_dict(torch.load(load_path, map_location=device)) coarse_height, coarse_width = height // model_args.patch_size, width // model_args.patch_size self.register_buffer("coarse_grid", get_grid(coarse_height, coarse_width)) def forward(self, data, mode, **kwargs): if mode == "flow_with_tracks_init": return self.get_flow_with_tracks_init(data, **kwargs) elif mode == "motion_boundaries": return self.get_motion_boundaries(data, **kwargs) elif mode == "feats": return self.get_feats(data, **kwargs) elif mode == "tracks_for_queries": return self.get_tracks_for_queries(data, **kwargs) elif mode == "tracks_from_first_to_every_other_frame": return self.get_tracks_from_first_to_every_other_frame(data, **kwargs) elif mode == "flow_from_last_to_first_frame": return self.get_flow_from_last_to_first_frame(data, **kwargs) else: raise ValueError(f"Unknown mode {mode}") def get_motion_boundaries(self, data, boundaries_size=1, boundaries_dilation=4, boundaries_thresh=0.025, **kwargs): eps = 1e-12 src_frame, tgt_frame = data["src_frame"], data["tgt_frame"] K = boundaries_size * 2 + 1 D = boundaries_dilation B, _, H, W = src_frame.shape reflect = torch.nn.ReflectionPad2d(K // 2) sobel_kernel = get_sobel_kernel(K).to(src_frame.device) flow, _ = self.model(src_frame, tgt_frame) norm_flow = torch.stack([flow[..., 0] / (W - 1), flow[..., 1] / (H - 1)], dim=-1) norm_flow = norm_flow.permute(0, 3, 1, 2).reshape(-1, 1, H, W) boundaries = F.conv2d(reflect(norm_flow), sobel_kernel) boundaries = ((boundaries ** 2).sum(dim=1, keepdim=True) + eps).sqrt() boundaries = boundaries.view(-1, 2, H, W).mean(dim=1, keepdim=True) if boundaries_dilation > 1: boundaries = torch.nn.functional.max_pool2d(boundaries, kernel_size=D * 2, stride=1, padding=D) boundaries = boundaries[:, :, -H:, -W:] boundaries = boundaries[:, 0] boundaries = boundaries - boundaries.reshape(B, -1).min(dim=1)[0].reshape(B, 1, 1) boundaries = boundaries / boundaries.reshape(B, -1).max(dim=1)[0].reshape(B, 1, 1) boundaries = boundaries > boundaries_thresh return {"motion_boundaries": boundaries, "flow": flow} def get_feats(self, data, **kwargs): video = data["video"] feats = [] for step in tqdm(range(video.size(1)), desc="Extract feats for frame", leave=False): feats.append(self.model.encode(video[:, step])) feats = torch.stack(feats, dim=1) return {"feats": feats} def get_flow_with_tracks_init(self, data, is_train=False, interpolation_version="torch3d", alpha_thresh=0.8, **kwargs): coarse_flow, coarse_alpha = interpolate(data["src_points"], data["tgt_points"], self.coarse_grid, version=interpolation_version) flow, alpha = self.model(src_frame=data["src_frame"] if "src_feats" not in data else None, tgt_frame=data["tgt_frame"] if "tgt_feats" not in data else None, src_feats=data["src_feats"] if "src_feats" in data else None, tgt_feats=data["tgt_feats"] if "tgt_feats" in data else None, coarse_flow=coarse_flow, coarse_alpha=coarse_alpha, is_train=is_train) if not is_train: alpha = (alpha > alpha_thresh).float() return {"flow": flow, "alpha": alpha, "coarse_flow": coarse_flow, "coarse_alpha": coarse_alpha} def get_tracks_for_queries(self, data, **kwargs): raise NotImplementedError