Spaces:
Configuration error
Configuration error
| import os | |
| import tyro | |
| import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from lgm.options import AllConfigs, Options | |
| from lgm.gs import GaussianRenderer | |
| import mcubes | |
| import nerfacc | |
| import nvdiffrast.torch as dr | |
| import kiui | |
| from kiui.mesh import Mesh | |
| from kiui.mesh_utils import clean_mesh, decimate_mesh | |
| from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency | |
| from kiui.op import uv_padding, safe_normalize, inverse_sigmoid | |
| from kiui.cam import orbit_camera, get_perspective | |
| from kiui.nn import MLP, trunc_exp | |
| from kiui.gridencoder import GridEncoder | |
| def get_rays(pose, h, w, fovy, opengl=True): | |
| x, y = torch.meshgrid( | |
| torch.arange(w, device=pose.device), | |
| torch.arange(h, device=pose.device), | |
| indexing="xy", | |
| ) | |
| x = x.flatten() | |
| y = y.flatten() | |
| cx = w * 0.5 | |
| cy = h * 0.5 | |
| focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) | |
| camera_dirs = F.pad( | |
| torch.stack( | |
| [ | |
| (x - cx + 0.5) / focal, | |
| (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), | |
| ], | |
| dim=-1, | |
| ), | |
| (0, 1), | |
| value=(-1.0 if opengl else 1.0), | |
| ) # [hw, 3] | |
| rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] | |
| rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] | |
| rays_d = safe_normalize(rays_d) | |
| return rays_o, rays_d | |
| # Triple renderer of gaussians, gaussian, and diso mesh. | |
| # gaussian --> nerf --> mesh | |
| class Converter(nn.Module): | |
| def __init__(self, opt: Options): | |
| super().__init__() | |
| self.opt = opt | |
| self.device = torch.device("cuda") | |
| # gs renderer | |
| self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) | |
| self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) | |
| self.proj_matrix[0, 0] = 1 / self.tan_half_fov | |
| self.proj_matrix[1, 1] = 1 / self.tan_half_fov | |
| self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) | |
| self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear) | |
| self.proj_matrix[2, 3] = 1 | |
| self.gs_renderer = GaussianRenderer(opt) | |
| self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device) | |
| # nerf renderer | |
| if not self.opt.force_cuda_rast: | |
| self.glctx = dr.RasterizeGLContext() | |
| else: | |
| self.glctx = dr.RasterizeCudaContext() | |
| self.step = 0 | |
| self.render_step_size = 5e-3 | |
| self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device) | |
| self.estimator = nerfacc.OccGridEstimator( | |
| roi_aabb=self.aabb, resolution=64, levels=1 | |
| ) | |
| self.encoder_density = GridEncoder( | |
| num_levels=12 | |
| ) # VMEncoder(output_dim=16, mode='sum') | |
| self.encoder = GridEncoder(num_levels=12) | |
| self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False) | |
| self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False) | |
| # mesh renderer | |
| self.proj = ( | |
| torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device) | |
| ) | |
| self.v = self.f = None | |
| self.vt = self.ft = None | |
| self.deform = None | |
| self.albedo = None | |
| def render_gs(self, pose): | |
| cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device) | |
| cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
| # cameras needed by gaussian rasterizer | |
| cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
| cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] | |
| cam_pos = -cam_poses[:, :3, 3] # [V, 3] | |
| out = self.gs_renderer.render( | |
| self.gaussians.unsqueeze(0), | |
| cam_view.unsqueeze(0), | |
| cam_view_proj.unsqueeze(0), | |
| cam_pos.unsqueeze(0), | |
| ) | |
| image = out["image"].squeeze(1).squeeze(0) # [C, H, W] | |
| alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W] | |
| return image, alpha | |
| def get_density(self, xs): | |
| # xs: [..., 3] | |
| prefix = xs.shape[:-1] | |
| xs = xs.view(-1, 3) | |
| feats = self.encoder_density(xs) | |
| density = trunc_exp(self.mlp_density(feats)) | |
| density = density.view(*prefix, 1) | |
| return density | |
| def render_nerf(self, pose): | |
| pose = torch.from_numpy(pose.astype(np.float32)).to(self.device) | |
| # get rays | |
| resolution = self.opt.output_size | |
| rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy) | |
| # update occ grid | |
| if self.training: | |
| def occ_eval_fn(xs): | |
| sigmas = self.get_density(xs) | |
| return self.render_step_size * sigmas | |
| self.estimator.update_every_n_steps( | |
| self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8 | |
| ) | |
| self.step += 1 | |
| # render | |
| def sigma_fn(t_starts, t_ends, ray_indices): | |
| t_origins = rays_o[ray_indices] | |
| t_dirs = rays_d[ray_indices] | |
| xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 | |
| sigmas = self.get_density(xs) | |
| return sigmas.squeeze(-1) | |
| with torch.no_grad(): | |
| ray_indices, t_starts, t_ends = self.estimator.sampling( | |
| rays_o, | |
| rays_d, | |
| sigma_fn=sigma_fn, | |
| near_plane=0.01, | |
| far_plane=100, | |
| render_step_size=self.render_step_size, | |
| stratified=self.training, | |
| cone_angle=0, | |
| ) | |
| t_origins = rays_o[ray_indices] | |
| t_dirs = rays_d[ray_indices] | |
| xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 | |
| sigmas = self.get_density(xs).squeeze(-1) | |
| rgbs = torch.sigmoid(self.mlp(self.encoder(xs))) | |
| n_rays = rays_o.shape[0] | |
| weights, trans, alphas = nerfacc.render_weight_from_density( | |
| t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays | |
| ) | |
| color = nerfacc.accumulate_along_rays( | |
| weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays | |
| ) | |
| alpha = nerfacc.accumulate_along_rays( | |
| weights, values=None, ray_indices=ray_indices, n_rays=n_rays | |
| ) | |
| color = color + 1 * (1.0 - alpha) | |
| color = ( | |
| color.view(resolution, resolution, 3) | |
| .clamp(0, 1) | |
| .permute(2, 0, 1) | |
| .contiguous() | |
| ) | |
| alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous() | |
| return color, alpha | |
| def fit_nerf(self, iters=512, resolution=128): | |
| self.opt.output_size = resolution | |
| optimizer = torch.optim.Adam( | |
| [ | |
| {"params": self.encoder_density.parameters(), "lr": 1e-2}, | |
| {"params": self.encoder.parameters(), "lr": 1e-2}, | |
| {"params": self.mlp_density.parameters(), "lr": 1e-3}, | |
| {"params": self.mlp.parameters(), "lr": 1e-3}, | |
| ] | |
| ) | |
| print(f"[INFO] fitting nerf...") | |
| pbar = tqdm.trange(iters) | |
| for i in pbar: | |
| ver = np.random.randint(-45, 45) | |
| hor = np.random.randint(-180, 180) | |
| rad = np.random.uniform(1.5, 3.0) | |
| pose = orbit_camera(ver, hor, rad) | |
| image_gt, alpha_gt = self.render_gs(pose) | |
| image_pred, alpha_pred = self.render_nerf(pose) | |
| # if i % 200 == 0: | |
| # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred) | |
| loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss( | |
| alpha_pred, alpha_gt | |
| ) | |
| loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss() | |
| loss.backward() | |
| self.encoder_density.grad_total_variation(1e-8) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| pbar.set_description(f"MSE = {loss_mse.item():.6f}") | |
| print(f"[INFO] finished fitting nerf!") | |
| def render_mesh(self, pose): | |
| h = w = self.opt.output_size | |
| v = self.v + self.deform | |
| f = self.f | |
| pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) | |
| # get v_clip and render rgb | |
| v_cam = ( | |
| torch.matmul( | |
| F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T | |
| ) | |
| .float() | |
| .unsqueeze(0) | |
| ) | |
| v_clip = v_cam @ self.proj.T | |
| rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) | |
| alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] | |
| alpha = ( | |
| dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) | |
| ) # [H, W] important to enable gradients! | |
| if self.albedo is None: | |
| xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] | |
| xyzs = xyzs.view(-1, 3) | |
| mask = (alpha > 0).view(-1) | |
| image = torch.zeros_like(xyzs, dtype=torch.float32) | |
| if mask.any(): | |
| masked_albedo = torch.sigmoid( | |
| self.mlp(self.encoder(xyzs[mask].detach(), bound=1)) | |
| ) | |
| image[mask] = masked_albedo.float() | |
| else: | |
| texc, texc_db = dr.interpolate( | |
| self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all" | |
| ) | |
| image = torch.sigmoid( | |
| dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) | |
| ) # [1, H, W, 3] | |
| image = image.view(1, h, w, 3) | |
| # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) | |
| image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] | |
| image = alpha * image + (1 - alpha) | |
| return image, alpha | |
| def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4): | |
| self.opt.output_size = resolution | |
| # init mesh from nerf | |
| grid_size = 256 | |
| sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32) | |
| S = 128 | |
| density_thresh = 10 | |
| X = torch.linspace(-1, 1, grid_size).split(S) | |
| Y = torch.linspace(-1, 1, grid_size).split(S) | |
| Z = torch.linspace(-1, 1, grid_size).split(S) | |
| for xi, xs in enumerate(X): | |
| for yi, ys in enumerate(Y): | |
| for zi, zs in enumerate(Z): | |
| xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") | |
| pts = torch.cat( | |
| [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], | |
| dim=-1, | |
| ) # [S, 3] | |
| val = self.get_density(pts.to(self.device)) | |
| sigmas[ | |
| xi * S : xi * S + len(xs), | |
| yi * S : yi * S + len(ys), | |
| zi * S : zi * S + len(zs), | |
| ] = ( | |
| val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() | |
| ) # [S, 1] --> [x, y, z] | |
| print( | |
| f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})" | |
| ) | |
| vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) | |
| vertices = vertices / (grid_size - 1.0) * 2 - 1 | |
| # clean | |
| vertices = vertices.astype(np.float32) | |
| triangles = triangles.astype(np.int32) | |
| vertices, triangles = clean_mesh( | |
| vertices, triangles, remesh=True, remesh_size=0.01 | |
| ) | |
| if triangles.shape[0] > decimate_target: | |
| vertices, triangles = decimate_mesh( | |
| vertices, triangles, decimate_target, optimalplacement=False | |
| ) | |
| self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) | |
| self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) | |
| self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) | |
| # fit mesh from gs | |
| lr_factor = 1 | |
| optimizer = torch.optim.Adam( | |
| [ | |
| {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor}, | |
| {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor}, | |
| {"params": self.deform, "lr": 1e-4}, | |
| ] | |
| ) | |
| print(f"[INFO] fitting mesh...") | |
| pbar = tqdm.trange(iters) | |
| for i in pbar: | |
| ver = np.random.randint(-10, 10) | |
| hor = np.random.randint(-180, 180) | |
| rad = self.opt.cam_radius # np.random.uniform(1, 2) | |
| pose = orbit_camera(ver, hor, rad) | |
| image_gt, alpha_gt = self.render_gs(pose) | |
| image_pred, alpha_pred = self.render_mesh(pose) | |
| loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss( | |
| alpha_pred, alpha_gt | |
| ) | |
| # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f) | |
| loss_normal = normal_consistency(self.v + self.deform, self.f) | |
| loss_offsets = (self.deform**2).sum(-1).mean() | |
| loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # remesh periodically | |
| if i > 0 and i % 512 == 0: | |
| vertices = (self.v + self.deform).detach().cpu().numpy() | |
| triangles = self.f.detach().cpu().numpy() | |
| vertices, triangles = clean_mesh( | |
| vertices, triangles, remesh=True, remesh_size=0.01 | |
| ) | |
| if triangles.shape[0] > decimate_target: | |
| vertices, triangles = decimate_mesh( | |
| vertices, triangles, decimate_target, optimalplacement=False | |
| ) | |
| self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) | |
| self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) | |
| self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) | |
| lr_factor *= 0.5 | |
| optimizer = torch.optim.Adam( | |
| [ | |
| {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor}, | |
| {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor}, | |
| {"params": self.deform, "lr": 1e-4}, | |
| ] | |
| ) | |
| pbar.set_description(f"MSE = {loss_mse.item():.6f}") | |
| # last clean | |
| vertices = (self.v + self.deform).detach().cpu().numpy() | |
| triangles = self.f.detach().cpu().numpy() | |
| vertices, triangles = clean_mesh(vertices, triangles, remesh=False) | |
| self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) | |
| self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) | |
| self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device)) | |
| print(f"[INFO] finished fitting mesh!") | |
| # uv mesh refine | |
| def fit_mesh_uv( | |
| self, iters=512, resolution=512, texture_resolution=1024, padding=2 | |
| ): | |
| self.opt.output_size = resolution | |
| # unwrap uv | |
| print(f"[INFO] uv unwrapping...") | |
| mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device) | |
| mesh.auto_normal() | |
| mesh.auto_uv() | |
| self.vt = mesh.vt | |
| self.ft = mesh.ft | |
| # render uv maps | |
| h = w = texture_resolution | |
| uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] | |
| uv = torch.cat( | |
| (uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1 | |
| ) # [N, 4] | |
| rast, _ = dr.rasterize( | |
| self.glctx, uv.unsqueeze(0), mesh.ft, (h, w) | |
| ) # [1, h, w, 4] | |
| xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3] | |
| mask, _ = dr.interpolate( | |
| torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f | |
| ) # [1, h, w, 1] | |
| # masked query | |
| xyzs = xyzs.view(-1, 3) | |
| mask = (mask > 0).view(-1) | |
| albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32) | |
| if mask.any(): | |
| print(f"[INFO] querying texture...") | |
| xyzs = xyzs[mask] # [M, 3] | |
| # batched inference to avoid OOM | |
| batch = [] | |
| head = 0 | |
| while head < xyzs.shape[0]: | |
| tail = min(head + 640000, xyzs.shape[0]) | |
| batch.append( | |
| torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float() | |
| ) | |
| head += 640000 | |
| albedo[mask] = torch.cat(batch, dim=0) | |
| albedo = albedo.view(h, w, -1) | |
| mask = mask.view(h, w) | |
| albedo = uv_padding(albedo, mask, padding) | |
| # optimize texture | |
| self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device) | |
| optimizer = torch.optim.Adam( | |
| [ | |
| {"params": self.albedo, "lr": 1e-3}, | |
| ] | |
| ) | |
| print(f"[INFO] fitting mesh texture...") | |
| pbar = tqdm.trange(iters) | |
| for i in pbar: | |
| # shrink to front view as we care more about it... | |
| ver = np.random.randint(-5, 5) | |
| hor = np.random.randint(-15, 15) | |
| rad = self.opt.cam_radius # np.random.uniform(1, 2) | |
| pose = orbit_camera(ver, hor, rad) | |
| image_gt, alpha_gt = self.render_gs(pose) | |
| image_pred, alpha_pred = self.render_mesh(pose) | |
| loss_mse = F.mse_loss(image_pred, image_gt) | |
| loss = loss_mse | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| pbar.set_description(f"MSE = {loss_mse.item():.6f}") | |
| print(f"[INFO] finished fitting mesh texture!") | |
| def export_mesh(self, path): | |
| mesh = Mesh( | |
| v=self.v, | |
| f=self.f, | |
| vt=self.vt, | |
| ft=self.ft, | |
| albedo=torch.sigmoid(self.albedo), | |
| device=self.device, | |
| ) | |
| mesh.auto_normal() | |
| mesh.write(path) | |
| opt = tyro.cli(AllConfigs) | |
| # load a saved ply and convert to mesh | |
| assert opt.test_path.endswith( | |
| ".ply" | |
| ), "--test_path must be a .ply file saved by infer.py" | |
| converter = Converter(opt).cuda() | |
| converter.fit_nerf() | |
| converter.fit_mesh() | |
| converter.fit_mesh_uv() | |
| converter.export_mesh(opt.test_path.replace(".ply", ".glb")) | |