tox21_snn_classifier / preprocess.py
antoniaebner's picture
upload code
35189e2
raw
history blame
1.78 kB
"""
This files includes a the data processing for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
import os
import json
import argparse
import numpy as np
from src.preprocess import create_descriptors, get_tox21_split
from src.utils import TASKS, HF_TOKEN, create_dir, normalize_config
parser = argparse.ArgumentParser(
description="Data preprocessing script for the Tox21 dataset"
)
parser.add_argument(
"--config",
type=str,
default="config/config.json",
)
def main(config):
"""Create molecule descriptors for HF Tox21 dataset"""
ds = get_tox21_split(HF_TOKEN, cvfold=config["cvfold"])
splits = ["train", "validation"]
for split in splits:
print(f"Preprocess {split} molecules")
ds_split = ds[split]
smiles = list(ds_split["smiles"])
features, clean_mol_mask = create_descriptors(
smiles, config["descriptors"], **config["ecfp"]
)
labels = []
for task in TASKS:
labels.append(ds_split[task].to_numpy())
labels = np.stack(labels, axis=1)
save_path = os.path.join(config["data_folder"], f"tox21_{split}_cv4.npz")
with open(save_path, "wb") as f:
np.savez(
f,
clean_mol_mask=clean_mol_mask,
labels=labels,
**features,
)
print(f"Saved preprocessed {split} split under {save_path}")
print("Preprocessing finished successfully")
if __name__ == "__main__":
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
config = normalize_config(config)
create_dir(config["data_folder"])
main(config)