|
|
import numpy as np |
|
|
import torch |
|
|
from scipy.stats import pearsonr |
|
|
import dataloader_gosai |
|
|
import oracle |
|
|
|
|
|
|
|
|
def compare_kmer(kmer1, kmer2, n_sp1, n_sp2): |
|
|
kmer_set = set(kmer1.keys()) | set(kmer2.keys()) |
|
|
counts = np.zeros((len(kmer_set), 2)) |
|
|
for i, kmer in enumerate(kmer_set): |
|
|
if kmer in kmer1: counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1 |
|
|
if kmer in kmer2: counts[i][0] = kmer2[kmer] |
|
|
return pearsonr(counts[:, 0], counts[:, 1])[0] |
|
|
|
|
|
|
|
|
def get_eval_matrics(samples, ref_model, gosai_oracle, cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999): |
|
|
"""samples: [B, 200]""" |
|
|
info = {} |
|
|
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy()) |
|
|
ref_log_lik = ref_model.get_likelihood(samples, num_steps=128, n_samples=1) |
|
|
info['[log-lik-med]'] = torch.median(ref_log_lik).item() |
|
|
preds = oracle.cal_gosai_pred_new(detokenized_samples, gosai_oracle, mode='eval')[:, 0] |
|
|
info['[pred-activity-med]'] = np.median(preds).item() |
|
|
atac = oracle.cal_atac_pred_new(detokenized_samples, cal_atac_pred_new_mdl)[:, 1] |
|
|
info['[atac-acc%]'] = (atac > 0.5).sum().item() / len(samples) * 100 |
|
|
kmer = oracle.count_kmers(detokenized_samples) |
|
|
info['[3-mer-corr]'] = compare_kmer(highexp_kmers_999, kmer, n_highexp_kmers_999, len(detokenized_samples)).item() |
|
|
return info |