import torch from torch import nn import torch.nn.functional as F from config import CFG import utils import math import numpy as np from cliplayers import QuickGELU, Transformer as MSTsfmEncoder from GNN import layers as gly class MolGNNEncoder(nn.Module): def __init__(self, outdim, n_feats=74, #330, # 74+256 morgan 256 n_filters_list=[256, 256, 256], n_head=4, mols=1, adj_chans=6, readout_layers=2, bias=True): super().__init__() n_filters_list = [i for i in n_filters_list if i is not None] lys = [] for i, nf in enumerate(n_filters_list): if i == 0: nf1 = n_feats else: nf1 = prevnf prevnf = nf ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias) lys.append(ly) self.block_layers = nn.ModuleList(lys) self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias) self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)]) self.gelu = QuickGELU() def forward(self, batch): V = batch['V'] A = batch['A'] mol_size = batch['mol_size'] for ly in self.block_layers: V = ly(V, A) X = self.attention_layer(V, mol_size) for ly in self.readout_layers: X = self.gelu(ly(X)) return X class ProjectionHead(nn.Module): def __init__(self, embedding_dim, projection_dim, cfg, transformer=True, lstm=False): super().__init__() self.projection = nn.Linear(embedding_dim, projection_dim) self.gelu = nn.GELU() #QuickGELU() self.transformer = None if transformer: self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads) self.lstm = None if lstm: self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True) self.dropout = nn.Dropout(cfg.dropout) def forward(self, x): projected = self.projection(x) if self.transformer is None: x = self.gelu(projected) else: x = self.transformer(projected) if not self.lstm is None: x, (_, _) = self.lstm(x) x = self.dropout(x) return x # New name in paper is CMSSPModel class FragSimiModel(nn.Module): def __init__( self, cfg ): super().__init__() self.cfg = cfg self.mol_gnn_encoder = None mol_embedding_dim = cfg.mol_embedding_dim if 'gnn' in self.cfg.mol_encoder: self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim, n_filters_list=cfg.molgnn_n_filters_list, n_head=cfg.molgnn_nhead, readout_layers=cfg.molgnn_readout_layers) if 'fp' in self.cfg.mol_encoder: mol_embedding_dim = 2*cfg.mol_embedding_dim if 'fm' in self.cfg.mol_encoder: mol_embedding_dim += 10 self.ms_projection = ProjectionHead(cfg.ms_embedding_dim, cfg.projection_dim, cfg, cfg.tsfm_in_ms, cfg.lstm_in_ms) self.mol_projection = ProjectionHead(mol_embedding_dim, cfg.projection_dim, cfg, cfg.tsfm_in_mol, cfg.lstm_in_mol) def forward(self, batch): ms_features = batch["ms_bins"] mol_feat_list = [] if 'gnn' in self.cfg.mol_encoder: mol_feat_list.append(self.mol_gnn_encoder(batch)) if 'fp' in self.cfg.mol_encoder: mol_feat_list.append(batch["mol_fps"]) if 'fm' in self.cfg.mol_encoder: mol_feat_list.append(batch["mol_fmvec"]) if len(mol_feat_list) > 1: mol_features = torch.cat(mol_feat_list, dim=1) else: mol_features = mol_feat_list[0] # Getting ms and mol Embeddings (with same dimension) ms_embeddings = self.ms_projection(ms_features) mol_embeddings = self.mol_projection(mol_features) # Normalize the projected embeddings mol_embeddings = F.normalize(mol_embeddings, p=2, dim=1) ms_embeddings = F.normalize(ms_embeddings, p=2, dim=1) return mol_embeddings, ms_embeddings # Calculating the Loss #logits = (mol_embeddings @ ms_embeddings.t()) #logit_scale = self.logit_scale.exp() '''logits = mol_embeddings @ ms_embeddings.t() ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device) ms_loss = loss_func(logits, ground_truth) mol_loss = loss_func(logits.t(), ground_truth) loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size) return loss.mean()'''