| | from SRW_KNN_greedy import *
|
| | import torch
|
| | def get_anchors(h0,h1,map_pairs,num_unique_labels):
|
| |
|
| |
|
| |
|
| | num_walks0, walk_length0, similarity_threshold0, num_anchors0, coverage_radius0 = para(h0,h0.shape[0],num_unique_labels)
|
| | num_walks1, walk_length1, similarity_threshold1, num_anchors1, coverage_radius1 = para(h1, h1.shape[0],num_unique_labels)
|
| | transition_matrix0,transition_matrix1 = visit(h0),visit(h1)
|
| |
|
| | visit_matrix0,visit_matrix1 = random_walk_batch_paths(transition_matrix0, num_walks0, walk_length0), random_walk_batch_paths(transition_matrix1, num_walks1, walk_length1)
|
| |
|
| | node_importance0, node_importance1 = visit_matrix0.sum(dim=0),visit_matrix1.sum(dim=0)
|
| |
|
| | anchor_indices0 = greedy_cover_with_importance(h0, node_importance0, coverage_radius0, num_anchors0)
|
| | anchor_indices1 = greedy_cover_with_importance(h1, node_importance1, coverage_radius1, num_anchors1)
|
| | combined_indices = torch.cat((anchor_indices0, anchor_indices1))
|
| | unique_indices = torch.unique(combined_indices)
|
| | len_indices=len(unique_indices)
|
| | mapdata0,mapdata1=torch.tensor(map_pairs[0]),torch.tensor(map_pairs[1])
|
| | anchors0,anchors1 = mapdata0[unique_indices].float(),mapdata1[unique_indices].float()
|
| | return anchors0,anchors1,len_indices |