Fix keypoint visualization code (#5)
Browse files- Fix keypoint visualization code (6ddc5897875a79401ea9fb590c59dec578398192)
Co-authored-by: Merve Noyan <[email protected]>
README.md
CHANGED
|
@@ -78,23 +78,47 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup
|
|
| 78 |
|
| 79 |
inputs = processor(images, return_tensors="pt")
|
| 80 |
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
for i in range(len(images)):
|
|
|
|
|
|
|
|
|
|
| 83 |
image_mask = outputs.mask[i]
|
| 84 |
image_indices = torch.nonzero(image_mask).squeeze()
|
| 85 |
-
image_keypoints = outputs.keypoints[i][image_indices]
|
| 86 |
-
image_scores = outputs.scores[i][image_indices]
|
| 87 |
-
image_descriptors = outputs.descriptors[i][image_indices]
|
| 88 |
-
```
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
```
|
| 99 |
|
| 100 |
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
|
|
|
| 78 |
|
| 79 |
inputs = processor(images, return_tensors="pt")
|
| 80 |
outputs = model(**inputs)
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
We can now visualize the keypoints.
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
import matplotlib.pyplot as plt
|
| 87 |
+
import torch
|
| 88 |
|
| 89 |
for i in range(len(images)):
|
| 90 |
+
image = images[i]
|
| 91 |
+
image_width, image_height = image.size
|
| 92 |
+
|
| 93 |
image_mask = outputs.mask[i]
|
| 94 |
image_indices = torch.nonzero(image_mask).squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
image_scores = outputs.scores[i][image_indices]
|
| 97 |
+
image_keypoints = outputs.keypoints[i][image_indices]
|
| 98 |
+
|
| 99 |
+
keypoints = image_keypoints.detach().numpy()
|
| 100 |
+
scores = image_scores.detach().numpy()
|
| 101 |
+
|
| 102 |
+
valid_keypoints = [
|
| 103 |
+
(kp, score) for kp, score in zip(keypoints, scores)
|
| 104 |
+
if 0 <= kp[0] < image_width and 0 <= kp[1] < image_height
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
valid_keypoints, valid_scores = zip(*valid_keypoints)
|
| 108 |
+
valid_keypoints = torch.tensor(valid_keypoints)
|
| 109 |
+
valid_scores = torch.tensor(valid_scores)
|
| 110 |
+
|
| 111 |
+
print(valid_keypoints.shape)
|
| 112 |
+
|
| 113 |
+
plt.axis('off')
|
| 114 |
+
plt.imshow(image)
|
| 115 |
+
plt.scatter(
|
| 116 |
+
valid_keypoints[:, 0],
|
| 117 |
+
valid_keypoints[:, 1],
|
| 118 |
+
s=valid_scores * 100,
|
| 119 |
+
c='red'
|
| 120 |
+
)
|
| 121 |
+
plt.show()
|
| 122 |
```
|
| 123 |
|
| 124 |
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|