Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from .shelf import RAFT
from .interpolation import interpolate
from dot.utils.io import read_config
from dot.utils.torch import get_grid, get_sobel_kernel
class OpticalFlow(nn.Module):
def __init__(self, height, width, config, load_path):
super().__init__()
model_args = read_config(config)
model_dict = {"raft": RAFT}
self.model = model_dict[model_args.name](model_args)
self.name = model_args.name
if load_path is not None:
device = next(self.model.parameters()).device
self.model.load_state_dict(torch.load(load_path, map_location=device))
coarse_height, coarse_width = height // model_args.patch_size, width // model_args.patch_size
self.register_buffer("coarse_grid", get_grid(coarse_height, coarse_width))
def forward(self, data, mode, **kwargs):
if mode == "flow_with_tracks_init":
return self.get_flow_with_tracks_init(data, **kwargs)
elif mode == "motion_boundaries":
return self.get_motion_boundaries(data, **kwargs)
elif mode == "feats":
return self.get_feats(data, **kwargs)
elif mode == "tracks_for_queries":
return self.get_tracks_for_queries(data, **kwargs)
elif mode == "tracks_from_first_to_every_other_frame":
return self.get_tracks_from_first_to_every_other_frame(data, **kwargs)
elif mode == "flow_from_last_to_first_frame":
return self.get_flow_from_last_to_first_frame(data, **kwargs)
else:
raise ValueError(f"Unknown mode {mode}")
def get_motion_boundaries(self, data, boundaries_size=1, boundaries_dilation=4, boundaries_thresh=0.025, **kwargs):
eps = 1e-12
src_frame, tgt_frame = data["src_frame"], data["tgt_frame"]
K = boundaries_size * 2 + 1
D = boundaries_dilation
B, _, H, W = src_frame.shape
reflect = torch.nn.ReflectionPad2d(K // 2)
sobel_kernel = get_sobel_kernel(K).to(src_frame.device)
flow, _ = self.model(src_frame, tgt_frame)
norm_flow = torch.stack([flow[..., 0] / (W - 1), flow[..., 1] / (H - 1)], dim=-1)
norm_flow = norm_flow.permute(0, 3, 1, 2).reshape(-1, 1, H, W)
boundaries = F.conv2d(reflect(norm_flow), sobel_kernel)
boundaries = ((boundaries ** 2).sum(dim=1, keepdim=True) + eps).sqrt()
boundaries = boundaries.view(-1, 2, H, W).mean(dim=1, keepdim=True)
if boundaries_dilation > 1:
boundaries = torch.nn.functional.max_pool2d(boundaries, kernel_size=D * 2, stride=1, padding=D)
boundaries = boundaries[:, :, -H:, -W:]
boundaries = boundaries[:, 0]
boundaries = boundaries - boundaries.reshape(B, -1).min(dim=1)[0].reshape(B, 1, 1)
boundaries = boundaries / boundaries.reshape(B, -1).max(dim=1)[0].reshape(B, 1, 1)
boundaries = boundaries > boundaries_thresh
return {"motion_boundaries": boundaries, "flow": flow}
def get_feats(self, data, **kwargs):
video = data["video"]
feats = []
for step in tqdm(range(video.size(1)), desc="Extract feats for frame", leave=False):
feats.append(self.model.encode(video[:, step]))
feats = torch.stack(feats, dim=1)
return {"feats": feats}
def get_flow_with_tracks_init(self, data, is_train=False, interpolation_version="torch3d", alpha_thresh=0.8, **kwargs):
coarse_flow, coarse_alpha = interpolate(data["src_points"], data["tgt_points"], self.coarse_grid,
version=interpolation_version)
flow, alpha = self.model(src_frame=data["src_frame"] if "src_feats" not in data else None,
tgt_frame=data["tgt_frame"] if "tgt_feats" not in data else None,
src_feats=data["src_feats"] if "src_feats" in data else None,
tgt_feats=data["tgt_feats"] if "tgt_feats" in data else None,
coarse_flow=coarse_flow,
coarse_alpha=coarse_alpha,
is_train=is_train)
if not is_train:
alpha = (alpha > alpha_thresh).float()
return {"flow": flow, "alpha": alpha, "coarse_flow": coarse_flow, "coarse_alpha": coarse_alpha}
def get_tracks_for_queries(self, data, **kwargs):
raise NotImplementedError