samhithaxdk's picture
Update app.py
966e5cc verified
import os, re, io, time, math, textwrap, warnings, requests
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg") # 👈 headless backend for HF Spaces
import matplotlib.pyplot as plt
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import mean_squared_error, r2_score, roc_auc_score, average_precision_score
from sklearn.metrics.pairwise import cosine_similarity
from xgboost import XGBRegressor
warnings.filterwarnings("ignore")
# -----------------------------
# Config
# -----------------------------
DATA_CANDIDATES = [
os.getenv("TEM1_DATA_PATH", "tem1_clean.csv"),
"data/tem1_clean.csv",
"/data/tem1_clean.csv",
]
UNIPROT_ID = "P62593" # TEM-1 beta-lactamase
PAFF_BINDER_THRESHOLD = 6.0 # >=6 ~ <=1µM
# -----------------------------
# Small helpers
# -----------------------------
def pAff_to_nM(p):
# p = -log10(Kd M) -> Kd (nM) = 10**(9-p)
return 10.0 ** (9.0 - float(p))
def fmt_conc(nM):
if nM < 1e-3: return f"{nM*1e3:.2f} pM"
if nM < 1: return f"{nM:.2f} nM"
if nM < 1e3: return f"{nM/1e3:.2f} µM"
return f"{nM/1e6:.2f} mM"
def conf_label(p):
if p >= 0.80: return "Likely"
if p >= 0.60: return "Uncertain"
return "Unlikely"
def conf_emoji(p):
if p >= 0.80: return "🟢"
if p >= 0.60: return "🟡"
return "🔴"
def _parse_smiles_block(text, limit=100):
items = [s.strip() for s in re.split(r'[\n,;]+', str(text or "")) if s.strip()]
return items[:limit]
# -----------------------------
# Load TEM-1 protein and embed
# -----------------------------
print("[boot] Fetching TEM-1 (UniProt %s)" % UNIPROT_ID)
fasta = requests.get(f"https://rest.uniprot.org/uniprotkb/{UNIPROT_ID}.fasta").text
TEM1_SEQ = "".join(line.strip() for line in fasta.splitlines() if not line.startswith(">"))
TEM1_SEQ = re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", TEM1_SEQ.upper())
print("[boot] TEM-1 length:", len(TEM1_SEQ))
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[boot] Using device:", device)
print("[boot] Loading ESM-2 35M ...")
tok_p = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
mdl_p = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D").to(device).eval()
print("[boot] Loading ChemBERTa ...")
tok_l = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mdl_l = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device).eval()
with torch.inference_mode():
toks = tok_p(TEM1_SEQ, return_tensors="pt", add_special_tokens=True).to(device)
rep = mdl_p(**toks).last_hidden_state[0, 1:-1, :].mean(dim=0).cpu().numpy()
prot_vec = rep.astype(np.float32) # ~480-D
print("[boot] Protein embedding:", prot_vec.shape)
def _embed_ligands(smiles_list, batch_size=64, max_length=256):
vecs = []
for i in range(0, len(smiles_list), batch_size):
batch = smiles_list[i:i+batch_size]
enc = tok_l(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
with torch.inference_mode():
out = mdl_l(**enc).last_hidden_state
cls = out[:, 0, :].detach().cpu().numpy().astype(np.float32)
vecs.append(cls)
return np.vstack(vecs) if vecs else np.zeros((0, mdl_l.config.hidden_size), dtype=np.float32)
# -----------------------------
# Try to load training data
# -----------------------------
df = None
for p in DATA_CANDIDATES:
if os.path.exists(p):
try:
df = pd.read_csv(p)
if {'smiles','pAff'}.issubset(df.columns):
print(f"[boot] Loaded dataset: {p} -> {df.shape}")
break
except Exception as e:
print("[boot] Failed reading", p, e)
have_data = df is not None
# Placeholders initialized below
reg = None
clf = None
clf_cal = None
bins = None
q90_table = None
lig_tr = None
metrics_md = "*(Train a model or upload tem1_clean.csv to populate metrics here.)*"
def _train_models_from_df(df):
global reg, clf, clf_cal, bins, q90_table, lig_tr, metrics_md
df = df.dropna(subset=["smiles","pAff"]).reset_index(drop=True)
# Ligand embeddings
t0 = time.time()
lig_X = _embed_ligands(df["smiles"].tolist())
print(f"[train] Ligand embed {lig_X.shape} in {time.time()-t0:.1f}s")
# Joint features with protein
prot_X = np.repeat(prot_vec.reshape(1, -1), len(df), axis=0)
X = np.hstack([prot_X, lig_X]).astype(np.float32)
# Targets
y = df["pAff"].astype(np.float32).values
y_bin = (y >= PAFF_BINDER_THRESHOLD).astype(int)
# Group-wise split by k-means clusters (scaffold-free)
k = max(5, min(50, len(df)//50))
km = KMeans(n_clusters=k, random_state=7, n_init=10)
groups = km.fit_predict(lig_X)
# custom split that holds out whole clusters
def groupwise_split(groups, test_frac=0.2, seed=7):
rng = np.random.default_rng(seed)
keys = list(set(groups))
rng.shuffle(keys)
N = len(groups)
target = int(N*test_frac)
taken, test_idx = 0, []
for key in keys:
idx = np.where(groups==key)[0].tolist()
test_idx.extend(idx)
taken += len(idx)
if taken >= target:
break
train_idx = sorted(set(range(N)) - set(test_idx))
return np.array(train_idx), np.array(test_idx)
tr_idx, te_idx = groupwise_split(groups, test_frac=0.2, seed=7)
X_tr, X_te = X[tr_idx], X[te_idx]
y_tr, y_te = y[tr_idx], y[te_idx]
yb_tr, yb_te = y_bin[tr_idx], y_bin[te_idx]
# Heads
reg = XGBRegressor(
n_estimators=600, max_depth=6, learning_rate=0.05,
subsample=0.8, colsample_bytree=0.8, n_jobs=-1
).fit(X_tr, y_tr)
clf = LogisticRegression(max_iter=2000).fit(X_tr, yb_tr)
# Metrics
pred = reg.predict(X_te)
try:
rmse = mean_squared_error(y_te, pred, squared=False)
except TypeError:
rmse = mean_squared_error(y_te, pred) ** 0.5
r2 = r2_score(y_te, pred)
p_bin = clf.predict_proba(X_te)[:, 1]
roc = roc_auc_score(yb_te, p_bin)
pr = average_precision_score(yb_te, p_bin)
# conditional q90 by predicted bin
bins = np.linspace(float(pred.min()), float(pred.max()), 8)
bin_idx = np.digitize(pred, bins)
abs_err = np.abs(y_te - pred)
q90_table = np.zeros(len(bins)+1, dtype=np.float32)
for i in range(len(q90_table)):
vals = abs_err[bin_idx==i]
q90_table[i] = np.quantile(vals, 0.90) if len(vals)>0 else float(np.quantile(abs_err, 0.90))
# calibration & similarity
clf_cal = CalibratedClassifierCV(clf, method="isotonic", cv=3).fit(X_tr, yb_tr)
lig_tr = lig_X[tr_idx]
metrics_md = (
f"**Eval (held-out)** — RMSE: {rmse:.2f} pAff (≈×{10**rmse:.1f}), "
f"R²: {r2:.2f}, ROC-AUC: {roc:.2f}, PR-AUC: {pr:.2f}"
)
print("[train] done.")
def q90_for(p):
i = int(np.digitize([p], bins)[0]) if bins is not None else 0
i = max(0, min(i, len(q90_table)-1)) if q90_table is not None else 0
return q90_table[i] if q90_table is not None else 0.75 # conservative fallback
# Try real training; otherwise install heuristic heads
if have_data:
_train_models_from_df(df)
else:
print("[boot] No dataset found — using heuristic heads (demo mode).")
class HeuristicReg:
def predict(self, X):
# X: [B, Dp+Dl]; take ligand part and compute cosine to protein-projected vector
Dp = prot_vec.shape[0]
lig = X[:, Dp:]
# project protein to ligand dims
pv = prot_vec[:lig.shape[1]]
pv = pv / (np.linalg.norm(pv) + 1e-8)
lig_n = lig / (np.linalg.norm(lig, axis=1, keepdims=True)+1e-8)
sim = (lig_n @ pv)
return 5.5 + 2.0*(sim.clip(-1,1)+1)/2.0 # ~ [4.5,7.5]
class HeuristicClf:
def predict_proba(self, X):
Dp = prot_vec.shape[0]
lig = X[:, Dp:]
pv = prot_vec[:lig.shape[1]]
pv = pv / (np.linalg.norm(pv) + 1e-8)
lig_n = lig / (np.linalg.norm(lig, axis=1, keepdims=True)+1e-8)
sim = (lig_n @ pv)
z = (sim - sim.min()) / (sim.max()-sim.min()+1e-8)
p = 1/(1+np.exp(-4*(z-0.5)))
return np.vstack([1-p, p]).T
reg = HeuristicReg()
clf = HeuristicClf()
clf_cal = clf
bins = np.linspace(4.0, 8.0, 8)
q90_table = np.full(len(bins)+1, 0.75, dtype=np.float32)
lig_tr = np.zeros((1, mdl_l.config.hidden_size), dtype=np.float32)
metrics_md = "*(Demo mode — upload tem1_clean.csv to train real heads.)*"
# -----------------------------
# Prediction helpers
# -----------------------------
def train_similarity(smiles):
enc = tok_l([smiles], padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
with torch.inference_mode():
lig = mdl_l(**enc).last_hidden_state[:,0,:].cpu().numpy().astype(np.float32)
if lig_tr is None or lig_tr.shape[0]==0:
return 0.0
sim = cosine_similarity(lig, lig_tr)[0]
return float(sim.max())
import matplotlib.pyplot as plt # (already imported at top, fine to keep)
import traceback
import matplotlib.pyplot as plt # keep after matplotlib.use("Agg")
def _blank_fig(width=3.6, height=0.6):
fig = plt.figure(figsize=(width, height))
plt.axis("off")
return fig
def predict_smiles(smiles: str):
try:
# Empty input → friendly message + blank fig
if not smiles:
return "Please enter a SMILES", _blank_fig()
# 1) ligand embedding
enc = tok_l([smiles], padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
with torch.inference_mode():
out = mdl_l(**enc).last_hidden_state
lig = out[:, 0, :].detach().cpu().numpy().astype(np.float32)
# 2) joint feature
fx = np.hstack([prot_vec.reshape(1, -1), lig]).astype(np.float32)
# 3) regression + interval
p_aff = float(reg.predict(fx)[0])
q90 = q90_for(p_aff)
p_lo, p_hi = p_aff - q90, p_aff + q90
nM_center = pAff_to_nM(p_aff)
nM_hi, nM_lo = pAff_to_nM(p_hi), pAff_to_nM(p_lo)
# 4) calibrated binder probability
try:
p_cal = float(clf_cal.predict_proba(fx)[:, 1])
except Exception:
p_cal = float(clf.predict_proba(fx)[:, 1])
label = conf_label(p_cal); mark = conf_emoji(p_cal)
badge = " (≤1 µM)" if p_aff >= PAFF_BINDER_THRESHOLD else ""
# 5) similarity
sim = train_similarity(smiles)
sim_note = (f"\nNearest-set similarity: {sim:.2f}"
if sim >= 0.60 else
f"\n⚠️ Low similarity to training set: {sim:.2f}")
md = (
f"**Predicted pAff:** {p_aff:.2f} (−log10 M){badge} → **Kd ≈ {fmt_conc(nM_center)}**\n\n"
f"**90% interval:** {p_lo:.2f}{p_hi:.2f} (≈ {fmt_conc(nM_hi)} to {fmt_conc(nM_lo)})\n\n"
f"**Binder confidence:** {mark} {label} ({p_cal:.2f}){sim_note}\n"
)
# Mini bar to visualize P(binder)
fig = plt.figure(figsize=(3.6, 0.6))
ax = fig.add_axes([0.07, 0.35, 0.86, 0.35])
ax.barh([0], [p_cal], height=0.6)
ax.set_xlim(0, 1)
ax.set_yticks([])
ax.set_xticks([0, 0.5, 1.0])
ax.set_title("P(binder)")
for spine in ax.spines.values():
spine.set_visible(False)
return md, fig
except Exception as e:
# Show the error inline so we can debug without checking logs
tb = traceback.format_exc(limit=5)
msg = f"❌ **Error:** {e}\n\n```\n{tb}\n```"
return msg, _blank_fig()
def batch_predict(smiles_text):
smi = _parse_smiles_block(smiles_text)
if not smi:
return [], np.array([]), np.array([])
lig = _embed_ligands(smi) # (L, Dl)
P = np.repeat(prot_vec.reshape(1, -1), len(smi), 0) # (L, Dp)
X = np.hstack([P, lig]).astype(np.float32) # (L, Dp+Dl)
p_aff = reg.predict(X)
p_bind = clf.predict_proba(X)[:, 1]
return smi, p_aff, p_bind
def plot_paff_bars(names, paff, paff_thr=PAFF_BINDER_THRESHOLD):
names = list(names); paff = np.array(paff, dtype=float)
fig, ax = plt.subplots(figsize=(max(6, len(names)*0.6), 3.2))
ax.bar(range(len(names)), paff)
ax.axhline(paff_thr, linestyle="--")
ax.set_xticks(range(len(names)))
ax.set_xticklabels([n[:16]+("…" if len(n)>16 else "") for n in names], rotation=45, ha="right")
ax.set_ylabel("Predicted pAff (−log10 M)"); ax.set_title("Batch predictions — pAff")
plt.tight_layout()
return fig
def plot_paff_vs_pbind(names, paff, pbind, hi=0.80, mid=0.60, paff_thr=PAFF_BINDER_THRESHOLD):
names = list(names); paff = np.array(paff, dtype=float); pbind = np.array(pbind, dtype=float)
fig, ax = plt.subplots(figsize=(5.8, 4.2))
ax.scatter(paff, pbind, s=36)
ax.axvline(paff_thr, linestyle="--"); ax.axhline(hi, linestyle="--"); ax.axhline(mid, linestyle="--")
top = np.argsort(-(paff + pbind))[:10]
for i in top:
lbl = names[i][:18] + ("…" if len(names[i]) > 18 else "")
ax.annotate(lbl, (paff[i], pbind[i]), xytext=(4, 4), textcoords="offset points")
ax.set_xlabel("Predicted pAff (−log10 M)"); ax.set_ylabel("Calibrated P(binder)")
ax.set_title("Batch predictions"); plt.tight_layout()
return fig
def heatmap_predict(smiles_block):
smi_list = _parse_smiles_block(smiles_block)
if not smi_list:
fig = plt.figure(figsize=(4, 2))
plt.axis("off")
plt.text(0.5, 0.5, "No SMILES provided", ha="center", va="center")
return fig
# Embed ligands
ligs = _embed_ligands(smi_list)
# Joint features (protein + ligands)
pv_rep = np.repeat(prot_vec.reshape(1, -1), len(smi_list), axis=0)
fx = np.hstack([pv_rep, ligs]).astype(np.float32)
# Predict pAff (single protein row)
p_affs = reg.predict(fx) # shape (L,)
M = p_affs.reshape(1, -1) # 1 x L
fig, ax = plt.subplots(figsize=(max(6, len(smi_list)*0.8), 2.8))
im = ax.imshow(M, aspect="auto")
ax.set_xticks(range(len(smi_list)))
ax.set_xticklabels([s[:14] + ("…" if len(s) > 14 else "") for s in smi_list],
rotation=45, ha="right")
ax.set_yticks([0]); ax.set_yticklabels(["TEM-1 (WT)"])
cbar = fig.colorbar(im, ax=ax); cbar.set_label("Predicted pAff")
# Mark predicted binders (>= threshold)
for j in range(M.shape[1]):
if M[0, j] >= PAFF_BINDER_THRESHOLD:
ax.text(j, 0, "★", ha="center", va="center", color="white", fontsize=12)
ax.set_title("Heatmap — predicted pAff (higher is better)")
plt.tight_layout()
return fig
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(title="Antibiotic Resistance Target Finder — TEM-1") as demo:
gr.Markdown("""\
# Antibiotic Resistance Target Finder — TEM-1
**Goal:** Predict how tightly a small molecule binds **TEM-1 β-lactamase** variants.
**How to use (2 steps):**
1) Paste a **SMILES** string and click **Submit** to get a prediction.
2) (Optional) Paste multiple SMILES for batch plots and a heatmap.
*Protein embeddings:* ESM-2 (35M) · *Ligand embeddings:* ChemBERTa · *Models:* XGBoost + LogisticRegression
""")
with gr.Row():
smi_in = gr.Textbox(label="SMILES", placeholder="e.g., CC1=CC(=O)C=CC1=O", lines=1)
btn = gr.Button("Submit", variant="primary")
out_md = gr.Markdown()
out_plot = gr.Plot()
btn.click(fn=predict_smiles, inputs=smi_in, outputs=[out_md, out_plot])
gr.Markdown("""---
### Batch mode (paste 1–100 SMILES separated by newlines, commas, or semicolons)
""")
smi_batch = gr.Textbox(label="Batch SMILES", lines=6, placeholder="SMILES per line ...")
with gr.Row():
btn_bars = gr.Button("Bar chart (pAff)")
btn_scatter = gr.Button("Scatter (pAff vs P(binder))")
btn_heat = gr.Button("Heatmap")
plot1 = gr.Plot()
plot2 = gr.Plot()
plot3 = gr.Plot()
def _bars(smiblock):
names, paff, pbind = batch_predict(smiblock)
return plot_paff_bars(names, paff)
def _scatter(smiblock):
names, paff, pbind = batch_predict(smiblock)
return plot_paff_vs_pbind(names, paff, pbind)
def _heat(smiblock):
return heatmap_predict(smiblock)
btn_bars.click(_bars, inputs=smi_batch, outputs=plot1)
btn_scatter.click(_scatter, inputs=smi_batch, outputs=plot2)
btn_heat.click(_heat, inputs=smi_batch, outputs=plot3)
with gr.Accordion("Model card: assumptions, metrics & limits", open=False):
gr.Markdown("""\
**Compute footprint:** small (≤50M embeddings + lightweight heads). Runs on CPU in Spaces.
%s
**Assumptions / caveats**
- Trained on **TEM-1** datasets; predictions for very dissimilar chemotypes are less certain.
- Reported “confidence” is **calibrated** on a held-out set; not a substitute for wet-lab validation.
- Use as a **ranking/triage** tool, not as a definitive activity claim.
**pAff** is −log10(Kd in molar). Bigger is better. Example: 1 µM → pAff=6; 100 nM → 7; 10 nM → 8.
""" % metrics_md)
demo.launch()