Upload modeling_moment.py
Browse files- modeling_moment.py +2 -1
modeling_moment.py
CHANGED
|
@@ -448,13 +448,14 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
| 448 |
attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
|
| 449 |
outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
|
| 450 |
enc_out = outputs.last_hidden_state
|
|
|
|
| 451 |
|
| 452 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
| 453 |
# [batch_size x n_channels x n_patches x d_model]
|
| 454 |
|
| 455 |
# For Mists model
|
| 456 |
# [batch_size, n_channels x n_patches, d_model]
|
| 457 |
-
hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
|
| 458 |
|
| 459 |
if reduction == "mean":
|
| 460 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|
|
|
|
| 448 |
attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
|
| 449 |
outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
|
| 450 |
enc_out = outputs.last_hidden_state
|
| 451 |
+
hidden_states = outputs.hidden_states # hidden_statesを取得
|
| 452 |
|
| 453 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
| 454 |
# [batch_size x n_channels x n_patches x d_model]
|
| 455 |
|
| 456 |
# For Mists model
|
| 457 |
# [batch_size, n_channels x n_patches, d_model]
|
| 458 |
+
# hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
|
| 459 |
|
| 460 |
if reduction == "mean":
|
| 461 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|