Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from spiga.models.gnn.layers import MLP | |
| from spiga.models.gnn.gat import GAT | |
| class StepRegressor(nn.Module): | |
| def __init__(self, input_dim: int, feature_dim: int, nstack=4, decoding=[256, 128, 64, 32]): | |
| super(StepRegressor, self).__init__() | |
| assert nstack > 0 | |
| self.nstack = nstack | |
| self.gat = nn.ModuleList([GAT(input_dim, feature_dim, 4)]) | |
| for _ in range(nstack-1): | |
| self.gat.append(GAT(feature_dim, feature_dim, 4)) | |
| self.decoder = OffsetDecoder(feature_dim, decoding) | |
| def forward(self, embedded, prob_list=[]): | |
| embedded = embedded.transpose(-1, -2) | |
| for i in range(self.nstack): | |
| embedded, prob = self.gat[i](embedded) | |
| prob_list.append(prob) | |
| offset = self.decoder(embedded) | |
| return offset.transpose(-1, -2), prob_list | |
| class OffsetDecoder(nn.Module): | |
| def __init__(self, feature_dim, layers): | |
| super().__init__() | |
| self.decoder = MLP([feature_dim] + layers + [2]) | |
| def forward(self, embedded): | |
| return self.decoder(embedded) | |
| class RelativePositionEncoder(nn.Module): | |
| def __init__(self, input_dim, feature_dim, layers): | |
| super().__init__() | |
| self.encoder = MLP([input_dim] + layers + [feature_dim]) | |
| def forward(self, feature): | |
| feature = feature.transpose(-1, -2) | |
| return self.encoder(feature).transpose(-1, -2) | |