Added Gradient Checkpointing and fix bugs
Browse files- smi-ted/training/trainer.py +36 -2
smi-ted/training/trainer.py
CHANGED
|
@@ -2,12 +2,16 @@
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
|
|
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
| 7 |
|
| 8 |
# Standard library
|
| 9 |
from tqdm import tqdm
|
| 10 |
import pandas as pd
|
|
|
|
|
|
|
| 11 |
import os
|
| 12 |
|
| 13 |
|
|
@@ -41,6 +45,7 @@ class Trainer:
|
|
| 41 |
self.model = DDP(self.model, device_ids=[self.local_rank])
|
| 42 |
|
| 43 |
def _load_checkpoint(self, checkpoint_path):
|
|
|
|
| 44 |
loc = f"cuda:{self.local_rank}"
|
| 45 |
ckpt_dict = torch.load(checkpoint_path, map_location=loc)
|
| 46 |
if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
|
|
@@ -262,6 +267,12 @@ class TrainerEncoderDecoder(Trainer):
|
|
| 262 |
if self.local_rank == 0:
|
| 263 |
loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
| 266 |
self.optimE.zero_grad(set_to_none=True)
|
| 267 |
self.optimD.zero_grad(set_to_none=True)
|
|
@@ -292,7 +303,13 @@ class TrainerEncoderDecoder(Trainer):
|
|
| 292 |
for param in self.model.module.decoder.parameters():
|
| 293 |
param.requires_grad = False
|
| 294 |
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
logits = logits.view(-1, logits.size(-1))
|
| 297 |
targets = targets.view(-1)
|
| 298 |
errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
|
|
@@ -370,6 +387,12 @@ class TrainerDirectDecoder(Trainer):
|
|
| 370 |
self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
|
| 371 |
self.criterionR = nn.MSELoss()
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
| 374 |
padding_idx = 2
|
| 375 |
error = torch.zeros(1).to(self.local_rank)
|
|
@@ -385,7 +408,18 @@ class TrainerDirectDecoder(Trainer):
|
|
| 385 |
mask = (idx_masked != padding_idx)
|
| 386 |
|
| 387 |
# encoder forward
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
# add padding
|
| 391 |
input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
+
import torch.utils.checkpoint as checkpoint
|
| 6 |
from torch.utils.data import DataLoader
|
| 7 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 8 |
+
from fast_transformers.masking import LengthMask
|
| 9 |
|
| 10 |
# Standard library
|
| 11 |
from tqdm import tqdm
|
| 12 |
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
import os
|
| 16 |
|
| 17 |
|
|
|
|
| 45 |
self.model = DDP(self.model, device_ids=[self.local_rank])
|
| 46 |
|
| 47 |
def _load_checkpoint(self, checkpoint_path):
|
| 48 |
+
opt_dict = None
|
| 49 |
loc = f"cuda:{self.local_rank}"
|
| 50 |
ckpt_dict = torch.load(checkpoint_path, map_location=loc)
|
| 51 |
if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
|
|
|
|
| 267 |
if self.local_rank == 0:
|
| 268 |
loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
|
| 269 |
|
| 270 |
+
def custom(self, module):
|
| 271 |
+
def custom_forward(*inputs):
|
| 272 |
+
inputs = module(inputs[0])
|
| 273 |
+
return inputs
|
| 274 |
+
return custom_forward
|
| 275 |
+
|
| 276 |
def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
| 277 |
self.optimE.zero_grad(set_to_none=True)
|
| 278 |
self.optimD.zero_grad(set_to_none=True)
|
|
|
|
| 303 |
for param in self.model.module.decoder.parameters():
|
| 304 |
param.requires_grad = False
|
| 305 |
|
| 306 |
+
# encoder forward
|
| 307 |
+
x = self.model.module.encoder.tok_emb(idx_masked)
|
| 308 |
+
x = self.model.module.encoder.drop(x)
|
| 309 |
+
x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x)
|
| 310 |
+
logits = self.model.module.encoder.lang_model(x)
|
| 311 |
+
|
| 312 |
+
# loss function
|
| 313 |
logits = logits.view(-1, logits.size(-1))
|
| 314 |
targets = targets.view(-1)
|
| 315 |
errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
|
|
|
|
| 387 |
self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
|
| 388 |
self.criterionR = nn.MSELoss()
|
| 389 |
|
| 390 |
+
def custom(self, module):
|
| 391 |
+
def custom_forward(*inputs):
|
| 392 |
+
inputs = module(inputs[0], length_mask=inputs[1])
|
| 393 |
+
return inputs
|
| 394 |
+
return custom_forward
|
| 395 |
+
|
| 396 |
def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
| 397 |
padding_idx = 2
|
| 398 |
error = torch.zeros(1).to(self.local_rank)
|
|
|
|
| 408 |
mask = (idx_masked != padding_idx)
|
| 409 |
|
| 410 |
# encoder forward
|
| 411 |
+
x = self.model.module.encoder.tok_emb(idx_masked)
|
| 412 |
+
x = self.model.module.encoder.drop(x)
|
| 413 |
+
x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x, LengthMask(mask.sum(-1), max_len=idx_masked.shape[1]))
|
| 414 |
+
|
| 415 |
+
# mean pooling
|
| 416 |
+
input_masked_expanded = mask.unsqueeze(-1).expand(x.size()).float()
|
| 417 |
+
sum_embeddings = torch.sum(x*input_masked_expanded, 1)
|
| 418 |
+
sum_mask = torch.clamp(input_masked_expanded.sum(1), min=1e-9)
|
| 419 |
+
true_set = sum_embeddings/sum_mask
|
| 420 |
+
true_cte = x
|
| 421 |
+
del x
|
| 422 |
+
torch.cuda.empty_cache()
|
| 423 |
|
| 424 |
# add padding
|
| 425 |
input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
|