import copy import json from typing import Any import numpy as np import pandas as pd from datasets import load_dataset from sklearn.base import BaseEstimator, TransformerMixin from sklearn.feature_selection import VarianceThreshold from sklearn.preprocessing import StandardScaler, FunctionTransformer from statsmodels.distributions.empirical_distribution import ECDF from rdkit import Chem, DataStructs from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys from rdkit.Chem.rdchem import Mol from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer, FeatureDictMixin class SquashScaler(TransformerMixin, BaseEstimator): """ Scaler that performs sequential standardization, nonlinearity (tanh), and re-standardization. Inspired by DeepTox (Mayr et al., 2016) """ def __init__(self): self.scaler1 = StandardScaler() self.scaler2 = StandardScaler() def fit(self, X): _X = X.copy() _X = self.scaler1.fit_transform(_X) _X = np.tanh(_X) _X = self.scaler2.fit(_X) self.is_fitted_ = True return self def transform(self, X): _X = X.copy() _X = self.scaler1.transform(_X) _X = np.tanh(_X) return self.scaler2.transform(_X) SCALER_REGISTRY = { None: FunctionTransformer, "standard": StandardScaler, "squash": SquashScaler, } class SubSampler(TransformerMixin, BaseEstimator): """ Preprocessor that randomly samples `max_samples` from data. Args: max_samples (int): Maximum allowed samples. If -1, all samples are retained. Input: np.ndarray: A 2D NumPy array of shape (n_samples, n_features). Output: np.ndarray: Subsampled array of shape (min(n_samples, max_samples), n_features). """ def __init__(self, *, max_samples=-1): self.max_samples = max_samples self.is_fitted_ = True def fit(self, X: np.ndarray, y: np.ndarray | None = None): return self def transform( self, X: np.ndarray, y: np.ndarray | None = None ) -> np.ndarray | tuple[np.ndarray]: _X = X.copy() _y = y.copy() if y is not None else None if self.max_samples > 0 and _X.shape[0] > self.max_samples: resample_idxs = np.random.choice( np.arange(_X.shape[0]), size=(self.max_samples,), replace=True ) _X = _X[resample_idxs] _y = _y[resample_idxs] if _y is not None else None if _y is None: return _X return _X, _y class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator): """ Preprocessor that performs feature selection based on variance and correlation. This transformer selects features that: 1. Have variance above a specified threshold. 2. Are below a given pairwise correlation threshold. 3. Among the remaining features, keeps only the top `max_features` with the highest variance. The input and output are both dictionaries mapping feature types to their corresponding feature matrices. Args: min_var (float): Minimum variance required for a feature to be retained. max_corr (float): Maximum allowed correlation between features. Features exceeding this threshold with others are removed. max_features (int): Maximum number of features to keep after filtering. If -1, all remaining features are retained. feature_keys (list[str]): Features to apply feature selection to. independent_keys (bool): Apply filtering only within features types. Input: dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type and each value is a 2D NumPy array of shape (n_samples, n_features). Output: dict[str, np.ndarray]: A dictionary with the same keys as the input, containing only the selected features for each feature type. """ def __init__( self, *, min_var=0.0, max_corr=1.0, max_features=-1, feature_keys=None, min_var__feature_keys=None, max_corr__feature_keys=None, max_features__feature_keys=None, min_var__independent_keys=False, max_corr__independent_keys=False, max_features__independent_keys=False, ): self.min_var = min_var self.max_corr = max_corr self.max_features = max_features self.min_var__feature_keys = min_var__feature_keys self.max_corr__feature_keys = max_corr__feature_keys self.max_features__feature_keys = max_features__feature_keys self.min_var__independent_keys = min_var__independent_keys self.max_corr__independent_keys = max_corr__independent_keys self.max_features__independent_keys = max_features__independent_keys super().__init__(feature_keys=feature_keys) def _get_min_var_mask(self, X: np.ndarray, *args) -> np.ndarray: var_thresh = VarianceThreshold(threshold=self.min_var) return var_thresh.fit(X).get_support() # mask def _get_max_corr_mask( self, X: np.ndarray, prev_feature_mask: np.ndarray ) -> np.ndarray: _prev_feature_mask = prev_feature_mask.copy() corr_matrix = np.corrcoef(X[:, _prev_feature_mask], rowvar=False) upper_tri = np.triu(corr_matrix, k=1) to_keep = np.ones((sum(_prev_feature_mask),), dtype=bool) for i in range(upper_tri.shape[0]): for j in range(upper_tri.shape[1]): if upper_tri[i, j] > self.max_corr: to_keep[j] = False _prev_feature_mask[_prev_feature_mask] = to_keep return _prev_feature_mask def _get_max_features_mask( self, X: np.ndarray, prev_feature_mask: np.ndarray ) -> np.ndarray: _prev_feature_mask = prev_feature_mask.copy() # select features with at least max_var variation feature_vars = np.nanvar(X[:, _prev_feature_mask], axis=0) order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1] keep_feat_idx = np.arange(len(_prev_feature_mask))[order] _prev_feature_mask = np.isin( np.arange(len(_prev_feature_mask)), keep_feat_idx, assume_unique=True ) return _prev_feature_mask def apply_filter(self, filter, X, prev_feature_mask): mask = prev_feature_mask.copy() func = self.__getattribute__(f"_get_{filter}_mask") feature_keys = self.__getattribute__(f"{filter}__feature_keys") if self.__getattribute__(f"{filter}__independent_keys"): for key in feature_keys: key_mask = self._curr_keys == key mask[key_mask] = func(X[:, key_mask], mask[key_mask]) else: feature_key_mask = np.isin(self._curr_keys, feature_keys) mask[feature_key_mask] = func( X[:, feature_key_mask], mask[feature_key_mask] ) return mask def fit(self, X: dict[str, np.ndarray]): _X = self.dict_to_array(X) feature_mask = np.ones((_X.shape[1]), dtype=bool) # select features with at least min_var variation if self.min_var > 0.0: if self.min_var__independent_keys: for key in self.min_var__feature_keys: key_mask = self._curr_keys == key feature_mask[key_mask] = self._get_min_var_mask(_X[:, key_mask]) else: feature_key_mask = np.isin(self._curr_keys, self.min_var__feature_keys) feature_mask[feature_key_mask] = self._get_min_var_mask( _X[:, feature_key_mask] ) # select features with at least max_var variation if self.max_corr < 1.0: if self.max_corr__independent_keys: for key in self.max_corr__feature_keys: key_mask = self._curr_keys == key subset = _X[:, key_mask] feature_mask[key_mask] = self._get_max_corr_mask( subset, feature_mask[key_mask] ) else: feature_key_mask = np.isin(self._curr_keys, self.max_corr__feature_keys) feature_mask[feature_key_mask] = self._get_max_corr_mask( _X[:, feature_key_mask], feature_mask[feature_key_mask] ) if self.max_features == 0: raise ValueError( f"max_features (={self.max_features}) must be -1 or larger 0." ) elif self.max_features > 0: if self.max_features__independent_keys: for key in self.max_features__feature_keys: key_mask = self._curr_keys == key feature_mask[key_mask] = self._get_max_features_mask( _X[:, key_mask], feature_mask[key_mask] ) else: feature_key_mask = np.isin( self._curr_keys, self.max_features__feature_keys ) feature_mask[feature_key_mask] = self._get_max_features_mask( _X[:, feature_key_mask], feature_mask[feature_key_mask] ) self._feature_mask = feature_mask self.is_fitted_ = True return self def transform(self, X: dict[str, np.ndarray]) -> dict[str, np.ndarray]: _X = self.dict_to_array(X) _X = _X[:, self._feature_mask] self._curr_keys = self._curr_keys[self._feature_mask] return self.array_to_dict(_X) class QuantileCreator(FeatureDictMixin, TransformerMixin, BaseEstimator): """ Preprocessor that transforms features into empirical quantiles using ECDFs. This transformer applies an Empirical Cumulative Distribution Function (ECDF) to each feature and replaces feature values with their corresponding quantile ranks. The transformation is applied independently to each feature type. Both input and output are dictionaries mapping feature types to their corresponding feature matrices. Args: feature_keys (list[str]): Features to apply quantile creation to. Input: dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type and each value is a 2D NumPy array of shape (n_samples, n_features). Output: dict[str, np.ndarray]: A dictionary with the same keys as the input, where each feature value is replaced by its corresponding ECDF quantile rank. """ def __init__(self, *, feature_keys=None): self._ecdfs = None super().__init__(feature_keys=feature_keys) def fit(self, X: dict[str, np.ndarray]): _X = self.dict_to_array(X) ecdfs = [] for column in range(_X.shape[1]): raw_values = _X[:, column].reshape(-1) ecdfs.append(ECDF(raw_values)) self._ecdfs = ecdfs self.is_fitted_ = True return self def transform(self, X: dict[str, np.ndarray]) -> np.ndarray: _X = self.dict_to_array(X) quantiles = np.zeros_like(_X) for column in range(_X.shape[1]): raw_values = _X[:, column].reshape(-1) ecdf = self._ecdfs[column] q = ecdf(raw_values) quantiles[:, column] = q return self.array_to_dict(quantiles) class FeaturePreprocessor(TransformerMixin, BaseEstimator): """This class implements the feature preprocessing from a dictionary of molecule features.""" def __init__( self, feature_selection_config: dict[str, Any], feature_quantilization_config: dict[str, Any], descriptors: list[str], max_samples: int = -1, scaler: str = "standard", ): self.descriptors = descriptors self.feature_quantilization_config = copy.deepcopy( feature_quantilization_config ) self.use_feat_quant = self.feature_quantilization_config.pop("use") self.quantile_creator = QuantileCreator(**self.feature_quantilization_config) self.feature_selection_config = copy.deepcopy(feature_selection_config) self.use_feat_selec = self.feature_selection_config.pop("use") self.feature_selection_config["feature_keys"] = descriptors self.feature_selector = FeatureSelector(**self.feature_selection_config) self.max_samples = max_samples self.sub_sampler = SubSampler(max_samples=max_samples) self.scaler = SCALER_REGISTRY[scaler]() def __getstate__(self): state = super().__getstate__() state["quantile_creator"] = self.quantile_creator.__getstate__() state["feature_selector"] = self.feature_selector.__getstate__() state["sub_sampler"] = self.sub_sampler.__getstate__() state["scaler"] = self.scaler.__getstate__() return state def __setstate__(self, state): _state = copy.deepcopy(state) self.quantile_creator.__setstate__(_state.pop("quantile_creator")) self.feature_selector.__setstate__(_state.pop("feature_selector")) self.sub_sampler.__setstate__(_state.pop("sub_sampler")) self.scaler.__setstate__(_state.pop("scaler")) super().__setstate__(_state) def get_state(self): return self.__getstate__() def set_state(self, state): return self.__setstate__(state) def fit(self, X: dict[str, np.ndarray]): """Fit the processor transformers""" _X = copy.deepcopy(X) if self.use_feat_quant: _X = self.quantile_creator.fit_transform(_X) if self.use_feat_selec: _X = self.feature_selector.fit_transform(_X) _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1) self.scaler.fit(_X) return self def transform( self, X: np.ndarray, y: np.ndarray | None = None ) -> np.ndarray | tuple[np.ndarray]: _X = X.copy() _y = y.copy() if y is not None else None if self.use_feat_quant: _X = self.quantile_creator.transform(_X) if self.use_feat_selec: _X = self.feature_selector.transform(_X) _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1) _X = self.scaler.transform(_X) if _y is None: _X = self.sub_sampler.transform(_X) return _X _X, _y = self.sub_sampler.transform(_X, _y) return _X, _y def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]: """This function creates cleaned RDKit mol objects from a list of SMILES. Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py Modification by Antonia Ebner: - skip uncleanable molecules - return clean molecule mask Args: smiles (list[str]): list of SMILES Returns: list[Mol]: list of cleaned molecules np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at index `i` could not be cleaned and was removed. """ sm = Standardizer(canon_taut=True) clean_mol_mask = list() mols = list() for i, smile in enumerate(smiles): mol = Chem.MolFromSmiles(smile) standardized_mol, _ = sm.standardize_mol(mol) is_cleaned = standardized_mol is not None clean_mol_mask.append(is_cleaned) if not is_cleaned: continue can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol)) mols.append(can_mol) return mols, np.array(clean_mol_mask) def create_ecfp_fps(mols: list[Mol], radius=3, fpsize=2048, **kwargs) -> np.ndarray: """This function ECFP fingerprints for a list of molecules. Inspired by from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py Args: mols (list[Mol]): list of molecules Returns: np.ndarray: ECFP fingerprints of molecules """ ecfps = list() for mol in mols: gen = rdFingerprintGenerator.GetMorganGenerator( countSimulation=True, fpSize=fpsize, radius=radius ) fp_sparse_vec = gen.GetCountFingerprint(mol) fp = np.zeros((0,), np.int8) DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp) ecfps.append(fp) return np.array(ecfps) def create_maccs_keys(mols: list[Mol]) -> np.ndarray: """This function creates MACCS keys for a list of molecules. Args: mols (list[Mol]): list of molecules Returns: np.ndarray: MACCS keys of molecules """ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols] return np.array(maccs) def get_tox_patterns(filepath: str): """This retrieves the tox features defined in filepath. Args: filepath (str): A list of tox features """ # load patterns with open(filepath) as f: smarts_list = [s[1] for s in json.load(f)] # Code does not work for this case assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0 # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first # and then use them for all molecules. This gives a huge speedup over existing code. # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value all_patterns = [] for smarts in smarts_list: patterns = [] # list of smarts-patterns # value for each of the patterns above. Negates the values of the above later. negations = [] if " AND " in smarts: smarts = smarts.split(" AND ") merge_any = False # If an ' AND ' is found all 'subsmarts' have to match else: # If there is an ' OR ' present it's enough is any of the 'subsmarts' match. # This also accumulates smarts where neither ' OR ' nor ' AND ' occur smarts = smarts.split(" OR ") merge_any = True # for all subsmarts check if they are preceded by 'NOT ' for s in smarts: neg = s.startswith("NOT ") if neg: s = s[4:] patterns.append(Chem.MolFromSmarts(s)) negations.append(neg) all_patterns.append((patterns, negations, merge_any)) return all_patterns def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray: """Matches the tox patterns against a molecule. Returns a boolean array""" tox_data = [] for mol in mols: mol_features = [] for patts, negations, merge_any in patterns: matches = [mol.HasSubstructMatch(p) for p in patts] matches = [m != n for m, n in zip(matches, negations)] if merge_any: pres = any(matches) else: pres = all(matches) mol_features.append(pres) tox_data.append(np.array(mol_features)) return np.array(tox_data) def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray: """This function creates RDKit descriptors for a list of molecules. Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py Args: mols (list[Mol]): list of molecules Returns: np.ndarray: RDKit descriptors of molecules """ rdkit_descriptors = list() for mol in mols: descrs = [] for _, descr_calc_fn in Descriptors._descList: descrs.append(descr_calc_fn(mol)) descrs = np.array(descrs) descrs = descrs[USED_200_DESCR] rdkit_descriptors.append(descrs) return np.array(rdkit_descriptors) def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray: """Create quantile values for given features using the columns Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py Args: raw_features (np.ndarray): values to put into quantiles ecdfs (list): ECDFs to use Returns: np.ndarray: computed quantiles """ quantiles = np.zeros_like(raw_features) for column in range(raw_features.shape[1]): raw_values = raw_features[:, column].reshape(-1) ecdf = ecdfs[column] q = ecdf(raw_values) quantiles[:, column] = q return quantiles def fill(features, mask, value=np.nan): n_mols = len(mask) n_features = features.shape[1] data = np.zeros(shape=(n_mols, n_features)) data.fill(value) data[~mask] = features return data def create_descriptors( smiles, descriptors, **ecfp_kwargs, ): """Generate molecular descriptors for multiple SMILES strings. Inspired by https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py Each SMILES is processed and sanitized using RDKit. SMILES that cannot be sanitized are encoded with NaNs, and a corresponding boolean mask is returned to indicate which inputs were successfully processed. Args: smiles (list[str]): List of SMILES strings for which to generate descriptors. descriptors (list[str]): List of descriptor types to compute. Supported values include: ['ecfps', 'tox', 'maccs', 'rdkit_descrs']. Returns: tuple[dict[str, np.ndarray], np.ndarray]: - A dictionary mapping descriptor names to their computed arrays. - A boolean mask of shape (len(smiles),) indicating which SMILES were successfully sanitized and processed. """ # Create cleanded rdkit mol objects mols, clean_mol_mask = create_cleaned_mol_objects(smiles) print(f"Cleaned molecules, {(~clean_mol_mask).sum()} could not be sanitized") # Create fingerprints and descriptors if "ecfps" in descriptors: ecfps = create_ecfp_fps(mols, **ecfp_kwargs) ecfps = fill(ecfps, ~clean_mol_mask) print("Created ECFP fingerprints") if "tox" in descriptors: tox_patterns = get_tox_patterns(TOX_SMARTS_PATH) tox = create_tox_features(mols, tox_patterns) tox = fill(tox, ~clean_mol_mask) print("Created Tox features") if "maccs" in descriptors: maccs = create_maccs_keys(mols) maccs = fill(maccs, ~clean_mol_mask) print("Created MACCS keys") if "rdkit_descrs" in descriptors: rdkit_descrs = create_rdkit_descriptors(mols) rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask) print("Created RDKit descriptors") # concatenate features features = {} for descr in descriptors: features[descr] = vars()[descr] return features, clean_mol_mask def get_tox21_split(token, cvfold=None): """Retrieve Tox21 splits from HuggingFace with respect to given cvfold.""" ds = load_dataset("ml-jku/tox21", token=token) train_df = ds["train"].to_pandas() val_df = ds["validation"].to_pandas() if cvfold is None: return {"train": train_df, "validation": val_df} combined_df = pd.concat([train_df, val_df], ignore_index=True) cvfold = float(cvfold) # create new splits cvfold = float(cvfold) train_df = combined_df[combined_df.CVfold != cvfold] val_df = combined_df[combined_df.CVfold == cvfold] # exclude train mols that occur in the validation split val_inchikeys = set(val_df["inchikey"]) train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)] return { "train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True), }