|
|
from rdkit import Chem
|
|
|
from rdkit.Chem import AllChem, MACCSkeys
|
|
|
from rdkit.Chem.rdmolops import FastFindRings
|
|
|
from rdkit.Chem.rdMolDescriptors import CalcMolFormula
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import scipy
|
|
|
import scipy.sparse as ss
|
|
|
import scipy.sparse.linalg
|
|
|
import math
|
|
|
import json
|
|
|
import itertools as it
|
|
|
import re
|
|
|
from GNN import featurizer as ft
|
|
|
|
|
|
import rdkit.RDLogger as rkl
|
|
|
logger = rkl.logger()
|
|
|
logger.setLevel(rkl.ERROR)
|
|
|
|
|
|
import rdkit.rdBase as rkrb
|
|
|
rkrb.DisableLog('rdApp.error')
|
|
|
|
|
|
|
|
|
FPBitIdx = [1, 5, 13, 41, 69, 80, 84, 94, 114, 117, 118, 119, 125, 133, 145,
|
|
|
147, 191, 192, 197, 202, 222, 227, 231, 249, 283, 294, 310, 314,
|
|
|
322, 333, 352, 361, 378, 387, 389, 392, 401, 406, 441, 478, 486,
|
|
|
489, 519, 521, 524, 555, 561, 591, 598, 599, 610, 622, 650, 656,
|
|
|
667, 675, 677, 679, 680, 694, 695, 715, 718, 722, 729, 736, 739,
|
|
|
745, 750, 760, 775, 781, 787, 794, 798, 802, 807, 811, 823, 835,
|
|
|
841, 849, 869, 872, 874, 875, 881, 890, 896, 926, 935, 980, 991,
|
|
|
1004, 1009, 1017, 1019, 1027, 1028, 1035, 1037, 1039, 1057, 1060,
|
|
|
1066, 1070, 1077, 1088, 1097, 1114, 1126, 1136, 1142, 1143, 1145,
|
|
|
1152, 1154, 1160, 1162, 1171, 1181, 1195, 1199, 1202, 1218, 1234,
|
|
|
1236, 1243, 1257, 1267, 1274, 1279, 1283, 1292, 1294, 1309, 1313,
|
|
|
1323, 1325, 1349, 1356, 1357, 1366, 1380, 1381, 1385, 1386, 1391,
|
|
|
1399, 1436, 1440, 1441, 1444, 1452, 1454, 1457, 1475, 1476, 1477,
|
|
|
1480, 1487, 1516, 1536, 1544, 1558, 1564, 1573, 1599, 1602, 1604,
|
|
|
1607, 1619, 1648, 1670, 1683, 1693, 1716, 1722, 1737, 1738, 1745,
|
|
|
1747, 1750, 1754, 1755, 1764, 1781, 1803, 1808, 1810, 1816, 1838,
|
|
|
1844, 1847, 1855, 1860, 1866, 1873, 1905, 1911, 1917, 1921, 1923,
|
|
|
1928, 1933, 1950, 1951, 1970, 1977, 1980, 1984, 1991, 2002, 2033, 2034, 2038]
|
|
|
|
|
|
class ConfigDict(dict):
|
|
|
'''
|
|
|
Makes a dictionary behave like an object,with attribute-style access.
|
|
|
'''
|
|
|
def __getattr__(self, name):
|
|
|
try:
|
|
|
return self[name]
|
|
|
except:
|
|
|
raise AttributeError(name)
|
|
|
|
|
|
def __setattr__(self, name, value):
|
|
|
self[name] = value
|
|
|
|
|
|
def save(self, fn):
|
|
|
json.dump(self, open(fn, 'w'), indent=2)
|
|
|
|
|
|
def load_dict(self, dic):
|
|
|
for k, v in dic.items():
|
|
|
self[k] = v
|
|
|
|
|
|
def load(self, fn):
|
|
|
try:
|
|
|
d = json.load(open(fn, 'r'))
|
|
|
self.load_dict(d)
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
|
|
|
def conv_out_dim(length_in, kernel, stride, padding, dilation):
|
|
|
length_out = (length_in + 2 * padding - dilation * (kernel - 1) - 1)// stride + 1
|
|
|
return length_out
|
|
|
|
|
|
def filter_ms(ms, thr=0.05, max_mz=2000):
|
|
|
mz = []
|
|
|
intn = []
|
|
|
maxi = 0
|
|
|
for m, i in ms:
|
|
|
if m < max_mz and i > maxi:
|
|
|
maxi = i
|
|
|
|
|
|
for m, i in ms:
|
|
|
if m < max_mz and i/maxi > thr:
|
|
|
mz.append(m)
|
|
|
intn.append(round(i/maxi*100, 2))
|
|
|
|
|
|
return mz, intn
|
|
|
|
|
|
def calc_nls(ms, thr=0.05, max_mz=2000):
|
|
|
mz, intn = filter_ms(ms, thr=0.05, max_mz=2000)
|
|
|
|
|
|
nlmass = []
|
|
|
nlintn = []
|
|
|
for a, b in it.combinations(mz[::-1], 2):
|
|
|
nl = a - b
|
|
|
if 0 < nl < 200:
|
|
|
nlmass.append(round(nl, 5))
|
|
|
idxa = mz.index(a)
|
|
|
idxb = mz.index(b)
|
|
|
nlintn.append(round((intn[idxa]+intn[idxb])/2., 5))
|
|
|
|
|
|
nls = sorted(list(zip(nlmass, nlintn)))
|
|
|
return nls
|
|
|
|
|
|
def ms_binner(ms, nls=[], min_mz=20, max_mz=2000, bin_size=0.05, add_nl=False, binary_intn=False):
|
|
|
"""
|
|
|
Convert the given spectrum to a binned sparse SciPy vector.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
spectrum_mz : np.ndarray
|
|
|
The peak m/z values of the spectrum to be converted to a vector.
|
|
|
spectrum_intensity : np.ndarray
|
|
|
The peak intensities of the spectrum to be converted to a vector.
|
|
|
min_mz : float
|
|
|
The minimum m/z to include in the vector.
|
|
|
bin_size : float
|
|
|
The bin size in m/z used to divide the m/z range.
|
|
|
num_bins : int
|
|
|
The number of elements of which the vector consists.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
ss.csr_matrix
|
|
|
The binned spectrum vector.
|
|
|
"""
|
|
|
if add_nl and not nls:
|
|
|
nls = calc_nls(ms, max_mz=max_mz)
|
|
|
|
|
|
nltensor = None
|
|
|
mz, intn = filter_ms(ms)
|
|
|
|
|
|
if add_nl:
|
|
|
nlmass = []
|
|
|
nlintn = []
|
|
|
|
|
|
if not nls:
|
|
|
nls = calc_nls(ms, max_mz=max_mz)
|
|
|
|
|
|
for m, i in nls:
|
|
|
if m < 200:
|
|
|
if binary_intn:
|
|
|
i = 1
|
|
|
nlmass.append(m)
|
|
|
nlintn.append(i)
|
|
|
|
|
|
nlmass = np.array(nlmass)
|
|
|
nlintn = np.array(nlintn)
|
|
|
if len(nlintn) > 0:
|
|
|
nlintn = nlintn/nlintn.max()
|
|
|
num_nlbins = math.ceil((200) / bin_size)
|
|
|
|
|
|
nlbins = (nlmass / bin_size).astype(np.int32)
|
|
|
|
|
|
if len(nlmass) > 0:
|
|
|
vecnl = ss.csr_matrix(
|
|
|
(nlintn,
|
|
|
(np.repeat(0, len(nlintn)), nlbins)),
|
|
|
shape=(1, num_nlbins),
|
|
|
dtype=np.float32)
|
|
|
|
|
|
vecnl = (vecnl / scipy.sparse.linalg.norm(vecnl)*100)
|
|
|
nltensor = torch.FloatTensor(vecnl.todense()).view(-1)
|
|
|
else:
|
|
|
nltensor = torch.zeros(num_nlbins)
|
|
|
|
|
|
mz = np.array(mz)
|
|
|
keepidx = (mz <= max_mz)
|
|
|
mz = mz[keepidx]
|
|
|
intn = np.array(intn)
|
|
|
intn = intn[keepidx]
|
|
|
|
|
|
if binary_intn:
|
|
|
intn[intn > 0] = 1.0
|
|
|
elif len(intn) > 0:
|
|
|
intn = intn/intn.max()
|
|
|
|
|
|
num_bins = math.ceil((max_mz - min_mz) / bin_size)
|
|
|
|
|
|
bins = ((mz - min_mz) / bin_size).astype(np.int32)
|
|
|
|
|
|
|
|
|
|
|
|
if len(mz) > 0:
|
|
|
vec = ss.csr_matrix(
|
|
|
(intn,
|
|
|
(np.repeat(0, len(intn)), bins)),
|
|
|
shape=(1, num_bins),
|
|
|
dtype=np.float32)
|
|
|
|
|
|
if not binary_intn:
|
|
|
vec = (vec / scipy.sparse.linalg.norm(vec)*100)
|
|
|
|
|
|
mstensor = torch.FloatTensor(vec.todense()).view(-1)
|
|
|
else:
|
|
|
mstensor = torch.zeros(num_bins)
|
|
|
|
|
|
if not nltensor is None:
|
|
|
return torch.cat([nltensor, mstensor], dim=0)
|
|
|
|
|
|
return mstensor
|
|
|
|
|
|
def formula2vec(formula, elements=['C', 'H', 'O', 'N', 'P', 'S', 'P', 'F', 'Cl', 'Br']):
|
|
|
formula_p = re.findall(r'([A-Z][a-z]*)(\d*)', formula)
|
|
|
vec = np.zeros(len(elements))
|
|
|
for i in range(len(formula_p)):
|
|
|
ele = formula_p[i][0]
|
|
|
num = formula_p[i][1]
|
|
|
if num == '':
|
|
|
num = 1
|
|
|
else:
|
|
|
num = int(num)
|
|
|
if ele in elements:
|
|
|
vec[elements.index(ele)] += num
|
|
|
return np.array(vec)
|
|
|
|
|
|
def mol_fp_encoder0(smiles, tp='rdkit', nbits=2048):
|
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
|
if mol is None:
|
|
|
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
|
|
if not mol is None:
|
|
|
mol.UpdatePropertyCache()
|
|
|
FastFindRings(mol)
|
|
|
|
|
|
if mol is None:
|
|
|
return None, None
|
|
|
|
|
|
if tp == 'morgan':
|
|
|
fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
|
|
|
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
|
|
fp = fp.tolist()
|
|
|
elif tp == 'morgan1':
|
|
|
fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
|
|
|
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
|
|
fp = fp[FPBitIdx].tolist()
|
|
|
elif tp == 'macc':
|
|
|
|
|
|
fp_vec = MACCSkeys.GenMACCSKeys(mol)
|
|
|
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
|
|
fp = fp.tolist()
|
|
|
elif tp == 'rdkit':
|
|
|
fp_vec = Chem.RDKFingerprint(mol, nBitsPerHash=1)
|
|
|
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
|
|
fp = fp.tolist()
|
|
|
|
|
|
return torch.FloatTensor(fp), mol
|
|
|
|
|
|
def mol_fp_encoder(smiles, tp='rdkit', nbits=2048):
|
|
|
fpenc, _ = mol_fp_encoder0(smiles, tp, nbits)
|
|
|
return fpenc
|
|
|
|
|
|
def mol_fp_fm_encoder(smiles, tp='rdkit', nbits=2048):
|
|
|
fmenc = None
|
|
|
fpenc, mol = mol_fp_encoder0(smiles, tp, nbits)
|
|
|
if not mol is None:
|
|
|
fm = CalcMolFormula(mol)
|
|
|
fmenc = torch.FloatTensor(formula2vec(fm))
|
|
|
return fpenc, fmenc
|
|
|
|
|
|
def smi2fmvec(smiles):
|
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
|
if mol is None:
|
|
|
return None
|
|
|
fm = CalcMolFormula(mol)
|
|
|
fmenc = torch.FloatTensor(formula2vec(fm))
|
|
|
|
|
|
return fmenc
|
|
|
|
|
|
def mol_graph_featurizer(smiles):
|
|
|
|
|
|
'''mol_graph = ft.calc_data_from_smile(smiles,
|
|
|
addh=True,
|
|
|
with_ring_conj=True,
|
|
|
with_atom_feats=True,
|
|
|
with_submol_fp=True,
|
|
|
radius=2)
|
|
|
'''
|
|
|
mol_graph = ft.calc_data_from_smile(smiles,
|
|
|
addh=False,
|
|
|
with_ring_conj=True,
|
|
|
with_atom_feats=True,
|
|
|
with_submol_fp=False,
|
|
|
radius=2)
|
|
|
return mol_graph
|
|
|
|
|
|
def pad_V(V, max_n):
|
|
|
N, C = V.shape
|
|
|
if max_n > N:
|
|
|
zeros = torch.zeros(max_n-N, C)
|
|
|
V = torch.cat([V, zeros], dim=0)
|
|
|
return V
|
|
|
|
|
|
def pad_A(A, max_n):
|
|
|
N, L, _ = A.shape
|
|
|
if max_n > N:
|
|
|
zeros = torch.zeros(N, L, max_n-N)
|
|
|
A = torch.cat([A, zeros], dim=-1)
|
|
|
zeros = torch.zeros(max_n-N, L, max_n)
|
|
|
A = torch.cat([A, zeros], dim=0)
|
|
|
return A
|
|
|
|
|
|
class AvgMeter:
|
|
|
def __init__(self, name="Metric"):
|
|
|
self.name = name
|
|
|
self.reset()
|
|
|
|
|
|
def reset(self):
|
|
|
self.avg, self.sum, self.count = [0] * 3
|
|
|
|
|
|
def update(self, val, count=1):
|
|
|
self.count += count
|
|
|
self.sum += val * count
|
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def __repr__(self):
|
|
|
text = f"{self.name}: {self.avg:.4f}"
|
|
|
return text
|
|
|
|
|
|
def get_lr(optimizer):
|
|
|
for param_group in optimizer.param_groups:
|
|
|
return param_group["lr"]
|
|
|
|
|
|
def segment_max(x, size_list):
|
|
|
size_list = [int(i) for i in size_list]
|
|
|
return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
|
|
|
|
|
def segment_sum(x, size_list):
|
|
|
size_list = [int(i) for i in size_list]
|
|
|
return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
|
|
|
|
|
def segment_softmax(gate, size_list):
|
|
|
segmax = segment_max(gate, size_list)
|
|
|
|
|
|
segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
|
|
subtract = gate - segmax_expand
|
|
|
exp = torch.exp(subtract)
|
|
|
segsum = segment_sum(exp, size_list)
|
|
|
|
|
|
segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
|
|
attention = exp / (segsum_expand + 1e-16)
|
|
|
|
|
|
return attention
|
|
|
|
|
|
def pad_ms_list(ms_list, thr=0.05, min_mz=20, max_mz=2000):
|
|
|
thr = thr*100
|
|
|
mslst = []
|
|
|
for ms in ms_list:
|
|
|
ms = np.array(ms)
|
|
|
ms[:,1] = ms[:,1]/ms[:,1].max()*100
|
|
|
|
|
|
if thr > 0:
|
|
|
ms = ms[(ms[:,1] >= thr)]
|
|
|
|
|
|
ms = ms[(ms[:,0] >= min_mz)]
|
|
|
ms = ms[(ms[:,0] <= max_mz)]
|
|
|
|
|
|
mslst.append(ms)
|
|
|
|
|
|
size_list = [ms.shape[0] for ms in mslst]
|
|
|
maxlen = max(size_list)
|
|
|
|
|
|
l = []
|
|
|
for ms in mslst:
|
|
|
extn = maxlen-len(ms)
|
|
|
if extn > 0:
|
|
|
l.append(np.concatenate([ms, [[0,0]]*extn], axis=0))
|
|
|
else:
|
|
|
l.append(ms)
|
|
|
|
|
|
return torch.FloatTensor(np.stack(l)), torch.IntTensor(size_list)
|
|
|
|