from rdkit import Chem from rdkit.Chem import AllChem, MACCSkeys from rdkit.Chem.rdmolops import FastFindRings from rdkit.Chem.rdMolDescriptors import CalcMolFormula import torch import numpy as np import scipy import scipy.sparse as ss import scipy.sparse.linalg import math import json import itertools as it import re from GNN import featurizer as ft import rdkit.RDLogger as rkl logger = rkl.logger() logger.setLevel(rkl.ERROR) import rdkit.rdBase as rkrb rkrb.DisableLog('rdApp.error') # 50w metabolites fpbit relative aboundance > 5% FPBitIdx = [1, 5, 13, 41, 69, 80, 84, 94, 114, 117, 118, 119, 125, 133, 145, 147, 191, 192, 197, 202, 222, 227, 231, 249, 283, 294, 310, 314, 322, 333, 352, 361, 378, 387, 389, 392, 401, 406, 441, 478, 486, 489, 519, 521, 524, 555, 561, 591, 598, 599, 610, 622, 650, 656, 667, 675, 677, 679, 680, 694, 695, 715, 718, 722, 729, 736, 739, 745, 750, 760, 775, 781, 787, 794, 798, 802, 807, 811, 823, 835, 841, 849, 869, 872, 874, 875, 881, 890, 896, 926, 935, 980, 991, 1004, 1009, 1017, 1019, 1027, 1028, 1035, 1037, 1039, 1057, 1060, 1066, 1070, 1077, 1088, 1097, 1114, 1126, 1136, 1142, 1143, 1145, 1152, 1154, 1160, 1162, 1171, 1181, 1195, 1199, 1202, 1218, 1234, 1236, 1243, 1257, 1267, 1274, 1279, 1283, 1292, 1294, 1309, 1313, 1323, 1325, 1349, 1356, 1357, 1366, 1380, 1381, 1385, 1386, 1391, 1399, 1436, 1440, 1441, 1444, 1452, 1454, 1457, 1475, 1476, 1477, 1480, 1487, 1516, 1536, 1544, 1558, 1564, 1573, 1599, 1602, 1604, 1607, 1619, 1648, 1670, 1683, 1693, 1716, 1722, 1737, 1738, 1745, 1747, 1750, 1754, 1755, 1764, 1781, 1803, 1808, 1810, 1816, 1838, 1844, 1847, 1855, 1860, 1866, 1873, 1905, 1911, 1917, 1921, 1923, 1928, 1933, 1950, 1951, 1970, 1977, 1980, 1984, 1991, 2002, 2033, 2034, 2038] class ConfigDict(dict): ''' Makes a dictionary behave like an object,with attribute-style access. ''' def __getattr__(self, name): try: return self[name] except: raise AttributeError(name) def __setattr__(self, name, value): self[name] = value def save(self, fn): json.dump(self, open(fn, 'w'), indent=2) def load_dict(self, dic): for k, v in dic.items(): self[k] = v def load(self, fn): try: d = json.load(open(fn, 'r')) self.load_dict(d) except Exception as e: print(e) def conv_out_dim(length_in, kernel, stride, padding, dilation): length_out = (length_in + 2 * padding - dilation * (kernel - 1) - 1)// stride + 1 return length_out def filter_ms(ms, thr=0.05, max_mz=2000): mz = [] intn = [] maxi = 0 for m, i in ms: if m < max_mz and i > maxi: maxi = i for m, i in ms: if m < max_mz and i/maxi > thr: mz.append(m) intn.append(round(i/maxi*100, 2)) return mz, intn def calc_nls(ms, thr=0.05, max_mz=2000): mz, intn = filter_ms(ms, thr=0.05, max_mz=2000) nlmass = [] nlintn = [] for a, b in it.combinations(mz[::-1], 2): nl = a - b if 0 < nl < 200: nlmass.append(round(nl, 5)) idxa = mz.index(a) idxb = mz.index(b) nlintn.append(round((intn[idxa]+intn[idxb])/2., 5)) nls = sorted(list(zip(nlmass, nlintn))) return nls def ms_binner(ms, nls=[], min_mz=20, max_mz=2000, bin_size=0.05, add_nl=False, binary_intn=False): """ Convert the given spectrum to a binned sparse SciPy vector. Parameters ---------- spectrum_mz : np.ndarray The peak m/z values of the spectrum to be converted to a vector. spectrum_intensity : np.ndarray The peak intensities of the spectrum to be converted to a vector. min_mz : float The minimum m/z to include in the vector. bin_size : float The bin size in m/z used to divide the m/z range. num_bins : int The number of elements of which the vector consists. Returns ------- ss.csr_matrix The binned spectrum vector. """ if add_nl and not nls: nls = calc_nls(ms, max_mz=max_mz) nltensor = None mz, intn = filter_ms(ms) if add_nl: nlmass = [] nlintn = [] if not nls: nls = calc_nls(ms, max_mz=max_mz) for m, i in nls: if m < 200: if binary_intn: i = 1 nlmass.append(m) nlintn.append(i) nlmass = np.array(nlmass) nlintn = np.array(nlintn) if len(nlintn) > 0: nlintn = nlintn/nlintn.max() num_nlbins = math.ceil((200) / bin_size) #print('num_nlbins', num_nlbins) nlbins = (nlmass / bin_size).astype(np.int32) if len(nlmass) > 0: vecnl = ss.csr_matrix( (nlintn, (np.repeat(0, len(nlintn)), nlbins)), shape=(1, num_nlbins), dtype=np.float32) vecnl = (vecnl / scipy.sparse.linalg.norm(vecnl)*100) nltensor = torch.FloatTensor(vecnl.todense()).view(-1) else: nltensor = torch.zeros(num_nlbins) mz = np.array(mz) keepidx = (mz <= max_mz) mz = mz[keepidx] intn = np.array(intn) intn = intn[keepidx] if binary_intn: intn[intn > 0] = 1.0 elif len(intn) > 0: intn = intn/intn.max() num_bins = math.ceil((max_mz - min_mz) / bin_size) #print('num_bins', num_bins) bins = ((mz - min_mz) / bin_size).astype(np.int32) #print(num_bins, intn, bins) if len(mz) > 0: vec = ss.csr_matrix( (intn, (np.repeat(0, len(intn)), bins)), shape=(1, num_bins), dtype=np.float32) if not binary_intn: vec = (vec / scipy.sparse.linalg.norm(vec)*100) mstensor = torch.FloatTensor(vec.todense()).view(-1) else: mstensor = torch.zeros(num_bins) if not nltensor is None: return torch.cat([nltensor, mstensor], dim=0) return mstensor def formula2vec(formula, elements=['C', 'H', 'O', 'N', 'P', 'S', 'P', 'F', 'Cl', 'Br']): formula_p = re.findall(r'([A-Z][a-z]*)(\d*)', formula) vec = np.zeros(len(elements)) for i in range(len(formula_p)): ele = formula_p[i][0] num = formula_p[i][1] if num == '': num = 1 else: num = int(num) if ele in elements: vec[elements.index(ele)] += num return np.array(vec) def mol_fp_encoder0(smiles, tp='rdkit', nbits=2048): mol = Chem.MolFromSmiles(smiles) if mol is None: mol = Chem.MolFromSmiles(smiles, sanitize=False) if not mol is None: mol.UpdatePropertyCache() FastFindRings(mol) if mol is None: return None, None if tp == 'morgan': fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits) fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0') fp = fp.tolist() elif tp == 'morgan1': fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0') fp = fp[FPBitIdx].tolist() elif tp == 'macc': # MACCSkeys fp_vec = MACCSkeys.GenMACCSKeys(mol) fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0') fp = fp.tolist() elif tp == 'rdkit': fp_vec = Chem.RDKFingerprint(mol, nBitsPerHash=1) fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0') fp = fp.tolist() return torch.FloatTensor(fp), mol def mol_fp_encoder(smiles, tp='rdkit', nbits=2048): fpenc, _ = mol_fp_encoder0(smiles, tp, nbits) return fpenc def mol_fp_fm_encoder(smiles, tp='rdkit', nbits=2048): fmenc = None fpenc, mol = mol_fp_encoder0(smiles, tp, nbits) if not mol is None: fm = CalcMolFormula(mol) fmenc = torch.FloatTensor(formula2vec(fm)) return fpenc, fmenc def smi2fmvec(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return None fm = CalcMolFormula(mol) fmenc = torch.FloatTensor(formula2vec(fm)) return fmenc def mol_graph_featurizer(smiles): # mol_graph = {V, A, mol_size} '''mol_graph = ft.calc_data_from_smile(smiles, addh=True, with_ring_conj=True, with_atom_feats=True, with_submol_fp=True, radius=2) ''' mol_graph = ft.calc_data_from_smile(smiles, addh=False, with_ring_conj=True, with_atom_feats=True, with_submol_fp=False, radius=2) return mol_graph def pad_V(V, max_n): N, C = V.shape if max_n > N: zeros = torch.zeros(max_n-N, C) V = torch.cat([V, zeros], dim=0) return V def pad_A(A, max_n): N, L, _ = A.shape if max_n > N: zeros = torch.zeros(N, L, max_n-N) A = torch.cat([A, zeros], dim=-1) zeros = torch.zeros(max_n-N, L, max_n) A = torch.cat([A, zeros], dim=0) return A class AvgMeter: def __init__(self, name="Metric"): self.name = name self.reset() def reset(self): self.avg, self.sum, self.count = [0] * 3 def update(self, val, count=1): self.count += count self.sum += val * count self.avg = self.sum / self.count def __repr__(self): text = f"{self.name}: {self.avg:.4f}" return text def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group["lr"] def segment_max(x, size_list): size_list = [int(i) for i in size_list] return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)]) def segment_sum(x, size_list): size_list = [int(i) for i in size_list] return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)]) def segment_softmax(gate, size_list): segmax = segment_max(gate, size_list) # expand segmax shape to alpha shape segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0) subtract = gate - segmax_expand exp = torch.exp(subtract) segsum = segment_sum(exp, size_list) # expand segmax shape to alpha shape segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0) attention = exp / (segsum_expand + 1e-16) return attention def pad_ms_list(ms_list, thr=0.05, min_mz=20, max_mz=2000): thr = thr*100 mslst = [] for ms in ms_list: ms = np.array(ms) ms[:,1] = ms[:,1]/ms[:,1].max()*100 if thr > 0: ms = ms[(ms[:,1] >= thr)] ms = ms[(ms[:,0] >= min_mz)] ms = ms[(ms[:,0] <= max_mz)] mslst.append(ms) size_list = [ms.shape[0] for ms in mslst] maxlen = max(size_list) l = [] for ms in mslst: extn = maxlen-len(ms) if extn > 0: l.append(np.concatenate([ms, [[0,0]]*extn], axis=0)) else: l.append(ms) return torch.FloatTensor(np.stack(l)), torch.IntTensor(size_list)