Victor Shirasuna
commited on
Commit
·
8abbc76
1
Parent(s):
8c39e88
Added restart checkpoint in finetune
Browse files
smi-ted/finetune/args.py
CHANGED
|
@@ -304,6 +304,7 @@ def get_parser(parser=None):
|
|
| 304 |
# parser.add_argument("--patience_epochs", type=int, required=True)
|
| 305 |
parser.add_argument("--model_path", type=str, default="./smi_ted/")
|
| 306 |
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
|
|
|
|
| 307 |
# parser.add_argument('--n_output', type=int, default=1)
|
| 308 |
parser.add_argument("--save_every_epoch", type=int, default=0)
|
| 309 |
parser.add_argument("--save_ckpt", type=int, default=1)
|
|
|
|
| 304 |
# parser.add_argument("--patience_epochs", type=int, required=True)
|
| 305 |
parser.add_argument("--model_path", type=str, default="./smi_ted/")
|
| 306 |
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
|
| 307 |
+
parser.add_argument("--restart_filename", type=str, default="")
|
| 308 |
# parser.add_argument('--n_output', type=int, default=1)
|
| 309 |
parser.add_argument("--save_every_epoch", type=int, default=0)
|
| 310 |
parser.add_argument("--save_ckpt", type=int, default=1)
|
smi-ted/finetune/finetune_classification.py
CHANGED
|
@@ -48,6 +48,7 @@ def main(config):
|
|
| 48 |
seed=config.start_seed,
|
| 49 |
smi_ted_version=config.smi_ted_version,
|
| 50 |
checkpoints_folder=config.checkpoints_folder,
|
|
|
|
| 51 |
device=device,
|
| 52 |
save_every_epoch=bool(config.save_every_epoch),
|
| 53 |
save_ckpt=bool(config.save_ckpt)
|
|
|
|
| 48 |
seed=config.start_seed,
|
| 49 |
smi_ted_version=config.smi_ted_version,
|
| 50 |
checkpoints_folder=config.checkpoints_folder,
|
| 51 |
+
restart_filename=config.restart_filename,
|
| 52 |
device=device,
|
| 53 |
save_every_epoch=bool(config.save_every_epoch),
|
| 54 |
save_ckpt=bool(config.save_ckpt)
|
smi-ted/finetune/finetune_classification_multitask.py
CHANGED
|
@@ -81,6 +81,7 @@ def main(config):
|
|
| 81 |
seed=config.start_seed,
|
| 82 |
smi_ted_version=config.smi_ted_version,
|
| 83 |
checkpoints_folder=config.checkpoints_folder,
|
|
|
|
| 84 |
device=device,
|
| 85 |
save_every_epoch=bool(config.save_every_epoch),
|
| 86 |
save_ckpt=bool(config.save_ckpt)
|
|
|
|
| 81 |
seed=config.start_seed,
|
| 82 |
smi_ted_version=config.smi_ted_version,
|
| 83 |
checkpoints_folder=config.checkpoints_folder,
|
| 84 |
+
restart_filename=config.restart_filename,
|
| 85 |
device=device,
|
| 86 |
save_every_epoch=bool(config.save_every_epoch),
|
| 87 |
save_ckpt=bool(config.save_ckpt)
|
smi-ted/finetune/finetune_regression.py
CHANGED
|
@@ -50,6 +50,7 @@ def main(config):
|
|
| 50 |
seed=config.start_seed,
|
| 51 |
smi_ted_version=config.smi_ted_version,
|
| 52 |
checkpoints_folder=config.checkpoints_folder,
|
|
|
|
| 53 |
device=device,
|
| 54 |
save_every_epoch=bool(config.save_every_epoch),
|
| 55 |
save_ckpt=bool(config.save_ckpt)
|
|
|
|
| 50 |
seed=config.start_seed,
|
| 51 |
smi_ted_version=config.smi_ted_version,
|
| 52 |
checkpoints_folder=config.checkpoints_folder,
|
| 53 |
+
restart_filename=config.restart_filename,
|
| 54 |
device=device,
|
| 55 |
save_every_epoch=bool(config.save_every_epoch),
|
| 56 |
save_ckpt=bool(config.save_ckpt)
|
smi-ted/finetune/trainers.py
CHANGED
|
@@ -26,7 +26,7 @@ from utils import RMSE, sensitivity, specificity
|
|
| 26 |
class Trainer:
|
| 27 |
|
| 28 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 29 |
-
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 30 |
# data
|
| 31 |
self.df_train = raw_data[0]
|
| 32 |
self.df_valid = raw_data[1]
|
|
@@ -42,6 +42,8 @@ class Trainer:
|
|
| 42 |
self.seed = seed
|
| 43 |
self.smi_ted_version = smi_ted_version
|
| 44 |
self.checkpoints_folder = checkpoints_folder
|
|
|
|
|
|
|
| 45 |
self.save_every_epoch = save_every_epoch
|
| 46 |
self.save_ckpt = save_ckpt
|
| 47 |
self.device = device
|
|
@@ -82,11 +84,14 @@ class Trainer:
|
|
| 82 |
self.optimizer = optimizer
|
| 83 |
self.loss_fn = loss_fn
|
| 84 |
self._print_configuration()
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
def fit(self, max_epochs=500):
|
| 87 |
best_vloss = float('inf')
|
| 88 |
|
| 89 |
-
for epoch in range(
|
| 90 |
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
| 91 |
|
| 92 |
# training
|
|
@@ -183,6 +188,7 @@ class Trainer:
|
|
| 183 |
ckpt_path = os.path.join(self.checkpoints_folder, filename)
|
| 184 |
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
| 185 |
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
|
|
|
| 186 |
|
| 187 |
def _save_checkpoint(self, current_epoch, filename):
|
| 188 |
if not os.path.exists(self.checkpoints_folder):
|
|
@@ -229,9 +235,9 @@ class Trainer:
|
|
| 229 |
class TrainerRegressor(Trainer):
|
| 230 |
|
| 231 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 232 |
-
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 233 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
| 234 |
-
target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
|
| 235 |
|
| 236 |
def _train_one_epoch(self):
|
| 237 |
running_loss = 0.0
|
|
@@ -320,9 +326,9 @@ class TrainerRegressor(Trainer):
|
|
| 320 |
class TrainerClassifier(Trainer):
|
| 321 |
|
| 322 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 323 |
-
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 324 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
| 325 |
-
target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
|
| 326 |
|
| 327 |
def _train_one_epoch(self):
|
| 328 |
running_loss = 0.0
|
|
@@ -427,9 +433,9 @@ class TrainerClassifier(Trainer):
|
|
| 427 |
class TrainerClassifierMultitask(Trainer):
|
| 428 |
|
| 429 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 430 |
-
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 431 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
| 432 |
-
target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
|
| 433 |
|
| 434 |
def _prepare_data(self):
|
| 435 |
# normalize dataset
|
|
|
|
| 26 |
class Trainer:
|
| 27 |
|
| 28 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 29 |
+
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 30 |
# data
|
| 31 |
self.df_train = raw_data[0]
|
| 32 |
self.df_valid = raw_data[1]
|
|
|
|
| 42 |
self.seed = seed
|
| 43 |
self.smi_ted_version = smi_ted_version
|
| 44 |
self.checkpoints_folder = checkpoints_folder
|
| 45 |
+
self.restart_filename = restart_filename
|
| 46 |
+
self.start_epoch = 1
|
| 47 |
self.save_every_epoch = save_every_epoch
|
| 48 |
self.save_ckpt = save_ckpt
|
| 49 |
self.device = device
|
|
|
|
| 84 |
self.optimizer = optimizer
|
| 85 |
self.loss_fn = loss_fn
|
| 86 |
self._print_configuration()
|
| 87 |
+
if self.restart_filename:
|
| 88 |
+
self._load_checkpoint(self.restart_filename)
|
| 89 |
+
print('Checkpoint restored!')
|
| 90 |
|
| 91 |
def fit(self, max_epochs=500):
|
| 92 |
best_vloss = float('inf')
|
| 93 |
|
| 94 |
+
for epoch in range(self.start_epoch, max_epochs+1):
|
| 95 |
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
| 96 |
|
| 97 |
# training
|
|
|
|
| 188 |
ckpt_path = os.path.join(self.checkpoints_folder, filename)
|
| 189 |
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
| 190 |
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
| 191 |
+
self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
|
| 192 |
|
| 193 |
def _save_checkpoint(self, current_epoch, filename):
|
| 194 |
if not os.path.exists(self.checkpoints_folder):
|
|
|
|
| 235 |
class TrainerRegressor(Trainer):
|
| 236 |
|
| 237 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 238 |
+
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 239 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
| 240 |
+
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
| 241 |
|
| 242 |
def _train_one_epoch(self):
|
| 243 |
running_loss = 0.0
|
|
|
|
| 326 |
class TrainerClassifier(Trainer):
|
| 327 |
|
| 328 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 329 |
+
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 330 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
| 331 |
+
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
| 332 |
|
| 333 |
def _train_one_epoch(self):
|
| 334 |
running_loss = 0.0
|
|
|
|
| 433 |
class TrainerClassifierMultitask(Trainer):
|
| 434 |
|
| 435 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
| 436 |
+
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
| 437 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
| 438 |
+
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
| 439 |
|
| 440 |
def _prepare_data(self):
|
| 441 |
# normalize dataset
|