Uploading patch
Browse files- modeling_gpt_bert.py +9 -10
modeling_gpt_bert.py
CHANGED
|
@@ -138,7 +138,7 @@ class Attention(nn.Module):
|
|
| 138 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
| 139 |
position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
| 140 |
position_indices = config.position_bucket_size - 1 + position_indices
|
| 141 |
-
self.register_buffer("position_indices", position_indices, persistent=
|
| 142 |
|
| 143 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 144 |
self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
|
|
@@ -301,18 +301,17 @@ class GPTBERT(GPTBERTPreTrainedModel):
|
|
| 301 |
batch_size, seq_length = input_shape
|
| 302 |
|
| 303 |
if attention_mask is None:
|
| 304 |
-
attention_mask = input_ids.
|
| 305 |
-
|
| 306 |
-
if attention_mask is not None:
|
| 307 |
attention_mask = ~attention_mask.bool()
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
|
| 317 |
static_embeddings, relative_embeddings = self.embedding(input_ids.t())
|
| 318 |
contextualized_embeddings = [static_embeddings]
|
|
|
|
| 138 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
| 139 |
position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
| 140 |
position_indices = config.position_bucket_size - 1 + position_indices
|
| 141 |
+
self.register_buffer("position_indices", position_indices, persistent=False)
|
| 142 |
|
| 143 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 144 |
self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
|
|
|
|
| 301 |
batch_size, seq_length = input_shape
|
| 302 |
|
| 303 |
if attention_mask is None:
|
| 304 |
+
attention_mask = input_ids.new_zeros((batch_size, seq_length), dtype=torch.bool).unsqueeze(1).unsqueeze(2)
|
| 305 |
+
else:
|
|
|
|
| 306 |
attention_mask = ~attention_mask.bool()
|
| 307 |
|
| 308 |
+
if len(attention_mask.size()) == 2:
|
| 309 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 310 |
+
elif len(attention_mask.size()) == 3:
|
| 311 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 312 |
|
| 313 |
+
if self.is_causal:
|
| 314 |
+
attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)
|
| 315 |
|
| 316 |
static_embeddings, relative_embeddings = self.embedding(input_ids.t())
|
| 317 |
contextualized_embeddings = [static_embeddings]
|