| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .spherical_armonics import SH as SH_analytic | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class SphericalHarmonics(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Spherical Harmonics locaiton encoder | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        legendre_polys: determines the number of legendre polynomials. | 
					
					
						
						| 
							 | 
						                        more polynomials lead more fine-grained resolutions | 
					
					
						
						| 
							 | 
						        calculation of spherical harmonics: | 
					
					
						
						| 
							 | 
						            analytic uses pre-computed equations. This is exact, but works only up to degree 50, | 
					
					
						
						| 
							 | 
						            closed-form uses one equation but is computationally slower (especially for high degrees) | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super(SphericalHarmonics, self).__init__() | 
					
					
						
						| 
							 | 
						        self.L, self.M = int(legendre_polys), int(legendre_polys) | 
					
					
						
						| 
							 | 
						        self.embedding_dim = self.L * self.M | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if harmonics_calculation == "closed-form": | 
					
					
						
						| 
							 | 
						            self.SH = SH_closed_form | 
					
					
						
						| 
							 | 
						        elif harmonics_calculation == "analytic": | 
					
					
						
						| 
							 | 
						            self.SH = SH_analytic | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, lonlat): | 
					
					
						
						| 
							 | 
						        lon, lat = lonlat[:, 0], lonlat[:, 1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        phi = torch.deg2rad(lon + 180) | 
					
					
						
						| 
							 | 
						        theta = torch.deg2rad(lat + 90) | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        greater_than_50 = (lon > 50).any() or (lat > 50).any() | 
					
					
						
						| 
							 | 
						        if greater_than_50: | 
					
					
						
						| 
							 | 
						            SH = SH_closed_form | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            SH = SH_analytic | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        SH = self.SH | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        Y = [] | 
					
					
						
						| 
							 | 
						        for l in range(self.L): | 
					
					
						
						| 
							 | 
						            for m in range(-l, l + 1): | 
					
					
						
						| 
							 | 
						                y = SH(m, l, phi, theta) | 
					
					
						
						| 
							 | 
						                if isinstance(y, float): | 
					
					
						
						| 
							 | 
						                    y = y * torch.ones_like(phi) | 
					
					
						
						| 
							 | 
						                if y.isnan().any(): | 
					
					
						
						| 
							 | 
						                    print(m, l, y) | 
					
					
						
						| 
							 | 
						                Y.append(y) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return torch.stack(Y, dim=-1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def associated_legendre_polynomial(l, m, x): | 
					
					
						
						| 
							 | 
						    pmm = torch.ones_like(x) | 
					
					
						
						| 
							 | 
						    if m > 0: | 
					
					
						
						| 
							 | 
						        somx2 = torch.sqrt((1 - x) * (1 + x)) | 
					
					
						
						| 
							 | 
						        fact = 1.0 | 
					
					
						
						| 
							 | 
						        for i in range(1, m + 1): | 
					
					
						
						| 
							 | 
						            pmm = pmm * (-fact) * somx2 | 
					
					
						
						| 
							 | 
						            fact += 2.0 | 
					
					
						
						| 
							 | 
						    if l == m: | 
					
					
						
						| 
							 | 
						        return pmm | 
					
					
						
						| 
							 | 
						    pmmp1 = x * (2.0 * m + 1.0) * pmm | 
					
					
						
						| 
							 | 
						    if l == m + 1: | 
					
					
						
						| 
							 | 
						        return pmmp1 | 
					
					
						
						| 
							 | 
						    pll = torch.zeros_like(x) | 
					
					
						
						| 
							 | 
						    for ll in range(m + 2, l + 1): | 
					
					
						
						| 
							 | 
						        pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m) | 
					
					
						
						| 
							 | 
						        pmm = pmmp1 | 
					
					
						
						| 
							 | 
						        pmmp1 = pll | 
					
					
						
						| 
							 | 
						    return pll | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def SH_renormalization(l, m): | 
					
					
						
						| 
							 | 
						    return math.sqrt( | 
					
					
						
						| 
							 | 
						        (2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m)) | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def SH_closed_form(m, l, phi, theta): | 
					
					
						
						| 
							 | 
						    if m == 0: | 
					
					
						
						| 
							 | 
						        return SH_renormalization(l, m) * associated_legendre_polynomial( | 
					
					
						
						| 
							 | 
						            l, m, torch.cos(theta) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    elif m > 0: | 
					
					
						
						| 
							 | 
						        return ( | 
					
					
						
						| 
							 | 
						            math.sqrt(2.0) | 
					
					
						
						| 
							 | 
						            * SH_renormalization(l, m) | 
					
					
						
						| 
							 | 
						            * torch.cos(m * phi) | 
					
					
						
						| 
							 | 
						            * associated_legendre_polynomial(l, m, torch.cos(theta)) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return ( | 
					
					
						
						| 
							 | 
						            math.sqrt(2.0) | 
					
					
						
						| 
							 | 
						            * SH_renormalization(l, -m) | 
					
					
						
						| 
							 | 
						            * torch.sin(-m * phi) | 
					
					
						
						| 
							 | 
						            * associated_legendre_polynomial(l, -m, torch.cos(theta)) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 |