|
|
import torch
|
|
|
from torch import nn
|
|
|
import torch.nn.functional as F
|
|
|
from config import CFG
|
|
|
import utils
|
|
|
import math
|
|
|
import numpy as np
|
|
|
from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
|
|
|
from GNN import layers as gly
|
|
|
|
|
|
class MolGNNEncoder(nn.Module):
|
|
|
def __init__(self,
|
|
|
outdim,
|
|
|
n_feats=74,
|
|
|
n_filters_list=[256, 256, 256],
|
|
|
n_head=4,
|
|
|
mols=1,
|
|
|
adj_chans=6,
|
|
|
readout_layers=2,
|
|
|
bias=True):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
n_filters_list = [i for i in n_filters_list if i is not None]
|
|
|
lys = []
|
|
|
|
|
|
for i, nf in enumerate(n_filters_list):
|
|
|
if i == 0:
|
|
|
nf1 = n_feats
|
|
|
else:
|
|
|
nf1 = prevnf
|
|
|
|
|
|
prevnf = nf
|
|
|
|
|
|
ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias)
|
|
|
lys.append(ly)
|
|
|
|
|
|
self.block_layers = nn.ModuleList(lys)
|
|
|
self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias)
|
|
|
self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)])
|
|
|
self.gelu = QuickGELU()
|
|
|
|
|
|
def forward(self, batch):
|
|
|
V = batch['V']
|
|
|
A = batch['A']
|
|
|
mol_size = batch['mol_size']
|
|
|
|
|
|
for ly in self.block_layers:
|
|
|
V = ly(V, A)
|
|
|
|
|
|
X = self.attention_layer(V, mol_size)
|
|
|
|
|
|
for ly in self.readout_layers:
|
|
|
X = self.gelu(ly(X))
|
|
|
|
|
|
return X
|
|
|
|
|
|
class ProjectionHead(nn.Module):
|
|
|
def __init__(self,
|
|
|
embedding_dim,
|
|
|
projection_dim,
|
|
|
cfg,
|
|
|
transformer=True,
|
|
|
lstm=False):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.projection = nn.Linear(embedding_dim, projection_dim)
|
|
|
self.gelu = nn.GELU()
|
|
|
self.transformer = None
|
|
|
if transformer:
|
|
|
self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads)
|
|
|
self.lstm = None
|
|
|
if lstm:
|
|
|
self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True)
|
|
|
self.dropout = nn.Dropout(cfg.dropout)
|
|
|
|
|
|
def forward(self, x):
|
|
|
projected = self.projection(x)
|
|
|
if self.transformer is None:
|
|
|
x = self.gelu(projected)
|
|
|
else:
|
|
|
x = self.transformer(projected)
|
|
|
if not self.lstm is None:
|
|
|
x, (_, _) = self.lstm(x)
|
|
|
x = self.dropout(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class FragSimiModel(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
cfg
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.cfg = cfg
|
|
|
self.mol_gnn_encoder = None
|
|
|
mol_embedding_dim = cfg.mol_embedding_dim
|
|
|
|
|
|
if 'gnn' in self.cfg.mol_encoder:
|
|
|
self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim,
|
|
|
n_filters_list=cfg.molgnn_n_filters_list,
|
|
|
n_head=cfg.molgnn_nhead,
|
|
|
readout_layers=cfg.molgnn_readout_layers)
|
|
|
if 'fp' in self.cfg.mol_encoder:
|
|
|
mol_embedding_dim = 2*cfg.mol_embedding_dim
|
|
|
|
|
|
if 'fm' in self.cfg.mol_encoder:
|
|
|
mol_embedding_dim += 10
|
|
|
|
|
|
self.ms_projection = ProjectionHead(cfg.ms_embedding_dim,
|
|
|
cfg.projection_dim,
|
|
|
cfg,
|
|
|
cfg.tsfm_in_ms,
|
|
|
cfg.lstm_in_ms)
|
|
|
|
|
|
self.mol_projection = ProjectionHead(mol_embedding_dim,
|
|
|
cfg.projection_dim,
|
|
|
cfg,
|
|
|
cfg.tsfm_in_mol,
|
|
|
cfg.lstm_in_mol)
|
|
|
|
|
|
def forward(self, batch):
|
|
|
ms_features = batch["ms_bins"]
|
|
|
mol_feat_list = []
|
|
|
if 'gnn' in self.cfg.mol_encoder:
|
|
|
mol_feat_list.append(self.mol_gnn_encoder(batch))
|
|
|
if 'fp' in self.cfg.mol_encoder:
|
|
|
mol_feat_list.append(batch["mol_fps"])
|
|
|
if 'fm' in self.cfg.mol_encoder:
|
|
|
mol_feat_list.append(batch["mol_fmvec"])
|
|
|
|
|
|
if len(mol_feat_list) > 1:
|
|
|
mol_features = torch.cat(mol_feat_list, dim=1)
|
|
|
else:
|
|
|
mol_features = mol_feat_list[0]
|
|
|
|
|
|
|
|
|
ms_embeddings = self.ms_projection(ms_features)
|
|
|
mol_embeddings = self.mol_projection(mol_features)
|
|
|
|
|
|
|
|
|
mol_embeddings = F.normalize(mol_embeddings, p=2, dim=1)
|
|
|
ms_embeddings = F.normalize(ms_embeddings, p=2, dim=1)
|
|
|
|
|
|
return mol_embeddings, ms_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''logits = mol_embeddings @ ms_embeddings.t()
|
|
|
|
|
|
ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
|
|
|
|
|
|
ms_loss = loss_func(logits, ground_truth)
|
|
|
mol_loss = loss_func(logits.t(), ground_truth)
|
|
|
loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
|
|
|
|
|
|
return loss.mean()'''
|
|
|
|