Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from my.registry import Registry | |
| VOXRF_REGISTRY = Registry("VoxRF") | |
| def to_grid_samp_coords(xyz_sampled, aabb): | |
| # output range is [-1, 1] | |
| aabbSize = aabb[1] - aabb[0] | |
| return (xyz_sampled - aabb[0]) / aabbSize * 2 - 1 | |
| def add_non_state_tsr(nn_module, key, val): | |
| # tsr added here does not appear in module's state_dict; | |
| nn_module.register_buffer(key, val, persistent=False) | |
| class VoxRF(nn.Module): | |
| def __init__( | |
| self, aabb, grid_size, step_ratio=0.5, | |
| density_shift=-10, ray_march_weight_thres=0.0001, c=3, | |
| blend_bg_texture=True, bg_texture_hw=64 | |
| ): | |
| assert aabb.shape == (2, 3) | |
| xyz = grid_size | |
| del grid_size | |
| super().__init__() | |
| add_non_state_tsr(self, "aabb", torch.tensor(aabb, dtype=torch.float32)) | |
| add_non_state_tsr(self, "grid_size", torch.LongTensor(xyz)) | |
| self.density_shift = density_shift | |
| self.ray_march_weight_thres = ray_march_weight_thres | |
| self.step_ratio = step_ratio | |
| zyx = xyz[::-1] | |
| self.density = torch.nn.Parameter( | |
| torch.zeros((1, 1, *zyx)) | |
| ) | |
| self.color = torch.nn.Parameter( | |
| torch.randn((1, c, *zyx)) | |
| ) | |
| self.blend_bg_texture = blend_bg_texture | |
| self.bg = torch.nn.Parameter( | |
| torch.randn((1, c, bg_texture_hw, bg_texture_hw)) | |
| ) | |
| self.c = c | |
| self.alphaMask = None | |
| self.feats2color = lambda feats: torch.sigmoid(feats) | |
| self.d_scale = torch.nn.Parameter(torch.tensor(0.0)) | |
| def device(self): | |
| return self.density.device | |
| def compute_density_feats(self, xyz_sampled): | |
| xyz_sampled = to_grid_samp_coords(xyz_sampled, self.aabb) | |
| n = xyz_sampled.shape[0] | |
| xyz_sampled = xyz_sampled.reshape(1, n, 1, 1, 3) | |
| σ = F.grid_sample(self.density, xyz_sampled).view(n) | |
| # We notice that DreamFusion also uses an exp scaling on densities. | |
| # The technique here is developed BEFORE DreamFusion came out, | |
| # and forms part of our upcoming technical report discussing invariant | |
| # scaling for volume rendering. The reseach was presented to our | |
| # funding agency (TRI) on Aug. 25th, and discussed with a few researcher friends | |
| # during the period. | |
| σ = σ * torch.exp(self.d_scale) | |
| σ = F.softplus(σ + self.density_shift) | |
| return σ | |
| def compute_app_feats(self, xyz_sampled): | |
| xyz_sampled = to_grid_samp_coords(xyz_sampled, self.aabb) | |
| n = xyz_sampled.shape[0] | |
| xyz_sampled = xyz_sampled.reshape(1, n, 1, 1, 3) | |
| feats = F.grid_sample(self.color, xyz_sampled).view(self.c, n) | |
| feats = feats.T | |
| return feats | |
| def compute_bg(self, uv): | |
| n = uv.shape[0] | |
| uv = uv.reshape(1, n, 1, 2) | |
| feats = F.grid_sample(self.bg, uv).view(self.c, n) | |
| feats = feats.T | |
| return feats | |
| def get_per_voxel_length(self): | |
| aabb_size = self.aabb[1] - self.aabb[0] | |
| # NOTE I am not -1 on grid_size here; | |
| # I interpret a voxel as a square and val sits at the center; like pixel | |
| # this is consistent with align_corners=False | |
| vox_xyz_length = aabb_size / self.grid_size | |
| return vox_xyz_length | |
| def get_num_samples(self, max_size=None): | |
| # funny way to set step size; whatever | |
| unit = torch.mean(self.get_per_voxel_length()) | |
| step_size = unit * self.step_ratio | |
| step_size = step_size.item() # get the float | |
| if max_size is None: | |
| aabb_size = self.aabb[1] - self.aabb[0] | |
| aabb_diag = torch.norm(aabb_size) | |
| max_size = aabb_diag | |
| num_samples = int((max_size / step_size).item()) + 1 | |
| return num_samples, step_size | |
| def resample(self, target_xyz: list): | |
| zyx = target_xyz[::-1] | |
| self.density = self._resamp_param(self.density, zyx) | |
| self.color = self._resamp_param(self.color, zyx) | |
| target_xyz = torch.LongTensor(target_xyz).to(self.aabb.device) | |
| add_non_state_tsr(self, "grid_size", target_xyz) | |
| def _resamp_param(param, target_size): | |
| return torch.nn.Parameter(F.interpolate( | |
| param.data, size=target_size, mode="trilinear" | |
| )) | |
| def compute_volume_alpha(self): | |
| xyz = self.grid_size.tolist() | |
| unit_xyz = self.get_per_voxel_length() | |
| xs, ys, zs = torch.meshgrid( | |
| *[torch.arange(nd) for nd in xyz], indexing="ij" | |
| ) | |
| pts = torch.stack([xs, ys, zs], dim=-1).to(unit_xyz.device) # [nx, ny, nz, 3] | |
| pts = self.aabb[0] + (pts + 0.5) * unit_xyz | |
| pts = pts.reshape(-1, 3) | |
| # could potentially filter with alpha mask itself if exists | |
| σ = self.compute_density_feats(pts) | |
| d = torch.mean(unit_xyz) | |
| α = 1 - torch.exp(-σ * d) | |
| α = rearrange(α.view(xyz), "x y z -> 1 1 z y x") | |
| α = α.contiguous() | |
| return α | |
| def make_alpha_mask(self): | |
| α = self.compute_volume_alpha() | |
| ks = 3 | |
| α = F.max_pool3d(α, kernel_size=ks, padding=ks // 2, stride=1) | |
| α = (α > 0.08).float() | |
| vol_mask = AlphaMask(self.aabb, α) | |
| self.alphaMask = vol_mask | |
| def state_dict(self, *args, **kwargs): | |
| state = super().state_dict(*args, **kwargs) | |
| if self.alphaMask is not None: | |
| state['alpha_mask'] = self.alphaMask.export_state() | |
| return state | |
| def load_state_dict(self, state_dict): | |
| if 'alpha_mask' in state_dict.keys(): | |
| state = state_dict.pop("alpha_mask") | |
| self.alphaMask = AlphaMask.from_state(state) | |
| return super().load_state_dict(state_dict, strict=True) | |
| class V_SJC(VoxRF): | |
| """ | |
| For SJC, when sampling density σ, add a gaussian ball offset | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # rendering color in [-1, 1] range, since score models all operate on centered img | |
| self.feats2color = lambda feats: torch.sigmoid(feats) * 2 - 1 | |
| def opt_params(self): | |
| groups = [] | |
| for name, param in self.named_parameters(): | |
| # print(f"{name} {param.shape}") | |
| grp = {"params": param} | |
| if name in ["bg"]: | |
| grp["lr"] = 0.0001 | |
| if name in ["density"]: | |
| # grp["lr"] = 0. | |
| pass | |
| groups.append(grp) | |
| return groups | |
| def annealed_opt_params(self, base_lr, σ): | |
| groups = [] | |
| for name, param in self.named_parameters(): | |
| # print(f"{name} {param.shape}") | |
| grp = {"params": param, "lr": base_lr * σ} | |
| if name in ["density"]: | |
| grp["lr"] = base_lr * σ | |
| if name in ["d_scale"]: | |
| grp["lr"] = 0. | |
| if name in ["color"]: | |
| grp["lr"] = base_lr * σ | |
| if name in ["bg"]: | |
| grp["lr"] = 0.01 | |
| groups.append(grp) | |
| return groups | |
| class V_SD(V_SJC): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # rendering in feature space; no sigmoid thresholding | |
| self.feats2color = lambda feats: feats | |
| class AlphaMask(nn.Module): | |
| def __init__(self, aabb, alphas): | |
| super().__init__() | |
| zyx = list(alphas.shape[-3:]) | |
| add_non_state_tsr(self, "alphas", alphas.view(1, 1, *zyx)) | |
| xyz = zyx[::-1] | |
| add_non_state_tsr(self, "grid_size", torch.LongTensor(xyz)) | |
| add_non_state_tsr(self, "aabb", aabb) | |
| def sample_alpha(self, xyz_pts): | |
| xyz_pts = to_grid_samp_coords(xyz_pts, self.aabb) | |
| xyz_pts = xyz_pts.view(1, -1, 1, 1, 3) | |
| α = F.grid_sample(self.alphas, xyz_pts).view(-1) | |
| return α | |
| def export_state(self): | |
| state = {} | |
| alphas = self.alphas.bool().cpu().numpy() | |
| state['shape'] = alphas.shape | |
| state['mask'] = np.packbits(alphas.reshape(-1)) | |
| state['aabb'] = self.aabb.cpu() | |
| return state | |
| def from_state(cls, state): | |
| shape = state['shape'] | |
| mask = state['mask'] | |
| aabb = state['aabb'] | |
| length = np.prod(shape) | |
| alphas = torch.from_numpy( | |
| np.unpackbits(mask)[:length].reshape(shape) | |
| ) | |
| amask = cls(aabb, alphas.float()) | |
| return amask | |
| def test(): | |
| device = torch.device("cuda:1") | |
| aabb = 1.5 * np.array([ | |
| [-1, -1, -1], | |
| [1, 1, 1] | |
| ]) | |
| model = VoxRF(aabb, [10, 20, 30]) | |
| model.to(device) | |
| print(model.density.shape) | |
| print(model.grid_size) | |
| return | |
| if __name__ == "__main__": | |
| test() | |