CAPIMAC / anchors.py
bestow136's picture
Upload 13 files
8ffcfd0 verified
from SRW_KNN_greedy import *
import torch
def get_anchors(h0,h1,map_pairs,num_unique_labels):
#又一个点,提前预训练,统一样本的潜在空间
# print(h0.shape[0],num_unique_labels,'ghjhggjf')
#初始化随机游走参数
num_walks0, walk_length0, similarity_threshold0, num_anchors0, coverage_radius0 = para(h0,h0.shape[0],num_unique_labels)
num_walks1, walk_length1, similarity_threshold1, num_anchors1, coverage_radius1 = para(h1, h1.shape[0],num_unique_labels)
transition_matrix0,transition_matrix1 = visit(h0),visit(h1)#转移概率矩阵
#访问矩阵
visit_matrix0,visit_matrix1 = random_walk_batch_paths(transition_matrix0, num_walks0, walk_length0), random_walk_batch_paths(transition_matrix1, num_walks1, walk_length1)
#
node_importance0, node_importance1 = visit_matrix0.sum(dim=0),visit_matrix1.sum(dim=0)
# # 使用贪心覆盖算法选择锚点
anchor_indices0 = greedy_cover_with_importance(h0, node_importance0, coverage_radius0, num_anchors0)
anchor_indices1 = greedy_cover_with_importance(h1, node_importance1, coverage_radius1, num_anchors1)
combined_indices = torch.cat((anchor_indices0, anchor_indices1))
unique_indices = torch.unique(combined_indices)#合并索引去重
len_indices=len(unique_indices)
mapdata0,mapdata1=torch.tensor(map_pairs[0]),torch.tensor(map_pairs[1])
anchors0,anchors1 = mapdata0[unique_indices].float(),mapdata1[unique_indices].float()# 提取锚点(降维前)
return anchors0,anchors1,len_indices