| | import torch |
| | import networkx as nx |
| | from torch_geometric.utils import from_networkx |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | def extract_patches(feature_map, patch_size=(4, 4)): |
| | """ |
| | Extracts non-overlapping patches from a feature map to form nodes in a graph. |
| | |
| | Parameters: |
| | - feature_map (Tensor): The feature map from the CNN of shape (B, C, H', W'). |
| | H' and W' are reduced spatial dimensions after CNN feature extraction. |
| | - patch_size (tuple): Spatial size (height, width) of each patch. |
| | |
| | Returns: |
| | - patches (Tensor): Tensor of shape (B, N, C, patch_h, patch_w), where N is the number of patches per image. |
| | """ |
| | b, c, h, w = feature_map.size() |
| | patch_h, patch_w = patch_size |
| |
|
| | |
| | patches = feature_map.unfold(2, patch_h, patch_h).unfold(3, patch_w, patch_w) |
| |
|
| | |
| | patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous() |
| | patches = patches.view(b, -1, c, patch_h, patch_w) |
| | return patches |
| |
|
| | def construct_graph_from_patch(patch_index, patch_shape, image_shape): |
| | """ |
| | Constructs edges between patch nodes based on spatial adjacency (k-connectivity). |
| | This follows the approach described in Section 3.2 of SAG-ViT, where patches |
| | are arranged in a grid and connected to their spatial neighbors. |
| | |
| | Parameters: |
| | - patch_index (int): Index of the current patch node. |
| | - patch_shape (tuple): (patch_height, patch_width). |
| | - image_shape (tuple): (height, width) of the feature map. |
| | |
| | Returns: |
| | - G (nx.Graph): A graph with a single node and edges to its neighbors (to be composed globally). |
| | """ |
| | G = nx.Graph() |
| |
|
| | |
| | grid_height = image_shape[0] // patch_shape[0] |
| | grid_width = image_shape[1] // patch_shape[1] |
| |
|
| | |
| | current_node = patch_index |
| |
|
| | G.add_node(current_node) |
| |
|
| | |
| | neighbor_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1), |
| | (-1, -1), (-1, 1), (1, -1), (1, 1)] |
| |
|
| | |
| | row = current_node // grid_width |
| | col = current_node % grid_width |
| |
|
| | for dr, dc in neighbor_offsets: |
| | neighbor_row = row + dr |
| | neighbor_col = col + dc |
| | if 0 <= neighbor_row < grid_height and 0 <= neighbor_col < grid_width: |
| | neighbor_node = neighbor_row * grid_width + neighbor_col |
| | G.add_edge(current_node, neighbor_node) |
| |
|
| | return G |
| |
|
| | def build_graph_from_patches(feature_map, patch_size=(4,4)): |
| | """ |
| | Builds a global graph for each image in the batch, where each node corresponds |
| | to a patch, and edges represent spatial adjacency. This graph captures local |
| | spatial relationships of the patches, as outlined in Sections 3.1 and 3.2 of SAG-ViT. |
| | |
| | Parameters: |
| | - feature_map (Tensor): CNN output (B, C, H', W'). |
| | - patch_size (tuple): Size of each patch (patch_h, patch_w). |
| | |
| | Returns: |
| | - G_global_batch (list): A list of NetworkX graphs, one per image in the batch. |
| | - patches (Tensor): The extracted patches (B, N, C, patch_h, patch_w). |
| | """ |
| | patches = extract_patches(feature_map, patch_size) |
| | batch_size = patches.size(0) |
| |
|
| | grid_height = feature_map.size(2) // patch_size[0] |
| | grid_width = feature_map.size(3) // patch_size[1] |
| | num_patches = grid_height * grid_width |
| |
|
| | G_global_batch = [] |
| | for batch_idx in range(batch_size): |
| | G_global = nx.Graph() |
| | |
| | for patch_idx in range(num_patches): |
| | G_patch = construct_graph_from_patch( |
| | patch_index=patch_idx, |
| | patch_shape=patch_size, |
| | image_shape=(feature_map.size(2), feature_map.size(3)) |
| | ) |
| | G_global = nx.compose(G_global, G_patch) |
| | G_global_batch.append(G_global) |
| |
|
| | return G_global_batch, patches |
| |
|
| | def build_graph_data_from_patches(G_global_batch, patches): |
| | """ |
| | Converts NetworkX graphs and associated patches into PyTorch Geometric Data objects. |
| | Each node corresponds to a patch vectorized into a feature node embedding. |
| | |
| | Parameters: |
| | - G_global_batch (list): List of global graphs (one per image) in NetworkX form. |
| | - patches (Tensor): (B, N, C, patch_h, patch_w) patch tensor. |
| | |
| | Returns: |
| | - data_list (list): List of PyTorch Geometric Data objects, where data.x are node features, |
| | and data.edge_index is the adjacency from the constructed graph. |
| | """ |
| | from_networkx_ = from_networkx |
| | |
| | data_list = [] |
| | batch_size, num_patches, channels, patch_h, patch_w = patches.size() |
| |
|
| | for batch_idx, G_global in enumerate(G_global_batch): |
| | |
| | node_features = patches[batch_idx].view(num_patches, -1) |
| |
|
| | G_pygeom = from_networkx_(G_global) |
| | G_pygeom.x = node_features |
| | data_list.append(G_pygeom) |
| |
|
| | return data_list |
| |
|