VIDraft Tox21 v3 β ChemBERTa + XGBoost + LightGBM Ensemble
Molecular toxicity prediction for 12 Tox21 endpoints.
Target: Beat DeepTox (AUC 0.862, 1st place 2015) on the JKU Tox21 Leaderboard.
Live API: VIDraft/tox21-xgb Space
Model Description
This repository contains 13 model files (12 task-specific ensemble models + 1 stacking DNN) trained on 9,158-dimensional molecular features combining:
- Chemical fingerprints (8,390 dims): ECFP4(2048) + ECFP6(2048) + FCFP4(2048) + MACCS(167) + 31 RDKit physicochemical descriptors
- ChemBERTa-2 embeddings (768 dims):
seyonec/ChemBERTa-zinc-base-v1[CLS] token representation
Per endpoint ensemble: XGBoost (30 random seeds) + LightGBM (20 random seeds) β averaged probability.
Tox21 Endpoints (12)
| Category | Endpoints |
|---|---|
| Nuclear Receptors (NR) | NR-AhR, NR-AR, NR-AR-LBD, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma |
| Stress Response (SR) | SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53 |
Architecture
SMILES Input
β
βββ RDKit Fingerprints ββββββββββββββββββββββββ 8,390 dims
β (ECFP4 + ECFP6 + FCFP4 + MACCS + desc)
β
βββ ChemBERTa-2 [CLS] embedding βββββββββββββ 768 dims
(seyonec/ChemBERTa-zinc-base-v1)
β
βββββββββββββ Concatenate βββββββββββββββββββ 9,158 dims
β
βββββββββββββ΄ββββββββββββ
β XGBoost Γ 30 seeds β per task
β LightGBM Γ 20 seeds β
βββββββββββββ¬ββββββββββββ
β Mean ensemble
βΌ
Toxicity probability
Training hyperparameters (XGBoost, per task):
n_estimators=500, max_depth=7, learning_rate=0.04subsample=0.85, colsample_bytree=0.65min_child_weight=3, reg_alpha=0.05, reg_lambda=0.3scale_pos_weight=auto(class imbalance correction)device=cuda(NVIDIA H200 GPU training)
Training data: Tox21 challenge dataset, 12,060 molecules (deepchem scaffold split)
Files
| File | Description | Size |
|---|---|---|
NR-AhR.pkl |
NR-AhR ensemble model | ~14 MB |
NR-AR.pkl |
NR-AR ensemble model | ~14 MB |
NR-AR-LBD.pkl |
NR-AR-LBD ensemble model | ~13 MB |
NR-Aromatase.pkl |
NR-Aromatase ensemble model | ~14 MB |
NR-ER.pkl |
NR-ER ensemble model | ~14 MB |
NR-ER-LBD.pkl |
NR-ER-LBD ensemble model | ~14 MB |
NR-PPAR-gamma.pkl |
NR-PPAR-gamma ensemble model | ~13 MB |
SR-ARE.pkl |
SR-ARE ensemble model | ~14 MB |
SR-ATAD5.pkl |
SR-ATAD5 ensemble model | ~14 MB |
SR-HSE.pkl |
SR-HSE ensemble model | ~14 MB |
SR-MMP.pkl |
SR-MMP ensemble model | ~15 MB |
SR-p53.pkl |
SR-p53 ensemble model | ~14 MB |
stacking_dnn.pkl |
PyTorch DNN meta-learner | ~1 MB |
Usage
Via Live API (Recommended)
import requests
response = requests.post(
"https://vidraft-tox21-xgb.hf.space/predict",
json={"smiles": ["CCO", "c1ccccc1"]}
)
predictions = response.json()["predictions"]
# {"CCO": {"NR-AhR": 0.032, "NR-AR": 0.018, ...}, ...}
Direct Model Loading
import pickle
import numpy as np
from huggingface_hub import hf_hub_download
# Download and load a single task model
path = hf_hub_download("VIDraft/tox21-v3-models", "NR-AhR.pkl")
with open(path, "rb") as f:
model = pickle.load(f)
# X must be 9158-dim: concatenate fingerprints (8390) + ChemBERTa [CLS] (768)
# model.predict_proba(X)[:, 1] -> toxicity probability [0, 1]
Feature Extraction
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys
from transformers import AutoTokenizer, AutoModel
import torch
def get_features(smiles: str) -> np.ndarray:
mol = Chem.MolFromSmiles(smiles)
# Fingerprints
ecfp4 = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048), dtype=np.float32)
ecfp6 = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 3, 2048), dtype=np.float32)
fcfp4 = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048, useFeatures=True), dtype=np.float32)
maccs = np.array(MACCSkeys.GenMACCSKeys(mol), dtype=np.float32)
ecfp4c = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048), dtype=np.float32)
# ChemBERTa
tok = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
m = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").eval()
enc = tok([smiles], return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
cls_emb = m(**enc).last_hidden_state[:, 0, :].cpu().numpy()[0]
return np.concatenate([ecfp4, ecfp6, fcfp4, maccs, ecfp4c, cls_emb]) # 9158 dims
Performance
| Model | Mean AUC (12 tasks) | Year | Reference |
|---|---|---|---|
| DeepTox (1st place) | 0.862 | 2015 | Mayr et al., Front. Environ. Sci. 2016 |
| SNN | 0.856 | 2017 | β |
| VIDraft Tox21 v3 (this model) | leaderboard pending | 2026 | β |
Results will be updated after JKU leaderboard submission.
Training Infrastructure
- GPU: NVIDIA H200 Γ 8 (NIPA KT Cloud, 2026)
- Python: 3.11
- Key packages: XGBoost 2.1.1, LightGBM 4.3.0, scikit-learn 1.5.2, transformers 4.44.0
- Training date: 2026-06-12
Citation
@misc{vidraft2026tox21,
title = {VIDraft Tox21 v3: ChemBERTa + XGBoost/LightGBM Ensemble for Molecular Toxicity Prediction},
author = {VIDraft},
year = {2026},
url = {https://huggingface.co/VIDraft/tox21-v3-models}
}
License
Apache-2.0