TR2-D2 / tr2d2-dna /eval_utils.py
zyc4975matholic
Include DNA training code
303c2e0
import numpy as np
import torch
from scipy.stats import pearsonr
import dataloader_gosai
import oracle
def compare_kmer(kmer1, kmer2, n_sp1, n_sp2):
kmer_set = set(kmer1.keys()) | set(kmer2.keys())
counts = np.zeros((len(kmer_set), 2))
for i, kmer in enumerate(kmer_set):
if kmer in kmer1: counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
if kmer in kmer2: counts[i][0] = kmer2[kmer]
return pearsonr(counts[:, 0], counts[:, 1])[0]
def get_eval_matrics(samples, ref_model, gosai_oracle, cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999):
"""samples: [B, 200]"""
info = {}
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy()) # [B], strings with length 200
ref_log_lik = ref_model.get_likelihood(samples, num_steps=128, n_samples=1) # [B]
info['[log-lik-med]'] = torch.median(ref_log_lik).item()
preds = oracle.cal_gosai_pred_new(detokenized_samples, gosai_oracle, mode='eval')[:, 0]
info['[pred-activity-med]'] = np.median(preds).item()
atac = oracle.cal_atac_pred_new(detokenized_samples, cal_atac_pred_new_mdl)[:, 1]
info['[atac-acc%]'] = (atac > 0.5).sum().item() / len(samples) * 100
kmer = oracle.count_kmers(detokenized_samples)
info['[3-mer-corr]'] = compare_kmer(highexp_kmers_999, kmer, n_highexp_kmers_999, len(detokenized_samples)).item()
return info