Upload 4 files
Browse files- .gitattributes +1 -0
 - location_encoder.py +158 -0
 - modeling_closp.py +202 -0
 - positional_encoding.py +110 -0
 - spherical_armonics.py +3 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            spherical_armonics.py filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        location_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,158 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Microsoft Corporation.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from einops import rearrange
         
     | 
| 7 | 
         
            +
            from torch import nn
         
     | 
| 8 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .positional_encoding import SphericalHarmonics
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class LocationEncoder(nn.Module):
         
     | 
| 14 | 
         
            +
                def __init__(
         
     | 
| 15 | 
         
            +
                    self,
         
     | 
| 16 | 
         
            +
                    dim_hidden: int,
         
     | 
| 17 | 
         
            +
                    num_layers: int,
         
     | 
| 18 | 
         
            +
                    dim_out: int,
         
     | 
| 19 | 
         
            +
                    legendre_polys: int = 10,
         
     | 
| 20 | 
         
            +
                ):
         
     | 
| 21 | 
         
            +
                    super().__init__()
         
     | 
| 22 | 
         
            +
                    self.posenc = SphericalHarmonics(legendre_polys=legendre_polys)
         
     | 
| 23 | 
         
            +
                    self.nnet = SirenNet(
         
     | 
| 24 | 
         
            +
                        dim_in=self.posenc.embedding_dim,
         
     | 
| 25 | 
         
            +
                        dim_hidden=dim_hidden,
         
     | 
| 26 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 27 | 
         
            +
                        dim_out=dim_out,
         
     | 
| 28 | 
         
            +
                    )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def forward(self, x):
         
     | 
| 31 | 
         
            +
                    x = self.posenc(x)
         
     | 
| 32 | 
         
            +
                    return self.nnet(x)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            class SirenNet(nn.Module):
         
     | 
| 36 | 
         
            +
                """Sinusoidal Representation Network (SIREN)"""
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __init__(
         
     | 
| 39 | 
         
            +
                    self,
         
     | 
| 40 | 
         
            +
                    dim_in,
         
     | 
| 41 | 
         
            +
                    dim_hidden,
         
     | 
| 42 | 
         
            +
                    dim_out,
         
     | 
| 43 | 
         
            +
                    num_layers,
         
     | 
| 44 | 
         
            +
                    w0=1.0,
         
     | 
| 45 | 
         
            +
                    w0_initial=30.0,
         
     | 
| 46 | 
         
            +
                    use_bias=True,
         
     | 
| 47 | 
         
            +
                    final_activation=None,
         
     | 
| 48 | 
         
            +
                    degreeinput=False,
         
     | 
| 49 | 
         
            +
                    dropout=True,
         
     | 
| 50 | 
         
            +
                ):
         
     | 
| 51 | 
         
            +
                    super().__init__()
         
     | 
| 52 | 
         
            +
                    self.num_layers = num_layers
         
     | 
| 53 | 
         
            +
                    self.dim_hidden = dim_hidden
         
     | 
| 54 | 
         
            +
                    self.degreeinput = degreeinput
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    self.layers = nn.ModuleList([])
         
     | 
| 57 | 
         
            +
                    for ind in range(num_layers):
         
     | 
| 58 | 
         
            +
                        is_first = ind == 0
         
     | 
| 59 | 
         
            +
                        layer_w0 = w0_initial if is_first else w0
         
     | 
| 60 | 
         
            +
                        layer_dim_in = dim_in if is_first else dim_hidden
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                        self.layers.append(
         
     | 
| 63 | 
         
            +
                            Siren(
         
     | 
| 64 | 
         
            +
                                dim_in=layer_dim_in,
         
     | 
| 65 | 
         
            +
                                dim_out=dim_hidden,
         
     | 
| 66 | 
         
            +
                                w0=layer_w0,
         
     | 
| 67 | 
         
            +
                                use_bias=use_bias,
         
     | 
| 68 | 
         
            +
                                is_first=is_first,
         
     | 
| 69 | 
         
            +
                                dropout=dropout,
         
     | 
| 70 | 
         
            +
                            )
         
     | 
| 71 | 
         
            +
                        )
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    final_activation = (
         
     | 
| 74 | 
         
            +
                        nn.Identity() if not exists(final_activation) else final_activation
         
     | 
| 75 | 
         
            +
                    )
         
     | 
| 76 | 
         
            +
                    self.last_layer = Siren(
         
     | 
| 77 | 
         
            +
                        dim_in=dim_hidden,
         
     | 
| 78 | 
         
            +
                        dim_out=dim_out,
         
     | 
| 79 | 
         
            +
                        w0=w0,
         
     | 
| 80 | 
         
            +
                        use_bias=use_bias,
         
     | 
| 81 | 
         
            +
                        activation=final_activation,
         
     | 
| 82 | 
         
            +
                        dropout=False,
         
     | 
| 83 | 
         
            +
                    )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                def forward(self, x, mods=None):
         
     | 
| 86 | 
         
            +
                    # do some normalization to bring degrees in a -pi to pi range
         
     | 
| 87 | 
         
            +
                    if self.degreeinput:
         
     | 
| 88 | 
         
            +
                        x = torch.deg2rad(x) - torch.pi
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    mods = cast_tuple(mods, self.num_layers)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    for layer, mod in zip(self.layers, mods):
         
     | 
| 93 | 
         
            +
                        x = layer(x)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                        if exists(mod):
         
     | 
| 96 | 
         
            +
                            x *= rearrange(mod, "d -> () d")
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    return self.last_layer(x)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            class Sine(nn.Module):
         
     | 
| 102 | 
         
            +
                def __init__(self, w0=1.0):
         
     | 
| 103 | 
         
            +
                    super().__init__()
         
     | 
| 104 | 
         
            +
                    self.w0 = w0
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def forward(self, x):
         
     | 
| 107 | 
         
            +
                    return torch.sin(self.w0 * x)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            class Siren(nn.Module):
         
     | 
| 111 | 
         
            +
                def __init__(
         
     | 
| 112 | 
         
            +
                    self,
         
     | 
| 113 | 
         
            +
                    dim_in,
         
     | 
| 114 | 
         
            +
                    dim_out,
         
     | 
| 115 | 
         
            +
                    w0=1.0,
         
     | 
| 116 | 
         
            +
                    c=6.0,
         
     | 
| 117 | 
         
            +
                    is_first=False,
         
     | 
| 118 | 
         
            +
                    use_bias=True,
         
     | 
| 119 | 
         
            +
                    activation=None,
         
     | 
| 120 | 
         
            +
                    dropout=False,
         
     | 
| 121 | 
         
            +
                ):
         
     | 
| 122 | 
         
            +
                    super().__init__()
         
     | 
| 123 | 
         
            +
                    self.dim_in = dim_in
         
     | 
| 124 | 
         
            +
                    self.is_first = is_first
         
     | 
| 125 | 
         
            +
                    self.dim_out = dim_out
         
     | 
| 126 | 
         
            +
                    self.dropout = dropout
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    weight = torch.zeros(dim_out, dim_in)
         
     | 
| 129 | 
         
            +
                    bias = torch.zeros(dim_out) if use_bias else None
         
     | 
| 130 | 
         
            +
                    self.init_(weight, bias, c=c, w0=w0)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    self.weight = nn.Parameter(weight)
         
     | 
| 133 | 
         
            +
                    self.bias = nn.Parameter(bias) if use_bias else None
         
     | 
| 134 | 
         
            +
                    self.activation = Sine(w0) if activation is None else activation
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def init_(self, weight, bias, c, w0):
         
     | 
| 137 | 
         
            +
                    dim = self.dim_in
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
         
     | 
| 140 | 
         
            +
                    weight.uniform_(-w_std, w_std)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    if exists(bias):
         
     | 
| 143 | 
         
            +
                        bias.uniform_(-w_std, w_std)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def forward(self, x):
         
     | 
| 146 | 
         
            +
                    out = F.linear(x, self.weight, self.bias)
         
     | 
| 147 | 
         
            +
                    if self.dropout:
         
     | 
| 148 | 
         
            +
                        out = F.dropout(out, training=self.training)
         
     | 
| 149 | 
         
            +
                    out = self.activation(out)
         
     | 
| 150 | 
         
            +
                    return out
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            def exists(val):
         
     | 
| 154 | 
         
            +
                return val is not None
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            def cast_tuple(val, repeat=1):
         
     | 
| 158 | 
         
            +
                return val if isinstance(val, tuple) else ((val,) * repeat)
         
     | 
    	
        modeling_closp.py
    ADDED
    
    | 
         @@ -0,0 +1,202 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from timm import create_model
         
     | 
| 7 | 
         
            +
            from transformers import (
         
     | 
| 8 | 
         
            +
                AutoConfig,
         
     | 
| 9 | 
         
            +
                AutoModel,
         
     | 
| 10 | 
         
            +
                AutoTokenizer,
         
     | 
| 11 | 
         
            +
                PretrainedConfig,
         
     | 
| 12 | 
         
            +
                PreTrainedModel,
         
     | 
| 13 | 
         
            +
            )
         
     | 
| 14 | 
         
            +
            from transformers.utils import ModelOutput
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from .location_encoder import LocationEncoder
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class CLOSPConfig(PretrainedConfig):
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
                Configuration class for CLOSPModel.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                This class stores the configuration of a CLOSPModel, which is used to instantiate the model
         
     | 
| 24 | 
         
            +
                according to the specified parameters.
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                model_type = "closp"
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def __init__(
         
     | 
| 30 | 
         
            +
                    self,
         
     | 
| 31 | 
         
            +
                    # Vision model parameters
         
     | 
| 32 | 
         
            +
                    vision_model_key: str = "vit-s",
         
     | 
| 33 | 
         
            +
                    s1_embedding_dim: int = 384,
         
     | 
| 34 | 
         
            +
                    s2_embedding_dim: int = 384,
         
     | 
| 35 | 
         
            +
                    s1_head_dim: int = 0,
         
     | 
| 36 | 
         
            +
                    s2_head_dim: int = 0,
         
     | 
| 37 | 
         
            +
                    # Text model parameters
         
     | 
| 38 | 
         
            +
                    text_model_name_or_path: str = "distilbert-base-uncased",
         
     | 
| 39 | 
         
            +
                    # Location encoder parameters (optional)
         
     | 
| 40 | 
         
            +
                    use_location_encoder: bool = True,
         
     | 
| 41 | 
         
            +
                    location_embedding_dim: int = 512,
         
     | 
| 42 | 
         
            +
                    # General model parameters
         
     | 
| 43 | 
         
            +
                    projection_dim: int = 768,
         
     | 
| 44 | 
         
            +
                    **kwargs,
         
     | 
| 45 | 
         
            +
                ):
         
     | 
| 46 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 47 | 
         
            +
                    self.vision_model_key = vision_model_key
         
     | 
| 48 | 
         
            +
                    self.s1_embedding_dim = s1_embedding_dim
         
     | 
| 49 | 
         
            +
                    self.s2_embedding_dim = s2_embedding_dim
         
     | 
| 50 | 
         
            +
                    self.text_model_name_or_path = text_model_name_or_path
         
     | 
| 51 | 
         
            +
                    self.use_location_encoder = use_location_encoder
         
     | 
| 52 | 
         
            +
                    self.location_embedding_dim = location_embedding_dim
         
     | 
| 53 | 
         
            +
                    self.projection_dim = projection_dim
         
     | 
| 54 | 
         
            +
                    self.s1_head_dim = s1_head_dim
         
     | 
| 55 | 
         
            +
                    self.s2_head_dim = s2_head_dim
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            # --- Structured Model Output ---
         
     | 
| 59 | 
         
            +
            @dataclass
         
     | 
| 60 | 
         
            +
            class CLOSPOutput(ModelOutput):
         
     | 
| 61 | 
         
            +
                """
         
     | 
| 62 | 
         
            +
                Base class for CLOSP model's outputs.
         
     | 
| 63 | 
         
            +
                """
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                loss: torch.FloatTensor = None
         
     | 
| 66 | 
         
            +
                logits_per_image: torch.FloatTensor = None
         
     | 
| 67 | 
         
            +
                logits_per_text: torch.FloatTensor = None
         
     | 
| 68 | 
         
            +
                logits_per_loc_img: torch.FloatTensor = None
         
     | 
| 69 | 
         
            +
                logits_per_img_loc: torch.FloatTensor = None
         
     | 
| 70 | 
         
            +
                image_embeds: torch.FloatTensor = None
         
     | 
| 71 | 
         
            +
                text_embeds: torch.FloatTensor = None
         
     | 
| 72 | 
         
            +
                location_embeds: torch.FloatTensor = None
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            class CLOSPModel(PreTrainedModel):
         
     | 
| 76 | 
         
            +
                config_class = CLOSPConfig
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def __init__(self, config: CLOSPConfig):
         
     | 
| 79 | 
         
            +
                    super().__init__(config)
         
     | 
| 80 | 
         
            +
                    # --- Vision Encoders ---
         
     | 
| 81 | 
         
            +
                    self.s1_encoder = create_model(
         
     | 
| 82 | 
         
            +
                        config.vision_model_key,
         
     | 
| 83 | 
         
            +
                        in_chans=2,
         
     | 
| 84 | 
         
            +
                        num_classes=config.s1_head_dim,
         
     | 
| 85 | 
         
            +
                        pretrained=False,
         
     | 
| 86 | 
         
            +
                    )
         
     | 
| 87 | 
         
            +
                    self.s2_encoder = create_model(
         
     | 
| 88 | 
         
            +
                        config.vision_model_key,
         
     | 
| 89 | 
         
            +
                        in_chans=13,
         
     | 
| 90 | 
         
            +
                        num_classes=config.s2_head_dim,
         
     | 
| 91 | 
         
            +
                        pretrained=False,
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
                    self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim)
         
     | 
| 94 | 
         
            +
                    self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    # --- Text Encoder ---
         
     | 
| 97 | 
         
            +
                    self.text_model = AutoModel.from_config(
         
     | 
| 98 | 
         
            +
                        AutoConfig.from_pretrained(config.text_model_name_or_path)
         
     | 
| 99 | 
         
            +
                    )
         
     | 
| 100 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    # --- Location Encoder ---
         
     | 
| 103 | 
         
            +
                    if config.use_location_encoder:
         
     | 
| 104 | 
         
            +
                        self.location_encoder = LocationEncoder(512, 2, 256, 10)
         
     | 
| 105 | 
         
            +
                        self.location_projection = nn.Linear(
         
     | 
| 106 | 
         
            +
                            config.location_embedding_dim, config.projection_dim
         
     | 
| 107 | 
         
            +
                        )
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def tokenize_text(self, text: str):
         
     | 
| 110 | 
         
            +
                    """Tokenizes input text using the model's tokenizer."""
         
     | 
| 111 | 
         
            +
                    return self.tokenizer(
         
     | 
| 112 | 
         
            +
                        text,
         
     | 
| 113 | 
         
            +
                        padding="max_length",
         
     | 
| 114 | 
         
            +
                        truncation=True,
         
     | 
| 115 | 
         
            +
                        max_length=self.tokenizer.model_max_length,
         
     | 
| 116 | 
         
            +
                        return_tensors="pt",
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
         
     | 
| 120 | 
         
            +
                    """Encodes an image tensor into features."""
         
     | 
| 121 | 
         
            +
                    image = image.float()
         
     | 
| 122 | 
         
            +
                    if image.shape[1] == 2:  # Sentinel-1
         
     | 
| 123 | 
         
            +
                        image_features = self.s1_projection(self.s1_encoder(image))
         
     | 
| 124 | 
         
            +
                    else:  # Sentinel-2
         
     | 
| 125 | 
         
            +
                        image_features = self.s2_projection(self.s2_encoder(image))
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    return F.normalize(image_features, p=2, dim=-1)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def get_text_features(
         
     | 
| 130 | 
         
            +
                    self, input_ids: torch.Tensor, attention_mask: torch.Tensor
         
     | 
| 131 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 132 | 
         
            +
                    """Encodes text tokens into features."""
         
     | 
| 133 | 
         
            +
                    text_outputs = self.text_model(
         
     | 
| 134 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 135 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 136 | 
         
            +
                        output_hidden_states=True,
         
     | 
| 137 | 
         
            +
                    )
         
     | 
| 138 | 
         
            +
                    text_features = text_outputs.last_hidden_state[:, 0, :]
         
     | 
| 139 | 
         
            +
                    return F.normalize(text_features, p=2, dim=-1)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                def get_location_features(self, coords: torch.Tensor) -> torch.Tensor:
         
     | 
| 142 | 
         
            +
                    """Encodes coordinates into features."""
         
     | 
| 143 | 
         
            +
                    if not self.config.use_location_encoder:
         
     | 
| 144 | 
         
            +
                        raise ValueError(
         
     | 
| 145 | 
         
            +
                            "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config."
         
     | 
| 146 | 
         
            +
                        )
         
     | 
| 147 | 
         
            +
                    location_features = self.location_encoder(coords)
         
     | 
| 148 | 
         
            +
                    location_features = self.location_projection(location_features)
         
     | 
| 149 | 
         
            +
                    return F.normalize(location_features, p=2, dim=-1)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def forward(
         
     | 
| 152 | 
         
            +
                    self,
         
     | 
| 153 | 
         
            +
                    image: torch.Tensor,
         
     | 
| 154 | 
         
            +
                    input_ids: torch.Tensor,
         
     | 
| 155 | 
         
            +
                    attention_mask: torch.Tensor,
         
     | 
| 156 | 
         
            +
                    coords: torch.Tensor = None,
         
     | 
| 157 | 
         
            +
                    return_loss: bool = False,
         
     | 
| 158 | 
         
            +
                ) -> CLOSPOutput:
         
     | 
| 159 | 
         
            +
                    image_embeds = self.get_image_features(image)
         
     | 
| 160 | 
         
            +
                    text_embeds = self.get_text_features(input_ids, attention_mask)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # Cosine similarity as logits
         
     | 
| 163 | 
         
            +
                    logits_per_image = image_embeds @ text_embeds.T
         
     | 
| 164 | 
         
            +
                    logits_per_text = logits_per_image.T
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    # --- Optional Location Logic ---
         
     | 
| 167 | 
         
            +
                    location_embeds = None
         
     | 
| 168 | 
         
            +
                    logits_per_loc_img = None
         
     | 
| 169 | 
         
            +
                    logits_per_img_loc = None
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    if self.config.use_location_encoder:
         
     | 
| 172 | 
         
            +
                        if coords is None:
         
     | 
| 173 | 
         
            +
                            raise ValueError(
         
     | 
| 174 | 
         
            +
                                "Coordinates must be provided when use_location_encoder is True."
         
     | 
| 175 | 
         
            +
                            )
         
     | 
| 176 | 
         
            +
                        location_embeds = self.get_location_features(coords)
         
     | 
| 177 | 
         
            +
                        logits_per_loc_img = location_embeds @ image_embeds.T
         
     | 
| 178 | 
         
            +
                        logits_per_img_loc = image_embeds @ location_embeds.T
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    # --- Optional Loss Calculation ---
         
     | 
| 181 | 
         
            +
                    loss = None
         
     | 
| 182 | 
         
            +
                    if return_loss:
         
     | 
| 183 | 
         
            +
                        outputs = [
         
     | 
| 184 | 
         
            +
                            logits_per_image,
         
     | 
| 185 | 
         
            +
                            logits_per_text,
         
     | 
| 186 | 
         
            +
                            logits_per_loc_img,
         
     | 
| 187 | 
         
            +
                            logits_per_img_loc,
         
     | 
| 188 | 
         
            +
                        ]
         
     | 
| 189 | 
         
            +
                        ground_truth = torch.arange(len(input_ids)).to(self.device)
         
     | 
| 190 | 
         
            +
                        loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None]
         
     | 
| 191 | 
         
            +
                        loss = sum(loss) / len(loss)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    return CLOSPOutput(
         
     | 
| 194 | 
         
            +
                        loss=loss,
         
     | 
| 195 | 
         
            +
                        logits_per_image=logits_per_image,
         
     | 
| 196 | 
         
            +
                        logits_per_text=logits_per_text,
         
     | 
| 197 | 
         
            +
                        logits_per_loc_img=logits_per_loc_img,
         
     | 
| 198 | 
         
            +
                        logits_per_img_loc=logits_per_img_loc,
         
     | 
| 199 | 
         
            +
                        image_embeds=image_embeds,
         
     | 
| 200 | 
         
            +
                        text_embeds=text_embeds,
         
     | 
| 201 | 
         
            +
                        location_embeds=location_embeds,
         
     | 
| 202 | 
         
            +
                    )
         
     | 
    	
        positional_encoding.py
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Microsoft Corporation.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from torch import nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from .spherical_armonics import SH as SH_analytic
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class SphericalHarmonics(nn.Module):
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
                Spherical Harmonics locaiton encoder
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"):
         
     | 
| 17 | 
         
            +
                    """
         
     | 
| 18 | 
         
            +
                    legendre_polys: determines the number of legendre polynomials.
         
     | 
| 19 | 
         
            +
                                    more polynomials lead more fine-grained resolutions
         
     | 
| 20 | 
         
            +
                    calculation of spherical harmonics:
         
     | 
| 21 | 
         
            +
                        analytic uses pre-computed equations. This is exact, but works only up to degree 50,
         
     | 
| 22 | 
         
            +
                        closed-form uses one equation but is computationally slower (especially for high degrees)
         
     | 
| 23 | 
         
            +
                    """
         
     | 
| 24 | 
         
            +
                    super(SphericalHarmonics, self).__init__()
         
     | 
| 25 | 
         
            +
                    self.L, self.M = int(legendre_polys), int(legendre_polys)
         
     | 
| 26 | 
         
            +
                    self.embedding_dim = self.L * self.M
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    if harmonics_calculation == "closed-form":
         
     | 
| 29 | 
         
            +
                        self.SH = SH_closed_form
         
     | 
| 30 | 
         
            +
                    elif harmonics_calculation == "analytic":
         
     | 
| 31 | 
         
            +
                        self.SH = SH_analytic
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def forward(self, lonlat):
         
     | 
| 34 | 
         
            +
                    lon, lat = lonlat[:, 0], lonlat[:, 1]
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    # convert degree to rad
         
     | 
| 37 | 
         
            +
                    phi = torch.deg2rad(lon + 180)
         
     | 
| 38 | 
         
            +
                    theta = torch.deg2rad(lat + 90)
         
     | 
| 39 | 
         
            +
                    """
         
     | 
| 40 | 
         
            +
                    greater_than_50 = (lon > 50).any() or (lat > 50).any()
         
     | 
| 41 | 
         
            +
                    if greater_than_50:
         
     | 
| 42 | 
         
            +
                        SH = SH_closed_form
         
     | 
| 43 | 
         
            +
                    else:
         
     | 
| 44 | 
         
            +
                        SH = SH_analytic
         
     | 
| 45 | 
         
            +
                    """
         
     | 
| 46 | 
         
            +
                    SH = self.SH
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    Y = []
         
     | 
| 49 | 
         
            +
                    for l in range(self.L):
         
     | 
| 50 | 
         
            +
                        for m in range(-l, l + 1):
         
     | 
| 51 | 
         
            +
                            y = SH(m, l, phi, theta)
         
     | 
| 52 | 
         
            +
                            if isinstance(y, float):
         
     | 
| 53 | 
         
            +
                                y = y * torch.ones_like(phi)
         
     | 
| 54 | 
         
            +
                            if y.isnan().any():
         
     | 
| 55 | 
         
            +
                                print(m, l, y)
         
     | 
| 56 | 
         
            +
                            Y.append(y)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    return torch.stack(Y, dim=-1)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            ####################### Spherical Harmonics utilities ########################
         
     | 
| 62 | 
         
            +
            # Code copied from https://github.com/BachiLi/redner/blob/master/pyredner/utils.py
         
     | 
| 63 | 
         
            +
            # Code adapted from "Spherical Harmonic Lighting: The Gritty Details", Robin Green
         
     | 
| 64 | 
         
            +
            # http://silviojemma.com/public/papers/lighting/spherical-harmonic-lighting.pdf
         
     | 
| 65 | 
         
            +
            def associated_legendre_polynomial(l, m, x):
         
     | 
| 66 | 
         
            +
                pmm = torch.ones_like(x)
         
     | 
| 67 | 
         
            +
                if m > 0:
         
     | 
| 68 | 
         
            +
                    somx2 = torch.sqrt((1 - x) * (1 + x))
         
     | 
| 69 | 
         
            +
                    fact = 1.0
         
     | 
| 70 | 
         
            +
                    for i in range(1, m + 1):
         
     | 
| 71 | 
         
            +
                        pmm = pmm * (-fact) * somx2
         
     | 
| 72 | 
         
            +
                        fact += 2.0
         
     | 
| 73 | 
         
            +
                if l == m:
         
     | 
| 74 | 
         
            +
                    return pmm
         
     | 
| 75 | 
         
            +
                pmmp1 = x * (2.0 * m + 1.0) * pmm
         
     | 
| 76 | 
         
            +
                if l == m + 1:
         
     | 
| 77 | 
         
            +
                    return pmmp1
         
     | 
| 78 | 
         
            +
                pll = torch.zeros_like(x)
         
     | 
| 79 | 
         
            +
                for ll in range(m + 2, l + 1):
         
     | 
| 80 | 
         
            +
                    pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
         
     | 
| 81 | 
         
            +
                    pmm = pmmp1
         
     | 
| 82 | 
         
            +
                    pmmp1 = pll
         
     | 
| 83 | 
         
            +
                return pll
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            def SH_renormalization(l, m):
         
     | 
| 87 | 
         
            +
                return math.sqrt(
         
     | 
| 88 | 
         
            +
                    (2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m))
         
     | 
| 89 | 
         
            +
                )
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            def SH_closed_form(m, l, phi, theta):
         
     | 
| 93 | 
         
            +
                if m == 0:
         
     | 
| 94 | 
         
            +
                    return SH_renormalization(l, m) * associated_legendre_polynomial(
         
     | 
| 95 | 
         
            +
                        l, m, torch.cos(theta)
         
     | 
| 96 | 
         
            +
                    )
         
     | 
| 97 | 
         
            +
                elif m > 0:
         
     | 
| 98 | 
         
            +
                    return (
         
     | 
| 99 | 
         
            +
                        math.sqrt(2.0)
         
     | 
| 100 | 
         
            +
                        * SH_renormalization(l, m)
         
     | 
| 101 | 
         
            +
                        * torch.cos(m * phi)
         
     | 
| 102 | 
         
            +
                        * associated_legendre_polynomial(l, m, torch.cos(theta))
         
     | 
| 103 | 
         
            +
                    )
         
     | 
| 104 | 
         
            +
                else:
         
     | 
| 105 | 
         
            +
                    return (
         
     | 
| 106 | 
         
            +
                        math.sqrt(2.0)
         
     | 
| 107 | 
         
            +
                        * SH_renormalization(l, -m)
         
     | 
| 108 | 
         
            +
                        * torch.sin(-m * phi)
         
     | 
| 109 | 
         
            +
                        * associated_legendre_polynomial(l, -m, torch.cos(theta))
         
     | 
| 110 | 
         
            +
                    )
         
     | 
    	
        spherical_armonics.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:1fc4e9b49abb4e81411376fc6d09b1281aa8ed96cef64b7aa95cc4aeeccb97a4
         
     | 
| 3 | 
         
            +
            size 10994723
         
     |