|
|
import os.path as osp
|
|
|
import matplotlib
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
from dot.utils.io import create_folder
|
|
|
|
|
|
|
|
|
def to_rgb(tensor, mode, tracks=None, is_video=False, to_torch=True, reshape_as_video=False):
|
|
|
if isinstance(tensor, list):
|
|
|
tensor = torch.stack(tensor)
|
|
|
tensor = tensor.cpu().numpy()
|
|
|
if is_video:
|
|
|
batch_size, time_steps = tensor.shape[:2]
|
|
|
if mode == "flow":
|
|
|
height, width = tensor.shape[-3: -1]
|
|
|
tensor = np.reshape(tensor, (-1, height, width, 2))
|
|
|
tensor = flow_to_rgb(tensor)
|
|
|
elif mode == "mask":
|
|
|
height, width = tensor.shape[-2:]
|
|
|
tensor = np.reshape(tensor, (-1, 1, height, width))
|
|
|
tensor = np.repeat(tensor, 3, axis=1)
|
|
|
else:
|
|
|
height, width = tensor.shape[-2:]
|
|
|
tensor = np.reshape(tensor, (-1, 3, height, width))
|
|
|
if tracks is not None:
|
|
|
samples = tracks.size(-2)
|
|
|
tracks = tracks.cpu().numpy()
|
|
|
tracks = np.reshape(tracks, (-1, samples, 3))
|
|
|
traj, occ = tracks[..., :2], 1 - tracks[..., 2]
|
|
|
if is_video:
|
|
|
tensor = np.reshape(tensor, (-1, time_steps, 3, height, width))
|
|
|
traj = np.reshape(traj, (-1, time_steps, samples, 2))
|
|
|
occ = np.reshape(occ, (-1, time_steps, samples))
|
|
|
new_tensor = []
|
|
|
for t in range(time_steps):
|
|
|
pos_t = traj[:, t]
|
|
|
occ_t = occ[:, t]
|
|
|
new_tensor.append(plot_tracks(tensor[:, t], pos_t, occ_t, tracks=traj[:, :t + 1]))
|
|
|
tensor = np.stack(new_tensor, axis=1)
|
|
|
else:
|
|
|
tensor = plot_tracks(tensor, traj, occ)
|
|
|
if is_video and reshape_as_video:
|
|
|
tensor = np.reshape(tensor, (batch_size, time_steps, 3, height, width))
|
|
|
else:
|
|
|
tensor = np.reshape(tensor, (-1, 3, height, width))
|
|
|
if to_torch:
|
|
|
tensor = torch.from_numpy(tensor)
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
def flow_to_rgb(flow, transparent=False):
|
|
|
flow = np.copy(flow)
|
|
|
H, W = flow.shape[-3: -1]
|
|
|
mul = 20.
|
|
|
scaling = mul / (H ** 2 + W ** 2) ** 0.5
|
|
|
direction = (np.arctan2(flow[..., 0], flow[..., 1]) + np.pi) / (2 * np.pi)
|
|
|
norm = np.linalg.norm(flow, axis=-1)
|
|
|
magnitude = np.clip(norm * scaling, 0., 1.)
|
|
|
saturation = np.ones_like(direction)
|
|
|
if transparent:
|
|
|
hsv = np.stack([direction, saturation, np.ones_like(magnitude)], axis=-1)
|
|
|
else:
|
|
|
hsv = np.stack([direction, saturation, magnitude], axis=-1)
|
|
|
rgb = matplotlib.colors.hsv_to_rgb(hsv)
|
|
|
rgb = np.moveaxis(rgb, -1, -3)
|
|
|
if transparent:
|
|
|
return np.concatenate([rgb, np.expand_dims(magnitude, axis=-3)], axis=-3)
|
|
|
return rgb
|
|
|
|
|
|
|
|
|
def plot_tracks(rgb, points, occluded, tracks=None, trackgroup=None):
|
|
|
"""Plot tracks with matplotlib.
|
|
|
Adapted from: https://github.com/google-research/kubric/blob/main/challenges/point_tracking/dataset.py"""
|
|
|
rgb = rgb.transpose(0, 2, 3, 1)
|
|
|
_, height, width, _ = rgb.shape
|
|
|
points = points.transpose(1, 0, 2).copy()
|
|
|
|
|
|
|
|
|
if tracks is not None:
|
|
|
tracks = tracks.copy()
|
|
|
|
|
|
|
|
|
if occluded is not None:
|
|
|
occluded = occluded.transpose(1, 0)
|
|
|
disp = []
|
|
|
cmap = plt.cm.hsv
|
|
|
|
|
|
z_list = np.arange(points.shape[0]) if trackgroup is None else np.array(trackgroup)
|
|
|
|
|
|
np.random.seed(0)
|
|
|
z_list = np.random.permutation(np.max(z_list) + 1)[z_list]
|
|
|
colors = cmap(z_list / (np.max(z_list) + 1))
|
|
|
figure_dpi = 64
|
|
|
|
|
|
for i in range(rgb.shape[0]):
|
|
|
fig = plt.figure(
|
|
|
figsize=(width / figure_dpi, height / figure_dpi),
|
|
|
dpi=figure_dpi,
|
|
|
frameon=False,
|
|
|
facecolor='w')
|
|
|
ax = fig.add_subplot()
|
|
|
ax.axis('off')
|
|
|
ax.imshow(rgb[i])
|
|
|
|
|
|
valid = points[:, i, 0] > 0
|
|
|
valid = np.logical_and(valid, points[:, i, 0] < rgb.shape[2] - 1)
|
|
|
valid = np.logical_and(valid, points[:, i, 1] > 0)
|
|
|
valid = np.logical_and(valid, points[:, i, 1] < rgb.shape[1] - 1)
|
|
|
|
|
|
if occluded is not None:
|
|
|
colalpha = np.concatenate([colors[:, :-1], 1 - occluded[:, i:i + 1]], axis=1)
|
|
|
else:
|
|
|
colalpha = colors[:, :-1]
|
|
|
|
|
|
ax.scatter(
|
|
|
points[valid, i, 0] - 0.5,
|
|
|
points[valid, i, 1] - 0.5,
|
|
|
s=3,
|
|
|
c=colalpha[valid],
|
|
|
)
|
|
|
|
|
|
if tracks is not None:
|
|
|
for j in range(tracks.shape[2]):
|
|
|
track_color = colors[j]
|
|
|
x = tracks[i, :, j, 0]
|
|
|
y = tracks[i, :, j, 1]
|
|
|
valid_track = x > 0
|
|
|
valid_track = np.logical_and(valid_track, x < rgb.shape[2] - 1)
|
|
|
valid_track = np.logical_and(valid_track, y > 0)
|
|
|
valid_track = np.logical_and(valid_track, y < rgb.shape[1] - 1)
|
|
|
ax.plot(x[valid_track] - 0.5, y[valid_track] - 0.5, color=track_color, marker=None)
|
|
|
|
|
|
if occluded is not None:
|
|
|
occ2 = occluded[:, i:i + 1]
|
|
|
|
|
|
colalpha = np.concatenate([colors[:, :-1], occ2], axis=1)
|
|
|
|
|
|
ax.scatter(
|
|
|
points[valid, i, 0],
|
|
|
points[valid, i, 1],
|
|
|
s=20,
|
|
|
facecolors='none',
|
|
|
edgecolors=colalpha[valid],
|
|
|
)
|
|
|
|
|
|
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
plt.margins(0, 0)
|
|
|
fig.canvas.draw()
|
|
|
width, height = fig.get_size_inches() * fig.get_dpi()
|
|
|
img = np.frombuffer(
|
|
|
fig.canvas.tostring_rgb(),
|
|
|
dtype='uint8').reshape(int(height), int(width), 3)
|
|
|
disp.append(np.copy(img))
|
|
|
plt.close("all")
|
|
|
|
|
|
return np.stack(disp, axis=0).astype(float).transpose(0, 3, 1, 2) / 255
|
|
|
|
|
|
|
|
|
def plot_points(src_frame, tgt_frame, src_points, tgt_points, save_path, max_points=256):
|
|
|
_, H, W = src_frame.shape
|
|
|
src_frame = src_frame.permute(1, 2, 0).cpu().numpy()
|
|
|
tgt_frame = tgt_frame.permute(1, 2, 0).cpu().numpy()
|
|
|
src_points = src_points.cpu().numpy()
|
|
|
tgt_points = tgt_points.cpu().numpy()
|
|
|
src_pos, src_alpha = src_points[..., :2], src_points[..., 2]
|
|
|
tgt_pos, tgt_alpha = tgt_points[..., :2], tgt_points[..., 2]
|
|
|
src_pos = np.stack([src_pos[..., 0] * (W - 1), src_pos[..., 1] * (H - 1)], axis=-1)
|
|
|
tgt_pos = np.stack([tgt_pos[..., 0] * (W - 1), tgt_pos[..., 1] * (H - 1)], axis=-1)
|
|
|
|
|
|
plt.figure()
|
|
|
ax = plt.gca()
|
|
|
P = 10
|
|
|
plt.imshow(np.concatenate((src_frame, np.ones_like(src_frame[:, :P]), tgt_frame), axis=1))
|
|
|
indices = np.random.choice(len(src_pos), size=min(max_points, len(src_pos)), replace=False)
|
|
|
for i in indices:
|
|
|
if src_alpha[i] == 1:
|
|
|
ax.scatter(src_pos[i, 0], src_pos[i, 1], s=5, c="black", marker='x')
|
|
|
else:
|
|
|
ax.scatter(src_pos[i, 0], src_pos[i, 1], s=5, linewidths=1.5, c="black", marker='o')
|
|
|
ax.scatter(src_pos[i, 0], src_pos[i, 1], s=2.5, c="white", marker='o')
|
|
|
if tgt_alpha[i] == 1:
|
|
|
ax.scatter(W + P + tgt_pos[i, 0], tgt_pos[i, 1], s=5, c="black", marker='x')
|
|
|
else:
|
|
|
ax.scatter(W + P + tgt_pos[i, 0], tgt_pos[i, 1], s=5, linewidths=1.5, c="black", marker='o')
|
|
|
ax.scatter(W + P + tgt_pos[i, 0], tgt_pos[i, 1], s=2.5, c="white", marker='o')
|
|
|
|
|
|
plt.plot([src_pos[i, 0], W + P + tgt_pos[i, 0]], [src_pos[i, 1], tgt_pos[i, 1]], linewidth=0.5, c="black")
|
|
|
|
|
|
|
|
|
ax.axis('off')
|
|
|
plt.tight_layout()
|
|
|
plt.subplots_adjust(wspace=0)
|
|
|
create_folder(osp.dirname(save_path))
|
|
|
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
|
|
plt.close() |