TR2-D2 / tr2d2-dna /oracle.py
zyc4975matholic
Include DNA training code
303c2e0
import torch
import tempfile
import grelu
import pandas as pd
import os
from grelu.lightning import LightningModel
import grelu.data.preprocess
import grelu.data.dataset
import dataloader_gosai
import numpy as np
from typing import Callable, Union, List
from scipy.linalg import sqrtm
from scipy.stats import pearsonr
import torch.nn.functional as F
import io
base_path = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
def get_cal_atac_orale(device=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
ckpt_path = os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt')
ckpt = torch.load(ckpt_path, map_location="cpu")
hp = ckpt.get("hyper_parameters", {})
ckpt.setdefault("data_params", hp.get("data_params", {}))
ckpt.setdefault("performance", {})
if not ckpt["performance"]:
ckpt["performance"] = {
"best_step": ckpt.get("global_step", 0),
"best_metric": None,
}
# Load model from in-memory checkpoint (no file I/O needed)
buffer = io.BytesIO()
torch.save(ckpt, buffer)
buffer.seek(0) # Reset buffer position to the beginning
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
model_load.to(device)
model_load.train_params['logger'] = None
return model_load
def get_gosai_oracle(mode='train', device=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
if mode == 'train':
ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_ft.ckpt")
ckpt = torch.load(ckpt_path, map_location="cpu")
hp = ckpt.get("hyper_parameters", {})
ckpt.setdefault("data_params", hp.get("data_params", {}))
ckpt.setdefault("performance", {})
if not ckpt["performance"]:
ckpt["performance"] = {
"best_step": ckpt.get("global_step", 0),
"best_metric": None,
}
# Load model from in-memory checkpoint (no file I/O needed)
buffer = io.BytesIO()
torch.save(ckpt, buffer)
buffer.seek(0) # Reset buffer position to the beginning
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
model_load.to(device)
elif mode == 'eval':
ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_eval.ckpt")
ckpt = torch.load(ckpt_path, map_location="cpu")
hp = ckpt.get("hyper_parameters", {})
ckpt.setdefault("data_params", hp.get("data_params", {})) # safe default
ckpt.setdefault("performance", {}) # safe default
# Optional: add minimal hints if code later reads fields
if not ckpt["performance"]:
ckpt["performance"] = {
"best_step": ckpt.get("global_step", 0),
"best_metric": None,
}
# Load model from in-memory checkpoint (no file I/O needed)
buffer = io.BytesIO()
torch.save(ckpt, buffer)
buffer.seek(0) # Reset buffer position to the beginning
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
model_load.to(device)
else:
raise ValueError
model_load.train_params['logger'] = None
return model_load
def cal_gosai_pred(seqs, model=None, mode='eval'):
"""
seqs: list of sequences (detokenized ACGT...)
"""
if model is None:
model = get_gosai_oracle(mode=mode)
df_seqs = pd.DataFrame(seqs, columns=['seq'])
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
preds = model.predict_on_dataset(pred_dataset, devices=[0])
return preds.squeeze() # numpy array with shape [n_seqs, 3]
def cal_gosai_pred_new(seqs, model=None, mode='eval'):
"""
seqs: list of sequences (detokenized ACGT...)
"""
if model is None:
model = get_gosai_oracle(mode=mode)
model.eval()
tokens = dataloader_gosai.batch_dna_tokenize(seqs)
tokens = torch.tensor(tokens).long().to(model.device)
onehot_tokens = F.one_hot(tokens, num_classes=4).float()
preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy()
return preds.squeeze()
def cal_atac_pred(seqs, model=None):
"""
seqs: list of sequences (detokenized ACGT...)
"""
if model is None:
model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda')
df_seqs = pd.DataFrame(seqs, columns=['seq'])
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
preds = model.predict_on_dataset(pred_dataset, devices=[0])
return preds.squeeze() # numpy array with shape [n_seqs, 7]
def cal_atac_pred_new(seqs, model=None):
"""
seqs: list of sequences (detokenized ACGT...)
"""
if model is None:
model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda')
model.eval()
tokens = dataloader_gosai.batch_dna_tokenize(seqs)
tokens = torch.tensor(tokens).long().to(model.device)
onehot_tokens = F.one_hot(tokens, num_classes=4).float()
preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy()
return preds.squeeze() # numpy array with shape [n_seqs, 7]
def count_kmers(seqs, k=3):
counts = {}
for seq in seqs:
for i in range(len(seq) - k + 1):
subseq = seq[i : i + k]
try:
counts[subseq] += 1
except KeyError:
counts[subseq] = 1
return counts
def subset_for_eval(n=5000, seed=0):
train_set = dataloader_gosai.get_datasets_gosai()
np.random.seed(seed)
torch.manual_seed(seed)
train_set_sp = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), n, replace=False))
return train_set_sp
def subset_eval_groundtruth(sets_sp):
train_set_sp = sets_sp
train_set_sp_clss = train_set_sp.dataset.clss[train_set_sp.indices]
return train_set_sp_clss
def subset_eval_preds(sets_sp, oracle_model=None):
train_set_sp = sets_sp
train_preds = cal_gosai_pred(
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
return train_preds
def subset_eval_kmers(sets_sp, k=3):
train_set_sp = sets_sp
train_seqs = dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy())
train_kmers = count_kmers(train_seqs, k)
return train_kmers
def subset_eval_embs(sets_sp, oracle_model=None):
train_set_sp = sets_sp
train_sp_emb = cal_gosai_emb(
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
return train_sp_emb
def cal_emb_pca(sets_sp, n_components=50, oracle_model=None):
train_set_sp = sets_sp
train_sp_emb = cal_gosai_emb(
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
from sklearn.decomposition import PCA
pca = PCA(n_components=n_components)
pca.fit(train_sp_emb.reshape(train_sp_emb.shape[0], -1))
return pca
def subset_eval_embs_pca(sets_sp, pca, oracle_model=None):
train_sp_emb = subset_eval_embs(sets_sp, oracle_model)
train_sp_emb_pca = pca.transform(train_sp_emb.reshape(train_sp_emb.shape[0], -1))
return train_sp_emb_pca
# https://github.com/HannesStark/dirichlet-flow-matching/blob/main/utils/flow_utils.py
def get_wasserstein_dist(embeds1, embeds2):
if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
return float('nan')
mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
ssdiff = np.sum((mu1 - mu2) ** 2.0)
covmean = sqrtm(sigma1.dot(sigma2))
if np.iscomplexobj(covmean):
covmean = covmean.real
dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return dist
def embed_on_dataset(
model,
dataset: Callable,
devices: Union[str, int, List[int]] = "cpu",
num_workers: int = 1,
batch_size: int = 256,
):
"""
Return embeddings for a dataset of sequences
Args:
dataset: Dataset object that yields one-hot encoded sequences
devices: Device IDs to use
num_workers: Number of workers for data loader
batch_size: Batch size for data loader
Returns:
Numpy array of shape (B, T, L) containing embeddings.
"""
torch.set_float32_matmul_precision("medium")
# Make dataloader
dataloader = model.make_predict_loader(
dataset, num_workers=num_workers, batch_size=batch_size
)
# Get device
orig_device = model.device
device = model.parse_devices(devices)[1]
if isinstance(device, list):
device = device[0]
model.to(device)
# Get embeddings
preds = []
model.model = model.model.eval()
for batch in iter(dataloader):
batch = batch.to(device)
preds.append(model.model.embedding(batch).detach().cpu())
# Return to original device
model.to(orig_device)
return torch.vstack(preds).numpy()
def cal_gosai_emb(seqs, model=None):
"""
seqs: list of sequences (detokenized ACGT...)
"""
if model is None:
model = get_gosai_oracle()
df_seqs = pd.DataFrame(seqs, columns=['seq'])
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
embs = embed_on_dataset(model, pred_dataset, devices=[0])
return embs # numpy array with shape [n_seqs, 3072, 2]
def cal_highexp_kmers(k=3, return_clss=False):
train_set = dataloader_gosai.get_datasets_gosai()
exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.99) # 4.56
highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold]
highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices)
highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())
highexp_kmers_99 = count_kmers(highexp_seqs, k=k)
n_highexp_kmers_99 = len(highexp_indices)
exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.999) # 6.27
highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold]
highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices)
highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())
highexp_kmers_999 = count_kmers(highexp_seqs, k=k)
n_highexp_kmers_999 = len(highexp_indices)
if return_clss:
highexp_set_sp_clss_999 = highexp_set_sp.dataset.clss[highexp_set_sp.indices]
highexp_preds_999 = cal_gosai_pred_new(
dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy()))
return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999, highexp_set_sp_clss_999, highexp_preds_999, highexp_seqs
return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999
def cal_kmer_corr(model, highexp_kmers, n_highexp_kmers, n_sample=128):
model.eval()
all_detoeknized_samples = []
for _ in range(10):
samples = model._sample(eval_sp_size=n_sample).detach().cpu().numpy()
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
all_detoeknized_samples.extend(detokenized_samples)
generated_kmer = count_kmers(all_detoeknized_samples)
kmer_set = set(highexp_kmers.keys()) | set(generated_kmer.keys())
counts = np.zeros((len(kmer_set), 2))
for i, kmer in enumerate(kmer_set):
if kmer in highexp_kmers:
counts[i][1] = highexp_kmers[kmer] * len(generated_kmer) / n_highexp_kmers
if kmer in generated_kmer:
counts[i][0] = generated_kmer[kmer]
corr = pearsonr(counts[:, 0], counts[:, 1])[0]
return corr
def cal_avg_likelihood(model, old_model, n_sample=128):
model.eval()
old_model.eval()
all_raw_samples = []
for _ in range(10):
samples = model._sample(eval_sp_size=n_sample)
all_raw_samples.append(samples)
all_raw_samples = torch.concat(all_raw_samples)
avg_likelihood = old_model._forward_pass_diffusion(all_raw_samples).sum(-1).mean().item()
return avg_likelihood