fix: remove cleaving (#13)
Browse files- Revert "feat: cleave off layers from encoder (#11)" (99b812dc5c29ac777aa7b5164d3ec7399d520e06)
- modeling_bert.py +4 -23
modeling_bert.py
CHANGED
|
@@ -166,25 +166,6 @@ class BertEncoder(nn.Module):
|
|
| 166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 167 |
)
|
| 168 |
self._grad_checkpointing = False
|
| 169 |
-
self._last_layer_idx = len(self.layers) - 1
|
| 170 |
-
|
| 171 |
-
@property
|
| 172 |
-
def last_layer_idx(self):
|
| 173 |
-
return self._last_layer_idx
|
| 174 |
-
|
| 175 |
-
@last_layer_idx.setter
|
| 176 |
-
def last_layer_idx(self, idx: int):
|
| 177 |
-
assert 0 <= idx < len(self.layers)
|
| 178 |
-
self._last_layer_idx = idx
|
| 179 |
-
|
| 180 |
-
@property
|
| 181 |
-
def cleaved_layers(self):
|
| 182 |
-
return len(self.layers) - self.last_layer_idx - 1
|
| 183 |
-
|
| 184 |
-
@cleaved_layers.setter
|
| 185 |
-
def cleaved_layers(self, n: int):
|
| 186 |
-
assert 0 <= n < len(self.layers)
|
| 187 |
-
self.last_layer_idx = len(self.layers) - n - 1
|
| 188 |
|
| 189 |
@property
|
| 190 |
def gradient_checkpointing(self):
|
|
@@ -205,7 +186,7 @@ class BertEncoder(nn.Module):
|
|
| 205 |
mixer_kwargs = (
|
| 206 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
| 207 |
)
|
| 208 |
-
for layer in self.layers
|
| 209 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 210 |
if subset_mask is not None:
|
| 211 |
hidden_states = hidden_states[subset_mask]
|
|
@@ -216,11 +197,11 @@ class BertEncoder(nn.Module):
|
|
| 216 |
)
|
| 217 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 218 |
if subset_mask is None:
|
| 219 |
-
for layer in self.layers
|
| 220 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 221 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 222 |
else:
|
| 223 |
-
for layer in self.layers[
|
| 224 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 225 |
if key_padding_mask is not None:
|
| 226 |
subset_idx = torch.nonzero(
|
|
@@ -247,7 +228,7 @@ class BertEncoder(nn.Module):
|
|
| 247 |
"cu_seqlens_k": cu_seqlens,
|
| 248 |
"max_seqlen_k": max_seqlen_in_batch,
|
| 249 |
}
|
| 250 |
-
hidden_states = self.layers[
|
| 251 |
return hidden_states
|
| 252 |
|
| 253 |
|
|
|
|
| 166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 167 |
)
|
| 168 |
self._grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
@property
|
| 171 |
def gradient_checkpointing(self):
|
|
|
|
| 186 |
mixer_kwargs = (
|
| 187 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
| 188 |
)
|
| 189 |
+
for layer in self.layers:
|
| 190 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 191 |
if subset_mask is not None:
|
| 192 |
hidden_states = hidden_states[subset_mask]
|
|
|
|
| 197 |
)
|
| 198 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 199 |
if subset_mask is None:
|
| 200 |
+
for layer in self.layers:
|
| 201 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 202 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 203 |
else:
|
| 204 |
+
for layer in self.layers[:-1]:
|
| 205 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 206 |
if key_padding_mask is not None:
|
| 207 |
subset_idx = torch.nonzero(
|
|
|
|
| 228 |
"cu_seqlens_k": cu_seqlens,
|
| 229 |
"max_seqlen_k": max_seqlen_in_batch,
|
| 230 |
}
|
| 231 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
| 232 |
return hidden_states
|
| 233 |
|
| 234 |
|