File size: 8,291 Bytes
5e90249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# direct reward backpropagation
import numpy as np
import torch
import wandb
import os
from finetune_utils import loss_wdce
from tqdm import tqdm
import pandas as pd
from plotting import plot_data_with_distribution_seaborn, plot_data

def finetune(args, cfg, policy_model, reward_model, mcts=None, pretrained=None, filename=None, prot_name=None, eps=1e-5):
    """
    Finetuning with WDCE loss
    """
    base_path = args.base_path
    dt = (1 - eps) / args.total_num_steps
    
    if args.no_mcts:
        assert pretrained is not None, "pretrained model is required for no mcts"
    else:
        assert mcts is not None, "mcts is required for mcts"
        
    # set model to train mode
    policy_model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
    
    # record metrics
    batch_losses = []
    #batch_rewards = []
    
    # initialize the final seqs and log_rnd of the trajectories that generated those seqs
    x_saved, log_rnd_saved, final_rewards_saved = None, None, None
    
    valid_fraction_log = []
    affinity_log = []
    sol_log = []
    hemo_log = []
    nf_log = []
    permeability_log = []

     ### End of Fine-Tuning Loop ###
    pbar = tqdm(range(args.num_epochs))
    
    for epoch in pbar:
        # store metrics
        rewards = []
        losses = []
        
        policy_model.train()
        
        with torch.no_grad():
            if x_saved is None or epoch % args.resample_every_n_step == 0:
                # compute final sequences and trajectory log_rnd
                if args.no_mcts:
                    x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(args, reward_model, pretrained)
                else:
                    # decides whether to reset tree                
                    if (epoch) % args.reset_every_n_step == 0:
                        x_final, log_rnd, final_rewards, _, _ = mcts.forward(resetTree=True)
                    else:
                        x_final, log_rnd, final_rewards, _, _ = mcts.forward(resetTree=False)
                
                # save for next iteration
                x_saved, log_rnd_saved, final_rewards_saved = x_final, log_rnd, final_rewards
            else:
                x_final, log_rnd, final_rewards = x_saved, log_rnd_saved, final_rewards_saved
                
        # compute wdce loss
        loss = loss_wdce(policy_model, log_rnd, x_final, num_replicates=args.wdce_num_replicates, centering=args.centering)
        
        # gradient descent
        loss.backward()
        
        # optimizer
        if args.grad_clip:
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
        
        optim.step()
        optim.zero_grad()
        
        pbar.set_postfix(loss=loss.item())
                
        # sample a eval batch with updated policy to evaluate rewards
        x_eval, affinity, sol, hemo, nf, permeability, valid_fraction = policy_model.sample_finetuned(args, reward_model, batch_size=50, dataframe=False)

        # append to log
        affinity_log.append(affinity)
        sol_log.append(sol)
        hemo_log.append(hemo)
        nf_log.append(nf)
        permeability_log.append(permeability)
        valid_fraction_log.append(valid_fraction)
        
        batch_losses.append(loss.cpu().detach().numpy())
        
        losses.append(loss.cpu().detach().numpy())
        losses = np.array(losses)
        
        if args.no_mcts:
            mean_reward_search = final_rewards.mean().item()
            min_reward_search = final_rewards.min().item()
            max_reward_search = final_rewards.max().item()
            median_reward_search = final_rewards.median().item()
        else:
            mean_reward_search = np.mean(final_rewards)
            min_reward_search = np.min(final_rewards)
            max_reward_search = np.max(final_rewards)
            median_reward_search = np.median(final_rewards)
        
        print("epoch %d"%epoch, "affinity %f"%np.mean(affinity), "sol %f"%np.mean(sol), "hemo %f"%np.mean(hemo), "nf %f"%np.mean(nf), "permeability %f"%np.mean(permeability), "mean loss %f"%np.mean(losses))
        
        wandb.log({"epoch": epoch, "affinity": np.mean(affinity), "sol": np.mean(sol), "hemo": np.mean(hemo), "nf": np.mean(nf), "permeability": np.mean(permeability),
                   "mean_loss": np.mean(losses),
                   "mean_reward_search": mean_reward_search, "min_reward_search": min_reward_search,
                   "max_reward_search": max_reward_search, "median_reward_search": median_reward_search})
        
        if (epoch+1) % args.save_every_n_epochs == 0:
            model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
            torch.save(policy_model.state_dict(), model_path)
            print(f"model saved at epoch {epoch}")
    
    ### End of Fine-Tuning Loop ###
    
    wandb.finish()
    
    # save logs and plot
    plot_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}'
    os.makedirs(plot_path, exist_ok=True)
    output_log_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/log_{filename}.csv'
    save_logs_to_file(valid_fraction_log, affinity_log, 
                      sol_log, hemo_log, nf_log, 
                      permeability_log, output_log_path)
    
    plot_data(valid_fraction_log, 
            save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/valid_{filename}.png')
    
    plot_data_with_distribution_seaborn(log1=affinity_log,
            save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/binding_{filename}.png',
            label1=f"Average Binding Affinity to {prot_name}",
            title=f"Average Binding Affinity to {prot_name} Over Iterations")
    
    plot_data_with_distribution_seaborn(log1=sol_log,
            save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/sol_{filename}.png',
            label1="Average Solubility Score",
            title="Average Solubility Score Over Iterations")
    plot_data_with_distribution_seaborn(log1=hemo_log,
            save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/hemo_{filename}.png',
            label1="Average Hemolysis Score",
            title="Average Hemolysis Score Over Iterations")
    plot_data_with_distribution_seaborn(log1=nf_log,
            save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/nf_{filename}.png',
            label1="Average Nonfouling Score",
            title="Average Nonfouling Score Over Iterations")
    plot_data_with_distribution_seaborn(log1=permeability_log,
            save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/perm_{filename}.png',
            label1="Average Permeability Score",
            title="Average Permeability Score Over Iterations")
    
    x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = policy_model.sample_finetuned(args, reward_model, batch_size=200, dataframe=True)
    df.to_csv(f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/{prot_name}_generation_results.csv', index=False)

    return batch_losses

def save_logs_to_file(valid_fraction_log, affinity_log, 
                      sol_log, hemo_log, nf_log, 
                      permeability_log, output_path):
    """
    Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
    
    Parameters:
        valid_fraction_log (list): Log of valid fractions over iterations.
        affinity1_log (list): Log of binding affinity over iterations.
        permeability_log (list): Log of membrane permeability over iterations.
        output_path (str): Path to save the log CSV file.
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Combine logs into a DataFrame
    log_data = {
        "Iteration": list(range(1, len(valid_fraction_log) + 1)),
        "Valid Fraction": valid_fraction_log,
        "Binding Affinity": affinity_log,
        "Solubility": sol_log,
        "Hemolysis": hemo_log, 
        "Nonfouling": nf_log,
        "Permeability": permeability_log
    }
        
    df = pd.DataFrame(log_data)
    
    # Save to CSV
    df.to_csv(output_path, index=False)