| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| import torch | |
| from torchvision.utils import make_grid | |
| from torch.utils.tensorboard import SummaryWriter | |
| from dot.utils.plot import to_rgb | |
| def detach(tensor): | |
| if isinstance(tensor, torch.Tensor): | |
| return tensor.detach().cpu() | |
| return tensor | |
| def number(tensor): | |
| if isinstance(tensor, torch.Tensor) and tensor.isnan().any(): | |
| return torch.zeros_like(tensor) | |
| return tensor | |
| class Logger(): | |
| def __init__(self, args): | |
| self.writer = SummaryWriter(args.log_path) | |
| self.factor = args.log_factor | |
| self.world_size = args.world_size | |
| def log_scalar(self, name, scalar, global_iter): | |
| if scalar is not None: | |
| if type(scalar) == list: | |
| for i, x in enumerate(scalar): | |
| self.log_scalar(f"{name}_{i}", x, global_iter) | |
| else: | |
| self.writer.add_scalar(name, number(detach(scalar)), global_iter) | |
| def log_scalars(self, name, scalars, global_iter): | |
| for s in scalars: | |
| self.log_scalar(f"{name}/{s}", scalars[s], global_iter) | |
| def log_image(self, name, tensor, mode, nrow, global_iter, pos=None, occ=None): | |
| tensor = detach(tensor) | |
| tensor = to_rgb(tensor, mode, pos, occ) | |
| grid = make_grid(tensor, nrow=nrow, normalize=False, value_range=[0, 1], pad_value=0) | |
| grid = torch.nn.functional.interpolate(grid[None], scale_factor=self.factor)[0] | |
| self.writer.add_image(name, grid, global_iter) | |
| def log_video(self, name, tensor, mode, nrow, global_iter, fps=4, pos=None, occ=None): | |
| tensor = detach(tensor) | |
| tensor = to_rgb(tensor, mode, pos, occ, is_video=True) | |
| grid = [] | |
| for i in range(tensor.shape[1]): | |
| grid.append(make_grid(tensor[:, i], nrow=nrow, normalize=False, value_range=[0, 1], pad_value=0)) | |
| grid = torch.stack(grid, dim=0) | |
| grid = torch.nn.functional.interpolate(grid, scale_factor=self.factor) | |
| grid = grid[None] | |
| self.writer.add_video(name, grid, global_iter, fps=fps) |