Spaces:
Running
on
Zero
Running
on
Zero
Miroslav Purkrabek
commited on
Commit
·
a249588
1
Parent(s):
4b8d5c5
add code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CITATION.cff +25 -0
- LICENSE +674 -0
- app.py +262 -0
- configs/README.md +30 -0
- configs/bmp_D3.yaml +37 -0
- configs/bmp_J1.yaml +39 -0
- demo/bmp_demo.py +250 -0
- demo/demo_utils.py +705 -0
- demo/mm_utils.py +106 -0
- demo/posevis_lite.py +507 -0
- demo/sam2_utils.py +714 -0
- mmpose/__init__.py +27 -0
- mmpose/apis/__init__.py +16 -0
- mmpose/apis/inference.py +280 -0
- mmpose/apis/inference_3d.py +360 -0
- mmpose/apis/inference_tracking.py +103 -0
- mmpose/apis/inferencers/__init__.py +11 -0
- mmpose/apis/inferencers/base_mmpose_inferencer.py +691 -0
- mmpose/apis/inferencers/hand3d_inferencer.py +344 -0
- mmpose/apis/inferencers/mmpose_inferencer.py +250 -0
- mmpose/apis/inferencers/pose2d_inferencer.py +262 -0
- mmpose/apis/inferencers/pose3d_inferencer.py +457 -0
- mmpose/apis/inferencers/utils/__init__.py +5 -0
- mmpose/apis/inferencers/utils/default_det_models.py +36 -0
- mmpose/apis/inferencers/utils/get_model_alias.py +37 -0
- mmpose/apis/visualization.py +132 -0
- mmpose/codecs/__init__.py +25 -0
- mmpose/codecs/annotation_processors.py +100 -0
- mmpose/codecs/associative_embedding.py +522 -0
- mmpose/codecs/base.py +81 -0
- mmpose/codecs/decoupled_heatmap.py +274 -0
- mmpose/codecs/edpose_label.py +153 -0
- mmpose/codecs/hand_3d_heatmap.py +202 -0
- mmpose/codecs/image_pose_lifting.py +280 -0
- mmpose/codecs/integral_regression_label.py +121 -0
- mmpose/codecs/megvii_heatmap.py +147 -0
- mmpose/codecs/motionbert_label.py +240 -0
- mmpose/codecs/msra_heatmap.py +153 -0
- mmpose/codecs/onehot_heatmap.py +263 -0
- mmpose/codecs/regression_label.py +108 -0
- mmpose/codecs/simcc_label.py +311 -0
- mmpose/codecs/spr.py +306 -0
- mmpose/codecs/udp_heatmap.py +263 -0
- mmpose/codecs/utils/__init__.py +32 -0
- mmpose/codecs/utils/camera_image_projection.py +102 -0
- mmpose/codecs/utils/gaussian_heatmap.py +433 -0
- mmpose/codecs/utils/instance_property.py +111 -0
- mmpose/codecs/utils/offset_heatmap.py +143 -0
- mmpose/codecs/utils/oks_map.py +97 -0
- mmpose/codecs/utils/post_processing.py +530 -0
CITATION.cff
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CITATION.cff file for Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle
|
| 2 |
+
# This file provides metadata for the software and its preferred citation format.
|
| 3 |
+
cff-version: 1.2.0
|
| 4 |
+
message: "If you use this software, please cite it as below."
|
| 5 |
+
authors:
|
| 6 |
+
- family-names: Purkrabek
|
| 7 |
+
given-names: Miroslav
|
| 8 |
+
- family-names: Matas
|
| 9 |
+
given-names: Jiri
|
| 10 |
+
title: "Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle"
|
| 11 |
+
version: 1.0.0
|
| 12 |
+
date-released: 2025-06-20
|
| 13 |
+
preferred-citation:
|
| 14 |
+
type: conference-paper
|
| 15 |
+
authors:
|
| 16 |
+
- family-names: Purkrabek
|
| 17 |
+
given-names: Miroslav
|
| 18 |
+
- family-names: Matas
|
| 19 |
+
given-names: Jiri
|
| 20 |
+
collection-title: "Proceedings of the IEEE/CVF International Conference on Computer Vision"
|
| 21 |
+
month: 10
|
| 22 |
+
start: 1 # First page number
|
| 23 |
+
end: 8 # Last page number
|
| 24 |
+
title: "Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle"
|
| 25 |
+
year: 2025
|
LICENSE
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GNU GENERAL PUBLIC LICENSE
|
| 2 |
+
Version 3, 29 June 2007
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
| 5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
| 6 |
+
of this license document, but changing it is not allowed.
|
| 7 |
+
|
| 8 |
+
Preamble
|
| 9 |
+
|
| 10 |
+
The GNU General Public License is a free, copyleft license for
|
| 11 |
+
software and other kinds of works.
|
| 12 |
+
|
| 13 |
+
The licenses for most software and other practical works are designed
|
| 14 |
+
to take away your freedom to share and change the works. By contrast,
|
| 15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
| 16 |
+
share and change all versions of a program--to make sure it remains free
|
| 17 |
+
software for all its users. We, the Free Software Foundation, use the
|
| 18 |
+
GNU General Public License for most of our software; it applies also to
|
| 19 |
+
any other work released this way by its authors. You can apply it to
|
| 20 |
+
your programs, too.
|
| 21 |
+
|
| 22 |
+
When we speak of free software, we are referring to freedom, not
|
| 23 |
+
price. Our General Public Licenses are designed to make sure that you
|
| 24 |
+
have the freedom to distribute copies of free software (and charge for
|
| 25 |
+
them if you wish), that you receive source code or can get it if you
|
| 26 |
+
want it, that you can change the software or use pieces of it in new
|
| 27 |
+
free programs, and that you know you can do these things.
|
| 28 |
+
|
| 29 |
+
To protect your rights, we need to prevent others from denying you
|
| 30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
| 31 |
+
certain responsibilities if you distribute copies of the software, or if
|
| 32 |
+
you modify it: responsibilities to respect the freedom of others.
|
| 33 |
+
|
| 34 |
+
For example, if you distribute copies of such a program, whether
|
| 35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
| 36 |
+
freedoms that you received. You must make sure that they, too, receive
|
| 37 |
+
or can get the source code. And you must show them these terms so they
|
| 38 |
+
know their rights.
|
| 39 |
+
|
| 40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
| 41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
| 42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
| 43 |
+
|
| 44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
| 45 |
+
that there is no warranty for this free software. For both users' and
|
| 46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
| 47 |
+
changed, so that their problems will not be attributed erroneously to
|
| 48 |
+
authors of previous versions.
|
| 49 |
+
|
| 50 |
+
Some devices are designed to deny users access to install or run
|
| 51 |
+
modified versions of the software inside them, although the manufacturer
|
| 52 |
+
can do so. This is fundamentally incompatible with the aim of
|
| 53 |
+
protecting users' freedom to change the software. The systematic
|
| 54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
| 55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
| 56 |
+
have designed this version of the GPL to prohibit the practice for those
|
| 57 |
+
products. If such problems arise substantially in other domains, we
|
| 58 |
+
stand ready to extend this provision to those domains in future versions
|
| 59 |
+
of the GPL, as needed to protect the freedom of users.
|
| 60 |
+
|
| 61 |
+
Finally, every program is threatened constantly by software patents.
|
| 62 |
+
States should not allow patents to restrict development and use of
|
| 63 |
+
software on general-purpose computers, but in those that do, we wish to
|
| 64 |
+
avoid the special danger that patents applied to a free program could
|
| 65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
| 66 |
+
patents cannot be used to render the program non-free.
|
| 67 |
+
|
| 68 |
+
The precise terms and conditions for copying, distribution and
|
| 69 |
+
modification follow.
|
| 70 |
+
|
| 71 |
+
TERMS AND CONDITIONS
|
| 72 |
+
|
| 73 |
+
0. Definitions.
|
| 74 |
+
|
| 75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
| 76 |
+
|
| 77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
| 78 |
+
works, such as semiconductor masks.
|
| 79 |
+
|
| 80 |
+
"The Program" refers to any copyrightable work licensed under this
|
| 81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
| 82 |
+
"recipients" may be individuals or organizations.
|
| 83 |
+
|
| 84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
| 85 |
+
in a fashion requiring copyright permission, other than the making of an
|
| 86 |
+
exact copy. The resulting work is called a "modified version" of the
|
| 87 |
+
earlier work or a work "based on" the earlier work.
|
| 88 |
+
|
| 89 |
+
A "covered work" means either the unmodified Program or a work based
|
| 90 |
+
on the Program.
|
| 91 |
+
|
| 92 |
+
To "propagate" a work means to do anything with it that, without
|
| 93 |
+
permission, would make you directly or secondarily liable for
|
| 94 |
+
infringement under applicable copyright law, except executing it on a
|
| 95 |
+
computer or modifying a private copy. Propagation includes copying,
|
| 96 |
+
distribution (with or without modification), making available to the
|
| 97 |
+
public, and in some countries other activities as well.
|
| 98 |
+
|
| 99 |
+
To "convey" a work means any kind of propagation that enables other
|
| 100 |
+
parties to make or receive copies. Mere interaction with a user through
|
| 101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
| 102 |
+
|
| 103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
| 104 |
+
to the extent that it includes a convenient and prominently visible
|
| 105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
| 106 |
+
tells the user that there is no warranty for the work (except to the
|
| 107 |
+
extent that warranties are provided), that licensees may convey the
|
| 108 |
+
work under this License, and how to view a copy of this License. If
|
| 109 |
+
the interface presents a list of user commands or options, such as a
|
| 110 |
+
menu, a prominent item in the list meets this criterion.
|
| 111 |
+
|
| 112 |
+
1. Source Code.
|
| 113 |
+
|
| 114 |
+
The "source code" for a work means the preferred form of the work
|
| 115 |
+
for making modifications to it. "Object code" means any non-source
|
| 116 |
+
form of a work.
|
| 117 |
+
|
| 118 |
+
A "Standard Interface" means an interface that either is an official
|
| 119 |
+
standard defined by a recognized standards body, or, in the case of
|
| 120 |
+
interfaces specified for a particular programming language, one that
|
| 121 |
+
is widely used among developers working in that language.
|
| 122 |
+
|
| 123 |
+
The "System Libraries" of an executable work include anything, other
|
| 124 |
+
than the work as a whole, that (a) is included in the normal form of
|
| 125 |
+
packaging a Major Component, but which is not part of that Major
|
| 126 |
+
Component, and (b) serves only to enable use of the work with that
|
| 127 |
+
Major Component, or to implement a Standard Interface for which an
|
| 128 |
+
implementation is available to the public in source code form. A
|
| 129 |
+
"Major Component", in this context, means a major essential component
|
| 130 |
+
(kernel, window system, and so on) of the specific operating system
|
| 131 |
+
(if any) on which the executable work runs, or a compiler used to
|
| 132 |
+
produce the work, or an object code interpreter used to run it.
|
| 133 |
+
|
| 134 |
+
The "Corresponding Source" for a work in object code form means all
|
| 135 |
+
the source code needed to generate, install, and (for an executable
|
| 136 |
+
work) run the object code and to modify the work, including scripts to
|
| 137 |
+
control those activities. However, it does not include the work's
|
| 138 |
+
System Libraries, or general-purpose tools or generally available free
|
| 139 |
+
programs which are used unmodified in performing those activities but
|
| 140 |
+
which are not part of the work. For example, Corresponding Source
|
| 141 |
+
includes interface definition files associated with source files for
|
| 142 |
+
the work, and the source code for shared libraries and dynamically
|
| 143 |
+
linked subprograms that the work is specifically designed to require,
|
| 144 |
+
such as by intimate data communication or control flow between those
|
| 145 |
+
subprograms and other parts of the work.
|
| 146 |
+
|
| 147 |
+
The Corresponding Source need not include anything that users
|
| 148 |
+
can regenerate automatically from other parts of the Corresponding
|
| 149 |
+
Source.
|
| 150 |
+
|
| 151 |
+
The Corresponding Source for a work in source code form is that
|
| 152 |
+
same work.
|
| 153 |
+
|
| 154 |
+
2. Basic Permissions.
|
| 155 |
+
|
| 156 |
+
All rights granted under this License are granted for the term of
|
| 157 |
+
copyright on the Program, and are irrevocable provided the stated
|
| 158 |
+
conditions are met. This License explicitly affirms your unlimited
|
| 159 |
+
permission to run the unmodified Program. The output from running a
|
| 160 |
+
covered work is covered by this License only if the output, given its
|
| 161 |
+
content, constitutes a covered work. This License acknowledges your
|
| 162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
| 163 |
+
|
| 164 |
+
You may make, run and propagate covered works that you do not
|
| 165 |
+
convey, without conditions so long as your license otherwise remains
|
| 166 |
+
in force. You may convey covered works to others for the sole purpose
|
| 167 |
+
of having them make modifications exclusively for you, or provide you
|
| 168 |
+
with facilities for running those works, provided that you comply with
|
| 169 |
+
the terms of this License in conveying all material for which you do
|
| 170 |
+
not control copyright. Those thus making or running the covered works
|
| 171 |
+
for you must do so exclusively on your behalf, under your direction
|
| 172 |
+
and control, on terms that prohibit them from making any copies of
|
| 173 |
+
your copyrighted material outside their relationship with you.
|
| 174 |
+
|
| 175 |
+
Conveying under any other circumstances is permitted solely under
|
| 176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
| 177 |
+
makes it unnecessary.
|
| 178 |
+
|
| 179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
| 180 |
+
|
| 181 |
+
No covered work shall be deemed part of an effective technological
|
| 182 |
+
measure under any applicable law fulfilling obligations under article
|
| 183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
| 184 |
+
similar laws prohibiting or restricting circumvention of such
|
| 185 |
+
measures.
|
| 186 |
+
|
| 187 |
+
When you convey a covered work, you waive any legal power to forbid
|
| 188 |
+
circumvention of technological measures to the extent such circumvention
|
| 189 |
+
is effected by exercising rights under this License with respect to
|
| 190 |
+
the covered work, and you disclaim any intention to limit operation or
|
| 191 |
+
modification of the work as a means of enforcing, against the work's
|
| 192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
| 193 |
+
technological measures.
|
| 194 |
+
|
| 195 |
+
4. Conveying Verbatim Copies.
|
| 196 |
+
|
| 197 |
+
You may convey verbatim copies of the Program's source code as you
|
| 198 |
+
receive it, in any medium, provided that you conspicuously and
|
| 199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
| 200 |
+
keep intact all notices stating that this License and any
|
| 201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
| 202 |
+
keep intact all notices of the absence of any warranty; and give all
|
| 203 |
+
recipients a copy of this License along with the Program.
|
| 204 |
+
|
| 205 |
+
You may charge any price or no price for each copy that you convey,
|
| 206 |
+
and you may offer support or warranty protection for a fee.
|
| 207 |
+
|
| 208 |
+
5. Conveying Modified Source Versions.
|
| 209 |
+
|
| 210 |
+
You may convey a work based on the Program, or the modifications to
|
| 211 |
+
produce it from the Program, in the form of source code under the
|
| 212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
| 213 |
+
|
| 214 |
+
a) The work must carry prominent notices stating that you modified
|
| 215 |
+
it, and giving a relevant date.
|
| 216 |
+
|
| 217 |
+
b) The work must carry prominent notices stating that it is
|
| 218 |
+
released under this License and any conditions added under section
|
| 219 |
+
7. This requirement modifies the requirement in section 4 to
|
| 220 |
+
"keep intact all notices".
|
| 221 |
+
|
| 222 |
+
c) You must license the entire work, as a whole, under this
|
| 223 |
+
License to anyone who comes into possession of a copy. This
|
| 224 |
+
License will therefore apply, along with any applicable section 7
|
| 225 |
+
additional terms, to the whole of the work, and all its parts,
|
| 226 |
+
regardless of how they are packaged. This License gives no
|
| 227 |
+
permission to license the work in any other way, but it does not
|
| 228 |
+
invalidate such permission if you have separately received it.
|
| 229 |
+
|
| 230 |
+
d) If the work has interactive user interfaces, each must display
|
| 231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
| 232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
| 233 |
+
work need not make them do so.
|
| 234 |
+
|
| 235 |
+
A compilation of a covered work with other separate and independent
|
| 236 |
+
works, which are not by their nature extensions of the covered work,
|
| 237 |
+
and which are not combined with it such as to form a larger program,
|
| 238 |
+
in or on a volume of a storage or distribution medium, is called an
|
| 239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
| 240 |
+
used to limit the access or legal rights of the compilation's users
|
| 241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
| 242 |
+
in an aggregate does not cause this License to apply to the other
|
| 243 |
+
parts of the aggregate.
|
| 244 |
+
|
| 245 |
+
6. Conveying Non-Source Forms.
|
| 246 |
+
|
| 247 |
+
You may convey a covered work in object code form under the terms
|
| 248 |
+
of sections 4 and 5, provided that you also convey the
|
| 249 |
+
machine-readable Corresponding Source under the terms of this License,
|
| 250 |
+
in one of these ways:
|
| 251 |
+
|
| 252 |
+
a) Convey the object code in, or embodied in, a physical product
|
| 253 |
+
(including a physical distribution medium), accompanied by the
|
| 254 |
+
Corresponding Source fixed on a durable physical medium
|
| 255 |
+
customarily used for software interchange.
|
| 256 |
+
|
| 257 |
+
b) Convey the object code in, or embodied in, a physical product
|
| 258 |
+
(including a physical distribution medium), accompanied by a
|
| 259 |
+
written offer, valid for at least three years and valid for as
|
| 260 |
+
long as you offer spare parts or customer support for that product
|
| 261 |
+
model, to give anyone who possesses the object code either (1) a
|
| 262 |
+
copy of the Corresponding Source for all the software in the
|
| 263 |
+
product that is covered by this License, on a durable physical
|
| 264 |
+
medium customarily used for software interchange, for a price no
|
| 265 |
+
more than your reasonable cost of physically performing this
|
| 266 |
+
conveying of source, or (2) access to copy the
|
| 267 |
+
Corresponding Source from a network server at no charge.
|
| 268 |
+
|
| 269 |
+
c) Convey individual copies of the object code with a copy of the
|
| 270 |
+
written offer to provide the Corresponding Source. This
|
| 271 |
+
alternative is allowed only occasionally and noncommercially, and
|
| 272 |
+
only if you received the object code with such an offer, in accord
|
| 273 |
+
with subsection 6b.
|
| 274 |
+
|
| 275 |
+
d) Convey the object code by offering access from a designated
|
| 276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
| 277 |
+
Corresponding Source in the same way through the same place at no
|
| 278 |
+
further charge. You need not require recipients to copy the
|
| 279 |
+
Corresponding Source along with the object code. If the place to
|
| 280 |
+
copy the object code is a network server, the Corresponding Source
|
| 281 |
+
may be on a different server (operated by you or a third party)
|
| 282 |
+
that supports equivalent copying facilities, provided you maintain
|
| 283 |
+
clear directions next to the object code saying where to find the
|
| 284 |
+
Corresponding Source. Regardless of what server hosts the
|
| 285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
| 286 |
+
available for as long as needed to satisfy these requirements.
|
| 287 |
+
|
| 288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
| 289 |
+
you inform other peers where the object code and Corresponding
|
| 290 |
+
Source of the work are being offered to the general public at no
|
| 291 |
+
charge under subsection 6d.
|
| 292 |
+
|
| 293 |
+
A separable portion of the object code, whose source code is excluded
|
| 294 |
+
from the Corresponding Source as a System Library, need not be
|
| 295 |
+
included in conveying the object code work.
|
| 296 |
+
|
| 297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
| 298 |
+
tangible personal property which is normally used for personal, family,
|
| 299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
| 300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
| 301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
| 302 |
+
product received by a particular user, "normally used" refers to a
|
| 303 |
+
typical or common use of that class of product, regardless of the status
|
| 304 |
+
of the particular user or of the way in which the particular user
|
| 305 |
+
actually uses, or expects or is expected to use, the product. A product
|
| 306 |
+
is a consumer product regardless of whether the product has substantial
|
| 307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
| 308 |
+
the only significant mode of use of the product.
|
| 309 |
+
|
| 310 |
+
"Installation Information" for a User Product means any methods,
|
| 311 |
+
procedures, authorization keys, or other information required to install
|
| 312 |
+
and execute modified versions of a covered work in that User Product from
|
| 313 |
+
a modified version of its Corresponding Source. The information must
|
| 314 |
+
suffice to ensure that the continued functioning of the modified object
|
| 315 |
+
code is in no case prevented or interfered with solely because
|
| 316 |
+
modification has been made.
|
| 317 |
+
|
| 318 |
+
If you convey an object code work under this section in, or with, or
|
| 319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
| 320 |
+
part of a transaction in which the right of possession and use of the
|
| 321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
| 322 |
+
fixed term (regardless of how the transaction is characterized), the
|
| 323 |
+
Corresponding Source conveyed under this section must be accompanied
|
| 324 |
+
by the Installation Information. But this requirement does not apply
|
| 325 |
+
if neither you nor any third party retains the ability to install
|
| 326 |
+
modified object code on the User Product (for example, the work has
|
| 327 |
+
been installed in ROM).
|
| 328 |
+
|
| 329 |
+
The requirement to provide Installation Information does not include a
|
| 330 |
+
requirement to continue to provide support service, warranty, or updates
|
| 331 |
+
for a work that has been modified or installed by the recipient, or for
|
| 332 |
+
the User Product in which it has been modified or installed. Access to a
|
| 333 |
+
network may be denied when the modification itself materially and
|
| 334 |
+
adversely affects the operation of the network or violates the rules and
|
| 335 |
+
protocols for communication across the network.
|
| 336 |
+
|
| 337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
| 338 |
+
in accord with this section must be in a format that is publicly
|
| 339 |
+
documented (and with an implementation available to the public in
|
| 340 |
+
source code form), and must require no special password or key for
|
| 341 |
+
unpacking, reading or copying.
|
| 342 |
+
|
| 343 |
+
7. Additional Terms.
|
| 344 |
+
|
| 345 |
+
"Additional permissions" are terms that supplement the terms of this
|
| 346 |
+
License by making exceptions from one or more of its conditions.
|
| 347 |
+
Additional permissions that are applicable to the entire Program shall
|
| 348 |
+
be treated as though they were included in this License, to the extent
|
| 349 |
+
that they are valid under applicable law. If additional permissions
|
| 350 |
+
apply only to part of the Program, that part may be used separately
|
| 351 |
+
under those permissions, but the entire Program remains governed by
|
| 352 |
+
this License without regard to the additional permissions.
|
| 353 |
+
|
| 354 |
+
When you convey a copy of a covered work, you may at your option
|
| 355 |
+
remove any additional permissions from that copy, or from any part of
|
| 356 |
+
it. (Additional permissions may be written to require their own
|
| 357 |
+
removal in certain cases when you modify the work.) You may place
|
| 358 |
+
additional permissions on material, added by you to a covered work,
|
| 359 |
+
for which you have or can give appropriate copyright permission.
|
| 360 |
+
|
| 361 |
+
Notwithstanding any other provision of this License, for material you
|
| 362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
| 363 |
+
that material) supplement the terms of this License with terms:
|
| 364 |
+
|
| 365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
| 366 |
+
terms of sections 15 and 16 of this License; or
|
| 367 |
+
|
| 368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
| 369 |
+
author attributions in that material or in the Appropriate Legal
|
| 370 |
+
Notices displayed by works containing it; or
|
| 371 |
+
|
| 372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
| 373 |
+
requiring that modified versions of such material be marked in
|
| 374 |
+
reasonable ways as different from the original version; or
|
| 375 |
+
|
| 376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
| 377 |
+
authors of the material; or
|
| 378 |
+
|
| 379 |
+
e) Declining to grant rights under trademark law for use of some
|
| 380 |
+
trade names, trademarks, or service marks; or
|
| 381 |
+
|
| 382 |
+
f) Requiring indemnification of licensors and authors of that
|
| 383 |
+
material by anyone who conveys the material (or modified versions of
|
| 384 |
+
it) with contractual assumptions of liability to the recipient, for
|
| 385 |
+
any liability that these contractual assumptions directly impose on
|
| 386 |
+
those licensors and authors.
|
| 387 |
+
|
| 388 |
+
All other non-permissive additional terms are considered "further
|
| 389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
| 390 |
+
received it, or any part of it, contains a notice stating that it is
|
| 391 |
+
governed by this License along with a term that is a further
|
| 392 |
+
restriction, you may remove that term. If a license document contains
|
| 393 |
+
a further restriction but permits relicensing or conveying under this
|
| 394 |
+
License, you may add to a covered work material governed by the terms
|
| 395 |
+
of that license document, provided that the further restriction does
|
| 396 |
+
not survive such relicensing or conveying.
|
| 397 |
+
|
| 398 |
+
If you add terms to a covered work in accord with this section, you
|
| 399 |
+
must place, in the relevant source files, a statement of the
|
| 400 |
+
additional terms that apply to those files, or a notice indicating
|
| 401 |
+
where to find the applicable terms.
|
| 402 |
+
|
| 403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
| 404 |
+
form of a separately written license, or stated as exceptions;
|
| 405 |
+
the above requirements apply either way.
|
| 406 |
+
|
| 407 |
+
8. Termination.
|
| 408 |
+
|
| 409 |
+
You may not propagate or modify a covered work except as expressly
|
| 410 |
+
provided under this License. Any attempt otherwise to propagate or
|
| 411 |
+
modify it is void, and will automatically terminate your rights under
|
| 412 |
+
this License (including any patent licenses granted under the third
|
| 413 |
+
paragraph of section 11).
|
| 414 |
+
|
| 415 |
+
However, if you cease all violation of this License, then your
|
| 416 |
+
license from a particular copyright holder is reinstated (a)
|
| 417 |
+
provisionally, unless and until the copyright holder explicitly and
|
| 418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
| 419 |
+
holder fails to notify you of the violation by some reasonable means
|
| 420 |
+
prior to 60 days after the cessation.
|
| 421 |
+
|
| 422 |
+
Moreover, your license from a particular copyright holder is
|
| 423 |
+
reinstated permanently if the copyright holder notifies you of the
|
| 424 |
+
violation by some reasonable means, this is the first time you have
|
| 425 |
+
received notice of violation of this License (for any work) from that
|
| 426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
| 427 |
+
your receipt of the notice.
|
| 428 |
+
|
| 429 |
+
Termination of your rights under this section does not terminate the
|
| 430 |
+
licenses of parties who have received copies or rights from you under
|
| 431 |
+
this License. If your rights have been terminated and not permanently
|
| 432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
| 433 |
+
material under section 10.
|
| 434 |
+
|
| 435 |
+
9. Acceptance Not Required for Having Copies.
|
| 436 |
+
|
| 437 |
+
You are not required to accept this License in order to receive or
|
| 438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
| 439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
| 440 |
+
to receive a copy likewise does not require acceptance. However,
|
| 441 |
+
nothing other than this License grants you permission to propagate or
|
| 442 |
+
modify any covered work. These actions infringe copyright if you do
|
| 443 |
+
not accept this License. Therefore, by modifying or propagating a
|
| 444 |
+
covered work, you indicate your acceptance of this License to do so.
|
| 445 |
+
|
| 446 |
+
10. Automatic Licensing of Downstream Recipients.
|
| 447 |
+
|
| 448 |
+
Each time you convey a covered work, the recipient automatically
|
| 449 |
+
receives a license from the original licensors, to run, modify and
|
| 450 |
+
propagate that work, subject to this License. You are not responsible
|
| 451 |
+
for enforcing compliance by third parties with this License.
|
| 452 |
+
|
| 453 |
+
An "entity transaction" is a transaction transferring control of an
|
| 454 |
+
organization, or substantially all assets of one, or subdividing an
|
| 455 |
+
organization, or merging organizations. If propagation of a covered
|
| 456 |
+
work results from an entity transaction, each party to that
|
| 457 |
+
transaction who receives a copy of the work also receives whatever
|
| 458 |
+
licenses to the work the party's predecessor in interest had or could
|
| 459 |
+
give under the previous paragraph, plus a right to possession of the
|
| 460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
| 461 |
+
the predecessor has it or can get it with reasonable efforts.
|
| 462 |
+
|
| 463 |
+
You may not impose any further restrictions on the exercise of the
|
| 464 |
+
rights granted or affirmed under this License. For example, you may
|
| 465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
| 466 |
+
rights granted under this License, and you may not initiate litigation
|
| 467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
| 468 |
+
any patent claim is infringed by making, using, selling, offering for
|
| 469 |
+
sale, or importing the Program or any portion of it.
|
| 470 |
+
|
| 471 |
+
11. Patents.
|
| 472 |
+
|
| 473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
| 474 |
+
License of the Program or a work on which the Program is based. The
|
| 475 |
+
work thus licensed is called the contributor's "contributor version".
|
| 476 |
+
|
| 477 |
+
A contributor's "essential patent claims" are all patent claims
|
| 478 |
+
owned or controlled by the contributor, whether already acquired or
|
| 479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
| 480 |
+
by this License, of making, using, or selling its contributor version,
|
| 481 |
+
but do not include claims that would be infringed only as a
|
| 482 |
+
consequence of further modification of the contributor version. For
|
| 483 |
+
purposes of this definition, "control" includes the right to grant
|
| 484 |
+
patent sublicenses in a manner consistent with the requirements of
|
| 485 |
+
this License.
|
| 486 |
+
|
| 487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
| 488 |
+
patent license under the contributor's essential patent claims, to
|
| 489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
| 490 |
+
propagate the contents of its contributor version.
|
| 491 |
+
|
| 492 |
+
In the following three paragraphs, a "patent license" is any express
|
| 493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
| 494 |
+
(such as an express permission to practice a patent or covenant not to
|
| 495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
| 496 |
+
party means to make such an agreement or commitment not to enforce a
|
| 497 |
+
patent against the party.
|
| 498 |
+
|
| 499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
| 500 |
+
and the Corresponding Source of the work is not available for anyone
|
| 501 |
+
to copy, free of charge and under the terms of this License, through a
|
| 502 |
+
publicly available network server or other readily accessible means,
|
| 503 |
+
then you must either (1) cause the Corresponding Source to be so
|
| 504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
| 505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
| 506 |
+
consistent with the requirements of this License, to extend the patent
|
| 507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
| 508 |
+
actual knowledge that, but for the patent license, your conveying the
|
| 509 |
+
covered work in a country, or your recipient's use of the covered work
|
| 510 |
+
in a country, would infringe one or more identifiable patents in that
|
| 511 |
+
country that you have reason to believe are valid.
|
| 512 |
+
|
| 513 |
+
If, pursuant to or in connection with a single transaction or
|
| 514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
| 515 |
+
covered work, and grant a patent license to some of the parties
|
| 516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
| 517 |
+
or convey a specific copy of the covered work, then the patent license
|
| 518 |
+
you grant is automatically extended to all recipients of the covered
|
| 519 |
+
work and works based on it.
|
| 520 |
+
|
| 521 |
+
A patent license is "discriminatory" if it does not include within
|
| 522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
| 523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
| 524 |
+
specifically granted under this License. You may not convey a covered
|
| 525 |
+
work if you are a party to an arrangement with a third party that is
|
| 526 |
+
in the business of distributing software, under which you make payment
|
| 527 |
+
to the third party based on the extent of your activity of conveying
|
| 528 |
+
the work, and under which the third party grants, to any of the
|
| 529 |
+
parties who would receive the covered work from you, a discriminatory
|
| 530 |
+
patent license (a) in connection with copies of the covered work
|
| 531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
| 532 |
+
for and in connection with specific products or compilations that
|
| 533 |
+
contain the covered work, unless you entered into that arrangement,
|
| 534 |
+
or that patent license was granted, prior to 28 March 2007.
|
| 535 |
+
|
| 536 |
+
Nothing in this License shall be construed as excluding or limiting
|
| 537 |
+
any implied license or other defenses to infringement that may
|
| 538 |
+
otherwise be available to you under applicable patent law.
|
| 539 |
+
|
| 540 |
+
12. No Surrender of Others' Freedom.
|
| 541 |
+
|
| 542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
| 543 |
+
otherwise) that contradict the conditions of this License, they do not
|
| 544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
| 545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
| 546 |
+
License and any other pertinent obligations, then as a consequence you may
|
| 547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
| 548 |
+
to collect a royalty for further conveying from those to whom you convey
|
| 549 |
+
the Program, the only way you could satisfy both those terms and this
|
| 550 |
+
License would be to refrain entirely from conveying the Program.
|
| 551 |
+
|
| 552 |
+
13. Use with the GNU Affero General Public License.
|
| 553 |
+
|
| 554 |
+
Notwithstanding any other provision of this License, you have
|
| 555 |
+
permission to link or combine any covered work with a work licensed
|
| 556 |
+
under version 3 of the GNU Affero General Public License into a single
|
| 557 |
+
combined work, and to convey the resulting work. The terms of this
|
| 558 |
+
License will continue to apply to the part which is the covered work,
|
| 559 |
+
but the special requirements of the GNU Affero General Public License,
|
| 560 |
+
section 13, concerning interaction through a network will apply to the
|
| 561 |
+
combination as such.
|
| 562 |
+
|
| 563 |
+
14. Revised Versions of this License.
|
| 564 |
+
|
| 565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
| 566 |
+
the GNU General Public License from time to time. Such new versions will
|
| 567 |
+
be similar in spirit to the present version, but may differ in detail to
|
| 568 |
+
address new problems or concerns.
|
| 569 |
+
|
| 570 |
+
Each version is given a distinguishing version number. If the
|
| 571 |
+
Program specifies that a certain numbered version of the GNU General
|
| 572 |
+
Public License "or any later version" applies to it, you have the
|
| 573 |
+
option of following the terms and conditions either of that numbered
|
| 574 |
+
version or of any later version published by the Free Software
|
| 575 |
+
Foundation. If the Program does not specify a version number of the
|
| 576 |
+
GNU General Public License, you may choose any version ever published
|
| 577 |
+
by the Free Software Foundation.
|
| 578 |
+
|
| 579 |
+
If the Program specifies that a proxy can decide which future
|
| 580 |
+
versions of the GNU General Public License can be used, that proxy's
|
| 581 |
+
public statement of acceptance of a version permanently authorizes you
|
| 582 |
+
to choose that version for the Program.
|
| 583 |
+
|
| 584 |
+
Later license versions may give you additional or different
|
| 585 |
+
permissions. However, no additional obligations are imposed on any
|
| 586 |
+
author or copyright holder as a result of your choosing to follow a
|
| 587 |
+
later version.
|
| 588 |
+
|
| 589 |
+
15. Disclaimer of Warranty.
|
| 590 |
+
|
| 591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
| 592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
| 593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
| 594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
| 595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
| 596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
| 597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
| 598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
| 599 |
+
|
| 600 |
+
16. Limitation of Liability.
|
| 601 |
+
|
| 602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
| 603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
| 604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
| 605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
| 606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
| 607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
| 608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
| 609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
| 610 |
+
SUCH DAMAGES.
|
| 611 |
+
|
| 612 |
+
17. Interpretation of Sections 15 and 16.
|
| 613 |
+
|
| 614 |
+
If the disclaimer of warranty and limitation of liability provided
|
| 615 |
+
above cannot be given local legal effect according to their terms,
|
| 616 |
+
reviewing courts shall apply local law that most closely approximates
|
| 617 |
+
an absolute waiver of all civil liability in connection with the
|
| 618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
| 619 |
+
copy of the Program in return for a fee.
|
| 620 |
+
|
| 621 |
+
END OF TERMS AND CONDITIONS
|
| 622 |
+
|
| 623 |
+
How to Apply These Terms to Your New Programs
|
| 624 |
+
|
| 625 |
+
If you develop a new program, and you want it to be of the greatest
|
| 626 |
+
possible use to the public, the best way to achieve this is to make it
|
| 627 |
+
free software which everyone can redistribute and change under these terms.
|
| 628 |
+
|
| 629 |
+
To do so, attach the following notices to the program. It is safest
|
| 630 |
+
to attach them to the start of each source file to most effectively
|
| 631 |
+
state the exclusion of warranty; and each file should have at least
|
| 632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
| 633 |
+
|
| 634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
| 635 |
+
Copyright (C) <year> <name of author>
|
| 636 |
+
|
| 637 |
+
This program is free software: you can redistribute it and/or modify
|
| 638 |
+
it under the terms of the GNU General Public License as published by
|
| 639 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 640 |
+
(at your option) any later version.
|
| 641 |
+
|
| 642 |
+
This program is distributed in the hope that it will be useful,
|
| 643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 645 |
+
GNU General Public License for more details.
|
| 646 |
+
|
| 647 |
+
You should have received a copy of the GNU General Public License
|
| 648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 649 |
+
|
| 650 |
+
Also add information on how to contact you by electronic and paper mail.
|
| 651 |
+
|
| 652 |
+
If the program does terminal interaction, make it output a short
|
| 653 |
+
notice like this when it starts in an interactive mode:
|
| 654 |
+
|
| 655 |
+
<program> Copyright (C) <year> <name of author>
|
| 656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
| 657 |
+
This is free software, and you are welcome to redistribute it
|
| 658 |
+
under certain conditions; type `show c' for details.
|
| 659 |
+
|
| 660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
| 661 |
+
parts of the General Public License. Of course, your program's commands
|
| 662 |
+
might be different; for a GUI interface, you would use an "about box".
|
| 663 |
+
|
| 664 |
+
You should also get your employer (if you work as a programmer) or school,
|
| 665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
| 666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
| 667 |
+
<https://www.gnu.org/licenses/>.
|
| 668 |
+
|
| 669 |
+
The GNU General Public License does not permit incorporating your program
|
| 670 |
+
into proprietary programs. If your program is a subroutine library, you
|
| 671 |
+
may consider it more useful to permit linking proprietary applications with
|
| 672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
| 673 |
+
Public License instead of this License. But first, please read
|
| 674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
app.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import yaml
|
| 8 |
+
from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo
|
| 9 |
+
from demo.mm_utils import run_MMDetector, run_MMPose
|
| 10 |
+
from mmdet.apis import init_detector
|
| 11 |
+
from demo.sam2_utils import prepare_model as prepare_sam2_model
|
| 12 |
+
from demo.sam2_utils import process_image_with_SAM
|
| 13 |
+
|
| 14 |
+
from mmpose.apis import init_model as init_pose_estimator
|
| 15 |
+
from mmpose.utils import adapt_mmdet_pipeline
|
| 16 |
+
|
| 17 |
+
# Default thresholds
|
| 18 |
+
DEFAULT_CAT_ID: int = 0
|
| 19 |
+
|
| 20 |
+
DEFAULT_BBOX_THR: float = 0.3
|
| 21 |
+
DEFAULT_NMS_THR: float = 0.3
|
| 22 |
+
DEFAULT_KPT_THR: float = 0.3
|
| 23 |
+
|
| 24 |
+
# Global models variable
|
| 25 |
+
det_model = None
|
| 26 |
+
pose_model = None
|
| 27 |
+
sam2_model = None
|
| 28 |
+
|
| 29 |
+
def _parse_yaml_config(yaml_path: Path) -> DotDict:
|
| 30 |
+
"""
|
| 31 |
+
Load BMP configuration from a YAML file.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
yaml_path (Path): Path to YAML config.
|
| 35 |
+
Returns:
|
| 36 |
+
DotDict: Nested config dictionary.
|
| 37 |
+
"""
|
| 38 |
+
with open(yaml_path, "r") as f:
|
| 39 |
+
cfg = yaml.safe_load(f)
|
| 40 |
+
return DotDict(cfg)
|
| 41 |
+
|
| 42 |
+
def load_models(bmp_config):
|
| 43 |
+
device = 'cuda:0'
|
| 44 |
+
|
| 45 |
+
global det_model, pose_model, sam2_model
|
| 46 |
+
|
| 47 |
+
# build detectors
|
| 48 |
+
det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF
|
| 49 |
+
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# build pose estimator
|
| 53 |
+
pose_model = init_pose_estimator(
|
| 54 |
+
bmp_config.pose_estimator.pose_config,
|
| 55 |
+
bmp_config.pose_estimator.pose_checkpoint,
|
| 56 |
+
device=device,
|
| 57 |
+
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
sam2_model = prepare_sam2_model(
|
| 61 |
+
model_cfg=bmp_config.sam2.sam2_config,
|
| 62 |
+
model_checkpoint=bmp_config.sam2.sam2_checkpoint,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return det_model, pose_model, sam2_model
|
| 66 |
+
|
| 67 |
+
@spaces.GPU(duration=60)
|
| 68 |
+
def process_image_with_BMP(
|
| 69 |
+
img: np.ndarray
|
| 70 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 71 |
+
"""
|
| 72 |
+
Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
args (Namespace): Parsed CLI arguments.
|
| 76 |
+
bmp_config (DotDict): Configuration parameters.
|
| 77 |
+
img_path (Path): Path to the input image.
|
| 78 |
+
detector: Primary MMDetection model.
|
| 79 |
+
detector_prime: Secondary MMDetection model for iterations.
|
| 80 |
+
pose_estimator: MMPose model for keypoint estimation.
|
| 81 |
+
sam2_model: SAM model for mask refinement.
|
| 82 |
+
Returns:
|
| 83 |
+
InstanceData: Final merged detections and refined masks.
|
| 84 |
+
"""
|
| 85 |
+
bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml"))
|
| 86 |
+
load_models(bmp_config)
|
| 87 |
+
|
| 88 |
+
# img: RGB -> BGR
|
| 89 |
+
img = img[..., ::-1]
|
| 90 |
+
|
| 91 |
+
img_for_detection = img.copy()
|
| 92 |
+
rtmdet_result = None
|
| 93 |
+
all_detections = None
|
| 94 |
+
for iteration in range(bmp_config.num_bmp_iters):
|
| 95 |
+
|
| 96 |
+
# Step 1: Detection
|
| 97 |
+
det_instances = run_MMDetector(
|
| 98 |
+
det_model,
|
| 99 |
+
img_for_detection,
|
| 100 |
+
det_cat_id=DEFAULT_CAT_ID,
|
| 101 |
+
bbox_thr=DEFAULT_BBOX_THR,
|
| 102 |
+
nms_thr=DEFAULT_NMS_THR,
|
| 103 |
+
)
|
| 104 |
+
if len(det_instances.bboxes) == 0:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Step 2: Pose estimation
|
| 108 |
+
pose_instances = run_MMPose(
|
| 109 |
+
pose_model,
|
| 110 |
+
img.copy(),
|
| 111 |
+
detections=det_instances,
|
| 112 |
+
kpt_thr=DEFAULT_KPT_THR,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Restrict to first 17 COCO keypoints
|
| 116 |
+
pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
|
| 117 |
+
pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
|
| 118 |
+
pose_instances.keypoints = np.concatenate(
|
| 119 |
+
[pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Step 3: Pose-NMS and SAM refinement
|
| 123 |
+
all_keypoints = (
|
| 124 |
+
pose_instances.keypoints
|
| 125 |
+
if all_detections is None
|
| 126 |
+
else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
|
| 127 |
+
)
|
| 128 |
+
all_bboxes = (
|
| 129 |
+
pose_instances.bboxes
|
| 130 |
+
if all_detections is None
|
| 131 |
+
else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
|
| 132 |
+
)
|
| 133 |
+
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
|
| 134 |
+
keep_indices = pose_nms(
|
| 135 |
+
DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
|
| 136 |
+
image_kpts=all_keypoints,
|
| 137 |
+
image_bboxes=all_bboxes,
|
| 138 |
+
num_valid_kpts=num_valid_kpts,
|
| 139 |
+
)
|
| 140 |
+
keep_indices = sorted(keep_indices) # Sort by original index
|
| 141 |
+
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
|
| 142 |
+
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
|
| 143 |
+
keep_old_indices = [i for i in keep_indices if i < num_old_detections]
|
| 144 |
+
if len(keep_new_indices) == 0:
|
| 145 |
+
continue
|
| 146 |
+
# filter new detections and compute scores
|
| 147 |
+
new_dets = filter_instances(pose_instances, keep_new_indices)
|
| 148 |
+
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
|
| 149 |
+
old_dets = None
|
| 150 |
+
if len(keep_old_indices) > 0:
|
| 151 |
+
old_dets = filter_instances(all_detections, keep_old_indices)
|
| 152 |
+
|
| 153 |
+
new_detections = process_image_with_SAM(
|
| 154 |
+
DotDict(bmp_config.sam2.prompting),
|
| 155 |
+
img.copy(),
|
| 156 |
+
sam2_model,
|
| 157 |
+
new_dets,
|
| 158 |
+
old_dets if old_dets is not None else None,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Merge detections
|
| 162 |
+
if all_detections is None:
|
| 163 |
+
all_detections = new_detections
|
| 164 |
+
else:
|
| 165 |
+
all_detections = concat_instances(all_detections, new_dets)
|
| 166 |
+
|
| 167 |
+
# Step 4: Visualization
|
| 168 |
+
img_for_detection, rtmdet_r, _ = visualize_demo(
|
| 169 |
+
img.copy(),
|
| 170 |
+
all_detections,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if iteration == 0:
|
| 174 |
+
rtmdet_result = rtmdet_r
|
| 175 |
+
|
| 176 |
+
_, _, bmp_result = visualize_demo(
|
| 177 |
+
img.copy(),
|
| 178 |
+
all_detections,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# img: BGR -> RGB
|
| 182 |
+
rtmdet_result = rtmdet_result[..., ::-1]
|
| 183 |
+
bmp_result = bmp_result[..., ::-1]
|
| 184 |
+
|
| 185 |
+
return rtmdet_result, bmp_result
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
with gr.Blocks() as app:
|
| 189 |
+
gr.Markdown("# BBoxMaskPose Image Demo")
|
| 190 |
+
gr.Markdown(
|
| 191 |
+
"Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]"
|
| 192 |
+
)
|
| 193 |
+
gr.Markdown(
|
| 194 |
+
"For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). "
|
| 195 |
+
"The demo showcases the capabilities of the BBoxMaskPose framework on any image. "
|
| 196 |
+
"If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). "
|
| 197 |
+
"Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
with gr.Row():
|
| 201 |
+
with gr.Column():
|
| 202 |
+
original_image_input = gr.Image(type="numpy", label="Original Image")
|
| 203 |
+
submit_button = gr.Button("Run Inference")
|
| 204 |
+
|
| 205 |
+
with gr.Column():
|
| 206 |
+
output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B")
|
| 207 |
+
|
| 208 |
+
with gr.Column():
|
| 209 |
+
output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
gr.Examples(
|
| 213 |
+
label="OCHuman examples",
|
| 214 |
+
examples=[
|
| 215 |
+
["examples/004806.jpg"],
|
| 216 |
+
["examples/005056.jpg"],
|
| 217 |
+
["examples/004981.jpg"],
|
| 218 |
+
["examples/004655.jpg"],
|
| 219 |
+
["examples/004684.jpg"],
|
| 220 |
+
["examples/004974.jpg"],
|
| 221 |
+
["examples/004983.jpg"],
|
| 222 |
+
["examples/005017.jpg"],
|
| 223 |
+
["examples/004849.jpg"],
|
| 224 |
+
],
|
| 225 |
+
inputs=[
|
| 226 |
+
original_image_input,
|
| 227 |
+
],
|
| 228 |
+
outputs=[output_standard, output_sahi_sliced],
|
| 229 |
+
fn=process_image_with_BMP,
|
| 230 |
+
cache_examples=True,
|
| 231 |
+
)
|
| 232 |
+
gr.Examples(
|
| 233 |
+
label="In-the-wild examples",
|
| 234 |
+
examples=[
|
| 235 |
+
["examples/prochazka_MMA.jpg"],
|
| 236 |
+
["examples/riner_judo.jpg"],
|
| 237 |
+
["examples/tackle3.jpg"],
|
| 238 |
+
["examples/tackle1.jpg"],
|
| 239 |
+
["examples/tackle2.jpg"],
|
| 240 |
+
["examples/tackle5.jpg"],
|
| 241 |
+
["examples/floorball_SKV_3.jpg"],
|
| 242 |
+
["examples/santa_o_crop.jpg"],
|
| 243 |
+
["examples/floorball_SKV_2.jpg"],
|
| 244 |
+
],
|
| 245 |
+
inputs=[
|
| 246 |
+
original_image_input,
|
| 247 |
+
],
|
| 248 |
+
outputs=[output_standard, output_sahi_sliced],
|
| 249 |
+
fn=process_image_with_BMP,
|
| 250 |
+
cache_examples=True,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
submit_button.click(
|
| 254 |
+
fn=process_image_with_BMP,
|
| 255 |
+
inputs=[
|
| 256 |
+
original_image_input,
|
| 257 |
+
],
|
| 258 |
+
outputs=[output_standard, output_sahi_sliced],
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Launch the demo
|
| 262 |
+
app.launch()
|
configs/README.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration Files Overview
|
| 2 |
+
|
| 3 |
+
This directory contains configuration files for reproducing experiments and running inference across different components of the BBoxMaskPose project.
|
| 4 |
+
|
| 5 |
+
## Which configs are available?
|
| 6 |
+
|
| 7 |
+
Here you can find configs setting-up hyperparameters of the whole loop.
|
| 8 |
+
These are mainly:
|
| 9 |
+
- How to prompt SAM
|
| 10 |
+
- Which models to use (detection, pose, SAM)
|
| 11 |
+
- How to chain models
|
| 12 |
+
- ...
|
| 13 |
+
|
| 14 |
+
For easier reference, the configs have the same names as in the supplementary material of the ICCV paper.
|
| 15 |
+
So for example config [**bmp_D3.yaml**](bmp_D3.yaml) is the prompting experiment used in the BMP loop.
|
| 16 |
+
For details, see Tabs. 6 - 8 of the supplementary.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
## Where are appropriate configs?
|
| 20 |
+
|
| 21 |
+
- **/configs** (this folder)
|
| 22 |
+
- Hyperparameter configurations for the BMP loop experiments. Use these files to reproduce training and evaluation settings.
|
| 23 |
+
|
| 24 |
+
- **/mmpose/configs**
|
| 25 |
+
- Configuration files for MMPose, following the same format and structure as MMPose v1.3.1. Supports models, datasets, and training pipelines.
|
| 26 |
+
|
| 27 |
+
- **/sam2/configs**
|
| 28 |
+
- Configuration files for SAM2, matching the format and directory layout of the original SAM v2.1 repository. Use these for prompt-driven segmentation and related tasks.
|
| 29 |
+
|
| 30 |
+
|
configs/bmp_D3.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BBoxMaskPose Hyperparameters from Experiment D3.
|
| 2 |
+
# For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
|
| 3 |
+
|
| 4 |
+
# This configuration is good for the BMP loop as was used for most of the experiments.
|
| 5 |
+
detector:
|
| 6 |
+
det_config: 'mmpose/configs/mmdet/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
|
| 7 |
+
det_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth'
|
| 8 |
+
|
| 9 |
+
# Detectors D and D' could be different.
|
| 10 |
+
det_prime_config: null
|
| 11 |
+
det_prime_checkpoint: null
|
| 12 |
+
|
| 13 |
+
pose_estimator:
|
| 14 |
+
pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
|
| 15 |
+
pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
|
| 16 |
+
|
| 17 |
+
sam2:
|
| 18 |
+
sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
|
| 19 |
+
sam2_checkpoint: 'models/SAM/sam2.1_hiera_base_plus.pt'
|
| 20 |
+
prompting:
|
| 21 |
+
batch: False
|
| 22 |
+
use_bbox: False
|
| 23 |
+
num_pos_keypoints: 6
|
| 24 |
+
num_pos_keypoints_if_crowd: 6
|
| 25 |
+
num_neg_keypoints: 0
|
| 26 |
+
confidence_thr: 0.3
|
| 27 |
+
visibility_thr: 0.3
|
| 28 |
+
selection_method: 'distance+confidence'
|
| 29 |
+
extend_bbox: False
|
| 30 |
+
pose_mask_consistency: False
|
| 31 |
+
crowd_by_max_iou: False # Determine if the instance is in the multi-body scenario. If yes, use different amount of keypoints and NO BBOX. If no, use bbox according to 'use_bbox' argument.
|
| 32 |
+
crop: False
|
| 33 |
+
exclusive_masks: True
|
| 34 |
+
ignore_small_bboxes: False
|
| 35 |
+
|
| 36 |
+
num_bmp_iters: 2
|
| 37 |
+
oks_nms_thr: 0.8
|
configs/bmp_J1.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BBoxMaskPose Hyperparameters from Experiment J1.
|
| 2 |
+
# For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
|
| 3 |
+
|
| 4 |
+
# This configuration is good for getting extra AP points when the estimates are already good.
|
| 5 |
+
# It is not recommended for the whole loop (as done here -- this is for the demo) but rather for
|
| 6 |
+
# the det-pose-sam-pose studied in Tab. 4.
|
| 7 |
+
detector:
|
| 8 |
+
det_config: 'mmpose/configs/mmdet/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
|
| 9 |
+
det_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth'
|
| 10 |
+
|
| 11 |
+
# Detectors D and D' could be different.
|
| 12 |
+
det_prime_config: null
|
| 13 |
+
det_prime_checkpoint: null
|
| 14 |
+
|
| 15 |
+
pose_estimator:
|
| 16 |
+
pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
|
| 17 |
+
pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
|
| 18 |
+
|
| 19 |
+
sam2:
|
| 20 |
+
sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
|
| 21 |
+
sam2_checkpoint: 'models/SAM/sam2.1_hiera_base_plus.pt'
|
| 22 |
+
prompting:
|
| 23 |
+
batch: True
|
| 24 |
+
use_bbox: False
|
| 25 |
+
num_pos_keypoints: 4
|
| 26 |
+
num_pos_keypoints_if_crowd: 6
|
| 27 |
+
num_neg_keypoints: 0
|
| 28 |
+
confidence_thr: 0.5
|
| 29 |
+
visibility_thr: 0.5
|
| 30 |
+
selection_method: 'distance+confidence'
|
| 31 |
+
extend_bbox: False
|
| 32 |
+
pose_mask_consistency: False
|
| 33 |
+
crowd_by_max_iou: 0.5 # Determine if the instance is in the multi-body scenario. If yes, use different amount of keypoints and NO BBOX. If no, use bbox according to 'use_bbox' argument.
|
| 34 |
+
crop: False
|
| 35 |
+
exclusive_masks: True
|
| 36 |
+
ignore_small_bboxes: False
|
| 37 |
+
|
| 38 |
+
num_bmp_iters: 2
|
| 39 |
+
oks_nms_thr: 0.8
|
demo/bmp_demo.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization.
|
| 4 |
+
Usage:
|
| 5 |
+
python bmp_demo.py <config.yaml> <input_image> [--output-root <dir>]
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
from argparse import ArgumentParser, Namespace
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import mmcv
|
| 14 |
+
import mmengine
|
| 15 |
+
import numpy as np
|
| 16 |
+
import yaml
|
| 17 |
+
from demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_itteration
|
| 18 |
+
from mm_utils import run_MMDetector, run_MMPose
|
| 19 |
+
from mmdet.apis import init_detector
|
| 20 |
+
from mmengine.logging import print_log
|
| 21 |
+
from mmengine.structures import InstanceData
|
| 22 |
+
from sam2_utils import prepare_model as prepare_sam2_model
|
| 23 |
+
from sam2_utils import process_image_with_SAM
|
| 24 |
+
|
| 25 |
+
from mmpose.apis import init_model as init_pose_estimator
|
| 26 |
+
from mmpose.utils import adapt_mmdet_pipeline
|
| 27 |
+
|
| 28 |
+
# Default thresholds
|
| 29 |
+
DEFAULT_DET_CAT_ID: int = 0 # "person"
|
| 30 |
+
DEFAULT_BBOX_THR: float = 0.3
|
| 31 |
+
DEFAULT_NMS_THR: float = 0.3
|
| 32 |
+
DEFAULT_KPT_THR: float = 0.3
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def parse_args() -> Namespace:
|
| 36 |
+
"""
|
| 37 |
+
Parse command-line arguments for BMP demo.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Namespace: Contains bmp_config (Path), input (Path), output_root (Path), device (str).
|
| 41 |
+
"""
|
| 42 |
+
parser = ArgumentParser(description="BBoxMaskPose demo")
|
| 43 |
+
parser.add_argument("bmp_config", type=Path, help="Path to BMP YAML config file")
|
| 44 |
+
parser.add_argument("input", type=Path, help="Input image file")
|
| 45 |
+
parser.add_argument("--output-root", type=Path, default=None, help="Directory to save outputs (default: ./outputs)")
|
| 46 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g., cuda:0 or cpu)")
|
| 47 |
+
parser.add_argument("--create-gif", action="store_true", default=False, help="Create GIF of all BMP iterations")
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
if args.output_root is None:
|
| 50 |
+
args.output_root = os.path.join(Path(__file__).parent, "outputs")
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_yaml_config(yaml_path: Path) -> DotDict:
|
| 55 |
+
"""
|
| 56 |
+
Load BMP configuration from a YAML file.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
yaml_path (Path): Path to YAML config.
|
| 60 |
+
Returns:
|
| 61 |
+
DotDict: Nested config dictionary.
|
| 62 |
+
"""
|
| 63 |
+
with open(yaml_path, "r") as f:
|
| 64 |
+
cfg = yaml.safe_load(f)
|
| 65 |
+
return DotDict(cfg)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def process_one_image(
|
| 69 |
+
args: Namespace,
|
| 70 |
+
bmp_config: DotDict,
|
| 71 |
+
img_path: Path,
|
| 72 |
+
detector: object,
|
| 73 |
+
detector_prime: object,
|
| 74 |
+
pose_estimator: object,
|
| 75 |
+
sam2_model: object,
|
| 76 |
+
) -> InstanceData:
|
| 77 |
+
"""
|
| 78 |
+
Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
args (Namespace): Parsed CLI arguments.
|
| 82 |
+
bmp_config (DotDict): Configuration parameters.
|
| 83 |
+
img_path (Path): Path to the input image.
|
| 84 |
+
detector: Primary MMDetection model.
|
| 85 |
+
detector_prime: Secondary MMDetection model for iterations.
|
| 86 |
+
pose_estimator: MMPose model for keypoint estimation.
|
| 87 |
+
sam2_model: SAM model for mask refinement.
|
| 88 |
+
Returns:
|
| 89 |
+
InstanceData: Final merged detections and refined masks.
|
| 90 |
+
"""
|
| 91 |
+
# Load image
|
| 92 |
+
img = mmcv.imread(str(img_path), channel_order="bgr")
|
| 93 |
+
if img is None:
|
| 94 |
+
raise ValueError("Failed to read image from {}.".format(img_path))
|
| 95 |
+
|
| 96 |
+
# Prepare output directory
|
| 97 |
+
output_dir = os.path.join(args.output_root, img_path.stem)
|
| 98 |
+
shutil.rmtree(str(output_dir), ignore_errors=True)
|
| 99 |
+
mmengine.mkdir_or_exist(str(output_dir))
|
| 100 |
+
|
| 101 |
+
img_for_detection = img.copy()
|
| 102 |
+
all_detections = None
|
| 103 |
+
for iteration in range(bmp_config.num_bmp_iters):
|
| 104 |
+
print_log("BMP Iteration {}/{} started".format(iteration + 1, bmp_config.num_bmp_iters), logger="current")
|
| 105 |
+
|
| 106 |
+
# Step 1: Detection
|
| 107 |
+
det_instances = run_MMDetector(
|
| 108 |
+
detector if iteration == 0 else detector_prime,
|
| 109 |
+
img_for_detection,
|
| 110 |
+
det_cat_id=DEFAULT_DET_CAT_ID,
|
| 111 |
+
bbox_thr=DEFAULT_BBOX_THR,
|
| 112 |
+
nms_thr=DEFAULT_NMS_THR,
|
| 113 |
+
)
|
| 114 |
+
print_log("Detected {} instances".format(len(det_instances.bboxes)), logger="current")
|
| 115 |
+
if len(det_instances.bboxes) == 0:
|
| 116 |
+
print_log("No detections found, skipping.", logger="current")
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# Step 2: Pose estimation
|
| 120 |
+
pose_instances = run_MMPose(
|
| 121 |
+
pose_estimator,
|
| 122 |
+
img.copy(),
|
| 123 |
+
detections=det_instances,
|
| 124 |
+
kpt_thr=DEFAULT_KPT_THR,
|
| 125 |
+
)
|
| 126 |
+
# Restrict to first 17 COCO keypoints
|
| 127 |
+
pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
|
| 128 |
+
pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
|
| 129 |
+
pose_instances.keypoints = np.concatenate(
|
| 130 |
+
[pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Step 3: Pose-NMS and SAM refinement
|
| 134 |
+
all_keypoints = (
|
| 135 |
+
pose_instances.keypoints
|
| 136 |
+
if all_detections is None
|
| 137 |
+
else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
|
| 138 |
+
)
|
| 139 |
+
all_bboxes = (
|
| 140 |
+
pose_instances.bboxes
|
| 141 |
+
if all_detections is None
|
| 142 |
+
else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
|
| 143 |
+
)
|
| 144 |
+
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
|
| 145 |
+
keep_indices = pose_nms(
|
| 146 |
+
DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
|
| 147 |
+
image_kpts=all_keypoints,
|
| 148 |
+
image_bboxes=all_bboxes,
|
| 149 |
+
num_valid_kpts=num_valid_kpts,
|
| 150 |
+
)
|
| 151 |
+
keep_indices = sorted(keep_indices) # Sort by original index
|
| 152 |
+
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
|
| 153 |
+
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
|
| 154 |
+
keep_old_indices = [i for i in keep_indices if i < num_old_detections]
|
| 155 |
+
if len(keep_new_indices) == 0:
|
| 156 |
+
print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current")
|
| 157 |
+
continue
|
| 158 |
+
# filter new detections and compute scores
|
| 159 |
+
new_dets = filter_instances(pose_instances, keep_new_indices)
|
| 160 |
+
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
|
| 161 |
+
old_dets = None
|
| 162 |
+
if len(keep_old_indices) > 0:
|
| 163 |
+
old_dets = filter_instances(all_detections, keep_old_indices)
|
| 164 |
+
print_log(
|
| 165 |
+
"Pose NMS reduced instances to {:d} ({:d}+{:d}) instances".format(
|
| 166 |
+
len(new_dets.bboxes) + num_old_detections, num_old_detections, len(new_dets.bboxes)
|
| 167 |
+
),
|
| 168 |
+
logger="current",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
new_detections = process_image_with_SAM(
|
| 172 |
+
DotDict(bmp_config.sam2.prompting),
|
| 173 |
+
img.copy(),
|
| 174 |
+
sam2_model,
|
| 175 |
+
new_dets,
|
| 176 |
+
old_dets if old_dets is not None else None,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Merge detections
|
| 180 |
+
if all_detections is None:
|
| 181 |
+
all_detections = new_detections
|
| 182 |
+
else:
|
| 183 |
+
all_detections = concat_instances(all_detections, new_dets)
|
| 184 |
+
|
| 185 |
+
# Step 4: Visualization
|
| 186 |
+
img_for_detection = visualize_itteration(
|
| 187 |
+
img.copy(),
|
| 188 |
+
all_detections,
|
| 189 |
+
iteration_idx=iteration,
|
| 190 |
+
output_root=str(output_dir),
|
| 191 |
+
img_name=img_path.stem,
|
| 192 |
+
)
|
| 193 |
+
print_log("Iteration {} completed".format(iteration + 1), logger="current")
|
| 194 |
+
|
| 195 |
+
# Create GIF of iterations if requested
|
| 196 |
+
if args.create_gif:
|
| 197 |
+
image_file = os.path.join(output_dir, "{:s}.jpg".format(img_path.stem))
|
| 198 |
+
create_GIF(
|
| 199 |
+
img_path=str(image_file),
|
| 200 |
+
output_root=str(output_dir),
|
| 201 |
+
bmp_x=bmp_config.num_bmp_iters,
|
| 202 |
+
)
|
| 203 |
+
return all_detections
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def main() -> None:
|
| 207 |
+
"""
|
| 208 |
+
Entry point for the BMP demo: loads models and processes one image.
|
| 209 |
+
"""
|
| 210 |
+
args = parse_args()
|
| 211 |
+
bmp_config = parse_yaml_config(args.bmp_config)
|
| 212 |
+
|
| 213 |
+
# Ensure output root exists
|
| 214 |
+
mmengine.mkdir_or_exist(str(args.output_root))
|
| 215 |
+
|
| 216 |
+
# build detectors
|
| 217 |
+
detector = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device=args.device)
|
| 218 |
+
detector.cfg = adapt_mmdet_pipeline(detector.cfg)
|
| 219 |
+
if (
|
| 220 |
+
bmp_config.detector.det_config == bmp_config.detector.det_prime_config
|
| 221 |
+
and bmp_config.detector.det_checkpoint == bmp_config.detector.det_prime_checkpoint
|
| 222 |
+
) or (bmp_config.detector.det_prime_config is None or bmp_config.detector.det_prime_checkpoint is None):
|
| 223 |
+
print_log("Using the same detector as D and D'", logger="current")
|
| 224 |
+
detector_prime = detector
|
| 225 |
+
else:
|
| 226 |
+
detector_prime = init_detector(
|
| 227 |
+
bmp_config.detector.det_prime_config, bmp_config.detector.det_prime_checkpoint, device=args.device
|
| 228 |
+
)
|
| 229 |
+
detector_prime.cfg = adapt_mmdet_pipeline(detector_prime.cfg)
|
| 230 |
+
print_log("Using a different detector for D'", logger="current")
|
| 231 |
+
|
| 232 |
+
# build pose estimator
|
| 233 |
+
pose_estimator = init_pose_estimator(
|
| 234 |
+
bmp_config.pose_estimator.pose_config,
|
| 235 |
+
bmp_config.pose_estimator.pose_checkpoint,
|
| 236 |
+
device=args.device,
|
| 237 |
+
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
sam2 = prepare_sam2_model(
|
| 241 |
+
model_cfg=bmp_config.sam2.sam2_config,
|
| 242 |
+
model_checkpoint=bmp_config.sam2.sam2_checkpoint,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Run inference on one image
|
| 246 |
+
_ = process_one_image(args, bmp_config, args.input, detector, detector_prime, pose_estimator, sam2)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|
demo/demo_utils.py
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for the BMP demo:
|
| 3 |
+
- Visualization of detections, masks, and poses
|
| 4 |
+
- Mask and bounding-box processing
|
| 5 |
+
- Pose non-maximum suppression (NMS)
|
| 6 |
+
- Animated GIF creation of demo iterations
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import subprocess
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
from mmengine.logging import print_log
|
| 19 |
+
from mmengine.structures import InstanceData
|
| 20 |
+
from pycocotools import mask as Mask
|
| 21 |
+
from sam2.distinctipy import get_colors
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
### Visualization hyperparameters
|
| 25 |
+
MIN_CONTOUR_AREA: int = 50
|
| 26 |
+
BBOX_WEIGHT: float = 0.9
|
| 27 |
+
MASK_WEIGHT: float = 0.6
|
| 28 |
+
BACK_MASK_WEIGHT: float = 0.6
|
| 29 |
+
POSE_WEIGHT: float = 0.8
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
posevis is our custom visualization library for pose estimation. For compatibility, we also provide a lite version that has fewer features but still reproduces visualization from the paper.
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
from posevis import pose_visualization
|
| 37 |
+
except ImportError:
|
| 38 |
+
from .posevis_lite import pose_visualization
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DotDict(dict):
|
| 42 |
+
"""Dictionary with attribute access and nested dict wrapping."""
|
| 43 |
+
|
| 44 |
+
def __getattr__(self, name: str) -> any:
|
| 45 |
+
if name in self:
|
| 46 |
+
val = self[name]
|
| 47 |
+
if isinstance(val, dict):
|
| 48 |
+
val = DotDict(val)
|
| 49 |
+
self[name] = val
|
| 50 |
+
return val
|
| 51 |
+
raise AttributeError("No attribute named {!r}".format(name))
|
| 52 |
+
|
| 53 |
+
def __setattr__(self, name: str, value: any) -> None:
|
| 54 |
+
self[name] = value
|
| 55 |
+
|
| 56 |
+
def __delattr__(self, name: str) -> None:
|
| 57 |
+
if name in self:
|
| 58 |
+
del self[name]
|
| 59 |
+
else:
|
| 60 |
+
raise AttributeError("No attribute named {!r}".format(name))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def filter_instances(instances: InstanceData, indices):
|
| 64 |
+
"""
|
| 65 |
+
Return a new InstanceData containing only the entries of 'instances' at the given indices.
|
| 66 |
+
"""
|
| 67 |
+
if instances is None:
|
| 68 |
+
return None
|
| 69 |
+
data = {}
|
| 70 |
+
# Attributes to filter
|
| 71 |
+
for attr in [
|
| 72 |
+
"bboxes",
|
| 73 |
+
"bbox_scores",
|
| 74 |
+
"keypoints",
|
| 75 |
+
"keypoint_scores",
|
| 76 |
+
"scores",
|
| 77 |
+
"pred_masks",
|
| 78 |
+
"refined_masks",
|
| 79 |
+
"sam_scores",
|
| 80 |
+
"sam_kpts",
|
| 81 |
+
]:
|
| 82 |
+
if hasattr(instances, attr):
|
| 83 |
+
arr = getattr(instances, attr)
|
| 84 |
+
data[attr] = arr[indices] if arr is not None else None
|
| 85 |
+
return InstanceData(**data)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def concat_instances(instances1: InstanceData, instances2: InstanceData):
|
| 89 |
+
"""
|
| 90 |
+
Concatenate two InstanceData objects along the first axis, preserving order.
|
| 91 |
+
If instances1 or instances2 is None, returns the other.
|
| 92 |
+
"""
|
| 93 |
+
if instances1 is None:
|
| 94 |
+
return instances2
|
| 95 |
+
if instances2 is None:
|
| 96 |
+
return instances1
|
| 97 |
+
data = {}
|
| 98 |
+
for attr in [
|
| 99 |
+
"bboxes",
|
| 100 |
+
"bbox_scores",
|
| 101 |
+
"keypoints",
|
| 102 |
+
"keypoint_scores",
|
| 103 |
+
"scores",
|
| 104 |
+
"pred_masks",
|
| 105 |
+
"refined_masks",
|
| 106 |
+
"sam_scores",
|
| 107 |
+
"sam_kpts",
|
| 108 |
+
]:
|
| 109 |
+
arr1 = getattr(instances1, attr, None)
|
| 110 |
+
arr2 = getattr(instances2, attr, None)
|
| 111 |
+
if arr1 is None and arr2 is None:
|
| 112 |
+
continue
|
| 113 |
+
if arr1 is None:
|
| 114 |
+
data[attr] = arr2
|
| 115 |
+
elif arr2 is None:
|
| 116 |
+
data[attr] = arr1
|
| 117 |
+
else:
|
| 118 |
+
data[attr] = np.concatenate([arr1, arr2], axis=0)
|
| 119 |
+
return InstanceData(**data)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _visualize_predictions(
|
| 123 |
+
img: np.ndarray,
|
| 124 |
+
bboxes: np.ndarray,
|
| 125 |
+
scores: np.ndarray,
|
| 126 |
+
masks: List[Optional[List[np.ndarray]]],
|
| 127 |
+
poses: List[Optional[np.ndarray]],
|
| 128 |
+
vis_type: str = "mask",
|
| 129 |
+
mask_is_binary: bool = False,
|
| 130 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 131 |
+
"""
|
| 132 |
+
Render bounding boxes, segmentation masks, and poses on the input image.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
img (np.ndarray): BGR image of shape (H, W, 3).
|
| 136 |
+
bboxes (np.ndarray): Array of bounding boxes [x, y, w, h].
|
| 137 |
+
scores (np.ndarray): Confidence scores for each bbox.
|
| 138 |
+
masks (List[Optional[List[np.ndarray]]]): Polygon masks per instance.
|
| 139 |
+
poses (List[Optional[np.ndarray]]): Keypoint arrays per instance.
|
| 140 |
+
vis_type (str): Flags for visualization types separated by '+'.
|
| 141 |
+
mask_is_binary (bool): Whether input masks are binary arrays.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Tuple[np.ndarray, np.ndarray]: The visualized image and color map.
|
| 145 |
+
"""
|
| 146 |
+
vis_types = vis_type.split("+")
|
| 147 |
+
|
| 148 |
+
# # Filter-out small detections to make the visualization more clear
|
| 149 |
+
# new_bboxes = []
|
| 150 |
+
# new_scores = []
|
| 151 |
+
# new_masks = []
|
| 152 |
+
# new_poses = []
|
| 153 |
+
# size_thr = img.shape[0] * img.shape[1] * 0.01
|
| 154 |
+
# for bbox, score, mask, pose in zip(bboxes, scores, masks, poses):
|
| 155 |
+
# area = mask.sum() # Assume binary mask. OK for demo purposes
|
| 156 |
+
# if area > size_thr:
|
| 157 |
+
# new_bboxes.append(bbox)
|
| 158 |
+
# new_scores.append(score)
|
| 159 |
+
# new_masks.append(mask)
|
| 160 |
+
# new_poses.append(pose)
|
| 161 |
+
# bboxes = np.array(new_bboxes)
|
| 162 |
+
# scores = np.array(new_scores)
|
| 163 |
+
# masks = new_masks
|
| 164 |
+
# poses = new_poses
|
| 165 |
+
|
| 166 |
+
if mask_is_binary:
|
| 167 |
+
poly_masks: List[Optional[List[np.ndarray]]] = []
|
| 168 |
+
for binary_mask in masks:
|
| 169 |
+
if binary_mask is not None:
|
| 170 |
+
contours, _ = cv2.findContours(
|
| 171 |
+
(binary_mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
| 172 |
+
)
|
| 173 |
+
polys = [cnt.flatten() for cnt in contours if cv2.contourArea(cnt) >= MIN_CONTOUR_AREA]
|
| 174 |
+
else:
|
| 175 |
+
polys = None
|
| 176 |
+
poly_masks.append(polys)
|
| 177 |
+
masks = poly_masks # type: ignore
|
| 178 |
+
|
| 179 |
+
# Exclude white, black, and green colors from the palette as they are not distinctive
|
| 180 |
+
colors = (np.array(get_colors(len(bboxes), exclude_colors=[(0, 1, 0), (.5, .5, .5), (0, 0, 0), (1, 1, 1)], rng=0)) * 255).astype(
|
| 181 |
+
int
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if "inv-mask" in vis_types:
|
| 186 |
+
stencil = np.zeros_like(img)
|
| 187 |
+
|
| 188 |
+
for bbox, score, mask_poly, pose, color in zip(bboxes, scores, masks, poses, colors):
|
| 189 |
+
bbox = _update_bbox_by_mask(list(map(int, bbox)), mask_poly, img.shape)
|
| 190 |
+
color_list = color.tolist()
|
| 191 |
+
img_copy = img.copy()
|
| 192 |
+
|
| 193 |
+
if "bbox" in vis_types:
|
| 194 |
+
x, y, w, h = bbox
|
| 195 |
+
cv2.rectangle(img_copy, (x, y), (x + w, y + h), color_list, 2)
|
| 196 |
+
img = cv2.addWeighted(img, 1 - BBOX_WEIGHT, img_copy, BBOX_WEIGHT, 0)
|
| 197 |
+
|
| 198 |
+
if mask_poly is not None and "mask" in vis_types:
|
| 199 |
+
for seg in mask_poly:
|
| 200 |
+
seg_pts = np.array(seg).reshape(-1, 1, 2).astype(int)
|
| 201 |
+
cv2.fillPoly(img_copy, [seg_pts], color_list)
|
| 202 |
+
img = cv2.addWeighted(img, 1 - MASK_WEIGHT, img_copy, MASK_WEIGHT, 0)
|
| 203 |
+
|
| 204 |
+
if mask_poly is not None and "mask-out" in vis_types:
|
| 205 |
+
for seg in mask_poly:
|
| 206 |
+
seg_pts = np.array(seg).reshape(-1, 1, 2).astype(int)
|
| 207 |
+
cv2.fillPoly(img, [seg_pts], (0, 0, 0))
|
| 208 |
+
|
| 209 |
+
if mask_poly is not None and "inv-mask" in vis_types:
|
| 210 |
+
for seg in mask_poly:
|
| 211 |
+
seg = np.array(seg).reshape(-1, 1, 2).astype(int)
|
| 212 |
+
if cv2.contourArea(seg) < MIN_CONTOUR_AREA:
|
| 213 |
+
continue
|
| 214 |
+
cv2.fillPoly(stencil, [seg], (255, 255, 255))
|
| 215 |
+
|
| 216 |
+
if pose is not None and "pose" in vis_types:
|
| 217 |
+
vis_img = pose_visualization(
|
| 218 |
+
img.copy(),
|
| 219 |
+
pose.reshape(-1, 3),
|
| 220 |
+
width_multiplier=8,
|
| 221 |
+
differ_individuals=True,
|
| 222 |
+
color=color_list,
|
| 223 |
+
keep_image_size=True,
|
| 224 |
+
)
|
| 225 |
+
img = cv2.addWeighted(img, 1 - POSE_WEIGHT, vis_img, POSE_WEIGHT, 0)
|
| 226 |
+
|
| 227 |
+
if "inv-mask" in vis_types:
|
| 228 |
+
img = cv2.addWeighted(img, 1 - BACK_MASK_WEIGHT, cv2.bitwise_and(img, stencil), BACK_MASK_WEIGHT, 0)
|
| 229 |
+
|
| 230 |
+
return img, colors
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def visualize_itteration(
|
| 234 |
+
img: np.ndarray, detections: Any, iteration_idx: int, output_root: Path, img_name: str, with_text: bool = True
|
| 235 |
+
) -> Optional[np.ndarray]:
|
| 236 |
+
"""
|
| 237 |
+
Generate and save visualization images for each BMP iteration.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
img (np.ndarray): Original input image.
|
| 241 |
+
detections: InstanceData containing bboxes, scores, masks, keypoints.
|
| 242 |
+
iteration_idx (int): Current iteration index (0-based).
|
| 243 |
+
output_root (Path): Directory to save output images.
|
| 244 |
+
img_name (str): Base name of the image without extension.
|
| 245 |
+
with_text (bool): Whether to overlay text labels.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Optional[np.ndarray]: The masked-out image if generated, else None.
|
| 249 |
+
"""
|
| 250 |
+
bboxes = detections.bboxes
|
| 251 |
+
scores = detections.scores
|
| 252 |
+
pred_masks = detections.pred_masks
|
| 253 |
+
refined_masks = detections.refined_masks
|
| 254 |
+
keypoints = detections.keypoints
|
| 255 |
+
sam_kpts = detections.sam_kpts
|
| 256 |
+
|
| 257 |
+
masked_out = None
|
| 258 |
+
for vis_def in [
|
| 259 |
+
{"type": "bbox+mask", "masks": pred_masks, "label": "Detector (out)"},
|
| 260 |
+
{"type": "inv-mask", "masks": pred_masks, "label": "MaskPose (in)"},
|
| 261 |
+
{"type": "inv-mask+pose", "masks": pred_masks, "label": "MaskPose (out)"},
|
| 262 |
+
{"type": "mask", "masks": refined_masks, "label": "SAM Masks"},
|
| 263 |
+
{"type": "mask-out", "masks": refined_masks, "label": "Mask-Out"},
|
| 264 |
+
{"type": "pose", "masks": refined_masks, "label": "Final Poses"},
|
| 265 |
+
]:
|
| 266 |
+
vis_img, colors = _visualize_predictions(
|
| 267 |
+
img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
|
| 268 |
+
)
|
| 269 |
+
if vis_def["type"] == "mask-out":
|
| 270 |
+
masked_out = vis_img
|
| 271 |
+
if with_text:
|
| 272 |
+
label = "BMP {:d}x: {}".format(iteration_idx + 1, vis_def["label"])
|
| 273 |
+
cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
|
| 274 |
+
cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
| 275 |
+
out_path = os.path.join(
|
| 276 |
+
output_root, "{}_iter{}_{}.jpg".format(img_name, iteration_idx + 1, vis_def["label"].replace(" ", "_"))
|
| 277 |
+
)
|
| 278 |
+
cv2.imwrite(str(out_path), vis_img)
|
| 279 |
+
|
| 280 |
+
# Show prompting keypoints
|
| 281 |
+
tmp_img = img.copy()
|
| 282 |
+
for i, _ in enumerate(bboxes):
|
| 283 |
+
if len(sam_kpts[i]) > 0:
|
| 284 |
+
instance_color = colors[i].astype(int).tolist()
|
| 285 |
+
for kpt in sam_kpts[i]:
|
| 286 |
+
cv2.drawMarker(
|
| 287 |
+
tmp_img,
|
| 288 |
+
(int(kpt[0]), int(kpt[1])),
|
| 289 |
+
instance_color,
|
| 290 |
+
markerType=cv2.MARKER_CROSS,
|
| 291 |
+
markerSize=20,
|
| 292 |
+
thickness=3,
|
| 293 |
+
)
|
| 294 |
+
# Write the keypoint confidence next to the marker
|
| 295 |
+
cv2.putText(
|
| 296 |
+
tmp_img,
|
| 297 |
+
f"{kpt[2]:.2f}",
|
| 298 |
+
(int(kpt[0]) + 10, int(kpt[1]) - 10),
|
| 299 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 300 |
+
0.5,
|
| 301 |
+
instance_color,
|
| 302 |
+
1,
|
| 303 |
+
cv2.LINE_AA,
|
| 304 |
+
)
|
| 305 |
+
if with_text:
|
| 306 |
+
text = "BMP {:d}x: SAM prompts".format(iteration_idx + 1)
|
| 307 |
+
cv2.putText(tmp_img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3, cv2.LINE_AA)
|
| 308 |
+
cv2.putText(tmp_img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2, cv2.LINE_AA)
|
| 309 |
+
cv2.imwrite("{:s}/{:s}_iter{:d}_prompting_kpts.jpg".format(output_root, img_name, iteration_idx + 1), tmp_img)
|
| 310 |
+
|
| 311 |
+
return masked_out
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def visualize_demo(
|
| 315 |
+
img: np.ndarray, detections: Any,
|
| 316 |
+
) -> Optional[np.ndarray]:
|
| 317 |
+
"""
|
| 318 |
+
Generate and save visualization images for each BMP iteration.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
img (np.ndarray): Original input image.
|
| 322 |
+
detections: InstanceData containing bboxes, scores, masks, keypoints.
|
| 323 |
+
iteration_idx (int): Current iteration index (0-based).
|
| 324 |
+
output_root (Path): Directory to save output images.
|
| 325 |
+
img_name (str): Base name of the image without extension.
|
| 326 |
+
with_text (bool): Whether to overlay text labels.
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Optional[np.ndarray]: The masked-out image if generated, else None.
|
| 330 |
+
"""
|
| 331 |
+
bboxes = detections.bboxes
|
| 332 |
+
scores = detections.scores
|
| 333 |
+
pred_masks = detections.pred_masks
|
| 334 |
+
refined_masks = detections.refined_masks
|
| 335 |
+
keypoints = detections.keypoints
|
| 336 |
+
|
| 337 |
+
returns = []
|
| 338 |
+
for vis_def in [
|
| 339 |
+
{"type": "mask-out", "masks": refined_masks, "label": ""},
|
| 340 |
+
{"type": "mask+pose", "masks": pred_masks, "label": "RTMDet-L"},
|
| 341 |
+
{"type": "mask+pose", "masks": refined_masks, "label": "BMP"},
|
| 342 |
+
]:
|
| 343 |
+
vis_img, colors = _visualize_predictions(
|
| 344 |
+
img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
|
| 345 |
+
)
|
| 346 |
+
returns.append(vis_img)
|
| 347 |
+
|
| 348 |
+
return returns
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def create_GIF(
|
| 352 |
+
img_path: Path,
|
| 353 |
+
output_root: Path,
|
| 354 |
+
bmp_x: int = 2,
|
| 355 |
+
) -> None:
|
| 356 |
+
"""
|
| 357 |
+
Compile iteration images into an animated GIF using ffmpeg.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
img_path (Path): Path to a sample iteration image.
|
| 361 |
+
output_root (Path): Directory to save the GIF.
|
| 362 |
+
bmp_x (int): Number of BMP iterations.
|
| 363 |
+
duration_per_frame (int): Frame display duration in ms.
|
| 364 |
+
|
| 365 |
+
Raises:
|
| 366 |
+
RuntimeError: If ffmpeg is not available or images are missing.
|
| 367 |
+
"""
|
| 368 |
+
display_dur = 1.5 # seconds
|
| 369 |
+
fade_dur = 1.0
|
| 370 |
+
fps = 10
|
| 371 |
+
scale_width = 300 # Resize width for GIF, height will be auto-scaled to maintain aspect ratio
|
| 372 |
+
|
| 373 |
+
# Check if ffmpeg is installed. If not, raise warning and return
|
| 374 |
+
if shutil.which("ffmpeg") is None:
|
| 375 |
+
print_log("FFMpeg is not installed. GIF creation will be skipped.", logger="current", level=logging.WARNING)
|
| 376 |
+
return
|
| 377 |
+
print_log("Creating GIF with FFmpeg...", logger="current")
|
| 378 |
+
|
| 379 |
+
dirname, filename = os.path.split(img_path)
|
| 380 |
+
img_name_wo_ext, _ = os.path.splitext(filename)
|
| 381 |
+
|
| 382 |
+
gif_image_names = [
|
| 383 |
+
"Detector_(out)",
|
| 384 |
+
"MaskPose_(in)",
|
| 385 |
+
"MaskPose_(out)",
|
| 386 |
+
"prompting_kpts",
|
| 387 |
+
"SAM_Masks",
|
| 388 |
+
"Mask-Out",
|
| 389 |
+
]
|
| 390 |
+
|
| 391 |
+
# Create black image of the same size as the last image
|
| 392 |
+
last_img_path = os.path.join(dirname, "{}_iter1_{}".format(img_name_wo_ext, gif_image_names[0]) + ".jpg")
|
| 393 |
+
last_img = cv2.imread(last_img_path)
|
| 394 |
+
if last_img is None:
|
| 395 |
+
print_log("Could not read image {}.".format(last_img_path), logger="current", level=logging.ERROR)
|
| 396 |
+
return
|
| 397 |
+
black_img = np.zeros_like(last_img)
|
| 398 |
+
cv2.imwrite(os.path.join(dirname, "black_image.jpg"), black_img)
|
| 399 |
+
|
| 400 |
+
gif_images = []
|
| 401 |
+
for iter in range(bmp_x):
|
| 402 |
+
iter_img_path = os.path.join(dirname, "{}_iter{}_".format(img_name_wo_ext, iter + 1))
|
| 403 |
+
for img_name in gif_image_names:
|
| 404 |
+
|
| 405 |
+
if iter + 1 == bmp_x and img_name == "Mask-Out":
|
| 406 |
+
# Skip the last iteration's Mask-Out image
|
| 407 |
+
continue
|
| 408 |
+
|
| 409 |
+
img_file = "{}{}.jpg".format(iter_img_path, img_name)
|
| 410 |
+
if not os.path.exists(img_file):
|
| 411 |
+
print_log("{} does not exist, skipping.".format(img_file), logger="current", level=logging.WARNING)
|
| 412 |
+
continue
|
| 413 |
+
gif_images.append(img_file)
|
| 414 |
+
|
| 415 |
+
if len(gif_images) == 0:
|
| 416 |
+
print_log("No images found for GIF creation.", logger="current", level=logging.WARNING)
|
| 417 |
+
return
|
| 418 |
+
|
| 419 |
+
# Add 'before' and 'after' images
|
| 420 |
+
after1_img = os.path.join(dirname, "{}_iter{}_Final_Poses.jpg".format(img_name_wo_ext, bmp_x))
|
| 421 |
+
after2_img = os.path.join(dirname, "{}_iter{}_SAM_Masks.jpg".format(img_name_wo_ext, bmp_x))
|
| 422 |
+
# gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
|
| 423 |
+
gif_images.append(after1_img)
|
| 424 |
+
gif_images.append(after2_img)
|
| 425 |
+
gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
|
| 426 |
+
|
| 427 |
+
# Create a GIF from the images
|
| 428 |
+
gif_output_path = os.path.join(output_root, "{}_bmp_{}x.gif".format(img_name_wo_ext, bmp_x))
|
| 429 |
+
|
| 430 |
+
# 0. Make sure images exist and are divisible by 2
|
| 431 |
+
for img in gif_images:
|
| 432 |
+
if not os.path.exists(img):
|
| 433 |
+
print_log("Image {} does not exist, skipping GIF creation.".format(img), logger="current", level=logging.WARNING)
|
| 434 |
+
return
|
| 435 |
+
# Check if image dimensions are divisible by 2
|
| 436 |
+
img_data = cv2.imread(img)
|
| 437 |
+
if img_data.shape[1] % 2 != 0 or img_data.shape[0] % 2 != 0:
|
| 438 |
+
print_log(
|
| 439 |
+
"Image {} dimensions are not divisible by 2, resizing.".format(img),
|
| 440 |
+
logger="current",
|
| 441 |
+
level=logging.WARNING,
|
| 442 |
+
)
|
| 443 |
+
resized_img = cv2.resize(img_data, (img_data.shape[1] // 2 * 2, img_data.shape[0] // 2 * 2))
|
| 444 |
+
cv2.imwrite(img, resized_img)
|
| 445 |
+
|
| 446 |
+
# 1. inputs
|
| 447 |
+
in_args = []
|
| 448 |
+
for p in gif_images:
|
| 449 |
+
in_args += ["-loop", "1", "-t", str(display_dur), "-i", p]
|
| 450 |
+
|
| 451 |
+
# 2. build xfade chain
|
| 452 |
+
n = len(gif_images)
|
| 453 |
+
parts = []
|
| 454 |
+
for i in range(1, n):
|
| 455 |
+
# left label: first is input [0:v], then [v1], [v2], …
|
| 456 |
+
left = "[{}:v]".format(i - 1) if i == 1 else "[v{}]".format(i - 1)
|
| 457 |
+
right = "[{}:v]".format(i)
|
| 458 |
+
out = "[v{}]".format(i)
|
| 459 |
+
offset = (i - 1) * (display_dur + fade_dur) + display_dur
|
| 460 |
+
parts.append(
|
| 461 |
+
"{}{}xfade=transition=fade:".format(left, right)
|
| 462 |
+
+ "duration={}:offset={:.3f}{}".format(fade_dur, offset, out)
|
| 463 |
+
)
|
| 464 |
+
filter_complex = ";".join(parts)
|
| 465 |
+
|
| 466 |
+
# 3. make MP4 slideshow
|
| 467 |
+
mp4 = "slideshow.mp4"
|
| 468 |
+
cmd1 = [
|
| 469 |
+
"ffmpeg",
|
| 470 |
+
"-loglevel",
|
| 471 |
+
"error",
|
| 472 |
+
"-v",
|
| 473 |
+
"quiet",
|
| 474 |
+
"-hide_banner",
|
| 475 |
+
"-y",
|
| 476 |
+
*in_args,
|
| 477 |
+
"-filter_complex",
|
| 478 |
+
filter_complex,
|
| 479 |
+
"-map",
|
| 480 |
+
"[v{}]".format(n - 1),
|
| 481 |
+
"-c:v",
|
| 482 |
+
"libx264",
|
| 483 |
+
"-pix_fmt",
|
| 484 |
+
"yuv420p",
|
| 485 |
+
mp4,
|
| 486 |
+
]
|
| 487 |
+
subprocess.run(cmd1, check=True)
|
| 488 |
+
|
| 489 |
+
# 4. palette
|
| 490 |
+
palette = "palette.png"
|
| 491 |
+
vf = "fps={}".format(fps)
|
| 492 |
+
if scale_width:
|
| 493 |
+
vf += ",scale={}: -1:flags=lanczos".format(scale_width)
|
| 494 |
+
|
| 495 |
+
# 5. generate palette
|
| 496 |
+
subprocess.run(
|
| 497 |
+
[
|
| 498 |
+
"ffmpeg",
|
| 499 |
+
"-loglevel",
|
| 500 |
+
"error",
|
| 501 |
+
"-v",
|
| 502 |
+
"quiet",
|
| 503 |
+
"-hide_banner",
|
| 504 |
+
"-y",
|
| 505 |
+
"-i",
|
| 506 |
+
mp4,
|
| 507 |
+
"-vf",
|
| 508 |
+
vf + ",palettegen",
|
| 509 |
+
palette,
|
| 510 |
+
],
|
| 511 |
+
check=True,
|
| 512 |
+
stdout=subprocess.DEVNULL,
|
| 513 |
+
stderr=subprocess.PIPE,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# 6. build final GIF
|
| 517 |
+
subprocess.run(
|
| 518 |
+
[
|
| 519 |
+
"ffmpeg",
|
| 520 |
+
"-loglevel",
|
| 521 |
+
"error",
|
| 522 |
+
"-v",
|
| 523 |
+
"quiet",
|
| 524 |
+
"-hide_banner",
|
| 525 |
+
"-y",
|
| 526 |
+
"-i",
|
| 527 |
+
mp4,
|
| 528 |
+
"-i",
|
| 529 |
+
palette,
|
| 530 |
+
"-lavfi",
|
| 531 |
+
vf + "[x];[x][1:v]paletteuse",
|
| 532 |
+
gif_output_path,
|
| 533 |
+
],
|
| 534 |
+
check=True,
|
| 535 |
+
stdout=subprocess.DEVNULL,
|
| 536 |
+
stderr=subprocess.PIPE,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Clean up temporary files
|
| 540 |
+
os.remove(mp4)
|
| 541 |
+
os.remove(palette)
|
| 542 |
+
os.remove(os.path.join(dirname, "black_image.jpg"))
|
| 543 |
+
|
| 544 |
+
print_log(f"GIF saved as '{gif_output_path}'", logger="current")
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def _update_bbox_by_mask(
|
| 548 |
+
bbox: List[int], mask_poly: Optional[List[List[int]]], image_shape: Tuple[int, int, int]
|
| 549 |
+
) -> List[int]:
|
| 550 |
+
"""
|
| 551 |
+
Adjust bounding box to tightly fit mask polygon.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
bbox (List[int]): Original [x, y, w, h].
|
| 555 |
+
mask_poly (Optional[List[List[int]]]): Polygon coordinates.
|
| 556 |
+
image_shape (Tuple[int,int,int]): Image shape (H, W, C).
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
List[int]: Updated [x, y, w, h] bounding box.
|
| 560 |
+
"""
|
| 561 |
+
if mask_poly is None or len(mask_poly) == 0:
|
| 562 |
+
return bbox
|
| 563 |
+
|
| 564 |
+
mask_rle = Mask.frPyObjects(mask_poly, image_shape[0], image_shape[1])
|
| 565 |
+
mask_rle = Mask.merge(mask_rle)
|
| 566 |
+
bbox_segm_xywh = Mask.toBbox(mask_rle)
|
| 567 |
+
bbox_segm_xyxy = np.array(
|
| 568 |
+
[
|
| 569 |
+
bbox_segm_xywh[0],
|
| 570 |
+
bbox_segm_xywh[1],
|
| 571 |
+
bbox_segm_xywh[0] + bbox_segm_xywh[2],
|
| 572 |
+
bbox_segm_xywh[1] + bbox_segm_xywh[3],
|
| 573 |
+
]
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
bbox = bbox_segm_xywh
|
| 577 |
+
|
| 578 |
+
return bbox.astype(int).tolist()
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def pose_nms(config: Any, image_kpts: np.ndarray, image_bboxes: np.ndarray, num_valid_kpts: np.ndarray) -> np.ndarray:
|
| 582 |
+
"""
|
| 583 |
+
Perform OKS-based non-maximum suppression on detected poses.
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
config (Any): Configuration with confidence_thr and oks_thr.
|
| 587 |
+
image_kpts (np.ndarray): Detected keypoints of shape (N, K, 3).
|
| 588 |
+
image_bboxes (np.ndarray): Corresponding bboxes (N,4).
|
| 589 |
+
num_valid_kpts (np.ndarray): Count of valid keypoints per instance.
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
np.ndarray: Indices of kept instances.
|
| 593 |
+
"""
|
| 594 |
+
# Sort image kpts by average score - lowest first
|
| 595 |
+
# scores = image_kpts[:, :, 2].mean(axis=1)
|
| 596 |
+
# sort_idx = np.argsort(scores)
|
| 597 |
+
# image_kpts = image_kpts[sort_idx, :, :]
|
| 598 |
+
|
| 599 |
+
# Compute OKS between all pairs of poses
|
| 600 |
+
oks_matrix = np.zeros((image_kpts.shape[0], image_kpts.shape[0]))
|
| 601 |
+
for i in range(image_kpts.shape[0]):
|
| 602 |
+
for j in range(image_kpts.shape[0]):
|
| 603 |
+
gt_bbox_xywh = image_bboxes[i].copy()
|
| 604 |
+
gt_bbox_xyxy = gt_bbox_xywh.copy()
|
| 605 |
+
gt_bbox_xyxy[2:] += gt_bbox_xyxy[:2]
|
| 606 |
+
gt = {
|
| 607 |
+
"keypoints": image_kpts[i].copy(),
|
| 608 |
+
"bbox": gt_bbox_xyxy,
|
| 609 |
+
"area": gt_bbox_xywh[2] * gt_bbox_xywh[3],
|
| 610 |
+
}
|
| 611 |
+
dt = {"keypoints": image_kpts[j].copy(), "bbox": gt_bbox_xyxy}
|
| 612 |
+
gt["keypoints"][:, 2] = (gt["keypoints"][:, 2] > config.confidence_thr) * 2
|
| 613 |
+
oks = compute_oks(gt, dt)
|
| 614 |
+
if oks > 1:
|
| 615 |
+
breakpoint()
|
| 616 |
+
oks_matrix[i, j] = oks
|
| 617 |
+
|
| 618 |
+
np.fill_diagonal(oks_matrix, -1)
|
| 619 |
+
is_subset = oks_matrix > config.oks_thr
|
| 620 |
+
|
| 621 |
+
remove_instances = []
|
| 622 |
+
while is_subset.any():
|
| 623 |
+
# Find the pair with the highest OKS
|
| 624 |
+
i, j = np.unravel_index(np.argmax(oks_matrix), oks_matrix.shape)
|
| 625 |
+
|
| 626 |
+
# Keep the one with the highest number of keypoints
|
| 627 |
+
if num_valid_kpts[i] > num_valid_kpts[j]:
|
| 628 |
+
remove_idx = j
|
| 629 |
+
else:
|
| 630 |
+
remove_idx = i
|
| 631 |
+
|
| 632 |
+
# Remove the column from is_subset
|
| 633 |
+
oks_matrix[:, remove_idx] = 0
|
| 634 |
+
oks_matrix[remove_idx, j] = 0
|
| 635 |
+
remove_instances.append(remove_idx)
|
| 636 |
+
is_subset = oks_matrix > config.oks_thr
|
| 637 |
+
|
| 638 |
+
keep_instances = np.setdiff1d(np.arange(image_kpts.shape[0]), remove_instances)
|
| 639 |
+
|
| 640 |
+
return keep_instances
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def compute_oks(gt: Dict[str, Any], dt: Dict[str, Any], use_area: bool = True, per_kpt: bool = False) -> float:
|
| 644 |
+
"""
|
| 645 |
+
Compute Object Keypoint Similarity (OKS) between ground-truth and detected poses.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
gt (Dict): Ground-truth keypoints and bbox info.
|
| 649 |
+
dt (Dict): Detected keypoints and bbox info.
|
| 650 |
+
use_area (bool): Whether to normalize by GT area.
|
| 651 |
+
per_kpt (bool): Whether to return per-keypoint OKS array.
|
| 652 |
+
|
| 653 |
+
Returns:
|
| 654 |
+
float: OKS score or mean OKS.
|
| 655 |
+
"""
|
| 656 |
+
sigmas = (
|
| 657 |
+
np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
|
| 658 |
+
/ 10.0
|
| 659 |
+
)
|
| 660 |
+
vars = (sigmas * 2) ** 2
|
| 661 |
+
k = len(sigmas)
|
| 662 |
+
visibility_condition = lambda x: x > 0
|
| 663 |
+
g = np.array(gt["keypoints"]).reshape(k, 3)
|
| 664 |
+
xg = g[:, 0]
|
| 665 |
+
yg = g[:, 1]
|
| 666 |
+
vg = g[:, 2]
|
| 667 |
+
k1 = np.count_nonzero(visibility_condition(vg))
|
| 668 |
+
bb = gt["bbox"]
|
| 669 |
+
x0 = bb[0] - bb[2]
|
| 670 |
+
x1 = bb[0] + bb[2] * 2
|
| 671 |
+
y0 = bb[1] - bb[3]
|
| 672 |
+
y1 = bb[1] + bb[3] * 2
|
| 673 |
+
|
| 674 |
+
d = np.array(dt["keypoints"]).reshape((k, 3))
|
| 675 |
+
xd = d[:, 0]
|
| 676 |
+
yd = d[:, 1]
|
| 677 |
+
|
| 678 |
+
if k1 > 0:
|
| 679 |
+
# measure the per-keypoint distance if keypoints visible
|
| 680 |
+
dx = xd - xg
|
| 681 |
+
dy = yd - yg
|
| 682 |
+
|
| 683 |
+
else:
|
| 684 |
+
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
|
| 685 |
+
z = np.zeros((k))
|
| 686 |
+
dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
|
| 687 |
+
dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
|
| 688 |
+
|
| 689 |
+
if use_area:
|
| 690 |
+
e = (dx**2 + dy**2) / vars / (gt["area"] + np.spacing(1)) / 2
|
| 691 |
+
else:
|
| 692 |
+
tmparea = gt["bbox"][3] * gt["bbox"][2] * 0.53
|
| 693 |
+
e = (dx**2 + dy**2) / vars / (tmparea + np.spacing(1)) / 2
|
| 694 |
+
|
| 695 |
+
if per_kpt:
|
| 696 |
+
oks = np.exp(-e)
|
| 697 |
+
if k1 > 0:
|
| 698 |
+
oks[~visibility_condition(vg)] = 0
|
| 699 |
+
|
| 700 |
+
else:
|
| 701 |
+
if k1 > 0:
|
| 702 |
+
e = e[visibility_condition(vg)]
|
| 703 |
+
oks = np.sum(np.exp(-e)) / e.shape[0]
|
| 704 |
+
|
| 705 |
+
return oks
|
demo/mm_utils.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module provides high-level interfaces to run MMDetection and MMPose
|
| 3 |
+
models sequentially. Users can call run_MMDetector and run_MMPose from
|
| 4 |
+
other scripts (e.g., bmp_demo.py) to perform object detection and
|
| 5 |
+
pose estimation in a clean, modular fashion.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from mmdet.apis import inference_detector
|
| 10 |
+
from mmengine.structures import InstanceData
|
| 11 |
+
|
| 12 |
+
from mmpose.apis import inference_topdown
|
| 13 |
+
from mmpose.evaluation.functional import nms
|
| 14 |
+
from mmpose.structures import merge_data_samples
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def run_MMDetector(detector, image, det_cat_id: int = 0, bbox_thr: float = 0.3, nms_thr: float = 0.3) -> InstanceData:
|
| 18 |
+
"""
|
| 19 |
+
Run an MMDetection model to detect bounding boxes (and masks) in an image.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
detector: An initialized MMDetection detector model.
|
| 23 |
+
image: Input image as file path or BGR numpy array.
|
| 24 |
+
det_cat_id: Category ID to filter detections (default is 0 for 'person').
|
| 25 |
+
bbox_thr: Minimum bounding box score threshold.
|
| 26 |
+
nms_thr: IoU threshold for Non-Maximum Suppression (NMS).
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
InstanceData: A structure containing filtered bboxes, bbox_scores, and masks (if available).
|
| 30 |
+
"""
|
| 31 |
+
# Run detection
|
| 32 |
+
det_result = inference_detector(detector, image)
|
| 33 |
+
pred_instances = det_result.pred_instances.cpu().numpy()
|
| 34 |
+
|
| 35 |
+
# Aggregate bboxes and scores into an (N, 5) array
|
| 36 |
+
bboxes_all = np.concatenate((pred_instances.bboxes, pred_instances.scores[:, None]), axis=1)
|
| 37 |
+
|
| 38 |
+
# Filter by category and score
|
| 39 |
+
keep_mask = np.logical_and(pred_instances.labels == det_cat_id, pred_instances.scores > bbox_thr)
|
| 40 |
+
if not np.any(keep_mask):
|
| 41 |
+
# Return empty structure if nothing passes threshold
|
| 42 |
+
return InstanceData(bboxes=np.zeros((0, 4)), bbox_scores=np.zeros((0,)), masks=np.zeros((0, 1, 1)))
|
| 43 |
+
|
| 44 |
+
bboxes = bboxes_all[keep_mask]
|
| 45 |
+
masks = getattr(pred_instances, "masks", None)
|
| 46 |
+
if masks is not None:
|
| 47 |
+
masks = masks[keep_mask]
|
| 48 |
+
|
| 49 |
+
# Sort detections by descending score
|
| 50 |
+
order = np.argsort(bboxes[:, 4])[::-1]
|
| 51 |
+
bboxes = bboxes[order]
|
| 52 |
+
if masks is not None:
|
| 53 |
+
masks = masks[order]
|
| 54 |
+
|
| 55 |
+
# Apply Non-Maximum Suppression
|
| 56 |
+
keep_indices = nms(bboxes, nms_thr)
|
| 57 |
+
bboxes = bboxes[keep_indices]
|
| 58 |
+
if masks is not None:
|
| 59 |
+
masks = masks[keep_indices]
|
| 60 |
+
|
| 61 |
+
# Construct InstanceData to return
|
| 62 |
+
det_instances = InstanceData(bboxes=bboxes[:, :4], bbox_scores=bboxes[:, 4], masks=masks)
|
| 63 |
+
return det_instances
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def run_MMPose(pose_estimator, image, detections: InstanceData, kpt_thr: float = 0.3) -> InstanceData:
|
| 67 |
+
"""
|
| 68 |
+
Run an MMPose top-down model to estimate human pose given detected bounding boxes.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
pose_estimator: An initialized MMPose model.
|
| 72 |
+
image: Input image as file path or RGB/BGR numpy array.
|
| 73 |
+
detections: InstanceData from run_MMDetector containing bboxes and masks.
|
| 74 |
+
kpt_thr: Minimum keypoint score threshold to filter low-confidence joints.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
InstanceData: A structure containing estimated keypoints, keypoint_scores,
|
| 78 |
+
original bboxes, and masks (if provided).
|
| 79 |
+
"""
|
| 80 |
+
# Extract bounding boxes
|
| 81 |
+
bboxes = detections.bboxes
|
| 82 |
+
if bboxes.shape[0] == 0:
|
| 83 |
+
# No detections => empty pose data
|
| 84 |
+
return InstanceData(
|
| 85 |
+
keypoints=np.zeros((0, 17, 3)),
|
| 86 |
+
keypoint_scores=np.zeros((0, 17)),
|
| 87 |
+
bboxes=bboxes,
|
| 88 |
+
bbox_scores=detections.bbox_scores,
|
| 89 |
+
masks=detections.masks,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Run top-down pose estimation
|
| 93 |
+
pose_results = inference_topdown(pose_estimator, image, bboxes, masks=detections.masks)
|
| 94 |
+
data_samples = merge_data_samples(pose_results)
|
| 95 |
+
|
| 96 |
+
# Attach masks back into the data_samples if available
|
| 97 |
+
if detections.masks is not None:
|
| 98 |
+
data_samples.pred_instances.pred_masks = detections.masks
|
| 99 |
+
|
| 100 |
+
# Filter out low-confidence keypoints
|
| 101 |
+
kp_scores = data_samples.pred_instances.keypoint_scores
|
| 102 |
+
kp_mask = kp_scores >= kpt_thr
|
| 103 |
+
# data_samples.pred_instances.keypoints[~kp_mask] = [0, 0, 0]
|
| 104 |
+
|
| 105 |
+
# Return final InstanceData for poses
|
| 106 |
+
return data_samples.pred_instances
|
demo/posevis_lite.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
NEUTRAL_COLOR = (52, 235, 107)
|
| 8 |
+
|
| 9 |
+
LEFT_ARM_COLOR = (216, 235, 52)
|
| 10 |
+
LEFT_LEG_COLOR = (235, 107, 52)
|
| 11 |
+
LEFT_SIDE_COLOR = (245, 188, 113)
|
| 12 |
+
LEFT_FACE_COLOR = (235, 52, 107)
|
| 13 |
+
|
| 14 |
+
RIGHT_ARM_COLOR = (52, 235, 216)
|
| 15 |
+
RIGHT_LEG_COLOR = (52, 107, 235)
|
| 16 |
+
RIGHT_SIDE_COLOR = (52, 171, 235)
|
| 17 |
+
RIGHT_FACE_COLOR = (107, 52, 235)
|
| 18 |
+
|
| 19 |
+
COCO_MARKERS = [
|
| 20 |
+
["nose", cv2.MARKER_CROSS, NEUTRAL_COLOR],
|
| 21 |
+
["left_eye", cv2.MARKER_SQUARE, LEFT_FACE_COLOR],
|
| 22 |
+
["right_eye", cv2.MARKER_SQUARE, RIGHT_FACE_COLOR],
|
| 23 |
+
["left_ear", cv2.MARKER_CROSS, LEFT_FACE_COLOR],
|
| 24 |
+
["right_ear", cv2.MARKER_CROSS, RIGHT_FACE_COLOR],
|
| 25 |
+
["left_shoulder", cv2.MARKER_TRIANGLE_UP, LEFT_ARM_COLOR],
|
| 26 |
+
["right_shoulder", cv2.MARKER_TRIANGLE_UP, RIGHT_ARM_COLOR],
|
| 27 |
+
["left_elbow", cv2.MARKER_SQUARE, LEFT_ARM_COLOR],
|
| 28 |
+
["right_elbow", cv2.MARKER_SQUARE, RIGHT_ARM_COLOR],
|
| 29 |
+
["left_wrist", cv2.MARKER_CROSS, LEFT_ARM_COLOR],
|
| 30 |
+
["right_wrist", cv2.MARKER_CROSS, RIGHT_ARM_COLOR],
|
| 31 |
+
["left_hip", cv2.MARKER_TRIANGLE_UP, LEFT_LEG_COLOR],
|
| 32 |
+
["right_hip", cv2.MARKER_TRIANGLE_UP, RIGHT_LEG_COLOR],
|
| 33 |
+
["left_knee", cv2.MARKER_SQUARE, LEFT_LEG_COLOR],
|
| 34 |
+
["right_knee", cv2.MARKER_SQUARE, RIGHT_LEG_COLOR],
|
| 35 |
+
["left_ankle", cv2.MARKER_TILTED_CROSS, LEFT_LEG_COLOR],
|
| 36 |
+
["right_ankle", cv2.MARKER_TILTED_CROSS, RIGHT_LEG_COLOR],
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
COCO_SKELETON = [
|
| 41 |
+
[[16, 14], LEFT_LEG_COLOR], # Left ankle - Left knee
|
| 42 |
+
[[14, 12], LEFT_LEG_COLOR], # Left knee - Left hip
|
| 43 |
+
[[17, 15], RIGHT_LEG_COLOR], # Right ankle - Right knee
|
| 44 |
+
[[15, 13], RIGHT_LEG_COLOR], # Right knee - Right hip
|
| 45 |
+
[[12, 13], NEUTRAL_COLOR], # Left hip - Right hip
|
| 46 |
+
[[6, 12], LEFT_SIDE_COLOR], # Left hip - Left shoulder
|
| 47 |
+
[[7, 13], RIGHT_SIDE_COLOR], # Right hip - Right shoulder
|
| 48 |
+
[[6, 7], NEUTRAL_COLOR], # Left shoulder - Right shoulder
|
| 49 |
+
[[6, 8], LEFT_ARM_COLOR], # Left shoulder - Left elbow
|
| 50 |
+
[[7, 9], RIGHT_ARM_COLOR], # Right shoulder - Right elbow
|
| 51 |
+
[[8, 10], LEFT_ARM_COLOR], # Left elbow - Left wrist
|
| 52 |
+
[[9, 11], RIGHT_ARM_COLOR], # Right elbow - Right wrist
|
| 53 |
+
[[2, 3], NEUTRAL_COLOR], # Left eye - Right eye
|
| 54 |
+
[[1, 2], LEFT_FACE_COLOR], # Nose - Left eye
|
| 55 |
+
[[1, 3], RIGHT_FACE_COLOR], # Nose - Right eye
|
| 56 |
+
[[2, 4], LEFT_FACE_COLOR], # Left eye - Left ear
|
| 57 |
+
[[3, 5], RIGHT_FACE_COLOR], # Right eye - Right ear
|
| 58 |
+
[[4, 6], LEFT_FACE_COLOR], # Left ear - Left shoulder
|
| 59 |
+
[[5, 7], RIGHT_FACE_COLOR], # Right ear - Right shoulder
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _draw_line(
|
| 64 |
+
img: np.ndarray,
|
| 65 |
+
start: Tuple[float, float],
|
| 66 |
+
stop: Tuple[float, float],
|
| 67 |
+
color: Tuple[int, int, int],
|
| 68 |
+
line_type: str,
|
| 69 |
+
thickness: int = 1,
|
| 70 |
+
) -> np.ndarray:
|
| 71 |
+
"""
|
| 72 |
+
Draw a line segment on an image, supporting solid, dashed, or dotted styles.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
img (np.ndarray): BGR image of shape (H, W, 3).
|
| 76 |
+
start (tuple of float): (x, y) start coordinates.
|
| 77 |
+
stop (tuple of float): (x, y) end coordinates.
|
| 78 |
+
color (tuple of int): BGR color values.
|
| 79 |
+
line_type (str): One of 'solid', 'dashed', or 'doted'.
|
| 80 |
+
thickness (int): Line thickness in pixels.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
np.ndarray: Image with the line drawn.
|
| 84 |
+
"""
|
| 85 |
+
start = np.array(start)[:2]
|
| 86 |
+
stop = np.array(stop)[:2]
|
| 87 |
+
if line_type.lower() == "solid":
|
| 88 |
+
img = cv2.line(
|
| 89 |
+
img,
|
| 90 |
+
(int(start[0]), int(start[1])),
|
| 91 |
+
(int(stop[0]), int(stop[1])),
|
| 92 |
+
color=(0, 0, 0),
|
| 93 |
+
thickness=thickness+1,
|
| 94 |
+
lineType=cv2.LINE_AA,
|
| 95 |
+
)
|
| 96 |
+
img = cv2.line(
|
| 97 |
+
img,
|
| 98 |
+
(int(start[0]), int(start[1])),
|
| 99 |
+
(int(stop[0]), int(stop[1])),
|
| 100 |
+
color=color,
|
| 101 |
+
thickness=thickness,
|
| 102 |
+
lineType=cv2.LINE_AA,
|
| 103 |
+
)
|
| 104 |
+
elif line_type.lower() == "dashed":
|
| 105 |
+
delta = stop - start
|
| 106 |
+
length = np.linalg.norm(delta)
|
| 107 |
+
frac = np.linspace(0, 1, num=int(length / 5), endpoint=True)
|
| 108 |
+
for i in range(0, len(frac) - 1, 2):
|
| 109 |
+
s = start + frac[i] * delta
|
| 110 |
+
e = start + frac[i + 1] * delta
|
| 111 |
+
img = cv2.line(
|
| 112 |
+
img,
|
| 113 |
+
(int(s[0]), int(s[1])),
|
| 114 |
+
(int(e[0]), int(e[1])),
|
| 115 |
+
color=color,
|
| 116 |
+
thickness=thickness,
|
| 117 |
+
lineType=cv2.LINE_AA,
|
| 118 |
+
)
|
| 119 |
+
elif line_type.lower() == "doted":
|
| 120 |
+
delta = stop - start
|
| 121 |
+
length = np.linalg.norm(delta)
|
| 122 |
+
frac = np.linspace(0, 1, num=int(length / 5), endpoint=True)
|
| 123 |
+
for i in range(0, len(frac)):
|
| 124 |
+
s = start + frac[i] * delta
|
| 125 |
+
img = cv2.circle(
|
| 126 |
+
img,
|
| 127 |
+
(int(s[0]), int(s[1])),
|
| 128 |
+
radius=max(thickness // 2, 1),
|
| 129 |
+
color=color,
|
| 130 |
+
thickness=-1,
|
| 131 |
+
lineType=cv2.LINE_AA,
|
| 132 |
+
)
|
| 133 |
+
return img
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def pose_visualization(
|
| 137 |
+
img: Union[str, np.ndarray],
|
| 138 |
+
keypoints: Union[Dict[str, Any], np.ndarray],
|
| 139 |
+
format: str = "COCO",
|
| 140 |
+
greyness: float = 1.0,
|
| 141 |
+
show_markers: bool = True,
|
| 142 |
+
show_bones: bool = True,
|
| 143 |
+
line_type: str = "solid",
|
| 144 |
+
width_multiplier: float = 1.0,
|
| 145 |
+
bbox_width_multiplier: float = 1.0,
|
| 146 |
+
show_bbox: bool = False,
|
| 147 |
+
differ_individuals: bool = False,
|
| 148 |
+
confidence_thr: float = 0.3,
|
| 149 |
+
errors: Optional[np.ndarray] = None,
|
| 150 |
+
color: Optional[Tuple[int, int, int]] = None,
|
| 151 |
+
keep_image_size: bool = False,
|
| 152 |
+
return_padding: bool = False,
|
| 153 |
+
) -> Union[np.ndarray, Tuple[np.ndarray, List[int]]]:
|
| 154 |
+
"""
|
| 155 |
+
Overlay pose keypoints and skeleton on an image.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
img (str or np.ndarray): Path to image file or BGR image array.
|
| 159 |
+
keypoints (dict or np.ndarray): Either a dict with 'bbox' and 'keypoints' or
|
| 160 |
+
an array of shape (17, 2 or 3) or multiple poses stacked.
|
| 161 |
+
format (str): Keypoint format, currently only 'COCO'.
|
| 162 |
+
greyness (float): Factor for bone/marker color intensity (0.0-1.0).
|
| 163 |
+
show_markers (bool): Whether to draw keypoint markers.
|
| 164 |
+
show_bones (bool): Whether to draw skeleton bones.
|
| 165 |
+
line_type (str): One of 'solid', 'dashed', 'doted' for bone style.
|
| 166 |
+
width_multiplier (float): Line width scaling factor for bones.
|
| 167 |
+
bbox_width_multiplier (float): Line width scaling factor for bounding box.
|
| 168 |
+
show_bbox (bool): Whether to draw bounding box around keypoints.
|
| 169 |
+
differ_individuals (bool): Use distinct color per individual pose.
|
| 170 |
+
confidence_thr (float): Confidence threshold for keypoint visibility.
|
| 171 |
+
errors (np.ndarray or None): Optional array of per-kpt errors (17,1).
|
| 172 |
+
color (tuple or None): Override color for markers and bones.
|
| 173 |
+
keep_image_size (bool): Prevent image padding for out-of-bounds keypoints.
|
| 174 |
+
return_padding (bool): If True, also return padding offsets [top,bottom,left,right].
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
np.ndarray or (np.ndarray, list of int): Annotated image, and optional
|
| 178 |
+
padding offsets if `return_padding` is True.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
bbox = None
|
| 182 |
+
if isinstance(keypoints, dict):
|
| 183 |
+
try:
|
| 184 |
+
bbox = np.array(keypoints["bbox"]).flatten()
|
| 185 |
+
except KeyError:
|
| 186 |
+
pass
|
| 187 |
+
keypoints = np.array(keypoints["keypoints"])
|
| 188 |
+
|
| 189 |
+
# If keypoints is a list of poses, draw them all
|
| 190 |
+
if len(keypoints) % 17 != 0 or keypoints.ndim == 3:
|
| 191 |
+
|
| 192 |
+
if color is not None:
|
| 193 |
+
if not isinstance(color, (list, tuple)):
|
| 194 |
+
color = [color for keypoint in keypoints]
|
| 195 |
+
else:
|
| 196 |
+
color = [None for keypoint in keypoints]
|
| 197 |
+
|
| 198 |
+
max_padding = [0, 0, 0, 0]
|
| 199 |
+
for keypoint, clr in zip(keypoints, color):
|
| 200 |
+
img = pose_visualization(
|
| 201 |
+
img,
|
| 202 |
+
keypoint,
|
| 203 |
+
format=format,
|
| 204 |
+
greyness=greyness,
|
| 205 |
+
show_markers=show_markers,
|
| 206 |
+
show_bones=show_bones,
|
| 207 |
+
line_type=line_type,
|
| 208 |
+
width_multiplier=width_multiplier,
|
| 209 |
+
bbox_width_multiplier=bbox_width_multiplier,
|
| 210 |
+
show_bbox=show_bbox,
|
| 211 |
+
differ_individuals=differ_individuals,
|
| 212 |
+
color=clr,
|
| 213 |
+
confidence_thr=confidence_thr,
|
| 214 |
+
keep_image_size=keep_image_size,
|
| 215 |
+
return_padding=return_padding,
|
| 216 |
+
)
|
| 217 |
+
if return_padding:
|
| 218 |
+
img, padding = img
|
| 219 |
+
max_padding = [max(max_padding[i], int(padding[i])) for i in range(4)]
|
| 220 |
+
|
| 221 |
+
if return_padding:
|
| 222 |
+
return img, max_padding
|
| 223 |
+
else:
|
| 224 |
+
return img
|
| 225 |
+
|
| 226 |
+
keypoints = np.array(keypoints).reshape(17, -1)
|
| 227 |
+
# If keypoint visibility is not provided, assume all keypoints are visible
|
| 228 |
+
if keypoints.shape[1] == 2:
|
| 229 |
+
keypoints = np.hstack([keypoints, np.ones((17, 1)) * 2])
|
| 230 |
+
|
| 231 |
+
assert keypoints.shape[1] == 3, "Keypoints should be in the format (x, y, visibility)"
|
| 232 |
+
assert keypoints.shape[0] == 17, "Keypoints should be in the format (x, y, visibility)"
|
| 233 |
+
|
| 234 |
+
if errors is not None:
|
| 235 |
+
errors = np.array(errors).reshape(17, -1)
|
| 236 |
+
assert errors.shape[1] == 1, "Errors should be in the format (K, r)"
|
| 237 |
+
assert errors.shape[0] == 17, "Errors should be in the format (K, r)"
|
| 238 |
+
else:
|
| 239 |
+
errors = np.ones((17, 1)) * np.nan
|
| 240 |
+
|
| 241 |
+
# If keypoint visibility is float between 0 and 1, it is detection
|
| 242 |
+
# If conf < confidence_thr: conf = 1
|
| 243 |
+
# If conf >= confidence_thr: conf = 2
|
| 244 |
+
vis_is_float = np.any(np.logical_and(keypoints[:, -1] > 0, keypoints[:, -1] < 1))
|
| 245 |
+
if keypoints.shape[1] == 3 and vis_is_float:
|
| 246 |
+
# print("before", keypoints[:, -1])
|
| 247 |
+
lower_idx = keypoints[:, -1] < confidence_thr
|
| 248 |
+
keypoints[lower_idx, -1] = 1
|
| 249 |
+
keypoints[~lower_idx, -1] = 2
|
| 250 |
+
# print("after", keypoints[:, -1])
|
| 251 |
+
# print("-"*20)
|
| 252 |
+
|
| 253 |
+
# All visibility values should be ints
|
| 254 |
+
keypoints[:, -1] = keypoints[:, -1].astype(int)
|
| 255 |
+
|
| 256 |
+
if isinstance(img, str):
|
| 257 |
+
img = cv2.imread(img)
|
| 258 |
+
|
| 259 |
+
if img is None:
|
| 260 |
+
if return_padding:
|
| 261 |
+
return None, [0, 0, 0, 0]
|
| 262 |
+
else:
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
+
if not (keypoints[:, 2] > 0).any():
|
| 266 |
+
if return_padding:
|
| 267 |
+
return img, [0, 0, 0, 0]
|
| 268 |
+
else:
|
| 269 |
+
return img
|
| 270 |
+
|
| 271 |
+
valid_kpts = (keypoints[:, 0] > 0) & (keypoints[:, 1] > 0)
|
| 272 |
+
num_valid_kpts = np.sum(valid_kpts)
|
| 273 |
+
|
| 274 |
+
if num_valid_kpts == 0:
|
| 275 |
+
if return_padding:
|
| 276 |
+
return img, [0, 0, 0, 0]
|
| 277 |
+
else:
|
| 278 |
+
return img
|
| 279 |
+
|
| 280 |
+
min_x_kpts = np.min(keypoints[keypoints[:, 2] > 0, 0])
|
| 281 |
+
min_y_kpts = np.min(keypoints[keypoints[:, 2] > 0, 1])
|
| 282 |
+
max_x_kpts = np.max(keypoints[keypoints[:, 2] > 0, 0])
|
| 283 |
+
max_y_kpts = np.max(keypoints[keypoints[:, 2] > 0, 1])
|
| 284 |
+
if bbox is None:
|
| 285 |
+
min_x = min_x_kpts
|
| 286 |
+
min_y = min_y_kpts
|
| 287 |
+
max_x = max_x_kpts
|
| 288 |
+
max_y = max_y_kpts
|
| 289 |
+
else:
|
| 290 |
+
min_x = bbox[0]
|
| 291 |
+
min_y = bbox[1]
|
| 292 |
+
max_x = bbox[2]
|
| 293 |
+
max_y = bbox[3]
|
| 294 |
+
|
| 295 |
+
max_area = (max_x - min_x) * (max_y - min_y)
|
| 296 |
+
diagonal = np.sqrt((max_x - min_x) ** 2 + (max_y - min_y) ** 2)
|
| 297 |
+
line_width = max(int(np.sqrt(max_area) / 500 * width_multiplier), 1)
|
| 298 |
+
bbox_line_width = max(int(np.sqrt(max_area) / 500 * bbox_width_multiplier), 1)
|
| 299 |
+
marker_size = max(int(np.sqrt(max_area) / 80), 1)
|
| 300 |
+
invisible_marker_size = max(int(np.sqrt(max_area) / 100), 1)
|
| 301 |
+
marker_thickness = max(int(np.sqrt(max_area) / 100), 1)
|
| 302 |
+
|
| 303 |
+
if differ_individuals:
|
| 304 |
+
if color is not None:
|
| 305 |
+
instance_color = color
|
| 306 |
+
else:
|
| 307 |
+
instance_color = np.random.randint(0, 255, size=(3,)).tolist()
|
| 308 |
+
instance_color = tuple(instance_color)
|
| 309 |
+
|
| 310 |
+
# Pad image with dark gray if keypoints are outside the image
|
| 311 |
+
if not keep_image_size:
|
| 312 |
+
padding = [
|
| 313 |
+
max(0, -min_y_kpts),
|
| 314 |
+
max(0, max_y_kpts - img.shape[0]),
|
| 315 |
+
max(0, -min_x_kpts),
|
| 316 |
+
max(0, max_x_kpts - img.shape[1]),
|
| 317 |
+
]
|
| 318 |
+
padding = [int(p) for p in padding]
|
| 319 |
+
img = cv2.copyMakeBorder(
|
| 320 |
+
img,
|
| 321 |
+
padding[0],
|
| 322 |
+
padding[1],
|
| 323 |
+
padding[2],
|
| 324 |
+
padding[3],
|
| 325 |
+
cv2.BORDER_CONSTANT,
|
| 326 |
+
value=(80, 80, 80),
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Add padding to bbox and kpts
|
| 330 |
+
value_x_to_add = max(0, -min_x_kpts)
|
| 331 |
+
value_y_to_add = max(0, -min_y_kpts)
|
| 332 |
+
keypoints[keypoints[:, 2] > 0, 0] += value_x_to_add
|
| 333 |
+
keypoints[keypoints[:, 2] > 0, 1] += value_y_to_add
|
| 334 |
+
if bbox is not None:
|
| 335 |
+
bbox[0] += value_x_to_add
|
| 336 |
+
bbox[1] += value_y_to_add
|
| 337 |
+
bbox[2] += value_x_to_add
|
| 338 |
+
bbox[3] += value_y_to_add
|
| 339 |
+
|
| 340 |
+
if show_bbox and not (bbox is None):
|
| 341 |
+
pts = [
|
| 342 |
+
(bbox[0], bbox[1]),
|
| 343 |
+
(bbox[0], bbox[3]),
|
| 344 |
+
(bbox[2], bbox[3]),
|
| 345 |
+
(bbox[2], bbox[1]),
|
| 346 |
+
(bbox[0], bbox[1]),
|
| 347 |
+
]
|
| 348 |
+
for i in range(len(pts) - 1):
|
| 349 |
+
if differ_individuals:
|
| 350 |
+
img = _draw_line(img, pts[i], pts[i + 1], instance_color, "doted", thickness=bbox_line_width)
|
| 351 |
+
else:
|
| 352 |
+
img = _draw_line(img, pts[i], pts[i + 1], (0, 255, 0), line_type, thickness=bbox_line_width)
|
| 353 |
+
|
| 354 |
+
if show_markers:
|
| 355 |
+
for kpt, marker_info, err in zip(keypoints, COCO_MARKERS, errors):
|
| 356 |
+
if kpt[0] == 0 and kpt[1] == 0:
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
if kpt[2] != 2:
|
| 360 |
+
color = (140, 140, 140)
|
| 361 |
+
elif differ_individuals:
|
| 362 |
+
color = instance_color
|
| 363 |
+
else:
|
| 364 |
+
color = marker_info[2]
|
| 365 |
+
|
| 366 |
+
if kpt[2] == 1:
|
| 367 |
+
img_overlay = img.copy()
|
| 368 |
+
img_overlay = cv2.drawMarker(
|
| 369 |
+
img_overlay,
|
| 370 |
+
(int(kpt[0]), int(kpt[1])),
|
| 371 |
+
color=color,
|
| 372 |
+
markerType=marker_info[1],
|
| 373 |
+
markerSize=marker_size,
|
| 374 |
+
thickness=marker_thickness,
|
| 375 |
+
)
|
| 376 |
+
img = cv2.addWeighted(img_overlay, 0.4, img, 0.6, 0)
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
img = cv2.drawMarker(
|
| 380 |
+
img,
|
| 381 |
+
(int(kpt[0]), int(kpt[1])),
|
| 382 |
+
color=color,
|
| 383 |
+
markerType=marker_info[1],
|
| 384 |
+
markerSize=invisible_marker_size if kpt[2] == 1 else marker_size,
|
| 385 |
+
thickness=marker_thickness,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
if not np.isnan(err).any():
|
| 389 |
+
radius = err * diagonal
|
| 390 |
+
clr = (0, 0, 255) if "solid" in line_type else (0, 255, 0)
|
| 391 |
+
plus = 1 if "solid" in line_type else -1
|
| 392 |
+
img = cv2.circle(
|
| 393 |
+
img,
|
| 394 |
+
(int(kpt[0]), int(kpt[1])),
|
| 395 |
+
radius=int(radius),
|
| 396 |
+
color=clr,
|
| 397 |
+
thickness=1,
|
| 398 |
+
lineType=cv2.LINE_AA,
|
| 399 |
+
)
|
| 400 |
+
dx = np.sqrt(radius**2 / 2)
|
| 401 |
+
img = cv2.line(
|
| 402 |
+
img,
|
| 403 |
+
(int(kpt[0]), int(kpt[1])),
|
| 404 |
+
(int(kpt[0] + plus * dx), int(kpt[1] - dx)),
|
| 405 |
+
color=clr,
|
| 406 |
+
thickness=1,
|
| 407 |
+
lineType=cv2.LINE_AA,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if show_bones:
|
| 411 |
+
for bone_info in COCO_SKELETON:
|
| 412 |
+
kp1 = keypoints[bone_info[0][0] - 1, :]
|
| 413 |
+
kp2 = keypoints[bone_info[0][1] - 1, :]
|
| 414 |
+
|
| 415 |
+
if (kp1[0] == 0 and kp1[1] == 0) or (kp2[0] == 0 and kp2[1] == 0):
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
dashed = kp1[2] == 1 or kp2[2] == 1
|
| 419 |
+
|
| 420 |
+
if differ_individuals:
|
| 421 |
+
color = np.array(instance_color)
|
| 422 |
+
else:
|
| 423 |
+
color = np.array(bone_info[1])
|
| 424 |
+
color = (color * greyness).astype(int).tolist()
|
| 425 |
+
|
| 426 |
+
if dashed:
|
| 427 |
+
img_overlay = img.copy()
|
| 428 |
+
img_overlay = _draw_line(img_overlay, kp1, kp2, color, line_type, thickness=line_width)
|
| 429 |
+
img = cv2.addWeighted(img_overlay, 0.4, img, 0.6, 0)
|
| 430 |
+
|
| 431 |
+
else:
|
| 432 |
+
img = _draw_line(img, kp1, kp2, color, line_type, thickness=line_width)
|
| 433 |
+
|
| 434 |
+
if return_padding:
|
| 435 |
+
return img, padding
|
| 436 |
+
else:
|
| 437 |
+
return img
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
kpts = np.array(
|
| 442 |
+
[
|
| 443 |
+
344,
|
| 444 |
+
222,
|
| 445 |
+
2,
|
| 446 |
+
356,
|
| 447 |
+
211,
|
| 448 |
+
2,
|
| 449 |
+
330,
|
| 450 |
+
211,
|
| 451 |
+
2,
|
| 452 |
+
372,
|
| 453 |
+
220,
|
| 454 |
+
2,
|
| 455 |
+
309,
|
| 456 |
+
224,
|
| 457 |
+
2,
|
| 458 |
+
413,
|
| 459 |
+
279,
|
| 460 |
+
2,
|
| 461 |
+
274,
|
| 462 |
+
300,
|
| 463 |
+
2,
|
| 464 |
+
444,
|
| 465 |
+
372,
|
| 466 |
+
2,
|
| 467 |
+
261,
|
| 468 |
+
396,
|
| 469 |
+
2,
|
| 470 |
+
398,
|
| 471 |
+
359,
|
| 472 |
+
2,
|
| 473 |
+
316,
|
| 474 |
+
372,
|
| 475 |
+
2,
|
| 476 |
+
407,
|
| 477 |
+
489,
|
| 478 |
+
2,
|
| 479 |
+
185,
|
| 480 |
+
580,
|
| 481 |
+
2,
|
| 482 |
+
0,
|
| 483 |
+
0,
|
| 484 |
+
0,
|
| 485 |
+
0,
|
| 486 |
+
0,
|
| 487 |
+
0,
|
| 488 |
+
0,
|
| 489 |
+
0,
|
| 490 |
+
0,
|
| 491 |
+
0,
|
| 492 |
+
0,
|
| 493 |
+
0,
|
| 494 |
+
]
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
kpts = kpts.reshape(-1, 3)
|
| 498 |
+
kpts[:, -1] = np.random.randint(1, 3, size=(17,))
|
| 499 |
+
|
| 500 |
+
img = pose_visualization("demo/posevis_test.jpg", kpts, show_markers=True, line_type="solid")
|
| 501 |
+
|
| 502 |
+
kpts2 = kpts.copy()
|
| 503 |
+
kpts2[kpts2[:, 1] > 0, :2] += 10
|
| 504 |
+
img = pose_visualization(img, kpts2, show_markers=False, line_type="doted")
|
| 505 |
+
|
| 506 |
+
os.makedirs("demo/outputs", exist_ok=True)
|
| 507 |
+
cv2.imwrite("demo/outputs/posevis_test_out.jpg", img)
|
demo/sam2_utils.py
ADDED
|
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM2 utilities for BMP demo:
|
| 3 |
+
- Build and prepare SAM model
|
| 4 |
+
- Convert poses to segmentation
|
| 5 |
+
- Compute mask-pose consistency
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from mmengine.structures import InstanceData
|
| 13 |
+
from pycocotools import mask as Mask
|
| 14 |
+
from sam2.build_sam import build_sam2
|
| 15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
+
|
| 17 |
+
# Threshold for keypoint validity in mask-pose consistency
|
| 18 |
+
STRICT_KPT_THRESHOLD: float = 0.5
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _validate_sam_args(sam_args):
|
| 22 |
+
"""
|
| 23 |
+
Validate that all required sam_args attributes are present.
|
| 24 |
+
"""
|
| 25 |
+
required = [
|
| 26 |
+
"crop",
|
| 27 |
+
"use_bbox",
|
| 28 |
+
"confidence_thr",
|
| 29 |
+
"ignore_small_bboxes",
|
| 30 |
+
"num_pos_keypoints",
|
| 31 |
+
"num_pos_keypoints_if_crowd",
|
| 32 |
+
"crowd_by_max_iou",
|
| 33 |
+
"batch",
|
| 34 |
+
"exclusive_masks",
|
| 35 |
+
"extend_bbox",
|
| 36 |
+
"pose_mask_consistency",
|
| 37 |
+
"visibility_thr",
|
| 38 |
+
]
|
| 39 |
+
for param in required:
|
| 40 |
+
if not hasattr(sam_args, param):
|
| 41 |
+
raise AttributeError(f"Missing required arg {param} in sam_args")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_max_ious(bboxes: List[np.ndarray]) -> np.ndarray:
|
| 45 |
+
"""
|
| 46 |
+
Compute maximum IoU for each bbox against others.
|
| 47 |
+
"""
|
| 48 |
+
is_crowd = [0] * len(bboxes)
|
| 49 |
+
ious = Mask.iou(bboxes, bboxes, is_crowd)
|
| 50 |
+
mat = np.array(ious)
|
| 51 |
+
np.fill_diagonal(mat, 0)
|
| 52 |
+
return mat.max(axis=1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _compute_one_mask_pose_consistency(
|
| 56 |
+
mask: np.ndarray, pos_keypoints: Optional[np.ndarray] = None, neg_keypoints: Optional[np.ndarray] = None
|
| 57 |
+
) -> float:
|
| 58 |
+
"""
|
| 59 |
+
Compute a consistency score between a mask and given keypoints.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
mask (np.ndarray): Binary mask of shape (H, W).
|
| 63 |
+
pos_keypoints (Optional[np.ndarray]): Positive keypoints array (N, 3).
|
| 64 |
+
neg_keypoints (Optional[np.ndarray]): Negative keypoints array (M, 3).
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
float: Weighted mean of positive and negative keypoint consistency.
|
| 68 |
+
"""
|
| 69 |
+
if mask is None:
|
| 70 |
+
return 0.0
|
| 71 |
+
|
| 72 |
+
def _mean_inside(points: np.ndarray) -> float:
|
| 73 |
+
if points.size == 0:
|
| 74 |
+
return 0.0
|
| 75 |
+
pts_int = np.floor(points[:, :2]).astype(int)
|
| 76 |
+
pts_int[:, 0] = np.clip(pts_int[:, 0], 0, mask.shape[1] - 1)
|
| 77 |
+
pts_int[:, 1] = np.clip(pts_int[:, 1], 0, mask.shape[0] - 1)
|
| 78 |
+
vals = mask[pts_int[:, 1], pts_int[:, 0]]
|
| 79 |
+
return vals.mean() if vals.size > 0 else 0.0
|
| 80 |
+
|
| 81 |
+
pos_mean = 0.0
|
| 82 |
+
if pos_keypoints is not None:
|
| 83 |
+
valid = pos_keypoints[:, 2] > STRICT_KPT_THRESHOLD
|
| 84 |
+
pos_mean = _mean_inside(pos_keypoints[valid])
|
| 85 |
+
|
| 86 |
+
neg_mean = 0.0
|
| 87 |
+
if neg_keypoints is not None:
|
| 88 |
+
valid = neg_keypoints[:, 2] > STRICT_KPT_THRESHOLD
|
| 89 |
+
pts = neg_keypoints[valid][:, :2]
|
| 90 |
+
inside = mask[np.floor(pts[:, 1]).astype(int), np.floor(pts[:, 0]).astype(int)]
|
| 91 |
+
neg_mean = (~inside.astype(bool)).mean() if inside.size > 0 else 0.0
|
| 92 |
+
|
| 93 |
+
return 0.5 * pos_mean + 0.5 * neg_mean
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _select_keypoints(
|
| 97 |
+
args: Any,
|
| 98 |
+
kpts: np.ndarray,
|
| 99 |
+
num_visible: int,
|
| 100 |
+
bbox: Optional[Tuple[float, float, float, float]] = None,
|
| 101 |
+
method: Optional[str] = "distance+confidence",
|
| 102 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 103 |
+
"""
|
| 104 |
+
Select and order keypoints for SAM prompting based on specified method.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
args: Configuration object with selection_method and visibility_thr attributes.
|
| 108 |
+
kpts (np.ndarray): Keypoints array of shape (K, 3).
|
| 109 |
+
num_visible (int): Number of keypoints above visibility threshold.
|
| 110 |
+
bbox (Optional[Tuple]): Optional bbox for distance methods.
|
| 111 |
+
method (Optional[str]): Override selection method.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tuple[np.ndarray, np.ndarray]: Selected keypoint coordinates (N,2) and confidences (N,).
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: If an unknown method is specified.
|
| 118 |
+
"""
|
| 119 |
+
if num_visible == 0:
|
| 120 |
+
return kpts[:, :2], kpts[:, 2]
|
| 121 |
+
|
| 122 |
+
methods = ["confidence", "distance", "distance+confidence", "closest"]
|
| 123 |
+
sel_method = method or args.selection_method
|
| 124 |
+
if sel_method not in methods:
|
| 125 |
+
raise ValueError("Unknown method for keypoint selection: {}".format(sel_method))
|
| 126 |
+
|
| 127 |
+
# Select at maximum keypoint from the face
|
| 128 |
+
facial_kpts = kpts[:3, :]
|
| 129 |
+
facial_conf = kpts[:3, 2]
|
| 130 |
+
facial_point = facial_kpts[np.argmax(facial_conf)]
|
| 131 |
+
if facial_point[-1] >= args.visibility_thr:
|
| 132 |
+
kpts = np.concatenate([facial_point[None, :], kpts[3:]], axis=0)
|
| 133 |
+
|
| 134 |
+
conf = kpts[:, 2]
|
| 135 |
+
vis_mask = conf >= args.visibility_thr
|
| 136 |
+
coords = kpts[vis_mask, :2]
|
| 137 |
+
confs = conf[vis_mask]
|
| 138 |
+
|
| 139 |
+
if sel_method == "confidence":
|
| 140 |
+
order = np.argsort(confs)[::-1]
|
| 141 |
+
coords = coords[order]
|
| 142 |
+
confs = confs[order]
|
| 143 |
+
elif sel_method == "distance":
|
| 144 |
+
if bbox is None:
|
| 145 |
+
bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()])
|
| 146 |
+
else:
|
| 147 |
+
bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
|
| 148 |
+
dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1)
|
| 149 |
+
dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2)
|
| 150 |
+
np.fill_diagonal(dist_matrix, np.inf)
|
| 151 |
+
min_inter_dist = np.min(dist_matrix, axis=1)
|
| 152 |
+
order = np.argsort(dists + 3 * min_inter_dist)[::-1]
|
| 153 |
+
coords = coords[order, :2]
|
| 154 |
+
confs = confs[order]
|
| 155 |
+
elif sel_method == "distance+confidence":
|
| 156 |
+
order = np.argsort(confs)[::-1]
|
| 157 |
+
confidences = kpts[order, 2]
|
| 158 |
+
coords = coords[order, :2]
|
| 159 |
+
confs = confs[order]
|
| 160 |
+
|
| 161 |
+
dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2)
|
| 162 |
+
|
| 163 |
+
selected_idx = [0]
|
| 164 |
+
confidences[0] = -1
|
| 165 |
+
for _ in range(coords.shape[0] - 1):
|
| 166 |
+
min_dist = np.min(dist_matrix[:, selected_idx], axis=1)
|
| 167 |
+
min_dist[confidences < np.percentile(confidences, 80)] = -1
|
| 168 |
+
|
| 169 |
+
next_idx = np.argmax(min_dist)
|
| 170 |
+
selected_idx.append(next_idx)
|
| 171 |
+
confidences[next_idx] = -1
|
| 172 |
+
|
| 173 |
+
coords = coords[selected_idx]
|
| 174 |
+
confs = confs[selected_idx]
|
| 175 |
+
elif sel_method == "closest":
|
| 176 |
+
coords = coords[confs > STRICT_KPT_THRESHOLD, :]
|
| 177 |
+
confs = confs[confs > STRICT_KPT_THRESHOLD]
|
| 178 |
+
if bbox is None:
|
| 179 |
+
bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()])
|
| 180 |
+
else:
|
| 181 |
+
bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
|
| 182 |
+
dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1)
|
| 183 |
+
order = np.argsort(dists)
|
| 184 |
+
coords = coords[order, :2]
|
| 185 |
+
confs = confs[order]
|
| 186 |
+
|
| 187 |
+
return coords, confs
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def prepare_model(model_cfg: Any, model_checkpoint: str) -> SAM2ImagePredictor:
|
| 191 |
+
"""
|
| 192 |
+
Build and return a SAM2ImagePredictor model on the appropriate device.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
model_cfg: Configuration for SAM2 model.
|
| 196 |
+
model_checkpoint (str): Path to model checkpoint.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
SAM2ImagePredictor: Initialized SAM2 image predictor.
|
| 200 |
+
"""
|
| 201 |
+
if torch.cuda.is_available():
|
| 202 |
+
device = torch.device("cuda")
|
| 203 |
+
elif torch.backends.mps.is_available():
|
| 204 |
+
device = torch.device("mps")
|
| 205 |
+
else:
|
| 206 |
+
device = torch.device("cpu")
|
| 207 |
+
|
| 208 |
+
sam2 = build_sam2(model_cfg, model_checkpoint, device=device, apply_postprocessing=True)
|
| 209 |
+
model = SAM2ImagePredictor(
|
| 210 |
+
sam2,
|
| 211 |
+
max_hole_area=10.0,
|
| 212 |
+
max_sprinkle_area=50.0,
|
| 213 |
+
)
|
| 214 |
+
return model
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _compute_mask_pose_consistency(masks: List[np.ndarray], keypoints_list: List[np.ndarray]) -> np.ndarray:
|
| 218 |
+
"""
|
| 219 |
+
Compute mask-pose consistency score for each mask-keypoints pair.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
masks (List[np.ndarray]): Binary masks list.
|
| 223 |
+
keypoints_list (List[np.ndarray]): List of keypoint arrays per instance.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
np.ndarray: Consistency scores array of shape (N,).
|
| 227 |
+
"""
|
| 228 |
+
scores: List[float] = []
|
| 229 |
+
for mask, kpts in zip(masks, keypoints_list):
|
| 230 |
+
other_kpts = np.concatenate([keypoints_list[:idx], keypoints_list[idx + 1 :]], axis=0).reshape(-1, 3)
|
| 231 |
+
score = _compute_one_mask_pose_consistency(mask, kpts, other_kpts)
|
| 232 |
+
scores.append(score)
|
| 233 |
+
|
| 234 |
+
return np.array(scores)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _pose2seg(
|
| 238 |
+
args: Any,
|
| 239 |
+
model: SAM2ImagePredictor,
|
| 240 |
+
bbox_xyxy: Optional[List[float]] = None,
|
| 241 |
+
pos_kpts: Optional[np.ndarray] = None,
|
| 242 |
+
neg_kpts: Optional[np.ndarray] = None,
|
| 243 |
+
image: Optional[np.ndarray] = None,
|
| 244 |
+
gt_mask: Optional[Any] = None,
|
| 245 |
+
num_pos_keypoints: Optional[int] = None,
|
| 246 |
+
gt_mask_is_binary: bool = False,
|
| 247 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
|
| 248 |
+
"""
|
| 249 |
+
Run SAM segmentation conditioned on pose keypoints and optional ground truth mask.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
args: Configuration object with prompting settings.
|
| 253 |
+
model (SAM2ImagePredictor): Prepared SAM2 model.
|
| 254 |
+
bbox_xyxy (Optional[List[float]]): Bounding box coordinates in xyxy format.
|
| 255 |
+
pos_kpts (Optional[np.ndarray]): Positive keypoints array.
|
| 256 |
+
neg_kpts (Optional[np.ndarray]): Negative keypoints array.
|
| 257 |
+
image (Optional[np.ndarray]): Input image array.
|
| 258 |
+
gt_mask (Optional[Any]): Ground truth mask (optional).
|
| 259 |
+
num_pos_keypoints (Optional[int]): Number of positive keypoints to use.
|
| 260 |
+
gt_mask_is_binary (bool): Flag indicating if ground truth mask is binary.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Tuple of (mask, pos_kpts_backup, neg_kpts_backup, score).
|
| 264 |
+
"""
|
| 265 |
+
num_pos_keypoints = args.num_pos_keypoints if num_pos_keypoints is None else num_pos_keypoints
|
| 266 |
+
|
| 267 |
+
# Filter-out un-annotated and invisible keypoints
|
| 268 |
+
if pos_kpts is not None:
|
| 269 |
+
pos_kpts = pos_kpts.reshape(-1, 3)
|
| 270 |
+
valid_kpts = pos_kpts[:, 2] > args.visibility_thr
|
| 271 |
+
|
| 272 |
+
pose_bbox = np.array([pos_kpts[:, 0].min(), pos_kpts[:, 1].min(), pos_kpts[:, 0].max(), pos_kpts[:, 1].max()])
|
| 273 |
+
pos_kpts, conf = _select_keypoints(args, pos_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy)
|
| 274 |
+
|
| 275 |
+
pos_kpts_backup = np.concatenate([pos_kpts, conf[:, None]], axis=1)
|
| 276 |
+
|
| 277 |
+
if pos_kpts.shape[0] > num_pos_keypoints:
|
| 278 |
+
pos_kpts = pos_kpts[:num_pos_keypoints, :]
|
| 279 |
+
pos_kpts_backup = pos_kpts_backup[:num_pos_keypoints, :]
|
| 280 |
+
|
| 281 |
+
else:
|
| 282 |
+
pose_bbox = None
|
| 283 |
+
pos_kpts = np.empty((0, 2), dtype=np.float32)
|
| 284 |
+
pos_kpts_backup = np.empty((0, 3), dtype=np.float32)
|
| 285 |
+
|
| 286 |
+
if neg_kpts is not None:
|
| 287 |
+
neg_kpts = neg_kpts.reshape(-1, 3)
|
| 288 |
+
valid_kpts = neg_kpts[:, 2] > args.visibility_thr
|
| 289 |
+
|
| 290 |
+
neg_kpts, conf = _select_keypoints(
|
| 291 |
+
args, neg_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy, method="closest"
|
| 292 |
+
)
|
| 293 |
+
selected_neg_kpts = neg_kpts
|
| 294 |
+
neg_kpts_backup = np.concatenate([neg_kpts, conf[:, None]], axis=1)
|
| 295 |
+
|
| 296 |
+
if neg_kpts.shape[0] > args.num_neg_keypoints:
|
| 297 |
+
selected_neg_kpts = neg_kpts[: args.num_neg_keypoints, :]
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
selected_neg_kpts = np.empty((0, 2), dtype=np.float32)
|
| 301 |
+
neg_kpts_backup = np.empty((0, 3), dtype=np.float32)
|
| 302 |
+
|
| 303 |
+
# Concatenate positive and negative keypoints
|
| 304 |
+
kpts = np.concatenate([pos_kpts, selected_neg_kpts], axis=0)
|
| 305 |
+
kpts_labels = np.concatenate([np.ones(pos_kpts.shape[0]), np.zeros(selected_neg_kpts.shape[0])], axis=0)
|
| 306 |
+
|
| 307 |
+
bbox = bbox_xyxy if args.use_bbox else None
|
| 308 |
+
|
| 309 |
+
if args.extend_bbox and not bbox is None:
|
| 310 |
+
# Expand the bbox such that it contains all positive keypoints
|
| 311 |
+
pose_bbox = np.array(
|
| 312 |
+
[pos_kpts[:, 0].min() - 2, pos_kpts[:, 1].min() - 2, pos_kpts[:, 0].max() + 2, pos_kpts[:, 1].max() + 2]
|
| 313 |
+
)
|
| 314 |
+
expanded_bbox = np.array(bbox)
|
| 315 |
+
expanded_bbox[:2] = np.minimum(bbox[:2], pose_bbox[:2])
|
| 316 |
+
expanded_bbox[2:] = np.maximum(bbox[2:], pose_bbox[2:])
|
| 317 |
+
bbox = expanded_bbox
|
| 318 |
+
|
| 319 |
+
if args.crop and args.use_bbox and image is not None:
|
| 320 |
+
# Crop the image to the 1.5 * bbox size
|
| 321 |
+
crop_bbox = np.array(bbox)
|
| 322 |
+
bbox_center = np.array([(crop_bbox[0] + crop_bbox[2]) / 2, (crop_bbox[1] + crop_bbox[3]) / 2])
|
| 323 |
+
bbox_size = np.array([crop_bbox[2] - crop_bbox[0], crop_bbox[3] - crop_bbox[1]])
|
| 324 |
+
bbox_size = 1.5 * bbox_size
|
| 325 |
+
crop_bbox = np.array(
|
| 326 |
+
[
|
| 327 |
+
bbox_center[0] - bbox_size[0] / 2,
|
| 328 |
+
bbox_center[1] - bbox_size[1] / 2,
|
| 329 |
+
bbox_center[0] + bbox_size[0] / 2,
|
| 330 |
+
bbox_center[1] + bbox_size[1] / 2,
|
| 331 |
+
]
|
| 332 |
+
)
|
| 333 |
+
crop_bbox = np.round(crop_bbox).astype(int)
|
| 334 |
+
crop_bbox = np.clip(crop_bbox, 0, [image.shape[1], image.shape[0], image.shape[1], image.shape[0]])
|
| 335 |
+
original_image_size = image.shape[:2]
|
| 336 |
+
image = image[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2], :]
|
| 337 |
+
|
| 338 |
+
# Update the keypoints
|
| 339 |
+
kpts = kpts - crop_bbox[:2]
|
| 340 |
+
bbox[:2] = bbox[:2] - crop_bbox[:2]
|
| 341 |
+
bbox[2:] = bbox[2:] - crop_bbox[:2]
|
| 342 |
+
|
| 343 |
+
model.set_image(image)
|
| 344 |
+
|
| 345 |
+
masks, scores, logits = model.predict(
|
| 346 |
+
point_coords=kpts,
|
| 347 |
+
point_labels=kpts_labels,
|
| 348 |
+
box=bbox,
|
| 349 |
+
multimask_output=False,
|
| 350 |
+
)
|
| 351 |
+
mask = masks[0]
|
| 352 |
+
scores = scores[0]
|
| 353 |
+
|
| 354 |
+
if args.crop and args.use_bbox and image is not None:
|
| 355 |
+
# Pad the mask to the original image size
|
| 356 |
+
mask_padded = np.zeros(original_image_size, dtype=np.uint8)
|
| 357 |
+
mask_padded[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2]] = mask
|
| 358 |
+
mask = mask_padded
|
| 359 |
+
|
| 360 |
+
bbox[:2] = bbox[:2] + crop_bbox[:2]
|
| 361 |
+
bbox[2:] = bbox[2:] + crop_bbox[:2]
|
| 362 |
+
|
| 363 |
+
if args.pose_mask_consistency:
|
| 364 |
+
if gt_mask_is_binary:
|
| 365 |
+
gt_mask_binary = gt_mask
|
| 366 |
+
else:
|
| 367 |
+
gt_mask_binary = Mask.decode(gt_mask).astype(bool) if gt_mask is not None else None
|
| 368 |
+
|
| 369 |
+
gt_mask_pose_consistency = _compute_one_mask_pose_consistency(gt_mask_binary, pos_kpts_backup, neg_kpts_backup)
|
| 370 |
+
dt_mask_pose_consistency = _compute_one_mask_pose_consistency(mask, pos_kpts_backup, neg_kpts_backup)
|
| 371 |
+
|
| 372 |
+
tol = 0.1
|
| 373 |
+
dt_is_same = np.abs(dt_mask_pose_consistency - gt_mask_pose_consistency) < tol
|
| 374 |
+
if dt_is_same:
|
| 375 |
+
mask = gt_mask_binary if gt_mask_binary.sum() < mask.sum() else mask
|
| 376 |
+
else:
|
| 377 |
+
mask = gt_mask_binary if gt_mask_pose_consistency > dt_mask_pose_consistency else mask
|
| 378 |
+
|
| 379 |
+
return mask, pos_kpts_backup, neg_kpts_backup, scores
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def process_image_with_SAM(
|
| 383 |
+
sam_args: Any,
|
| 384 |
+
image: np.ndarray,
|
| 385 |
+
model: SAM2ImagePredictor,
|
| 386 |
+
new_dets: InstanceData,
|
| 387 |
+
old_dets: Optional[InstanceData] = None,
|
| 388 |
+
) -> InstanceData:
|
| 389 |
+
"""
|
| 390 |
+
Wrapper that validates args and routes to single or batch processing.
|
| 391 |
+
"""
|
| 392 |
+
_validate_sam_args(sam_args)
|
| 393 |
+
if sam_args.batch:
|
| 394 |
+
return _process_image_batch(sam_args, image, model, new_dets, old_dets)
|
| 395 |
+
return _process_image_single(sam_args, image, model, new_dets, old_dets)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def _process_image_single(
|
| 399 |
+
sam_args: Any,
|
| 400 |
+
image: np.ndarray,
|
| 401 |
+
model: SAM2ImagePredictor,
|
| 402 |
+
new_dets: InstanceData,
|
| 403 |
+
old_dets: Optional[InstanceData] = None,
|
| 404 |
+
) -> InstanceData:
|
| 405 |
+
"""
|
| 406 |
+
Refine instance segmentation masks using SAM2 with pose-conditioned prompts.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
sam_args (Any): DotDict containing required SAM parameters:
|
| 410 |
+
crop (bool), use_bbox (bool), confidence_thr (float),
|
| 411 |
+
ignore_small_bboxes (bool), num_pos_keypoints (int),
|
| 412 |
+
num_pos_keypoints_if_crowd (int), crowd_by_max_iou (Optional[float]),
|
| 413 |
+
batch (bool), exclusive_masks (bool), extend_bbox (bool), pose_mask_consistency (bool).
|
| 414 |
+
image (np.ndarray): BGR image array of shape (H, W, 3).
|
| 415 |
+
model (SAM2ImagePredictor): Initialized SAM2 predictor.
|
| 416 |
+
new_dets (InstanceData): New detections with attributes:
|
| 417 |
+
bboxes, pred_masks, keypoints, bbox_scores.
|
| 418 |
+
old_dets (Optional[InstanceData]): Previous detections for negative prompts.
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
InstanceData: `new_dets` updated in-place with
|
| 422 |
+
`.refined_masks`, `.sam_scores`, and `.sam_kpts`.
|
| 423 |
+
"""
|
| 424 |
+
_validate_sam_args(sam_args)
|
| 425 |
+
|
| 426 |
+
if not (sam_args.crop and sam_args.use_bbox):
|
| 427 |
+
model.set_image(image)
|
| 428 |
+
|
| 429 |
+
# Ignore all keypoints with confidence below the threshold
|
| 430 |
+
new_keypoints = new_dets.keypoints.copy()
|
| 431 |
+
for kpts in new_keypoints:
|
| 432 |
+
conf_mask = kpts[:, 2] < sam_args.confidence_thr
|
| 433 |
+
kpts[conf_mask, :] = 0
|
| 434 |
+
n_new_dets = len(new_dets.bboxes)
|
| 435 |
+
n_old_dets = 0
|
| 436 |
+
if old_dets is not None:
|
| 437 |
+
n_old_dets = len(old_dets.bboxes)
|
| 438 |
+
old_keypoints = old_dets.keypoints.copy()
|
| 439 |
+
for kpts in old_keypoints:
|
| 440 |
+
conf_mask = kpts[:, 2] < sam_args.confidence_thr
|
| 441 |
+
kpts[conf_mask, :] = 0
|
| 442 |
+
|
| 443 |
+
all_bboxes = new_dets.bboxes.copy()
|
| 444 |
+
if old_dets is not None:
|
| 445 |
+
all_bboxes = np.concatenate([all_bboxes, old_dets.bboxes], axis=0)
|
| 446 |
+
|
| 447 |
+
max_ious = _get_max_ious(all_bboxes)
|
| 448 |
+
|
| 449 |
+
gt_bboxes = []
|
| 450 |
+
new_dets.refined_masks = np.zeros((n_new_dets, image.shape[0], image.shape[1]), dtype=np.uint8)
|
| 451 |
+
new_dets.sam_scores = np.zeros_like(new_dets.bbox_scores)
|
| 452 |
+
new_dets.sam_kpts = np.zeros((len(new_dets.bboxes), sam_args.num_pos_keypoints, 3), dtype=np.float32)
|
| 453 |
+
for instance_idx in range(len(new_dets.bboxes)):
|
| 454 |
+
bbox_xywh = new_dets.bboxes[instance_idx]
|
| 455 |
+
bbox_area = bbox_xywh[2] * bbox_xywh[3]
|
| 456 |
+
|
| 457 |
+
if sam_args.ignore_small_bboxes and bbox_area < 100 * 100:
|
| 458 |
+
continue
|
| 459 |
+
dt_mask = new_dets.pred_masks[instance_idx] if new_dets.pred_masks is not None else None
|
| 460 |
+
|
| 461 |
+
bbox_xyxy = [bbox_xywh[0], bbox_xywh[1], bbox_xywh[0] + bbox_xywh[2], bbox_xywh[1] + bbox_xywh[3]]
|
| 462 |
+
gt_bboxes.append(bbox_xyxy)
|
| 463 |
+
this_kpts = new_keypoints[instance_idx].reshape(1, -1, 3)
|
| 464 |
+
other_kpts = None
|
| 465 |
+
if old_dets is not None:
|
| 466 |
+
other_kpts = old_keypoints.copy().reshape(n_old_dets, -1, 3)
|
| 467 |
+
if len(new_keypoints) > 1:
|
| 468 |
+
other_new_kpts = np.concatenate([new_keypoints[:instance_idx], new_keypoints[instance_idx + 1 :]], axis=0)
|
| 469 |
+
other_kpts = (
|
| 470 |
+
np.concatenate([other_kpts, other_new_kpts], axis=0) if other_kpts is not None else other_new_kpts
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
num_pos_keypoints = sam_args.num_pos_keypoints
|
| 474 |
+
if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou:
|
| 475 |
+
bbox_xyxy = None
|
| 476 |
+
num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd
|
| 477 |
+
|
| 478 |
+
dt_mask, pos_kpts, neg_kpts, scores = _pose2seg(
|
| 479 |
+
sam_args,
|
| 480 |
+
model,
|
| 481 |
+
bbox_xyxy,
|
| 482 |
+
pos_kpts=this_kpts,
|
| 483 |
+
neg_kpts=other_kpts,
|
| 484 |
+
image=image if (sam_args.crop and sam_args.use_bbox) else None,
|
| 485 |
+
gt_mask=dt_mask,
|
| 486 |
+
num_pos_keypoints=num_pos_keypoints,
|
| 487 |
+
gt_mask_is_binary=True,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
new_dets.refined_masks[instance_idx] = dt_mask
|
| 491 |
+
new_dets.sam_scores[instance_idx] = scores
|
| 492 |
+
|
| 493 |
+
# If the number of positive keypoints is less than the required number, fill the rest with zeros
|
| 494 |
+
if len(pos_kpts) != sam_args.num_pos_keypoints:
|
| 495 |
+
pos_kpts = np.concatenate(
|
| 496 |
+
[pos_kpts, np.zeros((sam_args.num_pos_keypoints - len(pos_kpts), 3), dtype=np.float32)], axis=0
|
| 497 |
+
)
|
| 498 |
+
new_dets.sam_kpts[instance_idx] = pos_kpts
|
| 499 |
+
|
| 500 |
+
n_masks = len(new_dets.refined_masks) + (len(old_dets.refined_masks) if old_dets is not None else 0)
|
| 501 |
+
|
| 502 |
+
if sam_args.exclusive_masks and n_masks > 1:
|
| 503 |
+
all_masks = (
|
| 504 |
+
np.concatenate([new_dets.refined_masks, old_dets.refined_masks], axis=0)
|
| 505 |
+
if old_dets is not None
|
| 506 |
+
else new_dets.refined_masks
|
| 507 |
+
)
|
| 508 |
+
all_scores = (
|
| 509 |
+
np.concatenate([new_dets.sam_scores, old_dets.sam_scores], axis=0)
|
| 510 |
+
if old_dets is not None
|
| 511 |
+
else new_dets.sam_scores
|
| 512 |
+
)
|
| 513 |
+
refined_masks = _apply_exclusive_masks(all_masks, all_scores)
|
| 514 |
+
new_dets.refined_masks = refined_masks[: len(new_dets.refined_masks)]
|
| 515 |
+
|
| 516 |
+
return new_dets
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def _process_image_batch(
|
| 520 |
+
sam_args: Any,
|
| 521 |
+
image: np.ndarray,
|
| 522 |
+
model: SAM2ImagePredictor,
|
| 523 |
+
new_dets: InstanceData,
|
| 524 |
+
old_dets: Optional[InstanceData] = None,
|
| 525 |
+
) -> InstanceData:
|
| 526 |
+
"""
|
| 527 |
+
Batch process multiple detection instances with SAM2 refinement.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
sam_args (Any): DotDict of SAM parameters (same as `process_image_with_SAM`).
|
| 531 |
+
image (np.ndarray): Input BGR image.
|
| 532 |
+
model (SAM2ImagePredictor): Prepared SAM2 predictor.
|
| 533 |
+
new_dets (InstanceData): New detection instances.
|
| 534 |
+
old_dets (Optional[InstanceData]): Previous detections for negative prompts.
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
InstanceData: `new_dets` updated as in `process_image_with_SAM`.
|
| 538 |
+
"""
|
| 539 |
+
n_new_dets = len(new_dets.bboxes)
|
| 540 |
+
|
| 541 |
+
model.set_image(image)
|
| 542 |
+
|
| 543 |
+
image_kpts = []
|
| 544 |
+
image_bboxes = []
|
| 545 |
+
num_valid_kpts = []
|
| 546 |
+
for instance_idx in range(len(new_dets.bboxes)):
|
| 547 |
+
|
| 548 |
+
bbox_xywh = new_dets.bboxes[instance_idx].copy()
|
| 549 |
+
bbox_area = bbox_xywh[2] * bbox_xywh[3]
|
| 550 |
+
if sam_args.ignore_small_bboxes and bbox_area < 100 * 100:
|
| 551 |
+
continue
|
| 552 |
+
|
| 553 |
+
this_kpts = new_dets.keypoints[instance_idx].copy().reshape(-1, 3)
|
| 554 |
+
kpts_vis = np.array(this_kpts[:, 2])
|
| 555 |
+
visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr)
|
| 556 |
+
num_visible = (visible_kpts).sum()
|
| 557 |
+
if num_visible <= 0:
|
| 558 |
+
continue
|
| 559 |
+
num_valid_kpts.append(num_visible)
|
| 560 |
+
image_bboxes.append(np.array(bbox_xywh))
|
| 561 |
+
this_kpts[~visible_kpts, :2] = 0
|
| 562 |
+
this_kpts[:, 2] = visible_kpts
|
| 563 |
+
image_kpts.append(this_kpts)
|
| 564 |
+
if old_dets is not None:
|
| 565 |
+
for instance_idx in range(len(old_dets.bboxes)):
|
| 566 |
+
bbox_xywh = old_dets.bboxes[instance_idx].copy()
|
| 567 |
+
bbox_area = bbox_xywh[2] * bbox_xywh[3]
|
| 568 |
+
if sam_args.ignore_small_bboxes and bbox_area < 100 * 100:
|
| 569 |
+
continue
|
| 570 |
+
this_kpts = old_dets.keypoints[instance_idx].reshape(-1, 3)
|
| 571 |
+
kpts_vis = np.array(this_kpts[:, 2])
|
| 572 |
+
visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr)
|
| 573 |
+
num_visible = (visible_kpts).sum()
|
| 574 |
+
if num_visible <= 0:
|
| 575 |
+
continue
|
| 576 |
+
num_valid_kpts.append(num_visible)
|
| 577 |
+
image_bboxes.append(np.array(bbox_xywh))
|
| 578 |
+
this_kpts[~visible_kpts, :2] = 0
|
| 579 |
+
this_kpts[:, 2] = visible_kpts
|
| 580 |
+
image_kpts.append(this_kpts)
|
| 581 |
+
|
| 582 |
+
image_kpts = np.array(image_kpts)
|
| 583 |
+
image_bboxes = np.array(image_bboxes)
|
| 584 |
+
num_valid_kpts = np.array(num_valid_kpts)
|
| 585 |
+
|
| 586 |
+
image_kpts_backup = image_kpts.copy()
|
| 587 |
+
|
| 588 |
+
# Prepare keypoints such that all instances have the same number of keypoints
|
| 589 |
+
# First sort keypoints by their distance to the center of the bounding box
|
| 590 |
+
# If some are missing, duplicate the last one
|
| 591 |
+
prepared_kpts = []
|
| 592 |
+
prepared_kpts_backup = []
|
| 593 |
+
for bbox, kpts, num_visible in zip(image_bboxes, image_kpts, num_valid_kpts):
|
| 594 |
+
|
| 595 |
+
this_kpts, this_conf = _select_keypoints(sam_args, kpts, num_visible, bbox)
|
| 596 |
+
|
| 597 |
+
# Duplicate the last keypoint if some are missing
|
| 598 |
+
if this_kpts.shape[0] < num_valid_kpts.max():
|
| 599 |
+
this_kpts = np.concatenate(
|
| 600 |
+
[this_kpts, np.tile(this_kpts[-1], (num_valid_kpts.max() - this_kpts.shape[0], 1))], axis=0
|
| 601 |
+
)
|
| 602 |
+
this_conf = np.concatenate(
|
| 603 |
+
[this_conf, np.tile(this_conf[-1], (num_valid_kpts.max() - this_conf.shape[0],))], axis=0
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
prepared_kpts.append(this_kpts)
|
| 607 |
+
prepared_kpts_backup.append(np.concatenate([this_kpts, this_conf[:, None]], axis=1))
|
| 608 |
+
image_kpts = np.array(prepared_kpts)
|
| 609 |
+
image_kpts_backup = np.array(prepared_kpts_backup)
|
| 610 |
+
kpts_labels = np.ones(image_kpts.shape[:2])
|
| 611 |
+
|
| 612 |
+
# Compute IoUs between all bounding boxes
|
| 613 |
+
max_ious = _get_max_ious(image_bboxes)
|
| 614 |
+
num_pos_keypoints = sam_args.num_pos_keypoints
|
| 615 |
+
use_bbox = sam_args.use_bbox
|
| 616 |
+
if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou:
|
| 617 |
+
use_bbox = False
|
| 618 |
+
num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd
|
| 619 |
+
|
| 620 |
+
# Threshold the number of positive keypoints
|
| 621 |
+
if num_pos_keypoints > 0 and num_pos_keypoints < image_kpts.shape[1]:
|
| 622 |
+
image_kpts = image_kpts[:, :num_pos_keypoints, :]
|
| 623 |
+
kpts_labels = kpts_labels[:, :num_pos_keypoints]
|
| 624 |
+
image_kpts_backup = image_kpts_backup[:, :num_pos_keypoints, :]
|
| 625 |
+
|
| 626 |
+
elif num_pos_keypoints == 0:
|
| 627 |
+
image_kpts = None
|
| 628 |
+
kpts_labels = None
|
| 629 |
+
image_kpts_backup = np.empty((0, 3), dtype=np.float32)
|
| 630 |
+
|
| 631 |
+
image_bboxes_xyxy = None
|
| 632 |
+
if use_bbox:
|
| 633 |
+
image_bboxes_xyxy = np.array(image_bboxes)
|
| 634 |
+
image_bboxes_xyxy[:, 2:] += image_bboxes_xyxy[:, :2]
|
| 635 |
+
|
| 636 |
+
# Expand the bbox to include the positive keypoints
|
| 637 |
+
if sam_args.extend_bbox:
|
| 638 |
+
pose_bbox = np.stack(
|
| 639 |
+
[
|
| 640 |
+
np.min(image_kpts[:, :, 0], axis=1) - 2,
|
| 641 |
+
np.min(image_kpts[:, :, 1], axis=1) - 2,
|
| 642 |
+
np.max(image_kpts[:, :, 0], axis=1) + 2,
|
| 643 |
+
np.max(image_kpts[:, :, 1], axis=1) + 2,
|
| 644 |
+
],
|
| 645 |
+
axis=1,
|
| 646 |
+
)
|
| 647 |
+
expanded_bbox = np.array(image_bboxes_xyxy)
|
| 648 |
+
expanded_bbox[:, :2] = np.minimum(expanded_bbox[:, :2], pose_bbox[:, :2])
|
| 649 |
+
expanded_bbox[:, 2:] = np.maximum(expanded_bbox[:, 2:], pose_bbox[:, 2:])
|
| 650 |
+
# bbox_expanded = (np.abs(expanded_bbox - image_bboxes_xyxy) > 1e-4).any(axis=1)
|
| 651 |
+
image_bboxes_xyxy = expanded_bbox
|
| 652 |
+
|
| 653 |
+
# Process even old detections to get their 'negative' keypoints
|
| 654 |
+
masks, scores, logits = model.predict(
|
| 655 |
+
point_coords=image_kpts,
|
| 656 |
+
point_labels=kpts_labels,
|
| 657 |
+
box=image_bboxes_xyxy,
|
| 658 |
+
multimask_output=False,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Reshape the masks to (N, C, H, W). If the model outputs (C, H, W), add a number of masks dimension
|
| 662 |
+
if len(masks.shape) == 3:
|
| 663 |
+
masks = masks[None, :, :, :]
|
| 664 |
+
masks = masks[:, 0, :, :]
|
| 665 |
+
N = masks.shape[0]
|
| 666 |
+
scores = scores.reshape(N)
|
| 667 |
+
|
| 668 |
+
if sam_args.exclusive_masks and N > 1:
|
| 669 |
+
# Make sure the masks are non-overlapping
|
| 670 |
+
# If two masks overlap, set the pixel to the one with the highest score
|
| 671 |
+
masks = _apply_exclusive_masks(masks, scores)
|
| 672 |
+
|
| 673 |
+
gt_masks = new_dets.pred_masks.copy() if new_dets.pred_masks is not None else None
|
| 674 |
+
if sam_args.pose_mask_consistency and gt_masks is not None:
|
| 675 |
+
# Measure 'mask-pose_conistency' by computing number of keypoints inside the mask
|
| 676 |
+
# Compute for both gt (if available) and predicted masks and then choose the one with higher consistency
|
| 677 |
+
dt_mask_pose_consistency = _compute_mask_pose_consistency(masks, image_kpts_backup)
|
| 678 |
+
gt_mask_pose_consistency = _compute_mask_pose_consistency(gt_masks, image_kpts_backup)
|
| 679 |
+
|
| 680 |
+
dt_masks_area = np.array([m.sum() for m in masks])
|
| 681 |
+
gt_masks_area = np.array([m.sum() for m in gt_masks]) if gt_masks is not None else np.zeros_like(dt_masks_area)
|
| 682 |
+
|
| 683 |
+
# If PM-c is approx the same, prefer the smaller mask
|
| 684 |
+
tol = 0.1
|
| 685 |
+
pmc_is_equal = np.isclose(dt_mask_pose_consistency, gt_mask_pose_consistency, atol=tol)
|
| 686 |
+
dt_is_worse = (dt_mask_pose_consistency < (gt_mask_pose_consistency - tol)) | pmc_is_equal & (
|
| 687 |
+
dt_masks_area > gt_masks_area
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
new_masks = []
|
| 691 |
+
for dt_mask, gt_mask, dt_worse in zip(masks, gt_masks, dt_is_worse):
|
| 692 |
+
if dt_worse:
|
| 693 |
+
new_masks.append(gt_mask)
|
| 694 |
+
else:
|
| 695 |
+
new_masks.append(dt_mask)
|
| 696 |
+
masks = np.array(new_masks)
|
| 697 |
+
|
| 698 |
+
new_dets.refined_masks = masks[:n_new_dets]
|
| 699 |
+
new_dets.sam_scores = scores[:n_new_dets]
|
| 700 |
+
new_dets.sam_kpts = image_kpts_backup[:n_new_dets]
|
| 701 |
+
|
| 702 |
+
return new_dets
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def _apply_exclusive_masks(masks: np.ndarray, scores: np.ndarray) -> np.ndarray:
|
| 706 |
+
"""
|
| 707 |
+
Ensure masks are non-overlapping by keeping at each pixel the mask with the highest score.
|
| 708 |
+
"""
|
| 709 |
+
no_mask = masks.sum(axis=0) == 0
|
| 710 |
+
masked_scores = masks * scores[:, None, None]
|
| 711 |
+
argmax_masks = np.argmax(masked_scores, axis=0)
|
| 712 |
+
new_masks = argmax_masks[None, :, :] == (np.arange(masks.shape[0])[:, None, None])
|
| 713 |
+
new_masks[:, no_mask] = 0
|
| 714 |
+
return new_masks
|
mmpose/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import mmcv
|
| 3 |
+
import mmengine
|
| 4 |
+
from mmengine.utils import digit_version
|
| 5 |
+
|
| 6 |
+
from .version import __version__, short_version
|
| 7 |
+
|
| 8 |
+
mmcv_minimum_version = '2.0.0rc4'
|
| 9 |
+
mmcv_maximum_version = '2.3.0'
|
| 10 |
+
mmcv_version = digit_version(mmcv.__version__)
|
| 11 |
+
|
| 12 |
+
mmengine_minimum_version = '0.6.0'
|
| 13 |
+
mmengine_maximum_version = '1.0.0'
|
| 14 |
+
mmengine_version = digit_version(mmengine.__version__)
|
| 15 |
+
|
| 16 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
| 17 |
+
and mmcv_version <= digit_version(mmcv_maximum_version)), \
|
| 18 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
| 19 |
+
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
|
| 20 |
+
|
| 21 |
+
assert (mmengine_version >= digit_version(mmengine_minimum_version)
|
| 22 |
+
and mmengine_version <= digit_version(mmengine_maximum_version)), \
|
| 23 |
+
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
| 24 |
+
f'Please install mmengine>={mmengine_minimum_version}, ' \
|
| 25 |
+
f'<={mmengine_maximum_version}.'
|
| 26 |
+
|
| 27 |
+
__all__ = ['__version__', 'short_version']
|
mmpose/apis/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .inference import (collect_multi_frames, inference_bottomup,
|
| 3 |
+
inference_topdown, init_model)
|
| 4 |
+
from .inference_3d import (collate_pose_sequence, convert_keypoint_definition,
|
| 5 |
+
extract_pose_sequence, inference_pose_lifter_model)
|
| 6 |
+
from .inference_tracking import _compute_iou, _track_by_iou, _track_by_oks
|
| 7 |
+
from .inferencers import MMPoseInferencer, Pose2DInferencer
|
| 8 |
+
from .visualization import visualize
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'init_model', 'inference_topdown', 'inference_bottomup',
|
| 12 |
+
'collect_multi_frames', 'Pose2DInferencer', 'MMPoseInferencer',
|
| 13 |
+
'_track_by_iou', '_track_by_oks', '_compute_iou',
|
| 14 |
+
'inference_pose_lifter_model', 'extract_pose_sequence',
|
| 15 |
+
'convert_keypoint_definition', 'collate_pose_sequence', 'visualize'
|
| 16 |
+
]
|
mmpose/apis/inference.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from mmengine.config import Config
|
| 10 |
+
from mmengine.dataset import Compose, pseudo_collate
|
| 11 |
+
from mmengine.model.utils import revert_sync_batchnorm
|
| 12 |
+
from mmengine.registry import init_default_scope
|
| 13 |
+
from mmengine.runner import load_checkpoint
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from mmpose.datasets.datasets.utils import parse_pose_metainfo
|
| 17 |
+
from mmpose.models.builder import build_pose_estimator
|
| 18 |
+
from mmpose.structures import PoseDataSample
|
| 19 |
+
from mmpose.structures.bbox import bbox_xywh2xyxy
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
def dataset_meta_from_config(config: Config,
|
| 24 |
+
dataset_mode: str = 'train') -> Optional[dict]:
|
| 25 |
+
"""Get dataset metainfo from the model config.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
| 29 |
+
:obj:`Path`, or the config object.
|
| 30 |
+
dataset_mode (str): Specify the dataset of which to get the metainfo.
|
| 31 |
+
Options are ``'train'``, ``'val'`` and ``'test'``. Defaults to
|
| 32 |
+
``'train'``
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
dict, optional: The dataset metainfo. See
|
| 36 |
+
``mmpose.datasets.datasets.utils.parse_pose_metainfo`` for details.
|
| 37 |
+
Return ``None`` if failing to get dataset metainfo from the config.
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
if dataset_mode == 'train':
|
| 41 |
+
dataset_cfg = config.train_dataloader.dataset
|
| 42 |
+
elif dataset_mode == 'val':
|
| 43 |
+
dataset_cfg = config.val_dataloader.dataset
|
| 44 |
+
elif dataset_mode == 'test':
|
| 45 |
+
dataset_cfg = config.test_dataloader.dataset
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f'Invalid dataset {dataset_mode} to get metainfo. '
|
| 49 |
+
'Should be one of "train", "val", or "test".')
|
| 50 |
+
|
| 51 |
+
if 'metainfo' in dataset_cfg:
|
| 52 |
+
metainfo = dataset_cfg.metainfo
|
| 53 |
+
else:
|
| 54 |
+
import mmpose.datasets.datasets # noqa: F401, F403
|
| 55 |
+
from mmpose.registry import DATASETS
|
| 56 |
+
|
| 57 |
+
dataset_class = dataset_cfg.type if isinstance(
|
| 58 |
+
dataset_cfg.type, type) else DATASETS.get(dataset_cfg.type)
|
| 59 |
+
metainfo = dataset_class.METAINFO
|
| 60 |
+
|
| 61 |
+
metainfo = parse_pose_metainfo(metainfo)
|
| 62 |
+
|
| 63 |
+
except AttributeError:
|
| 64 |
+
metainfo = None
|
| 65 |
+
|
| 66 |
+
return metainfo
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def init_model(config: Union[str, Path, Config],
|
| 70 |
+
checkpoint: Optional[str] = None,
|
| 71 |
+
device: str = 'cuda:0',
|
| 72 |
+
cfg_options: Optional[dict] = None) -> nn.Module:
|
| 73 |
+
"""Initialize a pose estimator from a config file.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
| 77 |
+
:obj:`Path`, or the config object.
|
| 78 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
| 79 |
+
will not load any weights. Defaults to ``None``
|
| 80 |
+
device (str): The device where the anchors will be put on.
|
| 81 |
+
Defaults to ``'cuda:0'``.
|
| 82 |
+
cfg_options (dict, optional): Options to override some settings in
|
| 83 |
+
the used config. Defaults to ``None``
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
nn.Module: The constructed pose estimator.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
if isinstance(config, (str, Path)):
|
| 90 |
+
config = Config.fromfile(config)
|
| 91 |
+
elif not isinstance(config, Config):
|
| 92 |
+
raise TypeError('config must be a filename or Config object, '
|
| 93 |
+
f'but got {type(config)}')
|
| 94 |
+
if cfg_options is not None:
|
| 95 |
+
config.merge_from_dict(cfg_options)
|
| 96 |
+
elif 'init_cfg' in config.model.backbone:
|
| 97 |
+
config.model.backbone.init_cfg = None
|
| 98 |
+
config.model.train_cfg = None
|
| 99 |
+
|
| 100 |
+
# register all modules in mmpose into the registries
|
| 101 |
+
scope = config.get('default_scope', 'mmpose')
|
| 102 |
+
if scope is not None:
|
| 103 |
+
init_default_scope(scope)
|
| 104 |
+
|
| 105 |
+
model = build_pose_estimator(config.model)
|
| 106 |
+
model = revert_sync_batchnorm(model)
|
| 107 |
+
# get dataset_meta in this priority: checkpoint > config > default (COCO)
|
| 108 |
+
dataset_meta = None
|
| 109 |
+
|
| 110 |
+
if checkpoint is not None:
|
| 111 |
+
ckpt = load_checkpoint(model, checkpoint, map_location='cpu')
|
| 112 |
+
|
| 113 |
+
if 'dataset_meta' in ckpt.get('meta', {}):
|
| 114 |
+
# checkpoint from mmpose 1.x
|
| 115 |
+
dataset_meta = ckpt['meta']['dataset_meta']
|
| 116 |
+
|
| 117 |
+
if dataset_meta is None:
|
| 118 |
+
dataset_meta = dataset_meta_from_config(config, dataset_mode='train')
|
| 119 |
+
|
| 120 |
+
if dataset_meta is None:
|
| 121 |
+
warnings.simplefilter('once')
|
| 122 |
+
warnings.warn('Can not load dataset_meta from the checkpoint or the '
|
| 123 |
+
'model config. Use COCO metainfo by default.')
|
| 124 |
+
dataset_meta = parse_pose_metainfo(
|
| 125 |
+
dict(from_file='configs/_base_/datasets/coco.py'))
|
| 126 |
+
|
| 127 |
+
model.dataset_meta = dataset_meta
|
| 128 |
+
|
| 129 |
+
model.cfg = config # save the config in the model for convenience
|
| 130 |
+
model.to(device)
|
| 131 |
+
model.eval()
|
| 132 |
+
return model
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def inference_topdown(model: nn.Module,
|
| 136 |
+
img: Union[np.ndarray, str],
|
| 137 |
+
bboxes: Optional[Union[List, np.ndarray]] = None,
|
| 138 |
+
masks: Optional[Union[List, np.ndarray]] = None,
|
| 139 |
+
bbox_format: str = 'xyxy') -> List[PoseDataSample]:
|
| 140 |
+
"""Inference image with a top-down pose estimator.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
model (nn.Module): The top-down pose estimator
|
| 144 |
+
img (np.ndarray | str): The loaded image or image file to inference
|
| 145 |
+
bboxes (np.ndarray, optional): The bboxes in shape (N, 4), each row
|
| 146 |
+
represents a bbox. If not given, the entire image will be regarded
|
| 147 |
+
as a single bbox area. Defaults to ``None``
|
| 148 |
+
bbox_format (str): The bbox format indicator. Options are ``'xywh'``
|
| 149 |
+
and ``'xyxy'``. Defaults to ``'xyxy'``
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
List[:obj:`PoseDataSample`]: The inference results. Specifically, the
|
| 153 |
+
predicted keypoints and scores are saved at
|
| 154 |
+
``data_sample.pred_instances.keypoints`` and
|
| 155 |
+
``data_sample.pred_instances.keypoint_scores``.
|
| 156 |
+
"""
|
| 157 |
+
scope = model.cfg.get('default_scope', 'mmpose')
|
| 158 |
+
if scope is not None:
|
| 159 |
+
init_default_scope(scope)
|
| 160 |
+
pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
|
| 161 |
+
|
| 162 |
+
if bboxes is None or len(bboxes) == 0:
|
| 163 |
+
# get bbox from the image size
|
| 164 |
+
if isinstance(img, str):
|
| 165 |
+
w, h = Image.open(img).size
|
| 166 |
+
else:
|
| 167 |
+
h, w = img.shape[:2]
|
| 168 |
+
|
| 169 |
+
bboxes = np.array([[0, 0, w, h]], dtype=np.float32)
|
| 170 |
+
else:
|
| 171 |
+
if isinstance(bboxes, list):
|
| 172 |
+
bboxes = np.array(bboxes)
|
| 173 |
+
|
| 174 |
+
assert bbox_format in {'xyxy', 'xywh'}, \
|
| 175 |
+
f'Invalid bbox_format "{bbox_format}".'
|
| 176 |
+
|
| 177 |
+
if bbox_format == 'xywh':
|
| 178 |
+
bboxes = bbox_xywh2xyxy(bboxes)
|
| 179 |
+
|
| 180 |
+
if masks is None or len(masks) == 0:
|
| 181 |
+
masks = np.zeros((bboxes.shape[0], img.shape[0], img.shape[1]),
|
| 182 |
+
dtype=np.uint8)
|
| 183 |
+
|
| 184 |
+
# Masks are expected in polygon format
|
| 185 |
+
poly_masks = []
|
| 186 |
+
for mask in masks:
|
| 187 |
+
if np.sum(mask) == 0:
|
| 188 |
+
poly_masks.append(None)
|
| 189 |
+
else:
|
| 190 |
+
contours, _ = cv2.findContours((mask*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 191 |
+
polygons = [contour.flatten() for contour in contours if len(contour) > 3]
|
| 192 |
+
poly_masks.append(polygons if polygons else None)
|
| 193 |
+
|
| 194 |
+
# construct batch data samples
|
| 195 |
+
data_list = []
|
| 196 |
+
for bbox, pmask in zip(bboxes, poly_masks):
|
| 197 |
+
if isinstance(img, str):
|
| 198 |
+
data_info = dict(img_path=img)
|
| 199 |
+
else:
|
| 200 |
+
data_info = dict(img=img)
|
| 201 |
+
data_info['bbox'] = bbox[None] # shape (1, 4)
|
| 202 |
+
data_info['segmentation'] = pmask
|
| 203 |
+
data_info['bbox_score'] = np.ones(1, dtype=np.float32) # shape (1,)
|
| 204 |
+
data_info.update(model.dataset_meta)
|
| 205 |
+
data_list.append(pipeline(data_info))
|
| 206 |
+
|
| 207 |
+
if data_list:
|
| 208 |
+
# collate data list into a batch, which is a dict with following keys:
|
| 209 |
+
# batch['inputs']: a list of input images
|
| 210 |
+
# batch['data_samples']: a list of :obj:`PoseDataSample`
|
| 211 |
+
batch = pseudo_collate(data_list)
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
results = model.test_step(batch)
|
| 214 |
+
else:
|
| 215 |
+
results = []
|
| 216 |
+
|
| 217 |
+
return results
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def inference_bottomup(model: nn.Module, img: Union[np.ndarray, str]):
|
| 221 |
+
"""Inference image with a bottom-up pose estimator.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
model (nn.Module): The bottom-up pose estimator
|
| 225 |
+
img (np.ndarray | str): The loaded image or image file to inference
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
List[:obj:`PoseDataSample`]: The inference results. Specifically, the
|
| 229 |
+
predicted keypoints and scores are saved at
|
| 230 |
+
``data_sample.pred_instances.keypoints`` and
|
| 231 |
+
``data_sample.pred_instances.keypoint_scores``.
|
| 232 |
+
"""
|
| 233 |
+
pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
|
| 234 |
+
|
| 235 |
+
# prepare data batch
|
| 236 |
+
if isinstance(img, str):
|
| 237 |
+
data_info = dict(img_path=img)
|
| 238 |
+
else:
|
| 239 |
+
data_info = dict(img=img)
|
| 240 |
+
data_info.update(model.dataset_meta)
|
| 241 |
+
data = pipeline(data_info)
|
| 242 |
+
batch = pseudo_collate([data])
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
results = model.test_step(batch)
|
| 246 |
+
|
| 247 |
+
return results
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def collect_multi_frames(video, frame_id, indices, online=False):
|
| 251 |
+
"""Collect multi frames from the video.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
video (mmcv.VideoReader): A VideoReader of the input video file.
|
| 255 |
+
frame_id (int): index of the current frame
|
| 256 |
+
indices (list(int)): index offsets of the frames to collect
|
| 257 |
+
online (bool): inference mode, if set to True, can not use future
|
| 258 |
+
frame information.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
list(ndarray): multi frames collected from the input video file.
|
| 262 |
+
"""
|
| 263 |
+
num_frames = len(video)
|
| 264 |
+
frames = []
|
| 265 |
+
# put the current frame at first
|
| 266 |
+
frames.append(video[frame_id])
|
| 267 |
+
# use multi frames for inference
|
| 268 |
+
for idx in indices:
|
| 269 |
+
# skip current frame
|
| 270 |
+
if idx == 0:
|
| 271 |
+
continue
|
| 272 |
+
support_idx = frame_id + idx
|
| 273 |
+
# online mode, can not use future frame information
|
| 274 |
+
if online:
|
| 275 |
+
support_idx = np.clip(support_idx, 0, frame_id)
|
| 276 |
+
else:
|
| 277 |
+
support_idx = np.clip(support_idx, 0, num_frames - 1)
|
| 278 |
+
frames.append(video[support_idx])
|
| 279 |
+
|
| 280 |
+
return frames
|
mmpose/apis/inference_3d.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from mmengine.dataset import Compose, pseudo_collate
|
| 5 |
+
from mmengine.registry import init_default_scope
|
| 6 |
+
from mmengine.structures import InstanceData
|
| 7 |
+
|
| 8 |
+
from mmpose.structures import PoseDataSample
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def convert_keypoint_definition(keypoints, pose_det_dataset,
|
| 12 |
+
pose_lift_dataset):
|
| 13 |
+
"""Convert pose det dataset keypoints definition to pose lifter dataset
|
| 14 |
+
keypoints definition, so that they are compatible with the definitions
|
| 15 |
+
required for 3D pose lifting.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
keypoints (ndarray[N, K, 2 or 3]): 2D keypoints to be transformed.
|
| 19 |
+
pose_det_dataset, (str): Name of the dataset for 2D pose detector.
|
| 20 |
+
pose_lift_dataset (str): Name of the dataset for pose lifter model.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
ndarray[K, 2 or 3]: the transformed 2D keypoints.
|
| 24 |
+
"""
|
| 25 |
+
assert pose_lift_dataset in [
|
| 26 |
+
'h36m', 'h3wb'], '`pose_lift_dataset` should be ' \
|
| 27 |
+
f'`h36m`, but got {pose_lift_dataset}.'
|
| 28 |
+
|
| 29 |
+
keypoints_new = np.zeros((keypoints.shape[0], 17, keypoints.shape[2]),
|
| 30 |
+
dtype=keypoints.dtype)
|
| 31 |
+
if pose_lift_dataset in ['h36m', 'h3wb']:
|
| 32 |
+
if pose_det_dataset in ['h36m', 'coco_wholebody']:
|
| 33 |
+
keypoints_new = keypoints
|
| 34 |
+
elif pose_det_dataset in ['coco', 'posetrack18']:
|
| 35 |
+
# pelvis (root) is in the middle of l_hip and r_hip
|
| 36 |
+
keypoints_new[:, 0] = (keypoints[:, 11] + keypoints[:, 12]) / 2
|
| 37 |
+
# thorax is in the middle of l_shoulder and r_shoulder
|
| 38 |
+
keypoints_new[:, 8] = (keypoints[:, 5] + keypoints[:, 6]) / 2
|
| 39 |
+
# spine is in the middle of thorax and pelvis
|
| 40 |
+
keypoints_new[:,
|
| 41 |
+
7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2
|
| 42 |
+
# in COCO, head is in the middle of l_eye and r_eye
|
| 43 |
+
# in PoseTrack18, head is in the middle of head_bottom and head_top
|
| 44 |
+
keypoints_new[:, 10] = (keypoints[:, 1] + keypoints[:, 2]) / 2
|
| 45 |
+
# rearrange other keypoints
|
| 46 |
+
keypoints_new[:, [1, 2, 3, 4, 5, 6, 9, 11, 12, 13, 14, 15, 16]] = \
|
| 47 |
+
keypoints[:, [12, 14, 16, 11, 13, 15, 0, 5, 7, 9, 6, 8, 10]]
|
| 48 |
+
elif pose_det_dataset in ['aic']:
|
| 49 |
+
# pelvis (root) is in the middle of l_hip and r_hip
|
| 50 |
+
keypoints_new[:, 0] = (keypoints[:, 9] + keypoints[:, 6]) / 2
|
| 51 |
+
# thorax is in the middle of l_shoulder and r_shoulder
|
| 52 |
+
keypoints_new[:, 8] = (keypoints[:, 3] + keypoints[:, 0]) / 2
|
| 53 |
+
# spine is in the middle of thorax and pelvis
|
| 54 |
+
keypoints_new[:,
|
| 55 |
+
7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2
|
| 56 |
+
# neck base (top end of neck) is 1/4 the way from
|
| 57 |
+
# neck (bottom end of neck) to head top
|
| 58 |
+
keypoints_new[:, 9] = (3 * keypoints[:, 13] + keypoints[:, 12]) / 4
|
| 59 |
+
# head (spherical centre of head) is 7/12 the way from
|
| 60 |
+
# neck (bottom end of neck) to head top
|
| 61 |
+
keypoints_new[:, 10] = (5 * keypoints[:, 13] +
|
| 62 |
+
7 * keypoints[:, 12]) / 12
|
| 63 |
+
|
| 64 |
+
keypoints_new[:, [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16]] = \
|
| 65 |
+
keypoints[:, [6, 7, 8, 9, 10, 11, 3, 4, 5, 0, 1, 2]]
|
| 66 |
+
elif pose_det_dataset in ['crowdpose']:
|
| 67 |
+
# pelvis (root) is in the middle of l_hip and r_hip
|
| 68 |
+
keypoints_new[:, 0] = (keypoints[:, 6] + keypoints[:, 7]) / 2
|
| 69 |
+
# thorax is in the middle of l_shoulder and r_shoulder
|
| 70 |
+
keypoints_new[:, 8] = (keypoints[:, 0] + keypoints[:, 1]) / 2
|
| 71 |
+
# spine is in the middle of thorax and pelvis
|
| 72 |
+
keypoints_new[:,
|
| 73 |
+
7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2
|
| 74 |
+
# neck base (top end of neck) is 1/4 the way from
|
| 75 |
+
# neck (bottom end of neck) to head top
|
| 76 |
+
keypoints_new[:, 9] = (3 * keypoints[:, 13] + keypoints[:, 12]) / 4
|
| 77 |
+
# head (spherical centre of head) is 7/12 the way from
|
| 78 |
+
# neck (bottom end of neck) to head top
|
| 79 |
+
keypoints_new[:, 10] = (5 * keypoints[:, 13] +
|
| 80 |
+
7 * keypoints[:, 12]) / 12
|
| 81 |
+
|
| 82 |
+
keypoints_new[:, [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16]] = \
|
| 83 |
+
keypoints[:, [7, 9, 11, 6, 8, 10, 0, 2, 4, 1, 3, 5]]
|
| 84 |
+
else:
|
| 85 |
+
raise NotImplementedError(
|
| 86 |
+
f'unsupported conversion between {pose_lift_dataset} and '
|
| 87 |
+
f'{pose_det_dataset}')
|
| 88 |
+
|
| 89 |
+
return keypoints_new
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def extract_pose_sequence(pose_results, frame_idx, causal, seq_len, step=1):
|
| 93 |
+
"""Extract the target frame from 2D pose results, and pad the sequence to a
|
| 94 |
+
fixed length.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
pose_results (List[List[:obj:`PoseDataSample`]]): Multi-frame pose
|
| 98 |
+
detection results stored in a list.
|
| 99 |
+
frame_idx (int): The index of the frame in the original video.
|
| 100 |
+
causal (bool): If True, the target frame is the last frame in
|
| 101 |
+
a sequence. Otherwise, the target frame is in the middle of
|
| 102 |
+
a sequence.
|
| 103 |
+
seq_len (int): The number of frames in the input sequence.
|
| 104 |
+
step (int): Step size to extract frames from the video.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List[List[:obj:`PoseDataSample`]]: Multi-frame pose detection results
|
| 108 |
+
stored in a nested list with a length of seq_len.
|
| 109 |
+
"""
|
| 110 |
+
if causal:
|
| 111 |
+
frames_left = seq_len - 1
|
| 112 |
+
frames_right = 0
|
| 113 |
+
else:
|
| 114 |
+
frames_left = (seq_len - 1) // 2
|
| 115 |
+
frames_right = frames_left
|
| 116 |
+
num_frames = len(pose_results)
|
| 117 |
+
|
| 118 |
+
# get the padded sequence
|
| 119 |
+
pad_left = max(0, frames_left - frame_idx // step)
|
| 120 |
+
pad_right = max(0, frames_right - (num_frames - 1 - frame_idx) // step)
|
| 121 |
+
start = max(frame_idx % step, frame_idx - frames_left * step)
|
| 122 |
+
end = min(num_frames - (num_frames - 1 - frame_idx) % step,
|
| 123 |
+
frame_idx + frames_right * step + 1)
|
| 124 |
+
pose_results_seq = [pose_results[0]] * pad_left + \
|
| 125 |
+
pose_results[start:end:step] + [pose_results[-1]] * pad_right
|
| 126 |
+
return pose_results_seq
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def collate_pose_sequence(pose_results_2d,
|
| 130 |
+
with_track_id=True,
|
| 131 |
+
target_frame=-1):
|
| 132 |
+
"""Reorganize multi-frame pose detection results into individual pose
|
| 133 |
+
sequences.
|
| 134 |
+
|
| 135 |
+
Note:
|
| 136 |
+
- The temporal length of the pose detection results: T
|
| 137 |
+
- The number of the person instances: N
|
| 138 |
+
- The number of the keypoints: K
|
| 139 |
+
- The channel number of each keypoint: C
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
pose_results_2d (List[List[:obj:`PoseDataSample`]]): Multi-frame pose
|
| 143 |
+
detection results stored in a nested list. Each element of the
|
| 144 |
+
outer list is the pose detection results of a single frame, and
|
| 145 |
+
each element of the inner list is the pose information of one
|
| 146 |
+
person, which contains:
|
| 147 |
+
|
| 148 |
+
- keypoints (ndarray[K, 2 or 3]): x, y, [score]
|
| 149 |
+
- track_id (int): unique id of each person, required when
|
| 150 |
+
``with_track_id==True```
|
| 151 |
+
|
| 152 |
+
with_track_id (bool): If True, the element in pose_results is expected
|
| 153 |
+
to contain "track_id", which will be used to gather the pose
|
| 154 |
+
sequence of a person from multiple frames. Otherwise, the pose
|
| 155 |
+
results in each frame are expected to have a consistent number and
|
| 156 |
+
order of identities. Default is True.
|
| 157 |
+
target_frame (int): The index of the target frame. Default: -1.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
List[:obj:`PoseDataSample`]: Indivisual pose sequence in with length N.
|
| 161 |
+
"""
|
| 162 |
+
T = len(pose_results_2d)
|
| 163 |
+
assert T > 0
|
| 164 |
+
|
| 165 |
+
target_frame = (T + target_frame) % T # convert negative index to positive
|
| 166 |
+
|
| 167 |
+
N = len(
|
| 168 |
+
pose_results_2d[target_frame]) # use identities in the target frame
|
| 169 |
+
if N == 0:
|
| 170 |
+
return []
|
| 171 |
+
|
| 172 |
+
B, K, C = pose_results_2d[target_frame][0].pred_instances.keypoints.shape
|
| 173 |
+
|
| 174 |
+
track_ids = None
|
| 175 |
+
if with_track_id:
|
| 176 |
+
track_ids = [res.track_id for res in pose_results_2d[target_frame]]
|
| 177 |
+
|
| 178 |
+
pose_sequences = []
|
| 179 |
+
for idx in range(N):
|
| 180 |
+
pose_seq = PoseDataSample()
|
| 181 |
+
pred_instances = InstanceData()
|
| 182 |
+
|
| 183 |
+
gt_instances = pose_results_2d[target_frame][idx].gt_instances.clone()
|
| 184 |
+
pred_instances = pose_results_2d[target_frame][
|
| 185 |
+
idx].pred_instances.clone()
|
| 186 |
+
pose_seq.pred_instances = pred_instances
|
| 187 |
+
pose_seq.gt_instances = gt_instances
|
| 188 |
+
|
| 189 |
+
if not with_track_id:
|
| 190 |
+
pose_seq.pred_instances.keypoints = np.stack([
|
| 191 |
+
frame[idx].pred_instances.keypoints
|
| 192 |
+
for frame in pose_results_2d
|
| 193 |
+
],
|
| 194 |
+
axis=1)
|
| 195 |
+
else:
|
| 196 |
+
keypoints = np.zeros((B, T, K, C), dtype=np.float32)
|
| 197 |
+
keypoints[:, target_frame] = pose_results_2d[target_frame][
|
| 198 |
+
idx].pred_instances.keypoints
|
| 199 |
+
# find the left most frame containing track_ids[idx]
|
| 200 |
+
for frame_idx in range(target_frame - 1, -1, -1):
|
| 201 |
+
contains_idx = False
|
| 202 |
+
for res in pose_results_2d[frame_idx]:
|
| 203 |
+
if res.track_id == track_ids[idx]:
|
| 204 |
+
keypoints[:, frame_idx] = res.pred_instances.keypoints
|
| 205 |
+
contains_idx = True
|
| 206 |
+
break
|
| 207 |
+
if not contains_idx:
|
| 208 |
+
# replicate the left most frame
|
| 209 |
+
keypoints[:, :frame_idx + 1] = keypoints[:, frame_idx + 1]
|
| 210 |
+
break
|
| 211 |
+
# find the right most frame containing track_idx[idx]
|
| 212 |
+
for frame_idx in range(target_frame + 1, T):
|
| 213 |
+
contains_idx = False
|
| 214 |
+
for res in pose_results_2d[frame_idx]:
|
| 215 |
+
if res.track_id == track_ids[idx]:
|
| 216 |
+
keypoints[:, frame_idx] = res.pred_instances.keypoints
|
| 217 |
+
contains_idx = True
|
| 218 |
+
break
|
| 219 |
+
if not contains_idx:
|
| 220 |
+
# replicate the right most frame
|
| 221 |
+
keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx]
|
| 222 |
+
break
|
| 223 |
+
pose_seq.pred_instances.set_field(keypoints, 'keypoints')
|
| 224 |
+
pose_sequences.append(pose_seq)
|
| 225 |
+
|
| 226 |
+
return pose_sequences
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def inference_pose_lifter_model(model,
|
| 230 |
+
pose_results_2d,
|
| 231 |
+
with_track_id=True,
|
| 232 |
+
image_size=None,
|
| 233 |
+
norm_pose_2d=False):
|
| 234 |
+
"""Inference 3D pose from 2D pose sequences using a pose lifter model.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
model (nn.Module): The loaded pose lifter model
|
| 238 |
+
pose_results_2d (List[List[:obj:`PoseDataSample`]]): The 2D pose
|
| 239 |
+
sequences stored in a nested list.
|
| 240 |
+
with_track_id: If True, the element in pose_results_2d is expected to
|
| 241 |
+
contain "track_id", which will be used to gather the pose sequence
|
| 242 |
+
of a person from multiple frames. Otherwise, the pose results in
|
| 243 |
+
each frame are expected to have a consistent number and order of
|
| 244 |
+
identities. Default is True.
|
| 245 |
+
image_size (tuple|list): image width, image height. If None, image size
|
| 246 |
+
will not be contained in dict ``data``.
|
| 247 |
+
norm_pose_2d (bool): If True, scale the bbox (along with the 2D
|
| 248 |
+
pose) to the average bbox scale of the dataset, and move the bbox
|
| 249 |
+
(along with the 2D pose) to the average bbox center of the dataset.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
List[:obj:`PoseDataSample`]: 3D pose inference results. Specifically,
|
| 253 |
+
the predicted keypoints and scores are saved at
|
| 254 |
+
``data_sample.pred_instances.keypoints_3d``.
|
| 255 |
+
"""
|
| 256 |
+
init_default_scope(model.cfg.get('default_scope', 'mmpose'))
|
| 257 |
+
pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
|
| 258 |
+
|
| 259 |
+
causal = model.cfg.test_dataloader.dataset.get('causal', False)
|
| 260 |
+
target_idx = -1 if causal else len(pose_results_2d) // 2
|
| 261 |
+
|
| 262 |
+
dataset_info = model.dataset_meta
|
| 263 |
+
if dataset_info is not None:
|
| 264 |
+
if 'stats_info' in dataset_info:
|
| 265 |
+
bbox_center = dataset_info['stats_info']['bbox_center']
|
| 266 |
+
bbox_scale = dataset_info['stats_info']['bbox_scale']
|
| 267 |
+
else:
|
| 268 |
+
if norm_pose_2d:
|
| 269 |
+
# compute the average bbox center and scale from the
|
| 270 |
+
# datasamples in pose_results_2d
|
| 271 |
+
bbox_center = np.zeros((1, 2), dtype=np.float32)
|
| 272 |
+
bbox_scale = 0
|
| 273 |
+
num_bbox = 0
|
| 274 |
+
for pose_res in pose_results_2d:
|
| 275 |
+
for data_sample in pose_res:
|
| 276 |
+
for bbox in data_sample.pred_instances.bboxes:
|
| 277 |
+
bbox_center += np.array([[(bbox[0] + bbox[2]) / 2,
|
| 278 |
+
(bbox[1] + bbox[3]) / 2]
|
| 279 |
+
])
|
| 280 |
+
bbox_scale += max(bbox[2] - bbox[0],
|
| 281 |
+
bbox[3] - bbox[1])
|
| 282 |
+
num_bbox += 1
|
| 283 |
+
bbox_center /= num_bbox
|
| 284 |
+
bbox_scale /= num_bbox
|
| 285 |
+
else:
|
| 286 |
+
bbox_center = None
|
| 287 |
+
bbox_scale = None
|
| 288 |
+
|
| 289 |
+
pose_results_2d_copy = []
|
| 290 |
+
for i, pose_res in enumerate(pose_results_2d):
|
| 291 |
+
pose_res_copy = []
|
| 292 |
+
for j, data_sample in enumerate(pose_res):
|
| 293 |
+
data_sample_copy = PoseDataSample()
|
| 294 |
+
data_sample_copy.gt_instances = data_sample.gt_instances.clone()
|
| 295 |
+
data_sample_copy.pred_instances = data_sample.pred_instances.clone(
|
| 296 |
+
)
|
| 297 |
+
data_sample_copy.track_id = data_sample.track_id
|
| 298 |
+
kpts = data_sample.pred_instances.keypoints
|
| 299 |
+
bboxes = data_sample.pred_instances.bboxes
|
| 300 |
+
keypoints = []
|
| 301 |
+
for k in range(len(kpts)):
|
| 302 |
+
kpt = kpts[k]
|
| 303 |
+
if norm_pose_2d:
|
| 304 |
+
bbox = bboxes[k]
|
| 305 |
+
center = np.array([[(bbox[0] + bbox[2]) / 2,
|
| 306 |
+
(bbox[1] + bbox[3]) / 2]])
|
| 307 |
+
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
| 308 |
+
keypoints.append((kpt[:, :2] - center) / scale *
|
| 309 |
+
bbox_scale + bbox_center)
|
| 310 |
+
else:
|
| 311 |
+
keypoints.append(kpt[:, :2])
|
| 312 |
+
data_sample_copy.pred_instances.set_field(
|
| 313 |
+
np.array(keypoints), 'keypoints')
|
| 314 |
+
pose_res_copy.append(data_sample_copy)
|
| 315 |
+
pose_results_2d_copy.append(pose_res_copy)
|
| 316 |
+
|
| 317 |
+
pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy,
|
| 318 |
+
with_track_id, target_idx)
|
| 319 |
+
|
| 320 |
+
if not pose_sequences_2d:
|
| 321 |
+
return []
|
| 322 |
+
|
| 323 |
+
data_list = []
|
| 324 |
+
for i, pose_seq in enumerate(pose_sequences_2d):
|
| 325 |
+
data_info = dict()
|
| 326 |
+
|
| 327 |
+
keypoints_2d = pose_seq.pred_instances.keypoints
|
| 328 |
+
keypoints_2d = np.squeeze(
|
| 329 |
+
keypoints_2d, axis=0) if keypoints_2d.ndim == 4 else keypoints_2d
|
| 330 |
+
|
| 331 |
+
T, K, C = keypoints_2d.shape
|
| 332 |
+
|
| 333 |
+
data_info['keypoints'] = keypoints_2d
|
| 334 |
+
data_info['keypoints_visible'] = np.ones((
|
| 335 |
+
T,
|
| 336 |
+
K,
|
| 337 |
+
), dtype=np.float32)
|
| 338 |
+
data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32)
|
| 339 |
+
data_info['factor'] = np.zeros((T, ), dtype=np.float32)
|
| 340 |
+
data_info['lifting_target_visible'] = np.ones((1, K, 1),
|
| 341 |
+
dtype=np.float32)
|
| 342 |
+
|
| 343 |
+
if image_size is not None:
|
| 344 |
+
assert len(image_size) == 2
|
| 345 |
+
data_info['camera_param'] = dict(w=image_size[0], h=image_size[1])
|
| 346 |
+
|
| 347 |
+
data_info.update(model.dataset_meta)
|
| 348 |
+
data_list.append(pipeline(data_info))
|
| 349 |
+
|
| 350 |
+
if data_list:
|
| 351 |
+
# collate data list into a batch, which is a dict with following keys:
|
| 352 |
+
# batch['inputs']: a list of input images
|
| 353 |
+
# batch['data_samples']: a list of :obj:`PoseDataSample`
|
| 354 |
+
batch = pseudo_collate(data_list)
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
results = model.test_step(batch)
|
| 357 |
+
else:
|
| 358 |
+
results = []
|
| 359 |
+
|
| 360 |
+
return results
|
mmpose/apis/inference_tracking.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from mmpose.evaluation.functional.nms import oks_iou
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _compute_iou(bboxA, bboxB):
|
| 10 |
+
"""Compute the Intersection over Union (IoU) between two boxes .
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
bboxA (list): The first bbox info (left, top, right, bottom, score).
|
| 14 |
+
bboxB (list): The second bbox info (left, top, right, bottom, score).
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
float: The IoU value.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
x1 = max(bboxA[0], bboxB[0])
|
| 21 |
+
y1 = max(bboxA[1], bboxB[1])
|
| 22 |
+
x2 = min(bboxA[2], bboxB[2])
|
| 23 |
+
y2 = min(bboxA[3], bboxB[3])
|
| 24 |
+
|
| 25 |
+
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
|
| 26 |
+
|
| 27 |
+
bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1])
|
| 28 |
+
bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1])
|
| 29 |
+
union_area = float(bboxA_area + bboxB_area - inter_area)
|
| 30 |
+
if union_area == 0:
|
| 31 |
+
union_area = 1e-5
|
| 32 |
+
warnings.warn('union_area=0 is unexpected')
|
| 33 |
+
|
| 34 |
+
iou = inter_area / union_area
|
| 35 |
+
|
| 36 |
+
return iou
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _track_by_iou(res, results_last, thr):
|
| 40 |
+
"""Get track id using IoU tracking greedily."""
|
| 41 |
+
|
| 42 |
+
bbox = list(np.squeeze(res.pred_instances.bboxes, axis=0))
|
| 43 |
+
|
| 44 |
+
max_iou_score = -1
|
| 45 |
+
max_index = -1
|
| 46 |
+
match_result = {}
|
| 47 |
+
for index, res_last in enumerate(results_last):
|
| 48 |
+
bbox_last = list(np.squeeze(res_last.pred_instances.bboxes, axis=0))
|
| 49 |
+
|
| 50 |
+
iou_score = _compute_iou(bbox, bbox_last)
|
| 51 |
+
if iou_score > max_iou_score:
|
| 52 |
+
max_iou_score = iou_score
|
| 53 |
+
max_index = index
|
| 54 |
+
|
| 55 |
+
if max_iou_score > thr:
|
| 56 |
+
track_id = results_last[max_index].track_id
|
| 57 |
+
match_result = results_last[max_index]
|
| 58 |
+
del results_last[max_index]
|
| 59 |
+
else:
|
| 60 |
+
track_id = -1
|
| 61 |
+
|
| 62 |
+
return track_id, results_last, match_result
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _track_by_oks(res, results_last, thr, sigmas=None):
|
| 66 |
+
"""Get track id using OKS tracking greedily."""
|
| 67 |
+
keypoint = np.concatenate((res.pred_instances.keypoints,
|
| 68 |
+
res.pred_instances.keypoint_scores[:, :, None]),
|
| 69 |
+
axis=2)
|
| 70 |
+
keypoint = np.squeeze(keypoint, axis=0).reshape((-1))
|
| 71 |
+
area = np.squeeze(res.pred_instances.areas, axis=0)
|
| 72 |
+
max_index = -1
|
| 73 |
+
match_result = {}
|
| 74 |
+
|
| 75 |
+
if len(results_last) == 0:
|
| 76 |
+
return -1, results_last, match_result
|
| 77 |
+
|
| 78 |
+
keypoints_last = np.array([
|
| 79 |
+
np.squeeze(
|
| 80 |
+
np.concatenate(
|
| 81 |
+
(res_last.pred_instances.keypoints,
|
| 82 |
+
res_last.pred_instances.keypoint_scores[:, :, None]),
|
| 83 |
+
axis=2),
|
| 84 |
+
axis=0).reshape((-1)) for res_last in results_last
|
| 85 |
+
])
|
| 86 |
+
area_last = np.array([
|
| 87 |
+
np.squeeze(res_last.pred_instances.areas, axis=0)
|
| 88 |
+
for res_last in results_last
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
oks_score = oks_iou(
|
| 92 |
+
keypoint, keypoints_last, area, area_last, sigmas=sigmas)
|
| 93 |
+
|
| 94 |
+
max_index = np.argmax(oks_score)
|
| 95 |
+
|
| 96 |
+
if oks_score[max_index] > thr:
|
| 97 |
+
track_id = results_last[max_index].track_id
|
| 98 |
+
match_result = results_last[max_index]
|
| 99 |
+
del results_last[max_index]
|
| 100 |
+
else:
|
| 101 |
+
track_id = -1
|
| 102 |
+
|
| 103 |
+
return track_id, results_last, match_result
|
mmpose/apis/inferencers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .hand3d_inferencer import Hand3DInferencer
|
| 3 |
+
from .mmpose_inferencer import MMPoseInferencer
|
| 4 |
+
from .pose2d_inferencer import Pose2DInferencer
|
| 5 |
+
from .pose3d_inferencer import Pose3DInferencer
|
| 6 |
+
from .utils import get_model_aliases
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'Pose2DInferencer', 'MMPoseInferencer', 'get_model_aliases',
|
| 10 |
+
'Pose3DInferencer', 'Hand3DInferencer'
|
| 11 |
+
]
|
mmpose/apis/inferencers/base_mmpose_inferencer.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import inspect
|
| 3 |
+
import logging
|
| 4 |
+
import mimetypes
|
| 5 |
+
import os
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import (Callable, Dict, Generator, Iterable, List, Optional,
|
| 8 |
+
Sequence, Tuple, Union)
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import mmcv
|
| 12 |
+
import mmengine
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from mmengine.config import Config, ConfigDict
|
| 16 |
+
from mmengine.dataset import Compose
|
| 17 |
+
from mmengine.fileio import (get_file_backend, isdir, join_path,
|
| 18 |
+
list_dir_or_file)
|
| 19 |
+
from mmengine.infer.infer import BaseInferencer, ModelType
|
| 20 |
+
from mmengine.logging import print_log
|
| 21 |
+
from mmengine.registry import init_default_scope
|
| 22 |
+
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
| 23 |
+
from mmengine.structures import InstanceData
|
| 24 |
+
from mmengine.utils import mkdir_or_exist
|
| 25 |
+
from rich.progress import track
|
| 26 |
+
|
| 27 |
+
from mmpose.apis.inference import dataset_meta_from_config
|
| 28 |
+
from mmpose.registry import DATASETS
|
| 29 |
+
from mmpose.structures import PoseDataSample, split_instances
|
| 30 |
+
from .utils import default_det_models
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from mmdet.apis.det_inferencer import DetInferencer
|
| 34 |
+
has_mmdet = True
|
| 35 |
+
except (ImportError, ModuleNotFoundError):
|
| 36 |
+
has_mmdet = False
|
| 37 |
+
|
| 38 |
+
InstanceList = List[InstanceData]
|
| 39 |
+
InputType = Union[str, np.ndarray]
|
| 40 |
+
InputsType = Union[InputType, Sequence[InputType]]
|
| 41 |
+
PredType = Union[InstanceData, InstanceList]
|
| 42 |
+
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
| 43 |
+
ConfigType = Union[Config, ConfigDict]
|
| 44 |
+
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BaseMMPoseInferencer(BaseInferencer):
|
| 48 |
+
"""The base class for MMPose inferencers."""
|
| 49 |
+
|
| 50 |
+
preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
|
| 51 |
+
forward_kwargs: set = set()
|
| 52 |
+
visualize_kwargs: set = {
|
| 53 |
+
'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
|
| 54 |
+
'kpt_thr', 'vis_out_dir', 'black_background'
|
| 55 |
+
}
|
| 56 |
+
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
| 57 |
+
|
| 58 |
+
def __init__(self,
|
| 59 |
+
model: Union[ModelType, str, None] = None,
|
| 60 |
+
weights: Optional[str] = None,
|
| 61 |
+
device: Optional[str] = None,
|
| 62 |
+
scope: Optional[str] = None,
|
| 63 |
+
show_progress: bool = False) -> None:
|
| 64 |
+
super().__init__(
|
| 65 |
+
model, weights, device, scope, show_progress=show_progress)
|
| 66 |
+
|
| 67 |
+
def _init_detector(
|
| 68 |
+
self,
|
| 69 |
+
det_model: Optional[Union[ModelType, str]] = None,
|
| 70 |
+
det_weights: Optional[str] = None,
|
| 71 |
+
det_cat_ids: Optional[Union[int, Tuple]] = None,
|
| 72 |
+
device: Optional[str] = None,
|
| 73 |
+
):
|
| 74 |
+
object_type = DATASETS.get(self.cfg.dataset_type).__module__.split(
|
| 75 |
+
'datasets.')[-1].split('.')[0].lower()
|
| 76 |
+
|
| 77 |
+
if det_model in ('whole_image', 'whole-image') or \
|
| 78 |
+
(det_model is None and
|
| 79 |
+
object_type not in default_det_models):
|
| 80 |
+
self.detector = None
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
det_scope = 'mmdet'
|
| 84 |
+
if det_model is None:
|
| 85 |
+
det_info = default_det_models[object_type]
|
| 86 |
+
det_model, det_weights, det_cat_ids = det_info[
|
| 87 |
+
'model'], det_info['weights'], det_info['cat_ids']
|
| 88 |
+
elif os.path.exists(det_model):
|
| 89 |
+
det_cfg = Config.fromfile(det_model)
|
| 90 |
+
det_scope = det_cfg.default_scope
|
| 91 |
+
|
| 92 |
+
if has_mmdet:
|
| 93 |
+
det_kwargs = dict(
|
| 94 |
+
model=det_model,
|
| 95 |
+
weights=det_weights,
|
| 96 |
+
device=device,
|
| 97 |
+
scope=det_scope,
|
| 98 |
+
)
|
| 99 |
+
# for compatibility with low version of mmdet
|
| 100 |
+
if 'show_progress' in inspect.signature(
|
| 101 |
+
DetInferencer).parameters:
|
| 102 |
+
det_kwargs['show_progress'] = False
|
| 103 |
+
|
| 104 |
+
self.detector = DetInferencer(**det_kwargs)
|
| 105 |
+
else:
|
| 106 |
+
raise RuntimeError(
|
| 107 |
+
'MMDetection (v3.0.0 or above) is required to build '
|
| 108 |
+
'inferencers for top-down pose estimation models.')
|
| 109 |
+
|
| 110 |
+
if isinstance(det_cat_ids, (tuple, list)):
|
| 111 |
+
self.det_cat_ids = det_cat_ids
|
| 112 |
+
else:
|
| 113 |
+
self.det_cat_ids = (det_cat_ids, )
|
| 114 |
+
|
| 115 |
+
def _load_weights_to_model(self, model: nn.Module,
|
| 116 |
+
checkpoint: Optional[dict],
|
| 117 |
+
cfg: Optional[ConfigType]) -> None:
|
| 118 |
+
"""Loading model weights and meta information from cfg and checkpoint.
|
| 119 |
+
|
| 120 |
+
Subclasses could override this method to load extra meta information
|
| 121 |
+
from ``checkpoint`` and ``cfg`` to model.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
model (nn.Module): Model to load weights and meta information.
|
| 125 |
+
checkpoint (dict, optional): The loaded checkpoint.
|
| 126 |
+
cfg (Config or ConfigDict, optional): The loaded config.
|
| 127 |
+
"""
|
| 128 |
+
if checkpoint is not None:
|
| 129 |
+
_load_checkpoint_to_model(model, checkpoint)
|
| 130 |
+
checkpoint_meta = checkpoint.get('meta', {})
|
| 131 |
+
# save the dataset_meta in the model for convenience
|
| 132 |
+
if 'dataset_meta' in checkpoint_meta:
|
| 133 |
+
# mmpose 1.x
|
| 134 |
+
model.dataset_meta = checkpoint_meta['dataset_meta']
|
| 135 |
+
else:
|
| 136 |
+
print_log(
|
| 137 |
+
'dataset_meta are not saved in the checkpoint\'s '
|
| 138 |
+
'meta data, load via config.',
|
| 139 |
+
logger='current',
|
| 140 |
+
level=logging.WARNING)
|
| 141 |
+
model.dataset_meta = dataset_meta_from_config(
|
| 142 |
+
cfg, dataset_mode='train')
|
| 143 |
+
else:
|
| 144 |
+
print_log(
|
| 145 |
+
'Checkpoint is not loaded, and the inference '
|
| 146 |
+
'result is calculated by the randomly initialized '
|
| 147 |
+
'model!',
|
| 148 |
+
logger='current',
|
| 149 |
+
level=logging.WARNING)
|
| 150 |
+
model.dataset_meta = dataset_meta_from_config(
|
| 151 |
+
cfg, dataset_mode='train')
|
| 152 |
+
|
| 153 |
+
def _inputs_to_list(self, inputs: InputsType) -> Iterable:
|
| 154 |
+
"""Preprocess the inputs to a list.
|
| 155 |
+
|
| 156 |
+
Preprocess inputs to a list according to its type:
|
| 157 |
+
|
| 158 |
+
- list or tuple: return inputs
|
| 159 |
+
- str:
|
| 160 |
+
- Directory path: return all files in the directory
|
| 161 |
+
- other cases: return a list containing the string. The string
|
| 162 |
+
could be a path to file, a url or other types of string
|
| 163 |
+
according to the task.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
inputs (InputsType): Inputs for the inferencer.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
list: List of input for the :meth:`preprocess`.
|
| 170 |
+
"""
|
| 171 |
+
self._video_input = False
|
| 172 |
+
|
| 173 |
+
if isinstance(inputs, str):
|
| 174 |
+
backend = get_file_backend(inputs)
|
| 175 |
+
if hasattr(backend, 'isdir') and isdir(inputs):
|
| 176 |
+
# Backends like HttpsBackend do not implement `isdir`, so only
|
| 177 |
+
# those backends that implement `isdir` could accept the
|
| 178 |
+
# inputs as a directory
|
| 179 |
+
filepath_list = [
|
| 180 |
+
join_path(inputs, fname)
|
| 181 |
+
for fname in list_dir_or_file(inputs, list_dir=False)
|
| 182 |
+
]
|
| 183 |
+
inputs = []
|
| 184 |
+
for filepath in filepath_list:
|
| 185 |
+
input_type = mimetypes.guess_type(filepath)[0].split(
|
| 186 |
+
'/')[0]
|
| 187 |
+
if input_type == 'image':
|
| 188 |
+
inputs.append(filepath)
|
| 189 |
+
inputs.sort()
|
| 190 |
+
else:
|
| 191 |
+
# if inputs is a path to a video file, it will be converted
|
| 192 |
+
# to a list containing separated frame filenames
|
| 193 |
+
input_type = mimetypes.guess_type(inputs)[0].split('/')[0]
|
| 194 |
+
if input_type == 'video':
|
| 195 |
+
self._video_input = True
|
| 196 |
+
video = mmcv.VideoReader(inputs)
|
| 197 |
+
self.video_info = dict(
|
| 198 |
+
fps=video.fps,
|
| 199 |
+
name=os.path.basename(inputs),
|
| 200 |
+
writer=None,
|
| 201 |
+
width=video.width,
|
| 202 |
+
height=video.height,
|
| 203 |
+
predictions=[])
|
| 204 |
+
inputs = video
|
| 205 |
+
elif input_type == 'image':
|
| 206 |
+
inputs = [inputs]
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f'Expected input to be an image, video, '
|
| 209 |
+
f'or folder, but received {inputs} of '
|
| 210 |
+
f'type {input_type}.')
|
| 211 |
+
|
| 212 |
+
elif isinstance(inputs, np.ndarray):
|
| 213 |
+
inputs = [inputs]
|
| 214 |
+
|
| 215 |
+
return inputs
|
| 216 |
+
|
| 217 |
+
def _get_webcam_inputs(self, inputs: str) -> Generator:
|
| 218 |
+
"""Sets up and returns a generator function that reads frames from a
|
| 219 |
+
webcam input. The generator function returns a new frame each time it
|
| 220 |
+
is iterated over.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
inputs (str): A string describing the webcam input, in the format
|
| 224 |
+
"webcam:id".
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
A generator function that yields frames from the webcam input.
|
| 228 |
+
|
| 229 |
+
Raises:
|
| 230 |
+
ValueError: If the inputs string is not in the expected format.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
# Ensure the inputs string is in the expected format.
|
| 234 |
+
inputs = inputs.lower()
|
| 235 |
+
assert inputs.startswith('webcam'), f'Expected input to start with ' \
|
| 236 |
+
f'"webcam", but got "{inputs}"'
|
| 237 |
+
|
| 238 |
+
# Parse the camera ID from the inputs string.
|
| 239 |
+
inputs_ = inputs.split(':')
|
| 240 |
+
if len(inputs_) == 1:
|
| 241 |
+
camera_id = 0
|
| 242 |
+
elif len(inputs_) == 2 and str.isdigit(inputs_[1]):
|
| 243 |
+
camera_id = int(inputs_[1])
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f'Expected webcam input to have format "webcam:id", '
|
| 247 |
+
f'but got "{inputs}"')
|
| 248 |
+
|
| 249 |
+
# Attempt to open the video capture object.
|
| 250 |
+
vcap = cv2.VideoCapture(camera_id)
|
| 251 |
+
if not vcap.isOpened():
|
| 252 |
+
print_log(
|
| 253 |
+
f'Cannot open camera (ID={camera_id})',
|
| 254 |
+
logger='current',
|
| 255 |
+
level=logging.WARNING)
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
# Set video input flag and metadata.
|
| 259 |
+
self._video_input = True
|
| 260 |
+
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
|
| 261 |
+
if int(major_ver) < 3:
|
| 262 |
+
fps = vcap.get(cv2.cv.CV_CAP_PROP_FPS)
|
| 263 |
+
width = vcap.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
|
| 264 |
+
height = vcap.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
|
| 265 |
+
else:
|
| 266 |
+
fps = vcap.get(cv2.CAP_PROP_FPS)
|
| 267 |
+
width = vcap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
| 268 |
+
height = vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
| 269 |
+
self.video_info = dict(
|
| 270 |
+
fps=fps,
|
| 271 |
+
name='webcam.mp4',
|
| 272 |
+
writer=None,
|
| 273 |
+
width=width,
|
| 274 |
+
height=height,
|
| 275 |
+
predictions=[])
|
| 276 |
+
|
| 277 |
+
def _webcam_reader() -> Generator:
|
| 278 |
+
while True:
|
| 279 |
+
if cv2.waitKey(5) & 0xFF == 27:
|
| 280 |
+
vcap.release()
|
| 281 |
+
break
|
| 282 |
+
|
| 283 |
+
ret_val, frame = vcap.read()
|
| 284 |
+
if not ret_val:
|
| 285 |
+
break
|
| 286 |
+
|
| 287 |
+
yield frame
|
| 288 |
+
|
| 289 |
+
return _webcam_reader()
|
| 290 |
+
|
| 291 |
+
def _init_pipeline(self, cfg: ConfigType) -> Callable:
|
| 292 |
+
"""Initialize the test pipeline.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
cfg (ConfigType): model config path or dict
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
A pipeline to handle various input data, such as ``str``,
|
| 299 |
+
``np.ndarray``. The returned pipeline will be used to process
|
| 300 |
+
a single data.
|
| 301 |
+
"""
|
| 302 |
+
scope = cfg.get('default_scope', 'mmpose')
|
| 303 |
+
if scope is not None:
|
| 304 |
+
init_default_scope(scope)
|
| 305 |
+
return Compose(cfg.test_dataloader.dataset.pipeline)
|
| 306 |
+
|
| 307 |
+
def update_model_visualizer_settings(self, **kwargs):
|
| 308 |
+
"""Update the settings of models and visualizer according to inference
|
| 309 |
+
arguments."""
|
| 310 |
+
|
| 311 |
+
pass
|
| 312 |
+
|
| 313 |
+
def preprocess(self,
|
| 314 |
+
inputs: InputsType,
|
| 315 |
+
batch_size: int = 1,
|
| 316 |
+
bboxes: Optional[List] = None,
|
| 317 |
+
bbox_thr: float = 0.3,
|
| 318 |
+
nms_thr: float = 0.3,
|
| 319 |
+
**kwargs):
|
| 320 |
+
"""Process the inputs into a model-feedable format.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
inputs (InputsType): Inputs given by user.
|
| 324 |
+
batch_size (int): batch size. Defaults to 1.
|
| 325 |
+
bbox_thr (float): threshold for bounding box detection.
|
| 326 |
+
Defaults to 0.3.
|
| 327 |
+
nms_thr (float): IoU threshold for bounding box NMS.
|
| 328 |
+
Defaults to 0.3.
|
| 329 |
+
|
| 330 |
+
Yields:
|
| 331 |
+
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
| 332 |
+
List[str or np.ndarray]: List of original inputs in the batch
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
# One-stage pose estimators perform prediction filtering within the
|
| 336 |
+
# head's `predict` method. Here, we set the arguments for filtering
|
| 337 |
+
if self.cfg.model.type == 'BottomupPoseEstimator':
|
| 338 |
+
# 1. init with default arguments
|
| 339 |
+
test_cfg = self.model.head.test_cfg.copy()
|
| 340 |
+
# 2. update the score_thr and nms_thr in the test_cfg of the head
|
| 341 |
+
if 'score_thr' in test_cfg:
|
| 342 |
+
test_cfg['score_thr'] = bbox_thr
|
| 343 |
+
if 'nms_thr' in test_cfg:
|
| 344 |
+
test_cfg['nms_thr'] = nms_thr
|
| 345 |
+
self.model.test_cfg = test_cfg
|
| 346 |
+
|
| 347 |
+
for i, input in enumerate(inputs):
|
| 348 |
+
bbox = bboxes[i] if bboxes else []
|
| 349 |
+
data_infos = self.preprocess_single(
|
| 350 |
+
input,
|
| 351 |
+
index=i,
|
| 352 |
+
bboxes=bbox,
|
| 353 |
+
bbox_thr=bbox_thr,
|
| 354 |
+
nms_thr=nms_thr,
|
| 355 |
+
**kwargs)
|
| 356 |
+
# only supports inference with batch size 1
|
| 357 |
+
yield self.collate_fn(data_infos), [input]
|
| 358 |
+
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
inputs: InputsType,
|
| 362 |
+
return_datasamples: bool = False,
|
| 363 |
+
batch_size: int = 1,
|
| 364 |
+
out_dir: Optional[str] = None,
|
| 365 |
+
**kwargs,
|
| 366 |
+
) -> dict:
|
| 367 |
+
"""Call the inferencer.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
inputs (InputsType): Inputs for the inferencer.
|
| 371 |
+
return_datasamples (bool): Whether to return results as
|
| 372 |
+
:obj:`BaseDataElement`. Defaults to False.
|
| 373 |
+
batch_size (int): Batch size. Defaults to 1.
|
| 374 |
+
out_dir (str, optional): directory to save visualization
|
| 375 |
+
results and predictions. Will be overoden if vis_out_dir or
|
| 376 |
+
pred_out_dir are given. Defaults to None
|
| 377 |
+
**kwargs: Key words arguments passed to :meth:`preprocess`,
|
| 378 |
+
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
| 379 |
+
Each key in kwargs should be in the corresponding set of
|
| 380 |
+
``preprocess_kwargs``, ``forward_kwargs``,
|
| 381 |
+
``visualize_kwargs`` and ``postprocess_kwargs``.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
dict: Inference and visualization results.
|
| 385 |
+
"""
|
| 386 |
+
if out_dir is not None:
|
| 387 |
+
if 'vis_out_dir' not in kwargs:
|
| 388 |
+
kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
|
| 389 |
+
if 'pred_out_dir' not in kwargs:
|
| 390 |
+
kwargs['pred_out_dir'] = f'{out_dir}/predictions'
|
| 391 |
+
|
| 392 |
+
(
|
| 393 |
+
preprocess_kwargs,
|
| 394 |
+
forward_kwargs,
|
| 395 |
+
visualize_kwargs,
|
| 396 |
+
postprocess_kwargs,
|
| 397 |
+
) = self._dispatch_kwargs(**kwargs)
|
| 398 |
+
|
| 399 |
+
self.update_model_visualizer_settings(**kwargs)
|
| 400 |
+
|
| 401 |
+
# preprocessing
|
| 402 |
+
if isinstance(inputs, str) and inputs.startswith('webcam'):
|
| 403 |
+
inputs = self._get_webcam_inputs(inputs)
|
| 404 |
+
batch_size = 1
|
| 405 |
+
if not visualize_kwargs.get('show', False):
|
| 406 |
+
print_log(
|
| 407 |
+
'The display mode is closed when using webcam '
|
| 408 |
+
'input. It will be turned on automatically.',
|
| 409 |
+
logger='current',
|
| 410 |
+
level=logging.WARNING)
|
| 411 |
+
visualize_kwargs['show'] = True
|
| 412 |
+
else:
|
| 413 |
+
inputs = self._inputs_to_list(inputs)
|
| 414 |
+
|
| 415 |
+
# check the compatibility between inputs/outputs
|
| 416 |
+
if not self._video_input and len(inputs) > 0:
|
| 417 |
+
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
|
| 418 |
+
if vis_out_dir is not None:
|
| 419 |
+
_, file_extension = os.path.splitext(vis_out_dir)
|
| 420 |
+
assert not file_extension, f'the argument `vis_out_dir` ' \
|
| 421 |
+
f'should be a folder while the input contains multiple ' \
|
| 422 |
+
f'images, but got {vis_out_dir}'
|
| 423 |
+
|
| 424 |
+
if 'bbox_thr' in self.forward_kwargs:
|
| 425 |
+
forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
|
| 426 |
+
inputs = self.preprocess(
|
| 427 |
+
inputs, batch_size=batch_size, **preprocess_kwargs)
|
| 428 |
+
|
| 429 |
+
preds = []
|
| 430 |
+
|
| 431 |
+
for proc_inputs, ori_inputs in (track(inputs, description='Inference')
|
| 432 |
+
if self.show_progress else inputs):
|
| 433 |
+
preds = self.forward(proc_inputs, **forward_kwargs)
|
| 434 |
+
|
| 435 |
+
visualization = self.visualize(ori_inputs, preds,
|
| 436 |
+
**visualize_kwargs)
|
| 437 |
+
results = self.postprocess(
|
| 438 |
+
preds,
|
| 439 |
+
visualization,
|
| 440 |
+
return_datasamples=return_datasamples,
|
| 441 |
+
**postprocess_kwargs)
|
| 442 |
+
yield results
|
| 443 |
+
|
| 444 |
+
if self._video_input:
|
| 445 |
+
self._finalize_video_processing(
|
| 446 |
+
postprocess_kwargs.get('pred_out_dir', ''))
|
| 447 |
+
|
| 448 |
+
# In 3D Inferencers, some intermediate results (e.g. 2d keypoints)
|
| 449 |
+
# will be temporarily stored in `self._buffer`. It's essential to
|
| 450 |
+
# clear this information to prevent any interference with subsequent
|
| 451 |
+
# inferences.
|
| 452 |
+
if hasattr(self, '_buffer'):
|
| 453 |
+
self._buffer.clear()
|
| 454 |
+
|
| 455 |
+
def visualize(self,
|
| 456 |
+
inputs: list,
|
| 457 |
+
preds: List[PoseDataSample],
|
| 458 |
+
return_vis: bool = False,
|
| 459 |
+
show: bool = False,
|
| 460 |
+
draw_bbox: bool = False,
|
| 461 |
+
wait_time: float = 0,
|
| 462 |
+
radius: int = 3,
|
| 463 |
+
thickness: int = 1,
|
| 464 |
+
kpt_thr: float = 0.3,
|
| 465 |
+
vis_out_dir: str = '',
|
| 466 |
+
window_name: str = '',
|
| 467 |
+
black_background: bool = False,
|
| 468 |
+
**kwargs) -> List[np.ndarray]:
|
| 469 |
+
"""Visualize predictions.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
| 473 |
+
preds (Any): Predictions of the model.
|
| 474 |
+
return_vis (bool): Whether to return images with predicted results.
|
| 475 |
+
show (bool): Whether to display the image in a popup window.
|
| 476 |
+
Defaults to False.
|
| 477 |
+
wait_time (float): The interval of show (ms). Defaults to 0
|
| 478 |
+
draw_bbox (bool): Whether to draw the bounding boxes.
|
| 479 |
+
Defaults to False
|
| 480 |
+
radius (int): Keypoint radius for visualization. Defaults to 3
|
| 481 |
+
thickness (int): Link thickness for visualization. Defaults to 1
|
| 482 |
+
kpt_thr (float): The threshold to visualize the keypoints.
|
| 483 |
+
Defaults to 0.3
|
| 484 |
+
vis_out_dir (str, optional): Directory to save visualization
|
| 485 |
+
results w/o predictions. If left as empty, no file will
|
| 486 |
+
be saved. Defaults to ''.
|
| 487 |
+
window_name (str, optional): Title of display window.
|
| 488 |
+
black_background (bool, optional): Whether to plot keypoints on a
|
| 489 |
+
black image instead of the input image. Defaults to False.
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
List[np.ndarray]: Visualization results.
|
| 493 |
+
"""
|
| 494 |
+
if (not return_vis) and (not show) and (not vis_out_dir):
|
| 495 |
+
return
|
| 496 |
+
|
| 497 |
+
if getattr(self, 'visualizer', None) is None:
|
| 498 |
+
raise ValueError('Visualization needs the "visualizer" term'
|
| 499 |
+
'defined in the config, but got None.')
|
| 500 |
+
|
| 501 |
+
self.visualizer.radius = radius
|
| 502 |
+
self.visualizer.line_width = thickness
|
| 503 |
+
|
| 504 |
+
results = []
|
| 505 |
+
|
| 506 |
+
for single_input, pred in zip(inputs, preds):
|
| 507 |
+
if isinstance(single_input, str):
|
| 508 |
+
img = mmcv.imread(single_input, channel_order='rgb')
|
| 509 |
+
elif isinstance(single_input, np.ndarray):
|
| 510 |
+
img = mmcv.bgr2rgb(single_input)
|
| 511 |
+
else:
|
| 512 |
+
raise ValueError('Unsupported input type: '
|
| 513 |
+
f'{type(single_input)}')
|
| 514 |
+
if black_background:
|
| 515 |
+
img = img * 0
|
| 516 |
+
|
| 517 |
+
img_name = os.path.basename(pred.metainfo['img_path'])
|
| 518 |
+
window_name = window_name if window_name else img_name
|
| 519 |
+
|
| 520 |
+
# since visualization and inference utilize the same process,
|
| 521 |
+
# the wait time is reduced when a video input is utilized,
|
| 522 |
+
# thereby eliminating the issue of inference getting stuck.
|
| 523 |
+
wait_time = 1e-5 if self._video_input else wait_time
|
| 524 |
+
|
| 525 |
+
visualization = self.visualizer.add_datasample(
|
| 526 |
+
window_name,
|
| 527 |
+
img,
|
| 528 |
+
pred,
|
| 529 |
+
draw_gt=False,
|
| 530 |
+
draw_bbox=draw_bbox,
|
| 531 |
+
show=show,
|
| 532 |
+
wait_time=wait_time,
|
| 533 |
+
kpt_thr=kpt_thr,
|
| 534 |
+
**kwargs)
|
| 535 |
+
results.append(visualization)
|
| 536 |
+
|
| 537 |
+
if vis_out_dir:
|
| 538 |
+
self.save_visualization(
|
| 539 |
+
visualization,
|
| 540 |
+
vis_out_dir,
|
| 541 |
+
img_name=img_name,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
if return_vis:
|
| 545 |
+
return results
|
| 546 |
+
else:
|
| 547 |
+
return []
|
| 548 |
+
|
| 549 |
+
def save_visualization(self, visualization, vis_out_dir, img_name=None):
|
| 550 |
+
out_img = mmcv.rgb2bgr(visualization)
|
| 551 |
+
_, file_extension = os.path.splitext(vis_out_dir)
|
| 552 |
+
if file_extension:
|
| 553 |
+
dir_name = os.path.dirname(vis_out_dir)
|
| 554 |
+
file_name = os.path.basename(vis_out_dir)
|
| 555 |
+
else:
|
| 556 |
+
dir_name = vis_out_dir
|
| 557 |
+
file_name = None
|
| 558 |
+
mkdir_or_exist(dir_name)
|
| 559 |
+
|
| 560 |
+
if self._video_input:
|
| 561 |
+
|
| 562 |
+
if self.video_info['writer'] is None:
|
| 563 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 564 |
+
if file_name is None:
|
| 565 |
+
file_name = os.path.basename(self.video_info['name'])
|
| 566 |
+
out_file = join_path(dir_name, file_name)
|
| 567 |
+
self.video_info['output_file'] = out_file
|
| 568 |
+
self.video_info['writer'] = cv2.VideoWriter(
|
| 569 |
+
out_file, fourcc, self.video_info['fps'],
|
| 570 |
+
(visualization.shape[1], visualization.shape[0]))
|
| 571 |
+
self.video_info['writer'].write(out_img)
|
| 572 |
+
|
| 573 |
+
else:
|
| 574 |
+
if file_name is None:
|
| 575 |
+
file_name = img_name if img_name else 'visualization.jpg'
|
| 576 |
+
|
| 577 |
+
out_file = join_path(dir_name, file_name)
|
| 578 |
+
mmcv.imwrite(out_img, out_file)
|
| 579 |
+
print_log(
|
| 580 |
+
f'the output image has been saved at {out_file}',
|
| 581 |
+
logger='current',
|
| 582 |
+
level=logging.INFO)
|
| 583 |
+
|
| 584 |
+
def postprocess(
|
| 585 |
+
self,
|
| 586 |
+
preds: List[PoseDataSample],
|
| 587 |
+
visualization: List[np.ndarray],
|
| 588 |
+
return_datasample=None,
|
| 589 |
+
return_datasamples=False,
|
| 590 |
+
pred_out_dir: str = '',
|
| 591 |
+
) -> dict:
|
| 592 |
+
"""Process the predictions and visualization results from ``forward``
|
| 593 |
+
and ``visualize``.
|
| 594 |
+
|
| 595 |
+
This method should be responsible for the following tasks:
|
| 596 |
+
|
| 597 |
+
1. Convert datasamples into a json-serializable dict if needed.
|
| 598 |
+
2. Pack the predictions and visualization results and return them.
|
| 599 |
+
3. Dump or log the predictions.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
preds (List[Dict]): Predictions of the model.
|
| 603 |
+
visualization (np.ndarray): Visualized predictions.
|
| 604 |
+
return_datasamples (bool): Whether to return results as
|
| 605 |
+
datasamples. Defaults to False.
|
| 606 |
+
pred_out_dir (str): Directory to save the inference results w/o
|
| 607 |
+
visualization. If left as empty, no file will be saved.
|
| 608 |
+
Defaults to ''.
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
dict: Inference and visualization results with key ``predictions``
|
| 612 |
+
and ``visualization``
|
| 613 |
+
|
| 614 |
+
- ``visualization (Any)``: Returned by :meth:`visualize`
|
| 615 |
+
- ``predictions`` (dict or DataSample): Returned by
|
| 616 |
+
:meth:`forward` and processed in :meth:`postprocess`.
|
| 617 |
+
If ``return_datasamples=False``, it usually should be a
|
| 618 |
+
json-serializable dict containing only basic data elements such
|
| 619 |
+
as strings and numbers.
|
| 620 |
+
"""
|
| 621 |
+
if return_datasample is not None:
|
| 622 |
+
print_log(
|
| 623 |
+
'The `return_datasample` argument is deprecated '
|
| 624 |
+
'and will be removed in future versions. Please '
|
| 625 |
+
'use `return_datasamples`.',
|
| 626 |
+
logger='current',
|
| 627 |
+
level=logging.WARNING)
|
| 628 |
+
return_datasamples = return_datasample
|
| 629 |
+
|
| 630 |
+
result_dict = defaultdict(list)
|
| 631 |
+
|
| 632 |
+
result_dict['visualization'] = visualization
|
| 633 |
+
for pred in preds:
|
| 634 |
+
if not return_datasamples:
|
| 635 |
+
# convert datasamples to list of instance predictions
|
| 636 |
+
pred = split_instances(pred.pred_instances)
|
| 637 |
+
result_dict['predictions'].append(pred)
|
| 638 |
+
|
| 639 |
+
if pred_out_dir != '':
|
| 640 |
+
for pred, data_sample in zip(result_dict['predictions'], preds):
|
| 641 |
+
if self._video_input:
|
| 642 |
+
# For video or webcam input, predictions for each frame
|
| 643 |
+
# are gathered in the 'predictions' key of 'video_info'
|
| 644 |
+
# dictionary. All frame predictions are then stored into
|
| 645 |
+
# a single file after processing all frames.
|
| 646 |
+
self.video_info['predictions'].append(pred)
|
| 647 |
+
else:
|
| 648 |
+
# For non-video inputs, predictions are stored in separate
|
| 649 |
+
# JSON files. The filename is determined by the basename
|
| 650 |
+
# of the input image path with a '.json' extension. The
|
| 651 |
+
# predictions are then dumped into this file.
|
| 652 |
+
fname = os.path.splitext(
|
| 653 |
+
os.path.basename(
|
| 654 |
+
data_sample.metainfo['img_path']))[0] + '.json'
|
| 655 |
+
mmengine.dump(
|
| 656 |
+
pred, join_path(pred_out_dir, fname), indent=' ')
|
| 657 |
+
|
| 658 |
+
return result_dict
|
| 659 |
+
|
| 660 |
+
def _finalize_video_processing(
|
| 661 |
+
self,
|
| 662 |
+
pred_out_dir: str = '',
|
| 663 |
+
):
|
| 664 |
+
"""Finalize video processing by releasing the video writer and saving
|
| 665 |
+
predictions to a file.
|
| 666 |
+
|
| 667 |
+
This method should be called after completing the video processing. It
|
| 668 |
+
releases the video writer, if it exists, and saves the predictions to a
|
| 669 |
+
JSON file if a prediction output directory is provided.
|
| 670 |
+
"""
|
| 671 |
+
|
| 672 |
+
# Release the video writer if it exists
|
| 673 |
+
if self.video_info['writer'] is not None:
|
| 674 |
+
out_file = self.video_info['output_file']
|
| 675 |
+
print_log(
|
| 676 |
+
f'the output video has been saved at {out_file}',
|
| 677 |
+
logger='current',
|
| 678 |
+
level=logging.INFO)
|
| 679 |
+
self.video_info['writer'].release()
|
| 680 |
+
|
| 681 |
+
# Save predictions
|
| 682 |
+
if pred_out_dir:
|
| 683 |
+
fname = os.path.splitext(
|
| 684 |
+
os.path.basename(self.video_info['name']))[0] + '.json'
|
| 685 |
+
predictions = [
|
| 686 |
+
dict(frame_id=i, instances=pred)
|
| 687 |
+
for i, pred in enumerate(self.video_info['predictions'])
|
| 688 |
+
]
|
| 689 |
+
|
| 690 |
+
mmengine.dump(
|
| 691 |
+
predictions, join_path(pred_out_dir, fname), indent=' ')
|
mmpose/apis/inferencers/hand3d_inferencer.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import mmcv
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from mmengine.config import Config, ConfigDict
|
| 11 |
+
from mmengine.infer.infer import ModelType
|
| 12 |
+
from mmengine.logging import print_log
|
| 13 |
+
from mmengine.model import revert_sync_batchnorm
|
| 14 |
+
from mmengine.registry import init_default_scope
|
| 15 |
+
from mmengine.structures import InstanceData
|
| 16 |
+
|
| 17 |
+
from mmpose.evaluation.functional import nms
|
| 18 |
+
from mmpose.registry import INFERENCERS
|
| 19 |
+
from mmpose.structures import PoseDataSample, merge_data_samples
|
| 20 |
+
from .base_mmpose_inferencer import BaseMMPoseInferencer
|
| 21 |
+
|
| 22 |
+
InstanceList = List[InstanceData]
|
| 23 |
+
InputType = Union[str, np.ndarray]
|
| 24 |
+
InputsType = Union[InputType, Sequence[InputType]]
|
| 25 |
+
PredType = Union[InstanceData, InstanceList]
|
| 26 |
+
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
| 27 |
+
ConfigType = Union[Config, ConfigDict]
|
| 28 |
+
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@INFERENCERS.register_module()
|
| 32 |
+
class Hand3DInferencer(BaseMMPoseInferencer):
|
| 33 |
+
"""The inferencer for 3D hand pose estimation.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model (str, optional): Pretrained 2D pose estimation algorithm.
|
| 37 |
+
It's the path to the config file or the model name defined in
|
| 38 |
+
metafile. For example, it could be:
|
| 39 |
+
|
| 40 |
+
- model alias, e.g. ``'body'``,
|
| 41 |
+
- config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
|
| 42 |
+
- config path
|
| 43 |
+
|
| 44 |
+
Defaults to ``None``.
|
| 45 |
+
weights (str, optional): Path to the checkpoint. If it is not
|
| 46 |
+
specified and "model" is a model name of metafile, the weights
|
| 47 |
+
will be loaded from metafile. Defaults to None.
|
| 48 |
+
device (str, optional): Device to run inference. If None, the
|
| 49 |
+
available device will be automatically used. Defaults to None.
|
| 50 |
+
scope (str, optional): The scope of the model. Defaults to "mmpose".
|
| 51 |
+
det_model (str, optional): Config path or alias of detection model.
|
| 52 |
+
Defaults to None.
|
| 53 |
+
det_weights (str, optional): Path to the checkpoints of detection
|
| 54 |
+
model. Defaults to None.
|
| 55 |
+
det_cat_ids (int or list[int], optional): Category id for
|
| 56 |
+
detection model. Defaults to None.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
|
| 60 |
+
forward_kwargs: set = {'disable_rebase_keypoint'}
|
| 61 |
+
visualize_kwargs: set = {
|
| 62 |
+
'return_vis',
|
| 63 |
+
'show',
|
| 64 |
+
'wait_time',
|
| 65 |
+
'draw_bbox',
|
| 66 |
+
'radius',
|
| 67 |
+
'thickness',
|
| 68 |
+
'kpt_thr',
|
| 69 |
+
'vis_out_dir',
|
| 70 |
+
'num_instances',
|
| 71 |
+
}
|
| 72 |
+
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
| 73 |
+
|
| 74 |
+
def __init__(self,
|
| 75 |
+
model: Union[ModelType, str],
|
| 76 |
+
weights: Optional[str] = None,
|
| 77 |
+
device: Optional[str] = None,
|
| 78 |
+
scope: Optional[str] = 'mmpose',
|
| 79 |
+
det_model: Optional[Union[ModelType, str]] = None,
|
| 80 |
+
det_weights: Optional[str] = None,
|
| 81 |
+
det_cat_ids: Optional[Union[int, Tuple]] = None,
|
| 82 |
+
show_progress: bool = False) -> None:
|
| 83 |
+
|
| 84 |
+
init_default_scope(scope)
|
| 85 |
+
super().__init__(
|
| 86 |
+
model=model,
|
| 87 |
+
weights=weights,
|
| 88 |
+
device=device,
|
| 89 |
+
scope=scope,
|
| 90 |
+
show_progress=show_progress)
|
| 91 |
+
self.model = revert_sync_batchnorm(self.model)
|
| 92 |
+
|
| 93 |
+
# assign dataset metainfo to self.visualizer
|
| 94 |
+
self.visualizer.set_dataset_meta(self.model.dataset_meta)
|
| 95 |
+
|
| 96 |
+
# initialize hand detector
|
| 97 |
+
self._init_detector(
|
| 98 |
+
det_model=det_model,
|
| 99 |
+
det_weights=det_weights,
|
| 100 |
+
det_cat_ids=det_cat_ids,
|
| 101 |
+
device=device,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self._video_input = False
|
| 105 |
+
self._buffer = defaultdict(list)
|
| 106 |
+
|
| 107 |
+
def preprocess_single(self,
|
| 108 |
+
input: InputType,
|
| 109 |
+
index: int,
|
| 110 |
+
bbox_thr: float = 0.3,
|
| 111 |
+
nms_thr: float = 0.3,
|
| 112 |
+
bboxes: Union[List[List], List[np.ndarray],
|
| 113 |
+
np.ndarray] = []):
|
| 114 |
+
"""Process a single input into a model-feedable format.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
input (InputType): Input given by user.
|
| 118 |
+
index (int): index of the input
|
| 119 |
+
bbox_thr (float): threshold for bounding box detection.
|
| 120 |
+
Defaults to 0.3.
|
| 121 |
+
nms_thr (float): IoU threshold for bounding box NMS.
|
| 122 |
+
Defaults to 0.3.
|
| 123 |
+
|
| 124 |
+
Yields:
|
| 125 |
+
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
if isinstance(input, str):
|
| 129 |
+
data_info = dict(img_path=input)
|
| 130 |
+
else:
|
| 131 |
+
data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0'))
|
| 132 |
+
data_info.update(self.model.dataset_meta)
|
| 133 |
+
|
| 134 |
+
if self.detector is not None:
|
| 135 |
+
try:
|
| 136 |
+
det_results = self.detector(
|
| 137 |
+
input, return_datasamples=True)['predictions']
|
| 138 |
+
except ValueError:
|
| 139 |
+
print_log(
|
| 140 |
+
'Support for mmpose and mmdet versions up to 3.1.0 '
|
| 141 |
+
'will be discontinued in upcoming releases. To '
|
| 142 |
+
'ensure ongoing compatibility, please upgrade to '
|
| 143 |
+
'mmdet version 3.2.0 or later.',
|
| 144 |
+
logger='current',
|
| 145 |
+
level=logging.WARNING)
|
| 146 |
+
det_results = self.detector(
|
| 147 |
+
input, return_datasample=True)['predictions']
|
| 148 |
+
pred_instance = det_results[0].pred_instances.cpu().numpy()
|
| 149 |
+
bboxes = np.concatenate(
|
| 150 |
+
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
|
| 151 |
+
|
| 152 |
+
label_mask = np.zeros(len(bboxes), dtype=np.uint8)
|
| 153 |
+
for cat_id in self.det_cat_ids:
|
| 154 |
+
label_mask = np.logical_or(label_mask,
|
| 155 |
+
pred_instance.labels == cat_id)
|
| 156 |
+
|
| 157 |
+
bboxes = bboxes[np.logical_and(label_mask,
|
| 158 |
+
pred_instance.scores > bbox_thr)]
|
| 159 |
+
bboxes = bboxes[nms(bboxes, nms_thr)]
|
| 160 |
+
|
| 161 |
+
data_infos = []
|
| 162 |
+
if len(bboxes) > 0:
|
| 163 |
+
for bbox in bboxes:
|
| 164 |
+
inst = data_info.copy()
|
| 165 |
+
inst['bbox'] = bbox[None, :4]
|
| 166 |
+
inst['bbox_score'] = bbox[4:5]
|
| 167 |
+
data_infos.append(self.pipeline(inst))
|
| 168 |
+
else:
|
| 169 |
+
inst = data_info.copy()
|
| 170 |
+
|
| 171 |
+
# get bbox from the image size
|
| 172 |
+
if isinstance(input, str):
|
| 173 |
+
input = mmcv.imread(input)
|
| 174 |
+
h, w = input.shape[:2]
|
| 175 |
+
|
| 176 |
+
inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32)
|
| 177 |
+
inst['bbox_score'] = np.ones(1, dtype=np.float32)
|
| 178 |
+
data_infos.append(self.pipeline(inst))
|
| 179 |
+
|
| 180 |
+
return data_infos
|
| 181 |
+
|
| 182 |
+
@torch.no_grad()
|
| 183 |
+
def forward(self,
|
| 184 |
+
inputs: Union[dict, tuple],
|
| 185 |
+
disable_rebase_keypoint: bool = False):
|
| 186 |
+
"""Performs a forward pass through the model.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
inputs (Union[dict, tuple]): The input data to be processed. Can
|
| 190 |
+
be either a dictionary or a tuple.
|
| 191 |
+
disable_rebase_keypoint (bool, optional): Flag to disable rebasing
|
| 192 |
+
the height of the keypoints. Defaults to False.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
A list of data samples with prediction instances.
|
| 196 |
+
"""
|
| 197 |
+
data_samples = self.model.test_step(inputs)
|
| 198 |
+
data_samples_2d = []
|
| 199 |
+
|
| 200 |
+
for idx, res in enumerate(data_samples):
|
| 201 |
+
pred_instances = res.pred_instances
|
| 202 |
+
keypoints = pred_instances.keypoints
|
| 203 |
+
rel_root_depth = pred_instances.rel_root_depth
|
| 204 |
+
scores = pred_instances.keypoint_scores
|
| 205 |
+
hand_type = pred_instances.hand_type
|
| 206 |
+
|
| 207 |
+
res_2d = PoseDataSample()
|
| 208 |
+
gt_instances = res.gt_instances.clone()
|
| 209 |
+
pred_instances = pred_instances.clone()
|
| 210 |
+
res_2d.gt_instances = gt_instances
|
| 211 |
+
res_2d.pred_instances = pred_instances
|
| 212 |
+
|
| 213 |
+
# add relative root depth to left hand joints
|
| 214 |
+
keypoints[:, 21:, 2] += rel_root_depth
|
| 215 |
+
|
| 216 |
+
# set joint scores according to hand type
|
| 217 |
+
scores[:, :21] *= hand_type[:, [0]]
|
| 218 |
+
scores[:, 21:] *= hand_type[:, [1]]
|
| 219 |
+
# normalize kpt score
|
| 220 |
+
if scores.max() > 1:
|
| 221 |
+
scores /= 255
|
| 222 |
+
|
| 223 |
+
res_2d.pred_instances.set_field(keypoints[..., :2].copy(),
|
| 224 |
+
'keypoints')
|
| 225 |
+
|
| 226 |
+
# rotate the keypoint to make z-axis correspondent to height
|
| 227 |
+
# for better visualization
|
| 228 |
+
vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
|
| 229 |
+
keypoints[..., :3] = keypoints[..., :3] @ vis_R
|
| 230 |
+
|
| 231 |
+
# rebase height (z-axis)
|
| 232 |
+
if not disable_rebase_keypoint:
|
| 233 |
+
valid = scores > 0
|
| 234 |
+
keypoints[..., 2] -= np.min(
|
| 235 |
+
keypoints[valid, 2], axis=-1, keepdims=True)
|
| 236 |
+
|
| 237 |
+
data_samples[idx].pred_instances.keypoints = keypoints
|
| 238 |
+
data_samples[idx].pred_instances.keypoint_scores = scores
|
| 239 |
+
data_samples_2d.append(res_2d)
|
| 240 |
+
|
| 241 |
+
data_samples = [merge_data_samples(data_samples)]
|
| 242 |
+
data_samples_2d = merge_data_samples(data_samples_2d)
|
| 243 |
+
|
| 244 |
+
self._buffer['pose2d_results'] = data_samples_2d
|
| 245 |
+
|
| 246 |
+
return data_samples
|
| 247 |
+
|
| 248 |
+
def visualize(
|
| 249 |
+
self,
|
| 250 |
+
inputs: list,
|
| 251 |
+
preds: List[PoseDataSample],
|
| 252 |
+
return_vis: bool = False,
|
| 253 |
+
show: bool = False,
|
| 254 |
+
draw_bbox: bool = False,
|
| 255 |
+
wait_time: float = 0,
|
| 256 |
+
radius: int = 3,
|
| 257 |
+
thickness: int = 1,
|
| 258 |
+
kpt_thr: float = 0.3,
|
| 259 |
+
num_instances: int = 1,
|
| 260 |
+
vis_out_dir: str = '',
|
| 261 |
+
window_name: str = '',
|
| 262 |
+
) -> List[np.ndarray]:
|
| 263 |
+
"""Visualize predictions.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
| 267 |
+
preds (Any): Predictions of the model.
|
| 268 |
+
return_vis (bool): Whether to return images with predicted results.
|
| 269 |
+
show (bool): Whether to display the image in a popup window.
|
| 270 |
+
Defaults to False.
|
| 271 |
+
wait_time (float): The interval of show (ms). Defaults to 0
|
| 272 |
+
draw_bbox (bool): Whether to draw the bounding boxes.
|
| 273 |
+
Defaults to False
|
| 274 |
+
radius (int): Keypoint radius for visualization. Defaults to 3
|
| 275 |
+
thickness (int): Link thickness for visualization. Defaults to 1
|
| 276 |
+
kpt_thr (float): The threshold to visualize the keypoints.
|
| 277 |
+
Defaults to 0.3
|
| 278 |
+
vis_out_dir (str, optional): Directory to save visualization
|
| 279 |
+
results w/o predictions. If left as empty, no file will
|
| 280 |
+
be saved. Defaults to ''.
|
| 281 |
+
window_name (str, optional): Title of display window.
|
| 282 |
+
window_close_event_handler (callable, optional):
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
List[np.ndarray]: Visualization results.
|
| 286 |
+
"""
|
| 287 |
+
if (not return_vis) and (not show) and (not vis_out_dir):
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
if getattr(self, 'visualizer', None) is None:
|
| 291 |
+
raise ValueError('Visualization needs the "visualizer" term'
|
| 292 |
+
'defined in the config, but got None.')
|
| 293 |
+
|
| 294 |
+
self.visualizer.radius = radius
|
| 295 |
+
self.visualizer.line_width = thickness
|
| 296 |
+
|
| 297 |
+
results = []
|
| 298 |
+
|
| 299 |
+
for single_input, pred in zip(inputs, preds):
|
| 300 |
+
if isinstance(single_input, str):
|
| 301 |
+
img = mmcv.imread(single_input, channel_order='rgb')
|
| 302 |
+
elif isinstance(single_input, np.ndarray):
|
| 303 |
+
img = mmcv.bgr2rgb(single_input)
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError('Unsupported input type: '
|
| 306 |
+
f'{type(single_input)}')
|
| 307 |
+
img_name = os.path.basename(pred.metainfo['img_path'])
|
| 308 |
+
|
| 309 |
+
# since visualization and inference utilize the same process,
|
| 310 |
+
# the wait time is reduced when a video input is utilized,
|
| 311 |
+
# thereby eliminating the issue of inference getting stuck.
|
| 312 |
+
wait_time = 1e-5 if self._video_input else wait_time
|
| 313 |
+
|
| 314 |
+
if num_instances < 0:
|
| 315 |
+
num_instances = len(pred.pred_instances)
|
| 316 |
+
|
| 317 |
+
visualization = self.visualizer.add_datasample(
|
| 318 |
+
window_name,
|
| 319 |
+
img,
|
| 320 |
+
data_sample=pred,
|
| 321 |
+
det_data_sample=self._buffer['pose2d_results'],
|
| 322 |
+
draw_gt=False,
|
| 323 |
+
draw_bbox=draw_bbox,
|
| 324 |
+
show=show,
|
| 325 |
+
wait_time=wait_time,
|
| 326 |
+
convert_keypoint=False,
|
| 327 |
+
axis_azimuth=-115,
|
| 328 |
+
axis_limit=200,
|
| 329 |
+
axis_elev=15,
|
| 330 |
+
kpt_thr=kpt_thr,
|
| 331 |
+
num_instances=num_instances)
|
| 332 |
+
results.append(visualization)
|
| 333 |
+
|
| 334 |
+
if vis_out_dir:
|
| 335 |
+
self.save_visualization(
|
| 336 |
+
visualization,
|
| 337 |
+
vis_out_dir,
|
| 338 |
+
img_name=img_name,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if return_vis:
|
| 342 |
+
return results
|
| 343 |
+
else:
|
| 344 |
+
return []
|
mmpose/apis/inferencers/mmpose_inferencer.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Dict, List, Optional, Sequence, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from mmengine.config import Config, ConfigDict
|
| 8 |
+
from mmengine.infer.infer import ModelType
|
| 9 |
+
from mmengine.structures import InstanceData
|
| 10 |
+
from rich.progress import track
|
| 11 |
+
|
| 12 |
+
from .base_mmpose_inferencer import BaseMMPoseInferencer
|
| 13 |
+
from .hand3d_inferencer import Hand3DInferencer
|
| 14 |
+
from .pose2d_inferencer import Pose2DInferencer
|
| 15 |
+
from .pose3d_inferencer import Pose3DInferencer
|
| 16 |
+
|
| 17 |
+
InstanceList = List[InstanceData]
|
| 18 |
+
InputType = Union[str, np.ndarray]
|
| 19 |
+
InputsType = Union[InputType, Sequence[InputType]]
|
| 20 |
+
PredType = Union[InstanceData, InstanceList]
|
| 21 |
+
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
| 22 |
+
ConfigType = Union[Config, ConfigDict]
|
| 23 |
+
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MMPoseInferencer(BaseMMPoseInferencer):
|
| 27 |
+
"""MMPose Inferencer. It's a unified inferencer interface for pose
|
| 28 |
+
estimation task, currently including: Pose2D. and it can be used to perform
|
| 29 |
+
2D keypoint detection.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
pose2d (str, optional): Pretrained 2D pose estimation algorithm.
|
| 33 |
+
It's the path to the config file or the model name defined in
|
| 34 |
+
metafile. For example, it could be:
|
| 35 |
+
|
| 36 |
+
- model alias, e.g. ``'body'``,
|
| 37 |
+
- config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
|
| 38 |
+
- config path
|
| 39 |
+
|
| 40 |
+
Defaults to ``None``.
|
| 41 |
+
pose2d_weights (str, optional): Path to the custom checkpoint file of
|
| 42 |
+
the selected pose2d model. If it is not specified and "pose2d" is
|
| 43 |
+
a model name of metafile, the weights will be loaded from
|
| 44 |
+
metafile. Defaults to None.
|
| 45 |
+
device (str, optional): Device to run inference. If None, the
|
| 46 |
+
available device will be automatically used. Defaults to None.
|
| 47 |
+
scope (str, optional): The scope of the model. Defaults to "mmpose".
|
| 48 |
+
det_model(str, optional): Config path or alias of detection model.
|
| 49 |
+
Defaults to None.
|
| 50 |
+
det_weights(str, optional): Path to the checkpoints of detection
|
| 51 |
+
model. Defaults to None.
|
| 52 |
+
det_cat_ids(int or list[int], optional): Category id for
|
| 53 |
+
detection model. Defaults to None.
|
| 54 |
+
output_heatmaps (bool, optional): Flag to visualize predicted
|
| 55 |
+
heatmaps. If set to None, the default setting from the model
|
| 56 |
+
config will be used. Default is None.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
preprocess_kwargs: set = {
|
| 60 |
+
'bbox_thr', 'nms_thr', 'bboxes', 'use_oks_tracking', 'tracking_thr',
|
| 61 |
+
'disable_norm_pose_2d'
|
| 62 |
+
}
|
| 63 |
+
forward_kwargs: set = {
|
| 64 |
+
'merge_results', 'disable_rebase_keypoint', 'pose_based_nms'
|
| 65 |
+
}
|
| 66 |
+
visualize_kwargs: set = {
|
| 67 |
+
'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
|
| 68 |
+
'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap',
|
| 69 |
+
'black_background', 'num_instances'
|
| 70 |
+
}
|
| 71 |
+
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
| 72 |
+
|
| 73 |
+
def __init__(self,
|
| 74 |
+
pose2d: Optional[str] = None,
|
| 75 |
+
pose2d_weights: Optional[str] = None,
|
| 76 |
+
pose3d: Optional[str] = None,
|
| 77 |
+
pose3d_weights: Optional[str] = None,
|
| 78 |
+
device: Optional[str] = None,
|
| 79 |
+
scope: str = 'mmpose',
|
| 80 |
+
det_model: Optional[Union[ModelType, str]] = None,
|
| 81 |
+
det_weights: Optional[str] = None,
|
| 82 |
+
det_cat_ids: Optional[Union[int, List]] = None,
|
| 83 |
+
show_progress: bool = False) -> None:
|
| 84 |
+
|
| 85 |
+
self.visualizer = None
|
| 86 |
+
self.show_progress = show_progress
|
| 87 |
+
if pose3d is not None:
|
| 88 |
+
if 'hand3d' in pose3d:
|
| 89 |
+
self.inferencer = Hand3DInferencer(pose3d, pose3d_weights,
|
| 90 |
+
device, scope, det_model,
|
| 91 |
+
det_weights, det_cat_ids,
|
| 92 |
+
show_progress)
|
| 93 |
+
else:
|
| 94 |
+
self.inferencer = Pose3DInferencer(pose3d, pose3d_weights,
|
| 95 |
+
pose2d, pose2d_weights,
|
| 96 |
+
device, scope, det_model,
|
| 97 |
+
det_weights, det_cat_ids,
|
| 98 |
+
show_progress)
|
| 99 |
+
elif pose2d is not None:
|
| 100 |
+
self.inferencer = Pose2DInferencer(pose2d, pose2d_weights, device,
|
| 101 |
+
scope, det_model, det_weights,
|
| 102 |
+
det_cat_ids, show_progress)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError('Either 2d or 3d pose estimation algorithm '
|
| 105 |
+
'should be provided.')
|
| 106 |
+
|
| 107 |
+
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
|
| 108 |
+
"""Process the inputs into a model-feedable format.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
inputs (InputsType): Inputs given by user.
|
| 112 |
+
batch_size (int): batch size. Defaults to 1.
|
| 113 |
+
|
| 114 |
+
Yields:
|
| 115 |
+
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
| 116 |
+
List[str or np.ndarray]: List of original inputs in the batch
|
| 117 |
+
"""
|
| 118 |
+
for data in self.inferencer.preprocess(inputs, batch_size, **kwargs):
|
| 119 |
+
yield data
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def forward(self, inputs: InputType, **forward_kwargs) -> PredType:
|
| 123 |
+
"""Forward the inputs to the model.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
inputs (InputsType): The inputs to be forwarded.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Dict: The prediction results. Possibly with keys "pose2d".
|
| 130 |
+
"""
|
| 131 |
+
return self.inferencer.forward(inputs, **forward_kwargs)
|
| 132 |
+
|
| 133 |
+
def __call__(
|
| 134 |
+
self,
|
| 135 |
+
inputs: InputsType,
|
| 136 |
+
return_datasamples: bool = False,
|
| 137 |
+
batch_size: int = 1,
|
| 138 |
+
out_dir: Optional[str] = None,
|
| 139 |
+
**kwargs,
|
| 140 |
+
) -> dict:
|
| 141 |
+
"""Call the inferencer.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
inputs (InputsType): Inputs for the inferencer.
|
| 145 |
+
return_datasamples (bool): Whether to return results as
|
| 146 |
+
:obj:`BaseDataElement`. Defaults to False.
|
| 147 |
+
batch_size (int): Batch size. Defaults to 1.
|
| 148 |
+
out_dir (str, optional): directory to save visualization
|
| 149 |
+
results and predictions. Will be overoden if vis_out_dir or
|
| 150 |
+
pred_out_dir are given. Defaults to None
|
| 151 |
+
**kwargs: Key words arguments passed to :meth:`preprocess`,
|
| 152 |
+
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
| 153 |
+
Each key in kwargs should be in the corresponding set of
|
| 154 |
+
``preprocess_kwargs``, ``forward_kwargs``,
|
| 155 |
+
``visualize_kwargs`` and ``postprocess_kwargs``.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
dict: Inference and visualization results.
|
| 159 |
+
"""
|
| 160 |
+
if out_dir is not None:
|
| 161 |
+
if 'vis_out_dir' not in kwargs:
|
| 162 |
+
kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
|
| 163 |
+
if 'pred_out_dir' not in kwargs:
|
| 164 |
+
kwargs['pred_out_dir'] = f'{out_dir}/predictions'
|
| 165 |
+
|
| 166 |
+
kwargs = {
|
| 167 |
+
key: value
|
| 168 |
+
for key, value in kwargs.items()
|
| 169 |
+
if key in set.union(self.inferencer.preprocess_kwargs,
|
| 170 |
+
self.inferencer.forward_kwargs,
|
| 171 |
+
self.inferencer.visualize_kwargs,
|
| 172 |
+
self.inferencer.postprocess_kwargs)
|
| 173 |
+
}
|
| 174 |
+
(
|
| 175 |
+
preprocess_kwargs,
|
| 176 |
+
forward_kwargs,
|
| 177 |
+
visualize_kwargs,
|
| 178 |
+
postprocess_kwargs,
|
| 179 |
+
) = self._dispatch_kwargs(**kwargs)
|
| 180 |
+
|
| 181 |
+
self.inferencer.update_model_visualizer_settings(**kwargs)
|
| 182 |
+
|
| 183 |
+
# preprocessing
|
| 184 |
+
if isinstance(inputs, str) and inputs.startswith('webcam'):
|
| 185 |
+
inputs = self.inferencer._get_webcam_inputs(inputs)
|
| 186 |
+
batch_size = 1
|
| 187 |
+
if not visualize_kwargs.get('show', False):
|
| 188 |
+
warnings.warn('The display mode is closed when using webcam '
|
| 189 |
+
'input. It will be turned on automatically.')
|
| 190 |
+
visualize_kwargs['show'] = True
|
| 191 |
+
else:
|
| 192 |
+
inputs = self.inferencer._inputs_to_list(inputs)
|
| 193 |
+
self._video_input = self.inferencer._video_input
|
| 194 |
+
if self._video_input:
|
| 195 |
+
self.video_info = self.inferencer.video_info
|
| 196 |
+
|
| 197 |
+
inputs = self.preprocess(
|
| 198 |
+
inputs, batch_size=batch_size, **preprocess_kwargs)
|
| 199 |
+
|
| 200 |
+
# forward
|
| 201 |
+
if 'bbox_thr' in self.inferencer.forward_kwargs:
|
| 202 |
+
forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
|
| 203 |
+
|
| 204 |
+
preds = []
|
| 205 |
+
|
| 206 |
+
for proc_inputs, ori_inputs in (track(inputs, description='Inference')
|
| 207 |
+
if self.show_progress else inputs):
|
| 208 |
+
preds = self.forward(proc_inputs, **forward_kwargs)
|
| 209 |
+
|
| 210 |
+
visualization = self.visualize(ori_inputs, preds,
|
| 211 |
+
**visualize_kwargs)
|
| 212 |
+
results = self.postprocess(
|
| 213 |
+
preds,
|
| 214 |
+
visualization,
|
| 215 |
+
return_datasamples=return_datasamples,
|
| 216 |
+
**postprocess_kwargs)
|
| 217 |
+
yield results
|
| 218 |
+
|
| 219 |
+
if self._video_input:
|
| 220 |
+
self._finalize_video_processing(
|
| 221 |
+
postprocess_kwargs.get('pred_out_dir', ''))
|
| 222 |
+
|
| 223 |
+
def visualize(self, inputs: InputsType, preds: PredType,
|
| 224 |
+
**kwargs) -> List[np.ndarray]:
|
| 225 |
+
"""Visualize predictions.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
| 229 |
+
preds (Any): Predictions of the model.
|
| 230 |
+
return_vis (bool): Whether to return images with predicted results.
|
| 231 |
+
show (bool): Whether to display the image in a popup window.
|
| 232 |
+
Defaults to False.
|
| 233 |
+
show_interval (int): The interval of show (s). Defaults to 0
|
| 234 |
+
radius (int): Keypoint radius for visualization. Defaults to 3
|
| 235 |
+
thickness (int): Link thickness for visualization. Defaults to 1
|
| 236 |
+
kpt_thr (float): The threshold to visualize the keypoints.
|
| 237 |
+
Defaults to 0.3
|
| 238 |
+
vis_out_dir (str, optional): directory to save visualization
|
| 239 |
+
results w/o predictions. If left as empty, no file will
|
| 240 |
+
be saved. Defaults to ''.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
List[np.ndarray]: Visualization results.
|
| 244 |
+
"""
|
| 245 |
+
window_name = ''
|
| 246 |
+
if self.inferencer._video_input:
|
| 247 |
+
window_name = self.inferencer.video_info['name']
|
| 248 |
+
|
| 249 |
+
return self.inferencer.visualize(
|
| 250 |
+
inputs, preds, window_name=window_name, **kwargs)
|
mmpose/apis/inferencers/pose2d_inferencer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import mmcv
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from mmengine.config import Config, ConfigDict
|
| 9 |
+
from mmengine.infer.infer import ModelType
|
| 10 |
+
from mmengine.logging import print_log
|
| 11 |
+
from mmengine.model import revert_sync_batchnorm
|
| 12 |
+
from mmengine.registry import init_default_scope
|
| 13 |
+
from mmengine.structures import InstanceData
|
| 14 |
+
|
| 15 |
+
from mmpose.evaluation.functional import nearby_joints_nms, nms
|
| 16 |
+
from mmpose.registry import INFERENCERS
|
| 17 |
+
from mmpose.structures import merge_data_samples
|
| 18 |
+
from .base_mmpose_inferencer import BaseMMPoseInferencer
|
| 19 |
+
|
| 20 |
+
InstanceList = List[InstanceData]
|
| 21 |
+
InputType = Union[str, np.ndarray]
|
| 22 |
+
InputsType = Union[InputType, Sequence[InputType]]
|
| 23 |
+
PredType = Union[InstanceData, InstanceList]
|
| 24 |
+
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
| 25 |
+
ConfigType = Union[Config, ConfigDict]
|
| 26 |
+
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@INFERENCERS.register_module(name='pose-estimation')
|
| 30 |
+
@INFERENCERS.register_module()
|
| 31 |
+
class Pose2DInferencer(BaseMMPoseInferencer):
|
| 32 |
+
"""The inferencer for 2D pose estimation.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model (str, optional): Pretrained 2D pose estimation algorithm.
|
| 36 |
+
It's the path to the config file or the model name defined in
|
| 37 |
+
metafile. For example, it could be:
|
| 38 |
+
|
| 39 |
+
- model alias, e.g. ``'body'``,
|
| 40 |
+
- config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
|
| 41 |
+
- config path
|
| 42 |
+
|
| 43 |
+
Defaults to ``None``.
|
| 44 |
+
weights (str, optional): Path to the checkpoint. If it is not
|
| 45 |
+
specified and "model" is a model name of metafile, the weights
|
| 46 |
+
will be loaded from metafile. Defaults to None.
|
| 47 |
+
device (str, optional): Device to run inference. If None, the
|
| 48 |
+
available device will be automatically used. Defaults to None.
|
| 49 |
+
scope (str, optional): The scope of the model. Defaults to "mmpose".
|
| 50 |
+
det_model (str, optional): Config path or alias of detection model.
|
| 51 |
+
Defaults to None.
|
| 52 |
+
det_weights (str, optional): Path to the checkpoints of detection
|
| 53 |
+
model. Defaults to None.
|
| 54 |
+
det_cat_ids (int or list[int], optional): Category id for
|
| 55 |
+
detection model. Defaults to None.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
|
| 59 |
+
forward_kwargs: set = {'merge_results', 'pose_based_nms'}
|
| 60 |
+
visualize_kwargs: set = {
|
| 61 |
+
'return_vis',
|
| 62 |
+
'show',
|
| 63 |
+
'wait_time',
|
| 64 |
+
'draw_bbox',
|
| 65 |
+
'radius',
|
| 66 |
+
'thickness',
|
| 67 |
+
'kpt_thr',
|
| 68 |
+
'vis_out_dir',
|
| 69 |
+
'skeleton_style',
|
| 70 |
+
'draw_heatmap',
|
| 71 |
+
'black_background',
|
| 72 |
+
}
|
| 73 |
+
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
| 74 |
+
|
| 75 |
+
def __init__(self,
|
| 76 |
+
model: Union[ModelType, str],
|
| 77 |
+
weights: Optional[str] = None,
|
| 78 |
+
device: Optional[str] = None,
|
| 79 |
+
scope: Optional[str] = 'mmpose',
|
| 80 |
+
det_model: Optional[Union[ModelType, str]] = None,
|
| 81 |
+
det_weights: Optional[str] = None,
|
| 82 |
+
det_cat_ids: Optional[Union[int, Tuple]] = None,
|
| 83 |
+
show_progress: bool = False) -> None:
|
| 84 |
+
|
| 85 |
+
init_default_scope(scope)
|
| 86 |
+
super().__init__(
|
| 87 |
+
model=model,
|
| 88 |
+
weights=weights,
|
| 89 |
+
device=device,
|
| 90 |
+
scope=scope,
|
| 91 |
+
show_progress=show_progress)
|
| 92 |
+
self.model = revert_sync_batchnorm(self.model)
|
| 93 |
+
|
| 94 |
+
# assign dataset metainfo to self.visualizer
|
| 95 |
+
self.visualizer.set_dataset_meta(self.model.dataset_meta)
|
| 96 |
+
|
| 97 |
+
# initialize detector for top-down models
|
| 98 |
+
if self.cfg.data_mode == 'topdown':
|
| 99 |
+
self._init_detector(
|
| 100 |
+
det_model=det_model,
|
| 101 |
+
det_weights=det_weights,
|
| 102 |
+
det_cat_ids=det_cat_ids,
|
| 103 |
+
device=device,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self._video_input = False
|
| 107 |
+
|
| 108 |
+
def update_model_visualizer_settings(self,
|
| 109 |
+
draw_heatmap: bool = False,
|
| 110 |
+
skeleton_style: str = 'mmpose',
|
| 111 |
+
**kwargs) -> None:
|
| 112 |
+
"""Update the settings of models and visualizer according to inference
|
| 113 |
+
arguments.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
draw_heatmaps (bool, optional): Flag to visualize predicted
|
| 117 |
+
heatmaps. If not provided, it defaults to False.
|
| 118 |
+
skeleton_style (str, optional): Skeleton style selection. Valid
|
| 119 |
+
options are 'mmpose' and 'openpose'. Defaults to 'mmpose'.
|
| 120 |
+
"""
|
| 121 |
+
self.model.test_cfg['output_heatmaps'] = draw_heatmap
|
| 122 |
+
|
| 123 |
+
if skeleton_style not in ['mmpose', 'openpose']:
|
| 124 |
+
raise ValueError('`skeleton_style` must be either \'mmpose\' '
|
| 125 |
+
'or \'openpose\'')
|
| 126 |
+
|
| 127 |
+
if skeleton_style == 'openpose':
|
| 128 |
+
self.visualizer.set_dataset_meta(self.model.dataset_meta,
|
| 129 |
+
skeleton_style)
|
| 130 |
+
|
| 131 |
+
def preprocess_single(self,
|
| 132 |
+
input: InputType,
|
| 133 |
+
index: int,
|
| 134 |
+
bbox_thr: float = 0.3,
|
| 135 |
+
nms_thr: float = 0.3,
|
| 136 |
+
bboxes: Union[List[List], List[np.ndarray],
|
| 137 |
+
np.ndarray] = []):
|
| 138 |
+
"""Process a single input into a model-feedable format.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
input (InputType): Input given by user.
|
| 142 |
+
index (int): index of the input
|
| 143 |
+
bbox_thr (float): threshold for bounding box detection.
|
| 144 |
+
Defaults to 0.3.
|
| 145 |
+
nms_thr (float): IoU threshold for bounding box NMS.
|
| 146 |
+
Defaults to 0.3.
|
| 147 |
+
|
| 148 |
+
Yields:
|
| 149 |
+
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
if isinstance(input, str):
|
| 153 |
+
data_info = dict(img_path=input)
|
| 154 |
+
else:
|
| 155 |
+
data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0'))
|
| 156 |
+
data_info.update(self.model.dataset_meta)
|
| 157 |
+
|
| 158 |
+
if self.cfg.data_mode == 'topdown':
|
| 159 |
+
bboxes = []
|
| 160 |
+
if self.detector is not None:
|
| 161 |
+
try:
|
| 162 |
+
det_results = self.detector(
|
| 163 |
+
input, return_datasamples=True)['predictions']
|
| 164 |
+
except ValueError:
|
| 165 |
+
print_log(
|
| 166 |
+
'Support for mmpose and mmdet versions up to 3.1.0 '
|
| 167 |
+
'will be discontinued in upcoming releases. To '
|
| 168 |
+
'ensure ongoing compatibility, please upgrade to '
|
| 169 |
+
'mmdet version 3.2.0 or later.',
|
| 170 |
+
logger='current',
|
| 171 |
+
level=logging.WARNING)
|
| 172 |
+
det_results = self.detector(
|
| 173 |
+
input, return_datasample=True)['predictions']
|
| 174 |
+
pred_instance = det_results[0].pred_instances.cpu().numpy()
|
| 175 |
+
bboxes = np.concatenate(
|
| 176 |
+
(pred_instance.bboxes, pred_instance.scores[:, None]),
|
| 177 |
+
axis=1)
|
| 178 |
+
|
| 179 |
+
label_mask = np.zeros(len(bboxes), dtype=np.uint8)
|
| 180 |
+
for cat_id in self.det_cat_ids:
|
| 181 |
+
label_mask = np.logical_or(label_mask,
|
| 182 |
+
pred_instance.labels == cat_id)
|
| 183 |
+
|
| 184 |
+
bboxes = bboxes[np.logical_and(
|
| 185 |
+
label_mask, pred_instance.scores > bbox_thr)]
|
| 186 |
+
bboxes = bboxes[nms(bboxes, nms_thr)]
|
| 187 |
+
|
| 188 |
+
data_infos = []
|
| 189 |
+
if len(bboxes) > 0:
|
| 190 |
+
for bbox in bboxes:
|
| 191 |
+
inst = data_info.copy()
|
| 192 |
+
inst['bbox'] = bbox[None, :4]
|
| 193 |
+
inst['bbox_score'] = bbox[4:5]
|
| 194 |
+
data_infos.append(self.pipeline(inst))
|
| 195 |
+
else:
|
| 196 |
+
inst = data_info.copy()
|
| 197 |
+
|
| 198 |
+
# get bbox from the image size
|
| 199 |
+
if isinstance(input, str):
|
| 200 |
+
input = mmcv.imread(input)
|
| 201 |
+
h, w = input.shape[:2]
|
| 202 |
+
|
| 203 |
+
inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32)
|
| 204 |
+
inst['bbox_score'] = np.ones(1, dtype=np.float32)
|
| 205 |
+
data_infos.append(self.pipeline(inst))
|
| 206 |
+
|
| 207 |
+
else: # bottom-up
|
| 208 |
+
data_infos = [self.pipeline(data_info)]
|
| 209 |
+
|
| 210 |
+
return data_infos
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def forward(self,
|
| 214 |
+
inputs: Union[dict, tuple],
|
| 215 |
+
merge_results: bool = True,
|
| 216 |
+
bbox_thr: float = -1,
|
| 217 |
+
pose_based_nms: bool = False):
|
| 218 |
+
"""Performs a forward pass through the model.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
inputs (Union[dict, tuple]): The input data to be processed. Can
|
| 222 |
+
be either a dictionary or a tuple.
|
| 223 |
+
merge_results (bool, optional): Whether to merge data samples,
|
| 224 |
+
default to True. This is only applicable when the data_mode
|
| 225 |
+
is 'topdown'.
|
| 226 |
+
bbox_thr (float, optional): A threshold for the bounding box
|
| 227 |
+
scores. Bounding boxes with scores greater than this value
|
| 228 |
+
will be retained. Default value is -1 which retains all
|
| 229 |
+
bounding boxes.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
A list of data samples with prediction instances.
|
| 233 |
+
"""
|
| 234 |
+
data_samples = self.model.test_step(inputs)
|
| 235 |
+
if self.cfg.data_mode == 'topdown' and merge_results:
|
| 236 |
+
data_samples = [merge_data_samples(data_samples)]
|
| 237 |
+
|
| 238 |
+
if bbox_thr > 0:
|
| 239 |
+
for ds in data_samples:
|
| 240 |
+
if 'bbox_scores' in ds.pred_instances:
|
| 241 |
+
ds.pred_instances = ds.pred_instances[
|
| 242 |
+
ds.pred_instances.bbox_scores > bbox_thr]
|
| 243 |
+
|
| 244 |
+
if pose_based_nms:
|
| 245 |
+
for ds in data_samples:
|
| 246 |
+
if len(ds.pred_instances) == 0:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
kpts = ds.pred_instances.keypoints
|
| 250 |
+
scores = ds.pred_instances.bbox_scores
|
| 251 |
+
num_keypoints = kpts.shape[-2]
|
| 252 |
+
|
| 253 |
+
kept_indices = nearby_joints_nms(
|
| 254 |
+
[
|
| 255 |
+
dict(keypoints=kpts[i], score=scores[i])
|
| 256 |
+
for i in range(len(kpts))
|
| 257 |
+
],
|
| 258 |
+
num_nearby_joints_thr=num_keypoints // 3,
|
| 259 |
+
)
|
| 260 |
+
ds.pred_instances = ds.pred_instances[kept_indices]
|
| 261 |
+
|
| 262 |
+
return data_samples
|
mmpose/apis/inferencers/pose3d_inferencer.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import mmcv
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from mmengine.config import Config, ConfigDict
|
| 11 |
+
from mmengine.infer.infer import ModelType
|
| 12 |
+
from mmengine.model import revert_sync_batchnorm
|
| 13 |
+
from mmengine.registry import init_default_scope
|
| 14 |
+
from mmengine.structures import InstanceData
|
| 15 |
+
|
| 16 |
+
from mmpose.apis import (_track_by_iou, _track_by_oks, collate_pose_sequence,
|
| 17 |
+
convert_keypoint_definition, extract_pose_sequence)
|
| 18 |
+
from mmpose.registry import INFERENCERS
|
| 19 |
+
from mmpose.structures import PoseDataSample, merge_data_samples
|
| 20 |
+
from .base_mmpose_inferencer import BaseMMPoseInferencer
|
| 21 |
+
from .pose2d_inferencer import Pose2DInferencer
|
| 22 |
+
|
| 23 |
+
InstanceList = List[InstanceData]
|
| 24 |
+
InputType = Union[str, np.ndarray]
|
| 25 |
+
InputsType = Union[InputType, Sequence[InputType]]
|
| 26 |
+
PredType = Union[InstanceData, InstanceList]
|
| 27 |
+
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
| 28 |
+
ConfigType = Union[Config, ConfigDict]
|
| 29 |
+
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@INFERENCERS.register_module(name='pose-estimation-3d')
|
| 33 |
+
@INFERENCERS.register_module()
|
| 34 |
+
class Pose3DInferencer(BaseMMPoseInferencer):
|
| 35 |
+
"""The inferencer for 3D pose estimation.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model (str, optional): Pretrained 2D pose estimation algorithm.
|
| 39 |
+
It's the path to the config file or the model name defined in
|
| 40 |
+
metafile. For example, it could be:
|
| 41 |
+
|
| 42 |
+
- model alias, e.g. ``'body'``,
|
| 43 |
+
- config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
|
| 44 |
+
- config path
|
| 45 |
+
|
| 46 |
+
Defaults to ``None``.
|
| 47 |
+
weights (str, optional): Path to the checkpoint. If it is not
|
| 48 |
+
specified and "model" is a model name of metafile, the weights
|
| 49 |
+
will be loaded from metafile. Defaults to None.
|
| 50 |
+
device (str, optional): Device to run inference. If None, the
|
| 51 |
+
available device will be automatically used. Defaults to None.
|
| 52 |
+
scope (str, optional): The scope of the model. Defaults to "mmpose".
|
| 53 |
+
det_model (str, optional): Config path or alias of detection model.
|
| 54 |
+
Defaults to None.
|
| 55 |
+
det_weights (str, optional): Path to the checkpoints of detection
|
| 56 |
+
model. Defaults to None.
|
| 57 |
+
det_cat_ids (int or list[int], optional): Category id for
|
| 58 |
+
detection model. Defaults to None.
|
| 59 |
+
output_heatmaps (bool, optional): Flag to visualize predicted
|
| 60 |
+
heatmaps. If set to None, the default setting from the model
|
| 61 |
+
config will be used. Default is None.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
preprocess_kwargs: set = {
|
| 65 |
+
'bbox_thr', 'nms_thr', 'bboxes', 'use_oks_tracking', 'tracking_thr',
|
| 66 |
+
'disable_norm_pose_2d'
|
| 67 |
+
}
|
| 68 |
+
forward_kwargs: set = {'disable_rebase_keypoint'}
|
| 69 |
+
visualize_kwargs: set = {
|
| 70 |
+
'return_vis',
|
| 71 |
+
'show',
|
| 72 |
+
'wait_time',
|
| 73 |
+
'draw_bbox',
|
| 74 |
+
'radius',
|
| 75 |
+
'thickness',
|
| 76 |
+
'num_instances',
|
| 77 |
+
'kpt_thr',
|
| 78 |
+
'vis_out_dir',
|
| 79 |
+
}
|
| 80 |
+
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
| 81 |
+
|
| 82 |
+
def __init__(self,
|
| 83 |
+
model: Union[ModelType, str],
|
| 84 |
+
weights: Optional[str] = None,
|
| 85 |
+
pose2d_model: Optional[Union[ModelType, str]] = None,
|
| 86 |
+
pose2d_weights: Optional[str] = None,
|
| 87 |
+
device: Optional[str] = None,
|
| 88 |
+
scope: Optional[str] = 'mmpose',
|
| 89 |
+
det_model: Optional[Union[ModelType, str]] = None,
|
| 90 |
+
det_weights: Optional[str] = None,
|
| 91 |
+
det_cat_ids: Optional[Union[int, Tuple]] = None,
|
| 92 |
+
show_progress: bool = False) -> None:
|
| 93 |
+
|
| 94 |
+
init_default_scope(scope)
|
| 95 |
+
super().__init__(
|
| 96 |
+
model=model,
|
| 97 |
+
weights=weights,
|
| 98 |
+
device=device,
|
| 99 |
+
scope=scope,
|
| 100 |
+
show_progress=show_progress)
|
| 101 |
+
self.model = revert_sync_batchnorm(self.model)
|
| 102 |
+
|
| 103 |
+
# assign dataset metainfo to self.visualizer
|
| 104 |
+
self.visualizer.set_dataset_meta(self.model.dataset_meta)
|
| 105 |
+
|
| 106 |
+
# initialize 2d pose estimator
|
| 107 |
+
self.pose2d_model = Pose2DInferencer(
|
| 108 |
+
pose2d_model if pose2d_model else 'human', pose2d_weights, device,
|
| 109 |
+
scope, det_model, det_weights, det_cat_ids)
|
| 110 |
+
|
| 111 |
+
# helper functions
|
| 112 |
+
self._keypoint_converter = partial(
|
| 113 |
+
convert_keypoint_definition,
|
| 114 |
+
pose_det_dataset=self.pose2d_model.model.
|
| 115 |
+
dataset_meta['dataset_name'],
|
| 116 |
+
pose_lift_dataset=self.model.dataset_meta['dataset_name'],
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self._pose_seq_extractor = partial(
|
| 120 |
+
extract_pose_sequence,
|
| 121 |
+
causal=self.cfg.test_dataloader.dataset.get('causal', False),
|
| 122 |
+
seq_len=self.cfg.test_dataloader.dataset.get('seq_len', 1),
|
| 123 |
+
step=self.cfg.test_dataloader.dataset.get('seq_step', 1))
|
| 124 |
+
|
| 125 |
+
self._video_input = False
|
| 126 |
+
self._buffer = defaultdict(list)
|
| 127 |
+
|
| 128 |
+
def preprocess_single(self,
|
| 129 |
+
input: InputType,
|
| 130 |
+
index: int,
|
| 131 |
+
bbox_thr: float = 0.3,
|
| 132 |
+
nms_thr: float = 0.3,
|
| 133 |
+
bboxes: Union[List[List], List[np.ndarray],
|
| 134 |
+
np.ndarray] = [],
|
| 135 |
+
use_oks_tracking: bool = False,
|
| 136 |
+
tracking_thr: float = 0.3,
|
| 137 |
+
disable_norm_pose_2d: bool = False):
|
| 138 |
+
"""Process a single input into a model-feedable format.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
input (InputType): The input provided by the user.
|
| 142 |
+
index (int): The index of the input.
|
| 143 |
+
bbox_thr (float, optional): The threshold for bounding box
|
| 144 |
+
detection. Defaults to 0.3.
|
| 145 |
+
nms_thr (float, optional): The Intersection over Union (IoU)
|
| 146 |
+
threshold for bounding box Non-Maximum Suppression (NMS).
|
| 147 |
+
Defaults to 0.3.
|
| 148 |
+
bboxes (Union[List[List], List[np.ndarray], np.ndarray]):
|
| 149 |
+
The bounding boxes to use. Defaults to [].
|
| 150 |
+
use_oks_tracking (bool, optional): A flag that indicates
|
| 151 |
+
whether OKS-based tracking should be used. Defaults to False.
|
| 152 |
+
tracking_thr (float, optional): The threshold for tracking.
|
| 153 |
+
Defaults to 0.3.
|
| 154 |
+
disable_norm_pose_2d (bool, optional): A flag that indicates
|
| 155 |
+
whether 2D pose normalization should be used.
|
| 156 |
+
Defaults to False.
|
| 157 |
+
|
| 158 |
+
Yields:
|
| 159 |
+
Any: The data processed by the pipeline and collate_fn.
|
| 160 |
+
|
| 161 |
+
This method first calculates 2D keypoints using the provided
|
| 162 |
+
pose2d_model. The method also performs instance matching, which
|
| 163 |
+
can use either OKS-based tracking or IOU-based tracking.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# calculate 2d keypoints
|
| 167 |
+
results_pose2d = next(
|
| 168 |
+
self.pose2d_model(
|
| 169 |
+
input,
|
| 170 |
+
bbox_thr=bbox_thr,
|
| 171 |
+
nms_thr=nms_thr,
|
| 172 |
+
bboxes=bboxes,
|
| 173 |
+
merge_results=False,
|
| 174 |
+
return_datasamples=True))['predictions']
|
| 175 |
+
|
| 176 |
+
for ds in results_pose2d:
|
| 177 |
+
ds.pred_instances.set_field(
|
| 178 |
+
(ds.pred_instances.bboxes[..., 2:] -
|
| 179 |
+
ds.pred_instances.bboxes[..., :2]).prod(-1), 'areas')
|
| 180 |
+
|
| 181 |
+
if not self._video_input:
|
| 182 |
+
height, width = results_pose2d[0].metainfo['ori_shape']
|
| 183 |
+
|
| 184 |
+
# Clear the buffer if inputs are individual images to prevent
|
| 185 |
+
# carryover effects from previous images
|
| 186 |
+
self._buffer.clear()
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
height = self.video_info['height']
|
| 190 |
+
width = self.video_info['width']
|
| 191 |
+
img_path = results_pose2d[0].metainfo['img_path']
|
| 192 |
+
|
| 193 |
+
# instance matching
|
| 194 |
+
if use_oks_tracking:
|
| 195 |
+
_track = partial(_track_by_oks)
|
| 196 |
+
else:
|
| 197 |
+
_track = _track_by_iou
|
| 198 |
+
|
| 199 |
+
for result in results_pose2d:
|
| 200 |
+
track_id, self._buffer['results_pose2d_last'], _ = _track(
|
| 201 |
+
result, self._buffer['results_pose2d_last'], tracking_thr)
|
| 202 |
+
if track_id == -1:
|
| 203 |
+
pred_instances = result.pred_instances.cpu().numpy()
|
| 204 |
+
keypoints = pred_instances.keypoints
|
| 205 |
+
if np.count_nonzero(keypoints[:, :, 1]) >= 3:
|
| 206 |
+
next_id = self._buffer.get('next_id', 0)
|
| 207 |
+
result.set_field(next_id, 'track_id')
|
| 208 |
+
self._buffer['next_id'] = next_id + 1
|
| 209 |
+
else:
|
| 210 |
+
# If the number of keypoints detected is small,
|
| 211 |
+
# delete that person instance.
|
| 212 |
+
result.pred_instances.keypoints[..., 1] = -10
|
| 213 |
+
result.pred_instances.bboxes *= 0
|
| 214 |
+
result.set_field(-1, 'track_id')
|
| 215 |
+
else:
|
| 216 |
+
result.set_field(track_id, 'track_id')
|
| 217 |
+
self._buffer['pose2d_results'] = merge_data_samples(results_pose2d)
|
| 218 |
+
|
| 219 |
+
# convert keypoints
|
| 220 |
+
results_pose2d_converted = [ds.cpu().numpy() for ds in results_pose2d]
|
| 221 |
+
for ds in results_pose2d_converted:
|
| 222 |
+
ds.pred_instances.keypoints = self._keypoint_converter(
|
| 223 |
+
ds.pred_instances.keypoints)
|
| 224 |
+
self._buffer['pose_est_results_list'].append(results_pose2d_converted)
|
| 225 |
+
|
| 226 |
+
# extract and pad input pose2d sequence
|
| 227 |
+
pose_results_2d = self._pose_seq_extractor(
|
| 228 |
+
self._buffer['pose_est_results_list'],
|
| 229 |
+
frame_idx=index if self._video_input else 0)
|
| 230 |
+
causal = self.cfg.test_dataloader.dataset.get('causal', False)
|
| 231 |
+
target_idx = -1 if causal else len(pose_results_2d) // 2
|
| 232 |
+
|
| 233 |
+
stats_info = self.model.dataset_meta.get('stats_info', {})
|
| 234 |
+
bbox_center = stats_info.get('bbox_center', None)
|
| 235 |
+
bbox_scale = stats_info.get('bbox_scale', None)
|
| 236 |
+
|
| 237 |
+
pose_results_2d_copy = []
|
| 238 |
+
for pose_res in pose_results_2d:
|
| 239 |
+
pose_res_copy = []
|
| 240 |
+
for data_sample in pose_res:
|
| 241 |
+
|
| 242 |
+
data_sample_copy = PoseDataSample()
|
| 243 |
+
data_sample_copy.gt_instances = \
|
| 244 |
+
data_sample.gt_instances.clone()
|
| 245 |
+
data_sample_copy.pred_instances = \
|
| 246 |
+
data_sample.pred_instances.clone()
|
| 247 |
+
data_sample_copy.track_id = data_sample.track_id
|
| 248 |
+
|
| 249 |
+
kpts = data_sample.pred_instances.keypoints
|
| 250 |
+
bboxes = data_sample.pred_instances.bboxes
|
| 251 |
+
keypoints = []
|
| 252 |
+
for k in range(len(kpts)):
|
| 253 |
+
kpt = kpts[k]
|
| 254 |
+
if not disable_norm_pose_2d:
|
| 255 |
+
bbox = bboxes[k]
|
| 256 |
+
center = np.array([[(bbox[0] + bbox[2]) / 2,
|
| 257 |
+
(bbox[1] + bbox[3]) / 2]])
|
| 258 |
+
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
| 259 |
+
keypoints.append((kpt[:, :2] - center) / scale *
|
| 260 |
+
bbox_scale + bbox_center)
|
| 261 |
+
else:
|
| 262 |
+
keypoints.append(kpt[:, :2])
|
| 263 |
+
data_sample_copy.pred_instances.set_field(
|
| 264 |
+
np.array(keypoints), 'keypoints')
|
| 265 |
+
pose_res_copy.append(data_sample_copy)
|
| 266 |
+
|
| 267 |
+
pose_results_2d_copy.append(pose_res_copy)
|
| 268 |
+
pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy, True,
|
| 269 |
+
target_idx)
|
| 270 |
+
if not pose_sequences_2d:
|
| 271 |
+
return []
|
| 272 |
+
|
| 273 |
+
data_list = []
|
| 274 |
+
for i, pose_seq in enumerate(pose_sequences_2d):
|
| 275 |
+
data_info = dict()
|
| 276 |
+
|
| 277 |
+
keypoints_2d = pose_seq.pred_instances.keypoints
|
| 278 |
+
keypoints_2d = np.squeeze(
|
| 279 |
+
keypoints_2d,
|
| 280 |
+
axis=0) if keypoints_2d.ndim == 4 else keypoints_2d
|
| 281 |
+
|
| 282 |
+
T, K, C = keypoints_2d.shape
|
| 283 |
+
|
| 284 |
+
data_info['keypoints'] = keypoints_2d
|
| 285 |
+
data_info['keypoints_visible'] = np.ones((
|
| 286 |
+
T,
|
| 287 |
+
K,
|
| 288 |
+
),
|
| 289 |
+
dtype=np.float32)
|
| 290 |
+
data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32)
|
| 291 |
+
data_info['factor'] = np.zeros((T, ), dtype=np.float32)
|
| 292 |
+
data_info['lifting_target_visible'] = np.ones((1, K, 1),
|
| 293 |
+
dtype=np.float32)
|
| 294 |
+
data_info['camera_param'] = dict(w=width, h=height)
|
| 295 |
+
|
| 296 |
+
data_info.update(self.model.dataset_meta)
|
| 297 |
+
data_info = self.pipeline(data_info)
|
| 298 |
+
data_info['data_samples'].set_field(
|
| 299 |
+
img_path, 'img_path', field_type='metainfo')
|
| 300 |
+
data_list.append(data_info)
|
| 301 |
+
|
| 302 |
+
return data_list
|
| 303 |
+
|
| 304 |
+
@torch.no_grad()
|
| 305 |
+
def forward(self,
|
| 306 |
+
inputs: Union[dict, tuple],
|
| 307 |
+
disable_rebase_keypoint: bool = False):
|
| 308 |
+
"""Perform forward pass through the model and process the results.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
inputs (Union[dict, tuple]): The inputs for the model.
|
| 312 |
+
disable_rebase_keypoint (bool, optional): Flag to disable rebasing
|
| 313 |
+
the height of the keypoints. Defaults to False.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
list: A list of data samples, each containing the model's output
|
| 317 |
+
results.
|
| 318 |
+
"""
|
| 319 |
+
pose_lift_results = self.model.test_step(inputs)
|
| 320 |
+
|
| 321 |
+
# Post-processing of pose estimation results
|
| 322 |
+
pose_est_results_converted = self._buffer['pose_est_results_list'][-1]
|
| 323 |
+
for idx, pose_lift_res in enumerate(pose_lift_results):
|
| 324 |
+
# Update track_id from the pose estimation results
|
| 325 |
+
pose_lift_res.track_id = pose_est_results_converted[idx].get(
|
| 326 |
+
'track_id', 1e4)
|
| 327 |
+
|
| 328 |
+
# align the shape of output keypoints coordinates and scores
|
| 329 |
+
keypoints = pose_lift_res.pred_instances.keypoints
|
| 330 |
+
keypoint_scores = pose_lift_res.pred_instances.keypoint_scores
|
| 331 |
+
if keypoint_scores.ndim == 3:
|
| 332 |
+
pose_lift_results[idx].pred_instances.keypoint_scores = \
|
| 333 |
+
np.squeeze(keypoint_scores, axis=1)
|
| 334 |
+
if keypoints.ndim == 4:
|
| 335 |
+
keypoints = np.squeeze(keypoints, axis=1)
|
| 336 |
+
|
| 337 |
+
# Invert x and z values of the keypoints
|
| 338 |
+
keypoints = keypoints[..., [0, 2, 1]]
|
| 339 |
+
keypoints[..., 0] = -keypoints[..., 0]
|
| 340 |
+
keypoints[..., 2] = -keypoints[..., 2]
|
| 341 |
+
|
| 342 |
+
# If rebase_keypoint_height is True, adjust z-axis values
|
| 343 |
+
if not disable_rebase_keypoint:
|
| 344 |
+
keypoints[..., 2] -= np.min(
|
| 345 |
+
keypoints[..., 2], axis=-1, keepdims=True)
|
| 346 |
+
|
| 347 |
+
pose_lift_results[idx].pred_instances.keypoints = keypoints
|
| 348 |
+
|
| 349 |
+
pose_lift_results = sorted(
|
| 350 |
+
pose_lift_results, key=lambda x: x.get('track_id', 1e4))
|
| 351 |
+
|
| 352 |
+
data_samples = [merge_data_samples(pose_lift_results)]
|
| 353 |
+
return data_samples
|
| 354 |
+
|
| 355 |
+
def visualize(self,
|
| 356 |
+
inputs: list,
|
| 357 |
+
preds: List[PoseDataSample],
|
| 358 |
+
return_vis: bool = False,
|
| 359 |
+
show: bool = False,
|
| 360 |
+
draw_bbox: bool = False,
|
| 361 |
+
wait_time: float = 0,
|
| 362 |
+
radius: int = 3,
|
| 363 |
+
thickness: int = 1,
|
| 364 |
+
kpt_thr: float = 0.3,
|
| 365 |
+
num_instances: int = 1,
|
| 366 |
+
vis_out_dir: str = '',
|
| 367 |
+
window_name: str = '',
|
| 368 |
+
window_close_event_handler: Optional[Callable] = None
|
| 369 |
+
) -> List[np.ndarray]:
|
| 370 |
+
"""Visualize predictions.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
| 374 |
+
preds (Any): Predictions of the model.
|
| 375 |
+
return_vis (bool): Whether to return images with predicted results.
|
| 376 |
+
show (bool): Whether to display the image in a popup window.
|
| 377 |
+
Defaults to False.
|
| 378 |
+
wait_time (float): The interval of show (ms). Defaults to 0
|
| 379 |
+
draw_bbox (bool): Whether to draw the bounding boxes.
|
| 380 |
+
Defaults to False
|
| 381 |
+
radius (int): Keypoint radius for visualization. Defaults to 3
|
| 382 |
+
thickness (int): Link thickness for visualization. Defaults to 1
|
| 383 |
+
kpt_thr (float): The threshold to visualize the keypoints.
|
| 384 |
+
Defaults to 0.3
|
| 385 |
+
vis_out_dir (str, optional): Directory to save visualization
|
| 386 |
+
results w/o predictions. If left as empty, no file will
|
| 387 |
+
be saved. Defaults to ''.
|
| 388 |
+
window_name (str, optional): Title of display window.
|
| 389 |
+
window_close_event_handler (callable, optional):
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List[np.ndarray]: Visualization results.
|
| 393 |
+
"""
|
| 394 |
+
if (not return_vis) and (not show) and (not vis_out_dir):
|
| 395 |
+
return
|
| 396 |
+
|
| 397 |
+
if getattr(self, 'visualizer', None) is None:
|
| 398 |
+
raise ValueError('Visualization needs the "visualizer" term'
|
| 399 |
+
'defined in the config, but got None.')
|
| 400 |
+
|
| 401 |
+
self.visualizer.radius = radius
|
| 402 |
+
self.visualizer.line_width = thickness
|
| 403 |
+
det_kpt_color = self.pose2d_model.visualizer.kpt_color
|
| 404 |
+
det_dataset_skeleton = self.pose2d_model.visualizer.skeleton
|
| 405 |
+
det_dataset_link_color = self.pose2d_model.visualizer.link_color
|
| 406 |
+
self.visualizer.det_kpt_color = det_kpt_color
|
| 407 |
+
self.visualizer.det_dataset_skeleton = det_dataset_skeleton
|
| 408 |
+
self.visualizer.det_dataset_link_color = det_dataset_link_color
|
| 409 |
+
|
| 410 |
+
results = []
|
| 411 |
+
|
| 412 |
+
for single_input, pred in zip(inputs, preds):
|
| 413 |
+
if isinstance(single_input, str):
|
| 414 |
+
img = mmcv.imread(single_input, channel_order='rgb')
|
| 415 |
+
elif isinstance(single_input, np.ndarray):
|
| 416 |
+
img = mmcv.bgr2rgb(single_input)
|
| 417 |
+
else:
|
| 418 |
+
raise ValueError('Unsupported input type: '
|
| 419 |
+
f'{type(single_input)}')
|
| 420 |
+
|
| 421 |
+
# since visualization and inference utilize the same process,
|
| 422 |
+
# the wait time is reduced when a video input is utilized,
|
| 423 |
+
# thereby eliminating the issue of inference getting stuck.
|
| 424 |
+
wait_time = 1e-5 if self._video_input else wait_time
|
| 425 |
+
|
| 426 |
+
if num_instances < 0:
|
| 427 |
+
num_instances = len(pred.pred_instances)
|
| 428 |
+
|
| 429 |
+
visualization = self.visualizer.add_datasample(
|
| 430 |
+
window_name,
|
| 431 |
+
img,
|
| 432 |
+
data_sample=pred,
|
| 433 |
+
det_data_sample=self._buffer['pose2d_results'],
|
| 434 |
+
draw_gt=False,
|
| 435 |
+
draw_bbox=draw_bbox,
|
| 436 |
+
show=show,
|
| 437 |
+
wait_time=wait_time,
|
| 438 |
+
dataset_2d=self.pose2d_model.model.
|
| 439 |
+
dataset_meta['dataset_name'],
|
| 440 |
+
dataset_3d=self.model.dataset_meta['dataset_name'],
|
| 441 |
+
kpt_thr=kpt_thr,
|
| 442 |
+
num_instances=num_instances)
|
| 443 |
+
results.append(visualization)
|
| 444 |
+
|
| 445 |
+
if vis_out_dir:
|
| 446 |
+
img_name = os.path.basename(pred.metainfo['img_path']) \
|
| 447 |
+
if 'img_path' in pred.metainfo else None
|
| 448 |
+
self.save_visualization(
|
| 449 |
+
visualization,
|
| 450 |
+
vis_out_dir,
|
| 451 |
+
img_name=img_name,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if return_vis:
|
| 455 |
+
return results
|
| 456 |
+
else:
|
| 457 |
+
return []
|
mmpose/apis/inferencers/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .default_det_models import default_det_models
|
| 3 |
+
from .get_model_alias import get_model_aliases
|
| 4 |
+
|
| 5 |
+
__all__ = ['default_det_models', 'get_model_aliases']
|
mmpose/apis/inferencers/utils/default_det_models.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os.path as osp
|
| 3 |
+
|
| 4 |
+
from mmengine.config.utils import MODULE2PACKAGE
|
| 5 |
+
from mmengine.utils import get_installed_path
|
| 6 |
+
|
| 7 |
+
mmpose_path = get_installed_path(MODULE2PACKAGE['mmpose'])
|
| 8 |
+
|
| 9 |
+
default_det_models = dict(
|
| 10 |
+
human=dict(
|
| 11 |
+
model=osp.join(
|
| 12 |
+
mmpose_path, '.mim', 'demo/mmdetection_cfg/'
|
| 13 |
+
'rtmdet_m_640-8xb32_coco-person.py'),
|
| 14 |
+
weights='https://download.openmmlab.com/mmpose/v1/projects/'
|
| 15 |
+
'rtmposev1/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth',
|
| 16 |
+
cat_ids=(0, )),
|
| 17 |
+
face=dict(
|
| 18 |
+
model=osp.join(mmpose_path, '.mim',
|
| 19 |
+
'demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py'),
|
| 20 |
+
weights='https://download.openmmlab.com/mmpose/mmdet_pretrained/'
|
| 21 |
+
'yolo-x_8xb8-300e_coco-face_13274d7c.pth',
|
| 22 |
+
cat_ids=(0, )),
|
| 23 |
+
hand=dict(
|
| 24 |
+
model=osp.join(mmpose_path, '.mim', 'demo/mmdetection_cfg/'
|
| 25 |
+
'rtmdet_nano_320-8xb32_hand.py'),
|
| 26 |
+
weights='https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/'
|
| 27 |
+
'rtmdet_nano_8xb32-300e_hand-267f9c8f.pth',
|
| 28 |
+
cat_ids=(0, )),
|
| 29 |
+
animal=dict(
|
| 30 |
+
model='rtmdet-m',
|
| 31 |
+
weights=None,
|
| 32 |
+
cat_ids=(15, 16, 17, 18, 19, 20, 21, 22, 23)),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
default_det_models['body'] = default_det_models['human']
|
| 36 |
+
default_det_models['wholebody'] = default_det_models['human']
|
mmpose/apis/inferencers/utils/get_model_alias.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
from mmengine.infer import BaseInferencer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_model_aliases(scope: str = 'mmpose') -> Dict[str, str]:
|
| 8 |
+
"""Retrieve model aliases and their corresponding configuration names.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
scope (str, optional): The scope for the model aliases. Defaults
|
| 12 |
+
to 'mmpose'.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dict[str, str]: A dictionary containing model aliases as keys and
|
| 16 |
+
their corresponding configuration names as values.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
# Get a list of model configurations from the metafile
|
| 20 |
+
repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope)
|
| 21 |
+
model_cfgs = BaseInferencer._get_models_from_metafile(repo_or_mim_dir)
|
| 22 |
+
|
| 23 |
+
model_alias_dict = dict()
|
| 24 |
+
for model_cfg in model_cfgs:
|
| 25 |
+
if 'Alias' in model_cfg:
|
| 26 |
+
if isinstance(model_cfg['Alias'], str):
|
| 27 |
+
model_alias_dict[model_cfg['Alias']] = model_cfg['Name']
|
| 28 |
+
elif isinstance(model_cfg['Alias'], list):
|
| 29 |
+
for alias in model_cfg['Alias']:
|
| 30 |
+
model_alias_dict[alias] = model_cfg['Name']
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError(
|
| 33 |
+
'encounter an unexpected alias type. Please raise an '
|
| 34 |
+
'issue at https://github.com/open-mmlab/mmpose/issues '
|
| 35 |
+
'to announce us')
|
| 36 |
+
|
| 37 |
+
return model_alias_dict
|
mmpose/apis/visualization.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import mmcv
|
| 6 |
+
import numpy as np
|
| 7 |
+
from mmengine.structures import InstanceData
|
| 8 |
+
|
| 9 |
+
from mmpose.datasets.datasets.utils import parse_pose_metainfo
|
| 10 |
+
from mmpose.structures import PoseDataSample
|
| 11 |
+
from mmpose.visualization import PoseLocalVisualizer
|
| 12 |
+
|
| 13 |
+
# from posevis import pose_visualization
|
| 14 |
+
|
| 15 |
+
# def visualize(
|
| 16 |
+
# img: Union[np.ndarray, str],
|
| 17 |
+
# keypoints: np.ndarray,
|
| 18 |
+
# keypoint_score: np.ndarray = None,
|
| 19 |
+
# metainfo: Union[str, dict] = None,
|
| 20 |
+
# visualizer: PoseLocalVisualizer = None,
|
| 21 |
+
# show_kpt_idx: bool = False,
|
| 22 |
+
# skeleton_style: str = 'mmpose',
|
| 23 |
+
# show: bool = False,
|
| 24 |
+
# kpt_thr: float = 0.3,
|
| 25 |
+
# ):
|
| 26 |
+
# """Visualize 2d keypoints on an image.
|
| 27 |
+
|
| 28 |
+
# Args:
|
| 29 |
+
# img (str | np.ndarray): The image to be displayed.
|
| 30 |
+
# keypoints (np.ndarray): The keypoint to be displayed.
|
| 31 |
+
# keypoint_score (np.ndarray): The score of each keypoint.
|
| 32 |
+
# metainfo (str | dict): The metainfo of dataset.
|
| 33 |
+
# visualizer (PoseLocalVisualizer): The visualizer.
|
| 34 |
+
# show_kpt_idx (bool): Whether to show the index of keypoints.
|
| 35 |
+
# skeleton_style (str): Skeleton style. Options are 'mmpose' and
|
| 36 |
+
# 'openpose'.
|
| 37 |
+
# show (bool): Whether to show the image.
|
| 38 |
+
# wait_time (int): Value of waitKey param.
|
| 39 |
+
# kpt_thr (float): Keypoint threshold.
|
| 40 |
+
# """
|
| 41 |
+
# kpts = keypoints.reshape(-1, 2)
|
| 42 |
+
# kpts = np.concatenate([kpts, keypoint_score[:, None]], axis=1)
|
| 43 |
+
# kpts[kpts[:, 2] < kpt_thr, :] = 0
|
| 44 |
+
# pose_results = [{
|
| 45 |
+
# 'keypoints': kpts,
|
| 46 |
+
# }]
|
| 47 |
+
|
| 48 |
+
# img = pose_visualization(
|
| 49 |
+
# img,
|
| 50 |
+
# pose_results,
|
| 51 |
+
# format="COCO",
|
| 52 |
+
# greyness=1.0,
|
| 53 |
+
# show_markers=True,
|
| 54 |
+
# show_bones=True,
|
| 55 |
+
# line_type="solid",
|
| 56 |
+
# width_multiplier=1.0,
|
| 57 |
+
# bbox_width_multiplier=1.0,
|
| 58 |
+
# show_bbox=False,
|
| 59 |
+
# differ_individuals=False,
|
| 60 |
+
# )
|
| 61 |
+
# return img
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def visualize(
|
| 65 |
+
img: Union[np.ndarray, str],
|
| 66 |
+
keypoints: np.ndarray,
|
| 67 |
+
keypoint_score: np.ndarray = None,
|
| 68 |
+
metainfo: Union[str, dict] = None,
|
| 69 |
+
visualizer: PoseLocalVisualizer = None,
|
| 70 |
+
show_kpt_idx: bool = False,
|
| 71 |
+
skeleton_style: str = 'mmpose',
|
| 72 |
+
show: bool = False,
|
| 73 |
+
kpt_thr: float = 0.3,
|
| 74 |
+
):
|
| 75 |
+
"""Visualize 2d keypoints on an image.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
img (str | np.ndarray): The image to be displayed.
|
| 79 |
+
keypoints (np.ndarray): The keypoint to be displayed.
|
| 80 |
+
keypoint_score (np.ndarray): The score of each keypoint.
|
| 81 |
+
metainfo (str | dict): The metainfo of dataset.
|
| 82 |
+
visualizer (PoseLocalVisualizer): The visualizer.
|
| 83 |
+
show_kpt_idx (bool): Whether to show the index of keypoints.
|
| 84 |
+
skeleton_style (str): Skeleton style. Options are 'mmpose' and
|
| 85 |
+
'openpose'.
|
| 86 |
+
show (bool): Whether to show the image.
|
| 87 |
+
wait_time (int): Value of waitKey param.
|
| 88 |
+
kpt_thr (float): Keypoint threshold.
|
| 89 |
+
"""
|
| 90 |
+
assert skeleton_style in [
|
| 91 |
+
'mmpose', 'openpose'
|
| 92 |
+
], (f'Only support skeleton style in {["mmpose", "openpose"]}, ')
|
| 93 |
+
|
| 94 |
+
if visualizer is None:
|
| 95 |
+
visualizer = PoseLocalVisualizer()
|
| 96 |
+
else:
|
| 97 |
+
visualizer = deepcopy(visualizer)
|
| 98 |
+
|
| 99 |
+
if isinstance(metainfo, str):
|
| 100 |
+
metainfo = parse_pose_metainfo(dict(from_file=metainfo))
|
| 101 |
+
elif isinstance(metainfo, dict):
|
| 102 |
+
metainfo = parse_pose_metainfo(metainfo)
|
| 103 |
+
|
| 104 |
+
if metainfo is not None:
|
| 105 |
+
visualizer.set_dataset_meta(metainfo, skeleton_style=skeleton_style)
|
| 106 |
+
|
| 107 |
+
if isinstance(img, str):
|
| 108 |
+
img = mmcv.imread(img, channel_order='rgb')
|
| 109 |
+
elif isinstance(img, np.ndarray):
|
| 110 |
+
img = mmcv.bgr2rgb(img)
|
| 111 |
+
|
| 112 |
+
if keypoint_score is None:
|
| 113 |
+
keypoint_score = np.ones(keypoints.shape[0])
|
| 114 |
+
|
| 115 |
+
tmp_instances = InstanceData()
|
| 116 |
+
tmp_instances.keypoints = keypoints
|
| 117 |
+
tmp_instances.keypoint_score = keypoint_score
|
| 118 |
+
|
| 119 |
+
tmp_datasample = PoseDataSample()
|
| 120 |
+
tmp_datasample.pred_instances = tmp_instances
|
| 121 |
+
|
| 122 |
+
visualizer.add_datasample(
|
| 123 |
+
'visualization',
|
| 124 |
+
img,
|
| 125 |
+
tmp_datasample,
|
| 126 |
+
show_kpt_idx=show_kpt_idx,
|
| 127 |
+
skeleton_style=skeleton_style,
|
| 128 |
+
show=show,
|
| 129 |
+
wait_time=0,
|
| 130 |
+
kpt_thr=kpt_thr)
|
| 131 |
+
|
| 132 |
+
return visualizer.get_image()
|
mmpose/codecs/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .annotation_processors import YOLOXPoseAnnotationProcessor
|
| 3 |
+
from .associative_embedding import AssociativeEmbedding
|
| 4 |
+
from .decoupled_heatmap import DecoupledHeatmap
|
| 5 |
+
from .edpose_label import EDPoseLabel
|
| 6 |
+
from .hand_3d_heatmap import Hand3DHeatmap
|
| 7 |
+
from .image_pose_lifting import ImagePoseLifting
|
| 8 |
+
from .integral_regression_label import IntegralRegressionLabel
|
| 9 |
+
from .megvii_heatmap import MegviiHeatmap
|
| 10 |
+
from .motionbert_label import MotionBERTLabel
|
| 11 |
+
from .msra_heatmap import MSRAHeatmap
|
| 12 |
+
from .regression_label import RegressionLabel
|
| 13 |
+
from .simcc_label import SimCCLabel
|
| 14 |
+
from .spr import SPR
|
| 15 |
+
from .udp_heatmap import UDPHeatmap
|
| 16 |
+
from .video_pose_lifting import VideoPoseLifting
|
| 17 |
+
from .onehot_heatmap import OneHotHeatmap
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel',
|
| 21 |
+
'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR',
|
| 22 |
+
'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting',
|
| 23 |
+
'MotionBERTLabel', 'YOLOXPoseAnnotationProcessor', 'EDPoseLabel',
|
| 24 |
+
'Hand3DHeatmap', 'OneHotHeatmap'
|
| 25 |
+
]
|
mmpose/codecs/annotation_processors.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 7 |
+
from .base import BaseKeypointCodec
|
| 8 |
+
|
| 9 |
+
INF = 1e6
|
| 10 |
+
NEG_INF = -1e6
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseAnnotationProcessor(BaseKeypointCodec):
|
| 14 |
+
"""Base class for annotation processors."""
|
| 15 |
+
|
| 16 |
+
def decode(self, *args, **kwargs):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@KEYPOINT_CODECS.register_module()
|
| 21 |
+
class YOLOXPoseAnnotationProcessor(BaseAnnotationProcessor):
|
| 22 |
+
"""Convert dataset annotations to the input format of YOLOX-Pose.
|
| 23 |
+
|
| 24 |
+
This processor expands bounding boxes and converts category IDs to labels.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
extend_bbox (bool, optional): Whether to expand the bounding box
|
| 28 |
+
to include all keypoints. Defaults to False.
|
| 29 |
+
input_size (tuple, optional): The size of the input image for the
|
| 30 |
+
model, formatted as (h, w). This argument is necessary for the
|
| 31 |
+
codec in deployment but is not used indeed.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
auxiliary_encode_keys = {'category_id', 'bbox'}
|
| 35 |
+
label_mapping_table = dict(
|
| 36 |
+
bbox='bboxes',
|
| 37 |
+
bbox_labels='labels',
|
| 38 |
+
keypoints='keypoints',
|
| 39 |
+
keypoints_visible='keypoints_visible',
|
| 40 |
+
area='areas',
|
| 41 |
+
)
|
| 42 |
+
instance_mapping_table = dict(
|
| 43 |
+
bbox='bboxes',
|
| 44 |
+
bbox_score='bbox_scores',
|
| 45 |
+
keypoints='keypoints',
|
| 46 |
+
keypoints_visible='keypoints_visible',
|
| 47 |
+
# remove 'bbox_scales' in default instance_mapping_table to avoid
|
| 48 |
+
# length mismatch during training with multiple datasets
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def __init__(self,
|
| 52 |
+
extend_bbox: bool = False,
|
| 53 |
+
input_size: Optional[Tuple] = None):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.extend_bbox = extend_bbox
|
| 56 |
+
|
| 57 |
+
def encode(self,
|
| 58 |
+
keypoints: Optional[np.ndarray] = None,
|
| 59 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 60 |
+
bbox: Optional[np.ndarray] = None,
|
| 61 |
+
category_id: Optional[List[int]] = None
|
| 62 |
+
) -> Dict[str, np.ndarray]:
|
| 63 |
+
"""Encode keypoints, bounding boxes, and category IDs.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
keypoints (np.ndarray, optional): Keypoints array. Defaults
|
| 67 |
+
to None.
|
| 68 |
+
keypoints_visible (np.ndarray, optional): Visibility array for
|
| 69 |
+
keypoints. Defaults to None.
|
| 70 |
+
bbox (np.ndarray, optional): Bounding box array. Defaults to None.
|
| 71 |
+
category_id (List[int], optional): List of category IDs. Defaults
|
| 72 |
+
to None.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dict[str, np.ndarray]: Encoded annotations.
|
| 76 |
+
"""
|
| 77 |
+
results = {}
|
| 78 |
+
|
| 79 |
+
if self.extend_bbox and bbox is not None:
|
| 80 |
+
# Handle keypoints visibility
|
| 81 |
+
if keypoints_visible.ndim == 3:
|
| 82 |
+
keypoints_visible = keypoints_visible[..., 0]
|
| 83 |
+
|
| 84 |
+
# Expand bounding box to include keypoints
|
| 85 |
+
kpts_min = keypoints.copy()
|
| 86 |
+
kpts_min[keypoints_visible == 0] = INF
|
| 87 |
+
bbox[..., :2] = np.minimum(bbox[..., :2], kpts_min.min(axis=1))
|
| 88 |
+
|
| 89 |
+
kpts_max = keypoints.copy()
|
| 90 |
+
kpts_max[keypoints_visible == 0] = NEG_INF
|
| 91 |
+
bbox[..., 2:] = np.maximum(bbox[..., 2:], kpts_max.max(axis=1))
|
| 92 |
+
|
| 93 |
+
results['bbox'] = bbox
|
| 94 |
+
|
| 95 |
+
if category_id is not None:
|
| 96 |
+
# Convert category IDs to labels
|
| 97 |
+
bbox_labels = np.array(category_id).astype(np.int8) - 1
|
| 98 |
+
results['bbox_labels'] = bbox_labels
|
| 99 |
+
|
| 100 |
+
return results
|
mmpose/codecs/associative_embedding.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Any, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from munkres import Munkres
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 11 |
+
from mmpose.utils.tensor_utils import to_numpy
|
| 12 |
+
from .base import BaseKeypointCodec
|
| 13 |
+
from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps,
|
| 14 |
+
generate_udp_gaussian_heatmaps, refine_keypoints,
|
| 15 |
+
refine_keypoints_dark_udp)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _py_max_match(scores):
|
| 19 |
+
"""Apply munkres algorithm to get the best match.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
scores(np.ndarray): cost matrix.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
np.ndarray: best match.
|
| 26 |
+
"""
|
| 27 |
+
m = Munkres()
|
| 28 |
+
tmp = m.compute(scores)
|
| 29 |
+
tmp = np.array(tmp).astype(int)
|
| 30 |
+
return tmp
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _group_keypoints_by_tags(vals: np.ndarray,
|
| 34 |
+
tags: np.ndarray,
|
| 35 |
+
locs: np.ndarray,
|
| 36 |
+
keypoint_order: List[int],
|
| 37 |
+
val_thr: float,
|
| 38 |
+
tag_thr: float = 1.0,
|
| 39 |
+
max_groups: Optional[int] = None) -> np.ndarray:
|
| 40 |
+
"""Group the keypoints by tags using Munkres algorithm.
|
| 41 |
+
|
| 42 |
+
Note:
|
| 43 |
+
|
| 44 |
+
- keypoint number: K
|
| 45 |
+
- candidate number: M
|
| 46 |
+
- tag dimenssion: L
|
| 47 |
+
- coordinate dimension: D
|
| 48 |
+
- group number: G
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
vals (np.ndarray): The heatmap response values of keypoints in shape
|
| 52 |
+
(K, M)
|
| 53 |
+
tags (np.ndarray): The tags of the keypoint candidates in shape
|
| 54 |
+
(K, M, L)
|
| 55 |
+
locs (np.ndarray): The locations of the keypoint candidates in shape
|
| 56 |
+
(K, M, D)
|
| 57 |
+
keypoint_order (List[int]): The grouping order of the keypoints.
|
| 58 |
+
The groupping usually starts from a keypoints around the head and
|
| 59 |
+
torso, and gruadually moves out to the limbs
|
| 60 |
+
val_thr (float): The threshold of the keypoint response value
|
| 61 |
+
tag_thr (float): The maximum allowed tag distance when matching a
|
| 62 |
+
keypoint to a group. A keypoint with larger tag distance to any
|
| 63 |
+
of the existing groups will initializes a new group
|
| 64 |
+
max_groups (int, optional): The maximum group number. ``None`` means
|
| 65 |
+
no limitation. Defaults to ``None``
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
|
| 69 |
+
dimenssion is the concatenated keypoint coordinates and scores.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
tag_k, loc_k, val_k = tags, locs, vals
|
| 73 |
+
K, M, D = locs.shape
|
| 74 |
+
assert vals.shape == tags.shape[:2] == (K, M)
|
| 75 |
+
assert len(keypoint_order) == K
|
| 76 |
+
|
| 77 |
+
default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32)
|
| 78 |
+
|
| 79 |
+
joint_dict = {}
|
| 80 |
+
tag_dict = {}
|
| 81 |
+
for i in range(K):
|
| 82 |
+
idx = keypoint_order[i]
|
| 83 |
+
|
| 84 |
+
tags = tag_k[idx]
|
| 85 |
+
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
|
| 86 |
+
mask = joints[:, 2] > val_thr
|
| 87 |
+
tags = tags[mask] # shape: [M, L]
|
| 88 |
+
joints = joints[mask] # shape: [M, 3 + L], 3: x, y, val
|
| 89 |
+
|
| 90 |
+
if joints.shape[0] == 0:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
if i == 0 or len(joint_dict) == 0:
|
| 94 |
+
for tag, joint in zip(tags, joints):
|
| 95 |
+
key = tag[0]
|
| 96 |
+
joint_dict.setdefault(key, np.copy(default_))[idx] = joint
|
| 97 |
+
tag_dict[key] = [tag]
|
| 98 |
+
else:
|
| 99 |
+
# shape: [M]
|
| 100 |
+
grouped_keys = list(joint_dict.keys())
|
| 101 |
+
# shape: [M, L]
|
| 102 |
+
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]
|
| 103 |
+
|
| 104 |
+
# shape: [M, M, L]
|
| 105 |
+
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
|
| 106 |
+
# shape: [M, M]
|
| 107 |
+
diff_normed = np.linalg.norm(diff, ord=2, axis=2)
|
| 108 |
+
diff_saved = np.copy(diff_normed)
|
| 109 |
+
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]
|
| 110 |
+
|
| 111 |
+
num_added = diff.shape[0]
|
| 112 |
+
num_grouped = diff.shape[1]
|
| 113 |
+
|
| 114 |
+
if num_added > num_grouped:
|
| 115 |
+
diff_normed = np.concatenate(
|
| 116 |
+
(diff_normed,
|
| 117 |
+
np.zeros((num_added, num_added - num_grouped),
|
| 118 |
+
dtype=np.float32) + 1e10),
|
| 119 |
+
axis=1)
|
| 120 |
+
|
| 121 |
+
pairs = _py_max_match(diff_normed)
|
| 122 |
+
for row, col in pairs:
|
| 123 |
+
if (row < num_added and col < num_grouped
|
| 124 |
+
and diff_saved[row][col] < tag_thr):
|
| 125 |
+
key = grouped_keys[col]
|
| 126 |
+
joint_dict[key][idx] = joints[row]
|
| 127 |
+
tag_dict[key].append(tags[row])
|
| 128 |
+
else:
|
| 129 |
+
key = tags[row][0]
|
| 130 |
+
joint_dict.setdefault(key, np.copy(default_))[idx] = \
|
| 131 |
+
joints[row]
|
| 132 |
+
tag_dict[key] = [tags[row]]
|
| 133 |
+
|
| 134 |
+
joint_dict_keys = list(joint_dict.keys())[:max_groups]
|
| 135 |
+
|
| 136 |
+
if joint_dict_keys:
|
| 137 |
+
results = np.array([joint_dict[i]
|
| 138 |
+
for i in joint_dict_keys]).astype(np.float32)
|
| 139 |
+
results = results[..., :D + 1]
|
| 140 |
+
else:
|
| 141 |
+
results = np.empty((0, K, D + 1), dtype=np.float32)
|
| 142 |
+
return results
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@KEYPOINT_CODECS.register_module()
|
| 146 |
+
class AssociativeEmbedding(BaseKeypointCodec):
|
| 147 |
+
"""Encode/decode keypoints with the method introduced in "Associative
|
| 148 |
+
Embedding". This is an asymmetric codec, where the keypoints are
|
| 149 |
+
represented as gaussian heatmaps and position indices during encoding, and
|
| 150 |
+
restored from predicted heatmaps and group tags.
|
| 151 |
+
|
| 152 |
+
See the paper `Associative Embedding: End-to-End Learning for Joint
|
| 153 |
+
Detection and Grouping`_ by Newell et al (2017) for details
|
| 154 |
+
|
| 155 |
+
Note:
|
| 156 |
+
|
| 157 |
+
- instance number: N
|
| 158 |
+
- keypoint number: K
|
| 159 |
+
- keypoint dimension: D
|
| 160 |
+
- embedding tag dimension: L
|
| 161 |
+
- image size: [w, h]
|
| 162 |
+
- heatmap size: [W, H]
|
| 163 |
+
|
| 164 |
+
Encoded:
|
| 165 |
+
|
| 166 |
+
- heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
|
| 167 |
+
where [W, H] is the `heatmap_size`
|
| 168 |
+
- keypoint_indices (np.ndarray): The keypoint position indices in shape
|
| 169 |
+
(N, K, 2). Each keypoint's index is [i, v], where i is the position
|
| 170 |
+
index in the heatmap (:math:`i=y*w+x`) and v is the visibility
|
| 171 |
+
- keypoint_weights (np.ndarray): The target weights in shape (N, K)
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
input_size (tuple): Image size in [w, h]
|
| 175 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 176 |
+
sigma (float): The sigma value of the Gaussian heatmap
|
| 177 |
+
use_udp (bool): Whether use unbiased data processing. See
|
| 178 |
+
`UDP (CVPR 2020)`_ for details. Defaults to ``False``
|
| 179 |
+
decode_keypoint_order (List[int]): The grouping order of the
|
| 180 |
+
keypoint indices. The groupping usually starts from a keypoints
|
| 181 |
+
around the head and torso, and gruadually moves out to the limbs
|
| 182 |
+
decode_keypoint_thr (float): The threshold of keypoint response value
|
| 183 |
+
in heatmaps. Defaults to 0.1
|
| 184 |
+
decode_tag_thr (float): The maximum allowed tag distance when matching
|
| 185 |
+
a keypoint to a group. A keypoint with larger tag distance to any
|
| 186 |
+
of the existing groups will initializes a new group. Defaults to
|
| 187 |
+
1.0
|
| 188 |
+
decode_nms_kernel (int): The kernel size of the NMS during decoding,
|
| 189 |
+
which should be an odd integer. Defaults to 5
|
| 190 |
+
decode_gaussian_kernel (int): The kernel size of the Gaussian blur
|
| 191 |
+
during decoding, which should be an odd integer. It is only used
|
| 192 |
+
when ``self.use_udp==True``. Defaults to 3
|
| 193 |
+
decode_topk (int): The number top-k candidates of each keypoints that
|
| 194 |
+
will be retrieved from the heatmaps during dedocding. Defaults to
|
| 195 |
+
20
|
| 196 |
+
decode_max_instances (int, optional): The maximum number of instances
|
| 197 |
+
to decode. ``None`` means no limitation to the instance number.
|
| 198 |
+
Defaults to ``None``
|
| 199 |
+
|
| 200 |
+
.. _`Associative Embedding: End-to-End Learning for Joint Detection and
|
| 201 |
+
Grouping`: https://arxiv.org/abs/1611.05424
|
| 202 |
+
.. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
input_size: Tuple[int, int],
|
| 208 |
+
heatmap_size: Tuple[int, int],
|
| 209 |
+
sigma: Optional[float] = None,
|
| 210 |
+
use_udp: bool = False,
|
| 211 |
+
decode_keypoint_order: List[int] = [],
|
| 212 |
+
decode_nms_kernel: int = 5,
|
| 213 |
+
decode_gaussian_kernel: int = 3,
|
| 214 |
+
decode_keypoint_thr: float = 0.1,
|
| 215 |
+
decode_tag_thr: float = 1.0,
|
| 216 |
+
decode_topk: int = 30,
|
| 217 |
+
decode_center_shift=0.0,
|
| 218 |
+
decode_max_instances: Optional[int] = None,
|
| 219 |
+
) -> None:
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.input_size = input_size
|
| 222 |
+
self.heatmap_size = heatmap_size
|
| 223 |
+
self.use_udp = use_udp
|
| 224 |
+
self.decode_nms_kernel = decode_nms_kernel
|
| 225 |
+
self.decode_gaussian_kernel = decode_gaussian_kernel
|
| 226 |
+
self.decode_keypoint_thr = decode_keypoint_thr
|
| 227 |
+
self.decode_tag_thr = decode_tag_thr
|
| 228 |
+
self.decode_topk = decode_topk
|
| 229 |
+
self.decode_center_shift = decode_center_shift
|
| 230 |
+
self.decode_max_instances = decode_max_instances
|
| 231 |
+
self.decode_keypoint_order = decode_keypoint_order.copy()
|
| 232 |
+
|
| 233 |
+
if self.use_udp:
|
| 234 |
+
self.scale_factor = ((np.array(input_size) - 1) /
|
| 235 |
+
(np.array(heatmap_size) - 1)).astype(
|
| 236 |
+
np.float32)
|
| 237 |
+
else:
|
| 238 |
+
self.scale_factor = (np.array(input_size) /
|
| 239 |
+
heatmap_size).astype(np.float32)
|
| 240 |
+
|
| 241 |
+
if sigma is None:
|
| 242 |
+
sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64
|
| 243 |
+
self.sigma = sigma
|
| 244 |
+
|
| 245 |
+
def encode(
|
| 246 |
+
self,
|
| 247 |
+
keypoints: np.ndarray,
|
| 248 |
+
keypoints_visible: Optional[np.ndarray] = None
|
| 249 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 250 |
+
"""Encode keypoints into heatmaps and position indices. Note that the
|
| 251 |
+
original keypoint coordinates should be in the input image space.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 255 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 256 |
+
(N, K)
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
dict:
|
| 260 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 261 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 262 |
+
- keypoint_indices (np.ndarray): The keypoint position indices
|
| 263 |
+
in shape (N, K, 2). Each keypoint's index is [i, v], where i
|
| 264 |
+
is the position index in the heatmap (:math:`i=y*w+x`) and v
|
| 265 |
+
is the visibility
|
| 266 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 267 |
+
(N, K)
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
if keypoints_visible is None:
|
| 271 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 272 |
+
|
| 273 |
+
# keypoint coordinates in heatmap
|
| 274 |
+
_keypoints = keypoints / self.scale_factor
|
| 275 |
+
|
| 276 |
+
if self.use_udp:
|
| 277 |
+
heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps(
|
| 278 |
+
heatmap_size=self.heatmap_size,
|
| 279 |
+
keypoints=_keypoints,
|
| 280 |
+
keypoints_visible=keypoints_visible,
|
| 281 |
+
sigma=self.sigma)
|
| 282 |
+
else:
|
| 283 |
+
heatmaps, keypoint_weights = generate_gaussian_heatmaps(
|
| 284 |
+
heatmap_size=self.heatmap_size,
|
| 285 |
+
keypoints=_keypoints,
|
| 286 |
+
keypoints_visible=keypoints_visible,
|
| 287 |
+
sigma=self.sigma)
|
| 288 |
+
|
| 289 |
+
keypoint_indices = self._encode_keypoint_indices(
|
| 290 |
+
heatmap_size=self.heatmap_size,
|
| 291 |
+
keypoints=_keypoints,
|
| 292 |
+
keypoints_visible=keypoints_visible)
|
| 293 |
+
|
| 294 |
+
encoded = dict(
|
| 295 |
+
heatmaps=heatmaps,
|
| 296 |
+
keypoint_indices=keypoint_indices,
|
| 297 |
+
keypoint_weights=keypoint_weights)
|
| 298 |
+
|
| 299 |
+
return encoded
|
| 300 |
+
|
| 301 |
+
def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int],
|
| 302 |
+
keypoints: np.ndarray,
|
| 303 |
+
keypoints_visible: np.ndarray) -> np.ndarray:
|
| 304 |
+
w, h = heatmap_size
|
| 305 |
+
N, K, _ = keypoints.shape
|
| 306 |
+
keypoint_indices = np.zeros((N, K, 2), dtype=np.int64)
|
| 307 |
+
|
| 308 |
+
for n, k in product(range(N), range(K)):
|
| 309 |
+
x, y = (keypoints[n, k] + 0.5).astype(np.int64)
|
| 310 |
+
index = y * w + x
|
| 311 |
+
vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h)
|
| 312 |
+
keypoint_indices[n, k] = [index, vis]
|
| 313 |
+
|
| 314 |
+
return keypoint_indices
|
| 315 |
+
|
| 316 |
+
def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
|
| 317 |
+
raise NotImplementedError()
|
| 318 |
+
|
| 319 |
+
def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
|
| 320 |
+
k: int):
|
| 321 |
+
"""Get top-k response values from the heatmaps and corresponding tag
|
| 322 |
+
values from the tagging heatmaps.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
batch_heatmaps (Tensor): Keypoint detection heatmaps in shape
|
| 326 |
+
(B, K, H, W)
|
| 327 |
+
batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where
|
| 328 |
+
the tag dim C is 2*K when using flip testing, or K otherwise
|
| 329 |
+
k (int): The number of top responses to get
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
tuple:
|
| 333 |
+
- topk_vals (Tensor): Top-k response values of each heatmap in
|
| 334 |
+
shape (B, K, Topk)
|
| 335 |
+
- topk_tags (Tensor): The corresponding embedding tags of the
|
| 336 |
+
top-k responses, in shape (B, K, Topk, L)
|
| 337 |
+
- topk_locs (Tensor): The location of the top-k responses in each
|
| 338 |
+
heatmap, in shape (B, K, Topk, 2) where last dimension
|
| 339 |
+
represents x and y coordinates
|
| 340 |
+
"""
|
| 341 |
+
B, K, H, W = batch_heatmaps.shape
|
| 342 |
+
L = batch_tags.shape[1] // K
|
| 343 |
+
|
| 344 |
+
# shape of topk_val, top_indices: (B, K, TopK)
|
| 345 |
+
topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
|
| 346 |
+
k, dim=-1)
|
| 347 |
+
|
| 348 |
+
topk_tags_per_kpts = [
|
| 349 |
+
torch.gather(_tag, dim=2, index=topk_indices)
|
| 350 |
+
for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1)
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L)
|
| 354 |
+
topk_locs = torch.stack([topk_indices % W, topk_indices // W],
|
| 355 |
+
dim=-1) # (B, K, TopK, 2)
|
| 356 |
+
|
| 357 |
+
return topk_vals, topk_tags, topk_locs
|
| 358 |
+
|
| 359 |
+
def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray,
|
| 360 |
+
batch_locs: np.ndarray):
|
| 361 |
+
"""Group keypoints into groups (each represents an instance) by tags.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
batch_vals (Tensor): Heatmap response values of keypoint
|
| 365 |
+
candidates in shape (B, K, Topk)
|
| 366 |
+
batch_tags (Tensor): Tags of keypoint candidates in shape
|
| 367 |
+
(B, K, Topk, L)
|
| 368 |
+
batch_locs (Tensor): Locations of keypoint candidates in shape
|
| 369 |
+
(B, K, Topk, 2)
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
List[np.ndarray]: Grouping results of a batch, each element is a
|
| 373 |
+
np.ndarray (in shape [N, K, D+1]) that contains the groups
|
| 374 |
+
detected in an image, including both keypoint coordinates and
|
| 375 |
+
scores.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def _group_func(inputs: Tuple):
|
| 379 |
+
vals, tags, locs = inputs
|
| 380 |
+
return _group_keypoints_by_tags(
|
| 381 |
+
vals,
|
| 382 |
+
tags,
|
| 383 |
+
locs,
|
| 384 |
+
keypoint_order=self.decode_keypoint_order,
|
| 385 |
+
val_thr=self.decode_keypoint_thr,
|
| 386 |
+
tag_thr=self.decode_tag_thr,
|
| 387 |
+
max_groups=self.decode_max_instances)
|
| 388 |
+
|
| 389 |
+
_results = map(_group_func, zip(batch_vals, batch_tags, batch_locs))
|
| 390 |
+
results = list(_results)
|
| 391 |
+
return results
|
| 392 |
+
|
| 393 |
+
def _fill_missing_keypoints(self, keypoints: np.ndarray,
|
| 394 |
+
keypoint_scores: np.ndarray,
|
| 395 |
+
heatmaps: np.ndarray, tags: np.ndarray):
|
| 396 |
+
"""Fill the missing keypoints in the initial predictions.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
keypoints (np.ndarray): Keypoint predictions in shape (N, K, D)
|
| 400 |
+
keypoint_scores (np.ndarray): Keypint score predictions in shape
|
| 401 |
+
(N, K), in which 0 means the corresponding keypoint is
|
| 402 |
+
missing in the initial prediction
|
| 403 |
+
heatmaps (np.ndarry): Heatmaps in shape (K, H, W)
|
| 404 |
+
tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where
|
| 405 |
+
C=L*K
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
tuple:
|
| 409 |
+
- keypoints (np.ndarray): Keypoint predictions with missing
|
| 410 |
+
ones filled
|
| 411 |
+
- keypoint_scores (np.ndarray): Keypoint score predictions with
|
| 412 |
+
missing ones filled
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
N, K = keypoints.shape[:2]
|
| 416 |
+
H, W = heatmaps.shape[1:]
|
| 417 |
+
L = tags.shape[0] // K
|
| 418 |
+
keypoint_tags = [tags[k::K] for k in range(K)]
|
| 419 |
+
|
| 420 |
+
for n in range(N):
|
| 421 |
+
# Calculate the instance tag (mean tag of detected keypoints)
|
| 422 |
+
_tag = []
|
| 423 |
+
for k in range(K):
|
| 424 |
+
if keypoint_scores[n, k] > 0:
|
| 425 |
+
x, y = keypoints[n, k, :2].astype(np.int64)
|
| 426 |
+
x = np.clip(x, 0, W - 1)
|
| 427 |
+
y = np.clip(y, 0, H - 1)
|
| 428 |
+
_tag.append(keypoint_tags[k][:, y, x])
|
| 429 |
+
|
| 430 |
+
tag = np.mean(_tag, axis=0)
|
| 431 |
+
tag = tag.reshape(L, 1, 1)
|
| 432 |
+
# Search maximum response of the missing keypoints
|
| 433 |
+
for k in range(K):
|
| 434 |
+
if keypoint_scores[n, k] > 0:
|
| 435 |
+
continue
|
| 436 |
+
dist_map = np.linalg.norm(
|
| 437 |
+
keypoint_tags[k] - tag, ord=2, axis=0)
|
| 438 |
+
cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W
|
| 439 |
+
y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W))
|
| 440 |
+
keypoints[n, k] = [x, y]
|
| 441 |
+
keypoint_scores[n, k] = heatmaps[k, y, x]
|
| 442 |
+
|
| 443 |
+
return keypoints, keypoint_scores
|
| 444 |
+
|
| 445 |
+
def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
|
| 446 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
| 447 |
+
"""Decode the keypoint coordinates from a batch of heatmaps and tagging
|
| 448 |
+
heatmaps. The decoded keypoint coordinates are in the input image
|
| 449 |
+
space.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
batch_heatmaps (Tensor): Keypoint detection heatmaps in shape
|
| 453 |
+
(B, K, H, W)
|
| 454 |
+
batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where
|
| 455 |
+
:math:`C=L*K`
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
tuple:
|
| 459 |
+
- batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates
|
| 460 |
+
of the batch, each is in shape (N, K, D)
|
| 461 |
+
- batch_scores (List[np.ndarray]): Decoded keypoint scores of the
|
| 462 |
+
batch, each is in shape (N, K). It usually represents the
|
| 463 |
+
confidience of the keypoint prediction
|
| 464 |
+
"""
|
| 465 |
+
B, _, H, W = batch_heatmaps.shape
|
| 466 |
+
assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), (
|
| 467 |
+
f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and '
|
| 468 |
+
f'tagging map ({batch_tags.shape})')
|
| 469 |
+
|
| 470 |
+
# Heatmap NMS
|
| 471 |
+
batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps,
|
| 472 |
+
self.decode_nms_kernel)
|
| 473 |
+
|
| 474 |
+
# Get top-k in each heatmap and and convert to numpy
|
| 475 |
+
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
|
| 476 |
+
self._get_batch_topk(
|
| 477 |
+
batch_heatmaps_peak, batch_tags, k=self.decode_topk))
|
| 478 |
+
|
| 479 |
+
# Group keypoint candidates into groups (instances)
|
| 480 |
+
batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags,
|
| 481 |
+
batch_topk_locs)
|
| 482 |
+
|
| 483 |
+
# Convert to numpy
|
| 484 |
+
batch_heatmaps_np = to_numpy(batch_heatmaps)
|
| 485 |
+
batch_tags_np = to_numpy(batch_tags)
|
| 486 |
+
|
| 487 |
+
# Refine the keypoint prediction
|
| 488 |
+
batch_keypoints = []
|
| 489 |
+
batch_keypoint_scores = []
|
| 490 |
+
batch_instance_scores = []
|
| 491 |
+
for i, (groups, heatmaps, tags) in enumerate(
|
| 492 |
+
zip(batch_groups, batch_heatmaps_np, batch_tags_np)):
|
| 493 |
+
|
| 494 |
+
keypoints, scores = groups[..., :-1], groups[..., -1]
|
| 495 |
+
instance_scores = scores.mean(axis=-1)
|
| 496 |
+
|
| 497 |
+
if keypoints.size > 0:
|
| 498 |
+
# refine keypoint coordinates according to heatmap distribution
|
| 499 |
+
if self.use_udp:
|
| 500 |
+
keypoints = refine_keypoints_dark_udp(
|
| 501 |
+
keypoints,
|
| 502 |
+
heatmaps,
|
| 503 |
+
blur_kernel_size=self.decode_gaussian_kernel)
|
| 504 |
+
else:
|
| 505 |
+
keypoints = refine_keypoints(keypoints, heatmaps)
|
| 506 |
+
keypoints += self.decode_center_shift * \
|
| 507 |
+
(scores > 0).astype(keypoints.dtype)[..., None]
|
| 508 |
+
|
| 509 |
+
# identify missing keypoints
|
| 510 |
+
keypoints, scores = self._fill_missing_keypoints(
|
| 511 |
+
keypoints, scores, heatmaps, tags)
|
| 512 |
+
|
| 513 |
+
batch_keypoints.append(keypoints)
|
| 514 |
+
batch_keypoint_scores.append(scores)
|
| 515 |
+
batch_instance_scores.append(instance_scores)
|
| 516 |
+
|
| 517 |
+
# restore keypoint scale
|
| 518 |
+
batch_keypoints = [
|
| 519 |
+
kpts * self.scale_factor for kpts in batch_keypoints
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
return batch_keypoints, batch_keypoint_scores, batch_instance_scores
|
mmpose/codecs/base.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from abc import ABCMeta, abstractmethod
|
| 3 |
+
from typing import Any, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from mmengine.utils import is_method_overridden
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseKeypointCodec(metaclass=ABCMeta):
|
| 10 |
+
"""The base class of the keypoint codec.
|
| 11 |
+
|
| 12 |
+
A keypoint codec is a module to encode keypoint coordinates to specific
|
| 13 |
+
representation (e.g. heatmap) and vice versa. A subclass should implement
|
| 14 |
+
the methods :meth:`encode` and :meth:`decode`.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
# pass additional encoding arguments to the `encode` method, beyond the
|
| 18 |
+
# mandatory `keypoints` and `keypoints_visible` arguments.
|
| 19 |
+
auxiliary_encode_keys = set()
|
| 20 |
+
|
| 21 |
+
field_mapping_table = dict()
|
| 22 |
+
instance_mapping_table = dict()
|
| 23 |
+
label_mapping_table = dict()
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def encode(self,
|
| 27 |
+
keypoints: np.ndarray,
|
| 28 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 29 |
+
"""Encode keypoints.
|
| 30 |
+
|
| 31 |
+
Note:
|
| 32 |
+
|
| 33 |
+
- instance number: N
|
| 34 |
+
- keypoint number: K
|
| 35 |
+
- keypoint dimension: D
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 39 |
+
keypoints_visible (np.ndarray): Keypoint visibility in shape
|
| 40 |
+
(N, K, D)
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
dict: Encoded items.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
|
| 48 |
+
"""Decode keypoints.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
encoded (any): Encoded keypoint representation using the codec
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
tuple:
|
| 55 |
+
- keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 56 |
+
- keypoints_visible (np.ndarray): Keypoint visibility in shape
|
| 57 |
+
(N, K, D)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def batch_decode(self, batch_encoded: Any
|
| 61 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
| 62 |
+
"""Decode keypoints.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
batch_encoded (any): A batch of encoded keypoint
|
| 66 |
+
representations
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
tuple:
|
| 70 |
+
- batch_keypoints (List[np.ndarray]): Each element is keypoint
|
| 71 |
+
coordinates in shape (N, K, D)
|
| 72 |
+
- batch_keypoints (List[np.ndarray]): Each element is keypoint
|
| 73 |
+
visibility in shape (N, K)
|
| 74 |
+
"""
|
| 75 |
+
raise NotImplementedError()
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def support_batch_decoding(self) -> bool:
|
| 79 |
+
"""Return whether the codec support decoding from batch data."""
|
| 80 |
+
return is_method_overridden('batch_decode', BaseKeypointCodec,
|
| 81 |
+
self.__class__)
|
mmpose/codecs/decoupled_heatmap.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import random
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 8 |
+
from .base import BaseKeypointCodec
|
| 9 |
+
from .utils import (generate_gaussian_heatmaps, get_diagonal_lengths,
|
| 10 |
+
get_instance_bbox, get_instance_root)
|
| 11 |
+
from .utils.post_processing import get_heatmap_maximum
|
| 12 |
+
from .utils.refinement import refine_keypoints
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@KEYPOINT_CODECS.register_module()
|
| 16 |
+
class DecoupledHeatmap(BaseKeypointCodec):
|
| 17 |
+
"""Encode/decode keypoints with the method introduced in the paper CID.
|
| 18 |
+
|
| 19 |
+
See the paper Contextual Instance Decoupling for Robust Multi-Person
|
| 20 |
+
Pose Estimation`_ by Wang et al (2022) for details
|
| 21 |
+
|
| 22 |
+
Note:
|
| 23 |
+
|
| 24 |
+
- instance number: N
|
| 25 |
+
- keypoint number: K
|
| 26 |
+
- keypoint dimension: D
|
| 27 |
+
- image size: [w, h]
|
| 28 |
+
- heatmap size: [W, H]
|
| 29 |
+
|
| 30 |
+
Encoded:
|
| 31 |
+
- heatmaps (np.ndarray): The coupled heatmap in shape
|
| 32 |
+
(1+K, H, W) where [W, H] is the `heatmap_size`.
|
| 33 |
+
- instance_heatmaps (np.ndarray): The decoupled heatmap in shape
|
| 34 |
+
(M*K, H, W) where M is the number of instances.
|
| 35 |
+
- keypoint_weights (np.ndarray): The weight for heatmaps in shape
|
| 36 |
+
(M*K).
|
| 37 |
+
- instance_coords (np.ndarray): The coordinates of instance roots
|
| 38 |
+
in shape (M, 2)
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
input_size (tuple): Image size in [w, h]
|
| 42 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 43 |
+
root_type (str): The method to generate the instance root. Options
|
| 44 |
+
are:
|
| 45 |
+
|
| 46 |
+
- ``'kpt_center'``: Average coordinate of all visible keypoints.
|
| 47 |
+
- ``'bbox_center'``: Center point of bounding boxes outlined by
|
| 48 |
+
all visible keypoints.
|
| 49 |
+
|
| 50 |
+
Defaults to ``'kpt_center'``
|
| 51 |
+
|
| 52 |
+
heatmap_min_overlap (float): Minimum overlap rate among instances.
|
| 53 |
+
Used when calculating sigmas for instances. Defaults to 0.7
|
| 54 |
+
background_weight (float): Loss weight of background pixels.
|
| 55 |
+
Defaults to 0.1
|
| 56 |
+
encode_max_instances (int): The maximum number of instances
|
| 57 |
+
to encode for each sample. Defaults to 30
|
| 58 |
+
|
| 59 |
+
.. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_
|
| 60 |
+
Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_
|
| 61 |
+
CVPR_2022_paper.html
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
# DecoupledHeatmap requires bounding boxes to determine the size of each
|
| 65 |
+
# instance, so that it can assign varying sigmas based on their size
|
| 66 |
+
auxiliary_encode_keys = {'bbox'}
|
| 67 |
+
|
| 68 |
+
label_mapping_table = dict(
|
| 69 |
+
keypoint_weights='keypoint_weights',
|
| 70 |
+
instance_coords='instance_coords',
|
| 71 |
+
)
|
| 72 |
+
field_mapping_table = dict(
|
| 73 |
+
heatmaps='heatmaps',
|
| 74 |
+
instance_heatmaps='instance_heatmaps',
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
input_size: Tuple[int, int],
|
| 80 |
+
heatmap_size: Tuple[int, int],
|
| 81 |
+
root_type: str = 'kpt_center',
|
| 82 |
+
heatmap_min_overlap: float = 0.7,
|
| 83 |
+
encode_max_instances: int = 30,
|
| 84 |
+
):
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
self.input_size = input_size
|
| 88 |
+
self.heatmap_size = heatmap_size
|
| 89 |
+
self.root_type = root_type
|
| 90 |
+
self.encode_max_instances = encode_max_instances
|
| 91 |
+
self.heatmap_min_overlap = heatmap_min_overlap
|
| 92 |
+
|
| 93 |
+
self.scale_factor = (np.array(input_size) /
|
| 94 |
+
heatmap_size).astype(np.float32)
|
| 95 |
+
|
| 96 |
+
def _get_instance_wise_sigmas(
|
| 97 |
+
self,
|
| 98 |
+
bbox: np.ndarray,
|
| 99 |
+
) -> np.ndarray:
|
| 100 |
+
"""Get sigma values for each instance according to their size.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
bbox (np.ndarray): Bounding box in shape (N, 4, 2)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
np.ndarray: Array containing the sigma values for each instance.
|
| 107 |
+
"""
|
| 108 |
+
sigmas = np.zeros((bbox.shape[0], ), dtype=np.float32)
|
| 109 |
+
|
| 110 |
+
heights = np.sqrt(np.power(bbox[:, 0] - bbox[:, 1], 2).sum(axis=-1))
|
| 111 |
+
widths = np.sqrt(np.power(bbox[:, 0] - bbox[:, 2], 2).sum(axis=-1))
|
| 112 |
+
|
| 113 |
+
for i in range(bbox.shape[0]):
|
| 114 |
+
h, w = heights[i], widths[i]
|
| 115 |
+
|
| 116 |
+
# compute sigma for each instance
|
| 117 |
+
# condition 1
|
| 118 |
+
a1, b1 = 1, h + w
|
| 119 |
+
c1 = w * h * (1 - self.heatmap_min_overlap) / (
|
| 120 |
+
1 + self.heatmap_min_overlap)
|
| 121 |
+
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
|
| 122 |
+
r1 = (b1 + sq1) / 2
|
| 123 |
+
|
| 124 |
+
# condition 2
|
| 125 |
+
a2 = 4
|
| 126 |
+
b2 = 2 * (h + w)
|
| 127 |
+
c2 = (1 - self.heatmap_min_overlap) * w * h
|
| 128 |
+
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
|
| 129 |
+
r2 = (b2 + sq2) / 2
|
| 130 |
+
|
| 131 |
+
# condition 3
|
| 132 |
+
a3 = 4 * self.heatmap_min_overlap
|
| 133 |
+
b3 = -2 * self.heatmap_min_overlap * (h + w)
|
| 134 |
+
c3 = (self.heatmap_min_overlap - 1) * w * h
|
| 135 |
+
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
|
| 136 |
+
r3 = (b3 + sq3) / 2
|
| 137 |
+
|
| 138 |
+
sigmas[i] = min(r1, r2, r3) / 3
|
| 139 |
+
|
| 140 |
+
return sigmas
|
| 141 |
+
|
| 142 |
+
def encode(self,
|
| 143 |
+
keypoints: np.ndarray,
|
| 144 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 145 |
+
bbox: Optional[np.ndarray] = None) -> dict:
|
| 146 |
+
"""Encode keypoints into heatmaps.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 150 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 151 |
+
(N, K)
|
| 152 |
+
bbox (np.ndarray): Bounding box in shape (N, 8) which includes
|
| 153 |
+
coordinates of 4 corners.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
dict:
|
| 157 |
+
- heatmaps (np.ndarray): The coupled heatmap in shape
|
| 158 |
+
(1+K, H, W) where [W, H] is the `heatmap_size`.
|
| 159 |
+
- instance_heatmaps (np.ndarray): The decoupled heatmap in shape
|
| 160 |
+
(N*K, H, W) where M is the number of instances.
|
| 161 |
+
- keypoint_weights (np.ndarray): The weight for heatmaps in shape
|
| 162 |
+
(N*K).
|
| 163 |
+
- instance_coords (np.ndarray): The coordinates of instance roots
|
| 164 |
+
in shape (N, 2)
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
if keypoints_visible is None:
|
| 168 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 169 |
+
if bbox is None:
|
| 170 |
+
# generate pseudo bbox via visible keypoints
|
| 171 |
+
bbox = get_instance_bbox(keypoints, keypoints_visible)
|
| 172 |
+
bbox = np.tile(bbox, 2).reshape(-1, 4, 2)
|
| 173 |
+
# corner order: left_top, left_bottom, right_top, right_bottom
|
| 174 |
+
bbox[:, 1:3, 0] = bbox[:, 0:2, 0]
|
| 175 |
+
|
| 176 |
+
# keypoint coordinates in heatmap
|
| 177 |
+
_keypoints = keypoints / self.scale_factor
|
| 178 |
+
_bbox = bbox.reshape(-1, 4, 2) / self.scale_factor
|
| 179 |
+
|
| 180 |
+
# compute the root and scale of each instance
|
| 181 |
+
roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
|
| 182 |
+
self.root_type)
|
| 183 |
+
|
| 184 |
+
sigmas = self._get_instance_wise_sigmas(_bbox)
|
| 185 |
+
|
| 186 |
+
# generate global heatmaps
|
| 187 |
+
heatmaps, keypoint_weights = generate_gaussian_heatmaps(
|
| 188 |
+
heatmap_size=self.heatmap_size,
|
| 189 |
+
keypoints=np.concatenate((_keypoints, roots[:, None]), axis=1),
|
| 190 |
+
keypoints_visible=np.concatenate(
|
| 191 |
+
(keypoints_visible, roots_visible[:, None]), axis=1),
|
| 192 |
+
sigma=sigmas)
|
| 193 |
+
roots_visible = keypoint_weights[:, -1]
|
| 194 |
+
|
| 195 |
+
# select instances
|
| 196 |
+
inst_roots, inst_indices = [], []
|
| 197 |
+
diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
|
| 198 |
+
for i in np.argsort(diagonal_lengths):
|
| 199 |
+
if roots_visible[i] < 1:
|
| 200 |
+
continue
|
| 201 |
+
# rand root point in 3x3 grid
|
| 202 |
+
x, y = roots[i] + np.random.randint(-1, 2, (2, ))
|
| 203 |
+
x = max(0, min(x, self.heatmap_size[0] - 1))
|
| 204 |
+
y = max(0, min(y, self.heatmap_size[1] - 1))
|
| 205 |
+
if (x, y) not in inst_roots:
|
| 206 |
+
inst_roots.append((x, y))
|
| 207 |
+
inst_indices.append(i)
|
| 208 |
+
if len(inst_indices) > self.encode_max_instances:
|
| 209 |
+
rand_indices = random.sample(
|
| 210 |
+
range(len(inst_indices)), self.encode_max_instances)
|
| 211 |
+
inst_roots = [inst_roots[i] for i in rand_indices]
|
| 212 |
+
inst_indices = [inst_indices[i] for i in rand_indices]
|
| 213 |
+
|
| 214 |
+
# generate instance-wise heatmaps
|
| 215 |
+
inst_heatmaps, inst_heatmap_weights = [], []
|
| 216 |
+
for i in inst_indices:
|
| 217 |
+
inst_heatmap, inst_heatmap_weight = generate_gaussian_heatmaps(
|
| 218 |
+
heatmap_size=self.heatmap_size,
|
| 219 |
+
keypoints=_keypoints[i:i + 1],
|
| 220 |
+
keypoints_visible=keypoints_visible[i:i + 1],
|
| 221 |
+
sigma=sigmas[i].item())
|
| 222 |
+
inst_heatmaps.append(inst_heatmap)
|
| 223 |
+
inst_heatmap_weights.append(inst_heatmap_weight)
|
| 224 |
+
|
| 225 |
+
if len(inst_indices) > 0:
|
| 226 |
+
inst_heatmaps = np.concatenate(inst_heatmaps)
|
| 227 |
+
inst_heatmap_weights = np.concatenate(inst_heatmap_weights)
|
| 228 |
+
inst_roots = np.array(inst_roots, dtype=np.int32)
|
| 229 |
+
else:
|
| 230 |
+
inst_heatmaps = np.empty((0, *self.heatmap_size[::-1]))
|
| 231 |
+
inst_heatmap_weights = np.empty((0, ))
|
| 232 |
+
inst_roots = np.empty((0, 2), dtype=np.int32)
|
| 233 |
+
|
| 234 |
+
encoded = dict(
|
| 235 |
+
heatmaps=heatmaps,
|
| 236 |
+
instance_heatmaps=inst_heatmaps,
|
| 237 |
+
keypoint_weights=inst_heatmap_weights,
|
| 238 |
+
instance_coords=inst_roots)
|
| 239 |
+
|
| 240 |
+
return encoded
|
| 241 |
+
|
| 242 |
+
def decode(self, instance_heatmaps: np.ndarray,
|
| 243 |
+
instance_scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 244 |
+
"""Decode keypoint coordinates from decoupled heatmaps. The decoded
|
| 245 |
+
keypoint coordinates are in the input image space.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
instance_heatmaps (np.ndarray): Heatmaps in shape (N, K, H, W)
|
| 249 |
+
instance_scores (np.ndarray): Confidence of instance roots
|
| 250 |
+
prediction in shape (N, 1)
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
tuple:
|
| 254 |
+
- keypoints (np.ndarray): Decoded keypoint coordinates in shape
|
| 255 |
+
(N, K, D)
|
| 256 |
+
- scores (np.ndarray): The keypoint scores in shape (N, K). It
|
| 257 |
+
usually represents the confidence of the keypoint prediction
|
| 258 |
+
"""
|
| 259 |
+
keypoints, keypoint_scores = [], []
|
| 260 |
+
|
| 261 |
+
for i in range(instance_heatmaps.shape[0]):
|
| 262 |
+
heatmaps = instance_heatmaps[i].copy()
|
| 263 |
+
kpts, scores = get_heatmap_maximum(heatmaps)
|
| 264 |
+
keypoints.append(refine_keypoints(kpts[None], heatmaps))
|
| 265 |
+
keypoint_scores.append(scores[None])
|
| 266 |
+
|
| 267 |
+
keypoints = np.concatenate(keypoints)
|
| 268 |
+
# Restore the keypoint scale
|
| 269 |
+
keypoints = keypoints * self.scale_factor
|
| 270 |
+
|
| 271 |
+
keypoint_scores = np.concatenate(keypoint_scores)
|
| 272 |
+
keypoint_scores *= instance_scores
|
| 273 |
+
|
| 274 |
+
return keypoints, keypoint_scores
|
mmpose/codecs/edpose_label.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 7 |
+
from mmpose.structures import bbox_cs2xyxy, bbox_xyxy2cs
|
| 8 |
+
from .base import BaseKeypointCodec
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@KEYPOINT_CODECS.register_module()
|
| 12 |
+
class EDPoseLabel(BaseKeypointCodec):
|
| 13 |
+
r"""Generate keypoint and label coordinates for `ED-Pose`_ by
|
| 14 |
+
Yang J. et al (2023).
|
| 15 |
+
|
| 16 |
+
Note:
|
| 17 |
+
|
| 18 |
+
- instance number: N
|
| 19 |
+
- keypoint number: K
|
| 20 |
+
- keypoint dimension: D
|
| 21 |
+
- image size: [w, h]
|
| 22 |
+
|
| 23 |
+
Encoded:
|
| 24 |
+
|
| 25 |
+
- keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 26 |
+
- keypoints_visible (np.ndarray): Keypoint visibility in shape
|
| 27 |
+
(N, K, D)
|
| 28 |
+
- area (np.ndarray): Area in shape (N)
|
| 29 |
+
- bbox (np.ndarray): Bbox in shape (N, 4)
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
num_select (int): The number of candidate instances
|
| 33 |
+
num_keypoints (int): The Number of keypoints
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
auxiliary_encode_keys = {'area', 'bboxes', 'img_shape'}
|
| 37 |
+
instance_mapping_table = dict(
|
| 38 |
+
bbox='bboxes',
|
| 39 |
+
keypoints='keypoints',
|
| 40 |
+
keypoints_visible='keypoints_visible',
|
| 41 |
+
area='areas',
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def __init__(self, num_select: int = 100, num_keypoints: int = 17):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.num_select = num_select
|
| 48 |
+
self.num_keypoints = num_keypoints
|
| 49 |
+
|
| 50 |
+
def encode(
|
| 51 |
+
self,
|
| 52 |
+
img_shape,
|
| 53 |
+
keypoints: np.ndarray,
|
| 54 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 55 |
+
area: Optional[np.ndarray] = None,
|
| 56 |
+
bboxes: Optional[np.ndarray] = None,
|
| 57 |
+
) -> dict:
|
| 58 |
+
"""Encoding keypoints, area and bbox from input image space to
|
| 59 |
+
normalized space.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
- img_shape (Sequence[int]): The shape of image in the format
|
| 63 |
+
of (width, height).
|
| 64 |
+
- keypoints (np.ndarray): Keypoint coordinates in
|
| 65 |
+
shape (N, K, D).
|
| 66 |
+
- keypoints_visible (np.ndarray): Keypoint visibility in shape
|
| 67 |
+
(N, K)
|
| 68 |
+
- area (np.ndarray):
|
| 69 |
+
- bboxes (np.ndarray):
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
encoded (dict): Contains the following items:
|
| 73 |
+
|
| 74 |
+
- keypoint_labels (np.ndarray): The processed keypoints in
|
| 75 |
+
shape like (N, K, D).
|
| 76 |
+
- keypoints_visible (np.ndarray): Keypoint visibility in shape
|
| 77 |
+
(N, K, D)
|
| 78 |
+
- area_labels (np.ndarray): The processed target
|
| 79 |
+
area in shape (N).
|
| 80 |
+
- bboxes_labels: The processed target bbox in
|
| 81 |
+
shape (N, 4).
|
| 82 |
+
"""
|
| 83 |
+
w, h = img_shape
|
| 84 |
+
|
| 85 |
+
if keypoints_visible is None:
|
| 86 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 87 |
+
|
| 88 |
+
if bboxes is not None:
|
| 89 |
+
bboxes = np.concatenate(bbox_xyxy2cs(bboxes), axis=-1)
|
| 90 |
+
bboxes = bboxes / np.array([w, h, w, h], dtype=np.float32)
|
| 91 |
+
|
| 92 |
+
if area is not None:
|
| 93 |
+
area = area / float(w * h)
|
| 94 |
+
|
| 95 |
+
if keypoints is not None:
|
| 96 |
+
keypoints = keypoints / np.array([w, h], dtype=np.float32)
|
| 97 |
+
|
| 98 |
+
encoded = dict(
|
| 99 |
+
keypoints=keypoints,
|
| 100 |
+
area=area,
|
| 101 |
+
bbox=bboxes,
|
| 102 |
+
keypoints_visible=keypoints_visible)
|
| 103 |
+
|
| 104 |
+
return encoded
|
| 105 |
+
|
| 106 |
+
def decode(self, input_shapes: np.ndarray, pred_logits: np.ndarray,
|
| 107 |
+
pred_boxes: np.ndarray, pred_keypoints: np.ndarray):
|
| 108 |
+
"""Select the final top-k keypoints, and decode the results from
|
| 109 |
+
normalize size to origin input size.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
input_shapes (Tensor): The size of input image resize.
|
| 113 |
+
test_cfg (ConfigType): Config of testing.
|
| 114 |
+
pred_logits (Tensor): The result of score.
|
| 115 |
+
pred_boxes (Tensor): The result of bbox.
|
| 116 |
+
pred_keypoints (Tensor): The result of keypoints.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
tuple: Decoded boxes, keypoints, and keypoint scores.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
# Initialization
|
| 123 |
+
num_keypoints = self.num_keypoints
|
| 124 |
+
prob = pred_logits.reshape(-1)
|
| 125 |
+
|
| 126 |
+
# Select top-k instances based on prediction scores
|
| 127 |
+
topk_indexes = np.argsort(-prob)[:self.num_select]
|
| 128 |
+
topk_values = np.take_along_axis(prob, topk_indexes, axis=0)
|
| 129 |
+
scores = np.tile(topk_values[:, np.newaxis], [1, num_keypoints])
|
| 130 |
+
|
| 131 |
+
# Decode bounding boxes
|
| 132 |
+
topk_boxes = topk_indexes // pred_logits.shape[1]
|
| 133 |
+
boxes = bbox_cs2xyxy(*np.split(pred_boxes, [2], axis=-1))
|
| 134 |
+
boxes = np.take_along_axis(
|
| 135 |
+
boxes, np.tile(topk_boxes[:, np.newaxis], [1, 4]), axis=0)
|
| 136 |
+
|
| 137 |
+
# Convert from relative to absolute coordinates
|
| 138 |
+
img_h, img_w = np.split(input_shapes, 2, axis=0)
|
| 139 |
+
scale_fct = np.hstack([img_w, img_h, img_w, img_h])
|
| 140 |
+
boxes = boxes * scale_fct[np.newaxis, :]
|
| 141 |
+
|
| 142 |
+
# Decode keypoints
|
| 143 |
+
topk_keypoints = topk_indexes // pred_logits.shape[1]
|
| 144 |
+
keypoints = np.take_along_axis(
|
| 145 |
+
pred_keypoints,
|
| 146 |
+
np.tile(topk_keypoints[:, np.newaxis], [1, num_keypoints * 3]),
|
| 147 |
+
axis=0)
|
| 148 |
+
keypoints = keypoints[:, :(num_keypoints * 2)]
|
| 149 |
+
keypoints = keypoints * np.tile(
|
| 150 |
+
np.hstack([img_w, img_h]), [num_keypoints])[np.newaxis, :]
|
| 151 |
+
keypoints = keypoints.reshape(-1, num_keypoints, 2)
|
| 152 |
+
|
| 153 |
+
return boxes, keypoints, scores
|
mmpose/codecs/hand_3d_heatmap.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 7 |
+
from .base import BaseKeypointCodec
|
| 8 |
+
from .utils.gaussian_heatmap import generate_3d_gaussian_heatmaps
|
| 9 |
+
from .utils.post_processing import get_heatmap_3d_maximum
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@KEYPOINT_CODECS.register_module()
|
| 13 |
+
class Hand3DHeatmap(BaseKeypointCodec):
|
| 14 |
+
r"""Generate target 3d heatmap and relative root depth for hand datasets.
|
| 15 |
+
|
| 16 |
+
Note:
|
| 17 |
+
|
| 18 |
+
- instance number: N
|
| 19 |
+
- keypoint number: K
|
| 20 |
+
- keypoint dimension: D
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
image_size (tuple): Size of image. Default: ``[256, 256]``.
|
| 24 |
+
root_heatmap_size (int): Size of heatmap of root head.
|
| 25 |
+
Default: 64.
|
| 26 |
+
heatmap_size (tuple): Size of heatmap. Default: ``[64, 64, 64]``.
|
| 27 |
+
heatmap3d_depth_bound (float): Boundary for 3d heatmap depth.
|
| 28 |
+
Default: 400.0.
|
| 29 |
+
heatmap_size_root (int): Size of 3d heatmap root. Default: 64.
|
| 30 |
+
depth_size (int): Number of depth discretization size, used for
|
| 31 |
+
decoding. Defaults to 64.
|
| 32 |
+
root_depth_bound (float): Boundary for 3d heatmap root depth.
|
| 33 |
+
Default: 400.0.
|
| 34 |
+
use_different_joint_weights (bool): Whether to use different joint
|
| 35 |
+
weights. Default: ``False``.
|
| 36 |
+
sigma (int): Sigma of heatmap gaussian. Default: 2.
|
| 37 |
+
joint_indices (list, optional): Indices of joints used for heatmap
|
| 38 |
+
generation. If None (default) is given, all joints will be used.
|
| 39 |
+
Default: ``None``.
|
| 40 |
+
max_bound (float): The maximal value of heatmap. Default: 1.0.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
auxiliary_encode_keys = {
|
| 44 |
+
'dataset_keypoint_weights', 'rel_root_depth', 'rel_root_valid',
|
| 45 |
+
'hand_type', 'hand_type_valid', 'focal', 'principal_pt'
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
instance_mapping_table = {
|
| 49 |
+
'keypoints': 'keypoints',
|
| 50 |
+
'keypoints_visible': 'keypoints_visible',
|
| 51 |
+
'keypoints_cam': 'keypoints_cam',
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
label_mapping_table = {
|
| 55 |
+
'keypoint_weights': 'keypoint_weights',
|
| 56 |
+
'root_depth_weight': 'root_depth_weight',
|
| 57 |
+
'type_weight': 'type_weight',
|
| 58 |
+
'root_depth': 'root_depth',
|
| 59 |
+
'type': 'type'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def __init__(self,
|
| 63 |
+
image_size: Tuple[int, int] = [256, 256],
|
| 64 |
+
root_heatmap_size: int = 64,
|
| 65 |
+
heatmap_size: Tuple[int, int, int] = [64, 64, 64],
|
| 66 |
+
heatmap3d_depth_bound: float = 400.0,
|
| 67 |
+
heatmap_size_root: int = 64,
|
| 68 |
+
root_depth_bound: float = 400.0,
|
| 69 |
+
depth_size: int = 64,
|
| 70 |
+
use_different_joint_weights: bool = False,
|
| 71 |
+
sigma: int = 2,
|
| 72 |
+
joint_indices: Optional[list] = None,
|
| 73 |
+
max_bound: float = 1.0):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.image_size = np.array(image_size)
|
| 77 |
+
self.root_heatmap_size = root_heatmap_size
|
| 78 |
+
self.heatmap_size = np.array(heatmap_size)
|
| 79 |
+
self.heatmap3d_depth_bound = heatmap3d_depth_bound
|
| 80 |
+
self.heatmap_size_root = heatmap_size_root
|
| 81 |
+
self.root_depth_bound = root_depth_bound
|
| 82 |
+
self.depth_size = depth_size
|
| 83 |
+
self.use_different_joint_weights = use_different_joint_weights
|
| 84 |
+
|
| 85 |
+
self.sigma = sigma
|
| 86 |
+
self.joint_indices = joint_indices
|
| 87 |
+
self.max_bound = max_bound
|
| 88 |
+
self.scale_factor = (np.array(image_size) /
|
| 89 |
+
heatmap_size[:-1]).astype(np.float32)
|
| 90 |
+
|
| 91 |
+
def encode(
|
| 92 |
+
self,
|
| 93 |
+
keypoints: np.ndarray,
|
| 94 |
+
keypoints_visible: Optional[np.ndarray],
|
| 95 |
+
dataset_keypoint_weights: Optional[np.ndarray],
|
| 96 |
+
rel_root_depth: np.float32,
|
| 97 |
+
rel_root_valid: np.float32,
|
| 98 |
+
hand_type: np.ndarray,
|
| 99 |
+
hand_type_valid: np.ndarray,
|
| 100 |
+
focal: np.ndarray,
|
| 101 |
+
principal_pt: np.ndarray,
|
| 102 |
+
) -> dict:
|
| 103 |
+
"""Encoding keypoints from input image space to input image space.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D).
|
| 107 |
+
keypoints_visible (np.ndarray, optional): Keypoint visibilities in
|
| 108 |
+
shape (N, K).
|
| 109 |
+
dataset_keypoint_weights (np.ndarray, optional): Keypoints weight
|
| 110 |
+
in shape (K, ).
|
| 111 |
+
rel_root_depth (np.float32): Relative root depth.
|
| 112 |
+
rel_root_valid (float): Validity of relative root depth.
|
| 113 |
+
hand_type (np.ndarray): Type of hand encoded as a array.
|
| 114 |
+
hand_type_valid (np.ndarray): Validity of hand type.
|
| 115 |
+
focal (np.ndarray): Focal length of camera.
|
| 116 |
+
principal_pt (np.ndarray): Principal point of camera.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
encoded (dict): Contains the following items:
|
| 120 |
+
|
| 121 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 122 |
+
(K * D, H, W) where [W, H, D] is the `heatmap_size`
|
| 123 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 124 |
+
(N, K)
|
| 125 |
+
- root_depth (np.ndarray): Encoded relative root depth
|
| 126 |
+
- root_depth_weight (np.ndarray): The weights of relative root
|
| 127 |
+
depth
|
| 128 |
+
- type (np.ndarray): Encoded hand type
|
| 129 |
+
- type_weight (np.ndarray): The weights of hand type
|
| 130 |
+
"""
|
| 131 |
+
if keypoints_visible is None:
|
| 132 |
+
keypoints_visible = np.ones(keypoints.shape[:-1], dtype=np.float32)
|
| 133 |
+
|
| 134 |
+
if self.use_different_joint_weights:
|
| 135 |
+
assert dataset_keypoint_weights is not None, 'To use different ' \
|
| 136 |
+
'joint weights,`dataset_keypoint_weights` cannot be None.'
|
| 137 |
+
|
| 138 |
+
heatmaps, keypoint_weights = generate_3d_gaussian_heatmaps(
|
| 139 |
+
heatmap_size=self.heatmap_size,
|
| 140 |
+
keypoints=keypoints,
|
| 141 |
+
keypoints_visible=keypoints_visible,
|
| 142 |
+
sigma=self.sigma,
|
| 143 |
+
image_size=self.image_size,
|
| 144 |
+
heatmap3d_depth_bound=self.heatmap3d_depth_bound,
|
| 145 |
+
joint_indices=self.joint_indices,
|
| 146 |
+
max_bound=self.max_bound,
|
| 147 |
+
use_different_joint_weights=self.use_different_joint_weights,
|
| 148 |
+
dataset_keypoint_weights=dataset_keypoint_weights)
|
| 149 |
+
|
| 150 |
+
rel_root_depth = (rel_root_depth / self.root_depth_bound +
|
| 151 |
+
0.5) * self.heatmap_size_root
|
| 152 |
+
rel_root_valid = rel_root_valid * (rel_root_depth >= 0) * (
|
| 153 |
+
rel_root_depth <= self.heatmap_size_root)
|
| 154 |
+
|
| 155 |
+
encoded = dict(
|
| 156 |
+
heatmaps=heatmaps,
|
| 157 |
+
keypoint_weights=keypoint_weights,
|
| 158 |
+
root_depth=rel_root_depth * np.ones(1, dtype=np.float32),
|
| 159 |
+
type=hand_type,
|
| 160 |
+
type_weight=hand_type_valid,
|
| 161 |
+
root_depth_weight=rel_root_valid * np.ones(1, dtype=np.float32))
|
| 162 |
+
return encoded
|
| 163 |
+
|
| 164 |
+
def decode(self, heatmaps: np.ndarray, root_depth: np.ndarray,
|
| 165 |
+
hand_type: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 166 |
+
"""Decode keypoint coordinates from heatmaps. The decoded keypoint
|
| 167 |
+
coordinates are in the input image space.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
heatmaps (np.ndarray): Heatmaps in shape (K, D, H, W)
|
| 171 |
+
root_depth (np.ndarray): Root depth prediction.
|
| 172 |
+
hand_type (np.ndarray): Hand type prediction.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
tuple:
|
| 176 |
+
- keypoints (np.ndarray): Decoded keypoint coordinates in shape
|
| 177 |
+
(N, K, D)
|
| 178 |
+
- scores (np.ndarray): The keypoint scores in shape (N, K). It
|
| 179 |
+
usually represents the confidence of the keypoint prediction
|
| 180 |
+
"""
|
| 181 |
+
heatmap3d = heatmaps.copy()
|
| 182 |
+
|
| 183 |
+
keypoints, scores = get_heatmap_3d_maximum(heatmap3d)
|
| 184 |
+
|
| 185 |
+
# transform keypoint depth to camera space
|
| 186 |
+
keypoints[..., 2] = (keypoints[..., 2] / self.depth_size -
|
| 187 |
+
0.5) * self.heatmap3d_depth_bound
|
| 188 |
+
|
| 189 |
+
# Unsqueeze the instance dimension for single-instance results
|
| 190 |
+
keypoints, scores = keypoints[None], scores[None]
|
| 191 |
+
|
| 192 |
+
# Restore the keypoint scale
|
| 193 |
+
keypoints[..., :2] = keypoints[..., :2] * self.scale_factor
|
| 194 |
+
|
| 195 |
+
# decode relative hand root depth
|
| 196 |
+
# transform relative root depth to camera space
|
| 197 |
+
rel_root_depth = ((root_depth / self.root_heatmap_size - 0.5) *
|
| 198 |
+
self.root_depth_bound)
|
| 199 |
+
|
| 200 |
+
hand_type = (hand_type > 0).reshape(1, -1).astype(int)
|
| 201 |
+
|
| 202 |
+
return keypoints, scores, rel_root_depth, hand_type
|
mmpose/codecs/image_pose_lifting.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 7 |
+
from .base import BaseKeypointCodec
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@KEYPOINT_CODECS.register_module()
|
| 11 |
+
class ImagePoseLifting(BaseKeypointCodec):
|
| 12 |
+
r"""Generate keypoint coordinates for pose lifter.
|
| 13 |
+
|
| 14 |
+
Note:
|
| 15 |
+
|
| 16 |
+
- instance number: N
|
| 17 |
+
- keypoint number: K
|
| 18 |
+
- keypoint dimension: D
|
| 19 |
+
- pose-lifitng target dimension: C
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
num_keypoints (int): The number of keypoints in the dataset.
|
| 23 |
+
root_index (Union[int, List]): Root keypoint index in the pose.
|
| 24 |
+
remove_root (bool): If true, remove the root keypoint from the pose.
|
| 25 |
+
Default: ``False``.
|
| 26 |
+
save_index (bool): If true, store the root position separated from the
|
| 27 |
+
original pose. Default: ``False``.
|
| 28 |
+
reshape_keypoints (bool): If true, reshape the keypoints into shape
|
| 29 |
+
(-1, N). Default: ``True``.
|
| 30 |
+
concat_vis (bool): If true, concat the visibility item of keypoints.
|
| 31 |
+
Default: ``False``.
|
| 32 |
+
keypoints_mean (np.ndarray, optional): Mean values of keypoints
|
| 33 |
+
coordinates in shape (K, D).
|
| 34 |
+
keypoints_std (np.ndarray, optional): Std values of keypoints
|
| 35 |
+
coordinates in shape (K, D).
|
| 36 |
+
target_mean (np.ndarray, optional): Mean values of pose-lifitng target
|
| 37 |
+
coordinates in shape (K, C).
|
| 38 |
+
target_std (np.ndarray, optional): Std values of pose-lifitng target
|
| 39 |
+
coordinates in shape (K, C).
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
auxiliary_encode_keys = {'lifting_target', 'lifting_target_visible'}
|
| 43 |
+
|
| 44 |
+
instance_mapping_table = dict(
|
| 45 |
+
lifting_target='lifting_target',
|
| 46 |
+
lifting_target_visible='lifting_target_visible',
|
| 47 |
+
)
|
| 48 |
+
label_mapping_table = dict(
|
| 49 |
+
trajectory_weights='trajectory_weights',
|
| 50 |
+
lifting_target_label='lifting_target_label',
|
| 51 |
+
lifting_target_weight='lifting_target_weight')
|
| 52 |
+
|
| 53 |
+
def __init__(self,
|
| 54 |
+
num_keypoints: int,
|
| 55 |
+
root_index: Union[int, List] = 0,
|
| 56 |
+
remove_root: bool = False,
|
| 57 |
+
save_index: bool = False,
|
| 58 |
+
reshape_keypoints: bool = True,
|
| 59 |
+
concat_vis: bool = False,
|
| 60 |
+
keypoints_mean: Optional[np.ndarray] = None,
|
| 61 |
+
keypoints_std: Optional[np.ndarray] = None,
|
| 62 |
+
target_mean: Optional[np.ndarray] = None,
|
| 63 |
+
target_std: Optional[np.ndarray] = None,
|
| 64 |
+
additional_encode_keys: Optional[List[str]] = None):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.num_keypoints = num_keypoints
|
| 68 |
+
if isinstance(root_index, int):
|
| 69 |
+
root_index = [root_index]
|
| 70 |
+
self.root_index = root_index
|
| 71 |
+
self.remove_root = remove_root
|
| 72 |
+
self.save_index = save_index
|
| 73 |
+
self.reshape_keypoints = reshape_keypoints
|
| 74 |
+
self.concat_vis = concat_vis
|
| 75 |
+
if keypoints_mean is not None:
|
| 76 |
+
assert keypoints_std is not None, 'keypoints_std is None'
|
| 77 |
+
keypoints_mean = np.array(
|
| 78 |
+
keypoints_mean,
|
| 79 |
+
dtype=np.float32).reshape(1, num_keypoints, -1)
|
| 80 |
+
keypoints_std = np.array(
|
| 81 |
+
keypoints_std, dtype=np.float32).reshape(1, num_keypoints, -1)
|
| 82 |
+
|
| 83 |
+
assert keypoints_mean.shape == keypoints_std.shape, (
|
| 84 |
+
f'keypoints_mean.shape {keypoints_mean.shape} != '
|
| 85 |
+
f'keypoints_std.shape {keypoints_std.shape}')
|
| 86 |
+
if target_mean is not None:
|
| 87 |
+
assert target_std is not None, 'target_std is None'
|
| 88 |
+
target_dim = num_keypoints - 1 if remove_root else num_keypoints
|
| 89 |
+
target_mean = np.array(
|
| 90 |
+
target_mean, dtype=np.float32).reshape(1, target_dim, -1)
|
| 91 |
+
target_std = np.array(
|
| 92 |
+
target_std, dtype=np.float32).reshape(1, target_dim, -1)
|
| 93 |
+
|
| 94 |
+
assert target_mean.shape == target_std.shape, (
|
| 95 |
+
f'target_mean.shape {target_mean.shape} != '
|
| 96 |
+
f'target_std.shape {target_std.shape}')
|
| 97 |
+
self.keypoints_mean = keypoints_mean
|
| 98 |
+
self.keypoints_std = keypoints_std
|
| 99 |
+
self.target_mean = target_mean
|
| 100 |
+
self.target_std = target_std
|
| 101 |
+
|
| 102 |
+
if additional_encode_keys is not None:
|
| 103 |
+
self.auxiliary_encode_keys.update(additional_encode_keys)
|
| 104 |
+
|
| 105 |
+
def encode(self,
|
| 106 |
+
keypoints: np.ndarray,
|
| 107 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 108 |
+
lifting_target: Optional[np.ndarray] = None,
|
| 109 |
+
lifting_target_visible: Optional[np.ndarray] = None) -> dict:
|
| 110 |
+
"""Encoding keypoints from input image space to normalized space.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D).
|
| 114 |
+
keypoints_visible (np.ndarray, optional): Keypoint visibilities in
|
| 115 |
+
shape (N, K).
|
| 116 |
+
lifting_target (np.ndarray, optional): 3d target coordinate in
|
| 117 |
+
shape (T, K, C).
|
| 118 |
+
lifting_target_visible (np.ndarray, optional): Target coordinate in
|
| 119 |
+
shape (T, K, ).
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
encoded (dict): Contains the following items:
|
| 123 |
+
|
| 124 |
+
- keypoint_labels (np.ndarray): The processed keypoints in
|
| 125 |
+
shape like (N, K, D) or (K * D, N).
|
| 126 |
+
- keypoint_labels_visible (np.ndarray): The processed
|
| 127 |
+
keypoints' weights in shape (N, K, ) or (N-1, K, ).
|
| 128 |
+
- lifting_target_label: The processed target coordinate in
|
| 129 |
+
shape (K, C) or (K-1, C).
|
| 130 |
+
- lifting_target_weight (np.ndarray): The target weights in
|
| 131 |
+
shape (K, ) or (K-1, ).
|
| 132 |
+
- trajectory_weights (np.ndarray): The trajectory weights in
|
| 133 |
+
shape (K, ).
|
| 134 |
+
- target_root (np.ndarray): The root coordinate of target in
|
| 135 |
+
shape (C, ).
|
| 136 |
+
|
| 137 |
+
In addition, there are some optional items it may contain:
|
| 138 |
+
|
| 139 |
+
- target_root (np.ndarray): The root coordinate of target in
|
| 140 |
+
shape (C, ). Exists if ``zero_center`` is ``True``.
|
| 141 |
+
- target_root_removed (bool): Indicate whether the root of
|
| 142 |
+
pose-lifitng target is removed. Exists if
|
| 143 |
+
``remove_root`` is ``True``.
|
| 144 |
+
- target_root_index (int): An integer indicating the index of
|
| 145 |
+
root. Exists if ``remove_root`` and ``save_index``
|
| 146 |
+
are ``True``.
|
| 147 |
+
"""
|
| 148 |
+
if keypoints_visible is None:
|
| 149 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 150 |
+
|
| 151 |
+
if lifting_target is None:
|
| 152 |
+
lifting_target = [keypoints[0]]
|
| 153 |
+
|
| 154 |
+
# set initial value for `lifting_target_weight`
|
| 155 |
+
# and `trajectory_weights`
|
| 156 |
+
if lifting_target_visible is None:
|
| 157 |
+
lifting_target_visible = np.ones(
|
| 158 |
+
lifting_target.shape[:-1], dtype=np.float32)
|
| 159 |
+
lifting_target_weight = lifting_target_visible
|
| 160 |
+
trajectory_weights = (1 / lifting_target[:, 2])
|
| 161 |
+
else:
|
| 162 |
+
valid = lifting_target_visible > 0.5
|
| 163 |
+
lifting_target_weight = np.where(valid, 1., 0.).astype(np.float32)
|
| 164 |
+
trajectory_weights = lifting_target_weight
|
| 165 |
+
|
| 166 |
+
encoded = dict()
|
| 167 |
+
|
| 168 |
+
# Zero-center the target pose around a given root keypoint
|
| 169 |
+
assert (lifting_target.ndim >= 2 and
|
| 170 |
+
lifting_target.shape[-2] > max(self.root_index)), \
|
| 171 |
+
f'Got invalid joint shape {lifting_target.shape}'
|
| 172 |
+
|
| 173 |
+
root = np.mean(
|
| 174 |
+
lifting_target[..., self.root_index, :], axis=-2, dtype=np.float32)
|
| 175 |
+
lifting_target_label = lifting_target - root[np.newaxis, ...]
|
| 176 |
+
|
| 177 |
+
if self.remove_root and len(self.root_index) == 1:
|
| 178 |
+
root_index = self.root_index[0]
|
| 179 |
+
lifting_target_label = np.delete(
|
| 180 |
+
lifting_target_label, root_index, axis=-2)
|
| 181 |
+
lifting_target_visible = np.delete(
|
| 182 |
+
lifting_target_visible, root_index, axis=-2)
|
| 183 |
+
assert lifting_target_weight.ndim in {
|
| 184 |
+
2, 3
|
| 185 |
+
}, (f'lifting_target_weight.ndim {lifting_target_weight.ndim} '
|
| 186 |
+
'is not in {2, 3}')
|
| 187 |
+
|
| 188 |
+
axis_to_remove = -2 if lifting_target_weight.ndim == 3 else -1
|
| 189 |
+
lifting_target_weight = np.delete(
|
| 190 |
+
lifting_target_weight, root_index, axis=axis_to_remove)
|
| 191 |
+
# Add a flag to avoid latter transforms that rely on the root
|
| 192 |
+
# joint or the original joint index
|
| 193 |
+
encoded['target_root_removed'] = True
|
| 194 |
+
|
| 195 |
+
# Save the root index which is necessary to restore the global pose
|
| 196 |
+
if self.save_index:
|
| 197 |
+
encoded['target_root_index'] = root_index
|
| 198 |
+
|
| 199 |
+
# Normalize the 2D keypoint coordinate with mean and std
|
| 200 |
+
keypoint_labels = keypoints.copy()
|
| 201 |
+
|
| 202 |
+
if self.keypoints_mean is not None:
|
| 203 |
+
assert self.keypoints_mean.shape[1:] == keypoints.shape[1:], (
|
| 204 |
+
f'self.keypoints_mean.shape[1:] {self.keypoints_mean.shape[1:]} ' # noqa
|
| 205 |
+
f'!= keypoints.shape[1:] {keypoints.shape[1:]}')
|
| 206 |
+
encoded['keypoints_mean'] = self.keypoints_mean.copy()
|
| 207 |
+
encoded['keypoints_std'] = self.keypoints_std.copy()
|
| 208 |
+
|
| 209 |
+
keypoint_labels = (keypoint_labels -
|
| 210 |
+
self.keypoints_mean) / self.keypoints_std
|
| 211 |
+
if self.target_mean is not None:
|
| 212 |
+
assert self.target_mean.shape == lifting_target_label.shape, (
|
| 213 |
+
f'self.target_mean.shape {self.target_mean.shape} '
|
| 214 |
+
f'!= lifting_target_label.shape {lifting_target_label.shape}' # noqa
|
| 215 |
+
)
|
| 216 |
+
encoded['target_mean'] = self.target_mean.copy()
|
| 217 |
+
encoded['target_std'] = self.target_std.copy()
|
| 218 |
+
|
| 219 |
+
lifting_target_label = (lifting_target_label -
|
| 220 |
+
self.target_mean) / self.target_std
|
| 221 |
+
|
| 222 |
+
# Generate reshaped keypoint coordinates
|
| 223 |
+
assert keypoint_labels.ndim in {
|
| 224 |
+
2, 3
|
| 225 |
+
}, (f'keypoint_labels.ndim {keypoint_labels.ndim} is not in {2, 3}')
|
| 226 |
+
if keypoint_labels.ndim == 2:
|
| 227 |
+
keypoint_labels = keypoint_labels[None, ...]
|
| 228 |
+
|
| 229 |
+
if self.concat_vis:
|
| 230 |
+
keypoints_visible_ = keypoints_visible
|
| 231 |
+
if keypoints_visible.ndim == 2:
|
| 232 |
+
keypoints_visible_ = keypoints_visible[..., None]
|
| 233 |
+
keypoint_labels = np.concatenate(
|
| 234 |
+
(keypoint_labels, keypoints_visible_), axis=2)
|
| 235 |
+
|
| 236 |
+
if self.reshape_keypoints:
|
| 237 |
+
N = keypoint_labels.shape[0]
|
| 238 |
+
keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N)
|
| 239 |
+
|
| 240 |
+
encoded['keypoint_labels'] = keypoint_labels
|
| 241 |
+
encoded['keypoint_labels_visible'] = keypoints_visible
|
| 242 |
+
encoded['lifting_target_label'] = lifting_target_label
|
| 243 |
+
encoded['lifting_target_weight'] = lifting_target_weight
|
| 244 |
+
encoded['trajectory_weights'] = trajectory_weights
|
| 245 |
+
encoded['target_root'] = root
|
| 246 |
+
|
| 247 |
+
return encoded
|
| 248 |
+
|
| 249 |
+
def decode(self,
|
| 250 |
+
encoded: np.ndarray,
|
| 251 |
+
target_root: Optional[np.ndarray] = None
|
| 252 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 253 |
+
"""Decode keypoint coordinates from normalized space to input image
|
| 254 |
+
space.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
encoded (np.ndarray): Coordinates in shape (N, K, C).
|
| 258 |
+
target_root (np.ndarray, optional): The target root coordinate.
|
| 259 |
+
Default: ``None``.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
keypoints (np.ndarray): Decoded coordinates in shape (N, K, C).
|
| 263 |
+
scores (np.ndarray): The keypoint scores in shape (N, K).
|
| 264 |
+
"""
|
| 265 |
+
keypoints = encoded.copy()
|
| 266 |
+
|
| 267 |
+
if self.target_mean is not None and self.target_std is not None:
|
| 268 |
+
assert self.target_mean.shape == keypoints.shape, (
|
| 269 |
+
f'self.target_mean.shape {self.target_mean.shape} '
|
| 270 |
+
f'!= keypoints.shape {keypoints.shape}')
|
| 271 |
+
keypoints = keypoints * self.target_std + self.target_mean
|
| 272 |
+
|
| 273 |
+
if target_root is not None and target_root.size > 0:
|
| 274 |
+
keypoints = keypoints + target_root
|
| 275 |
+
if self.remove_root and len(self.root_index) == 1:
|
| 276 |
+
keypoints = np.insert(
|
| 277 |
+
keypoints, self.root_index, target_root, axis=1)
|
| 278 |
+
scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
|
| 279 |
+
|
| 280 |
+
return keypoints, scores
|
mmpose/codecs/integral_regression_label.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 8 |
+
from .base import BaseKeypointCodec
|
| 9 |
+
from .msra_heatmap import MSRAHeatmap
|
| 10 |
+
from .regression_label import RegressionLabel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@KEYPOINT_CODECS.register_module()
|
| 14 |
+
class IntegralRegressionLabel(BaseKeypointCodec):
|
| 15 |
+
"""Generate keypoint coordinates and normalized heatmaps. See the paper:
|
| 16 |
+
`DSNT`_ by Nibali et al(2018).
|
| 17 |
+
|
| 18 |
+
Note:
|
| 19 |
+
|
| 20 |
+
- instance number: N
|
| 21 |
+
- keypoint number: K
|
| 22 |
+
- keypoint dimension: D
|
| 23 |
+
- image size: [w, h]
|
| 24 |
+
|
| 25 |
+
Encoded:
|
| 26 |
+
|
| 27 |
+
- keypoint_labels (np.ndarray): The normalized regression labels in
|
| 28 |
+
shape (N, K, D) where D is 2 for 2d coordinates
|
| 29 |
+
- heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) where
|
| 30 |
+
[W, H] is the `heatmap_size`
|
| 31 |
+
- keypoint_weights (np.ndarray): The target weights in shape (N, K)
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
input_size (tuple): Input image size in [w, h]
|
| 35 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 36 |
+
sigma (float): The sigma value of the Gaussian heatmap
|
| 37 |
+
unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'``
|
| 38 |
+
encoding. See `Dark Pose`_ for details. Defaults to ``False``
|
| 39 |
+
blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
|
| 40 |
+
modulation in DarkPose. The kernel size and sigma should follow
|
| 41 |
+
the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`.
|
| 42 |
+
Defaults to 11
|
| 43 |
+
normalize (bool): Whether to normalize the heatmaps. Defaults to True.
|
| 44 |
+
|
| 45 |
+
.. _`DSNT`: https://arxiv.org/abs/1801.07372
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
label_mapping_table = dict(
|
| 49 |
+
keypoint_labels='keypoint_labels',
|
| 50 |
+
keypoint_weights='keypoint_weights',
|
| 51 |
+
)
|
| 52 |
+
field_mapping_table = dict(heatmaps='heatmaps', )
|
| 53 |
+
|
| 54 |
+
def __init__(self,
|
| 55 |
+
input_size: Tuple[int, int],
|
| 56 |
+
heatmap_size: Tuple[int, int],
|
| 57 |
+
sigma: float,
|
| 58 |
+
unbiased: bool = False,
|
| 59 |
+
blur_kernel_size: int = 11,
|
| 60 |
+
normalize: bool = True) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.heatmap_codec = MSRAHeatmap(input_size, heatmap_size, sigma,
|
| 64 |
+
unbiased, blur_kernel_size)
|
| 65 |
+
self.keypoint_codec = RegressionLabel(input_size)
|
| 66 |
+
self.normalize = normalize
|
| 67 |
+
|
| 68 |
+
def encode(self,
|
| 69 |
+
keypoints: np.ndarray,
|
| 70 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 71 |
+
"""Encoding keypoints to regression labels and heatmaps.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 75 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 76 |
+
(N, K)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
dict:
|
| 80 |
+
- keypoint_labels (np.ndarray): The normalized regression labels in
|
| 81 |
+
shape (N, K, D) where D is 2 for 2d coordinates
|
| 82 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 83 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 84 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 85 |
+
(N, K)
|
| 86 |
+
"""
|
| 87 |
+
encoded_hm = self.heatmap_codec.encode(keypoints, keypoints_visible)
|
| 88 |
+
encoded_kp = self.keypoint_codec.encode(keypoints, keypoints_visible)
|
| 89 |
+
|
| 90 |
+
heatmaps = encoded_hm['heatmaps']
|
| 91 |
+
keypoint_labels = encoded_kp['keypoint_labels']
|
| 92 |
+
keypoint_weights = encoded_kp['keypoint_weights']
|
| 93 |
+
|
| 94 |
+
if self.normalize:
|
| 95 |
+
val_sum = heatmaps.sum(axis=(-1, -2)).reshape(-1, 1, 1) + 1e-24
|
| 96 |
+
heatmaps = heatmaps / val_sum
|
| 97 |
+
|
| 98 |
+
encoded = dict(
|
| 99 |
+
keypoint_labels=keypoint_labels,
|
| 100 |
+
heatmaps=heatmaps,
|
| 101 |
+
keypoint_weights=keypoint_weights)
|
| 102 |
+
|
| 103 |
+
return encoded
|
| 104 |
+
|
| 105 |
+
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 106 |
+
"""Decode keypoint coordinates from normalized space to input image
|
| 107 |
+
space.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
encoded (np.ndarray): Coordinates in shape (N, K, D)
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
tuple:
|
| 114 |
+
- keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
|
| 115 |
+
- socres (np.ndarray): The keypoint scores in shape (N, K).
|
| 116 |
+
It usually represents the confidence of the keypoint prediction
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
keypoints, scores = self.keypoint_codec.decode(encoded)
|
| 120 |
+
|
| 121 |
+
return keypoints, scores
|
mmpose/codecs/megvii_heatmap.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 9 |
+
from .base import BaseKeypointCodec
|
| 10 |
+
from .utils import gaussian_blur, get_heatmap_maximum
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@KEYPOINT_CODECS.register_module()
|
| 14 |
+
class MegviiHeatmap(BaseKeypointCodec):
|
| 15 |
+
"""Represent keypoints as heatmaps via "Megvii" approach. See `MSPN`_
|
| 16 |
+
(2019) and `CPN`_ (2018) for details.
|
| 17 |
+
|
| 18 |
+
Note:
|
| 19 |
+
|
| 20 |
+
- instance number: N
|
| 21 |
+
- keypoint number: K
|
| 22 |
+
- keypoint dimension: D
|
| 23 |
+
- image size: [w, h]
|
| 24 |
+
- heatmap size: [W, H]
|
| 25 |
+
|
| 26 |
+
Encoded:
|
| 27 |
+
|
| 28 |
+
- heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
|
| 29 |
+
where [W, H] is the `heatmap_size`
|
| 30 |
+
- keypoint_weights (np.ndarray): The target weights in shape (N, K)
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
input_size (tuple): Image size in [w, h]
|
| 34 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 35 |
+
kernel_size (tuple): The kernel size of the heatmap gaussian in
|
| 36 |
+
[ks_x, ks_y]
|
| 37 |
+
|
| 38 |
+
.. _`MSPN`: https://arxiv.org/abs/1901.00148
|
| 39 |
+
.. _`CPN`: https://arxiv.org/abs/1711.07319
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
label_mapping_table = dict(keypoint_weights='keypoint_weights', )
|
| 43 |
+
field_mapping_table = dict(heatmaps='heatmaps', )
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
input_size: Tuple[int, int],
|
| 48 |
+
heatmap_size: Tuple[int, int],
|
| 49 |
+
kernel_size: int,
|
| 50 |
+
) -> None:
|
| 51 |
+
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.input_size = input_size
|
| 54 |
+
self.heatmap_size = heatmap_size
|
| 55 |
+
self.kernel_size = kernel_size
|
| 56 |
+
self.scale_factor = (np.array(input_size) /
|
| 57 |
+
heatmap_size).astype(np.float32)
|
| 58 |
+
|
| 59 |
+
def encode(self,
|
| 60 |
+
keypoints: np.ndarray,
|
| 61 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 62 |
+
"""Encode keypoints into heatmaps. Note that the original keypoint
|
| 63 |
+
coordinates should be in the input image space.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 67 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 68 |
+
(N, K)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
dict:
|
| 72 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 73 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 74 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 75 |
+
(N, K)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
N, K, _ = keypoints.shape
|
| 79 |
+
W, H = self.heatmap_size
|
| 80 |
+
|
| 81 |
+
assert N == 1, (
|
| 82 |
+
f'{self.__class__.__name__} only support single-instance '
|
| 83 |
+
'keypoint encoding')
|
| 84 |
+
|
| 85 |
+
heatmaps = np.zeros((K, H, W), dtype=np.float32)
|
| 86 |
+
keypoint_weights = keypoints_visible.copy()
|
| 87 |
+
|
| 88 |
+
for n, k in product(range(N), range(K)):
|
| 89 |
+
# skip unlabled keypoints
|
| 90 |
+
if keypoints_visible[n, k] < 0.5:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# get center coordinates
|
| 94 |
+
kx, ky = (keypoints[n, k] / self.scale_factor).astype(np.int64)
|
| 95 |
+
if kx < 0 or kx >= W or ky < 0 or ky >= H:
|
| 96 |
+
keypoint_weights[n, k] = 0
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
heatmaps[k, ky, kx] = 1.
|
| 100 |
+
kernel_size = (self.kernel_size, self.kernel_size)
|
| 101 |
+
heatmaps[k] = cv2.GaussianBlur(heatmaps[k], kernel_size, 0)
|
| 102 |
+
|
| 103 |
+
# normalize the heatmap
|
| 104 |
+
heatmaps[k] = heatmaps[k] / heatmaps[k, ky, kx] * 255.
|
| 105 |
+
|
| 106 |
+
encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
|
| 107 |
+
|
| 108 |
+
return encoded
|
| 109 |
+
|
| 110 |
+
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 111 |
+
"""Decode keypoint coordinates from heatmaps. The decoded keypoint
|
| 112 |
+
coordinates are in the input image space.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
encoded (np.ndarray): Heatmaps in shape (K, H, W)
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
tuple:
|
| 119 |
+
- keypoints (np.ndarray): Decoded keypoint coordinates in shape
|
| 120 |
+
(K, D)
|
| 121 |
+
- scores (np.ndarray): The keypoint scores in shape (K,). It
|
| 122 |
+
usually represents the confidence of the keypoint prediction
|
| 123 |
+
"""
|
| 124 |
+
heatmaps = gaussian_blur(encoded.copy(), self.kernel_size)
|
| 125 |
+
K, H, W = heatmaps.shape
|
| 126 |
+
|
| 127 |
+
keypoints, scores = get_heatmap_maximum(heatmaps)
|
| 128 |
+
|
| 129 |
+
for k in range(K):
|
| 130 |
+
heatmap = heatmaps[k]
|
| 131 |
+
px = int(keypoints[k, 0])
|
| 132 |
+
py = int(keypoints[k, 1])
|
| 133 |
+
if 1 < px < W - 1 and 1 < py < H - 1:
|
| 134 |
+
diff = np.array([
|
| 135 |
+
heatmap[py][px + 1] - heatmap[py][px - 1],
|
| 136 |
+
heatmap[py + 1][px] - heatmap[py - 1][px]
|
| 137 |
+
])
|
| 138 |
+
keypoints[k] += (np.sign(diff) * 0.25 + 0.5)
|
| 139 |
+
|
| 140 |
+
scores = scores / 255.0 + 0.5
|
| 141 |
+
|
| 142 |
+
# Unsqueeze the instance dimension for single-instance results
|
| 143 |
+
# and restore the keypoint scales
|
| 144 |
+
keypoints = keypoints[None] * self.scale_factor
|
| 145 |
+
scores = scores[None]
|
| 146 |
+
|
| 147 |
+
return keypoints, scores
|
mmpose/codecs/motionbert_label.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 9 |
+
from .base import BaseKeypointCodec
|
| 10 |
+
from .utils import camera_to_image_coord
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@KEYPOINT_CODECS.register_module()
|
| 14 |
+
class MotionBERTLabel(BaseKeypointCodec):
|
| 15 |
+
r"""Generate keypoint and label coordinates for `MotionBERT`_ by Zhu et al
|
| 16 |
+
(2022).
|
| 17 |
+
|
| 18 |
+
Note:
|
| 19 |
+
|
| 20 |
+
- instance number: N
|
| 21 |
+
- keypoint number: K
|
| 22 |
+
- keypoint dimension: D
|
| 23 |
+
- pose-lifitng target dimension: C
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
num_keypoints (int): The number of keypoints in the dataset.
|
| 27 |
+
root_index (int): Root keypoint index in the pose. Default: 0.
|
| 28 |
+
remove_root (bool): If true, remove the root keypoint from the pose.
|
| 29 |
+
Default: ``False``.
|
| 30 |
+
save_index (bool): If true, store the root position separated from the
|
| 31 |
+
original pose, only takes effect if ``remove_root`` is ``True``.
|
| 32 |
+
Default: ``False``.
|
| 33 |
+
concat_vis (bool): If true, concat the visibility item of keypoints.
|
| 34 |
+
Default: ``False``.
|
| 35 |
+
rootrel (bool): If true, the root keypoint will be set to the
|
| 36 |
+
coordinate origin. Default: ``False``.
|
| 37 |
+
mode (str): Indicating whether the current mode is 'train' or 'test'.
|
| 38 |
+
Default: ``'test'``.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
auxiliary_encode_keys = {
|
| 42 |
+
'lifting_target', 'lifting_target_visible', 'camera_param', 'factor'
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
instance_mapping_table = dict(
|
| 46 |
+
lifting_target='lifting_target',
|
| 47 |
+
lifting_target_visible='lifting_target_visible',
|
| 48 |
+
)
|
| 49 |
+
label_mapping_table = dict(
|
| 50 |
+
trajectory_weights='trajectory_weights',
|
| 51 |
+
lifting_target_label='lifting_target_label',
|
| 52 |
+
lifting_target_weight='lifting_target_weight')
|
| 53 |
+
|
| 54 |
+
def __init__(self,
|
| 55 |
+
num_keypoints: int,
|
| 56 |
+
root_index: int = 0,
|
| 57 |
+
remove_root: bool = False,
|
| 58 |
+
save_index: bool = False,
|
| 59 |
+
concat_vis: bool = False,
|
| 60 |
+
rootrel: bool = False,
|
| 61 |
+
mode: str = 'test'):
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.num_keypoints = num_keypoints
|
| 65 |
+
self.root_index = root_index
|
| 66 |
+
self.remove_root = remove_root
|
| 67 |
+
self.save_index = save_index
|
| 68 |
+
self.concat_vis = concat_vis
|
| 69 |
+
self.rootrel = rootrel
|
| 70 |
+
assert mode.lower() in {'train', 'test'
|
| 71 |
+
}, (f'Unsupported mode {mode}, '
|
| 72 |
+
'mode should be one of ("train", "test").')
|
| 73 |
+
self.mode = mode.lower()
|
| 74 |
+
|
| 75 |
+
def encode(self,
|
| 76 |
+
keypoints: np.ndarray,
|
| 77 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 78 |
+
lifting_target: Optional[np.ndarray] = None,
|
| 79 |
+
lifting_target_visible: Optional[np.ndarray] = None,
|
| 80 |
+
camera_param: Optional[dict] = None,
|
| 81 |
+
factor: Optional[np.ndarray] = None) -> dict:
|
| 82 |
+
"""Encoding keypoints from input image space to normalized space.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (B, T, K, D).
|
| 86 |
+
keypoints_visible (np.ndarray, optional): Keypoint visibilities in
|
| 87 |
+
shape (B, T, K).
|
| 88 |
+
lifting_target (np.ndarray, optional): 3d target coordinate in
|
| 89 |
+
shape (T, K, C).
|
| 90 |
+
lifting_target_visible (np.ndarray, optional): Target coordinate in
|
| 91 |
+
shape (T, K, ).
|
| 92 |
+
camera_param (dict, optional): The camera parameter dictionary.
|
| 93 |
+
factor (np.ndarray, optional): The factor mapping camera and image
|
| 94 |
+
coordinate in shape (T, ).
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
encoded (dict): Contains the following items:
|
| 98 |
+
|
| 99 |
+
- keypoint_labels (np.ndarray): The processed keypoints in
|
| 100 |
+
shape like (N, K, D).
|
| 101 |
+
- keypoint_labels_visible (np.ndarray): The processed
|
| 102 |
+
keypoints' weights in shape (N, K, ) or (N, K-1, ).
|
| 103 |
+
- lifting_target_label: The processed target coordinate in
|
| 104 |
+
shape (K, C) or (K-1, C).
|
| 105 |
+
- lifting_target_weight (np.ndarray): The target weights in
|
| 106 |
+
shape (K, ) or (K-1, ).
|
| 107 |
+
- factor (np.ndarray): The factor mapping camera and image
|
| 108 |
+
coordinate in shape (T, 1).
|
| 109 |
+
"""
|
| 110 |
+
if keypoints_visible is None:
|
| 111 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 112 |
+
|
| 113 |
+
# set initial value for `lifting_target_weight`
|
| 114 |
+
if lifting_target_visible is None:
|
| 115 |
+
lifting_target_visible = np.ones(
|
| 116 |
+
lifting_target.shape[:-1], dtype=np.float32)
|
| 117 |
+
lifting_target_weight = lifting_target_visible
|
| 118 |
+
else:
|
| 119 |
+
valid = lifting_target_visible > 0.5
|
| 120 |
+
lifting_target_weight = np.where(valid, 1., 0.).astype(np.float32)
|
| 121 |
+
|
| 122 |
+
if camera_param is None:
|
| 123 |
+
camera_param = dict()
|
| 124 |
+
|
| 125 |
+
encoded = dict()
|
| 126 |
+
|
| 127 |
+
assert lifting_target is not None
|
| 128 |
+
lifting_target_label = lifting_target.copy()
|
| 129 |
+
keypoint_labels = keypoints.copy()
|
| 130 |
+
|
| 131 |
+
assert keypoint_labels.ndim in {
|
| 132 |
+
2, 3
|
| 133 |
+
}, (f'Keypoint labels should have 2 or 3 dimensions, '
|
| 134 |
+
f'but got {keypoint_labels.ndim}.')
|
| 135 |
+
if keypoint_labels.ndim == 2:
|
| 136 |
+
keypoint_labels = keypoint_labels[None, ...]
|
| 137 |
+
|
| 138 |
+
# Normalize the 2D keypoint coordinate with image width and height
|
| 139 |
+
_camera_param = deepcopy(camera_param)
|
| 140 |
+
assert 'w' in _camera_param and 'h' in _camera_param, (
|
| 141 |
+
'Camera parameters should contain "w" and "h".')
|
| 142 |
+
w, h = _camera_param['w'], _camera_param['h']
|
| 143 |
+
keypoint_labels[
|
| 144 |
+
..., :2] = keypoint_labels[..., :2] / w * 2 - [1, h / w]
|
| 145 |
+
|
| 146 |
+
# convert target to image coordinate
|
| 147 |
+
T = keypoint_labels.shape[0]
|
| 148 |
+
factor_ = np.array([4] * T, dtype=np.float32).reshape(T, )
|
| 149 |
+
if 'f' in _camera_param and 'c' in _camera_param:
|
| 150 |
+
lifting_target_label, factor_ = camera_to_image_coord(
|
| 151 |
+
self.root_index, lifting_target_label, _camera_param)
|
| 152 |
+
if self.mode == 'train':
|
| 153 |
+
w, h = w / 1000, h / 1000
|
| 154 |
+
lifting_target_label[
|
| 155 |
+
..., :2] = lifting_target_label[..., :2] / w * 2 - [1, h / w]
|
| 156 |
+
lifting_target_label[..., 2] = lifting_target_label[..., 2] / w * 2
|
| 157 |
+
lifting_target_label[..., :, :] = lifting_target_label[
|
| 158 |
+
..., :, :] - lifting_target_label[...,
|
| 159 |
+
self.root_index:self.root_index +
|
| 160 |
+
1, :]
|
| 161 |
+
if factor is None or factor[0] == 0:
|
| 162 |
+
factor = factor_
|
| 163 |
+
if factor.ndim == 1:
|
| 164 |
+
factor = factor[:, None]
|
| 165 |
+
if self.mode == 'test':
|
| 166 |
+
lifting_target_label *= factor[..., None]
|
| 167 |
+
|
| 168 |
+
if self.concat_vis:
|
| 169 |
+
keypoints_visible_ = keypoints_visible
|
| 170 |
+
if keypoints_visible.ndim == 2:
|
| 171 |
+
keypoints_visible_ = keypoints_visible[..., None]
|
| 172 |
+
keypoint_labels = np.concatenate(
|
| 173 |
+
(keypoint_labels, keypoints_visible_), axis=2)
|
| 174 |
+
|
| 175 |
+
encoded['keypoint_labels'] = keypoint_labels
|
| 176 |
+
encoded['keypoint_labels_visible'] = keypoints_visible
|
| 177 |
+
encoded['lifting_target_label'] = lifting_target_label
|
| 178 |
+
encoded['lifting_target_weight'] = lifting_target_weight
|
| 179 |
+
encoded['lifting_target'] = lifting_target_label
|
| 180 |
+
encoded['lifting_target_visible'] = lifting_target_visible
|
| 181 |
+
encoded['factor'] = factor
|
| 182 |
+
|
| 183 |
+
return encoded
|
| 184 |
+
|
| 185 |
+
def decode(
|
| 186 |
+
self,
|
| 187 |
+
encoded: np.ndarray,
|
| 188 |
+
w: Optional[np.ndarray] = None,
|
| 189 |
+
h: Optional[np.ndarray] = None,
|
| 190 |
+
factor: Optional[np.ndarray] = None,
|
| 191 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 192 |
+
"""Decode keypoint coordinates from normalized space to input image
|
| 193 |
+
space.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
encoded (np.ndarray): Coordinates in shape (N, K, C).
|
| 197 |
+
w (np.ndarray, optional): The image widths in shape (N, ).
|
| 198 |
+
Default: ``None``.
|
| 199 |
+
h (np.ndarray, optional): The image heights in shape (N, ).
|
| 200 |
+
Default: ``None``.
|
| 201 |
+
factor (np.ndarray, optional): The factor for projection in shape
|
| 202 |
+
(N, ). Default: ``None``.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
keypoints (np.ndarray): Decoded coordinates in shape (N, K, C).
|
| 206 |
+
scores (np.ndarray): The keypoint scores in shape (N, K).
|
| 207 |
+
"""
|
| 208 |
+
keypoints = encoded.copy()
|
| 209 |
+
scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
|
| 210 |
+
|
| 211 |
+
if self.rootrel:
|
| 212 |
+
keypoints[..., 0, :] = 0
|
| 213 |
+
|
| 214 |
+
if w is not None and w.size > 0:
|
| 215 |
+
assert w.shape == h.shape, (f'w and h should have the same shape, '
|
| 216 |
+
f'but got {w.shape} and {h.shape}.')
|
| 217 |
+
assert w.shape[0] == keypoints.shape[0], (
|
| 218 |
+
f'w and h should have the same batch size, '
|
| 219 |
+
f'but got {w.shape[0]} and {keypoints.shape[0]}.')
|
| 220 |
+
assert w.ndim in {1,
|
| 221 |
+
2}, (f'w and h should have 1 or 2 dimensions, '
|
| 222 |
+
f'but got {w.ndim}.')
|
| 223 |
+
if w.ndim == 1:
|
| 224 |
+
w = w[:, None]
|
| 225 |
+
h = h[:, None]
|
| 226 |
+
trans = np.append(
|
| 227 |
+
np.ones((w.shape[0], 1)), h / w, axis=1)[:, None, :]
|
| 228 |
+
keypoints[..., :2] = (keypoints[..., :2] + trans) * w[:, None] / 2
|
| 229 |
+
keypoints[..., 2:] = keypoints[..., 2:] * w[:, None] / 2
|
| 230 |
+
|
| 231 |
+
if factor is not None and factor.size > 0:
|
| 232 |
+
assert factor.shape[0] == keypoints.shape[0], (
|
| 233 |
+
f'factor should have the same batch size, '
|
| 234 |
+
f'but got {factor.shape[0]} and {keypoints.shape[0]}.')
|
| 235 |
+
keypoints *= factor[..., None]
|
| 236 |
+
|
| 237 |
+
keypoints[..., :, :] = keypoints[..., :, :] - keypoints[
|
| 238 |
+
..., self.root_index:self.root_index + 1, :]
|
| 239 |
+
keypoints /= 1000.
|
| 240 |
+
return keypoints, scores
|
mmpose/codecs/msra_heatmap.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 7 |
+
from .base import BaseKeypointCodec
|
| 8 |
+
from .utils.gaussian_heatmap import (generate_gaussian_heatmaps,
|
| 9 |
+
generate_unbiased_gaussian_heatmaps)
|
| 10 |
+
from .utils.post_processing import get_heatmap_maximum
|
| 11 |
+
from .utils.refinement import refine_keypoints, refine_keypoints_dark
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@KEYPOINT_CODECS.register_module()
|
| 15 |
+
class MSRAHeatmap(BaseKeypointCodec):
|
| 16 |
+
"""Represent keypoints as heatmaps via "MSRA" approach. See the paper:
|
| 17 |
+
`Simple Baselines for Human Pose Estimation and Tracking`_ by Xiao et al
|
| 18 |
+
(2018) for details.
|
| 19 |
+
|
| 20 |
+
Note:
|
| 21 |
+
|
| 22 |
+
- instance number: N
|
| 23 |
+
- keypoint number: K
|
| 24 |
+
- keypoint dimension: D
|
| 25 |
+
- image size: [w, h]
|
| 26 |
+
- heatmap size: [W, H]
|
| 27 |
+
|
| 28 |
+
Encoded:
|
| 29 |
+
|
| 30 |
+
- heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
|
| 31 |
+
where [W, H] is the `heatmap_size`
|
| 32 |
+
- keypoint_weights (np.ndarray): The target weights in shape (N, K)
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
input_size (tuple): Image size in [w, h]
|
| 36 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 37 |
+
sigma (float): The sigma value of the Gaussian heatmap
|
| 38 |
+
unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'``
|
| 39 |
+
encoding. See `Dark Pose`_ for details. Defaults to ``False``
|
| 40 |
+
blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
|
| 41 |
+
modulation in DarkPose. The kernel size and sigma should follow
|
| 42 |
+
the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`.
|
| 43 |
+
Defaults to 11
|
| 44 |
+
|
| 45 |
+
.. _`Simple Baselines for Human Pose Estimation and Tracking`:
|
| 46 |
+
https://arxiv.org/abs/1804.06208
|
| 47 |
+
.. _`Dark Pose`: https://arxiv.org/abs/1910.06278
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
label_mapping_table = dict(keypoint_weights='keypoint_weights', )
|
| 51 |
+
field_mapping_table = dict(heatmaps='heatmaps', )
|
| 52 |
+
|
| 53 |
+
def __init__(self,
|
| 54 |
+
input_size: Tuple[int, int],
|
| 55 |
+
heatmap_size: Tuple[int, int],
|
| 56 |
+
sigma: float,
|
| 57 |
+
unbiased: bool = False,
|
| 58 |
+
blur_kernel_size: int = 11) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.input_size = input_size
|
| 61 |
+
self.heatmap_size = heatmap_size
|
| 62 |
+
self.sigma = sigma
|
| 63 |
+
self.unbiased = unbiased
|
| 64 |
+
|
| 65 |
+
# The Gaussian blur kernel size of the heatmap modulation
|
| 66 |
+
# in DarkPose and the sigma value follows the expirical
|
| 67 |
+
# formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`
|
| 68 |
+
# which gives:
|
| 69 |
+
# sigma~=3 if ks=17
|
| 70 |
+
# sigma=2 if ks=11;
|
| 71 |
+
# sigma~=1.5 if ks=7;
|
| 72 |
+
# sigma~=1 if ks=3;
|
| 73 |
+
self.blur_kernel_size = blur_kernel_size
|
| 74 |
+
self.scale_factor = (np.array(input_size) /
|
| 75 |
+
heatmap_size).astype(np.float32)
|
| 76 |
+
|
| 77 |
+
def encode(self,
|
| 78 |
+
keypoints: np.ndarray,
|
| 79 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 80 |
+
"""Encode keypoints into heatmaps. Note that the original keypoint
|
| 81 |
+
coordinates should be in the input image space.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 85 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 86 |
+
(N, K)
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
dict:
|
| 90 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 91 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 92 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 93 |
+
(N, K)
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
assert keypoints.shape[0] == 1, (
|
| 97 |
+
f'{self.__class__.__name__} only support single-instance '
|
| 98 |
+
'keypoint encoding')
|
| 99 |
+
|
| 100 |
+
if keypoints_visible is None:
|
| 101 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 102 |
+
|
| 103 |
+
if self.unbiased:
|
| 104 |
+
heatmaps, keypoint_weights = generate_unbiased_gaussian_heatmaps(
|
| 105 |
+
heatmap_size=self.heatmap_size,
|
| 106 |
+
keypoints=keypoints / self.scale_factor,
|
| 107 |
+
keypoints_visible=keypoints_visible,
|
| 108 |
+
sigma=self.sigma)
|
| 109 |
+
else:
|
| 110 |
+
heatmaps, keypoint_weights = generate_gaussian_heatmaps(
|
| 111 |
+
heatmap_size=self.heatmap_size,
|
| 112 |
+
keypoints=keypoints / self.scale_factor,
|
| 113 |
+
keypoints_visible=keypoints_visible,
|
| 114 |
+
sigma=self.sigma)
|
| 115 |
+
|
| 116 |
+
encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
|
| 117 |
+
|
| 118 |
+
return encoded
|
| 119 |
+
|
| 120 |
+
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 121 |
+
"""Decode keypoint coordinates from heatmaps. The decoded keypoint
|
| 122 |
+
coordinates are in the input image space.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
encoded (np.ndarray): Heatmaps in shape (K, H, W)
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
tuple:
|
| 129 |
+
- keypoints (np.ndarray): Decoded keypoint coordinates in shape
|
| 130 |
+
(N, K, D)
|
| 131 |
+
- scores (np.ndarray): The keypoint scores in shape (N, K). It
|
| 132 |
+
usually represents the confidence of the keypoint prediction
|
| 133 |
+
"""
|
| 134 |
+
heatmaps = encoded.copy()
|
| 135 |
+
K, H, W = heatmaps.shape
|
| 136 |
+
|
| 137 |
+
keypoints, scores = get_heatmap_maximum(heatmaps)
|
| 138 |
+
|
| 139 |
+
# Unsqueeze the instance dimension for single-instance results
|
| 140 |
+
keypoints, scores = keypoints[None], scores[None]
|
| 141 |
+
|
| 142 |
+
if self.unbiased:
|
| 143 |
+
# Alleviate biased coordinate
|
| 144 |
+
keypoints = refine_keypoints_dark(
|
| 145 |
+
keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
|
| 146 |
+
|
| 147 |
+
else:
|
| 148 |
+
keypoints = refine_keypoints(keypoints, heatmaps)
|
| 149 |
+
|
| 150 |
+
# Restore the keypoint scale
|
| 151 |
+
keypoints = keypoints * self.scale_factor
|
| 152 |
+
|
| 153 |
+
return keypoints, scores
|
mmpose/codecs/onehot_heatmap.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 8 |
+
from .base import BaseKeypointCodec
|
| 9 |
+
from .utils import (generate_offset_heatmap, generate_onehot_heatmaps,
|
| 10 |
+
get_heatmap_maximum, refine_keypoints_dark_udp)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@KEYPOINT_CODECS.register_module()
|
| 14 |
+
class OneHotHeatmap(BaseKeypointCodec):
|
| 15 |
+
r"""Generate keypoint heatmaps by Unbiased Data Processing (UDP).
|
| 16 |
+
See the paper: `The Devil is in the Details: Delving into Unbiased Data
|
| 17 |
+
Processing for Human Pose Estimation`_ by Huang et al (2020) for details.
|
| 18 |
+
|
| 19 |
+
Note:
|
| 20 |
+
|
| 21 |
+
- instance number: N
|
| 22 |
+
- keypoint number: K
|
| 23 |
+
- keypoint dimension: D
|
| 24 |
+
- image size: [w, h]
|
| 25 |
+
- heatmap size: [W, H]
|
| 26 |
+
|
| 27 |
+
Encoded:
|
| 28 |
+
|
| 29 |
+
- heatmap (np.ndarray): The generated heatmap in shape (C_out, H, W)
|
| 30 |
+
where [W, H] is the `heatmap_size`, and the C_out is the output
|
| 31 |
+
channel number which depends on the `heatmap_type`. If
|
| 32 |
+
`heatmap_type=='gaussian'`, C_out equals to keypoint number K;
|
| 33 |
+
if `heatmap_type=='combined'`, C_out equals to K*3
|
| 34 |
+
(x_offset, y_offset and class label)
|
| 35 |
+
- keypoint_weights (np.ndarray): The target weights in shape (K,)
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
input_size (tuple): Image size in [w, h]
|
| 39 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 40 |
+
heatmap_type (str): The heatmap type to encode the keypoitns. Options
|
| 41 |
+
are:
|
| 42 |
+
|
| 43 |
+
- ``'gaussian'``: Gaussian heatmap
|
| 44 |
+
- ``'combined'``: Combination of a binary label map and offset
|
| 45 |
+
maps for X and Y axes.
|
| 46 |
+
|
| 47 |
+
sigma (float): The sigma value of the Gaussian heatmap when
|
| 48 |
+
``heatmap_type=='gaussian'``. Defaults to 2.0
|
| 49 |
+
radius_factor (float): The radius factor of the binary label
|
| 50 |
+
map when ``heatmap_type=='combined'``. The positive region is
|
| 51 |
+
defined as the neighbor of the keypoit with the radius
|
| 52 |
+
:math:`r=radius_factor*max(W, H)`. Defaults to 0.0546875
|
| 53 |
+
blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
|
| 54 |
+
modulation in DarkPose. Defaults to 11
|
| 55 |
+
|
| 56 |
+
.. _`The Devil is in the Details: Delving into Unbiased Data Processing for
|
| 57 |
+
Human Pose Estimation`: https://arxiv.org/abs/1911.07524
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
label_mapping_table = dict(keypoint_weights='keypoint_weights', )
|
| 61 |
+
field_mapping_table = dict(heatmaps='heatmaps', )
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
input_size: Tuple[int, int],
|
| 65 |
+
heatmap_size: Tuple[int, int],
|
| 66 |
+
heatmap_type: str = 'gaussian',
|
| 67 |
+
sigma: float = 2.,
|
| 68 |
+
radius_factor: float = 0.0546875,
|
| 69 |
+
blur_kernel_size: int = 11,
|
| 70 |
+
increase_sigma_with_padding=False,
|
| 71 |
+
amap_scale: float = 1.0,
|
| 72 |
+
normalize=None,
|
| 73 |
+
) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.input_size = np.array(input_size)
|
| 76 |
+
self.heatmap_size = np.array(heatmap_size)
|
| 77 |
+
self.sigma = sigma
|
| 78 |
+
self.radius_factor = radius_factor
|
| 79 |
+
self.heatmap_type = heatmap_type
|
| 80 |
+
self.blur_kernel_size = blur_kernel_size
|
| 81 |
+
self.increase_sigma_with_padding = increase_sigma_with_padding
|
| 82 |
+
self.normalize = normalize
|
| 83 |
+
|
| 84 |
+
self.amap_size = self.input_size * amap_scale
|
| 85 |
+
self.scale_factor = ((self.amap_size - 1) /
|
| 86 |
+
(self.heatmap_size - 1)).astype(np.float32)
|
| 87 |
+
self.input_center = self.input_size / 2
|
| 88 |
+
self.top_left = self.input_center - self.amap_size / 2
|
| 89 |
+
|
| 90 |
+
if self.heatmap_type not in {'gaussian', 'combined'}:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f'{self.__class__.__name__} got invalid `heatmap_type` value'
|
| 93 |
+
f'{self.heatmap_type}. Should be one of '
|
| 94 |
+
'{"gaussian", "combined"}')
|
| 95 |
+
|
| 96 |
+
def _kpts_to_activation_pts(self, keypoints: np.ndarray) -> np.ndarray:
|
| 97 |
+
"""
|
| 98 |
+
Transform the keypoint coordinates to the activation space.
|
| 99 |
+
In the original UDPHeatmap, activation map is the same as the input image space with
|
| 100 |
+
different resolution but in this case we allow the activation map to have different
|
| 101 |
+
size (padding) than the input image space.
|
| 102 |
+
Centers of activation map and input image space are aligned.
|
| 103 |
+
"""
|
| 104 |
+
transformed_keypoints = keypoints - self.top_left
|
| 105 |
+
transformed_keypoints = transformed_keypoints / self.scale_factor
|
| 106 |
+
return transformed_keypoints
|
| 107 |
+
|
| 108 |
+
def _activation_pts_to_kpts(self, keypoints: np.ndarray) -> np.ndarray:
|
| 109 |
+
"""
|
| 110 |
+
Transform the points in activation map to the keypoint coordinates.
|
| 111 |
+
In the original UDPHeatmap, activation map is the same as the input image space with
|
| 112 |
+
different resolution but in this case we allow the activation map to have different
|
| 113 |
+
size (padding) than the input image space.
|
| 114 |
+
Centers of activation map and input image space are aligned.
|
| 115 |
+
"""
|
| 116 |
+
W, H = self.heatmap_size
|
| 117 |
+
transformed_keypoints = keypoints / [W - 1, H - 1] * self.amap_size
|
| 118 |
+
transformed_keypoints += self.top_left
|
| 119 |
+
return transformed_keypoints
|
| 120 |
+
|
| 121 |
+
def encode(self,
|
| 122 |
+
keypoints: np.ndarray,
|
| 123 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 124 |
+
id_similarity: Optional[float] = 0.0,
|
| 125 |
+
keypoints_visibility: Optional[np.ndarray] = None) -> dict:
|
| 126 |
+
"""Encode keypoints into heatmaps. Note that the original keypoint
|
| 127 |
+
coordinates should be in the input image space.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 131 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 132 |
+
(N, K)
|
| 133 |
+
id_similarity (float): The usefulness of the identity information
|
| 134 |
+
for the whole pose. Defaults to 0.0
|
| 135 |
+
keypoints_visibility (np.ndarray): The visibility bit for each
|
| 136 |
+
keypoint (N, K). Defaults to None
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
dict:
|
| 140 |
+
- heatmap (np.ndarray): The generated heatmap in shape
|
| 141 |
+
(C_out, H, W) where [W, H] is the `heatmap_size`, and the
|
| 142 |
+
C_out is the output channel number which depends on the
|
| 143 |
+
`heatmap_type`. If `heatmap_type=='gaussian'`, C_out equals to
|
| 144 |
+
keypoint number K; if `heatmap_type=='combined'`, C_out
|
| 145 |
+
equals to K*3 (x_offset, y_offset and class label)
|
| 146 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 147 |
+
(K,)
|
| 148 |
+
"""
|
| 149 |
+
assert keypoints.shape[0] == 1, (
|
| 150 |
+
f'{self.__class__.__name__} only support single-instance '
|
| 151 |
+
'keypoint encoding')
|
| 152 |
+
|
| 153 |
+
if keypoints_visibility is None:
|
| 154 |
+
keypoints_visibility = np.zeros(keypoints.shape[:2], dtype=np.float32)
|
| 155 |
+
|
| 156 |
+
if keypoints_visible is None:
|
| 157 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 158 |
+
|
| 159 |
+
if self.heatmap_type == 'gaussian':
|
| 160 |
+
heatmaps, keypoint_weights = generate_onehot_heatmaps(
|
| 161 |
+
heatmap_size=self.heatmap_size,
|
| 162 |
+
keypoints=self._kpts_to_activation_pts(keypoints),
|
| 163 |
+
keypoints_visible=keypoints_visible,
|
| 164 |
+
sigma=self.sigma,
|
| 165 |
+
keypoints_visibility=keypoints_visibility,
|
| 166 |
+
increase_sigma_with_padding=self.increase_sigma_with_padding)
|
| 167 |
+
elif self.heatmap_type == 'combined':
|
| 168 |
+
heatmaps, keypoint_weights = generate_offset_heatmap(
|
| 169 |
+
heatmap_size=self.heatmap_size,
|
| 170 |
+
keypoints=self._kpts_to_activation_pts(keypoints),
|
| 171 |
+
keypoints_visible=keypoints_visible,
|
| 172 |
+
radius_factor=self.radius_factor)
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f'{self.__class__.__name__} got invalid `heatmap_type` value'
|
| 176 |
+
f'{self.heatmap_type}. Should be one of '
|
| 177 |
+
'{"gaussian", "combined"}')
|
| 178 |
+
|
| 179 |
+
if self.normalize is not None:
|
| 180 |
+
heatmaps_sum = np.sum(heatmaps, axis=(1, 2), keepdims=False)
|
| 181 |
+
mask = heatmaps_sum > 0
|
| 182 |
+
heatmaps[mask, :, :] = heatmaps[mask, :, :] / (heatmaps_sum[mask, None, None] + np.finfo(np.float32).eps)
|
| 183 |
+
heatmaps = heatmaps * self.normalize
|
| 184 |
+
|
| 185 |
+
annotated = keypoints_visible > 0
|
| 186 |
+
|
| 187 |
+
heatmap_keypoints = self._kpts_to_activation_pts(keypoints)
|
| 188 |
+
in_image = np.logical_and(
|
| 189 |
+
heatmap_keypoints[:, :, 0] >= 0,
|
| 190 |
+
heatmap_keypoints[:, :, 0] < self.heatmap_size[0],
|
| 191 |
+
)
|
| 192 |
+
in_image = np.logical_and(
|
| 193 |
+
in_image,
|
| 194 |
+
heatmap_keypoints[:, :, 1] >= 0,
|
| 195 |
+
)
|
| 196 |
+
in_image = np.logical_and(
|
| 197 |
+
in_image,
|
| 198 |
+
heatmap_keypoints[:, :, 1] < self.heatmap_size[1],
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
encoded = dict(
|
| 202 |
+
heatmaps=heatmaps,
|
| 203 |
+
keypoint_weights=keypoint_weights,
|
| 204 |
+
annotated=annotated,
|
| 205 |
+
in_image=in_image,
|
| 206 |
+
keypoints_scaled=keypoints,
|
| 207 |
+
heatmap_keypoints=heatmap_keypoints,
|
| 208 |
+
identification_similarity=id_similarity,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return encoded
|
| 212 |
+
|
| 213 |
+
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 214 |
+
"""Decode keypoint coordinates from heatmaps. The decoded keypoint
|
| 215 |
+
coordinates are in the input image space.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
encoded (np.ndarray): Heatmaps in shape (K, H, W)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
tuple:
|
| 222 |
+
- keypoints (np.ndarray): Decoded keypoint coordinates in shape
|
| 223 |
+
(N, K, D)
|
| 224 |
+
- scores (np.ndarray): The keypoint scores in shape (N, K). It
|
| 225 |
+
usually represents the confidence of the keypoint prediction
|
| 226 |
+
"""
|
| 227 |
+
heatmaps = encoded.copy()
|
| 228 |
+
|
| 229 |
+
if self.heatmap_type == 'gaussian':
|
| 230 |
+
keypoints, scores = get_heatmap_maximum(heatmaps)
|
| 231 |
+
# unsqueeze the instance dimension for single-instance results
|
| 232 |
+
keypoints = keypoints[None]
|
| 233 |
+
scores = scores[None]
|
| 234 |
+
|
| 235 |
+
keypoints = refine_keypoints_dark_udp(
|
| 236 |
+
keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
|
| 237 |
+
|
| 238 |
+
elif self.heatmap_type == 'combined':
|
| 239 |
+
_K, H, W = heatmaps.shape
|
| 240 |
+
K = _K // 3
|
| 241 |
+
|
| 242 |
+
for cls_heatmap in heatmaps[::3]:
|
| 243 |
+
# Apply Gaussian blur on classification maps
|
| 244 |
+
ks = 2 * self.blur_kernel_size + 1
|
| 245 |
+
cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap)
|
| 246 |
+
|
| 247 |
+
# valid radius
|
| 248 |
+
radius = self.radius_factor * max(W, H)
|
| 249 |
+
|
| 250 |
+
x_offset = heatmaps[1::3].flatten() * radius
|
| 251 |
+
y_offset = heatmaps[2::3].flatten() * radius
|
| 252 |
+
keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3])
|
| 253 |
+
index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten()
|
| 254 |
+
index += W * H * np.arange(0, K)
|
| 255 |
+
index = index.astype(int)
|
| 256 |
+
keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1)
|
| 257 |
+
# unsqueeze the instance dimension for single-instance results
|
| 258 |
+
keypoints = keypoints[None].astype(np.float32)
|
| 259 |
+
scores = scores[None]
|
| 260 |
+
|
| 261 |
+
keypoints = self._activation_pts_to_kpts(keypoints)
|
| 262 |
+
|
| 263 |
+
return keypoints, scores
|
mmpose/codecs/regression_label.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 8 |
+
from .base import BaseKeypointCodec
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@KEYPOINT_CODECS.register_module()
|
| 12 |
+
class RegressionLabel(BaseKeypointCodec):
|
| 13 |
+
r"""Generate keypoint coordinates.
|
| 14 |
+
|
| 15 |
+
Note:
|
| 16 |
+
|
| 17 |
+
- instance number: N
|
| 18 |
+
- keypoint number: K
|
| 19 |
+
- keypoint dimension: D
|
| 20 |
+
- image size: [w, h]
|
| 21 |
+
|
| 22 |
+
Encoded:
|
| 23 |
+
|
| 24 |
+
- keypoint_labels (np.ndarray): The normalized regression labels in
|
| 25 |
+
shape (N, K, D) where D is 2 for 2d coordinates
|
| 26 |
+
- keypoint_weights (np.ndarray): The target weights in shape (N, K)
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
input_size (tuple): Input image size in [w, h]
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
label_mapping_table = dict(
|
| 34 |
+
keypoint_labels='keypoint_labels',
|
| 35 |
+
keypoint_weights='keypoint_weights',
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def __init__(self, input_size: Tuple[int, int]) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.input_size = input_size
|
| 42 |
+
|
| 43 |
+
def encode(self,
|
| 44 |
+
keypoints: np.ndarray,
|
| 45 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 46 |
+
"""Encoding keypoints from input image space to normalized space.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 50 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 51 |
+
(N, K)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
dict:
|
| 55 |
+
- keypoint_labels (np.ndarray): The normalized regression labels in
|
| 56 |
+
shape (N, K, D) where D is 2 for 2d coordinates
|
| 57 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 58 |
+
(N, K)
|
| 59 |
+
"""
|
| 60 |
+
if keypoints_visible is None:
|
| 61 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 62 |
+
|
| 63 |
+
w, h = self.input_size
|
| 64 |
+
valid = ((keypoints >= 0) &
|
| 65 |
+
(keypoints <= [w - 1, h - 1])).all(axis=-1) & (
|
| 66 |
+
keypoints_visible > 0.5)
|
| 67 |
+
|
| 68 |
+
keypoint_labels = (keypoints / np.array([w, h])).astype(np.float32)
|
| 69 |
+
keypoint_weights = np.where(valid, 1., 0.).astype(np.float32)
|
| 70 |
+
|
| 71 |
+
encoded = dict(
|
| 72 |
+
keypoint_labels=keypoint_labels, keypoint_weights=keypoint_weights)
|
| 73 |
+
|
| 74 |
+
return encoded
|
| 75 |
+
|
| 76 |
+
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 77 |
+
"""Decode keypoint coordinates from normalized space to input image
|
| 78 |
+
space.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
encoded (np.ndarray): Coordinates in shape (N, K, D)
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
tuple:
|
| 85 |
+
- keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
|
| 86 |
+
- scores (np.ndarray): The keypoint scores in shape (N, K).
|
| 87 |
+
It usually represents the confidence of the keypoint prediction
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
if encoded.shape[-1] == 2:
|
| 91 |
+
N, K, _ = encoded.shape
|
| 92 |
+
normalized_coords = encoded.copy()
|
| 93 |
+
scores = np.ones((N, K), dtype=np.float32)
|
| 94 |
+
elif encoded.shape[-1] == 4:
|
| 95 |
+
# split coords and sigma if outputs contain output_sigma
|
| 96 |
+
normalized_coords = encoded[..., :2].copy()
|
| 97 |
+
output_sigma = encoded[..., 2:4].copy()
|
| 98 |
+
|
| 99 |
+
scores = (1 - output_sigma).mean(axis=-1)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
'Keypoint dimension should be 2 or 4 (with sigma), '
|
| 103 |
+
f'but got {encoded.shape[-1]}')
|
| 104 |
+
|
| 105 |
+
w, h = self.input_size
|
| 106 |
+
keypoints = normalized_coords * np.array([w, h])
|
| 107 |
+
|
| 108 |
+
return keypoints, scores
|
mmpose/codecs/simcc_label.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from mmpose.codecs.utils import get_simcc_maximum
|
| 8 |
+
from mmpose.codecs.utils.refinement import refine_simcc_dark
|
| 9 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 10 |
+
from .base import BaseKeypointCodec
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@KEYPOINT_CODECS.register_module()
|
| 14 |
+
class SimCCLabel(BaseKeypointCodec):
|
| 15 |
+
r"""Generate keypoint representation via "SimCC" approach.
|
| 16 |
+
See the paper: `SimCC: a Simple Coordinate Classification Perspective for
|
| 17 |
+
Human Pose Estimation`_ by Li et al (2022) for more details.
|
| 18 |
+
Old name: SimDR
|
| 19 |
+
|
| 20 |
+
Note:
|
| 21 |
+
|
| 22 |
+
- instance number: N
|
| 23 |
+
- keypoint number: K
|
| 24 |
+
- keypoint dimension: D
|
| 25 |
+
- image size: [w, h]
|
| 26 |
+
|
| 27 |
+
Encoded:
|
| 28 |
+
|
| 29 |
+
- keypoint_x_labels (np.ndarray): The generated SimCC label for x-axis.
|
| 30 |
+
The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'``
|
| 31 |
+
and (N, K) if `smoothing_type=='standard'``, where
|
| 32 |
+
:math:`Wx=w*simcc_split_ratio`
|
| 33 |
+
- keypoint_y_labels (np.ndarray): The generated SimCC label for y-axis.
|
| 34 |
+
The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'``
|
| 35 |
+
and (N, K) if `smoothing_type=='standard'``, where
|
| 36 |
+
:math:`Wy=h*simcc_split_ratio`
|
| 37 |
+
- keypoint_weights (np.ndarray): The target weights in shape (N, K)
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
input_size (tuple): Input image size in [w, h]
|
| 41 |
+
smoothing_type (str): The SimCC label smoothing strategy. Options are
|
| 42 |
+
``'gaussian'`` and ``'standard'``. Defaults to ``'gaussian'``
|
| 43 |
+
sigma (float | int | tuple): The sigma value in the Gaussian SimCC
|
| 44 |
+
label. Defaults to 6.0
|
| 45 |
+
simcc_split_ratio (float): The ratio of the label size to the input
|
| 46 |
+
size. For example, if the input width is ``w``, the x label size
|
| 47 |
+
will be :math:`w*simcc_split_ratio`. Defaults to 2.0
|
| 48 |
+
label_smooth_weight (float): Label Smoothing weight. Defaults to 0.0
|
| 49 |
+
normalize (bool): Whether to normalize the heatmaps. Defaults to True.
|
| 50 |
+
use_dark (bool): Whether to use the DARK post processing. Defaults to
|
| 51 |
+
False.
|
| 52 |
+
decode_visibility (bool): Whether to decode the visibility. Defaults
|
| 53 |
+
to False.
|
| 54 |
+
decode_beta (float): The beta value for decoding visibility. Defaults
|
| 55 |
+
to 150.0.
|
| 56 |
+
|
| 57 |
+
.. _`SimCC: a Simple Coordinate Classification Perspective for Human Pose
|
| 58 |
+
Estimation`: https://arxiv.org/abs/2107.03332
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
label_mapping_table = dict(
|
| 62 |
+
keypoint_x_labels='keypoint_x_labels',
|
| 63 |
+
keypoint_y_labels='keypoint_y_labels',
|
| 64 |
+
keypoint_weights='keypoint_weights',
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
input_size: Tuple[int, int],
|
| 70 |
+
smoothing_type: str = 'gaussian',
|
| 71 |
+
sigma: Union[float, int, Tuple[float]] = 6.0,
|
| 72 |
+
simcc_split_ratio: float = 2.0,
|
| 73 |
+
label_smooth_weight: float = 0.0,
|
| 74 |
+
normalize: bool = True,
|
| 75 |
+
use_dark: bool = False,
|
| 76 |
+
decode_visibility: bool = False,
|
| 77 |
+
decode_beta: float = 150.0,
|
| 78 |
+
) -> None:
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.input_size = input_size
|
| 82 |
+
self.smoothing_type = smoothing_type
|
| 83 |
+
self.simcc_split_ratio = simcc_split_ratio
|
| 84 |
+
self.label_smooth_weight = label_smooth_weight
|
| 85 |
+
self.normalize = normalize
|
| 86 |
+
self.use_dark = use_dark
|
| 87 |
+
self.decode_visibility = decode_visibility
|
| 88 |
+
self.decode_beta = decode_beta
|
| 89 |
+
|
| 90 |
+
if isinstance(sigma, (float, int)):
|
| 91 |
+
self.sigma = np.array([sigma, sigma])
|
| 92 |
+
else:
|
| 93 |
+
self.sigma = np.array(sigma)
|
| 94 |
+
|
| 95 |
+
if self.smoothing_type not in {'gaussian', 'standard'}:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f'{self.__class__.__name__} got invalid `smoothing_type` value'
|
| 98 |
+
f'{self.smoothing_type}. Should be one of '
|
| 99 |
+
'{"gaussian", "standard"}')
|
| 100 |
+
|
| 101 |
+
if self.smoothing_type == 'gaussian' and self.label_smooth_weight > 0:
|
| 102 |
+
raise ValueError('Attribute `label_smooth_weight` is only '
|
| 103 |
+
'used for `standard` mode.')
|
| 104 |
+
|
| 105 |
+
if self.label_smooth_weight < 0.0 or self.label_smooth_weight > 1.0:
|
| 106 |
+
raise ValueError('`label_smooth_weight` should be in range [0, 1]')
|
| 107 |
+
|
| 108 |
+
def encode(self,
|
| 109 |
+
keypoints: np.ndarray,
|
| 110 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 111 |
+
"""Encoding keypoints into SimCC labels. Note that the original
|
| 112 |
+
keypoint coordinates should be in the input image space.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 116 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 117 |
+
(N, K)
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
dict:
|
| 121 |
+
- keypoint_x_labels (np.ndarray): The generated SimCC label for
|
| 122 |
+
x-axis.
|
| 123 |
+
The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'``
|
| 124 |
+
and (N, K) if `smoothing_type=='standard'``, where
|
| 125 |
+
:math:`Wx=w*simcc_split_ratio`
|
| 126 |
+
- keypoint_y_labels (np.ndarray): The generated SimCC label for
|
| 127 |
+
y-axis.
|
| 128 |
+
The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'``
|
| 129 |
+
and (N, K) if `smoothing_type=='standard'``, where
|
| 130 |
+
:math:`Wy=h*simcc_split_ratio`
|
| 131 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 132 |
+
(N, K)
|
| 133 |
+
"""
|
| 134 |
+
if keypoints_visible is None:
|
| 135 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 136 |
+
|
| 137 |
+
if self.smoothing_type == 'gaussian':
|
| 138 |
+
x_labels, y_labels, keypoint_weights = self._generate_gaussian(
|
| 139 |
+
keypoints, keypoints_visible)
|
| 140 |
+
elif self.smoothing_type == 'standard':
|
| 141 |
+
x_labels, y_labels, keypoint_weights = self._generate_standard(
|
| 142 |
+
keypoints, keypoints_visible)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f'{self.__class__.__name__} got invalid `smoothing_type` value'
|
| 146 |
+
f'{self.smoothing_type}. Should be one of '
|
| 147 |
+
'{"gaussian", "standard"}')
|
| 148 |
+
|
| 149 |
+
encoded = dict(
|
| 150 |
+
keypoint_x_labels=x_labels,
|
| 151 |
+
keypoint_y_labels=y_labels,
|
| 152 |
+
keypoint_weights=keypoint_weights)
|
| 153 |
+
|
| 154 |
+
return encoded
|
| 155 |
+
|
| 156 |
+
def decode(self, simcc_x: np.ndarray,
|
| 157 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 158 |
+
"""Decode keypoint coordinates from SimCC representations. The decoded
|
| 159 |
+
coordinates are in the input image space.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
encoded (Tuple[np.ndarray, np.ndarray]): SimCC labels for x-axis
|
| 163 |
+
and y-axis
|
| 164 |
+
simcc_x (np.ndarray): SimCC label for x-axis
|
| 165 |
+
simcc_y (np.ndarray): SimCC label for y-axis
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
tuple:
|
| 169 |
+
- keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
|
| 170 |
+
- socres (np.ndarray): The keypoint scores in shape (N, K).
|
| 171 |
+
It usually represents the confidence of the keypoint prediction
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
| 175 |
+
|
| 176 |
+
# Unsqueeze the instance dimension for single-instance results
|
| 177 |
+
if keypoints.ndim == 2:
|
| 178 |
+
keypoints = keypoints[None, :]
|
| 179 |
+
scores = scores[None, :]
|
| 180 |
+
|
| 181 |
+
if self.use_dark:
|
| 182 |
+
x_blur = int((self.sigma[0] * 20 - 7) // 3)
|
| 183 |
+
y_blur = int((self.sigma[1] * 20 - 7) // 3)
|
| 184 |
+
x_blur -= int((x_blur % 2) == 0)
|
| 185 |
+
y_blur -= int((y_blur % 2) == 0)
|
| 186 |
+
keypoints[:, :, 0] = refine_simcc_dark(keypoints[:, :, 0], simcc_x,
|
| 187 |
+
x_blur)
|
| 188 |
+
keypoints[:, :, 1] = refine_simcc_dark(keypoints[:, :, 1], simcc_y,
|
| 189 |
+
y_blur)
|
| 190 |
+
|
| 191 |
+
keypoints /= self.simcc_split_ratio
|
| 192 |
+
|
| 193 |
+
if self.decode_visibility:
|
| 194 |
+
_, visibility = get_simcc_maximum(
|
| 195 |
+
simcc_x * self.decode_beta * self.sigma[0],
|
| 196 |
+
simcc_y * self.decode_beta * self.sigma[1],
|
| 197 |
+
apply_softmax=True)
|
| 198 |
+
return keypoints, (scores, visibility)
|
| 199 |
+
else:
|
| 200 |
+
return keypoints, scores
|
| 201 |
+
|
| 202 |
+
def _map_coordinates(
|
| 203 |
+
self,
|
| 204 |
+
keypoints: np.ndarray,
|
| 205 |
+
keypoints_visible: Optional[np.ndarray] = None
|
| 206 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 207 |
+
"""Mapping keypoint coordinates into SimCC space."""
|
| 208 |
+
|
| 209 |
+
keypoints_split = keypoints.copy()
|
| 210 |
+
keypoints_split = np.around(keypoints_split * self.simcc_split_ratio)
|
| 211 |
+
keypoints_split = keypoints_split.astype(np.int64)
|
| 212 |
+
keypoint_weights = keypoints_visible.copy()
|
| 213 |
+
|
| 214 |
+
return keypoints_split, keypoint_weights
|
| 215 |
+
|
| 216 |
+
def _generate_standard(
|
| 217 |
+
self,
|
| 218 |
+
keypoints: np.ndarray,
|
| 219 |
+
keypoints_visible: Optional[np.ndarray] = None
|
| 220 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 221 |
+
"""Encoding keypoints into SimCC labels with Standard Label Smoothing
|
| 222 |
+
strategy.
|
| 223 |
+
|
| 224 |
+
Labels will be one-hot vectors if self.label_smooth_weight==0.0
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
N, K, _ = keypoints.shape
|
| 228 |
+
w, h = self.input_size
|
| 229 |
+
W = np.around(w * self.simcc_split_ratio).astype(int)
|
| 230 |
+
H = np.around(h * self.simcc_split_ratio).astype(int)
|
| 231 |
+
|
| 232 |
+
keypoints_split, keypoint_weights = self._map_coordinates(
|
| 233 |
+
keypoints, keypoints_visible)
|
| 234 |
+
|
| 235 |
+
target_x = np.zeros((N, K, W), dtype=np.float32)
|
| 236 |
+
target_y = np.zeros((N, K, H), dtype=np.float32)
|
| 237 |
+
|
| 238 |
+
for n, k in product(range(N), range(K)):
|
| 239 |
+
# skip unlabled keypoints
|
| 240 |
+
if keypoints_visible[n, k] < 0.5:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
# get center coordinates
|
| 244 |
+
mu_x, mu_y = keypoints_split[n, k].astype(np.int64)
|
| 245 |
+
|
| 246 |
+
# detect abnormal coords and assign the weight 0
|
| 247 |
+
if mu_x >= W or mu_y >= H or mu_x < 0 or mu_y < 0:
|
| 248 |
+
keypoint_weights[n, k] = 0
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
if self.label_smooth_weight > 0:
|
| 252 |
+
target_x[n, k] = self.label_smooth_weight / (W - 1)
|
| 253 |
+
target_y[n, k] = self.label_smooth_weight / (H - 1)
|
| 254 |
+
|
| 255 |
+
target_x[n, k, mu_x] = 1.0 - self.label_smooth_weight
|
| 256 |
+
target_y[n, k, mu_y] = 1.0 - self.label_smooth_weight
|
| 257 |
+
|
| 258 |
+
return target_x, target_y, keypoint_weights
|
| 259 |
+
|
| 260 |
+
def _generate_gaussian(
|
| 261 |
+
self,
|
| 262 |
+
keypoints: np.ndarray,
|
| 263 |
+
keypoints_visible: Optional[np.ndarray] = None
|
| 264 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 265 |
+
"""Encoding keypoints into SimCC labels with Gaussian Label Smoothing
|
| 266 |
+
strategy."""
|
| 267 |
+
|
| 268 |
+
N, K, _ = keypoints.shape
|
| 269 |
+
w, h = self.input_size
|
| 270 |
+
W = np.around(w * self.simcc_split_ratio).astype(int)
|
| 271 |
+
H = np.around(h * self.simcc_split_ratio).astype(int)
|
| 272 |
+
|
| 273 |
+
keypoints_split, keypoint_weights = self._map_coordinates(
|
| 274 |
+
keypoints, keypoints_visible)
|
| 275 |
+
|
| 276 |
+
target_x = np.zeros((N, K, W), dtype=np.float32)
|
| 277 |
+
target_y = np.zeros((N, K, H), dtype=np.float32)
|
| 278 |
+
|
| 279 |
+
# 3-sigma rule
|
| 280 |
+
radius = self.sigma * 3
|
| 281 |
+
|
| 282 |
+
# xy grid
|
| 283 |
+
x = np.arange(0, W, 1, dtype=np.float32)
|
| 284 |
+
y = np.arange(0, H, 1, dtype=np.float32)
|
| 285 |
+
|
| 286 |
+
for n, k in product(range(N), range(K)):
|
| 287 |
+
# skip unlabled keypoints
|
| 288 |
+
if keypoints_visible[n, k] < 0.5:
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
mu = keypoints_split[n, k]
|
| 292 |
+
|
| 293 |
+
# check that the gaussian has in-bounds part
|
| 294 |
+
left, top = mu - radius
|
| 295 |
+
right, bottom = mu + radius + 1
|
| 296 |
+
|
| 297 |
+
if left >= W or top >= H or right < 0 or bottom < 0:
|
| 298 |
+
keypoint_weights[n, k] = 0
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
mu_x, mu_y = mu
|
| 302 |
+
|
| 303 |
+
target_x[n, k] = np.exp(-((x - mu_x)**2) / (2 * self.sigma[0]**2))
|
| 304 |
+
target_y[n, k] = np.exp(-((y - mu_y)**2) / (2 * self.sigma[1]**2))
|
| 305 |
+
|
| 306 |
+
if self.normalize:
|
| 307 |
+
norm_value = self.sigma * np.sqrt(np.pi * 2)
|
| 308 |
+
target_x /= norm_value[0]
|
| 309 |
+
target_y /= norm_value[1]
|
| 310 |
+
|
| 311 |
+
return target_x, target_y, keypoint_weights
|
mmpose/codecs/spr.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 9 |
+
from .base import BaseKeypointCodec
|
| 10 |
+
from .utils import (batch_heatmap_nms, generate_displacement_heatmap,
|
| 11 |
+
generate_gaussian_heatmaps, get_diagonal_lengths,
|
| 12 |
+
get_instance_root)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@KEYPOINT_CODECS.register_module()
|
| 16 |
+
class SPR(BaseKeypointCodec):
|
| 17 |
+
"""Encode/decode keypoints with Structured Pose Representation (SPR).
|
| 18 |
+
|
| 19 |
+
See the paper `Single-stage multi-person pose machines`_
|
| 20 |
+
by Nie et al (2017) for details
|
| 21 |
+
|
| 22 |
+
Note:
|
| 23 |
+
|
| 24 |
+
- instance number: N
|
| 25 |
+
- keypoint number: K
|
| 26 |
+
- keypoint dimension: D
|
| 27 |
+
- image size: [w, h]
|
| 28 |
+
- heatmap size: [W, H]
|
| 29 |
+
|
| 30 |
+
Encoded:
|
| 31 |
+
|
| 32 |
+
- heatmaps (np.ndarray): The generated heatmap in shape (1, H, W)
|
| 33 |
+
where [W, H] is the `heatmap_size`. If the keypoint heatmap is
|
| 34 |
+
generated together, the output heatmap shape is (K+1, H, W)
|
| 35 |
+
- heatmap_weights (np.ndarray): The target weights for heatmaps which
|
| 36 |
+
has same shape with heatmaps.
|
| 37 |
+
- displacements (np.ndarray): The dense keypoint displacement in
|
| 38 |
+
shape (K*2, H, W).
|
| 39 |
+
- displacement_weights (np.ndarray): The target weights for heatmaps
|
| 40 |
+
which has same shape with displacements.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
input_size (tuple): Image size in [w, h]
|
| 44 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 45 |
+
sigma (float or tuple, optional): The sigma values of the Gaussian
|
| 46 |
+
heatmaps. If sigma is a tuple, it includes both sigmas for root
|
| 47 |
+
and keypoint heatmaps. ``None`` means the sigmas are computed
|
| 48 |
+
automatically from the heatmap size. Defaults to ``None``
|
| 49 |
+
generate_keypoint_heatmaps (bool): Whether to generate Gaussian
|
| 50 |
+
heatmaps for each keypoint. Defaults to ``False``
|
| 51 |
+
root_type (str): The method to generate the instance root. Options
|
| 52 |
+
are:
|
| 53 |
+
|
| 54 |
+
- ``'kpt_center'``: Average coordinate of all visible keypoints.
|
| 55 |
+
- ``'bbox_center'``: Center point of bounding boxes outlined by
|
| 56 |
+
all visible keypoints.
|
| 57 |
+
|
| 58 |
+
Defaults to ``'kpt_center'``
|
| 59 |
+
|
| 60 |
+
minimal_diagonal_length (int or float): The threshold of diagonal
|
| 61 |
+
length of instance bounding box. Small instances will not be
|
| 62 |
+
used in training. Defaults to 32
|
| 63 |
+
background_weight (float): Loss weight of background pixels.
|
| 64 |
+
Defaults to 0.1
|
| 65 |
+
decode_thr (float): The threshold of keypoint response value in
|
| 66 |
+
heatmaps. Defaults to 0.01
|
| 67 |
+
decode_nms_kernel (int): The kernel size of the NMS during decoding,
|
| 68 |
+
which should be an odd integer. Defaults to 5
|
| 69 |
+
decode_max_instances (int): The maximum number of instances
|
| 70 |
+
to decode. Defaults to 30
|
| 71 |
+
|
| 72 |
+
.. _`Single-stage multi-person pose machines`:
|
| 73 |
+
https://arxiv.org/abs/1908.09220
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
field_mapping_table = dict(
|
| 77 |
+
heatmaps='heatmaps',
|
| 78 |
+
heatmap_weights='heatmap_weights',
|
| 79 |
+
displacements='displacements',
|
| 80 |
+
displacement_weights='displacement_weights',
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
input_size: Tuple[int, int],
|
| 86 |
+
heatmap_size: Tuple[int, int],
|
| 87 |
+
sigma: Optional[Union[float, Tuple[float]]] = None,
|
| 88 |
+
generate_keypoint_heatmaps: bool = False,
|
| 89 |
+
root_type: str = 'kpt_center',
|
| 90 |
+
minimal_diagonal_length: Union[int, float] = 5,
|
| 91 |
+
background_weight: float = 0.1,
|
| 92 |
+
decode_nms_kernel: int = 5,
|
| 93 |
+
decode_max_instances: int = 30,
|
| 94 |
+
decode_thr: float = 0.01,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.input_size = input_size
|
| 99 |
+
self.heatmap_size = heatmap_size
|
| 100 |
+
self.generate_keypoint_heatmaps = generate_keypoint_heatmaps
|
| 101 |
+
self.root_type = root_type
|
| 102 |
+
self.minimal_diagonal_length = minimal_diagonal_length
|
| 103 |
+
self.background_weight = background_weight
|
| 104 |
+
self.decode_nms_kernel = decode_nms_kernel
|
| 105 |
+
self.decode_max_instances = decode_max_instances
|
| 106 |
+
self.decode_thr = decode_thr
|
| 107 |
+
|
| 108 |
+
self.scale_factor = (np.array(input_size) /
|
| 109 |
+
heatmap_size).astype(np.float32)
|
| 110 |
+
|
| 111 |
+
if sigma is None:
|
| 112 |
+
sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32
|
| 113 |
+
if generate_keypoint_heatmaps:
|
| 114 |
+
# sigma for root heatmap and keypoint heatmaps
|
| 115 |
+
self.sigma = (sigma, sigma // 2)
|
| 116 |
+
else:
|
| 117 |
+
self.sigma = (sigma, )
|
| 118 |
+
else:
|
| 119 |
+
if not isinstance(sigma, (tuple, list)):
|
| 120 |
+
sigma = (sigma, )
|
| 121 |
+
if generate_keypoint_heatmaps:
|
| 122 |
+
assert len(sigma) == 2, 'sigma for keypoints must be given ' \
|
| 123 |
+
'if `generate_keypoint_heatmaps` ' \
|
| 124 |
+
'is True. e.g. sigma=(4, 2)'
|
| 125 |
+
self.sigma = sigma
|
| 126 |
+
|
| 127 |
+
def _get_heatmap_weights(self,
|
| 128 |
+
heatmaps,
|
| 129 |
+
fg_weight: float = 1,
|
| 130 |
+
bg_weight: float = 0):
|
| 131 |
+
"""Generate weight array for heatmaps.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
heatmaps (np.ndarray): Root and keypoint (optional) heatmaps
|
| 135 |
+
fg_weight (float): Weight for foreground pixels. Defaults to 1.0
|
| 136 |
+
bg_weight (float): Weight for background pixels. Defaults to 0.0
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
np.ndarray: Heatmap weight array in the same shape with heatmaps
|
| 140 |
+
"""
|
| 141 |
+
heatmap_weights = np.ones(heatmaps.shape, dtype=np.float32) * bg_weight
|
| 142 |
+
heatmap_weights[heatmaps > 0] = fg_weight
|
| 143 |
+
return heatmap_weights
|
| 144 |
+
|
| 145 |
+
def encode(self,
|
| 146 |
+
keypoints: np.ndarray,
|
| 147 |
+
keypoints_visible: Optional[np.ndarray] = None) -> dict:
|
| 148 |
+
"""Encode keypoints into root heatmaps and keypoint displacement
|
| 149 |
+
fields. Note that the original keypoint coordinates should be in the
|
| 150 |
+
input image space.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 154 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 155 |
+
(N, K)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
dict:
|
| 159 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 160 |
+
(1, H, W) where [W, H] is the `heatmap_size`. If keypoint
|
| 161 |
+
heatmaps are generated together, the shape is (K+1, H, W)
|
| 162 |
+
- heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps
|
| 163 |
+
which has same shape with `heatmaps`
|
| 164 |
+
- displacements (np.ndarray): The generated displacement fields in
|
| 165 |
+
shape (K*D, H, W). The vector on each pixels represents the
|
| 166 |
+
displacement of keypoints belong to the associated instance
|
| 167 |
+
from this pixel.
|
| 168 |
+
- displacement_weights (np.ndarray): The pixel-wise weight for
|
| 169 |
+
displacements which has same shape with `displacements`
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
if keypoints_visible is None:
|
| 173 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 174 |
+
|
| 175 |
+
# keypoint coordinates in heatmap
|
| 176 |
+
_keypoints = keypoints / self.scale_factor
|
| 177 |
+
|
| 178 |
+
# compute the root and scale of each instance
|
| 179 |
+
roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
|
| 180 |
+
self.root_type)
|
| 181 |
+
diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
|
| 182 |
+
|
| 183 |
+
# discard the small instances
|
| 184 |
+
roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0
|
| 185 |
+
|
| 186 |
+
# generate heatmaps
|
| 187 |
+
heatmaps, _ = generate_gaussian_heatmaps(
|
| 188 |
+
heatmap_size=self.heatmap_size,
|
| 189 |
+
keypoints=roots[:, None],
|
| 190 |
+
keypoints_visible=roots_visible[:, None],
|
| 191 |
+
sigma=self.sigma[0])
|
| 192 |
+
heatmap_weights = self._get_heatmap_weights(
|
| 193 |
+
heatmaps, bg_weight=self.background_weight)
|
| 194 |
+
|
| 195 |
+
if self.generate_keypoint_heatmaps:
|
| 196 |
+
keypoint_heatmaps, _ = generate_gaussian_heatmaps(
|
| 197 |
+
heatmap_size=self.heatmap_size,
|
| 198 |
+
keypoints=_keypoints,
|
| 199 |
+
keypoints_visible=keypoints_visible,
|
| 200 |
+
sigma=self.sigma[1])
|
| 201 |
+
|
| 202 |
+
keypoint_heatmaps_weights = self._get_heatmap_weights(
|
| 203 |
+
keypoint_heatmaps, bg_weight=self.background_weight)
|
| 204 |
+
|
| 205 |
+
heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0)
|
| 206 |
+
heatmap_weights = np.concatenate(
|
| 207 |
+
(keypoint_heatmaps_weights, heatmap_weights), axis=0)
|
| 208 |
+
|
| 209 |
+
# generate displacements
|
| 210 |
+
displacements, displacement_weights = \
|
| 211 |
+
generate_displacement_heatmap(
|
| 212 |
+
self.heatmap_size,
|
| 213 |
+
_keypoints,
|
| 214 |
+
keypoints_visible,
|
| 215 |
+
roots,
|
| 216 |
+
roots_visible,
|
| 217 |
+
diagonal_lengths,
|
| 218 |
+
self.sigma[0],
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
encoded = dict(
|
| 222 |
+
heatmaps=heatmaps,
|
| 223 |
+
heatmap_weights=heatmap_weights,
|
| 224 |
+
displacements=displacements,
|
| 225 |
+
displacement_weights=displacement_weights)
|
| 226 |
+
|
| 227 |
+
return encoded
|
| 228 |
+
|
| 229 |
+
def decode(self, heatmaps: Tensor,
|
| 230 |
+
displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]:
|
| 231 |
+
"""Decode the keypoint coordinates from heatmaps and displacements. The
|
| 232 |
+
decoded keypoint coordinates are in the input image space.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps
|
| 236 |
+
in shape (1, H, W) or (K+1, H, W)
|
| 237 |
+
displacements (Tensor): Encoded keypoints displacement fields
|
| 238 |
+
in shape (K*D, H, W)
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
tuple:
|
| 242 |
+
- keypoints (Tensor): Decoded keypoint coordinates in shape
|
| 243 |
+
(N, K, D)
|
| 244 |
+
- scores (tuple):
|
| 245 |
+
- root_scores (Tensor): The root scores in shape (N, )
|
| 246 |
+
- keypoint_scores (Tensor): The keypoint scores in
|
| 247 |
+
shape (N, K). If keypoint heatmaps are not generated,
|
| 248 |
+
`keypoint_scores` will be `None`
|
| 249 |
+
"""
|
| 250 |
+
# heatmaps, displacements = encoded
|
| 251 |
+
_k, h, w = displacements.shape
|
| 252 |
+
k = _k // 2
|
| 253 |
+
displacements = displacements.view(k, 2, h, w)
|
| 254 |
+
|
| 255 |
+
# convert displacements to a dense keypoint prediction
|
| 256 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
| 257 |
+
regular_grid = torch.stack([x, y], dim=0).to(displacements)
|
| 258 |
+
posemaps = (regular_grid[None] + displacements).flatten(2)
|
| 259 |
+
|
| 260 |
+
# find local maximum on root heatmap
|
| 261 |
+
root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:],
|
| 262 |
+
self.decode_nms_kernel)
|
| 263 |
+
root_scores, pos_idx = root_heatmap_peaks.flatten().topk(
|
| 264 |
+
self.decode_max_instances)
|
| 265 |
+
mask = root_scores > self.decode_thr
|
| 266 |
+
root_scores, pos_idx = root_scores[mask], pos_idx[mask]
|
| 267 |
+
|
| 268 |
+
keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous()
|
| 269 |
+
|
| 270 |
+
if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k:
|
| 271 |
+
# compute scores for each keypoint
|
| 272 |
+
keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints)
|
| 273 |
+
else:
|
| 274 |
+
keypoint_scores = None
|
| 275 |
+
|
| 276 |
+
keypoints = torch.cat([
|
| 277 |
+
kpt * self.scale_factor[i]
|
| 278 |
+
for i, kpt in enumerate(keypoints.split(1, -1))
|
| 279 |
+
],
|
| 280 |
+
dim=-1)
|
| 281 |
+
return keypoints, (root_scores, keypoint_scores)
|
| 282 |
+
|
| 283 |
+
def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor):
|
| 284 |
+
"""Calculate the keypoint scores with keypoints heatmaps and
|
| 285 |
+
coordinates.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W)
|
| 289 |
+
keypoints (Tensor): Keypoint coordinates in shape (N, K, D)
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Tensor: Keypoint scores in [N, K]
|
| 293 |
+
"""
|
| 294 |
+
k, h, w = heatmaps.shape
|
| 295 |
+
keypoints = torch.stack((
|
| 296 |
+
keypoints[..., 0] / (w - 1) * 2 - 1,
|
| 297 |
+
keypoints[..., 1] / (h - 1) * 2 - 1,
|
| 298 |
+
),
|
| 299 |
+
dim=-1)
|
| 300 |
+
keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous()
|
| 301 |
+
|
| 302 |
+
keypoint_scores = torch.nn.functional.grid_sample(
|
| 303 |
+
heatmaps.unsqueeze(1), keypoints,
|
| 304 |
+
padding_mode='border').view(k, -1).transpose(0, 1).contiguous()
|
| 305 |
+
|
| 306 |
+
return keypoint_scores
|
mmpose/codecs/udp_heatmap.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from mmpose.registry import KEYPOINT_CODECS
|
| 8 |
+
from .base import BaseKeypointCodec
|
| 9 |
+
from .utils import (generate_offset_heatmap, generate_udp_gaussian_heatmaps,
|
| 10 |
+
get_heatmap_maximum, refine_keypoints_dark_udp)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@KEYPOINT_CODECS.register_module()
|
| 14 |
+
class UDPHeatmap(BaseKeypointCodec):
|
| 15 |
+
r"""Generate keypoint heatmaps by Unbiased Data Processing (UDP).
|
| 16 |
+
See the paper: `The Devil is in the Details: Delving into Unbiased Data
|
| 17 |
+
Processing for Human Pose Estimation`_ by Huang et al (2020) for details.
|
| 18 |
+
|
| 19 |
+
Note:
|
| 20 |
+
|
| 21 |
+
- instance number: N
|
| 22 |
+
- keypoint number: K
|
| 23 |
+
- keypoint dimension: D
|
| 24 |
+
- image size: [w, h]
|
| 25 |
+
- heatmap size: [W, H]
|
| 26 |
+
|
| 27 |
+
Encoded:
|
| 28 |
+
|
| 29 |
+
- heatmap (np.ndarray): The generated heatmap in shape (C_out, H, W)
|
| 30 |
+
where [W, H] is the `heatmap_size`, and the C_out is the output
|
| 31 |
+
channel number which depends on the `heatmap_type`. If
|
| 32 |
+
`heatmap_type=='gaussian'`, C_out equals to keypoint number K;
|
| 33 |
+
if `heatmap_type=='combined'`, C_out equals to K*3
|
| 34 |
+
(x_offset, y_offset and class label)
|
| 35 |
+
- keypoint_weights (np.ndarray): The target weights in shape (K,)
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
input_size (tuple): Image size in [w, h]
|
| 39 |
+
heatmap_size (tuple): Heatmap size in [W, H]
|
| 40 |
+
heatmap_type (str): The heatmap type to encode the keypoitns. Options
|
| 41 |
+
are:
|
| 42 |
+
|
| 43 |
+
- ``'gaussian'``: Gaussian heatmap
|
| 44 |
+
- ``'combined'``: Combination of a binary label map and offset
|
| 45 |
+
maps for X and Y axes.
|
| 46 |
+
|
| 47 |
+
sigma (float): The sigma value of the Gaussian heatmap when
|
| 48 |
+
``heatmap_type=='gaussian'``. Defaults to 2.0
|
| 49 |
+
radius_factor (float): The radius factor of the binary label
|
| 50 |
+
map when ``heatmap_type=='combined'``. The positive region is
|
| 51 |
+
defined as the neighbor of the keypoit with the radius
|
| 52 |
+
:math:`r=radius_factor*max(W, H)`. Defaults to 0.0546875
|
| 53 |
+
blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
|
| 54 |
+
modulation in DarkPose. Defaults to 11
|
| 55 |
+
|
| 56 |
+
.. _`The Devil is in the Details: Delving into Unbiased Data Processing for
|
| 57 |
+
Human Pose Estimation`: https://arxiv.org/abs/1911.07524
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
label_mapping_table = dict(keypoint_weights='keypoint_weights', )
|
| 61 |
+
field_mapping_table = dict(heatmaps='heatmaps', )
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
input_size: Tuple[int, int],
|
| 65 |
+
heatmap_size: Tuple[int, int],
|
| 66 |
+
heatmap_type: str = 'gaussian',
|
| 67 |
+
sigma: float = 2.,
|
| 68 |
+
radius_factor: float = 0.0546875,
|
| 69 |
+
blur_kernel_size: int = 11,
|
| 70 |
+
increase_sigma_with_padding=False,
|
| 71 |
+
amap_scale: float = 1.0,
|
| 72 |
+
normalize=None,
|
| 73 |
+
) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.input_size = np.array(input_size)
|
| 76 |
+
self.heatmap_size = np.array(heatmap_size)
|
| 77 |
+
self.sigma = sigma
|
| 78 |
+
self.radius_factor = radius_factor
|
| 79 |
+
self.heatmap_type = heatmap_type
|
| 80 |
+
self.blur_kernel_size = blur_kernel_size
|
| 81 |
+
self.increase_sigma_with_padding = increase_sigma_with_padding
|
| 82 |
+
self.normalize = normalize
|
| 83 |
+
|
| 84 |
+
self.amap_size = self.input_size * amap_scale
|
| 85 |
+
self.scale_factor = ((self.amap_size - 1) /
|
| 86 |
+
(self.heatmap_size - 1)).astype(np.float32)
|
| 87 |
+
self.input_center = self.input_size / 2
|
| 88 |
+
self.top_left = self.input_center - self.amap_size / 2
|
| 89 |
+
|
| 90 |
+
if self.heatmap_type not in {'gaussian', 'combined'}:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f'{self.__class__.__name__} got invalid `heatmap_type` value'
|
| 93 |
+
f'{self.heatmap_type}. Should be one of '
|
| 94 |
+
'{"gaussian", "combined"}')
|
| 95 |
+
|
| 96 |
+
def _kpts_to_activation_pts(self, keypoints: np.ndarray) -> np.ndarray:
|
| 97 |
+
"""
|
| 98 |
+
Transform the keypoint coordinates to the activation space.
|
| 99 |
+
In the original UDPHeatmap, activation map is the same as the input image space with
|
| 100 |
+
different resolution but in this case we allow the activation map to have different
|
| 101 |
+
size (padding) than the input image space.
|
| 102 |
+
Centers of activation map and input image space are aligned.
|
| 103 |
+
"""
|
| 104 |
+
transformed_keypoints = keypoints - self.top_left
|
| 105 |
+
transformed_keypoints = transformed_keypoints / self.scale_factor
|
| 106 |
+
return transformed_keypoints
|
| 107 |
+
|
| 108 |
+
def _activation_pts_to_kpts(self, keypoints: np.ndarray) -> np.ndarray:
|
| 109 |
+
"""
|
| 110 |
+
Transform the points in activation map to the keypoint coordinates.
|
| 111 |
+
In the original UDPHeatmap, activation map is the same as the input image space with
|
| 112 |
+
different resolution but in this case we allow the activation map to have different
|
| 113 |
+
size (padding) than the input image space.
|
| 114 |
+
Centers of activation map and input image space are aligned.
|
| 115 |
+
"""
|
| 116 |
+
W, H = self.heatmap_size
|
| 117 |
+
transformed_keypoints = keypoints / [W - 1, H - 1] * self.amap_size
|
| 118 |
+
transformed_keypoints += self.top_left
|
| 119 |
+
return transformed_keypoints
|
| 120 |
+
|
| 121 |
+
def encode(self,
|
| 122 |
+
keypoints: np.ndarray,
|
| 123 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 124 |
+
id_similarity: Optional[float] = 0.0,
|
| 125 |
+
keypoints_visibility: Optional[np.ndarray] = None) -> dict:
|
| 126 |
+
"""Encode keypoints into heatmaps. Note that the original keypoint
|
| 127 |
+
coordinates should be in the input image space.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 131 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 132 |
+
(N, K)
|
| 133 |
+
id_similarity (float): The usefulness of the identity information
|
| 134 |
+
for the whole pose. Defaults to 0.0
|
| 135 |
+
keypoints_visibility (np.ndarray): The visibility bit for each
|
| 136 |
+
keypoint (N, K). Defaults to None
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
dict:
|
| 140 |
+
- heatmap (np.ndarray): The generated heatmap in shape
|
| 141 |
+
(C_out, H, W) where [W, H] is the `heatmap_size`, and the
|
| 142 |
+
C_out is the output channel number which depends on the
|
| 143 |
+
`heatmap_type`. If `heatmap_type=='gaussian'`, C_out equals to
|
| 144 |
+
keypoint number K; if `heatmap_type=='combined'`, C_out
|
| 145 |
+
equals to K*3 (x_offset, y_offset and class label)
|
| 146 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 147 |
+
(K,)
|
| 148 |
+
"""
|
| 149 |
+
assert keypoints.shape[0] == 1, (
|
| 150 |
+
f'{self.__class__.__name__} only support single-instance '
|
| 151 |
+
'keypoint encoding')
|
| 152 |
+
|
| 153 |
+
if keypoints_visibility is None:
|
| 154 |
+
keypoints_visibility = np.zeros(keypoints.shape[:2], dtype=np.float32)
|
| 155 |
+
|
| 156 |
+
if keypoints_visible is None:
|
| 157 |
+
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
|
| 158 |
+
|
| 159 |
+
if self.heatmap_type == 'gaussian':
|
| 160 |
+
heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps(
|
| 161 |
+
heatmap_size=self.heatmap_size,
|
| 162 |
+
keypoints=self._kpts_to_activation_pts(keypoints),
|
| 163 |
+
keypoints_visible=keypoints_visible,
|
| 164 |
+
sigma=self.sigma,
|
| 165 |
+
keypoints_visibility=keypoints_visibility,
|
| 166 |
+
increase_sigma_with_padding=self.increase_sigma_with_padding)
|
| 167 |
+
elif self.heatmap_type == 'combined':
|
| 168 |
+
heatmaps, keypoint_weights = generate_offset_heatmap(
|
| 169 |
+
heatmap_size=self.heatmap_size,
|
| 170 |
+
keypoints=self._kpts_to_activation_pts(keypoints),
|
| 171 |
+
keypoints_visible=keypoints_visible,
|
| 172 |
+
radius_factor=self.radius_factor)
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f'{self.__class__.__name__} got invalid `heatmap_type` value'
|
| 176 |
+
f'{self.heatmap_type}. Should be one of '
|
| 177 |
+
'{"gaussian", "combined"}')
|
| 178 |
+
|
| 179 |
+
if self.normalize is not None:
|
| 180 |
+
heatmaps_sum = np.sum(heatmaps, axis=(1, 2), keepdims=False)
|
| 181 |
+
mask = heatmaps_sum > 0
|
| 182 |
+
heatmaps[mask, :, :] = heatmaps[mask, :, :] / (heatmaps_sum[mask, None, None] + np.finfo(np.float32).eps)
|
| 183 |
+
heatmaps = heatmaps * self.normalize
|
| 184 |
+
|
| 185 |
+
annotated = keypoints_visible > 0
|
| 186 |
+
|
| 187 |
+
heatmap_keypoints = self._kpts_to_activation_pts(keypoints)
|
| 188 |
+
in_image = np.logical_and(
|
| 189 |
+
heatmap_keypoints[:, :, 0] >= 0,
|
| 190 |
+
heatmap_keypoints[:, :, 0] < self.heatmap_size[0],
|
| 191 |
+
)
|
| 192 |
+
in_image = np.logical_and(
|
| 193 |
+
in_image,
|
| 194 |
+
heatmap_keypoints[:, :, 1] >= 0,
|
| 195 |
+
)
|
| 196 |
+
in_image = np.logical_and(
|
| 197 |
+
in_image,
|
| 198 |
+
heatmap_keypoints[:, :, 1] < self.heatmap_size[1],
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
encoded = dict(
|
| 202 |
+
heatmaps=heatmaps,
|
| 203 |
+
keypoint_weights=keypoint_weights,
|
| 204 |
+
annotated=annotated,
|
| 205 |
+
in_image=in_image,
|
| 206 |
+
keypoints_scaled=keypoints,
|
| 207 |
+
heatmap_keypoints=heatmap_keypoints,
|
| 208 |
+
identification_similarity=id_similarity,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return encoded
|
| 212 |
+
|
| 213 |
+
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 214 |
+
"""Decode keypoint coordinates from heatmaps. The decoded keypoint
|
| 215 |
+
coordinates are in the input image space.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
encoded (np.ndarray): Heatmaps in shape (K, H, W)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
tuple:
|
| 222 |
+
- keypoints (np.ndarray): Decoded keypoint coordinates in shape
|
| 223 |
+
(N, K, D)
|
| 224 |
+
- scores (np.ndarray): The keypoint scores in shape (N, K). It
|
| 225 |
+
usually represents the confidence of the keypoint prediction
|
| 226 |
+
"""
|
| 227 |
+
heatmaps = encoded.copy()
|
| 228 |
+
|
| 229 |
+
if self.heatmap_type == 'gaussian':
|
| 230 |
+
keypoints, scores = get_heatmap_maximum(heatmaps)
|
| 231 |
+
# unsqueeze the instance dimension for single-instance results
|
| 232 |
+
keypoints = keypoints[None]
|
| 233 |
+
scores = scores[None]
|
| 234 |
+
|
| 235 |
+
keypoints = refine_keypoints_dark_udp(
|
| 236 |
+
keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
|
| 237 |
+
|
| 238 |
+
elif self.heatmap_type == 'combined':
|
| 239 |
+
_K, H, W = heatmaps.shape
|
| 240 |
+
K = _K // 3
|
| 241 |
+
|
| 242 |
+
for cls_heatmap in heatmaps[::3]:
|
| 243 |
+
# Apply Gaussian blur on classification maps
|
| 244 |
+
ks = 2 * self.blur_kernel_size + 1
|
| 245 |
+
cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap)
|
| 246 |
+
|
| 247 |
+
# valid radius
|
| 248 |
+
radius = self.radius_factor * max(W, H)
|
| 249 |
+
|
| 250 |
+
x_offset = heatmaps[1::3].flatten() * radius
|
| 251 |
+
y_offset = heatmaps[2::3].flatten() * radius
|
| 252 |
+
keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3])
|
| 253 |
+
index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten()
|
| 254 |
+
index += W * H * np.arange(0, K)
|
| 255 |
+
index = index.astype(int)
|
| 256 |
+
keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1)
|
| 257 |
+
# unsqueeze the instance dimension for single-instance results
|
| 258 |
+
keypoints = keypoints[None].astype(np.float32)
|
| 259 |
+
scores = scores[None]
|
| 260 |
+
|
| 261 |
+
keypoints = self._activation_pts_to_kpts(keypoints)
|
| 262 |
+
|
| 263 |
+
return keypoints, scores
|
mmpose/codecs/utils/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .camera_image_projection import (camera_to_image_coord, camera_to_pixel,
|
| 3 |
+
pixel_to_camera)
|
| 4 |
+
from .gaussian_heatmap import (generate_3d_gaussian_heatmaps,
|
| 5 |
+
generate_gaussian_heatmaps,
|
| 6 |
+
generate_udp_gaussian_heatmaps,
|
| 7 |
+
generate_unbiased_gaussian_heatmaps,
|
| 8 |
+
generate_onehot_heatmaps)
|
| 9 |
+
from .instance_property import (get_diagonal_lengths, get_instance_bbox,
|
| 10 |
+
get_instance_root)
|
| 11 |
+
from .offset_heatmap import (generate_displacement_heatmap,
|
| 12 |
+
generate_offset_heatmap)
|
| 13 |
+
from .post_processing import (batch_heatmap_nms, gaussian_blur,
|
| 14 |
+
gaussian_blur1d, get_heatmap_3d_maximum,
|
| 15 |
+
get_heatmap_maximum, get_simcc_maximum,
|
| 16 |
+
get_simcc_normalized, get_heatmap_expected_value)
|
| 17 |
+
from .refinement import (refine_keypoints, refine_keypoints_dark,
|
| 18 |
+
refine_keypoints_dark_udp, refine_simcc_dark)
|
| 19 |
+
from .oks_map import generate_oks_maps
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
'generate_gaussian_heatmaps', 'generate_udp_gaussian_heatmaps',
|
| 23 |
+
'generate_unbiased_gaussian_heatmaps', 'gaussian_blur',
|
| 24 |
+
'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap',
|
| 25 |
+
'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark',
|
| 26 |
+
'refine_keypoints_dark_udp', 'generate_displacement_heatmap',
|
| 27 |
+
'refine_simcc_dark', 'gaussian_blur1d', 'get_diagonal_lengths',
|
| 28 |
+
'get_instance_root', 'get_instance_bbox', 'get_simcc_normalized',
|
| 29 |
+
'camera_to_image_coord', 'camera_to_pixel', 'pixel_to_camera',
|
| 30 |
+
'get_heatmap_3d_maximum', 'generate_3d_gaussian_heatmaps',
|
| 31 |
+
'generate_oks_maps', 'get_heatmap_expected_value', 'generate_onehot_heatmaps'
|
| 32 |
+
]
|
mmpose/codecs/utils/camera_image_projection.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Dict, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def camera_to_image_coord(root_index: int, kpts_3d_cam: np.ndarray,
|
| 8 |
+
camera_param: Dict) -> Tuple[np.ndarray, np.ndarray]:
|
| 9 |
+
"""Project keypoints from camera space to image space and calculate factor.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
root_index (int): Index for root keypoint.
|
| 13 |
+
kpts_3d_cam (np.ndarray): Keypoint coordinates in camera space in
|
| 14 |
+
shape (N, K, D).
|
| 15 |
+
camera_param (dict): Parameters for the camera.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
tuple:
|
| 19 |
+
- kpts_3d_image (np.ndarray): Keypoint coordinates in image space in
|
| 20 |
+
shape (N, K, D).
|
| 21 |
+
- factor (np.ndarray): The scaling factor that maps keypoints from
|
| 22 |
+
image space to camera space in shape (N, ).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
root = kpts_3d_cam[..., root_index, :]
|
| 26 |
+
tl_kpt = root.copy()
|
| 27 |
+
tl_kpt[..., :2] -= 1.0
|
| 28 |
+
br_kpt = root.copy()
|
| 29 |
+
br_kpt[..., :2] += 1.0
|
| 30 |
+
tl_kpt = np.reshape(tl_kpt, (-1, 3))
|
| 31 |
+
br_kpt = np.reshape(br_kpt, (-1, 3))
|
| 32 |
+
fx, fy = camera_param['f'] / 1000.
|
| 33 |
+
cx, cy = camera_param['c'] / 1000.
|
| 34 |
+
|
| 35 |
+
tl2d = camera_to_pixel(tl_kpt, fx, fy, cx, cy)
|
| 36 |
+
br2d = camera_to_pixel(br_kpt, fx, fy, cx, cy)
|
| 37 |
+
|
| 38 |
+
rectangle_3d_size = 2.0
|
| 39 |
+
kpts_3d_image = np.zeros_like(kpts_3d_cam)
|
| 40 |
+
kpts_3d_image[..., :2] = camera_to_pixel(kpts_3d_cam.copy(), fx, fy, cx,
|
| 41 |
+
cy)
|
| 42 |
+
ratio = (br2d[..., 0] - tl2d[..., 0] + 0.001) / rectangle_3d_size
|
| 43 |
+
factor = rectangle_3d_size / (br2d[..., 0] - tl2d[..., 0] + 0.001)
|
| 44 |
+
kpts_3d_depth = ratio[:, None] * (
|
| 45 |
+
kpts_3d_cam[..., 2] - kpts_3d_cam[..., root_index:root_index + 1, 2])
|
| 46 |
+
kpts_3d_image[..., 2] = kpts_3d_depth
|
| 47 |
+
return kpts_3d_image, factor
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def camera_to_pixel(kpts_3d: np.ndarray,
|
| 51 |
+
fx: float,
|
| 52 |
+
fy: float,
|
| 53 |
+
cx: float,
|
| 54 |
+
cy: float,
|
| 55 |
+
shift: bool = False) -> np.ndarray:
|
| 56 |
+
"""Project keypoints from camera space to image space.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
kpts_3d (np.ndarray): Keypoint coordinates in camera space.
|
| 60 |
+
fx (float): x-coordinate of camera's focal length.
|
| 61 |
+
fy (float): y-coordinate of camera's focal length.
|
| 62 |
+
cx (float): x-coordinate of image center.
|
| 63 |
+
cy (float): y-coordinate of image center.
|
| 64 |
+
shift (bool): Whether to shift the coordinates by 1e-8.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
pose_2d (np.ndarray): Projected keypoint coordinates in image space.
|
| 68 |
+
"""
|
| 69 |
+
if not shift:
|
| 70 |
+
pose_2d = kpts_3d[..., :2] / kpts_3d[..., 2:3]
|
| 71 |
+
else:
|
| 72 |
+
pose_2d = kpts_3d[..., :2] / (kpts_3d[..., 2:3] + 1e-8)
|
| 73 |
+
pose_2d[..., 0] *= fx
|
| 74 |
+
pose_2d[..., 1] *= fy
|
| 75 |
+
pose_2d[..., 0] += cx
|
| 76 |
+
pose_2d[..., 1] += cy
|
| 77 |
+
return pose_2d
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def pixel_to_camera(kpts_3d: np.ndarray, fx: float, fy: float, cx: float,
|
| 81 |
+
cy: float) -> np.ndarray:
|
| 82 |
+
"""Project keypoints from camera space to image space.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
kpts_3d (np.ndarray): Keypoint coordinates in camera space.
|
| 86 |
+
fx (float): x-coordinate of camera's focal length.
|
| 87 |
+
fy (float): y-coordinate of camera's focal length.
|
| 88 |
+
cx (float): x-coordinate of image center.
|
| 89 |
+
cy (float): y-coordinate of image center.
|
| 90 |
+
shift (bool): Whether to shift the coordinates by 1e-8.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
pose_2d (np.ndarray): Projected keypoint coordinates in image space.
|
| 94 |
+
"""
|
| 95 |
+
pose_2d = kpts_3d.copy()
|
| 96 |
+
pose_2d[..., 0] -= cx
|
| 97 |
+
pose_2d[..., 1] -= cy
|
| 98 |
+
pose_2d[..., 0] /= fx
|
| 99 |
+
pose_2d[..., 1] /= fy
|
| 100 |
+
pose_2d[..., 0] *= kpts_3d[..., 2]
|
| 101 |
+
pose_2d[..., 1] *= kpts_3d[..., 2]
|
| 102 |
+
return pose_2d
|
mmpose/codecs/utils/gaussian_heatmap.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.spatial.distance import cdist
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def generate_3d_gaussian_heatmaps(
|
| 10 |
+
heatmap_size: Tuple[int, int, int],
|
| 11 |
+
keypoints: np.ndarray,
|
| 12 |
+
keypoints_visible: np.ndarray,
|
| 13 |
+
sigma: Union[float, Tuple[float], np.ndarray],
|
| 14 |
+
image_size: Tuple[int, int],
|
| 15 |
+
heatmap3d_depth_bound: float = 400.0,
|
| 16 |
+
joint_indices: Optional[list] = None,
|
| 17 |
+
max_bound: float = 1.0,
|
| 18 |
+
use_different_joint_weights: bool = False,
|
| 19 |
+
dataset_keypoint_weights: Optional[np.ndarray] = None
|
| 20 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 21 |
+
"""Generate 3d gaussian heatmaps of keypoints.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H, D]
|
| 25 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, C)
|
| 26 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 27 |
+
(N, K)
|
| 28 |
+
sigma (float or List[float]): A list of sigma values of the Gaussian
|
| 29 |
+
heatmap for each instance. If sigma is given as a single float
|
| 30 |
+
value, it will be expanded into a tuple
|
| 31 |
+
image_size (Tuple[int, int]): Size of input image.
|
| 32 |
+
heatmap3d_depth_bound (float): Boundary for 3d heatmap depth.
|
| 33 |
+
Default: 400.0.
|
| 34 |
+
joint_indices (List[int], optional): Indices of joints used for heatmap
|
| 35 |
+
generation. If None (default) is given, all joints will be used.
|
| 36 |
+
Default: ``None``.
|
| 37 |
+
max_bound (float): The maximal value of heatmap. Default: 1.0.
|
| 38 |
+
use_different_joint_weights (bool): Whether to use different joint
|
| 39 |
+
weights. Default: ``False``.
|
| 40 |
+
dataset_keypoint_weights (np.ndarray, optional): Keypoints weight in
|
| 41 |
+
shape (K, ).
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
tuple:
|
| 45 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 46 |
+
(K * D, H, W) where [W, H, D] is the `heatmap_size`
|
| 47 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 48 |
+
(N, K)
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
W, H, D = heatmap_size
|
| 52 |
+
|
| 53 |
+
# select the joints used for target generation
|
| 54 |
+
if joint_indices is not None:
|
| 55 |
+
keypoints = keypoints[:, joint_indices, ...]
|
| 56 |
+
keypoints_visible = keypoints_visible[:, joint_indices, ...]
|
| 57 |
+
N, K, _ = keypoints.shape
|
| 58 |
+
|
| 59 |
+
heatmaps = np.zeros([K, D, H, W], dtype=np.float32)
|
| 60 |
+
keypoint_weights = keypoints_visible.copy()
|
| 61 |
+
|
| 62 |
+
if isinstance(sigma, (int, float)):
|
| 63 |
+
sigma = (sigma, ) * N
|
| 64 |
+
|
| 65 |
+
for n in range(N):
|
| 66 |
+
# 3-sigma rule
|
| 67 |
+
radius = sigma[n] * 3
|
| 68 |
+
|
| 69 |
+
# joint location in heatmap coordinates
|
| 70 |
+
mu_x = keypoints[n, :, 0] * W / image_size[0] # (K, )
|
| 71 |
+
mu_y = keypoints[n, :, 1] * H / image_size[1]
|
| 72 |
+
mu_z = (keypoints[n, :, 2] / heatmap3d_depth_bound + 0.5) * D
|
| 73 |
+
|
| 74 |
+
keypoint_weights[n, ...] = keypoint_weights[n, ...] * (mu_z >= 0) * (
|
| 75 |
+
mu_z < D)
|
| 76 |
+
if use_different_joint_weights:
|
| 77 |
+
keypoint_weights[
|
| 78 |
+
n] = keypoint_weights[n] * dataset_keypoint_weights
|
| 79 |
+
# xy grid
|
| 80 |
+
gaussian_size = 2 * radius + 1
|
| 81 |
+
|
| 82 |
+
# get neighboring voxels coordinates
|
| 83 |
+
x = y = z = np.arange(gaussian_size, dtype=np.float32) - radius
|
| 84 |
+
zz, yy, xx = np.meshgrid(z, y, x)
|
| 85 |
+
|
| 86 |
+
xx = np.expand_dims(xx, axis=0)
|
| 87 |
+
yy = np.expand_dims(yy, axis=0)
|
| 88 |
+
zz = np.expand_dims(zz, axis=0)
|
| 89 |
+
mu_x = np.expand_dims(mu_x, axis=(-1, -2, -3))
|
| 90 |
+
mu_y = np.expand_dims(mu_y, axis=(-1, -2, -3))
|
| 91 |
+
mu_z = np.expand_dims(mu_z, axis=(-1, -2, -3))
|
| 92 |
+
|
| 93 |
+
xx, yy, zz = xx + mu_x, yy + mu_y, zz + mu_z
|
| 94 |
+
local_size = xx.shape[1]
|
| 95 |
+
|
| 96 |
+
# round the coordinates
|
| 97 |
+
xx = xx.round().clip(0, W - 1)
|
| 98 |
+
yy = yy.round().clip(0, H - 1)
|
| 99 |
+
zz = zz.round().clip(0, D - 1)
|
| 100 |
+
|
| 101 |
+
# compute the target value near joints
|
| 102 |
+
gaussian = np.exp(-((xx - mu_x)**2 + (yy - mu_y)**2 + (zz - mu_z)**2) /
|
| 103 |
+
(2 * sigma[n]**2))
|
| 104 |
+
|
| 105 |
+
# put the local target value to the full target heatmap
|
| 106 |
+
idx_joints = np.tile(
|
| 107 |
+
np.expand_dims(np.arange(K), axis=(-1, -2, -3)),
|
| 108 |
+
[1, local_size, local_size, local_size])
|
| 109 |
+
idx = np.stack([idx_joints, zz, yy, xx],
|
| 110 |
+
axis=-1).astype(int).reshape(-1, 4)
|
| 111 |
+
|
| 112 |
+
heatmaps[idx[:, 0], idx[:, 1], idx[:, 2], idx[:, 3]] = np.maximum(
|
| 113 |
+
heatmaps[idx[:, 0], idx[:, 1], idx[:, 2], idx[:, 3]],
|
| 114 |
+
gaussian.reshape(-1))
|
| 115 |
+
|
| 116 |
+
heatmaps = (heatmaps * max_bound).reshape(-1, H, W)
|
| 117 |
+
|
| 118 |
+
return heatmaps, keypoint_weights
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def generate_gaussian_heatmaps(
|
| 122 |
+
heatmap_size: Tuple[int, int],
|
| 123 |
+
keypoints: np.ndarray,
|
| 124 |
+
keypoints_visible: np.ndarray,
|
| 125 |
+
sigma: Union[float, Tuple[float], np.ndarray],
|
| 126 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 127 |
+
"""Generate gaussian heatmaps of keypoints.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 131 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 132 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 133 |
+
(N, K)
|
| 134 |
+
sigma (float or List[float]): A list of sigma values of the Gaussian
|
| 135 |
+
heatmap for each instance. If sigma is given as a single float
|
| 136 |
+
value, it will be expanded into a tuple
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
tuple:
|
| 140 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 141 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 142 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 143 |
+
(N, K)
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
N, K, _ = keypoints.shape
|
| 147 |
+
W, H = heatmap_size
|
| 148 |
+
|
| 149 |
+
heatmaps = np.zeros((K, H, W), dtype=np.float32)
|
| 150 |
+
keypoint_weights = keypoints_visible.copy()
|
| 151 |
+
|
| 152 |
+
if isinstance(sigma, (int, float)):
|
| 153 |
+
sigma = (sigma, ) * N
|
| 154 |
+
|
| 155 |
+
for n in range(N):
|
| 156 |
+
# 3-sigma rule
|
| 157 |
+
radius = sigma[n] * 3
|
| 158 |
+
|
| 159 |
+
# xy grid
|
| 160 |
+
gaussian_size = 2 * radius + 1
|
| 161 |
+
x = np.arange(0, gaussian_size, 1, dtype=np.float32)
|
| 162 |
+
y = x[:, None]
|
| 163 |
+
x0 = y0 = gaussian_size // 2
|
| 164 |
+
|
| 165 |
+
for k in range(K):
|
| 166 |
+
# skip unlabled keypoints
|
| 167 |
+
if keypoints_visible[n, k] < 0.5:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
# get gaussian center coordinates
|
| 171 |
+
mu = (keypoints[n, k] + 0.5).astype(np.int64)
|
| 172 |
+
|
| 173 |
+
# check that the gaussian has in-bounds part
|
| 174 |
+
left, top = (mu - radius).astype(np.int64)
|
| 175 |
+
right, bottom = (mu + radius + 1).astype(np.int64)
|
| 176 |
+
|
| 177 |
+
if left >= W or top >= H or right < 0 or bottom < 0:
|
| 178 |
+
keypoint_weights[n, k] = 0
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# The gaussian is not normalized,
|
| 182 |
+
# we want the center value to equal 1
|
| 183 |
+
gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma[n]**2))
|
| 184 |
+
|
| 185 |
+
# valid range in gaussian
|
| 186 |
+
g_x1 = max(0, -left)
|
| 187 |
+
g_x2 = min(W, right) - left
|
| 188 |
+
g_y1 = max(0, -top)
|
| 189 |
+
g_y2 = min(H, bottom) - top
|
| 190 |
+
|
| 191 |
+
# valid range in heatmap
|
| 192 |
+
h_x1 = max(0, left)
|
| 193 |
+
h_x2 = min(W, right)
|
| 194 |
+
h_y1 = max(0, top)
|
| 195 |
+
h_y2 = min(H, bottom)
|
| 196 |
+
|
| 197 |
+
heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2]
|
| 198 |
+
gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2]
|
| 199 |
+
|
| 200 |
+
_ = np.maximum(
|
| 201 |
+
heatmap_region, gaussian_regsion, out=heatmap_region)
|
| 202 |
+
|
| 203 |
+
return heatmaps, keypoint_weights
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def generate_unbiased_gaussian_heatmaps(
|
| 207 |
+
heatmap_size: Tuple[int, int],
|
| 208 |
+
keypoints: np.ndarray,
|
| 209 |
+
keypoints_visible: np.ndarray,
|
| 210 |
+
sigma: float,
|
| 211 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 212 |
+
"""Generate gaussian heatmaps of keypoints using `Dark Pose`_.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 216 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 217 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 218 |
+
(N, K)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
tuple:
|
| 222 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 223 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 224 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 225 |
+
(N, K)
|
| 226 |
+
|
| 227 |
+
.. _`Dark Pose`: https://arxiv.org/abs/1910.06278
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
N, K, _ = keypoints.shape
|
| 231 |
+
W, H = heatmap_size
|
| 232 |
+
|
| 233 |
+
heatmaps = np.zeros((K, H, W), dtype=np.float32)
|
| 234 |
+
keypoint_weights = keypoints_visible.copy()
|
| 235 |
+
|
| 236 |
+
# 3-sigma rule
|
| 237 |
+
radius = sigma * 3
|
| 238 |
+
|
| 239 |
+
# xy grid
|
| 240 |
+
x = np.arange(0, W, 1, dtype=np.float32)
|
| 241 |
+
y = np.arange(0, H, 1, dtype=np.float32)[:, None]
|
| 242 |
+
|
| 243 |
+
for n, k in product(range(N), range(K)):
|
| 244 |
+
# skip unlabled keypoints
|
| 245 |
+
if keypoints_visible[n, k] < 0.5:
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
mu = keypoints[n, k]
|
| 249 |
+
# check that the gaussian has in-bounds part
|
| 250 |
+
left, top = mu - radius
|
| 251 |
+
right, bottom = mu + radius + 1
|
| 252 |
+
|
| 253 |
+
if left >= W or top >= H or right < 0 or bottom < 0:
|
| 254 |
+
keypoint_weights[n, k] = 0
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
gaussian = np.exp(-((x - mu[0])**2 + (y - mu[1])**2) / (2 * sigma**2))
|
| 258 |
+
|
| 259 |
+
_ = np.maximum(gaussian, heatmaps[k], out=heatmaps[k])
|
| 260 |
+
|
| 261 |
+
return heatmaps, keypoint_weights
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def generate_udp_gaussian_heatmaps(
|
| 265 |
+
heatmap_size: Tuple[int, int],
|
| 266 |
+
keypoints: np.ndarray,
|
| 267 |
+
keypoints_visible: np.ndarray,
|
| 268 |
+
sigma,
|
| 269 |
+
keypoints_visibility: np.ndarray,
|
| 270 |
+
increase_sigma_with_padding: bool = False,
|
| 271 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 272 |
+
"""Generate gaussian heatmaps of keypoints using `UDP`_.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 276 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 277 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 278 |
+
(N, K)
|
| 279 |
+
sigma (float): The sigma value of the Gaussian heatmap
|
| 280 |
+
keypoints_visibility (np.ndarray): The visibility bit for each keypoint (N, K)
|
| 281 |
+
increase_sigma_with_padding (bool): Whether to increase the sigma
|
| 282 |
+
value with padding. Default: False
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
tuple:
|
| 286 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 287 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 288 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 289 |
+
(N, K)
|
| 290 |
+
|
| 291 |
+
.. _`UDP`: https://arxiv.org/abs/1911.07524
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
N, K, _ = keypoints.shape
|
| 295 |
+
W, H = heatmap_size
|
| 296 |
+
|
| 297 |
+
heatmaps = np.zeros((K, H, W), dtype=np.float32)
|
| 298 |
+
keypoint_weights = keypoints_visible.copy()
|
| 299 |
+
|
| 300 |
+
if isinstance(sigma, (int, float)):
|
| 301 |
+
scaled_sigmas = sigma * np.ones((N, K), dtype=np.float32)
|
| 302 |
+
sigmas = np.array([sigma] * K).reshape(1, -1).repeat(N, axis=0)
|
| 303 |
+
else:
|
| 304 |
+
scaled_sigmas = np.array(sigma).reshape(1, -1).repeat(N, axis=0)
|
| 305 |
+
sigmas = np.array(sigma).reshape(1, -1).repeat(N, axis=0)
|
| 306 |
+
|
| 307 |
+
scales_arr = np.ones((N, K), dtype=np.float32)
|
| 308 |
+
if increase_sigma_with_padding:
|
| 309 |
+
diag = np.sqrt(W**2 + H**2)
|
| 310 |
+
for n in range(N):
|
| 311 |
+
image_kpts = keypoints[n, :].squeeze()
|
| 312 |
+
vis_kpts = image_kpts[keypoints_visibility[n, :] > 0.5]
|
| 313 |
+
|
| 314 |
+
# Compute the distance between img_kpts and visible_kpts
|
| 315 |
+
if vis_kpts.size == 0:
|
| 316 |
+
min_dists = np.ones(image_kpts.shape[0]) * diag
|
| 317 |
+
else:
|
| 318 |
+
dists = cdist(image_kpts, vis_kpts, metric='euclidean')
|
| 319 |
+
min_dists = np.min(dists, axis=1)
|
| 320 |
+
|
| 321 |
+
scales = min_dists / diag * 2.0 # Maximum distance (diagonal) results in .0*sigma
|
| 322 |
+
scales_arr[n, :] = scales
|
| 323 |
+
scaled_sigmas[n, :] = sigma * (1+scales)
|
| 324 |
+
|
| 325 |
+
# print(scales_arr)
|
| 326 |
+
# print(scaled_sigmas)
|
| 327 |
+
|
| 328 |
+
for n, k in product(range(N), range(K)):
|
| 329 |
+
scaled_sigma = scaled_sigmas[n, k]
|
| 330 |
+
# skip unlabled keypoints
|
| 331 |
+
if keypoints_visible[n, k] < 0.5:
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
# 3-sigma rule
|
| 335 |
+
radius = scaled_sigma * 3
|
| 336 |
+
|
| 337 |
+
# xy grid
|
| 338 |
+
gaussian_size = 2 * radius + 1
|
| 339 |
+
x = np.arange(0, gaussian_size, 1, dtype=np.float32)
|
| 340 |
+
y = x[:, None]
|
| 341 |
+
|
| 342 |
+
mu = (keypoints[n, k] + 0.5).astype(np.int64)
|
| 343 |
+
# check that the gaussian has in-bounds part
|
| 344 |
+
left, top = (mu - radius).round().astype(np.int64)
|
| 345 |
+
right, bottom = (mu + radius + 1).round().astype(np.int64)
|
| 346 |
+
# left, top = (mu - radius).astype(np.int64)
|
| 347 |
+
# right, bottom = (mu + radius + 1).astype(np.int64)
|
| 348 |
+
|
| 349 |
+
if left >= W or top >= H or right < 0 or bottom < 0:
|
| 350 |
+
keypoint_weights[n, k] = 0
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
mu_ac = keypoints[n, k]
|
| 354 |
+
x0 = y0 = gaussian_size // 2
|
| 355 |
+
x0 += mu_ac[0] - mu[0]
|
| 356 |
+
y0 += mu_ac[1] - mu[1]
|
| 357 |
+
gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * scaled_sigma**2))
|
| 358 |
+
|
| 359 |
+
# Normalize Gaussian such that scaled_sigma = sigma is the norm
|
| 360 |
+
gaussian = gaussian / (scaled_sigma / sigmas[n, k])
|
| 361 |
+
|
| 362 |
+
# valid range in gaussian
|
| 363 |
+
g_x1 = max(0, -left)
|
| 364 |
+
g_x2 = min(W, right) - left
|
| 365 |
+
g_y1 = max(0, -top)
|
| 366 |
+
g_y2 = min(H, bottom) - top
|
| 367 |
+
|
| 368 |
+
# valid range in heatmap
|
| 369 |
+
h_x1 = max(0, left)
|
| 370 |
+
h_x2 = min(W, right)
|
| 371 |
+
h_y1 = max(0, top)
|
| 372 |
+
h_y2 = min(H, bottom)
|
| 373 |
+
|
| 374 |
+
# breakpoint()
|
| 375 |
+
|
| 376 |
+
heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2]
|
| 377 |
+
gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2]
|
| 378 |
+
|
| 379 |
+
_ = np.maximum(heatmap_region, gaussian_regsion, out=heatmap_region)
|
| 380 |
+
|
| 381 |
+
return heatmaps, keypoint_weights
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def generate_onehot_heatmaps(
|
| 385 |
+
heatmap_size: Tuple[int, int],
|
| 386 |
+
keypoints: np.ndarray,
|
| 387 |
+
keypoints_visible: np.ndarray,
|
| 388 |
+
sigma,
|
| 389 |
+
keypoints_visibility: np.ndarray,
|
| 390 |
+
increase_sigma_with_padding: bool = False,
|
| 391 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 392 |
+
"""Generate gaussian heatmaps of keypoints using `UDP`_.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 396 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 397 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 398 |
+
(N, K)
|
| 399 |
+
sigma (float): The sigma value of the Gaussian heatmap
|
| 400 |
+
keypoints_visibility (np.ndarray): The visibility bit for each keypoint (N, K)
|
| 401 |
+
increase_sigma_with_padding (bool): Whether to increase the sigma
|
| 402 |
+
value with padding. Default: False
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
tuple:
|
| 406 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 407 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 408 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 409 |
+
(N, K)
|
| 410 |
+
|
| 411 |
+
.. _`UDP`: https://arxiv.org/abs/1911.07524
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
N, K, _ = keypoints.shape
|
| 415 |
+
W, H = heatmap_size
|
| 416 |
+
|
| 417 |
+
heatmaps = np.zeros((K, H, W), dtype=np.float32)
|
| 418 |
+
keypoint_weights = keypoints_visible.copy()
|
| 419 |
+
|
| 420 |
+
for n, k in product(range(N), range(K)):
|
| 421 |
+
# skip unlabled keypoints
|
| 422 |
+
if keypoints_visible[n, k] < 0.5:
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
mu = (keypoints[n, k] + 0.5).astype(np.int64)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if mu[0] < 0 or mu[0] >= W or mu[1] < 0 or mu[1] >= H:
|
| 429 |
+
keypoint_weights[n, k] = 0
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
heatmaps[k, mu[1], mu[0]] = 1
|
| 433 |
+
return heatmaps, keypoint_weights
|
mmpose/codecs/utils/instance_property.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_instance_root(keypoints: np.ndarray,
|
| 8 |
+
keypoints_visible: Optional[np.ndarray] = None,
|
| 9 |
+
root_type: str = 'kpt_center') -> np.ndarray:
|
| 10 |
+
"""Calculate the coordinates and visibility of instance roots.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 14 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 15 |
+
(N, K)
|
| 16 |
+
root_type (str): Calculation of instance roots which should
|
| 17 |
+
be one of the following options:
|
| 18 |
+
|
| 19 |
+
- ``'kpt_center'``: The roots' coordinates are the mean
|
| 20 |
+
coordinates of visible keypoints
|
| 21 |
+
- ``'bbox_center'``: The roots' are the center of bounding
|
| 22 |
+
boxes outlined by visible keypoints
|
| 23 |
+
|
| 24 |
+
Defaults to ``'kpt_center'``
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
tuple
|
| 28 |
+
- roots_coordinate(np.ndarray): Coordinates of instance roots in
|
| 29 |
+
shape [N, D]
|
| 30 |
+
- roots_visible(np.ndarray): Visibility of instance roots in
|
| 31 |
+
shape [N]
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32)
|
| 35 |
+
roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2
|
| 36 |
+
|
| 37 |
+
for i in range(keypoints.shape[0]):
|
| 38 |
+
|
| 39 |
+
# collect visible keypoints
|
| 40 |
+
if keypoints_visible is not None:
|
| 41 |
+
visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
|
| 42 |
+
else:
|
| 43 |
+
visible_keypoints = keypoints[i]
|
| 44 |
+
if visible_keypoints.size == 0:
|
| 45 |
+
roots_visible[i] = 0
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
# compute the instance root with visible keypoints
|
| 49 |
+
if root_type == 'kpt_center':
|
| 50 |
+
roots_coordinate[i] = visible_keypoints.mean(axis=0)
|
| 51 |
+
roots_visible[i] = 1
|
| 52 |
+
elif root_type == 'bbox_center':
|
| 53 |
+
roots_coordinate[i] = (visible_keypoints.max(axis=0) +
|
| 54 |
+
visible_keypoints.min(axis=0)) / 2.0
|
| 55 |
+
roots_visible[i] = 1
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f'the value of `root_type` must be \'kpt_center\' or '
|
| 59 |
+
f'\'bbox_center\', but got \'{root_type}\'')
|
| 60 |
+
|
| 61 |
+
return roots_coordinate, roots_visible
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_instance_bbox(keypoints: np.ndarray,
|
| 65 |
+
keypoints_visible: Optional[np.ndarray] = None
|
| 66 |
+
) -> np.ndarray:
|
| 67 |
+
"""Calculate the pseudo instance bounding box from visible keypoints. The
|
| 68 |
+
bounding boxes are in the xyxy format.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 72 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 73 |
+
(N, K)
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
np.ndarray: bounding boxes in [N, 4]
|
| 77 |
+
"""
|
| 78 |
+
bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32)
|
| 79 |
+
for i in range(keypoints.shape[0]):
|
| 80 |
+
if keypoints_visible is not None:
|
| 81 |
+
visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
|
| 82 |
+
else:
|
| 83 |
+
visible_keypoints = keypoints[i]
|
| 84 |
+
if visible_keypoints.size == 0:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
bbox[i, :2] = visible_keypoints.min(axis=0)
|
| 88 |
+
bbox[i, 2:] = visible_keypoints.max(axis=0)
|
| 89 |
+
return bbox
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_diagonal_lengths(keypoints: np.ndarray,
|
| 93 |
+
keypoints_visible: Optional[np.ndarray] = None
|
| 94 |
+
) -> np.ndarray:
|
| 95 |
+
"""Calculate the diagonal length of instance bounding box from visible
|
| 96 |
+
keypoints.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 100 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 101 |
+
(N, K)
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
np.ndarray: bounding box diagonal length in [N]
|
| 105 |
+
"""
|
| 106 |
+
pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible)
|
| 107 |
+
pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2)
|
| 108 |
+
h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0]
|
| 109 |
+
diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1))
|
| 110 |
+
|
| 111 |
+
return diagonal_length
|
mmpose/codecs/utils/offset_heatmap.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def generate_offset_heatmap(
|
| 9 |
+
heatmap_size: Tuple[int, int],
|
| 10 |
+
keypoints: np.ndarray,
|
| 11 |
+
keypoints_visible: np.ndarray,
|
| 12 |
+
radius_factor: float,
|
| 13 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 14 |
+
"""Generate offset heatmaps of keypoints, where each keypoint is
|
| 15 |
+
represented by 3 maps: one pixel-level class label map (1 for keypoint and
|
| 16 |
+
0 for non-keypoint) and 2 pixel-level offset maps for x and y directions
|
| 17 |
+
respectively.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 21 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 22 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 23 |
+
(N, K)
|
| 24 |
+
radius_factor (float): The radius factor of the binary label
|
| 25 |
+
map. The positive region is defined as the neighbor of the
|
| 26 |
+
keypoint with the radius :math:`r=radius_factor*max(W, H)`
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
tuple:
|
| 30 |
+
- heatmap (np.ndarray): The generated heatmap in shape
|
| 31 |
+
(K*3, H, W) where [W, H] is the `heatmap_size`
|
| 32 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 33 |
+
(K,)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
N, K, _ = keypoints.shape
|
| 37 |
+
W, H = heatmap_size
|
| 38 |
+
|
| 39 |
+
heatmaps = np.zeros((K, 3, H, W), dtype=np.float32)
|
| 40 |
+
keypoint_weights = keypoints_visible.copy()
|
| 41 |
+
|
| 42 |
+
# xy grid
|
| 43 |
+
x = np.arange(0, W, 1)
|
| 44 |
+
y = np.arange(0, H, 1)[:, None]
|
| 45 |
+
|
| 46 |
+
# positive area radius in the classification map
|
| 47 |
+
radius = radius_factor * max(W, H)
|
| 48 |
+
|
| 49 |
+
for n, k in product(range(N), range(K)):
|
| 50 |
+
if keypoints_visible[n, k] < 0.5:
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
mu = keypoints[n, k]
|
| 54 |
+
|
| 55 |
+
x_offset = (mu[0] - x) / radius
|
| 56 |
+
y_offset = (mu[1] - y) / radius
|
| 57 |
+
|
| 58 |
+
heatmaps[k, 0] = np.where(x_offset**2 + y_offset**2 <= 1, 1., 0.)
|
| 59 |
+
heatmaps[k, 1] = x_offset
|
| 60 |
+
heatmaps[k, 2] = y_offset
|
| 61 |
+
|
| 62 |
+
heatmaps = heatmaps.reshape(K * 3, H, W)
|
| 63 |
+
|
| 64 |
+
return heatmaps, keypoint_weights
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def generate_displacement_heatmap(
|
| 68 |
+
heatmap_size: Tuple[int, int],
|
| 69 |
+
keypoints: np.ndarray,
|
| 70 |
+
keypoints_visible: np.ndarray,
|
| 71 |
+
roots: np.ndarray,
|
| 72 |
+
roots_visible: np.ndarray,
|
| 73 |
+
diagonal_lengths: np.ndarray,
|
| 74 |
+
radius: float,
|
| 75 |
+
):
|
| 76 |
+
"""Generate displacement heatmaps of keypoints, where each keypoint is
|
| 77 |
+
represented by 3 maps: one pixel-level class label map (1 for keypoint and
|
| 78 |
+
0 for non-keypoint) and 2 pixel-level offset maps for x and y directions
|
| 79 |
+
respectively.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 83 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 84 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 85 |
+
(N, K)
|
| 86 |
+
roots (np.ndarray): Coordinates of instance centers in shape (N, D).
|
| 87 |
+
The displacement fields of each instance will locate around its
|
| 88 |
+
center.
|
| 89 |
+
roots_visible (np.ndarray): Roots visibilities in shape (N,)
|
| 90 |
+
diagonal_lengths (np.ndarray): Diaginal length of the bounding boxes
|
| 91 |
+
of each instance in shape (N,)
|
| 92 |
+
radius (float): The radius factor of the binary label
|
| 93 |
+
map. The positive region is defined as the neighbor of the
|
| 94 |
+
keypoint with the radius :math:`r=radius_factor*max(W, H)`
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
tuple:
|
| 98 |
+
- displacements (np.ndarray): The generated displacement map in
|
| 99 |
+
shape (K*2, H, W) where [W, H] is the `heatmap_size`
|
| 100 |
+
- displacement_weights (np.ndarray): The target weights in shape
|
| 101 |
+
(K*2, H, W)
|
| 102 |
+
"""
|
| 103 |
+
N, K, _ = keypoints.shape
|
| 104 |
+
W, H = heatmap_size
|
| 105 |
+
|
| 106 |
+
displacements = np.zeros((K * 2, H, W), dtype=np.float32)
|
| 107 |
+
displacement_weights = np.zeros((K * 2, H, W), dtype=np.float32)
|
| 108 |
+
instance_size_map = np.zeros((H, W), dtype=np.float32)
|
| 109 |
+
|
| 110 |
+
for n in range(N):
|
| 111 |
+
if (roots_visible[n] < 1 or (roots[n, 0] < 0 or roots[n, 1] < 0)
|
| 112 |
+
or (roots[n, 0] >= W or roots[n, 1] >= H)):
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
diagonal_length = diagonal_lengths[n]
|
| 116 |
+
|
| 117 |
+
for k in range(K):
|
| 118 |
+
if keypoints_visible[n, k] < 1 or keypoints[n, k, 0] < 0 \
|
| 119 |
+
or keypoints[n, k, 1] < 0 or keypoints[n, k, 0] >= W \
|
| 120 |
+
or keypoints[n, k, 1] >= H:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
start_x = max(int(roots[n, 0] - radius), 0)
|
| 124 |
+
start_y = max(int(roots[n, 1] - radius), 0)
|
| 125 |
+
end_x = min(int(roots[n, 0] + radius), W)
|
| 126 |
+
end_y = min(int(roots[n, 1] + radius), H)
|
| 127 |
+
|
| 128 |
+
for x in range(start_x, end_x):
|
| 129 |
+
for y in range(start_y, end_y):
|
| 130 |
+
if displacements[2 * k, y,
|
| 131 |
+
x] != 0 or displacements[2 * k + 1, y,
|
| 132 |
+
x] != 0:
|
| 133 |
+
if diagonal_length > instance_size_map[y, x]:
|
| 134 |
+
# keep the gt displacement of smaller instance
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
displacement_weights[2 * k:2 * k + 2, y,
|
| 138 |
+
x] = 1 / diagonal_length
|
| 139 |
+
displacements[2 * k:2 * k + 2, y,
|
| 140 |
+
x] = keypoints[n, k] - [x, y]
|
| 141 |
+
instance_size_map[y, x] = diagonal_length
|
| 142 |
+
|
| 143 |
+
return displacements, displacement_weights
|
mmpose/codecs/utils/oks_map.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.spatial.distance import cdist
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def generate_oks_maps(
|
| 10 |
+
heatmap_size: Tuple[int, int],
|
| 11 |
+
keypoints: np.ndarray,
|
| 12 |
+
keypoints_visible: np.ndarray,
|
| 13 |
+
keypoints_visibility: np.ndarray,
|
| 14 |
+
sigma: float = 0.55,
|
| 15 |
+
increase_sigma_with_padding: bool = False,
|
| 16 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 17 |
+
"""Generate gaussian heatmaps of keypoints using `UDP`_.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
|
| 21 |
+
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
|
| 22 |
+
keypoints_visible (np.ndarray): Keypoint visibilities in shape
|
| 23 |
+
(N, K)
|
| 24 |
+
sigma (float): The sigma value of the Gaussian heatmap
|
| 25 |
+
keypoints_visibility (np.ndarray): The visibility bit for each keypoint (N, K)
|
| 26 |
+
increase_sigma_with_padding (bool): Whether to increase the sigma
|
| 27 |
+
value with padding. Default: False
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
tuple:
|
| 31 |
+
- heatmaps (np.ndarray): The generated heatmap in shape
|
| 32 |
+
(K, H, W) where [W, H] is the `heatmap_size`
|
| 33 |
+
- keypoint_weights (np.ndarray): The target weights in shape
|
| 34 |
+
(N, K)
|
| 35 |
+
|
| 36 |
+
.. _`UDP`: https://arxiv.org/abs/1911.07524
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
N, K, _ = keypoints.shape
|
| 40 |
+
W, H = heatmap_size
|
| 41 |
+
|
| 42 |
+
# The default sigmas are used for COCO dataset.
|
| 43 |
+
sigmas = np.array(
|
| 44 |
+
[2.6, 2.5, 2.5, 3.5, 3.5, 7.9, 7.9, 7.2, 7.2, 6.2, 6.2, 10.7, 10.7, 8.7, 8.7, 8.9, 8.9])/100
|
| 45 |
+
# sigmas = sigmas * 2 / sigmas.mean()
|
| 46 |
+
# sigmas = np.round(sigmas).astype(int)
|
| 47 |
+
# sigmas = np.clip(sigmas, 1, 10)
|
| 48 |
+
|
| 49 |
+
heatmaps = np.zeros((K, H, W), dtype=np.float32)
|
| 50 |
+
keypoint_weights = keypoints_visible.copy()
|
| 51 |
+
|
| 52 |
+
# bbox_area = W/1.25 * H/1.25
|
| 53 |
+
# bbox_area = W * H * 0.53
|
| 54 |
+
bbox_area = np.sqrt(H/1.25 * W/1.25)
|
| 55 |
+
|
| 56 |
+
# print(scales_arr)
|
| 57 |
+
# print(scaled_sigmas)
|
| 58 |
+
|
| 59 |
+
for n, k in product(range(N), range(K)):
|
| 60 |
+
kpt_sigma = sigmas[k]
|
| 61 |
+
# skip unlabled keypoints
|
| 62 |
+
if keypoints_visible[n, k] < 0.5:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
y_idx, x_idx = np.indices((H, W))
|
| 66 |
+
dx = x_idx - keypoints[n, k, 0]
|
| 67 |
+
dy = y_idx - keypoints[n, k, 1]
|
| 68 |
+
dist = np.sqrt(dx**2 + dy**2)
|
| 69 |
+
|
| 70 |
+
# e_map = (dx**2 + dy**2) / ((kpt_sigma*100)**2 * sigma)
|
| 71 |
+
vars = (kpt_sigma*2)**2
|
| 72 |
+
s = vars * bbox_area * 2
|
| 73 |
+
s = np.clip(s, 0.55, 3.0)
|
| 74 |
+
if sigma is not None and sigma > 0:
|
| 75 |
+
s = sigma
|
| 76 |
+
e_map = dist**2 / (2*s)
|
| 77 |
+
oks_map = np.exp(-e_map)
|
| 78 |
+
|
| 79 |
+
keypoint_weights[n, k] = (oks_map.max() > 0).astype(int)
|
| 80 |
+
|
| 81 |
+
# Scale such that there is always 1 at the maximum
|
| 82 |
+
if oks_map.max() > 1e-3:
|
| 83 |
+
oks_map = oks_map / oks_map.max()
|
| 84 |
+
|
| 85 |
+
# Scale OKS map such that 1 stays 1 and 0.5 becomes 0
|
| 86 |
+
# oks_map[oks_map < 0.5] = 0
|
| 87 |
+
# oks_map = 2 * oks_map - 1
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# oks_map[oks_map > 0.95] = 1
|
| 91 |
+
# print("{:.4f}, {:7.1f}, {:9.3f}, {:9.3f}, {:4.2f}".format(vars, bbox_area, vars * bbox_area* 2, s, oks_map.max()))
|
| 92 |
+
# if np.all(oks_map < 0.1):
|
| 93 |
+
# print("\t{:d} --> {:.4f}".format(k, s))
|
| 94 |
+
heatmaps[k] = oks_map
|
| 95 |
+
# breakpoint()
|
| 96 |
+
|
| 97 |
+
return heatmaps, keypoint_weights
|
mmpose/codecs/utils/post_processing.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from itertools import product
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from scipy.signal import convolve2d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_simcc_normalized(batch_pred_simcc, sigma=None):
|
| 15 |
+
"""Normalize the predicted SimCC.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
batch_pred_simcc (torch.Tensor): The predicted SimCC.
|
| 19 |
+
sigma (float): The sigma of the Gaussian distribution.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.Tensor: The normalized SimCC.
|
| 23 |
+
"""
|
| 24 |
+
B, K, _ = batch_pred_simcc.shape
|
| 25 |
+
|
| 26 |
+
# Scale and clamp the tensor
|
| 27 |
+
if sigma is not None:
|
| 28 |
+
batch_pred_simcc = batch_pred_simcc / (sigma * np.sqrt(np.pi * 2))
|
| 29 |
+
batch_pred_simcc = batch_pred_simcc.clamp(min=0)
|
| 30 |
+
|
| 31 |
+
# Compute the binary mask
|
| 32 |
+
mask = (batch_pred_simcc.amax(dim=-1) > 1).reshape(B, K, 1)
|
| 33 |
+
|
| 34 |
+
# Normalize the tensor using the maximum value
|
| 35 |
+
norm = (batch_pred_simcc / batch_pred_simcc.amax(dim=-1).reshape(B, K, 1))
|
| 36 |
+
|
| 37 |
+
# Apply normalization
|
| 38 |
+
batch_pred_simcc = torch.where(mask, norm, batch_pred_simcc)
|
| 39 |
+
|
| 40 |
+
return batch_pred_simcc
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
| 44 |
+
simcc_y: np.ndarray,
|
| 45 |
+
apply_softmax: bool = False
|
| 46 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 47 |
+
"""Get maximum response location and value from simcc representations.
|
| 48 |
+
|
| 49 |
+
Note:
|
| 50 |
+
instance number: N
|
| 51 |
+
num_keypoints: K
|
| 52 |
+
heatmap height: H
|
| 53 |
+
heatmap width: W
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
| 57 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
| 58 |
+
apply_softmax (bool): whether to apply softmax on the heatmap.
|
| 59 |
+
Defaults to False.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
tuple:
|
| 63 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
| 64 |
+
(K, 2) or (N, K, 2)
|
| 65 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
| 66 |
+
(K,) or (N, K)
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
assert isinstance(simcc_x, np.ndarray), ('simcc_x should be numpy.ndarray')
|
| 70 |
+
assert isinstance(simcc_y, np.ndarray), ('simcc_y should be numpy.ndarray')
|
| 71 |
+
assert simcc_x.ndim == 2 or simcc_x.ndim == 3, (
|
| 72 |
+
f'Invalid shape {simcc_x.shape}')
|
| 73 |
+
assert simcc_y.ndim == 2 or simcc_y.ndim == 3, (
|
| 74 |
+
f'Invalid shape {simcc_y.shape}')
|
| 75 |
+
assert simcc_x.ndim == simcc_y.ndim, (
|
| 76 |
+
f'{simcc_x.shape} != {simcc_y.shape}')
|
| 77 |
+
|
| 78 |
+
if simcc_x.ndim == 3:
|
| 79 |
+
N, K, Wx = simcc_x.shape
|
| 80 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
| 81 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
| 82 |
+
else:
|
| 83 |
+
N = None
|
| 84 |
+
|
| 85 |
+
if apply_softmax:
|
| 86 |
+
simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True)
|
| 87 |
+
simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True)
|
| 88 |
+
ex, ey = np.exp(simcc_x), np.exp(simcc_y)
|
| 89 |
+
simcc_x = ex / np.sum(ex, axis=1, keepdims=True)
|
| 90 |
+
simcc_y = ey / np.sum(ey, axis=1, keepdims=True)
|
| 91 |
+
|
| 92 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
| 93 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
| 94 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 95 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
| 96 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
| 97 |
+
|
| 98 |
+
mask = max_val_x > max_val_y
|
| 99 |
+
max_val_x[mask] = max_val_y[mask]
|
| 100 |
+
vals = max_val_x
|
| 101 |
+
locs[vals <= 0.] = -1
|
| 102 |
+
|
| 103 |
+
if N:
|
| 104 |
+
locs = locs.reshape(N, K, 2)
|
| 105 |
+
vals = vals.reshape(N, K)
|
| 106 |
+
|
| 107 |
+
return locs, vals
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_heatmap_3d_maximum(heatmaps: np.ndarray
|
| 111 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 112 |
+
"""Get maximum response location and value from heatmaps.
|
| 113 |
+
|
| 114 |
+
Note:
|
| 115 |
+
batch_size: B
|
| 116 |
+
num_keypoints: K
|
| 117 |
+
heatmap dimension: D
|
| 118 |
+
heatmap height: H
|
| 119 |
+
heatmap width: W
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
heatmaps (np.ndarray): Heatmaps in shape (K, D, H, W) or
|
| 123 |
+
(B, K, D, H, W)
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
tuple:
|
| 127 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
| 128 |
+
(K, 3) or (B, K, 3)
|
| 129 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
| 130 |
+
(K,) or (B, K)
|
| 131 |
+
"""
|
| 132 |
+
assert isinstance(heatmaps,
|
| 133 |
+
np.ndarray), ('heatmaps should be numpy.ndarray')
|
| 134 |
+
assert heatmaps.ndim == 4 or heatmaps.ndim == 5, (
|
| 135 |
+
f'Invalid shape {heatmaps.shape}')
|
| 136 |
+
|
| 137 |
+
if heatmaps.ndim == 4:
|
| 138 |
+
K, D, H, W = heatmaps.shape
|
| 139 |
+
B = None
|
| 140 |
+
heatmaps_flatten = heatmaps.reshape(K, -1)
|
| 141 |
+
else:
|
| 142 |
+
B, K, D, H, W = heatmaps.shape
|
| 143 |
+
heatmaps_flatten = heatmaps.reshape(B * K, -1)
|
| 144 |
+
|
| 145 |
+
z_locs, y_locs, x_locs = np.unravel_index(
|
| 146 |
+
np.argmax(heatmaps_flatten, axis=1), shape=(D, H, W))
|
| 147 |
+
locs = np.stack((x_locs, y_locs, z_locs), axis=-1).astype(np.float32)
|
| 148 |
+
vals = np.amax(heatmaps_flatten, axis=1)
|
| 149 |
+
locs[vals <= 0.] = -1
|
| 150 |
+
|
| 151 |
+
if B:
|
| 152 |
+
locs = locs.reshape(B, K, 3)
|
| 153 |
+
vals = vals.reshape(B, K)
|
| 154 |
+
|
| 155 |
+
return locs, vals
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_heatmap_maximum(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 159 |
+
"""Get maximum response location and value from heatmaps.
|
| 160 |
+
|
| 161 |
+
Note:
|
| 162 |
+
batch_size: B
|
| 163 |
+
num_keypoints: K
|
| 164 |
+
heatmap height: H
|
| 165 |
+
heatmap width: W
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W)
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
tuple:
|
| 172 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
| 173 |
+
(K, 2) or (B, K, 2)
|
| 174 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
| 175 |
+
(K,) or (B, K)
|
| 176 |
+
"""
|
| 177 |
+
assert isinstance(heatmaps,
|
| 178 |
+
np.ndarray), ('heatmaps should be numpy.ndarray')
|
| 179 |
+
assert heatmaps.ndim == 3 or heatmaps.ndim == 4, (
|
| 180 |
+
f'Invalid shape {heatmaps.shape}')
|
| 181 |
+
|
| 182 |
+
if heatmaps.ndim == 3:
|
| 183 |
+
K, H, W = heatmaps.shape
|
| 184 |
+
B = None
|
| 185 |
+
heatmaps_flatten = heatmaps.reshape(K, -1)
|
| 186 |
+
else:
|
| 187 |
+
B, K, H, W = heatmaps.shape
|
| 188 |
+
heatmaps_flatten = heatmaps.reshape(B * K, -1)
|
| 189 |
+
|
| 190 |
+
y_locs, x_locs = np.unravel_index(
|
| 191 |
+
np.argmax(heatmaps_flatten, axis=1), shape=(H, W))
|
| 192 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 193 |
+
vals = np.amax(heatmaps_flatten, axis=1)
|
| 194 |
+
locs[vals <= 0.] = -1
|
| 195 |
+
|
| 196 |
+
if B:
|
| 197 |
+
locs = locs.reshape(B, K, 2)
|
| 198 |
+
vals = vals.reshape(B, K)
|
| 199 |
+
|
| 200 |
+
return locs, vals
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray:
|
| 204 |
+
"""Modulate heatmap distribution with Gaussian.
|
| 205 |
+
|
| 206 |
+
Note:
|
| 207 |
+
- num_keypoints: K
|
| 208 |
+
- heatmap height: H
|
| 209 |
+
- heatmap width: W
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
heatmaps (np.ndarray[K, H, W]): model predicted heatmaps.
|
| 213 |
+
kernel (int): Gaussian kernel size (K) for modulation, which should
|
| 214 |
+
match the heatmap gaussian sigma when training.
|
| 215 |
+
K=17 for sigma=3 and k=11 for sigma=2.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
np.ndarray ([K, H, W]): Modulated heatmap distribution.
|
| 219 |
+
"""
|
| 220 |
+
assert kernel % 2 == 1
|
| 221 |
+
|
| 222 |
+
border = (kernel - 1) // 2
|
| 223 |
+
K, H, W = heatmaps.shape
|
| 224 |
+
|
| 225 |
+
for k in range(K):
|
| 226 |
+
origin_max = np.max(heatmaps[k])
|
| 227 |
+
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
| 228 |
+
dr[border:-border, border:-border] = heatmaps[k].copy()
|
| 229 |
+
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
|
| 230 |
+
heatmaps[k] = dr[border:-border, border:-border].copy()
|
| 231 |
+
heatmaps[k] *= origin_max / (np.max(heatmaps[k])+1e-12)
|
| 232 |
+
return heatmaps
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def gaussian_blur1d(simcc: np.ndarray, kernel: int = 11) -> np.ndarray:
|
| 236 |
+
"""Modulate simcc distribution with Gaussian.
|
| 237 |
+
|
| 238 |
+
Note:
|
| 239 |
+
- num_keypoints: K
|
| 240 |
+
- simcc length: Wx
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
simcc (np.ndarray[K, Wx]): model predicted simcc.
|
| 244 |
+
kernel (int): Gaussian kernel size (K) for modulation, which should
|
| 245 |
+
match the simcc gaussian sigma when training.
|
| 246 |
+
K=17 for sigma=3 and k=11 for sigma=2.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
np.ndarray ([K, Wx]): Modulated simcc distribution.
|
| 250 |
+
"""
|
| 251 |
+
assert kernel % 2 == 1
|
| 252 |
+
|
| 253 |
+
border = (kernel - 1) // 2
|
| 254 |
+
N, K, Wx = simcc.shape
|
| 255 |
+
|
| 256 |
+
for n, k in product(range(N), range(K)):
|
| 257 |
+
origin_max = np.max(simcc[n, k])
|
| 258 |
+
dr = np.zeros((1, Wx + 2 * border), dtype=np.float32)
|
| 259 |
+
dr[0, border:-border] = simcc[n, k].copy()
|
| 260 |
+
dr = cv2.GaussianBlur(dr, (kernel, 1), 0)
|
| 261 |
+
simcc[n, k] = dr[0, border:-border].copy()
|
| 262 |
+
simcc[n, k] *= origin_max / np.max(simcc[n, k])
|
| 263 |
+
return simcc
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5):
|
| 267 |
+
"""Apply NMS on a batch of heatmaps.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W)
|
| 271 |
+
kernel_size (int): The kernel size of the NMS which should be
|
| 272 |
+
a odd integer. Defaults to 5
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Tensor: The batch heatmaps after NMS.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
assert isinstance(kernel_size, int) and kernel_size % 2 == 1, \
|
| 279 |
+
f'The kernel_size should be an odd integer, got {kernel_size}'
|
| 280 |
+
|
| 281 |
+
padding = (kernel_size - 1) // 2
|
| 282 |
+
|
| 283 |
+
maximum = F.max_pool2d(
|
| 284 |
+
batch_heatmaps, kernel_size, stride=1, padding=padding)
|
| 285 |
+
maximum_indicator = torch.eq(batch_heatmaps, maximum)
|
| 286 |
+
batch_heatmaps = batch_heatmaps * maximum_indicator.float()
|
| 287 |
+
|
| 288 |
+
return batch_heatmaps
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def get_heatmap_expected_value(heatmaps: np.ndarray, parzen_size: float = 0.1, return_heatmap: bool = False) -> Tuple[np.ndarray, np.ndarray]:
|
| 292 |
+
"""Get maximum response location and value from heatmaps.
|
| 293 |
+
|
| 294 |
+
Note:
|
| 295 |
+
batch_size: B
|
| 296 |
+
num_keypoints: K
|
| 297 |
+
heatmap height: H
|
| 298 |
+
heatmap width: W
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W)
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
tuple:
|
| 305 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
| 306 |
+
(K, 2) or (B, K, 2)
|
| 307 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
| 308 |
+
(K,) or (B, K)
|
| 309 |
+
"""
|
| 310 |
+
assert isinstance(heatmaps,
|
| 311 |
+
np.ndarray), ('heatmaps should be numpy.ndarray')
|
| 312 |
+
assert heatmaps.ndim == 3 or heatmaps.ndim == 4, (
|
| 313 |
+
f'Invalid shape {heatmaps.shape}')
|
| 314 |
+
|
| 315 |
+
assert parzen_size >= 0.0 and parzen_size <= 1.0, (
|
| 316 |
+
f'Invalid parzen_size {parzen_size}')
|
| 317 |
+
|
| 318 |
+
if heatmaps.ndim == 3:
|
| 319 |
+
K, H, W = heatmaps.shape
|
| 320 |
+
B = 1
|
| 321 |
+
FIRST_DIM = K
|
| 322 |
+
heatmaps_flatten = heatmaps.reshape(1, K, H, W)
|
| 323 |
+
else:
|
| 324 |
+
B, K, H, W = heatmaps.shape
|
| 325 |
+
FIRST_DIM = K*B
|
| 326 |
+
heatmaps_flatten = heatmaps.reshape(B, K, H, W)
|
| 327 |
+
|
| 328 |
+
# Blur heatmaps with Gaussian
|
| 329 |
+
# heatmaps_flatten = gaussian_blur(heatmaps_flatten, kernel=9)
|
| 330 |
+
|
| 331 |
+
# Zero out pixels far from the maximum for each heatmap
|
| 332 |
+
# heatmaps_tmp = heatmaps_flatten.copy().reshape(B*K, H*W)
|
| 333 |
+
# y_locs, x_locs = np.unravel_index(
|
| 334 |
+
# np.argmax(heatmaps_tmp, axis=1), shape=(H, W))
|
| 335 |
+
# locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 336 |
+
# heatmaps_flatten = heatmaps_flatten.reshape(B*K, H, W)
|
| 337 |
+
# for i, x in enumerate(x_locs):
|
| 338 |
+
# y = y_locs[i]
|
| 339 |
+
# start_x = int(max(0, x - 0.2*W))
|
| 340 |
+
# end_x = int(min(W, x + 0.2*W))
|
| 341 |
+
# start_y = int(max(0, y - 0.2*H))
|
| 342 |
+
# end_y = int(min(H, y + 0.2*H))
|
| 343 |
+
# mask = np.zeros((H, W))
|
| 344 |
+
# mask[start_y:end_y, start_x:end_x] = 1
|
| 345 |
+
# heatmaps_flatten[i] = heatmaps_flatten[i] * mask
|
| 346 |
+
# heatmaps_flatten = heatmaps_flatten.reshape(B, K, H, W)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
bbox_area = np.sqrt(H/1.25 * W/1.25)
|
| 350 |
+
|
| 351 |
+
kpt_sigmas = np.array(
|
| 352 |
+
[2.6, 2.5, 2.5, 3.5, 3.5, 7.9, 7.9, 7.2, 7.2, 6.2, 6.2, 10.7, 10.7, 8.7, 8.7, 8.9, 8.9])/100
|
| 353 |
+
|
| 354 |
+
heatmaps_covolved = np.zeros_like(heatmaps_flatten)
|
| 355 |
+
for k in range(K):
|
| 356 |
+
vars = (kpt_sigmas[k]*2)**2
|
| 357 |
+
s = vars * bbox_area * 2
|
| 358 |
+
s = np.clip(s, 0.55, 3.0)
|
| 359 |
+
radius = np.ceil(s * 3).astype(int)
|
| 360 |
+
diameter = 2*radius + 1
|
| 361 |
+
diameter = np.ceil(diameter).astype(int)
|
| 362 |
+
# kernel_sizes[kernel_sizes % 2 == 0] += 1
|
| 363 |
+
center = diameter // 2
|
| 364 |
+
dist_x = np.arange(diameter) - center
|
| 365 |
+
dist_y = np.arange(diameter) - center
|
| 366 |
+
dist_x, dist_y = np.meshgrid(dist_x, dist_y)
|
| 367 |
+
dist = np.sqrt(dist_x**2 + dist_y**2)
|
| 368 |
+
oks_kernel = np.exp(-dist**2 / (2 * s))
|
| 369 |
+
oks_kernel = oks_kernel / oks_kernel.sum()
|
| 370 |
+
|
| 371 |
+
htm = heatmaps_flatten[:, k, :, :].reshape(-1, H, W)
|
| 372 |
+
# htm = np.pad(htm, ((0, 0), (radius, radius), (radius, radius)), mode='symmetric')
|
| 373 |
+
# htm = torch.from_numpy(htm).float()
|
| 374 |
+
# oks_kernel = torch.from_numpy(oks_kernel).float().to(htm.device).reshape(1, diameter, diameter)
|
| 375 |
+
oks_kernel = oks_kernel.reshape(1, diameter, diameter)
|
| 376 |
+
htm_conv = np.zeros_like(htm)
|
| 377 |
+
for b in range(B):
|
| 378 |
+
htm_conv[b, :, :] = convolve2d(htm[b, :, :], oks_kernel[b, :, :], mode='same', boundary='symm')
|
| 379 |
+
# htm_conv = F.conv2d(htm.unsqueeze(1), oks_kernel.unsqueeze(1), padding='same')
|
| 380 |
+
# htm_conv = htm_conv[:, :, radius:-radius, radius:-radius]
|
| 381 |
+
htm_conv = htm_conv.reshape(-1, 1, H, W)
|
| 382 |
+
heatmaps_covolved[:, k, :, :] = htm_conv
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
heatmaps_covolved = heatmaps_covolved.reshape(B*K, H*W)
|
| 386 |
+
y_locs, x_locs = np.unravel_index(
|
| 387 |
+
np.argmax(heatmaps_covolved, axis=1), shape=(H, W))
|
| 388 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 389 |
+
|
| 390 |
+
# Apply mean-shift to get sub-pixel locations
|
| 391 |
+
locs = _get_subpixel_maximums(heatmaps_covolved.reshape(B*K, H, W), locs)
|
| 392 |
+
# breakpoint()
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# heatmaps_sums = heatmaps_flatten.sum(axis=(1, 2))
|
| 396 |
+
# norm_heatmaps = heatmaps_flatten.copy()
|
| 397 |
+
# norm_heatmaps[heatmaps_sums > 0] = heatmaps_flatten[heatmaps_sums > 0] / heatmaps_sums[heatmaps_sums > 0, None, None]
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# # Compute Parzen window with Gaussian blur along the edge instead of simple mirroring
|
| 401 |
+
# x_pad = int(parzen_size * W + 0.5)
|
| 402 |
+
# y_pad = int(parzen_size * H + 0.5)
|
| 403 |
+
# # x_pad = 0
|
| 404 |
+
# # y_pad = 0
|
| 405 |
+
# kernel_size = int(min(H, W)*parzen_size + 0.5)
|
| 406 |
+
# if kernel_size % 2 == 0:
|
| 407 |
+
# kernel_size += 1
|
| 408 |
+
# # norm_heatmaps_pad_blur = np.pad(norm_heatmaps, ((0, 0), (x_pad, x_pad), (y_pad, y_pad)), mode='symmetric')
|
| 409 |
+
# norm_heatmaps_pad = np.pad(norm_heatmaps, ((0, 0), (y_pad, y_pad), (x_pad, x_pad)), mode='constant', constant_values=0)
|
| 410 |
+
# norm_heatmaps_pad_blur = gaussian_blur(norm_heatmaps_pad, kernel=kernel_size)
|
| 411 |
+
|
| 412 |
+
# # norm_heatmaps_pad_blur[:, x_pad:-x_pad, y_pad:-y_pad] = norm_heatmaps
|
| 413 |
+
|
| 414 |
+
# norm_heatmaps_pad_sum = norm_heatmaps_pad_blur.sum(axis=(1, 2))
|
| 415 |
+
# norm_heatmaps_pad_blur[norm_heatmaps_pad_sum>0] = norm_heatmaps_pad_blur[norm_heatmaps_pad_sum>0] / norm_heatmaps_pad_sum[norm_heatmaps_pad_sum>0, None, None]
|
| 416 |
+
|
| 417 |
+
# # # Save the blurred heatmaps
|
| 418 |
+
# # for i in range(heatmaps.shape[0]):
|
| 419 |
+
# # tmp_htm = norm_heatmaps_pad_blur[i].copy()
|
| 420 |
+
# # tmp_htm = (tmp_htm - tmp_htm.min()) / (tmp_htm.max() - tmp_htm.min())
|
| 421 |
+
# # tmp_htm = (tmp_htm*255).astype(np.uint8)
|
| 422 |
+
# # tmp_htm = cv2.cvtColor(tmp_htm, cv2.COLOR_GRAY2BGR)
|
| 423 |
+
# # tmp_htm = cv2.applyColorMap(tmp_htm, cv2.COLORMAP_JET)
|
| 424 |
+
|
| 425 |
+
# # tmp_htm2 = norm_heatmaps_pad[i].copy()
|
| 426 |
+
# # tmp_htm2 = (tmp_htm2 - tmp_htm2.min()) / (tmp_htm2.max() - tmp_htm2.min())
|
| 427 |
+
# # tmp_htm2 = (tmp_htm2*255).astype(np.uint8)
|
| 428 |
+
# # tmp_htm2 = cv2.cvtColor(tmp_htm2, cv2.COLOR_GRAY2BGR)
|
| 429 |
+
# # tmp_htm2 = cv2.applyColorMap(tmp_htm2, cv2.COLORMAP_JET)
|
| 430 |
+
|
| 431 |
+
# # tmp_htm = cv2.addWeighted(tmp_htm, 0.5, tmp_htm2, 0.5, 0)
|
| 432 |
+
|
| 433 |
+
# # cv2.imwrite(f'heatmaps_blurred_{i}.png', tmp_htm)
|
| 434 |
+
|
| 435 |
+
# # norm_heatmaps_pad = np.pad(norm_heatmaps, ((0, 0), (x_pad, x_pad), (y_pad, y_pad)), mode='edge')
|
| 436 |
+
|
| 437 |
+
# y_idx, x_idx = np.indices(norm_heatmaps_pad_blur.shape[1:])
|
| 438 |
+
|
| 439 |
+
# # breakpoint()
|
| 440 |
+
# x_locs = np.sum(norm_heatmaps_pad_blur * x_idx, axis=(1, 2)) - x_pad
|
| 441 |
+
# y_locs = np.sum(norm_heatmaps_pad_blur * y_idx, axis=(1, 2)) - y_pad
|
| 442 |
+
|
| 443 |
+
# # mean_idx = np.argmax(heatmaps_flatten, axis=1)
|
| 444 |
+
# # x_locs, y_locs = np.unravel_index(mean_idx, shape=(H, W))
|
| 445 |
+
# # locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 446 |
+
# # breakpoint()
|
| 447 |
+
# # vals = heatmaps_flatten[np.arange(heatmaps_flatten.shape[0]), mean_idx]
|
| 448 |
+
# # locs[vals <= 0.] = -1
|
| 449 |
+
|
| 450 |
+
# # mean_idx = np.argmax(norm_heatmaps, axis=1)
|
| 451 |
+
# # y_locs, x_locs = np.unravel_index(
|
| 452 |
+
# # mean_idx, shape=(H, W))
|
| 453 |
+
|
| 454 |
+
# locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 455 |
+
# # vals = np.amax(heatmaps_flatten, axis=1)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
x_locs_int = np.round(x_locs).astype(int)
|
| 459 |
+
x_locs_int = np.clip(x_locs_int, 0, W-1)
|
| 460 |
+
y_locs_int = np.round(y_locs).astype(int)
|
| 461 |
+
y_locs_int = np.clip(y_locs_int, 0, H-1)
|
| 462 |
+
vals = heatmaps_flatten[np.arange(B), np.arange(K), y_locs_int, x_locs_int]
|
| 463 |
+
# breakpoint()
|
| 464 |
+
# locs[vals <= 0.] = -1
|
| 465 |
+
|
| 466 |
+
# print(mean_idx)
|
| 467 |
+
# print(x_locs)
|
| 468 |
+
# print(y_locs)
|
| 469 |
+
# print(locs)
|
| 470 |
+
heatmaps_covolved = heatmaps_covolved.reshape(B, K, H, W)
|
| 471 |
+
|
| 472 |
+
if B > 1:
|
| 473 |
+
locs = locs.reshape(B, K, 2)
|
| 474 |
+
vals = vals.reshape(B, K)
|
| 475 |
+
heatmaps_covolved = heatmaps_covolved.reshape(B, K, H, W)
|
| 476 |
+
else:
|
| 477 |
+
locs = locs.reshape(K, 2)
|
| 478 |
+
vals = vals.reshape(K)
|
| 479 |
+
heatmaps_covolved = heatmaps_covolved.reshape(K, H, W)
|
| 480 |
+
|
| 481 |
+
if return_heatmap:
|
| 482 |
+
return locs, vals, heatmaps_covolved
|
| 483 |
+
else:
|
| 484 |
+
return locs, vals
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _get_subpixel_maximums(heatmaps, locs):
|
| 489 |
+
# Extract integer peak locations
|
| 490 |
+
x_locs = locs[:, 0].astype(np.int32)
|
| 491 |
+
y_locs = locs[:, 1].astype(np.int32)
|
| 492 |
+
|
| 493 |
+
# Ensure we are not near the boundaries (avoid boundary issues)
|
| 494 |
+
valid_mask = (x_locs > 0) & (x_locs < heatmaps.shape[2] - 1) & \
|
| 495 |
+
(y_locs > 0) & (y_locs < heatmaps.shape[1] - 1)
|
| 496 |
+
|
| 497 |
+
# Initialize the output array with the integer locations
|
| 498 |
+
subpixel_locs = locs.copy()
|
| 499 |
+
|
| 500 |
+
if np.any(valid_mask):
|
| 501 |
+
# Extract valid locations
|
| 502 |
+
x_locs_valid = x_locs[valid_mask]
|
| 503 |
+
y_locs_valid = y_locs[valid_mask]
|
| 504 |
+
|
| 505 |
+
# Compute gradients (dx, dy) and second derivatives (dxx, dyy)
|
| 506 |
+
dx = (heatmaps[valid_mask, y_locs_valid, x_locs_valid + 1] -
|
| 507 |
+
heatmaps[valid_mask, y_locs_valid, x_locs_valid - 1]) / 2.0
|
| 508 |
+
dy = (heatmaps[valid_mask, y_locs_valid + 1, x_locs_valid] -
|
| 509 |
+
heatmaps[valid_mask, y_locs_valid - 1, x_locs_valid]) / 2.0
|
| 510 |
+
dxx = heatmaps[valid_mask, y_locs_valid, x_locs_valid + 1] + \
|
| 511 |
+
heatmaps[valid_mask, y_locs_valid, x_locs_valid - 1] - \
|
| 512 |
+
2 * heatmaps[valid_mask, y_locs_valid, x_locs_valid]
|
| 513 |
+
dyy = heatmaps[valid_mask, y_locs_valid + 1, x_locs_valid] + \
|
| 514 |
+
heatmaps[valid_mask, y_locs_valid - 1, x_locs_valid] - \
|
| 515 |
+
2 * heatmaps[valid_mask, y_locs_valid, x_locs_valid]
|
| 516 |
+
|
| 517 |
+
# Avoid division by zero by setting a minimum threshold for the second derivatives
|
| 518 |
+
dxx = np.where(dxx != 0, dxx, 1e-6)
|
| 519 |
+
dyy = np.where(dyy != 0, dyy, 1e-6)
|
| 520 |
+
|
| 521 |
+
# Calculate the sub-pixel shift
|
| 522 |
+
subpixel_x_shift = -dx / dxx
|
| 523 |
+
subpixel_y_shift = -dy / dyy
|
| 524 |
+
|
| 525 |
+
# Update subpixel locations for valid indices
|
| 526 |
+
subpixel_locs[valid_mask, 0] += subpixel_x_shift
|
| 527 |
+
subpixel_locs[valid_mask, 1] += subpixel_y_shift
|
| 528 |
+
|
| 529 |
+
return subpixel_locs
|
| 530 |
+
|