| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						from torch.nn import functional as F | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .positional_encoding import SphericalHarmonics | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class LocationEncoder(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        dim_hidden: int, | 
					
					
						
						| 
							 | 
						        num_layers: int, | 
					
					
						
						| 
							 | 
						        dim_out: int, | 
					
					
						
						| 
							 | 
						        legendre_polys: int = 10, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.posenc = SphericalHarmonics(legendre_polys=legendre_polys) | 
					
					
						
						| 
							 | 
						        self.nnet = SirenNet( | 
					
					
						
						| 
							 | 
						            dim_in=self.posenc.embedding_dim, | 
					
					
						
						| 
							 | 
						            dim_hidden=dim_hidden, | 
					
					
						
						| 
							 | 
						            num_layers=num_layers, | 
					
					
						
						| 
							 | 
						            dim_out=dim_out, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        x = self.posenc(x) | 
					
					
						
						| 
							 | 
						        return self.nnet(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class SirenNet(nn.Module): | 
					
					
						
						| 
							 | 
						    """Sinusoidal Representation Network (SIREN)""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        dim_in, | 
					
					
						
						| 
							 | 
						        dim_hidden, | 
					
					
						
						| 
							 | 
						        dim_out, | 
					
					
						
						| 
							 | 
						        num_layers, | 
					
					
						
						| 
							 | 
						        w0=1.0, | 
					
					
						
						| 
							 | 
						        w0_initial=30.0, | 
					
					
						
						| 
							 | 
						        use_bias=True, | 
					
					
						
						| 
							 | 
						        final_activation=None, | 
					
					
						
						| 
							 | 
						        degreeinput=False, | 
					
					
						
						| 
							 | 
						        dropout=True, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.num_layers = num_layers | 
					
					
						
						| 
							 | 
						        self.dim_hidden = dim_hidden | 
					
					
						
						| 
							 | 
						        self.degreeinput = degreeinput | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.layers = nn.ModuleList([]) | 
					
					
						
						| 
							 | 
						        for ind in range(num_layers): | 
					
					
						
						| 
							 | 
						            is_first = ind == 0 | 
					
					
						
						| 
							 | 
						            layer_w0 = w0_initial if is_first else w0 | 
					
					
						
						| 
							 | 
						            layer_dim_in = dim_in if is_first else dim_hidden | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            self.layers.append( | 
					
					
						
						| 
							 | 
						                Siren( | 
					
					
						
						| 
							 | 
						                    dim_in=layer_dim_in, | 
					
					
						
						| 
							 | 
						                    dim_out=dim_hidden, | 
					
					
						
						| 
							 | 
						                    w0=layer_w0, | 
					
					
						
						| 
							 | 
						                    use_bias=use_bias, | 
					
					
						
						| 
							 | 
						                    is_first=is_first, | 
					
					
						
						| 
							 | 
						                    dropout=dropout, | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        final_activation = ( | 
					
					
						
						| 
							 | 
						            nn.Identity() if not exists(final_activation) else final_activation | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.last_layer = Siren( | 
					
					
						
						| 
							 | 
						            dim_in=dim_hidden, | 
					
					
						
						| 
							 | 
						            dim_out=dim_out, | 
					
					
						
						| 
							 | 
						            w0=w0, | 
					
					
						
						| 
							 | 
						            use_bias=use_bias, | 
					
					
						
						| 
							 | 
						            activation=final_activation, | 
					
					
						
						| 
							 | 
						            dropout=False, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x, mods=None): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.degreeinput: | 
					
					
						
						| 
							 | 
						            x = torch.deg2rad(x) - torch.pi | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        mods = cast_tuple(mods, self.num_layers) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for layer, mod in zip(self.layers, mods): | 
					
					
						
						| 
							 | 
						            x = layer(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if exists(mod): | 
					
					
						
						| 
							 | 
						                x *= rearrange(mod, "d -> () d") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return self.last_layer(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Sine(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, w0=1.0): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.w0 = w0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        return torch.sin(self.w0 * x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Siren(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        dim_in, | 
					
					
						
						| 
							 | 
						        dim_out, | 
					
					
						
						| 
							 | 
						        w0=1.0, | 
					
					
						
						| 
							 | 
						        c=6.0, | 
					
					
						
						| 
							 | 
						        is_first=False, | 
					
					
						
						| 
							 | 
						        use_bias=True, | 
					
					
						
						| 
							 | 
						        activation=None, | 
					
					
						
						| 
							 | 
						        dropout=False, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.dim_in = dim_in | 
					
					
						
						| 
							 | 
						        self.is_first = is_first | 
					
					
						
						| 
							 | 
						        self.dim_out = dim_out | 
					
					
						
						| 
							 | 
						        self.dropout = dropout | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        weight = torch.zeros(dim_out, dim_in) | 
					
					
						
						| 
							 | 
						        bias = torch.zeros(dim_out) if use_bias else None | 
					
					
						
						| 
							 | 
						        self.init_(weight, bias, c=c, w0=w0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.weight = nn.Parameter(weight) | 
					
					
						
						| 
							 | 
						        self.bias = nn.Parameter(bias) if use_bias else None | 
					
					
						
						| 
							 | 
						        self.activation = Sine(w0) if activation is None else activation | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def init_(self, weight, bias, c, w0): | 
					
					
						
						| 
							 | 
						        dim = self.dim_in | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) | 
					
					
						
						| 
							 | 
						        weight.uniform_(-w_std, w_std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if exists(bias): | 
					
					
						
						| 
							 | 
						            bias.uniform_(-w_std, w_std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        out = F.linear(x, self.weight, self.bias) | 
					
					
						
						| 
							 | 
						        if self.dropout: | 
					
					
						
						| 
							 | 
						            out = F.dropout(out, training=self.training) | 
					
					
						
						| 
							 | 
						        out = self.activation(out) | 
					
					
						
						| 
							 | 
						        return out | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def exists(val): | 
					
					
						
						| 
							 | 
						    return val is not None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def cast_tuple(val, repeat=1): | 
					
					
						
						| 
							 | 
						    return val if isinstance(val, tuple) else ((val,) * repeat) | 
					
					
						
						| 
							 | 
						
 |