--- title: Tox21 SNN Classifier emoji: 🌖 colorFrom: green colorTo: pink sdk: docker pinned: false license: cc-by-nc-4.0 short_description: Self-Normalizing Neural Network Baseline for Tox21 --- # Tox21 SNN Classifier This repository hosts a Hugging Face Space that provides an API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/ml-jku/tox21_leaderboard). Here a [self-normalizing network (SNN)](https://arxiv.org/abs/1706.02515) is trained on the Tox21 dataset, and the trained models are provided for inference. Model input is a SMILES string of the small molecule, and the output are 12 numeric values for each of the toxic effects of the Tox21 dataset. **Important:** For leaderboard submission, your Space needs to include training code. The file `train.py` should train the model using the config specified inside the `config/` folder and save the final model parameters into a file inside the `checkpoints/` folder. The model should be trained using the [Tox21_dataset](https://huggingface.co/datasets/tschouis/tox21) provided on Hugging Face. The datasets can be loaded like this: ```python from datasets import load_dataset ds = load_dataset("ml-jku/tox21", token=token) train_df = ds["train"].to_pandas() val_df = ds["validation"].to_pandas() ``` Additionally, the Space needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference. # Repository Structure - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference). - `app.py` - FastAPI application wrapper (can be used as-is). - `preprocess.py` - preprocesses SMILES strings to generate feature descriptors and saves results as NPZ files in `data/`. - `train.py` - trains and saves a model using the config in the `config/` folder. - `config/` - the config file used by `train.py`. - `logs/` - all the logs of `train.py`, the saved model, and predictions on the validation set. - `data/` - SNN uses numerical data. During preprocessing in `preprocess.py` two NPZ files containing molecule features are created and saved here. - `checkpoints/` - the saved model that is used in `predict.py` is here. - `src/` - Core model & preprocessing logic: - `preprocess.py` - SMILES preprocessing logic - `model.py` - SNN model class with processing, saving and loading logic - `utils.py` - utility functions # Quickstart with Spaces You can easily adapt this project in your own Hugging Face account: - Open this Space on Hugging Face. - Click "Duplicate this Space" (top-right corner). - Modify `src/` for your preprocessing pipeline and model class - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard. - Modify `train.py` and/or `preprocess.py` according to your model and preprocessing pipeline. - Modify the file inside `config/` to contain all hyperparameters that are set in `train.py`. That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard. # Installation To run (and train) the SNN, clone the repository and install dependencies: ```bash git clone https://huggingface.co/spaces/ml-jku/tox21_snn_classifier cd tox21_snn_classifier conda create -n tox21_snn_cls python=3.11 conda activate tox21_snn_cls pip install -r requirements.txt ``` # Inference For inference, you only need `predict.py`. Example usage inside Python: ```python from predict import predict smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"] results = predict(smiles_list) print(results) ``` The output will be a nested dictionary in the format: ```python { "CCO": {"target1": 0, "target2": 1, ..., "target12": 0}, "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1}, "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0} } ``` # Notes - Adapting `predict.py`, `train.py`, `config/`, and `checkpoints/` is required for leaderboard submission. - Preprocessing must be done inside `predict.py` not just `train.py`.