Victor Shirasuna
		
	commited on
		
		
					Commit 
							
							·
						
						f6401dc
	
1
								Parent(s):
							
							8abbc76
								
Restore best_vloss in finetune
Browse files- smi-ted/finetune/trainers.py +10 -8
 
    	
        smi-ted/finetune/trainers.py
    CHANGED
    
    | 
         @@ -47,6 +47,8 @@ class Trainer: 
     | 
|
| 47 | 
         
             
                    self.save_every_epoch = save_every_epoch
         
     | 
| 48 | 
         
             
                    self.save_ckpt = save_ckpt
         
     | 
| 49 | 
         
             
                    self.device = device
         
     | 
| 
         | 
|
| 
         | 
|
| 50 | 
         
             
                    self._set_seed(seed)
         
     | 
| 51 | 
         | 
| 52 | 
         
             
                def _prepare_data(self):
         
     | 
| 
         @@ -89,8 +91,6 @@ class Trainer: 
     | 
|
| 89 | 
         
             
                        print('Checkpoint restored!')
         
     | 
| 90 | 
         | 
| 91 | 
         
             
                def fit(self, max_epochs=500):
         
     | 
| 92 | 
         
            -
                    best_vloss = float('inf')
         
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
             
                    for epoch in range(self.start_epoch, max_epochs+1):
         
     | 
| 95 | 
         
             
                        print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
         
     | 
| 96 | 
         | 
| 
         @@ -106,22 +106,22 @@ class Trainer: 
     | 
|
| 106 | 
         
             
                            print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
         
     | 
| 107 | 
         | 
| 108 | 
         
             
                        ############################### Save Finetune checkpoint #######################################
         
     | 
| 109 | 
         
            -
                        if ((val_loss < best_vloss) or self.save_every_epoch) and self.save_ckpt:
         
     | 
| 110 | 
         
             
                            # remove old checkpoint
         
     | 
| 111 | 
         
            -
                            if  
     | 
| 112 | 
         
             
                                os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
         
     | 
| 113 | 
         | 
| 114 | 
         
             
                            # filename
         
     | 
| 115 | 
         
             
                            model_name = f'{str(self.model)}-Finetune'
         
     | 
| 116 | 
         
            -
                            self.last_filename = f"{model_name} 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 117 | 
         | 
| 118 | 
         
             
                            # save checkpoint
         
     | 
| 119 | 
         
             
                            print('Saving checkpoint...')
         
     | 
| 120 | 
         
             
                            self._save_checkpoint(epoch, self.last_filename)
         
     | 
| 121 | 
         | 
| 122 | 
         
            -
                            # update best loss
         
     | 
| 123 | 
         
            -
                            best_vloss = val_loss
         
     | 
| 124 | 
         
            -
             
     | 
| 125 | 
         
             
                def evaluate(self, verbose=True):
         
     | 
| 126 | 
         
             
                    if verbose:
         
     | 
| 127 | 
         
             
                        print("\n=====Test Evaluation=====")
         
     | 
| 
         @@ -189,6 +189,7 @@ class Trainer: 
     | 
|
| 189 | 
         
             
                    ckpt_dict = torch.load(ckpt_path, map_location='cpu')
         
     | 
| 190 | 
         
             
                    self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
         
     | 
| 191 | 
         
             
                    self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
         
     | 
| 
         | 
|
| 192 | 
         | 
| 193 | 
         
             
                def _save_checkpoint(self, current_epoch, filename):
         
     | 
| 194 | 
         
             
                    if not os.path.exists(self.checkpoints_folder):
         
     | 
| 
         @@ -209,6 +210,7 @@ class Trainer: 
     | 
|
| 209 | 
         
             
                            'train_size': self.df_train.shape[0],
         
     | 
| 210 | 
         
             
                            'valid_size': self.df_valid.shape[0],
         
     | 
| 211 | 
         
             
                            'test_size': self.df_test.shape[0],
         
     | 
| 
         | 
|
| 212 | 
         
             
                        },
         
     | 
| 213 | 
         
             
                        'seed': self.seed,
         
     | 
| 214 | 
         
             
                    }
         
     | 
| 
         | 
|
| 47 | 
         
             
                    self.save_every_epoch = save_every_epoch
         
     | 
| 48 | 
         
             
                    self.save_ckpt = save_ckpt
         
     | 
| 49 | 
         
             
                    self.device = device
         
     | 
| 50 | 
         
            +
                    self.best_vloss = float('inf')
         
     | 
| 51 | 
         
            +
                    self.last_filename = None
         
     | 
| 52 | 
         
             
                    self._set_seed(seed)
         
     | 
| 53 | 
         | 
| 54 | 
         
             
                def _prepare_data(self):
         
     | 
| 
         | 
|
| 91 | 
         
             
                        print('Checkpoint restored!')
         
     | 
| 92 | 
         | 
| 93 | 
         
             
                def fit(self, max_epochs=500):
         
     | 
| 
         | 
|
| 
         | 
|
| 94 | 
         
             
                    for epoch in range(self.start_epoch, max_epochs+1):
         
     | 
| 95 | 
         
             
                        print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
         
     | 
| 96 | 
         | 
| 
         | 
|
| 106 | 
         
             
                            print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
         
     | 
| 107 | 
         | 
| 108 | 
         
             
                        ############################### Save Finetune checkpoint #######################################
         
     | 
| 109 | 
         
            +
                        if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt:
         
     | 
| 110 | 
         
             
                            # remove old checkpoint
         
     | 
| 111 | 
         
            +
                            if (self.last_filename != None) and (not self.save_every_epoch):
         
     | 
| 112 | 
         
             
                                os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
         
     | 
| 113 | 
         | 
| 114 | 
         
             
                            # filename
         
     | 
| 115 | 
         
             
                            model_name = f'{str(self.model)}-Finetune'
         
     | 
| 116 | 
         
            +
                            self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                            # update best loss
         
     | 
| 119 | 
         
            +
                            self.best_vloss = val_loss
         
     | 
| 120 | 
         | 
| 121 | 
         
             
                            # save checkpoint
         
     | 
| 122 | 
         
             
                            print('Saving checkpoint...')
         
     | 
| 123 | 
         
             
                            self._save_checkpoint(epoch, self.last_filename)
         
     | 
| 124 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 125 | 
         
             
                def evaluate(self, verbose=True):
         
     | 
| 126 | 
         
             
                    if verbose:
         
     | 
| 127 | 
         
             
                        print("\n=====Test Evaluation=====")
         
     | 
| 
         | 
|
| 189 | 
         
             
                    ckpt_dict = torch.load(ckpt_path, map_location='cpu')
         
     | 
| 190 | 
         
             
                    self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
         
     | 
| 191 | 
         
             
                    self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
         
     | 
| 192 | 
         
            +
                    self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
         
     | 
| 193 | 
         | 
| 194 | 
         
             
                def _save_checkpoint(self, current_epoch, filename):
         
     | 
| 195 | 
         
             
                    if not os.path.exists(self.checkpoints_folder):
         
     | 
| 
         | 
|
| 210 | 
         
             
                            'train_size': self.df_train.shape[0],
         
     | 
| 211 | 
         
             
                            'valid_size': self.df_valid.shape[0],
         
     | 
| 212 | 
         
             
                            'test_size': self.df_test.shape[0],
         
     | 
| 213 | 
         
            +
                            'best_vloss': self.best_vloss,
         
     | 
| 214 | 
         
             
                        },
         
     | 
| 215 | 
         
             
                        'seed': self.seed,
         
     | 
| 216 | 
         
             
                    }
         
     |