Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import imageio | |
| from my.utils.tqdm import tqdm | |
| from my.utils.event import EventStorage, read_stats, get_event_storage | |
| from my.utils.heartbeat import HeartBeat, get_heartbeat | |
| from my.utils.debug import EarlyLoopBreak | |
| from .utils import PSNR, Scrambler, every, at | |
| from .data import load_blender | |
| from .render import ( | |
| as_torch_tsrs, scene_box_filter, render_ray_bundle, render_one_view, rays_from_img | |
| ) | |
| from .vis import vis, stitch_vis | |
| device_glb = torch.device("cuda") | |
| def all_train_rays(scene): | |
| imgs, K, poses = load_blender("train", scene) | |
| num_imgs = len(imgs) | |
| ro, rd, rgbs = [], [], [] | |
| for i in tqdm(range(num_imgs)): | |
| img, pose = imgs[i], poses[i] | |
| H, W = img.shape[:2] | |
| _ro, _rd = rays_from_img(H, W, K, pose) | |
| ro.append(_ro) | |
| rd.append(_rd) | |
| rgbs.append(img.reshape(-1, 3)) | |
| ro, rd, rgbs = [ | |
| np.concatenate(xs, axis=0) for xs in (ro, rd, rgbs) | |
| ] | |
| return ro, rd, rgbs | |
| class OneTestView(): | |
| def __init__(self, scene): | |
| imgs, K, poses = load_blender("test", scene) | |
| self.imgs, self.K, self.poses = imgs, K, poses | |
| self.i = 0 | |
| def render(self, model): | |
| i = self.i | |
| img, K, pose = self.imgs[i], self.K, self.poses[i] | |
| with torch.no_grad(): | |
| aabb = model.aabb.T.cpu().numpy() | |
| H, W = img.shape[:2] | |
| rgbs, depth = render_one_view(model, aabb, H, W, K, pose) | |
| psnr = PSNR.psnr(img, rgbs) | |
| self.i = (self.i + 1) % len(self.imgs) | |
| return img, rgbs, depth, psnr | |
| def train( | |
| model, n_epoch=2, bs=4096, lr=0.02, scene="lego" | |
| ): | |
| fuse = EarlyLoopBreak(500) | |
| aabb = model.aabb.T.numpy() | |
| model = model.to(device_glb) | |
| optim = torch.optim.Adam(model.parameters(), lr=lr) | |
| test_view = OneTestView(scene) | |
| all_ro, all_rd, all_rgbs = all_train_rays(scene) | |
| with tqdm(total=(n_epoch * len(all_ro) // bs)) as pbar, \ | |
| HeartBeat(pbar) as hbeat, EventStorage() as metric: | |
| ro, rd, t_min, t_max, intsct_inds = scene_box_filter(all_ro, all_rd, aabb) | |
| rgbs = all_rgbs[intsct_inds] | |
| for epc in range(n_epoch): | |
| n = len(ro) | |
| scrambler = Scrambler(n) | |
| ro, rd, t_min, t_max, rgbs = scrambler.apply(ro, rd, t_min, t_max, rgbs) | |
| num_batch = int(np.ceil(n / bs)) | |
| for i in range(num_batch): | |
| if fuse.on_break(): | |
| break | |
| s = i * bs | |
| e = min(n, s + bs) | |
| optim.zero_grad() | |
| _ro, _rd, _t_min, _t_max, _rgbs = as_torch_tsrs( | |
| model.device, ro[s:e], rd[s:e], t_min[s:e], t_max[s:e], rgbs[s:e] | |
| ) | |
| pred, _, _ = render_ray_bundle(model, _ro, _rd, _t_min, _t_max) | |
| loss = ((pred - _rgbs) ** 2).mean() | |
| loss.backward() | |
| optim.step() | |
| pbar.update() | |
| psnr = PSNR.psnr_from_mse(loss.item()) | |
| metric.put_scalars(psnr=psnr, d_scale=model.d_scale.item()) | |
| if every(pbar, step=50): | |
| pbar.set_description(f"TRAIN: psnr {psnr:.2f}") | |
| if every(pbar, percent=1): | |
| gimg, rimg, depth, psnr = test_view.render(model) | |
| pane = vis( | |
| gimg, rimg, depth, | |
| msg=f"psnr: {psnr:.2f}", return_buffer=True | |
| ) | |
| metric.put_artifact( | |
| "vis", ".png", lambda fn: imageio.imwrite(fn, pane) | |
| ) | |
| if at(pbar, percent=30): | |
| model.make_alpha_mask() | |
| if every(pbar, percent=35): | |
| target_xyz = (model.grid_size * 1.328).int().tolist() | |
| model.resample(target_xyz) | |
| optim = torch.optim.Adam(model.parameters(), lr=lr) | |
| print(f"resamp the voxel to {model.grid_size}") | |
| curr_lr = update_lr(pbar, optim, lr) | |
| metric.put_scalars(lr=curr_lr) | |
| metric.step() | |
| hbeat.beat() | |
| metric.put_artifact( | |
| "ckpt", ".pt", lambda fn: torch.save(model.state_dict(), fn) | |
| ) | |
| # metric.step(flush=True) # no need to flush since the test routine directly takes the model | |
| metric.put_artifact( | |
| "train_seq", ".mp4", | |
| lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "vis")[1]) | |
| ) | |
| with EventStorage("test"): | |
| final_psnr = test(model, scene) | |
| metric.put("test_psnr", final_psnr) | |
| metric.step() | |
| hbeat.done() | |
| def update_lr(pbar, optimizer, init_lr): | |
| i, N = pbar.n, pbar.total | |
| factor = 0.1 ** (1 / N) | |
| lr = init_lr * (factor ** i) | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| return lr | |
| def last_ckpt(): | |
| ts, ckpts = read_stats("./", "ckpt") | |
| if len(ckpts) > 0: | |
| fname = ckpts[-1] | |
| last = torch.load(fname, map_location="cpu") | |
| print(f"loaded ckpt from iter {ts[-1]}") | |
| return last | |
| def __evaluate_ckpt(model, scene): | |
| # this is for external script that needs to evaluate an checkpoint | |
| # currently not used | |
| metric = get_event_storage() | |
| state = last_ckpt() | |
| if state is not None: | |
| model.load_state_dict(state) | |
| model.to(device_glb) | |
| with EventStorage("test"): | |
| final_psnr = test(model, scene) | |
| metric.put("test_psnr", final_psnr) | |
| def test(model, scene): | |
| fuse = EarlyLoopBreak(5) | |
| metric = get_event_storage() | |
| hbeat = get_heartbeat() | |
| aabb = model.aabb.T.cpu().numpy() | |
| model = model.to(device_glb) | |
| imgs, K, poses = load_blender("test", scene) | |
| num_imgs = len(imgs) | |
| stats = [] | |
| for i in (pbar := tqdm(range(num_imgs))): | |
| if fuse.on_break(): | |
| break | |
| img, pose = imgs[i], poses[i] | |
| H, W = img.shape[:2] | |
| rgbs, depth = render_one_view(model, aabb, H, W, K, pose) | |
| psnr = PSNR.psnr(img, rgbs) | |
| stats.append(psnr) | |
| metric.put_scalars(psnr=psnr) | |
| pbar.set_description(f"TEST: mean psnr {np.mean(stats):.2f}") | |
| plot = vis(img, rgbs, depth, msg=f"PSNR: {psnr:.2f}", return_buffer=True) | |
| metric.put_artifact("test_vis", ".png", lambda fn: imageio.imwrite(fn, plot)) | |
| metric.step() | |
| hbeat.beat() | |
| metric.put_artifact( | |
| "test_seq", ".mp4", | |
| lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "test_vis")[1]) | |
| ) | |
| final_psnr = np.mean(stats) | |
| metric.put("final_psnr", final_psnr) | |
| metric.step() | |
| return final_psnr | |