|
|
import torch |
|
|
import tempfile |
|
|
import grelu |
|
|
import pandas as pd |
|
|
import os |
|
|
from grelu.lightning import LightningModel |
|
|
import grelu.data.preprocess |
|
|
import grelu.data.dataset |
|
|
import dataloader_gosai |
|
|
import numpy as np |
|
|
from typing import Callable, Union, List |
|
|
from scipy.linalg import sqrtm |
|
|
from scipy.stats import pearsonr |
|
|
import torch.nn.functional as F |
|
|
import io |
|
|
|
|
|
base_path = "" |
|
|
|
|
|
|
|
|
def get_cal_atac_orale(device=None): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device |
|
|
ckpt_path = os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt') |
|
|
|
|
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
hp = ckpt.get("hyper_parameters", {}) |
|
|
ckpt.setdefault("data_params", hp.get("data_params", {})) |
|
|
ckpt.setdefault("performance", {}) |
|
|
|
|
|
if not ckpt["performance"]: |
|
|
ckpt["performance"] = { |
|
|
"best_step": ckpt.get("global_step", 0), |
|
|
"best_metric": None, |
|
|
} |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
torch.save(ckpt, buffer) |
|
|
buffer.seek(0) |
|
|
|
|
|
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu") |
|
|
model_load.to(device) |
|
|
|
|
|
model_load.train_params['logger'] = None |
|
|
|
|
|
return model_load |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gosai_oracle(mode='train', device=None): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device |
|
|
if mode == 'train': |
|
|
ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_ft.ckpt") |
|
|
|
|
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
hp = ckpt.get("hyper_parameters", {}) |
|
|
ckpt.setdefault("data_params", hp.get("data_params", {})) |
|
|
ckpt.setdefault("performance", {}) |
|
|
|
|
|
if not ckpt["performance"]: |
|
|
ckpt["performance"] = { |
|
|
"best_step": ckpt.get("global_step", 0), |
|
|
"best_metric": None, |
|
|
} |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
torch.save(ckpt, buffer) |
|
|
buffer.seek(0) |
|
|
|
|
|
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu") |
|
|
model_load.to(device) |
|
|
|
|
|
elif mode == 'eval': |
|
|
|
|
|
ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_eval.ckpt") |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
hp = ckpt.get("hyper_parameters", {}) |
|
|
ckpt.setdefault("data_params", hp.get("data_params", {})) |
|
|
ckpt.setdefault("performance", {}) |
|
|
|
|
|
if not ckpt["performance"]: |
|
|
ckpt["performance"] = { |
|
|
"best_step": ckpt.get("global_step", 0), |
|
|
"best_metric": None, |
|
|
} |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
torch.save(ckpt, buffer) |
|
|
buffer.seek(0) |
|
|
|
|
|
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu") |
|
|
model_load.to(device) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
model_load.train_params['logger'] = None |
|
|
|
|
|
return model_load |
|
|
|
|
|
def cal_gosai_pred(seqs, model=None, mode='eval'): |
|
|
""" |
|
|
seqs: list of sequences (detokenized ACGT...) |
|
|
""" |
|
|
if model is None: |
|
|
model = get_gosai_oracle(mode=mode) |
|
|
df_seqs = pd.DataFrame(seqs, columns=['seq']) |
|
|
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs) |
|
|
preds = model.predict_on_dataset(pred_dataset, devices=[0]) |
|
|
return preds.squeeze() |
|
|
|
|
|
def cal_gosai_pred_new(seqs, model=None, mode='eval'): |
|
|
""" |
|
|
seqs: list of sequences (detokenized ACGT...) |
|
|
""" |
|
|
if model is None: |
|
|
model = get_gosai_oracle(mode=mode) |
|
|
model.eval() |
|
|
tokens = dataloader_gosai.batch_dna_tokenize(seqs) |
|
|
tokens = torch.tensor(tokens).long().to(model.device) |
|
|
onehot_tokens = F.one_hot(tokens, num_classes=4).float() |
|
|
preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy() |
|
|
return preds.squeeze() |
|
|
|
|
|
def cal_atac_pred(seqs, model=None): |
|
|
""" |
|
|
seqs: list of sequences (detokenized ACGT...) |
|
|
""" |
|
|
if model is None: |
|
|
model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda') |
|
|
df_seqs = pd.DataFrame(seqs, columns=['seq']) |
|
|
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs) |
|
|
preds = model.predict_on_dataset(pred_dataset, devices=[0]) |
|
|
return preds.squeeze() |
|
|
|
|
|
|
|
|
def cal_atac_pred_new(seqs, model=None): |
|
|
""" |
|
|
seqs: list of sequences (detokenized ACGT...) |
|
|
""" |
|
|
if model is None: |
|
|
model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda') |
|
|
model.eval() |
|
|
tokens = dataloader_gosai.batch_dna_tokenize(seqs) |
|
|
tokens = torch.tensor(tokens).long().to(model.device) |
|
|
onehot_tokens = F.one_hot(tokens, num_classes=4).float() |
|
|
preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy() |
|
|
return preds.squeeze() |
|
|
|
|
|
|
|
|
def count_kmers(seqs, k=3): |
|
|
counts = {} |
|
|
for seq in seqs: |
|
|
for i in range(len(seq) - k + 1): |
|
|
subseq = seq[i : i + k] |
|
|
try: |
|
|
counts[subseq] += 1 |
|
|
except KeyError: |
|
|
counts[subseq] = 1 |
|
|
return counts |
|
|
|
|
|
|
|
|
def subset_for_eval(n=5000, seed=0): |
|
|
train_set = dataloader_gosai.get_datasets_gosai() |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
train_set_sp = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), n, replace=False)) |
|
|
return train_set_sp |
|
|
|
|
|
|
|
|
def subset_eval_groundtruth(sets_sp): |
|
|
train_set_sp = sets_sp |
|
|
train_set_sp_clss = train_set_sp.dataset.clss[train_set_sp.indices] |
|
|
return train_set_sp_clss |
|
|
|
|
|
|
|
|
def subset_eval_preds(sets_sp, oracle_model=None): |
|
|
train_set_sp = sets_sp |
|
|
train_preds = cal_gosai_pred( |
|
|
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model) |
|
|
return train_preds |
|
|
|
|
|
|
|
|
def subset_eval_kmers(sets_sp, k=3): |
|
|
train_set_sp = sets_sp |
|
|
train_seqs = dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()) |
|
|
train_kmers = count_kmers(train_seqs, k) |
|
|
return train_kmers |
|
|
|
|
|
|
|
|
def subset_eval_embs(sets_sp, oracle_model=None): |
|
|
train_set_sp = sets_sp |
|
|
train_sp_emb = cal_gosai_emb( |
|
|
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model) |
|
|
return train_sp_emb |
|
|
|
|
|
|
|
|
def cal_emb_pca(sets_sp, n_components=50, oracle_model=None): |
|
|
train_set_sp = sets_sp |
|
|
train_sp_emb = cal_gosai_emb( |
|
|
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model) |
|
|
from sklearn.decomposition import PCA |
|
|
pca = PCA(n_components=n_components) |
|
|
pca.fit(train_sp_emb.reshape(train_sp_emb.shape[0], -1)) |
|
|
return pca |
|
|
|
|
|
|
|
|
def subset_eval_embs_pca(sets_sp, pca, oracle_model=None): |
|
|
train_sp_emb = subset_eval_embs(sets_sp, oracle_model) |
|
|
train_sp_emb_pca = pca.transform(train_sp_emb.reshape(train_sp_emb.shape[0], -1)) |
|
|
return train_sp_emb_pca |
|
|
|
|
|
|
|
|
|
|
|
def get_wasserstein_dist(embeds1, embeds2): |
|
|
if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0: |
|
|
return float('nan') |
|
|
mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False) |
|
|
mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False) |
|
|
ssdiff = np.sum((mu1 - mu2) ** 2.0) |
|
|
covmean = sqrtm(sigma1.dot(sigma2)) |
|
|
if np.iscomplexobj(covmean): |
|
|
covmean = covmean.real |
|
|
dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) |
|
|
return dist |
|
|
|
|
|
|
|
|
def embed_on_dataset( |
|
|
model, |
|
|
dataset: Callable, |
|
|
devices: Union[str, int, List[int]] = "cpu", |
|
|
num_workers: int = 1, |
|
|
batch_size: int = 256, |
|
|
): |
|
|
""" |
|
|
Return embeddings for a dataset of sequences |
|
|
|
|
|
Args: |
|
|
dataset: Dataset object that yields one-hot encoded sequences |
|
|
devices: Device IDs to use |
|
|
num_workers: Number of workers for data loader |
|
|
batch_size: Batch size for data loader |
|
|
|
|
|
Returns: |
|
|
Numpy array of shape (B, T, L) containing embeddings. |
|
|
""" |
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
|
|
|
|
|
dataloader = model.make_predict_loader( |
|
|
dataset, num_workers=num_workers, batch_size=batch_size |
|
|
) |
|
|
|
|
|
|
|
|
orig_device = model.device |
|
|
device = model.parse_devices(devices)[1] |
|
|
if isinstance(device, list): |
|
|
device = device[0] |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
preds = [] |
|
|
model.model = model.model.eval() |
|
|
for batch in iter(dataloader): |
|
|
batch = batch.to(device) |
|
|
preds.append(model.model.embedding(batch).detach().cpu()) |
|
|
|
|
|
|
|
|
model.to(orig_device) |
|
|
return torch.vstack(preds).numpy() |
|
|
|
|
|
|
|
|
def cal_gosai_emb(seqs, model=None): |
|
|
""" |
|
|
seqs: list of sequences (detokenized ACGT...) |
|
|
""" |
|
|
if model is None: |
|
|
model = get_gosai_oracle() |
|
|
df_seqs = pd.DataFrame(seqs, columns=['seq']) |
|
|
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs) |
|
|
embs = embed_on_dataset(model, pred_dataset, devices=[0]) |
|
|
return embs |
|
|
|
|
|
|
|
|
def cal_highexp_kmers(k=3, return_clss=False): |
|
|
train_set = dataloader_gosai.get_datasets_gosai() |
|
|
exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.99) |
|
|
highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold] |
|
|
highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices) |
|
|
highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy()) |
|
|
highexp_kmers_99 = count_kmers(highexp_seqs, k=k) |
|
|
n_highexp_kmers_99 = len(highexp_indices) |
|
|
|
|
|
exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.999) |
|
|
highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold] |
|
|
highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices) |
|
|
highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy()) |
|
|
highexp_kmers_999 = count_kmers(highexp_seqs, k=k) |
|
|
n_highexp_kmers_999 = len(highexp_indices) |
|
|
|
|
|
if return_clss: |
|
|
highexp_set_sp_clss_999 = highexp_set_sp.dataset.clss[highexp_set_sp.indices] |
|
|
highexp_preds_999 = cal_gosai_pred_new( |
|
|
dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())) |
|
|
return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999, highexp_set_sp_clss_999, highexp_preds_999, highexp_seqs |
|
|
|
|
|
return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999 |
|
|
|
|
|
|
|
|
def cal_kmer_corr(model, highexp_kmers, n_highexp_kmers, n_sample=128): |
|
|
model.eval() |
|
|
all_detoeknized_samples = [] |
|
|
for _ in range(10): |
|
|
samples = model._sample(eval_sp_size=n_sample).detach().cpu().numpy() |
|
|
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples) |
|
|
all_detoeknized_samples.extend(detokenized_samples) |
|
|
generated_kmer = count_kmers(all_detoeknized_samples) |
|
|
|
|
|
|
|
|
kmer_set = set(highexp_kmers.keys()) | set(generated_kmer.keys()) |
|
|
counts = np.zeros((len(kmer_set), 2)) |
|
|
for i, kmer in enumerate(kmer_set): |
|
|
if kmer in highexp_kmers: |
|
|
counts[i][1] = highexp_kmers[kmer] * len(generated_kmer) / n_highexp_kmers |
|
|
if kmer in generated_kmer: |
|
|
counts[i][0] = generated_kmer[kmer] |
|
|
|
|
|
corr = pearsonr(counts[:, 0], counts[:, 1])[0] |
|
|
return corr |
|
|
|
|
|
def cal_avg_likelihood(model, old_model, n_sample=128): |
|
|
model.eval() |
|
|
old_model.eval() |
|
|
all_raw_samples = [] |
|
|
for _ in range(10): |
|
|
samples = model._sample(eval_sp_size=n_sample) |
|
|
all_raw_samples.append(samples) |
|
|
all_raw_samples = torch.concat(all_raw_samples) |
|
|
avg_likelihood = old_model._forward_pass_diffusion(all_raw_samples).sum(-1).mean().item() |
|
|
return avg_likelihood |