feat: updated activation checkpointing (#14)
Browse files- wrap every layer in a checkpoint (e0da4c55e7a599407614621df650326c11cafd2f)
- modeling_bert.py +38 -7
modeling_bert.py
CHANGED
|
@@ -81,7 +81,8 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
| 81 |
return_residual=return_residual,
|
| 82 |
use_alibi=True,
|
| 83 |
window_size=window_size,
|
| 84 |
-
qk_norm=use_qk_norm
|
|
|
|
| 85 |
)
|
| 86 |
return mixer_cls
|
| 87 |
|
|
@@ -174,8 +175,6 @@ class BertEncoder(nn.Module):
|
|
| 174 |
@gradient_checkpointing.setter
|
| 175 |
def gradient_checkpointing(self, value):
|
| 176 |
self._grad_checkpointing = value
|
| 177 |
-
for block in self.layers:
|
| 178 |
-
block.mixer.checkpointing = value
|
| 179 |
|
| 180 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 181 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
@@ -187,7 +186,15 @@ class BertEncoder(nn.Module):
|
|
| 187 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
| 188 |
)
|
| 189 |
for layer in self.layers:
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
if subset_mask is not None:
|
| 192 |
hidden_states = hidden_states[subset_mask]
|
| 193 |
else:
|
|
@@ -198,11 +205,27 @@ class BertEncoder(nn.Module):
|
|
| 198 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 199 |
if subset_mask is None:
|
| 200 |
for layer in self.layers:
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 203 |
else:
|
| 204 |
for layer in self.layers[:-1]:
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
if key_padding_mask is not None:
|
| 207 |
subset_idx = torch.nonzero(
|
| 208 |
subset_mask[key_padding_mask], as_tuple=False
|
|
@@ -228,7 +251,15 @@ class BertEncoder(nn.Module):
|
|
| 228 |
"cu_seqlens_k": cu_seqlens,
|
| 229 |
"max_seqlen_k": max_seqlen_in_batch,
|
| 230 |
}
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
return hidden_states
|
| 233 |
|
| 234 |
|
|
|
|
| 81 |
return_residual=return_residual,
|
| 82 |
use_alibi=True,
|
| 83 |
window_size=window_size,
|
| 84 |
+
qk_norm=use_qk_norm,
|
| 85 |
+
checkpointing=False,
|
| 86 |
)
|
| 87 |
return mixer_cls
|
| 88 |
|
|
|
|
| 175 |
@gradient_checkpointing.setter
|
| 176 |
def gradient_checkpointing(self, value):
|
| 177 |
self._grad_checkpointing = value
|
|
|
|
|
|
|
| 178 |
|
| 179 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 180 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
|
|
| 186 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
| 187 |
)
|
| 188 |
for layer in self.layers:
|
| 189 |
+
if self._grad_checkpointing:
|
| 190 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 191 |
+
layer,
|
| 192 |
+
hidden_states,
|
| 193 |
+
use_reentrant=False,
|
| 194 |
+
mixer_kwargs=mixer_kwargs
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 198 |
if subset_mask is not None:
|
| 199 |
hidden_states = hidden_states[subset_mask]
|
| 200 |
else:
|
|
|
|
| 205 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 206 |
if subset_mask is None:
|
| 207 |
for layer in self.layers:
|
| 208 |
+
if self._grad_checkpointing:
|
| 209 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 210 |
+
layer,
|
| 211 |
+
hidden_states,
|
| 212 |
+
use_reentrant=False,
|
| 213 |
+
mixer_kwargs=mixer_kwargs
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 217 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 218 |
else:
|
| 219 |
for layer in self.layers[:-1]:
|
| 220 |
+
if self._grad_checkpointing:
|
| 221 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 222 |
+
layer,
|
| 223 |
+
hidden_states,
|
| 224 |
+
use_reentrant=False,
|
| 225 |
+
mixer_kwargs=mixer_kwargs
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 229 |
if key_padding_mask is not None:
|
| 230 |
subset_idx = torch.nonzero(
|
| 231 |
subset_mask[key_padding_mask], as_tuple=False
|
|
|
|
| 251 |
"cu_seqlens_k": cu_seqlens,
|
| 252 |
"max_seqlen_k": max_seqlen_in_batch,
|
| 253 |
}
|
| 254 |
+
if self._grad_checkpointing:
|
| 255 |
+
torch.utils.checkpoint.checkpoint(
|
| 256 |
+
self.layers[-1],
|
| 257 |
+
hidden_states_subset,
|
| 258 |
+
use_reentrant=False,
|
| 259 |
+
mixer_kwargs=mixer_kwargs
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
| 263 |
return hidden_states
|
| 264 |
|
| 265 |
|