Spaces:
Build error
Build error
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from imageio import imwrite | |
| from pydantic import validator | |
| from my.utils import ( | |
| tqdm, EventStorage, HeartBeat, EarlyLoopBreak, | |
| get_event_storage, get_heartbeat, read_stats | |
| ) | |
| from my.config import BaseConf, dispatch, optional_load_config | |
| from my.utils.seed import seed_everything | |
| from adapt import ScoreAdapter, karras_t_schedule | |
| from run_img_sampling import GDDPM, SD, StableDiffusion | |
| from misc import torch_samps_to_imgs | |
| from pose import PoseConfig | |
| from run_nerf import VoxConfig | |
| from voxnerf.utils import every | |
| from voxnerf.render import ( | |
| as_torch_tsrs, rays_from_img, ray_box_intersect, render_ray_bundle | |
| ) | |
| from voxnerf.vis import stitch_vis, bad_vis as nerf_vis | |
| device_glb = torch.device("cuda") | |
| def tsr_stats(tsr): | |
| return { | |
| "mean": tsr.mean().item(), | |
| "std": tsr.std().item(), | |
| "max": tsr.max().item(), | |
| } | |
| class SJC(BaseConf): | |
| family: str = "sd" | |
| gddpm: GDDPM = GDDPM() | |
| sd: SD = SD( | |
| variant="v1", | |
| prompt="A high quality photo of a delicious burger", | |
| scale=100.0 | |
| ) | |
| lr: float = 0.05 | |
| n_steps: int = 10000 | |
| vox: VoxConfig = VoxConfig( | |
| model_type="V_SD", grid_size=100, density_shift=-1.0, c=3, | |
| blend_bg_texture=True, bg_texture_hw=4, | |
| bbox_len=1.0 | |
| ) | |
| pose: PoseConfig = PoseConfig(rend_hw=64, FoV=60.0, R=1.5) | |
| emptiness_scale: int = 10 | |
| emptiness_weight: int = 1e4 | |
| emptiness_step: float = 0.5 | |
| emptiness_multiplier: float = 20.0 | |
| depth_weight: int = 0 | |
| var_red: bool = True | |
| def check_vox(cls, vox_cfg, values): | |
| family = values['family'] | |
| if family == "sd": | |
| vox_cfg.c = 4 | |
| return vox_cfg | |
| def run(self): | |
| cfgs = self.dict() | |
| family = cfgs.pop("family") | |
| model = getattr(self, family).make() | |
| cfgs.pop("vox") | |
| vox = self.vox.make() | |
| cfgs.pop("pose") | |
| poser = self.pose.make() | |
| sjc_3d(**cfgs, poser=poser, model=model, vox=vox) | |
| def sjc_3d( | |
| poser, vox, model: ScoreAdapter, | |
| lr, n_steps, emptiness_scale, emptiness_weight, emptiness_step, emptiness_multiplier, | |
| depth_weight, var_red, **kwargs | |
| ): | |
| del kwargs | |
| assert model.samps_centered() | |
| _, target_H, target_W = model.data_shape() | |
| bs = 1 | |
| aabb = vox.aabb.T.cpu().numpy() | |
| vox = vox.to(device_glb) | |
| opt = torch.optim.Adamax(vox.opt_params(), lr=lr) | |
| H, W = poser.H, poser.W | |
| Ks, poses, prompt_prefixes = poser.sample_train(n_steps) | |
| ts = model.us[30:-10] | |
| fuse = EarlyLoopBreak(5) | |
| same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1) | |
| with tqdm(total=n_steps) as pbar, \ | |
| HeartBeat(pbar) as hbeat, \ | |
| EventStorage() as metric: | |
| for i in range(n_steps): | |
| if fuse.on_break(): | |
| break | |
| p = f"{prompt_prefixes[i]} {model.prompt}" | |
| score_conds = model.prompts_emb([p]) | |
| y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True) | |
| if isinstance(model, StableDiffusion): | |
| pass | |
| else: | |
| y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear') | |
| opt.zero_grad() | |
| with torch.no_grad(): | |
| chosen_σs = np.random.choice(ts, bs, replace=False) | |
| chosen_σs = chosen_σs.reshape(-1, 1, 1, 1) | |
| chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32) | |
| # chosen_σs = us[i] | |
| noise = torch.randn(bs, *y.shape[1:], device=model.device) | |
| zs = y + chosen_σs * noise | |
| Ds = model.denoise(zs, chosen_σs, **score_conds) | |
| if var_red: | |
| grad = (Ds - y) / chosen_σs | |
| else: | |
| grad = (Ds - zs) / chosen_σs | |
| grad = grad.mean(0, keepdim=True) | |
| y.backward(-grad, retain_graph=True) | |
| if depth_weight > 0: | |
| center_depth = depth[7:-7, 7:-7] | |
| border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50) | |
| center_depth_mean = center_depth.mean() | |
| depth_diff = center_depth_mean - border_depth_mean | |
| depth_loss = - torch.log(depth_diff + 1e-12) | |
| depth_loss = depth_weight * depth_loss | |
| depth_loss.backward(retain_graph=True) | |
| emptiness_loss = torch.log(1 + emptiness_scale * ws).mean() | |
| emptiness_loss = emptiness_weight * emptiness_loss | |
| if emptiness_step * n_steps <= i: | |
| emptiness_loss *= emptiness_multiplier | |
| emptiness_loss.backward() | |
| opt.step() | |
| metric.put_scalars(**tsr_stats(y)) | |
| if every(pbar, percent=1): | |
| with torch.no_grad(): | |
| if isinstance(model, StableDiffusion): | |
| y = model.decode(y) | |
| vis_routine(metric, y, depth) | |
| # if every(pbar, step=2500): | |
| # metric.put_artifact( | |
| # "ckpt", ".pt", lambda fn: torch.save(vox.state_dict(), fn) | |
| # ) | |
| # with EventStorage("test"): | |
| # evaluate(model, vox, poser) | |
| metric.step() | |
| pbar.update() | |
| pbar.set_description(p) | |
| hbeat.beat() | |
| metric.put_artifact( | |
| "ckpt", ".pt", lambda fn: torch.save(vox.state_dict(), fn) | |
| ) | |
| with EventStorage("test"): | |
| evaluate(model, vox, poser) | |
| metric.step() | |
| hbeat.done() | |
| def evaluate(score_model, vox, poser): | |
| H, W = poser.H, poser.W | |
| vox.eval() | |
| K, poses = poser.sample_test(100) | |
| fuse = EarlyLoopBreak(5) | |
| metric = get_event_storage() | |
| hbeat = get_heartbeat() | |
| aabb = vox.aabb.T.cpu().numpy() | |
| vox = vox.to(device_glb) | |
| num_imgs = len(poses) | |
| for i in (pbar := tqdm(range(num_imgs))): | |
| if fuse.on_break(): | |
| break | |
| pose = poses[i] | |
| y, depth = render_one_view(vox, aabb, H, W, K, pose) | |
| if isinstance(score_model, StableDiffusion): | |
| y = score_model.decode(y) | |
| vis_routine(metric, y, depth) | |
| metric.step() | |
| hbeat.beat() | |
| metric.flush_history() | |
| metric.put_artifact( | |
| "view_seq", ".mp4", | |
| lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1]) | |
| ) | |
| metric.step() | |
| def render_one_view(vox, aabb, H, W, K, pose, return_w=False): | |
| N = H * W | |
| ro, rd = rays_from_img(H, W, K, pose) | |
| ro, rd, t_min, t_max = scene_box_filter(ro, rd, aabb) | |
| assert len(ro) == N, "for now all pixels must be in" | |
| ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max) | |
| rgbs, depth, weights = render_ray_bundle(vox, ro, rd, t_min, t_max) | |
| rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H, w=W) | |
| depth = rearrange(depth, "(h w) 1 -> h w", h=H, w=W) | |
| if return_w: | |
| return rgbs, depth, weights | |
| else: | |
| return rgbs, depth | |
| def scene_box_filter(ro, rd, aabb): | |
| _, t_min, t_max = ray_box_intersect(ro, rd, aabb) | |
| # do not render what's behind the ray origin | |
| t_min, t_max = np.maximum(t_min, 0), np.maximum(t_max, 0) | |
| return ro, rd, t_min, t_max | |
| def vis_routine(metric, y, depth): | |
| pane = nerf_vis(y, depth, final_H=256) | |
| im = torch_samps_to_imgs(y)[0] | |
| depth = depth.cpu().numpy() | |
| metric.put_artifact("view", ".png", lambda fn: imwrite(fn, pane)) | |
| metric.put_artifact("img", ".png", lambda fn: imwrite(fn, im)) | |
| metric.put_artifact("depth", ".npy", lambda fn: np.save(fn, depth)) | |
| def evaluate_ckpt(): | |
| cfg = optional_load_config(fname="full_config.yml") | |
| assert len(cfg) > 0, "can't find cfg file" | |
| mod = SJC(**cfg) | |
| family = cfg.pop("family") | |
| model: ScoreAdapter = getattr(mod, family).make() | |
| vox = mod.vox.make() | |
| poser = mod.pose.make() | |
| pbar = tqdm(range(1)) | |
| with EventStorage(), HeartBeat(pbar): | |
| ckpt_fname = latest_ckpt() | |
| state = torch.load(ckpt_fname, map_location="cpu") | |
| vox.load_state_dict(state) | |
| vox.to(device_glb) | |
| with EventStorage("test"): | |
| evaluate(model, vox, poser) | |
| def latest_ckpt(): | |
| ts, ys = read_stats("./", "ckpt") | |
| assert len(ys) > 0 | |
| return ys[-1] | |
| if __name__ == "__main__": | |
| seed_everything(0) | |
| dispatch(SJC) | |
| # evaluate_ckpt() | |