File size: 1,391 Bytes
303c2e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
from scipy.stats import pearsonr
import dataloader_gosai
import oracle


def compare_kmer(kmer1, kmer2, n_sp1, n_sp2):
    kmer_set = set(kmer1.keys()) | set(kmer2.keys())
    counts = np.zeros((len(kmer_set), 2))
    for i, kmer in enumerate(kmer_set):
        if kmer in kmer1: counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
        if kmer in kmer2: counts[i][0] = kmer2[kmer]
    return pearsonr(counts[:, 0], counts[:, 1])[0]


def get_eval_matrics(samples, ref_model, gosai_oracle, cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999):
    """samples: [B, 200]"""
    info = {}
    detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy()) # [B], strings with length 200
    ref_log_lik = ref_model.get_likelihood(samples, num_steps=128, n_samples=1) # [B]
    info['[log-lik-med]'] = torch.median(ref_log_lik).item() 
    preds = oracle.cal_gosai_pred_new(detokenized_samples, gosai_oracle, mode='eval')[:, 0]
    info['[pred-activity-med]'] = np.median(preds).item()
    atac = oracle.cal_atac_pred_new(detokenized_samples, cal_atac_pred_new_mdl)[:, 1]
    info['[atac-acc%]'] = (atac > 0.5).sum().item() / len(samples) * 100
    kmer = oracle.count_kmers(detokenized_samples)
    info['[3-mer-corr]'] = compare_kmer(highexp_kmers_999, kmer, n_highexp_kmers_999, len(detokenized_samples)).item()
    return info