VIDraft Tox21 v3 β€” ChemBERTa + XGBoost + LightGBM Ensemble

Molecular toxicity prediction for 12 Tox21 endpoints.

Target: Beat DeepTox (AUC 0.862, 1st place 2015) on the JKU Tox21 Leaderboard.
Live API: VIDraft/tox21-xgb Space


Model Description

This repository contains 13 model files (12 task-specific ensemble models + 1 stacking DNN) trained on 9,158-dimensional molecular features combining:

  • Chemical fingerprints (8,390 dims): ECFP4(2048) + ECFP6(2048) + FCFP4(2048) + MACCS(167) + 31 RDKit physicochemical descriptors
  • ChemBERTa-2 embeddings (768 dims): seyonec/ChemBERTa-zinc-base-v1 [CLS] token representation

Per endpoint ensemble: XGBoost (30 random seeds) + LightGBM (20 random seeds) β†’ averaged probability.

Tox21 Endpoints (12)

Category Endpoints
Nuclear Receptors (NR) NR-AhR, NR-AR, NR-AR-LBD, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma
Stress Response (SR) SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53

Architecture

SMILES Input
  β”‚
  β”œβ”€β”€ RDKit Fingerprints ──────────────────────── 8,390 dims
  β”‚   (ECFP4 + ECFP6 + FCFP4 + MACCS + desc)
  β”‚
  └── ChemBERTa-2 [CLS] embedding ─────────────   768 dims
      (seyonec/ChemBERTa-zinc-base-v1)
  β”‚
  └──────────── Concatenate ─────────────────── 9,158 dims
                    β”‚
        β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
        β”‚  XGBoost Γ— 30 seeds  β”‚  per task
        β”‚  LightGBM Γ— 20 seeds β”‚
        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                    β”‚ Mean ensemble
                    β–Ό
             Toxicity probability

Training hyperparameters (XGBoost, per task):

  • n_estimators=500, max_depth=7, learning_rate=0.04
  • subsample=0.85, colsample_bytree=0.65
  • min_child_weight=3, reg_alpha=0.05, reg_lambda=0.3
  • scale_pos_weight=auto (class imbalance correction)
  • device=cuda (NVIDIA H200 GPU training)

Training data: Tox21 challenge dataset, 12,060 molecules (deepchem scaffold split)

Files

File Description Size
NR-AhR.pkl NR-AhR ensemble model ~14 MB
NR-AR.pkl NR-AR ensemble model ~14 MB
NR-AR-LBD.pkl NR-AR-LBD ensemble model ~13 MB
NR-Aromatase.pkl NR-Aromatase ensemble model ~14 MB
NR-ER.pkl NR-ER ensemble model ~14 MB
NR-ER-LBD.pkl NR-ER-LBD ensemble model ~14 MB
NR-PPAR-gamma.pkl NR-PPAR-gamma ensemble model ~13 MB
SR-ARE.pkl SR-ARE ensemble model ~14 MB
SR-ATAD5.pkl SR-ATAD5 ensemble model ~14 MB
SR-HSE.pkl SR-HSE ensemble model ~14 MB
SR-MMP.pkl SR-MMP ensemble model ~15 MB
SR-p53.pkl SR-p53 ensemble model ~14 MB
stacking_dnn.pkl PyTorch DNN meta-learner ~1 MB

Usage

Via Live API (Recommended)

import requests

response = requests.post(
    "https://vidraft-tox21-xgb.hf.space/predict",
    json={"smiles": ["CCO", "c1ccccc1"]}
)
predictions = response.json()["predictions"]
# {"CCO": {"NR-AhR": 0.032, "NR-AR": 0.018, ...}, ...}

Direct Model Loading

import pickle
import numpy as np
from huggingface_hub import hf_hub_download

# Download and load a single task model
path = hf_hub_download("VIDraft/tox21-v3-models", "NR-AhR.pkl")
with open(path, "rb") as f:
    model = pickle.load(f)

# X must be 9158-dim: concatenate fingerprints (8390) + ChemBERTa [CLS] (768)
# model.predict_proba(X)[:, 1]  ->  toxicity probability [0, 1]

Feature Extraction

import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys
from transformers import AutoTokenizer, AutoModel
import torch

def get_features(smiles: str) -> np.ndarray:
    mol = Chem.MolFromSmiles(smiles)
    # Fingerprints
    ecfp4  = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048), dtype=np.float32)
    ecfp6  = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 3, 2048), dtype=np.float32)
    fcfp4  = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048, useFeatures=True), dtype=np.float32)
    maccs  = np.array(MACCSkeys.GenMACCSKeys(mol), dtype=np.float32)
    ecfp4c = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048), dtype=np.float32)
    # ChemBERTa
    tok = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    m   = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").eval()
    enc = tok([smiles], return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        cls_emb = m(**enc).last_hidden_state[:, 0, :].cpu().numpy()[0]
    return np.concatenate([ecfp4, ecfp6, fcfp4, maccs, ecfp4c, cls_emb])  # 9158 dims

Performance

Model Mean AUC (12 tasks) Year Reference
DeepTox (1st place) 0.862 2015 Mayr et al., Front. Environ. Sci. 2016
SNN 0.856 2017 β€”
VIDraft Tox21 v3 (this model) leaderboard pending 2026 β€”

Results will be updated after JKU leaderboard submission.

Training Infrastructure

  • GPU: NVIDIA H200 Γ— 8 (NIPA KT Cloud, 2026)
  • Python: 3.11
  • Key packages: XGBoost 2.1.1, LightGBM 4.3.0, scikit-learn 1.5.2, transformers 4.44.0
  • Training date: 2026-06-12

Citation

@misc{vidraft2026tox21,
  title  = {VIDraft Tox21 v3: ChemBERTa + XGBoost/LightGBM Ensemble for Molecular Toxicity Prediction},
  author = {VIDraft},
  year   = {2026},
  url    = {https://huggingface.co/VIDraft/tox21-v3-models}
}

License

Apache-2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support