Miroslav Purkrabek commited on
Commit
a249588
·
1 Parent(s): 4b8d5c5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CITATION.cff +25 -0
  2. LICENSE +674 -0
  3. app.py +262 -0
  4. configs/README.md +30 -0
  5. configs/bmp_D3.yaml +37 -0
  6. configs/bmp_J1.yaml +39 -0
  7. demo/bmp_demo.py +250 -0
  8. demo/demo_utils.py +705 -0
  9. demo/mm_utils.py +106 -0
  10. demo/posevis_lite.py +507 -0
  11. demo/sam2_utils.py +714 -0
  12. mmpose/__init__.py +27 -0
  13. mmpose/apis/__init__.py +16 -0
  14. mmpose/apis/inference.py +280 -0
  15. mmpose/apis/inference_3d.py +360 -0
  16. mmpose/apis/inference_tracking.py +103 -0
  17. mmpose/apis/inferencers/__init__.py +11 -0
  18. mmpose/apis/inferencers/base_mmpose_inferencer.py +691 -0
  19. mmpose/apis/inferencers/hand3d_inferencer.py +344 -0
  20. mmpose/apis/inferencers/mmpose_inferencer.py +250 -0
  21. mmpose/apis/inferencers/pose2d_inferencer.py +262 -0
  22. mmpose/apis/inferencers/pose3d_inferencer.py +457 -0
  23. mmpose/apis/inferencers/utils/__init__.py +5 -0
  24. mmpose/apis/inferencers/utils/default_det_models.py +36 -0
  25. mmpose/apis/inferencers/utils/get_model_alias.py +37 -0
  26. mmpose/apis/visualization.py +132 -0
  27. mmpose/codecs/__init__.py +25 -0
  28. mmpose/codecs/annotation_processors.py +100 -0
  29. mmpose/codecs/associative_embedding.py +522 -0
  30. mmpose/codecs/base.py +81 -0
  31. mmpose/codecs/decoupled_heatmap.py +274 -0
  32. mmpose/codecs/edpose_label.py +153 -0
  33. mmpose/codecs/hand_3d_heatmap.py +202 -0
  34. mmpose/codecs/image_pose_lifting.py +280 -0
  35. mmpose/codecs/integral_regression_label.py +121 -0
  36. mmpose/codecs/megvii_heatmap.py +147 -0
  37. mmpose/codecs/motionbert_label.py +240 -0
  38. mmpose/codecs/msra_heatmap.py +153 -0
  39. mmpose/codecs/onehot_heatmap.py +263 -0
  40. mmpose/codecs/regression_label.py +108 -0
  41. mmpose/codecs/simcc_label.py +311 -0
  42. mmpose/codecs/spr.py +306 -0
  43. mmpose/codecs/udp_heatmap.py +263 -0
  44. mmpose/codecs/utils/__init__.py +32 -0
  45. mmpose/codecs/utils/camera_image_projection.py +102 -0
  46. mmpose/codecs/utils/gaussian_heatmap.py +433 -0
  47. mmpose/codecs/utils/instance_property.py +111 -0
  48. mmpose/codecs/utils/offset_heatmap.py +143 -0
  49. mmpose/codecs/utils/oks_map.py +97 -0
  50. mmpose/codecs/utils/post_processing.py +530 -0
CITATION.cff ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CITATION.cff file for Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle
2
+ # This file provides metadata for the software and its preferred citation format.
3
+ cff-version: 1.2.0
4
+ message: "If you use this software, please cite it as below."
5
+ authors:
6
+ - family-names: Purkrabek
7
+ given-names: Miroslav
8
+ - family-names: Matas
9
+ given-names: Jiri
10
+ title: "Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle"
11
+ version: 1.0.0
12
+ date-released: 2025-06-20
13
+ preferred-citation:
14
+ type: conference-paper
15
+ authors:
16
+ - family-names: Purkrabek
17
+ given-names: Miroslav
18
+ - family-names: Matas
19
+ given-names: Jiri
20
+ collection-title: "Proceedings of the IEEE/CVF International Conference on Computer Vision"
21
+ month: 10
22
+ start: 1 # First page number
23
+ end: 8 # Last page number
24
+ title: "Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle"
25
+ year: 2025
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import yaml
8
+ from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo
9
+ from demo.mm_utils import run_MMDetector, run_MMPose
10
+ from mmdet.apis import init_detector
11
+ from demo.sam2_utils import prepare_model as prepare_sam2_model
12
+ from demo.sam2_utils import process_image_with_SAM
13
+
14
+ from mmpose.apis import init_model as init_pose_estimator
15
+ from mmpose.utils import adapt_mmdet_pipeline
16
+
17
+ # Default thresholds
18
+ DEFAULT_CAT_ID: int = 0
19
+
20
+ DEFAULT_BBOX_THR: float = 0.3
21
+ DEFAULT_NMS_THR: float = 0.3
22
+ DEFAULT_KPT_THR: float = 0.3
23
+
24
+ # Global models variable
25
+ det_model = None
26
+ pose_model = None
27
+ sam2_model = None
28
+
29
+ def _parse_yaml_config(yaml_path: Path) -> DotDict:
30
+ """
31
+ Load BMP configuration from a YAML file.
32
+
33
+ Args:
34
+ yaml_path (Path): Path to YAML config.
35
+ Returns:
36
+ DotDict: Nested config dictionary.
37
+ """
38
+ with open(yaml_path, "r") as f:
39
+ cfg = yaml.safe_load(f)
40
+ return DotDict(cfg)
41
+
42
+ def load_models(bmp_config):
43
+ device = 'cuda:0'
44
+
45
+ global det_model, pose_model, sam2_model
46
+
47
+ # build detectors
48
+ det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF
49
+ det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
50
+
51
+
52
+ # build pose estimator
53
+ pose_model = init_pose_estimator(
54
+ bmp_config.pose_estimator.pose_config,
55
+ bmp_config.pose_estimator.pose_checkpoint,
56
+ device=device,
57
+ cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
58
+ )
59
+
60
+ sam2_model = prepare_sam2_model(
61
+ model_cfg=bmp_config.sam2.sam2_config,
62
+ model_checkpoint=bmp_config.sam2.sam2_checkpoint,
63
+ )
64
+
65
+ return det_model, pose_model, sam2_model
66
+
67
+ @spaces.GPU(duration=60)
68
+ def process_image_with_BMP(
69
+ img: np.ndarray
70
+ ) -> tuple[np.ndarray, np.ndarray]:
71
+ """
72
+ Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
73
+
74
+ Args:
75
+ args (Namespace): Parsed CLI arguments.
76
+ bmp_config (DotDict): Configuration parameters.
77
+ img_path (Path): Path to the input image.
78
+ detector: Primary MMDetection model.
79
+ detector_prime: Secondary MMDetection model for iterations.
80
+ pose_estimator: MMPose model for keypoint estimation.
81
+ sam2_model: SAM model for mask refinement.
82
+ Returns:
83
+ InstanceData: Final merged detections and refined masks.
84
+ """
85
+ bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml"))
86
+ load_models(bmp_config)
87
+
88
+ # img: RGB -> BGR
89
+ img = img[..., ::-1]
90
+
91
+ img_for_detection = img.copy()
92
+ rtmdet_result = None
93
+ all_detections = None
94
+ for iteration in range(bmp_config.num_bmp_iters):
95
+
96
+ # Step 1: Detection
97
+ det_instances = run_MMDetector(
98
+ det_model,
99
+ img_for_detection,
100
+ det_cat_id=DEFAULT_CAT_ID,
101
+ bbox_thr=DEFAULT_BBOX_THR,
102
+ nms_thr=DEFAULT_NMS_THR,
103
+ )
104
+ if len(det_instances.bboxes) == 0:
105
+ continue
106
+
107
+ # Step 2: Pose estimation
108
+ pose_instances = run_MMPose(
109
+ pose_model,
110
+ img.copy(),
111
+ detections=det_instances,
112
+ kpt_thr=DEFAULT_KPT_THR,
113
+ )
114
+
115
+ # Restrict to first 17 COCO keypoints
116
+ pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
117
+ pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
118
+ pose_instances.keypoints = np.concatenate(
119
+ [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
120
+ )
121
+
122
+ # Step 3: Pose-NMS and SAM refinement
123
+ all_keypoints = (
124
+ pose_instances.keypoints
125
+ if all_detections is None
126
+ else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
127
+ )
128
+ all_bboxes = (
129
+ pose_instances.bboxes
130
+ if all_detections is None
131
+ else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
132
+ )
133
+ num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
134
+ keep_indices = pose_nms(
135
+ DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
136
+ image_kpts=all_keypoints,
137
+ image_bboxes=all_bboxes,
138
+ num_valid_kpts=num_valid_kpts,
139
+ )
140
+ keep_indices = sorted(keep_indices) # Sort by original index
141
+ num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
142
+ keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
143
+ keep_old_indices = [i for i in keep_indices if i < num_old_detections]
144
+ if len(keep_new_indices) == 0:
145
+ continue
146
+ # filter new detections and compute scores
147
+ new_dets = filter_instances(pose_instances, keep_new_indices)
148
+ new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
149
+ old_dets = None
150
+ if len(keep_old_indices) > 0:
151
+ old_dets = filter_instances(all_detections, keep_old_indices)
152
+
153
+ new_detections = process_image_with_SAM(
154
+ DotDict(bmp_config.sam2.prompting),
155
+ img.copy(),
156
+ sam2_model,
157
+ new_dets,
158
+ old_dets if old_dets is not None else None,
159
+ )
160
+
161
+ # Merge detections
162
+ if all_detections is None:
163
+ all_detections = new_detections
164
+ else:
165
+ all_detections = concat_instances(all_detections, new_dets)
166
+
167
+ # Step 4: Visualization
168
+ img_for_detection, rtmdet_r, _ = visualize_demo(
169
+ img.copy(),
170
+ all_detections,
171
+ )
172
+
173
+ if iteration == 0:
174
+ rtmdet_result = rtmdet_r
175
+
176
+ _, _, bmp_result = visualize_demo(
177
+ img.copy(),
178
+ all_detections,
179
+ )
180
+
181
+ # img: BGR -> RGB
182
+ rtmdet_result = rtmdet_result[..., ::-1]
183
+ bmp_result = bmp_result[..., ::-1]
184
+
185
+ return rtmdet_result, bmp_result
186
+
187
+
188
+ with gr.Blocks() as app:
189
+ gr.Markdown("# BBoxMaskPose Image Demo")
190
+ gr.Markdown(
191
+ "Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]"
192
+ )
193
+ gr.Markdown(
194
+ "For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). "
195
+ "The demo showcases the capabilities of the BBoxMaskPose framework on any image. "
196
+ "If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). "
197
+ "Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation."
198
+ )
199
+
200
+ with gr.Row():
201
+ with gr.Column():
202
+ original_image_input = gr.Image(type="numpy", label="Original Image")
203
+ submit_button = gr.Button("Run Inference")
204
+
205
+ with gr.Column():
206
+ output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B")
207
+
208
+ with gr.Column():
209
+ output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose")
210
+
211
+
212
+ gr.Examples(
213
+ label="OCHuman examples",
214
+ examples=[
215
+ ["examples/004806.jpg"],
216
+ ["examples/005056.jpg"],
217
+ ["examples/004981.jpg"],
218
+ ["examples/004655.jpg"],
219
+ ["examples/004684.jpg"],
220
+ ["examples/004974.jpg"],
221
+ ["examples/004983.jpg"],
222
+ ["examples/005017.jpg"],
223
+ ["examples/004849.jpg"],
224
+ ],
225
+ inputs=[
226
+ original_image_input,
227
+ ],
228
+ outputs=[output_standard, output_sahi_sliced],
229
+ fn=process_image_with_BMP,
230
+ cache_examples=True,
231
+ )
232
+ gr.Examples(
233
+ label="In-the-wild examples",
234
+ examples=[
235
+ ["examples/prochazka_MMA.jpg"],
236
+ ["examples/riner_judo.jpg"],
237
+ ["examples/tackle3.jpg"],
238
+ ["examples/tackle1.jpg"],
239
+ ["examples/tackle2.jpg"],
240
+ ["examples/tackle5.jpg"],
241
+ ["examples/floorball_SKV_3.jpg"],
242
+ ["examples/santa_o_crop.jpg"],
243
+ ["examples/floorball_SKV_2.jpg"],
244
+ ],
245
+ inputs=[
246
+ original_image_input,
247
+ ],
248
+ outputs=[output_standard, output_sahi_sliced],
249
+ fn=process_image_with_BMP,
250
+ cache_examples=True,
251
+ )
252
+
253
+ submit_button.click(
254
+ fn=process_image_with_BMP,
255
+ inputs=[
256
+ original_image_input,
257
+ ],
258
+ outputs=[output_standard, output_sahi_sliced],
259
+ )
260
+
261
+ # Launch the demo
262
+ app.launch()
configs/README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration Files Overview
2
+
3
+ This directory contains configuration files for reproducing experiments and running inference across different components of the BBoxMaskPose project.
4
+
5
+ ## Which configs are available?
6
+
7
+ Here you can find configs setting-up hyperparameters of the whole loop.
8
+ These are mainly:
9
+ - How to prompt SAM
10
+ - Which models to use (detection, pose, SAM)
11
+ - How to chain models
12
+ - ...
13
+
14
+ For easier reference, the configs have the same names as in the supplementary material of the ICCV paper.
15
+ So for example config [**bmp_D3.yaml**](bmp_D3.yaml) is the prompting experiment used in the BMP loop.
16
+ For details, see Tabs. 6 - 8 of the supplementary.
17
+
18
+
19
+ ## Where are appropriate configs?
20
+
21
+ - **/configs** (this folder)
22
+ - Hyperparameter configurations for the BMP loop experiments. Use these files to reproduce training and evaluation settings.
23
+
24
+ - **/mmpose/configs**
25
+ - Configuration files for MMPose, following the same format and structure as MMPose v1.3.1. Supports models, datasets, and training pipelines.
26
+
27
+ - **/sam2/configs**
28
+ - Configuration files for SAM2, matching the format and directory layout of the original SAM v2.1 repository. Use these for prompt-driven segmentation and related tasks.
29
+
30
+
configs/bmp_D3.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BBoxMaskPose Hyperparameters from Experiment D3.
2
+ # For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
3
+
4
+ # This configuration is good for the BMP loop as was used for most of the experiments.
5
+ detector:
6
+ det_config: 'mmpose/configs/mmdet/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
7
+ det_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth'
8
+
9
+ # Detectors D and D' could be different.
10
+ det_prime_config: null
11
+ det_prime_checkpoint: null
12
+
13
+ pose_estimator:
14
+ pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
15
+ pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
16
+
17
+ sam2:
18
+ sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
19
+ sam2_checkpoint: 'models/SAM/sam2.1_hiera_base_plus.pt'
20
+ prompting:
21
+ batch: False
22
+ use_bbox: False
23
+ num_pos_keypoints: 6
24
+ num_pos_keypoints_if_crowd: 6
25
+ num_neg_keypoints: 0
26
+ confidence_thr: 0.3
27
+ visibility_thr: 0.3
28
+ selection_method: 'distance+confidence'
29
+ extend_bbox: False
30
+ pose_mask_consistency: False
31
+ crowd_by_max_iou: False # Determine if the instance is in the multi-body scenario. If yes, use different amount of keypoints and NO BBOX. If no, use bbox according to 'use_bbox' argument.
32
+ crop: False
33
+ exclusive_masks: True
34
+ ignore_small_bboxes: False
35
+
36
+ num_bmp_iters: 2
37
+ oks_nms_thr: 0.8
configs/bmp_J1.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BBoxMaskPose Hyperparameters from Experiment J1.
2
+ # For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
3
+
4
+ # This configuration is good for getting extra AP points when the estimates are already good.
5
+ # It is not recommended for the whole loop (as done here -- this is for the demo) but rather for
6
+ # the det-pose-sam-pose studied in Tab. 4.
7
+ detector:
8
+ det_config: 'mmpose/configs/mmdet/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
9
+ det_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth'
10
+
11
+ # Detectors D and D' could be different.
12
+ det_prime_config: null
13
+ det_prime_checkpoint: null
14
+
15
+ pose_estimator:
16
+ pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
17
+ pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
18
+
19
+ sam2:
20
+ sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
21
+ sam2_checkpoint: 'models/SAM/sam2.1_hiera_base_plus.pt'
22
+ prompting:
23
+ batch: True
24
+ use_bbox: False
25
+ num_pos_keypoints: 4
26
+ num_pos_keypoints_if_crowd: 6
27
+ num_neg_keypoints: 0
28
+ confidence_thr: 0.5
29
+ visibility_thr: 0.5
30
+ selection_method: 'distance+confidence'
31
+ extend_bbox: False
32
+ pose_mask_consistency: False
33
+ crowd_by_max_iou: 0.5 # Determine if the instance is in the multi-body scenario. If yes, use different amount of keypoints and NO BBOX. If no, use bbox according to 'use_bbox' argument.
34
+ crop: False
35
+ exclusive_masks: True
36
+ ignore_small_bboxes: False
37
+
38
+ num_bmp_iters: 2
39
+ oks_nms_thr: 0.8
demo/bmp_demo.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """
3
+ BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization.
4
+ Usage:
5
+ python bmp_demo.py <config.yaml> <input_image> [--output-root <dir>]
6
+ """
7
+
8
+ import os
9
+ import shutil
10
+ from argparse import ArgumentParser, Namespace
11
+ from pathlib import Path
12
+
13
+ import mmcv
14
+ import mmengine
15
+ import numpy as np
16
+ import yaml
17
+ from demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_itteration
18
+ from mm_utils import run_MMDetector, run_MMPose
19
+ from mmdet.apis import init_detector
20
+ from mmengine.logging import print_log
21
+ from mmengine.structures import InstanceData
22
+ from sam2_utils import prepare_model as prepare_sam2_model
23
+ from sam2_utils import process_image_with_SAM
24
+
25
+ from mmpose.apis import init_model as init_pose_estimator
26
+ from mmpose.utils import adapt_mmdet_pipeline
27
+
28
+ # Default thresholds
29
+ DEFAULT_DET_CAT_ID: int = 0 # "person"
30
+ DEFAULT_BBOX_THR: float = 0.3
31
+ DEFAULT_NMS_THR: float = 0.3
32
+ DEFAULT_KPT_THR: float = 0.3
33
+
34
+
35
+ def parse_args() -> Namespace:
36
+ """
37
+ Parse command-line arguments for BMP demo.
38
+
39
+ Returns:
40
+ Namespace: Contains bmp_config (Path), input (Path), output_root (Path), device (str).
41
+ """
42
+ parser = ArgumentParser(description="BBoxMaskPose demo")
43
+ parser.add_argument("bmp_config", type=Path, help="Path to BMP YAML config file")
44
+ parser.add_argument("input", type=Path, help="Input image file")
45
+ parser.add_argument("--output-root", type=Path, default=None, help="Directory to save outputs (default: ./outputs)")
46
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g., cuda:0 or cpu)")
47
+ parser.add_argument("--create-gif", action="store_true", default=False, help="Create GIF of all BMP iterations")
48
+ args = parser.parse_args()
49
+ if args.output_root is None:
50
+ args.output_root = os.path.join(Path(__file__).parent, "outputs")
51
+ return args
52
+
53
+
54
+ def parse_yaml_config(yaml_path: Path) -> DotDict:
55
+ """
56
+ Load BMP configuration from a YAML file.
57
+
58
+ Args:
59
+ yaml_path (Path): Path to YAML config.
60
+ Returns:
61
+ DotDict: Nested config dictionary.
62
+ """
63
+ with open(yaml_path, "r") as f:
64
+ cfg = yaml.safe_load(f)
65
+ return DotDict(cfg)
66
+
67
+
68
+ def process_one_image(
69
+ args: Namespace,
70
+ bmp_config: DotDict,
71
+ img_path: Path,
72
+ detector: object,
73
+ detector_prime: object,
74
+ pose_estimator: object,
75
+ sam2_model: object,
76
+ ) -> InstanceData:
77
+ """
78
+ Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
79
+
80
+ Args:
81
+ args (Namespace): Parsed CLI arguments.
82
+ bmp_config (DotDict): Configuration parameters.
83
+ img_path (Path): Path to the input image.
84
+ detector: Primary MMDetection model.
85
+ detector_prime: Secondary MMDetection model for iterations.
86
+ pose_estimator: MMPose model for keypoint estimation.
87
+ sam2_model: SAM model for mask refinement.
88
+ Returns:
89
+ InstanceData: Final merged detections and refined masks.
90
+ """
91
+ # Load image
92
+ img = mmcv.imread(str(img_path), channel_order="bgr")
93
+ if img is None:
94
+ raise ValueError("Failed to read image from {}.".format(img_path))
95
+
96
+ # Prepare output directory
97
+ output_dir = os.path.join(args.output_root, img_path.stem)
98
+ shutil.rmtree(str(output_dir), ignore_errors=True)
99
+ mmengine.mkdir_or_exist(str(output_dir))
100
+
101
+ img_for_detection = img.copy()
102
+ all_detections = None
103
+ for iteration in range(bmp_config.num_bmp_iters):
104
+ print_log("BMP Iteration {}/{} started".format(iteration + 1, bmp_config.num_bmp_iters), logger="current")
105
+
106
+ # Step 1: Detection
107
+ det_instances = run_MMDetector(
108
+ detector if iteration == 0 else detector_prime,
109
+ img_for_detection,
110
+ det_cat_id=DEFAULT_DET_CAT_ID,
111
+ bbox_thr=DEFAULT_BBOX_THR,
112
+ nms_thr=DEFAULT_NMS_THR,
113
+ )
114
+ print_log("Detected {} instances".format(len(det_instances.bboxes)), logger="current")
115
+ if len(det_instances.bboxes) == 0:
116
+ print_log("No detections found, skipping.", logger="current")
117
+ continue
118
+
119
+ # Step 2: Pose estimation
120
+ pose_instances = run_MMPose(
121
+ pose_estimator,
122
+ img.copy(),
123
+ detections=det_instances,
124
+ kpt_thr=DEFAULT_KPT_THR,
125
+ )
126
+ # Restrict to first 17 COCO keypoints
127
+ pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
128
+ pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
129
+ pose_instances.keypoints = np.concatenate(
130
+ [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
131
+ )
132
+
133
+ # Step 3: Pose-NMS and SAM refinement
134
+ all_keypoints = (
135
+ pose_instances.keypoints
136
+ if all_detections is None
137
+ else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
138
+ )
139
+ all_bboxes = (
140
+ pose_instances.bboxes
141
+ if all_detections is None
142
+ else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
143
+ )
144
+ num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
145
+ keep_indices = pose_nms(
146
+ DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
147
+ image_kpts=all_keypoints,
148
+ image_bboxes=all_bboxes,
149
+ num_valid_kpts=num_valid_kpts,
150
+ )
151
+ keep_indices = sorted(keep_indices) # Sort by original index
152
+ num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
153
+ keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
154
+ keep_old_indices = [i for i in keep_indices if i < num_old_detections]
155
+ if len(keep_new_indices) == 0:
156
+ print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current")
157
+ continue
158
+ # filter new detections and compute scores
159
+ new_dets = filter_instances(pose_instances, keep_new_indices)
160
+ new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
161
+ old_dets = None
162
+ if len(keep_old_indices) > 0:
163
+ old_dets = filter_instances(all_detections, keep_old_indices)
164
+ print_log(
165
+ "Pose NMS reduced instances to {:d} ({:d}+{:d}) instances".format(
166
+ len(new_dets.bboxes) + num_old_detections, num_old_detections, len(new_dets.bboxes)
167
+ ),
168
+ logger="current",
169
+ )
170
+
171
+ new_detections = process_image_with_SAM(
172
+ DotDict(bmp_config.sam2.prompting),
173
+ img.copy(),
174
+ sam2_model,
175
+ new_dets,
176
+ old_dets if old_dets is not None else None,
177
+ )
178
+
179
+ # Merge detections
180
+ if all_detections is None:
181
+ all_detections = new_detections
182
+ else:
183
+ all_detections = concat_instances(all_detections, new_dets)
184
+
185
+ # Step 4: Visualization
186
+ img_for_detection = visualize_itteration(
187
+ img.copy(),
188
+ all_detections,
189
+ iteration_idx=iteration,
190
+ output_root=str(output_dir),
191
+ img_name=img_path.stem,
192
+ )
193
+ print_log("Iteration {} completed".format(iteration + 1), logger="current")
194
+
195
+ # Create GIF of iterations if requested
196
+ if args.create_gif:
197
+ image_file = os.path.join(output_dir, "{:s}.jpg".format(img_path.stem))
198
+ create_GIF(
199
+ img_path=str(image_file),
200
+ output_root=str(output_dir),
201
+ bmp_x=bmp_config.num_bmp_iters,
202
+ )
203
+ return all_detections
204
+
205
+
206
+ def main() -> None:
207
+ """
208
+ Entry point for the BMP demo: loads models and processes one image.
209
+ """
210
+ args = parse_args()
211
+ bmp_config = parse_yaml_config(args.bmp_config)
212
+
213
+ # Ensure output root exists
214
+ mmengine.mkdir_or_exist(str(args.output_root))
215
+
216
+ # build detectors
217
+ detector = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device=args.device)
218
+ detector.cfg = adapt_mmdet_pipeline(detector.cfg)
219
+ if (
220
+ bmp_config.detector.det_config == bmp_config.detector.det_prime_config
221
+ and bmp_config.detector.det_checkpoint == bmp_config.detector.det_prime_checkpoint
222
+ ) or (bmp_config.detector.det_prime_config is None or bmp_config.detector.det_prime_checkpoint is None):
223
+ print_log("Using the same detector as D and D'", logger="current")
224
+ detector_prime = detector
225
+ else:
226
+ detector_prime = init_detector(
227
+ bmp_config.detector.det_prime_config, bmp_config.detector.det_prime_checkpoint, device=args.device
228
+ )
229
+ detector_prime.cfg = adapt_mmdet_pipeline(detector_prime.cfg)
230
+ print_log("Using a different detector for D'", logger="current")
231
+
232
+ # build pose estimator
233
+ pose_estimator = init_pose_estimator(
234
+ bmp_config.pose_estimator.pose_config,
235
+ bmp_config.pose_estimator.pose_checkpoint,
236
+ device=args.device,
237
+ cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
238
+ )
239
+
240
+ sam2 = prepare_sam2_model(
241
+ model_cfg=bmp_config.sam2.sam2_config,
242
+ model_checkpoint=bmp_config.sam2.sam2_checkpoint,
243
+ )
244
+
245
+ # Run inference on one image
246
+ _ = process_one_image(args, bmp_config, args.input, detector, detector_prime, pose_estimator, sam2)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
demo/demo_utils.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for the BMP demo:
3
+ - Visualization of detections, masks, and poses
4
+ - Mask and bounding-box processing
5
+ - Pose non-maximum suppression (NMS)
6
+ - Animated GIF creation of demo iterations
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import subprocess
13
+ from pathlib import Path
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import cv2
17
+ import numpy as np
18
+ from mmengine.logging import print_log
19
+ from mmengine.structures import InstanceData
20
+ from pycocotools import mask as Mask
21
+ from sam2.distinctipy import get_colors
22
+ from tqdm import tqdm
23
+
24
+ ### Visualization hyperparameters
25
+ MIN_CONTOUR_AREA: int = 50
26
+ BBOX_WEIGHT: float = 0.9
27
+ MASK_WEIGHT: float = 0.6
28
+ BACK_MASK_WEIGHT: float = 0.6
29
+ POSE_WEIGHT: float = 0.8
30
+
31
+
32
+ """
33
+ posevis is our custom visualization library for pose estimation. For compatibility, we also provide a lite version that has fewer features but still reproduces visualization from the paper.
34
+ """
35
+ try:
36
+ from posevis import pose_visualization
37
+ except ImportError:
38
+ from .posevis_lite import pose_visualization
39
+
40
+
41
+ class DotDict(dict):
42
+ """Dictionary with attribute access and nested dict wrapping."""
43
+
44
+ def __getattr__(self, name: str) -> any:
45
+ if name in self:
46
+ val = self[name]
47
+ if isinstance(val, dict):
48
+ val = DotDict(val)
49
+ self[name] = val
50
+ return val
51
+ raise AttributeError("No attribute named {!r}".format(name))
52
+
53
+ def __setattr__(self, name: str, value: any) -> None:
54
+ self[name] = value
55
+
56
+ def __delattr__(self, name: str) -> None:
57
+ if name in self:
58
+ del self[name]
59
+ else:
60
+ raise AttributeError("No attribute named {!r}".format(name))
61
+
62
+
63
+ def filter_instances(instances: InstanceData, indices):
64
+ """
65
+ Return a new InstanceData containing only the entries of 'instances' at the given indices.
66
+ """
67
+ if instances is None:
68
+ return None
69
+ data = {}
70
+ # Attributes to filter
71
+ for attr in [
72
+ "bboxes",
73
+ "bbox_scores",
74
+ "keypoints",
75
+ "keypoint_scores",
76
+ "scores",
77
+ "pred_masks",
78
+ "refined_masks",
79
+ "sam_scores",
80
+ "sam_kpts",
81
+ ]:
82
+ if hasattr(instances, attr):
83
+ arr = getattr(instances, attr)
84
+ data[attr] = arr[indices] if arr is not None else None
85
+ return InstanceData(**data)
86
+
87
+
88
+ def concat_instances(instances1: InstanceData, instances2: InstanceData):
89
+ """
90
+ Concatenate two InstanceData objects along the first axis, preserving order.
91
+ If instances1 or instances2 is None, returns the other.
92
+ """
93
+ if instances1 is None:
94
+ return instances2
95
+ if instances2 is None:
96
+ return instances1
97
+ data = {}
98
+ for attr in [
99
+ "bboxes",
100
+ "bbox_scores",
101
+ "keypoints",
102
+ "keypoint_scores",
103
+ "scores",
104
+ "pred_masks",
105
+ "refined_masks",
106
+ "sam_scores",
107
+ "sam_kpts",
108
+ ]:
109
+ arr1 = getattr(instances1, attr, None)
110
+ arr2 = getattr(instances2, attr, None)
111
+ if arr1 is None and arr2 is None:
112
+ continue
113
+ if arr1 is None:
114
+ data[attr] = arr2
115
+ elif arr2 is None:
116
+ data[attr] = arr1
117
+ else:
118
+ data[attr] = np.concatenate([arr1, arr2], axis=0)
119
+ return InstanceData(**data)
120
+
121
+
122
+ def _visualize_predictions(
123
+ img: np.ndarray,
124
+ bboxes: np.ndarray,
125
+ scores: np.ndarray,
126
+ masks: List[Optional[List[np.ndarray]]],
127
+ poses: List[Optional[np.ndarray]],
128
+ vis_type: str = "mask",
129
+ mask_is_binary: bool = False,
130
+ ) -> Tuple[np.ndarray, np.ndarray]:
131
+ """
132
+ Render bounding boxes, segmentation masks, and poses on the input image.
133
+
134
+ Args:
135
+ img (np.ndarray): BGR image of shape (H, W, 3).
136
+ bboxes (np.ndarray): Array of bounding boxes [x, y, w, h].
137
+ scores (np.ndarray): Confidence scores for each bbox.
138
+ masks (List[Optional[List[np.ndarray]]]): Polygon masks per instance.
139
+ poses (List[Optional[np.ndarray]]): Keypoint arrays per instance.
140
+ vis_type (str): Flags for visualization types separated by '+'.
141
+ mask_is_binary (bool): Whether input masks are binary arrays.
142
+
143
+ Returns:
144
+ Tuple[np.ndarray, np.ndarray]: The visualized image and color map.
145
+ """
146
+ vis_types = vis_type.split("+")
147
+
148
+ # # Filter-out small detections to make the visualization more clear
149
+ # new_bboxes = []
150
+ # new_scores = []
151
+ # new_masks = []
152
+ # new_poses = []
153
+ # size_thr = img.shape[0] * img.shape[1] * 0.01
154
+ # for bbox, score, mask, pose in zip(bboxes, scores, masks, poses):
155
+ # area = mask.sum() # Assume binary mask. OK for demo purposes
156
+ # if area > size_thr:
157
+ # new_bboxes.append(bbox)
158
+ # new_scores.append(score)
159
+ # new_masks.append(mask)
160
+ # new_poses.append(pose)
161
+ # bboxes = np.array(new_bboxes)
162
+ # scores = np.array(new_scores)
163
+ # masks = new_masks
164
+ # poses = new_poses
165
+
166
+ if mask_is_binary:
167
+ poly_masks: List[Optional[List[np.ndarray]]] = []
168
+ for binary_mask in masks:
169
+ if binary_mask is not None:
170
+ contours, _ = cv2.findContours(
171
+ (binary_mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
172
+ )
173
+ polys = [cnt.flatten() for cnt in contours if cv2.contourArea(cnt) >= MIN_CONTOUR_AREA]
174
+ else:
175
+ polys = None
176
+ poly_masks.append(polys)
177
+ masks = poly_masks # type: ignore
178
+
179
+ # Exclude white, black, and green colors from the palette as they are not distinctive
180
+ colors = (np.array(get_colors(len(bboxes), exclude_colors=[(0, 1, 0), (.5, .5, .5), (0, 0, 0), (1, 1, 1)], rng=0)) * 255).astype(
181
+ int
182
+ )
183
+
184
+
185
+ if "inv-mask" in vis_types:
186
+ stencil = np.zeros_like(img)
187
+
188
+ for bbox, score, mask_poly, pose, color in zip(bboxes, scores, masks, poses, colors):
189
+ bbox = _update_bbox_by_mask(list(map(int, bbox)), mask_poly, img.shape)
190
+ color_list = color.tolist()
191
+ img_copy = img.copy()
192
+
193
+ if "bbox" in vis_types:
194
+ x, y, w, h = bbox
195
+ cv2.rectangle(img_copy, (x, y), (x + w, y + h), color_list, 2)
196
+ img = cv2.addWeighted(img, 1 - BBOX_WEIGHT, img_copy, BBOX_WEIGHT, 0)
197
+
198
+ if mask_poly is not None and "mask" in vis_types:
199
+ for seg in mask_poly:
200
+ seg_pts = np.array(seg).reshape(-1, 1, 2).astype(int)
201
+ cv2.fillPoly(img_copy, [seg_pts], color_list)
202
+ img = cv2.addWeighted(img, 1 - MASK_WEIGHT, img_copy, MASK_WEIGHT, 0)
203
+
204
+ if mask_poly is not None and "mask-out" in vis_types:
205
+ for seg in mask_poly:
206
+ seg_pts = np.array(seg).reshape(-1, 1, 2).astype(int)
207
+ cv2.fillPoly(img, [seg_pts], (0, 0, 0))
208
+
209
+ if mask_poly is not None and "inv-mask" in vis_types:
210
+ for seg in mask_poly:
211
+ seg = np.array(seg).reshape(-1, 1, 2).astype(int)
212
+ if cv2.contourArea(seg) < MIN_CONTOUR_AREA:
213
+ continue
214
+ cv2.fillPoly(stencil, [seg], (255, 255, 255))
215
+
216
+ if pose is not None and "pose" in vis_types:
217
+ vis_img = pose_visualization(
218
+ img.copy(),
219
+ pose.reshape(-1, 3),
220
+ width_multiplier=8,
221
+ differ_individuals=True,
222
+ color=color_list,
223
+ keep_image_size=True,
224
+ )
225
+ img = cv2.addWeighted(img, 1 - POSE_WEIGHT, vis_img, POSE_WEIGHT, 0)
226
+
227
+ if "inv-mask" in vis_types:
228
+ img = cv2.addWeighted(img, 1 - BACK_MASK_WEIGHT, cv2.bitwise_and(img, stencil), BACK_MASK_WEIGHT, 0)
229
+
230
+ return img, colors
231
+
232
+
233
+ def visualize_itteration(
234
+ img: np.ndarray, detections: Any, iteration_idx: int, output_root: Path, img_name: str, with_text: bool = True
235
+ ) -> Optional[np.ndarray]:
236
+ """
237
+ Generate and save visualization images for each BMP iteration.
238
+
239
+ Args:
240
+ img (np.ndarray): Original input image.
241
+ detections: InstanceData containing bboxes, scores, masks, keypoints.
242
+ iteration_idx (int): Current iteration index (0-based).
243
+ output_root (Path): Directory to save output images.
244
+ img_name (str): Base name of the image without extension.
245
+ with_text (bool): Whether to overlay text labels.
246
+
247
+ Returns:
248
+ Optional[np.ndarray]: The masked-out image if generated, else None.
249
+ """
250
+ bboxes = detections.bboxes
251
+ scores = detections.scores
252
+ pred_masks = detections.pred_masks
253
+ refined_masks = detections.refined_masks
254
+ keypoints = detections.keypoints
255
+ sam_kpts = detections.sam_kpts
256
+
257
+ masked_out = None
258
+ for vis_def in [
259
+ {"type": "bbox+mask", "masks": pred_masks, "label": "Detector (out)"},
260
+ {"type": "inv-mask", "masks": pred_masks, "label": "MaskPose (in)"},
261
+ {"type": "inv-mask+pose", "masks": pred_masks, "label": "MaskPose (out)"},
262
+ {"type": "mask", "masks": refined_masks, "label": "SAM Masks"},
263
+ {"type": "mask-out", "masks": refined_masks, "label": "Mask-Out"},
264
+ {"type": "pose", "masks": refined_masks, "label": "Final Poses"},
265
+ ]:
266
+ vis_img, colors = _visualize_predictions(
267
+ img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
268
+ )
269
+ if vis_def["type"] == "mask-out":
270
+ masked_out = vis_img
271
+ if with_text:
272
+ label = "BMP {:d}x: {}".format(iteration_idx + 1, vis_def["label"])
273
+ cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
274
+ cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
275
+ out_path = os.path.join(
276
+ output_root, "{}_iter{}_{}.jpg".format(img_name, iteration_idx + 1, vis_def["label"].replace(" ", "_"))
277
+ )
278
+ cv2.imwrite(str(out_path), vis_img)
279
+
280
+ # Show prompting keypoints
281
+ tmp_img = img.copy()
282
+ for i, _ in enumerate(bboxes):
283
+ if len(sam_kpts[i]) > 0:
284
+ instance_color = colors[i].astype(int).tolist()
285
+ for kpt in sam_kpts[i]:
286
+ cv2.drawMarker(
287
+ tmp_img,
288
+ (int(kpt[0]), int(kpt[1])),
289
+ instance_color,
290
+ markerType=cv2.MARKER_CROSS,
291
+ markerSize=20,
292
+ thickness=3,
293
+ )
294
+ # Write the keypoint confidence next to the marker
295
+ cv2.putText(
296
+ tmp_img,
297
+ f"{kpt[2]:.2f}",
298
+ (int(kpt[0]) + 10, int(kpt[1]) - 10),
299
+ cv2.FONT_HERSHEY_SIMPLEX,
300
+ 0.5,
301
+ instance_color,
302
+ 1,
303
+ cv2.LINE_AA,
304
+ )
305
+ if with_text:
306
+ text = "BMP {:d}x: SAM prompts".format(iteration_idx + 1)
307
+ cv2.putText(tmp_img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3, cv2.LINE_AA)
308
+ cv2.putText(tmp_img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2, cv2.LINE_AA)
309
+ cv2.imwrite("{:s}/{:s}_iter{:d}_prompting_kpts.jpg".format(output_root, img_name, iteration_idx + 1), tmp_img)
310
+
311
+ return masked_out
312
+
313
+
314
+ def visualize_demo(
315
+ img: np.ndarray, detections: Any,
316
+ ) -> Optional[np.ndarray]:
317
+ """
318
+ Generate and save visualization images for each BMP iteration.
319
+
320
+ Args:
321
+ img (np.ndarray): Original input image.
322
+ detections: InstanceData containing bboxes, scores, masks, keypoints.
323
+ iteration_idx (int): Current iteration index (0-based).
324
+ output_root (Path): Directory to save output images.
325
+ img_name (str): Base name of the image without extension.
326
+ with_text (bool): Whether to overlay text labels.
327
+
328
+ Returns:
329
+ Optional[np.ndarray]: The masked-out image if generated, else None.
330
+ """
331
+ bboxes = detections.bboxes
332
+ scores = detections.scores
333
+ pred_masks = detections.pred_masks
334
+ refined_masks = detections.refined_masks
335
+ keypoints = detections.keypoints
336
+
337
+ returns = []
338
+ for vis_def in [
339
+ {"type": "mask-out", "masks": refined_masks, "label": ""},
340
+ {"type": "mask+pose", "masks": pred_masks, "label": "RTMDet-L"},
341
+ {"type": "mask+pose", "masks": refined_masks, "label": "BMP"},
342
+ ]:
343
+ vis_img, colors = _visualize_predictions(
344
+ img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
345
+ )
346
+ returns.append(vis_img)
347
+
348
+ return returns
349
+
350
+
351
+ def create_GIF(
352
+ img_path: Path,
353
+ output_root: Path,
354
+ bmp_x: int = 2,
355
+ ) -> None:
356
+ """
357
+ Compile iteration images into an animated GIF using ffmpeg.
358
+
359
+ Args:
360
+ img_path (Path): Path to a sample iteration image.
361
+ output_root (Path): Directory to save the GIF.
362
+ bmp_x (int): Number of BMP iterations.
363
+ duration_per_frame (int): Frame display duration in ms.
364
+
365
+ Raises:
366
+ RuntimeError: If ffmpeg is not available or images are missing.
367
+ """
368
+ display_dur = 1.5 # seconds
369
+ fade_dur = 1.0
370
+ fps = 10
371
+ scale_width = 300 # Resize width for GIF, height will be auto-scaled to maintain aspect ratio
372
+
373
+ # Check if ffmpeg is installed. If not, raise warning and return
374
+ if shutil.which("ffmpeg") is None:
375
+ print_log("FFMpeg is not installed. GIF creation will be skipped.", logger="current", level=logging.WARNING)
376
+ return
377
+ print_log("Creating GIF with FFmpeg...", logger="current")
378
+
379
+ dirname, filename = os.path.split(img_path)
380
+ img_name_wo_ext, _ = os.path.splitext(filename)
381
+
382
+ gif_image_names = [
383
+ "Detector_(out)",
384
+ "MaskPose_(in)",
385
+ "MaskPose_(out)",
386
+ "prompting_kpts",
387
+ "SAM_Masks",
388
+ "Mask-Out",
389
+ ]
390
+
391
+ # Create black image of the same size as the last image
392
+ last_img_path = os.path.join(dirname, "{}_iter1_{}".format(img_name_wo_ext, gif_image_names[0]) + ".jpg")
393
+ last_img = cv2.imread(last_img_path)
394
+ if last_img is None:
395
+ print_log("Could not read image {}.".format(last_img_path), logger="current", level=logging.ERROR)
396
+ return
397
+ black_img = np.zeros_like(last_img)
398
+ cv2.imwrite(os.path.join(dirname, "black_image.jpg"), black_img)
399
+
400
+ gif_images = []
401
+ for iter in range(bmp_x):
402
+ iter_img_path = os.path.join(dirname, "{}_iter{}_".format(img_name_wo_ext, iter + 1))
403
+ for img_name in gif_image_names:
404
+
405
+ if iter + 1 == bmp_x and img_name == "Mask-Out":
406
+ # Skip the last iteration's Mask-Out image
407
+ continue
408
+
409
+ img_file = "{}{}.jpg".format(iter_img_path, img_name)
410
+ if not os.path.exists(img_file):
411
+ print_log("{} does not exist, skipping.".format(img_file), logger="current", level=logging.WARNING)
412
+ continue
413
+ gif_images.append(img_file)
414
+
415
+ if len(gif_images) == 0:
416
+ print_log("No images found for GIF creation.", logger="current", level=logging.WARNING)
417
+ return
418
+
419
+ # Add 'before' and 'after' images
420
+ after1_img = os.path.join(dirname, "{}_iter{}_Final_Poses.jpg".format(img_name_wo_ext, bmp_x))
421
+ after2_img = os.path.join(dirname, "{}_iter{}_SAM_Masks.jpg".format(img_name_wo_ext, bmp_x))
422
+ # gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
423
+ gif_images.append(after1_img)
424
+ gif_images.append(after2_img)
425
+ gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
426
+
427
+ # Create a GIF from the images
428
+ gif_output_path = os.path.join(output_root, "{}_bmp_{}x.gif".format(img_name_wo_ext, bmp_x))
429
+
430
+ # 0. Make sure images exist and are divisible by 2
431
+ for img in gif_images:
432
+ if not os.path.exists(img):
433
+ print_log("Image {} does not exist, skipping GIF creation.".format(img), logger="current", level=logging.WARNING)
434
+ return
435
+ # Check if image dimensions are divisible by 2
436
+ img_data = cv2.imread(img)
437
+ if img_data.shape[1] % 2 != 0 or img_data.shape[0] % 2 != 0:
438
+ print_log(
439
+ "Image {} dimensions are not divisible by 2, resizing.".format(img),
440
+ logger="current",
441
+ level=logging.WARNING,
442
+ )
443
+ resized_img = cv2.resize(img_data, (img_data.shape[1] // 2 * 2, img_data.shape[0] // 2 * 2))
444
+ cv2.imwrite(img, resized_img)
445
+
446
+ # 1. inputs
447
+ in_args = []
448
+ for p in gif_images:
449
+ in_args += ["-loop", "1", "-t", str(display_dur), "-i", p]
450
+
451
+ # 2. build xfade chain
452
+ n = len(gif_images)
453
+ parts = []
454
+ for i in range(1, n):
455
+ # left label: first is input [0:v], then [v1], [v2], …
456
+ left = "[{}:v]".format(i - 1) if i == 1 else "[v{}]".format(i - 1)
457
+ right = "[{}:v]".format(i)
458
+ out = "[v{}]".format(i)
459
+ offset = (i - 1) * (display_dur + fade_dur) + display_dur
460
+ parts.append(
461
+ "{}{}xfade=transition=fade:".format(left, right)
462
+ + "duration={}:offset={:.3f}{}".format(fade_dur, offset, out)
463
+ )
464
+ filter_complex = ";".join(parts)
465
+
466
+ # 3. make MP4 slideshow
467
+ mp4 = "slideshow.mp4"
468
+ cmd1 = [
469
+ "ffmpeg",
470
+ "-loglevel",
471
+ "error",
472
+ "-v",
473
+ "quiet",
474
+ "-hide_banner",
475
+ "-y",
476
+ *in_args,
477
+ "-filter_complex",
478
+ filter_complex,
479
+ "-map",
480
+ "[v{}]".format(n - 1),
481
+ "-c:v",
482
+ "libx264",
483
+ "-pix_fmt",
484
+ "yuv420p",
485
+ mp4,
486
+ ]
487
+ subprocess.run(cmd1, check=True)
488
+
489
+ # 4. palette
490
+ palette = "palette.png"
491
+ vf = "fps={}".format(fps)
492
+ if scale_width:
493
+ vf += ",scale={}: -1:flags=lanczos".format(scale_width)
494
+
495
+ # 5. generate palette
496
+ subprocess.run(
497
+ [
498
+ "ffmpeg",
499
+ "-loglevel",
500
+ "error",
501
+ "-v",
502
+ "quiet",
503
+ "-hide_banner",
504
+ "-y",
505
+ "-i",
506
+ mp4,
507
+ "-vf",
508
+ vf + ",palettegen",
509
+ palette,
510
+ ],
511
+ check=True,
512
+ stdout=subprocess.DEVNULL,
513
+ stderr=subprocess.PIPE,
514
+ )
515
+
516
+ # 6. build final GIF
517
+ subprocess.run(
518
+ [
519
+ "ffmpeg",
520
+ "-loglevel",
521
+ "error",
522
+ "-v",
523
+ "quiet",
524
+ "-hide_banner",
525
+ "-y",
526
+ "-i",
527
+ mp4,
528
+ "-i",
529
+ palette,
530
+ "-lavfi",
531
+ vf + "[x];[x][1:v]paletteuse",
532
+ gif_output_path,
533
+ ],
534
+ check=True,
535
+ stdout=subprocess.DEVNULL,
536
+ stderr=subprocess.PIPE,
537
+ )
538
+
539
+ # Clean up temporary files
540
+ os.remove(mp4)
541
+ os.remove(palette)
542
+ os.remove(os.path.join(dirname, "black_image.jpg"))
543
+
544
+ print_log(f"GIF saved as '{gif_output_path}'", logger="current")
545
+
546
+
547
+ def _update_bbox_by_mask(
548
+ bbox: List[int], mask_poly: Optional[List[List[int]]], image_shape: Tuple[int, int, int]
549
+ ) -> List[int]:
550
+ """
551
+ Adjust bounding box to tightly fit mask polygon.
552
+
553
+ Args:
554
+ bbox (List[int]): Original [x, y, w, h].
555
+ mask_poly (Optional[List[List[int]]]): Polygon coordinates.
556
+ image_shape (Tuple[int,int,int]): Image shape (H, W, C).
557
+
558
+ Returns:
559
+ List[int]: Updated [x, y, w, h] bounding box.
560
+ """
561
+ if mask_poly is None or len(mask_poly) == 0:
562
+ return bbox
563
+
564
+ mask_rle = Mask.frPyObjects(mask_poly, image_shape[0], image_shape[1])
565
+ mask_rle = Mask.merge(mask_rle)
566
+ bbox_segm_xywh = Mask.toBbox(mask_rle)
567
+ bbox_segm_xyxy = np.array(
568
+ [
569
+ bbox_segm_xywh[0],
570
+ bbox_segm_xywh[1],
571
+ bbox_segm_xywh[0] + bbox_segm_xywh[2],
572
+ bbox_segm_xywh[1] + bbox_segm_xywh[3],
573
+ ]
574
+ )
575
+
576
+ bbox = bbox_segm_xywh
577
+
578
+ return bbox.astype(int).tolist()
579
+
580
+
581
+ def pose_nms(config: Any, image_kpts: np.ndarray, image_bboxes: np.ndarray, num_valid_kpts: np.ndarray) -> np.ndarray:
582
+ """
583
+ Perform OKS-based non-maximum suppression on detected poses.
584
+
585
+ Args:
586
+ config (Any): Configuration with confidence_thr and oks_thr.
587
+ image_kpts (np.ndarray): Detected keypoints of shape (N, K, 3).
588
+ image_bboxes (np.ndarray): Corresponding bboxes (N,4).
589
+ num_valid_kpts (np.ndarray): Count of valid keypoints per instance.
590
+
591
+ Returns:
592
+ np.ndarray: Indices of kept instances.
593
+ """
594
+ # Sort image kpts by average score - lowest first
595
+ # scores = image_kpts[:, :, 2].mean(axis=1)
596
+ # sort_idx = np.argsort(scores)
597
+ # image_kpts = image_kpts[sort_idx, :, :]
598
+
599
+ # Compute OKS between all pairs of poses
600
+ oks_matrix = np.zeros((image_kpts.shape[0], image_kpts.shape[0]))
601
+ for i in range(image_kpts.shape[0]):
602
+ for j in range(image_kpts.shape[0]):
603
+ gt_bbox_xywh = image_bboxes[i].copy()
604
+ gt_bbox_xyxy = gt_bbox_xywh.copy()
605
+ gt_bbox_xyxy[2:] += gt_bbox_xyxy[:2]
606
+ gt = {
607
+ "keypoints": image_kpts[i].copy(),
608
+ "bbox": gt_bbox_xyxy,
609
+ "area": gt_bbox_xywh[2] * gt_bbox_xywh[3],
610
+ }
611
+ dt = {"keypoints": image_kpts[j].copy(), "bbox": gt_bbox_xyxy}
612
+ gt["keypoints"][:, 2] = (gt["keypoints"][:, 2] > config.confidence_thr) * 2
613
+ oks = compute_oks(gt, dt)
614
+ if oks > 1:
615
+ breakpoint()
616
+ oks_matrix[i, j] = oks
617
+
618
+ np.fill_diagonal(oks_matrix, -1)
619
+ is_subset = oks_matrix > config.oks_thr
620
+
621
+ remove_instances = []
622
+ while is_subset.any():
623
+ # Find the pair with the highest OKS
624
+ i, j = np.unravel_index(np.argmax(oks_matrix), oks_matrix.shape)
625
+
626
+ # Keep the one with the highest number of keypoints
627
+ if num_valid_kpts[i] > num_valid_kpts[j]:
628
+ remove_idx = j
629
+ else:
630
+ remove_idx = i
631
+
632
+ # Remove the column from is_subset
633
+ oks_matrix[:, remove_idx] = 0
634
+ oks_matrix[remove_idx, j] = 0
635
+ remove_instances.append(remove_idx)
636
+ is_subset = oks_matrix > config.oks_thr
637
+
638
+ keep_instances = np.setdiff1d(np.arange(image_kpts.shape[0]), remove_instances)
639
+
640
+ return keep_instances
641
+
642
+
643
+ def compute_oks(gt: Dict[str, Any], dt: Dict[str, Any], use_area: bool = True, per_kpt: bool = False) -> float:
644
+ """
645
+ Compute Object Keypoint Similarity (OKS) between ground-truth and detected poses.
646
+
647
+ Args:
648
+ gt (Dict): Ground-truth keypoints and bbox info.
649
+ dt (Dict): Detected keypoints and bbox info.
650
+ use_area (bool): Whether to normalize by GT area.
651
+ per_kpt (bool): Whether to return per-keypoint OKS array.
652
+
653
+ Returns:
654
+ float: OKS score or mean OKS.
655
+ """
656
+ sigmas = (
657
+ np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
658
+ / 10.0
659
+ )
660
+ vars = (sigmas * 2) ** 2
661
+ k = len(sigmas)
662
+ visibility_condition = lambda x: x > 0
663
+ g = np.array(gt["keypoints"]).reshape(k, 3)
664
+ xg = g[:, 0]
665
+ yg = g[:, 1]
666
+ vg = g[:, 2]
667
+ k1 = np.count_nonzero(visibility_condition(vg))
668
+ bb = gt["bbox"]
669
+ x0 = bb[0] - bb[2]
670
+ x1 = bb[0] + bb[2] * 2
671
+ y0 = bb[1] - bb[3]
672
+ y1 = bb[1] + bb[3] * 2
673
+
674
+ d = np.array(dt["keypoints"]).reshape((k, 3))
675
+ xd = d[:, 0]
676
+ yd = d[:, 1]
677
+
678
+ if k1 > 0:
679
+ # measure the per-keypoint distance if keypoints visible
680
+ dx = xd - xg
681
+ dy = yd - yg
682
+
683
+ else:
684
+ # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
685
+ z = np.zeros((k))
686
+ dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
687
+ dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
688
+
689
+ if use_area:
690
+ e = (dx**2 + dy**2) / vars / (gt["area"] + np.spacing(1)) / 2
691
+ else:
692
+ tmparea = gt["bbox"][3] * gt["bbox"][2] * 0.53
693
+ e = (dx**2 + dy**2) / vars / (tmparea + np.spacing(1)) / 2
694
+
695
+ if per_kpt:
696
+ oks = np.exp(-e)
697
+ if k1 > 0:
698
+ oks[~visibility_condition(vg)] = 0
699
+
700
+ else:
701
+ if k1 > 0:
702
+ e = e[visibility_condition(vg)]
703
+ oks = np.sum(np.exp(-e)) / e.shape[0]
704
+
705
+ return oks
demo/mm_utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides high-level interfaces to run MMDetection and MMPose
3
+ models sequentially. Users can call run_MMDetector and run_MMPose from
4
+ other scripts (e.g., bmp_demo.py) to perform object detection and
5
+ pose estimation in a clean, modular fashion.
6
+ """
7
+
8
+ import numpy as np
9
+ from mmdet.apis import inference_detector
10
+ from mmengine.structures import InstanceData
11
+
12
+ from mmpose.apis import inference_topdown
13
+ from mmpose.evaluation.functional import nms
14
+ from mmpose.structures import merge_data_samples
15
+
16
+
17
+ def run_MMDetector(detector, image, det_cat_id: int = 0, bbox_thr: float = 0.3, nms_thr: float = 0.3) -> InstanceData:
18
+ """
19
+ Run an MMDetection model to detect bounding boxes (and masks) in an image.
20
+
21
+ Args:
22
+ detector: An initialized MMDetection detector model.
23
+ image: Input image as file path or BGR numpy array.
24
+ det_cat_id: Category ID to filter detections (default is 0 for 'person').
25
+ bbox_thr: Minimum bounding box score threshold.
26
+ nms_thr: IoU threshold for Non-Maximum Suppression (NMS).
27
+
28
+ Returns:
29
+ InstanceData: A structure containing filtered bboxes, bbox_scores, and masks (if available).
30
+ """
31
+ # Run detection
32
+ det_result = inference_detector(detector, image)
33
+ pred_instances = det_result.pred_instances.cpu().numpy()
34
+
35
+ # Aggregate bboxes and scores into an (N, 5) array
36
+ bboxes_all = np.concatenate((pred_instances.bboxes, pred_instances.scores[:, None]), axis=1)
37
+
38
+ # Filter by category and score
39
+ keep_mask = np.logical_and(pred_instances.labels == det_cat_id, pred_instances.scores > bbox_thr)
40
+ if not np.any(keep_mask):
41
+ # Return empty structure if nothing passes threshold
42
+ return InstanceData(bboxes=np.zeros((0, 4)), bbox_scores=np.zeros((0,)), masks=np.zeros((0, 1, 1)))
43
+
44
+ bboxes = bboxes_all[keep_mask]
45
+ masks = getattr(pred_instances, "masks", None)
46
+ if masks is not None:
47
+ masks = masks[keep_mask]
48
+
49
+ # Sort detections by descending score
50
+ order = np.argsort(bboxes[:, 4])[::-1]
51
+ bboxes = bboxes[order]
52
+ if masks is not None:
53
+ masks = masks[order]
54
+
55
+ # Apply Non-Maximum Suppression
56
+ keep_indices = nms(bboxes, nms_thr)
57
+ bboxes = bboxes[keep_indices]
58
+ if masks is not None:
59
+ masks = masks[keep_indices]
60
+
61
+ # Construct InstanceData to return
62
+ det_instances = InstanceData(bboxes=bboxes[:, :4], bbox_scores=bboxes[:, 4], masks=masks)
63
+ return det_instances
64
+
65
+
66
+ def run_MMPose(pose_estimator, image, detections: InstanceData, kpt_thr: float = 0.3) -> InstanceData:
67
+ """
68
+ Run an MMPose top-down model to estimate human pose given detected bounding boxes.
69
+
70
+ Args:
71
+ pose_estimator: An initialized MMPose model.
72
+ image: Input image as file path or RGB/BGR numpy array.
73
+ detections: InstanceData from run_MMDetector containing bboxes and masks.
74
+ kpt_thr: Minimum keypoint score threshold to filter low-confidence joints.
75
+
76
+ Returns:
77
+ InstanceData: A structure containing estimated keypoints, keypoint_scores,
78
+ original bboxes, and masks (if provided).
79
+ """
80
+ # Extract bounding boxes
81
+ bboxes = detections.bboxes
82
+ if bboxes.shape[0] == 0:
83
+ # No detections => empty pose data
84
+ return InstanceData(
85
+ keypoints=np.zeros((0, 17, 3)),
86
+ keypoint_scores=np.zeros((0, 17)),
87
+ bboxes=bboxes,
88
+ bbox_scores=detections.bbox_scores,
89
+ masks=detections.masks,
90
+ )
91
+
92
+ # Run top-down pose estimation
93
+ pose_results = inference_topdown(pose_estimator, image, bboxes, masks=detections.masks)
94
+ data_samples = merge_data_samples(pose_results)
95
+
96
+ # Attach masks back into the data_samples if available
97
+ if detections.masks is not None:
98
+ data_samples.pred_instances.pred_masks = detections.masks
99
+
100
+ # Filter out low-confidence keypoints
101
+ kp_scores = data_samples.pred_instances.keypoint_scores
102
+ kp_mask = kp_scores >= kpt_thr
103
+ # data_samples.pred_instances.keypoints[~kp_mask] = [0, 0, 0]
104
+
105
+ # Return final InstanceData for poses
106
+ return data_samples.pred_instances
demo/posevis_lite.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ NEUTRAL_COLOR = (52, 235, 107)
8
+
9
+ LEFT_ARM_COLOR = (216, 235, 52)
10
+ LEFT_LEG_COLOR = (235, 107, 52)
11
+ LEFT_SIDE_COLOR = (245, 188, 113)
12
+ LEFT_FACE_COLOR = (235, 52, 107)
13
+
14
+ RIGHT_ARM_COLOR = (52, 235, 216)
15
+ RIGHT_LEG_COLOR = (52, 107, 235)
16
+ RIGHT_SIDE_COLOR = (52, 171, 235)
17
+ RIGHT_FACE_COLOR = (107, 52, 235)
18
+
19
+ COCO_MARKERS = [
20
+ ["nose", cv2.MARKER_CROSS, NEUTRAL_COLOR],
21
+ ["left_eye", cv2.MARKER_SQUARE, LEFT_FACE_COLOR],
22
+ ["right_eye", cv2.MARKER_SQUARE, RIGHT_FACE_COLOR],
23
+ ["left_ear", cv2.MARKER_CROSS, LEFT_FACE_COLOR],
24
+ ["right_ear", cv2.MARKER_CROSS, RIGHT_FACE_COLOR],
25
+ ["left_shoulder", cv2.MARKER_TRIANGLE_UP, LEFT_ARM_COLOR],
26
+ ["right_shoulder", cv2.MARKER_TRIANGLE_UP, RIGHT_ARM_COLOR],
27
+ ["left_elbow", cv2.MARKER_SQUARE, LEFT_ARM_COLOR],
28
+ ["right_elbow", cv2.MARKER_SQUARE, RIGHT_ARM_COLOR],
29
+ ["left_wrist", cv2.MARKER_CROSS, LEFT_ARM_COLOR],
30
+ ["right_wrist", cv2.MARKER_CROSS, RIGHT_ARM_COLOR],
31
+ ["left_hip", cv2.MARKER_TRIANGLE_UP, LEFT_LEG_COLOR],
32
+ ["right_hip", cv2.MARKER_TRIANGLE_UP, RIGHT_LEG_COLOR],
33
+ ["left_knee", cv2.MARKER_SQUARE, LEFT_LEG_COLOR],
34
+ ["right_knee", cv2.MARKER_SQUARE, RIGHT_LEG_COLOR],
35
+ ["left_ankle", cv2.MARKER_TILTED_CROSS, LEFT_LEG_COLOR],
36
+ ["right_ankle", cv2.MARKER_TILTED_CROSS, RIGHT_LEG_COLOR],
37
+ ]
38
+
39
+
40
+ COCO_SKELETON = [
41
+ [[16, 14], LEFT_LEG_COLOR], # Left ankle - Left knee
42
+ [[14, 12], LEFT_LEG_COLOR], # Left knee - Left hip
43
+ [[17, 15], RIGHT_LEG_COLOR], # Right ankle - Right knee
44
+ [[15, 13], RIGHT_LEG_COLOR], # Right knee - Right hip
45
+ [[12, 13], NEUTRAL_COLOR], # Left hip - Right hip
46
+ [[6, 12], LEFT_SIDE_COLOR], # Left hip - Left shoulder
47
+ [[7, 13], RIGHT_SIDE_COLOR], # Right hip - Right shoulder
48
+ [[6, 7], NEUTRAL_COLOR], # Left shoulder - Right shoulder
49
+ [[6, 8], LEFT_ARM_COLOR], # Left shoulder - Left elbow
50
+ [[7, 9], RIGHT_ARM_COLOR], # Right shoulder - Right elbow
51
+ [[8, 10], LEFT_ARM_COLOR], # Left elbow - Left wrist
52
+ [[9, 11], RIGHT_ARM_COLOR], # Right elbow - Right wrist
53
+ [[2, 3], NEUTRAL_COLOR], # Left eye - Right eye
54
+ [[1, 2], LEFT_FACE_COLOR], # Nose - Left eye
55
+ [[1, 3], RIGHT_FACE_COLOR], # Nose - Right eye
56
+ [[2, 4], LEFT_FACE_COLOR], # Left eye - Left ear
57
+ [[3, 5], RIGHT_FACE_COLOR], # Right eye - Right ear
58
+ [[4, 6], LEFT_FACE_COLOR], # Left ear - Left shoulder
59
+ [[5, 7], RIGHT_FACE_COLOR], # Right ear - Right shoulder
60
+ ]
61
+
62
+
63
+ def _draw_line(
64
+ img: np.ndarray,
65
+ start: Tuple[float, float],
66
+ stop: Tuple[float, float],
67
+ color: Tuple[int, int, int],
68
+ line_type: str,
69
+ thickness: int = 1,
70
+ ) -> np.ndarray:
71
+ """
72
+ Draw a line segment on an image, supporting solid, dashed, or dotted styles.
73
+
74
+ Args:
75
+ img (np.ndarray): BGR image of shape (H, W, 3).
76
+ start (tuple of float): (x, y) start coordinates.
77
+ stop (tuple of float): (x, y) end coordinates.
78
+ color (tuple of int): BGR color values.
79
+ line_type (str): One of 'solid', 'dashed', or 'doted'.
80
+ thickness (int): Line thickness in pixels.
81
+
82
+ Returns:
83
+ np.ndarray: Image with the line drawn.
84
+ """
85
+ start = np.array(start)[:2]
86
+ stop = np.array(stop)[:2]
87
+ if line_type.lower() == "solid":
88
+ img = cv2.line(
89
+ img,
90
+ (int(start[0]), int(start[1])),
91
+ (int(stop[0]), int(stop[1])),
92
+ color=(0, 0, 0),
93
+ thickness=thickness+1,
94
+ lineType=cv2.LINE_AA,
95
+ )
96
+ img = cv2.line(
97
+ img,
98
+ (int(start[0]), int(start[1])),
99
+ (int(stop[0]), int(stop[1])),
100
+ color=color,
101
+ thickness=thickness,
102
+ lineType=cv2.LINE_AA,
103
+ )
104
+ elif line_type.lower() == "dashed":
105
+ delta = stop - start
106
+ length = np.linalg.norm(delta)
107
+ frac = np.linspace(0, 1, num=int(length / 5), endpoint=True)
108
+ for i in range(0, len(frac) - 1, 2):
109
+ s = start + frac[i] * delta
110
+ e = start + frac[i + 1] * delta
111
+ img = cv2.line(
112
+ img,
113
+ (int(s[0]), int(s[1])),
114
+ (int(e[0]), int(e[1])),
115
+ color=color,
116
+ thickness=thickness,
117
+ lineType=cv2.LINE_AA,
118
+ )
119
+ elif line_type.lower() == "doted":
120
+ delta = stop - start
121
+ length = np.linalg.norm(delta)
122
+ frac = np.linspace(0, 1, num=int(length / 5), endpoint=True)
123
+ for i in range(0, len(frac)):
124
+ s = start + frac[i] * delta
125
+ img = cv2.circle(
126
+ img,
127
+ (int(s[0]), int(s[1])),
128
+ radius=max(thickness // 2, 1),
129
+ color=color,
130
+ thickness=-1,
131
+ lineType=cv2.LINE_AA,
132
+ )
133
+ return img
134
+
135
+
136
+ def pose_visualization(
137
+ img: Union[str, np.ndarray],
138
+ keypoints: Union[Dict[str, Any], np.ndarray],
139
+ format: str = "COCO",
140
+ greyness: float = 1.0,
141
+ show_markers: bool = True,
142
+ show_bones: bool = True,
143
+ line_type: str = "solid",
144
+ width_multiplier: float = 1.0,
145
+ bbox_width_multiplier: float = 1.0,
146
+ show_bbox: bool = False,
147
+ differ_individuals: bool = False,
148
+ confidence_thr: float = 0.3,
149
+ errors: Optional[np.ndarray] = None,
150
+ color: Optional[Tuple[int, int, int]] = None,
151
+ keep_image_size: bool = False,
152
+ return_padding: bool = False,
153
+ ) -> Union[np.ndarray, Tuple[np.ndarray, List[int]]]:
154
+ """
155
+ Overlay pose keypoints and skeleton on an image.
156
+
157
+ Args:
158
+ img (str or np.ndarray): Path to image file or BGR image array.
159
+ keypoints (dict or np.ndarray): Either a dict with 'bbox' and 'keypoints' or
160
+ an array of shape (17, 2 or 3) or multiple poses stacked.
161
+ format (str): Keypoint format, currently only 'COCO'.
162
+ greyness (float): Factor for bone/marker color intensity (0.0-1.0).
163
+ show_markers (bool): Whether to draw keypoint markers.
164
+ show_bones (bool): Whether to draw skeleton bones.
165
+ line_type (str): One of 'solid', 'dashed', 'doted' for bone style.
166
+ width_multiplier (float): Line width scaling factor for bones.
167
+ bbox_width_multiplier (float): Line width scaling factor for bounding box.
168
+ show_bbox (bool): Whether to draw bounding box around keypoints.
169
+ differ_individuals (bool): Use distinct color per individual pose.
170
+ confidence_thr (float): Confidence threshold for keypoint visibility.
171
+ errors (np.ndarray or None): Optional array of per-kpt errors (17,1).
172
+ color (tuple or None): Override color for markers and bones.
173
+ keep_image_size (bool): Prevent image padding for out-of-bounds keypoints.
174
+ return_padding (bool): If True, also return padding offsets [top,bottom,left,right].
175
+
176
+ Returns:
177
+ np.ndarray or (np.ndarray, list of int): Annotated image, and optional
178
+ padding offsets if `return_padding` is True.
179
+ """
180
+
181
+ bbox = None
182
+ if isinstance(keypoints, dict):
183
+ try:
184
+ bbox = np.array(keypoints["bbox"]).flatten()
185
+ except KeyError:
186
+ pass
187
+ keypoints = np.array(keypoints["keypoints"])
188
+
189
+ # If keypoints is a list of poses, draw them all
190
+ if len(keypoints) % 17 != 0 or keypoints.ndim == 3:
191
+
192
+ if color is not None:
193
+ if not isinstance(color, (list, tuple)):
194
+ color = [color for keypoint in keypoints]
195
+ else:
196
+ color = [None for keypoint in keypoints]
197
+
198
+ max_padding = [0, 0, 0, 0]
199
+ for keypoint, clr in zip(keypoints, color):
200
+ img = pose_visualization(
201
+ img,
202
+ keypoint,
203
+ format=format,
204
+ greyness=greyness,
205
+ show_markers=show_markers,
206
+ show_bones=show_bones,
207
+ line_type=line_type,
208
+ width_multiplier=width_multiplier,
209
+ bbox_width_multiplier=bbox_width_multiplier,
210
+ show_bbox=show_bbox,
211
+ differ_individuals=differ_individuals,
212
+ color=clr,
213
+ confidence_thr=confidence_thr,
214
+ keep_image_size=keep_image_size,
215
+ return_padding=return_padding,
216
+ )
217
+ if return_padding:
218
+ img, padding = img
219
+ max_padding = [max(max_padding[i], int(padding[i])) for i in range(4)]
220
+
221
+ if return_padding:
222
+ return img, max_padding
223
+ else:
224
+ return img
225
+
226
+ keypoints = np.array(keypoints).reshape(17, -1)
227
+ # If keypoint visibility is not provided, assume all keypoints are visible
228
+ if keypoints.shape[1] == 2:
229
+ keypoints = np.hstack([keypoints, np.ones((17, 1)) * 2])
230
+
231
+ assert keypoints.shape[1] == 3, "Keypoints should be in the format (x, y, visibility)"
232
+ assert keypoints.shape[0] == 17, "Keypoints should be in the format (x, y, visibility)"
233
+
234
+ if errors is not None:
235
+ errors = np.array(errors).reshape(17, -1)
236
+ assert errors.shape[1] == 1, "Errors should be in the format (K, r)"
237
+ assert errors.shape[0] == 17, "Errors should be in the format (K, r)"
238
+ else:
239
+ errors = np.ones((17, 1)) * np.nan
240
+
241
+ # If keypoint visibility is float between 0 and 1, it is detection
242
+ # If conf < confidence_thr: conf = 1
243
+ # If conf >= confidence_thr: conf = 2
244
+ vis_is_float = np.any(np.logical_and(keypoints[:, -1] > 0, keypoints[:, -1] < 1))
245
+ if keypoints.shape[1] == 3 and vis_is_float:
246
+ # print("before", keypoints[:, -1])
247
+ lower_idx = keypoints[:, -1] < confidence_thr
248
+ keypoints[lower_idx, -1] = 1
249
+ keypoints[~lower_idx, -1] = 2
250
+ # print("after", keypoints[:, -1])
251
+ # print("-"*20)
252
+
253
+ # All visibility values should be ints
254
+ keypoints[:, -1] = keypoints[:, -1].astype(int)
255
+
256
+ if isinstance(img, str):
257
+ img = cv2.imread(img)
258
+
259
+ if img is None:
260
+ if return_padding:
261
+ return None, [0, 0, 0, 0]
262
+ else:
263
+ return None
264
+
265
+ if not (keypoints[:, 2] > 0).any():
266
+ if return_padding:
267
+ return img, [0, 0, 0, 0]
268
+ else:
269
+ return img
270
+
271
+ valid_kpts = (keypoints[:, 0] > 0) & (keypoints[:, 1] > 0)
272
+ num_valid_kpts = np.sum(valid_kpts)
273
+
274
+ if num_valid_kpts == 0:
275
+ if return_padding:
276
+ return img, [0, 0, 0, 0]
277
+ else:
278
+ return img
279
+
280
+ min_x_kpts = np.min(keypoints[keypoints[:, 2] > 0, 0])
281
+ min_y_kpts = np.min(keypoints[keypoints[:, 2] > 0, 1])
282
+ max_x_kpts = np.max(keypoints[keypoints[:, 2] > 0, 0])
283
+ max_y_kpts = np.max(keypoints[keypoints[:, 2] > 0, 1])
284
+ if bbox is None:
285
+ min_x = min_x_kpts
286
+ min_y = min_y_kpts
287
+ max_x = max_x_kpts
288
+ max_y = max_y_kpts
289
+ else:
290
+ min_x = bbox[0]
291
+ min_y = bbox[1]
292
+ max_x = bbox[2]
293
+ max_y = bbox[3]
294
+
295
+ max_area = (max_x - min_x) * (max_y - min_y)
296
+ diagonal = np.sqrt((max_x - min_x) ** 2 + (max_y - min_y) ** 2)
297
+ line_width = max(int(np.sqrt(max_area) / 500 * width_multiplier), 1)
298
+ bbox_line_width = max(int(np.sqrt(max_area) / 500 * bbox_width_multiplier), 1)
299
+ marker_size = max(int(np.sqrt(max_area) / 80), 1)
300
+ invisible_marker_size = max(int(np.sqrt(max_area) / 100), 1)
301
+ marker_thickness = max(int(np.sqrt(max_area) / 100), 1)
302
+
303
+ if differ_individuals:
304
+ if color is not None:
305
+ instance_color = color
306
+ else:
307
+ instance_color = np.random.randint(0, 255, size=(3,)).tolist()
308
+ instance_color = tuple(instance_color)
309
+
310
+ # Pad image with dark gray if keypoints are outside the image
311
+ if not keep_image_size:
312
+ padding = [
313
+ max(0, -min_y_kpts),
314
+ max(0, max_y_kpts - img.shape[0]),
315
+ max(0, -min_x_kpts),
316
+ max(0, max_x_kpts - img.shape[1]),
317
+ ]
318
+ padding = [int(p) for p in padding]
319
+ img = cv2.copyMakeBorder(
320
+ img,
321
+ padding[0],
322
+ padding[1],
323
+ padding[2],
324
+ padding[3],
325
+ cv2.BORDER_CONSTANT,
326
+ value=(80, 80, 80),
327
+ )
328
+
329
+ # Add padding to bbox and kpts
330
+ value_x_to_add = max(0, -min_x_kpts)
331
+ value_y_to_add = max(0, -min_y_kpts)
332
+ keypoints[keypoints[:, 2] > 0, 0] += value_x_to_add
333
+ keypoints[keypoints[:, 2] > 0, 1] += value_y_to_add
334
+ if bbox is not None:
335
+ bbox[0] += value_x_to_add
336
+ bbox[1] += value_y_to_add
337
+ bbox[2] += value_x_to_add
338
+ bbox[3] += value_y_to_add
339
+
340
+ if show_bbox and not (bbox is None):
341
+ pts = [
342
+ (bbox[0], bbox[1]),
343
+ (bbox[0], bbox[3]),
344
+ (bbox[2], bbox[3]),
345
+ (bbox[2], bbox[1]),
346
+ (bbox[0], bbox[1]),
347
+ ]
348
+ for i in range(len(pts) - 1):
349
+ if differ_individuals:
350
+ img = _draw_line(img, pts[i], pts[i + 1], instance_color, "doted", thickness=bbox_line_width)
351
+ else:
352
+ img = _draw_line(img, pts[i], pts[i + 1], (0, 255, 0), line_type, thickness=bbox_line_width)
353
+
354
+ if show_markers:
355
+ for kpt, marker_info, err in zip(keypoints, COCO_MARKERS, errors):
356
+ if kpt[0] == 0 and kpt[1] == 0:
357
+ continue
358
+
359
+ if kpt[2] != 2:
360
+ color = (140, 140, 140)
361
+ elif differ_individuals:
362
+ color = instance_color
363
+ else:
364
+ color = marker_info[2]
365
+
366
+ if kpt[2] == 1:
367
+ img_overlay = img.copy()
368
+ img_overlay = cv2.drawMarker(
369
+ img_overlay,
370
+ (int(kpt[0]), int(kpt[1])),
371
+ color=color,
372
+ markerType=marker_info[1],
373
+ markerSize=marker_size,
374
+ thickness=marker_thickness,
375
+ )
376
+ img = cv2.addWeighted(img_overlay, 0.4, img, 0.6, 0)
377
+
378
+ else:
379
+ img = cv2.drawMarker(
380
+ img,
381
+ (int(kpt[0]), int(kpt[1])),
382
+ color=color,
383
+ markerType=marker_info[1],
384
+ markerSize=invisible_marker_size if kpt[2] == 1 else marker_size,
385
+ thickness=marker_thickness,
386
+ )
387
+
388
+ if not np.isnan(err).any():
389
+ radius = err * diagonal
390
+ clr = (0, 0, 255) if "solid" in line_type else (0, 255, 0)
391
+ plus = 1 if "solid" in line_type else -1
392
+ img = cv2.circle(
393
+ img,
394
+ (int(kpt[0]), int(kpt[1])),
395
+ radius=int(radius),
396
+ color=clr,
397
+ thickness=1,
398
+ lineType=cv2.LINE_AA,
399
+ )
400
+ dx = np.sqrt(radius**2 / 2)
401
+ img = cv2.line(
402
+ img,
403
+ (int(kpt[0]), int(kpt[1])),
404
+ (int(kpt[0] + plus * dx), int(kpt[1] - dx)),
405
+ color=clr,
406
+ thickness=1,
407
+ lineType=cv2.LINE_AA,
408
+ )
409
+
410
+ if show_bones:
411
+ for bone_info in COCO_SKELETON:
412
+ kp1 = keypoints[bone_info[0][0] - 1, :]
413
+ kp2 = keypoints[bone_info[0][1] - 1, :]
414
+
415
+ if (kp1[0] == 0 and kp1[1] == 0) or (kp2[0] == 0 and kp2[1] == 0):
416
+ continue
417
+
418
+ dashed = kp1[2] == 1 or kp2[2] == 1
419
+
420
+ if differ_individuals:
421
+ color = np.array(instance_color)
422
+ else:
423
+ color = np.array(bone_info[1])
424
+ color = (color * greyness).astype(int).tolist()
425
+
426
+ if dashed:
427
+ img_overlay = img.copy()
428
+ img_overlay = _draw_line(img_overlay, kp1, kp2, color, line_type, thickness=line_width)
429
+ img = cv2.addWeighted(img_overlay, 0.4, img, 0.6, 0)
430
+
431
+ else:
432
+ img = _draw_line(img, kp1, kp2, color, line_type, thickness=line_width)
433
+
434
+ if return_padding:
435
+ return img, padding
436
+ else:
437
+ return img
438
+
439
+
440
+ if __name__ == "__main__":
441
+ kpts = np.array(
442
+ [
443
+ 344,
444
+ 222,
445
+ 2,
446
+ 356,
447
+ 211,
448
+ 2,
449
+ 330,
450
+ 211,
451
+ 2,
452
+ 372,
453
+ 220,
454
+ 2,
455
+ 309,
456
+ 224,
457
+ 2,
458
+ 413,
459
+ 279,
460
+ 2,
461
+ 274,
462
+ 300,
463
+ 2,
464
+ 444,
465
+ 372,
466
+ 2,
467
+ 261,
468
+ 396,
469
+ 2,
470
+ 398,
471
+ 359,
472
+ 2,
473
+ 316,
474
+ 372,
475
+ 2,
476
+ 407,
477
+ 489,
478
+ 2,
479
+ 185,
480
+ 580,
481
+ 2,
482
+ 0,
483
+ 0,
484
+ 0,
485
+ 0,
486
+ 0,
487
+ 0,
488
+ 0,
489
+ 0,
490
+ 0,
491
+ 0,
492
+ 0,
493
+ 0,
494
+ ]
495
+ )
496
+
497
+ kpts = kpts.reshape(-1, 3)
498
+ kpts[:, -1] = np.random.randint(1, 3, size=(17,))
499
+
500
+ img = pose_visualization("demo/posevis_test.jpg", kpts, show_markers=True, line_type="solid")
501
+
502
+ kpts2 = kpts.copy()
503
+ kpts2[kpts2[:, 1] > 0, :2] += 10
504
+ img = pose_visualization(img, kpts2, show_markers=False, line_type="doted")
505
+
506
+ os.makedirs("demo/outputs", exist_ok=True)
507
+ cv2.imwrite("demo/outputs/posevis_test_out.jpg", img)
demo/sam2_utils.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 utilities for BMP demo:
3
+ - Build and prepare SAM model
4
+ - Convert poses to segmentation
5
+ - Compute mask-pose consistency
6
+ """
7
+
8
+ from typing import Any, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from mmengine.structures import InstanceData
13
+ from pycocotools import mask as Mask
14
+ from sam2.build_sam import build_sam2
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+
17
+ # Threshold for keypoint validity in mask-pose consistency
18
+ STRICT_KPT_THRESHOLD: float = 0.5
19
+
20
+
21
+ def _validate_sam_args(sam_args):
22
+ """
23
+ Validate that all required sam_args attributes are present.
24
+ """
25
+ required = [
26
+ "crop",
27
+ "use_bbox",
28
+ "confidence_thr",
29
+ "ignore_small_bboxes",
30
+ "num_pos_keypoints",
31
+ "num_pos_keypoints_if_crowd",
32
+ "crowd_by_max_iou",
33
+ "batch",
34
+ "exclusive_masks",
35
+ "extend_bbox",
36
+ "pose_mask_consistency",
37
+ "visibility_thr",
38
+ ]
39
+ for param in required:
40
+ if not hasattr(sam_args, param):
41
+ raise AttributeError(f"Missing required arg {param} in sam_args")
42
+
43
+
44
+ def _get_max_ious(bboxes: List[np.ndarray]) -> np.ndarray:
45
+ """
46
+ Compute maximum IoU for each bbox against others.
47
+ """
48
+ is_crowd = [0] * len(bboxes)
49
+ ious = Mask.iou(bboxes, bboxes, is_crowd)
50
+ mat = np.array(ious)
51
+ np.fill_diagonal(mat, 0)
52
+ return mat.max(axis=1)
53
+
54
+
55
+ def _compute_one_mask_pose_consistency(
56
+ mask: np.ndarray, pos_keypoints: Optional[np.ndarray] = None, neg_keypoints: Optional[np.ndarray] = None
57
+ ) -> float:
58
+ """
59
+ Compute a consistency score between a mask and given keypoints.
60
+
61
+ Args:
62
+ mask (np.ndarray): Binary mask of shape (H, W).
63
+ pos_keypoints (Optional[np.ndarray]): Positive keypoints array (N, 3).
64
+ neg_keypoints (Optional[np.ndarray]): Negative keypoints array (M, 3).
65
+
66
+ Returns:
67
+ float: Weighted mean of positive and negative keypoint consistency.
68
+ """
69
+ if mask is None:
70
+ return 0.0
71
+
72
+ def _mean_inside(points: np.ndarray) -> float:
73
+ if points.size == 0:
74
+ return 0.0
75
+ pts_int = np.floor(points[:, :2]).astype(int)
76
+ pts_int[:, 0] = np.clip(pts_int[:, 0], 0, mask.shape[1] - 1)
77
+ pts_int[:, 1] = np.clip(pts_int[:, 1], 0, mask.shape[0] - 1)
78
+ vals = mask[pts_int[:, 1], pts_int[:, 0]]
79
+ return vals.mean() if vals.size > 0 else 0.0
80
+
81
+ pos_mean = 0.0
82
+ if pos_keypoints is not None:
83
+ valid = pos_keypoints[:, 2] > STRICT_KPT_THRESHOLD
84
+ pos_mean = _mean_inside(pos_keypoints[valid])
85
+
86
+ neg_mean = 0.0
87
+ if neg_keypoints is not None:
88
+ valid = neg_keypoints[:, 2] > STRICT_KPT_THRESHOLD
89
+ pts = neg_keypoints[valid][:, :2]
90
+ inside = mask[np.floor(pts[:, 1]).astype(int), np.floor(pts[:, 0]).astype(int)]
91
+ neg_mean = (~inside.astype(bool)).mean() if inside.size > 0 else 0.0
92
+
93
+ return 0.5 * pos_mean + 0.5 * neg_mean
94
+
95
+
96
+ def _select_keypoints(
97
+ args: Any,
98
+ kpts: np.ndarray,
99
+ num_visible: int,
100
+ bbox: Optional[Tuple[float, float, float, float]] = None,
101
+ method: Optional[str] = "distance+confidence",
102
+ ) -> Tuple[np.ndarray, np.ndarray]:
103
+ """
104
+ Select and order keypoints for SAM prompting based on specified method.
105
+
106
+ Args:
107
+ args: Configuration object with selection_method and visibility_thr attributes.
108
+ kpts (np.ndarray): Keypoints array of shape (K, 3).
109
+ num_visible (int): Number of keypoints above visibility threshold.
110
+ bbox (Optional[Tuple]): Optional bbox for distance methods.
111
+ method (Optional[str]): Override selection method.
112
+
113
+ Returns:
114
+ Tuple[np.ndarray, np.ndarray]: Selected keypoint coordinates (N,2) and confidences (N,).
115
+
116
+ Raises:
117
+ ValueError: If an unknown method is specified.
118
+ """
119
+ if num_visible == 0:
120
+ return kpts[:, :2], kpts[:, 2]
121
+
122
+ methods = ["confidence", "distance", "distance+confidence", "closest"]
123
+ sel_method = method or args.selection_method
124
+ if sel_method not in methods:
125
+ raise ValueError("Unknown method for keypoint selection: {}".format(sel_method))
126
+
127
+ # Select at maximum keypoint from the face
128
+ facial_kpts = kpts[:3, :]
129
+ facial_conf = kpts[:3, 2]
130
+ facial_point = facial_kpts[np.argmax(facial_conf)]
131
+ if facial_point[-1] >= args.visibility_thr:
132
+ kpts = np.concatenate([facial_point[None, :], kpts[3:]], axis=0)
133
+
134
+ conf = kpts[:, 2]
135
+ vis_mask = conf >= args.visibility_thr
136
+ coords = kpts[vis_mask, :2]
137
+ confs = conf[vis_mask]
138
+
139
+ if sel_method == "confidence":
140
+ order = np.argsort(confs)[::-1]
141
+ coords = coords[order]
142
+ confs = confs[order]
143
+ elif sel_method == "distance":
144
+ if bbox is None:
145
+ bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()])
146
+ else:
147
+ bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
148
+ dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1)
149
+ dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2)
150
+ np.fill_diagonal(dist_matrix, np.inf)
151
+ min_inter_dist = np.min(dist_matrix, axis=1)
152
+ order = np.argsort(dists + 3 * min_inter_dist)[::-1]
153
+ coords = coords[order, :2]
154
+ confs = confs[order]
155
+ elif sel_method == "distance+confidence":
156
+ order = np.argsort(confs)[::-1]
157
+ confidences = kpts[order, 2]
158
+ coords = coords[order, :2]
159
+ confs = confs[order]
160
+
161
+ dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2)
162
+
163
+ selected_idx = [0]
164
+ confidences[0] = -1
165
+ for _ in range(coords.shape[0] - 1):
166
+ min_dist = np.min(dist_matrix[:, selected_idx], axis=1)
167
+ min_dist[confidences < np.percentile(confidences, 80)] = -1
168
+
169
+ next_idx = np.argmax(min_dist)
170
+ selected_idx.append(next_idx)
171
+ confidences[next_idx] = -1
172
+
173
+ coords = coords[selected_idx]
174
+ confs = confs[selected_idx]
175
+ elif sel_method == "closest":
176
+ coords = coords[confs > STRICT_KPT_THRESHOLD, :]
177
+ confs = confs[confs > STRICT_KPT_THRESHOLD]
178
+ if bbox is None:
179
+ bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()])
180
+ else:
181
+ bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
182
+ dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1)
183
+ order = np.argsort(dists)
184
+ coords = coords[order, :2]
185
+ confs = confs[order]
186
+
187
+ return coords, confs
188
+
189
+
190
+ def prepare_model(model_cfg: Any, model_checkpoint: str) -> SAM2ImagePredictor:
191
+ """
192
+ Build and return a SAM2ImagePredictor model on the appropriate device.
193
+
194
+ Args:
195
+ model_cfg: Configuration for SAM2 model.
196
+ model_checkpoint (str): Path to model checkpoint.
197
+
198
+ Returns:
199
+ SAM2ImagePredictor: Initialized SAM2 image predictor.
200
+ """
201
+ if torch.cuda.is_available():
202
+ device = torch.device("cuda")
203
+ elif torch.backends.mps.is_available():
204
+ device = torch.device("mps")
205
+ else:
206
+ device = torch.device("cpu")
207
+
208
+ sam2 = build_sam2(model_cfg, model_checkpoint, device=device, apply_postprocessing=True)
209
+ model = SAM2ImagePredictor(
210
+ sam2,
211
+ max_hole_area=10.0,
212
+ max_sprinkle_area=50.0,
213
+ )
214
+ return model
215
+
216
+
217
+ def _compute_mask_pose_consistency(masks: List[np.ndarray], keypoints_list: List[np.ndarray]) -> np.ndarray:
218
+ """
219
+ Compute mask-pose consistency score for each mask-keypoints pair.
220
+
221
+ Args:
222
+ masks (List[np.ndarray]): Binary masks list.
223
+ keypoints_list (List[np.ndarray]): List of keypoint arrays per instance.
224
+
225
+ Returns:
226
+ np.ndarray: Consistency scores array of shape (N,).
227
+ """
228
+ scores: List[float] = []
229
+ for mask, kpts in zip(masks, keypoints_list):
230
+ other_kpts = np.concatenate([keypoints_list[:idx], keypoints_list[idx + 1 :]], axis=0).reshape(-1, 3)
231
+ score = _compute_one_mask_pose_consistency(mask, kpts, other_kpts)
232
+ scores.append(score)
233
+
234
+ return np.array(scores)
235
+
236
+
237
+ def _pose2seg(
238
+ args: Any,
239
+ model: SAM2ImagePredictor,
240
+ bbox_xyxy: Optional[List[float]] = None,
241
+ pos_kpts: Optional[np.ndarray] = None,
242
+ neg_kpts: Optional[np.ndarray] = None,
243
+ image: Optional[np.ndarray] = None,
244
+ gt_mask: Optional[Any] = None,
245
+ num_pos_keypoints: Optional[int] = None,
246
+ gt_mask_is_binary: bool = False,
247
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
248
+ """
249
+ Run SAM segmentation conditioned on pose keypoints and optional ground truth mask.
250
+
251
+ Args:
252
+ args: Configuration object with prompting settings.
253
+ model (SAM2ImagePredictor): Prepared SAM2 model.
254
+ bbox_xyxy (Optional[List[float]]): Bounding box coordinates in xyxy format.
255
+ pos_kpts (Optional[np.ndarray]): Positive keypoints array.
256
+ neg_kpts (Optional[np.ndarray]): Negative keypoints array.
257
+ image (Optional[np.ndarray]): Input image array.
258
+ gt_mask (Optional[Any]): Ground truth mask (optional).
259
+ num_pos_keypoints (Optional[int]): Number of positive keypoints to use.
260
+ gt_mask_is_binary (bool): Flag indicating if ground truth mask is binary.
261
+
262
+ Returns:
263
+ Tuple of (mask, pos_kpts_backup, neg_kpts_backup, score).
264
+ """
265
+ num_pos_keypoints = args.num_pos_keypoints if num_pos_keypoints is None else num_pos_keypoints
266
+
267
+ # Filter-out un-annotated and invisible keypoints
268
+ if pos_kpts is not None:
269
+ pos_kpts = pos_kpts.reshape(-1, 3)
270
+ valid_kpts = pos_kpts[:, 2] > args.visibility_thr
271
+
272
+ pose_bbox = np.array([pos_kpts[:, 0].min(), pos_kpts[:, 1].min(), pos_kpts[:, 0].max(), pos_kpts[:, 1].max()])
273
+ pos_kpts, conf = _select_keypoints(args, pos_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy)
274
+
275
+ pos_kpts_backup = np.concatenate([pos_kpts, conf[:, None]], axis=1)
276
+
277
+ if pos_kpts.shape[0] > num_pos_keypoints:
278
+ pos_kpts = pos_kpts[:num_pos_keypoints, :]
279
+ pos_kpts_backup = pos_kpts_backup[:num_pos_keypoints, :]
280
+
281
+ else:
282
+ pose_bbox = None
283
+ pos_kpts = np.empty((0, 2), dtype=np.float32)
284
+ pos_kpts_backup = np.empty((0, 3), dtype=np.float32)
285
+
286
+ if neg_kpts is not None:
287
+ neg_kpts = neg_kpts.reshape(-1, 3)
288
+ valid_kpts = neg_kpts[:, 2] > args.visibility_thr
289
+
290
+ neg_kpts, conf = _select_keypoints(
291
+ args, neg_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy, method="closest"
292
+ )
293
+ selected_neg_kpts = neg_kpts
294
+ neg_kpts_backup = np.concatenate([neg_kpts, conf[:, None]], axis=1)
295
+
296
+ if neg_kpts.shape[0] > args.num_neg_keypoints:
297
+ selected_neg_kpts = neg_kpts[: args.num_neg_keypoints, :]
298
+
299
+ else:
300
+ selected_neg_kpts = np.empty((0, 2), dtype=np.float32)
301
+ neg_kpts_backup = np.empty((0, 3), dtype=np.float32)
302
+
303
+ # Concatenate positive and negative keypoints
304
+ kpts = np.concatenate([pos_kpts, selected_neg_kpts], axis=0)
305
+ kpts_labels = np.concatenate([np.ones(pos_kpts.shape[0]), np.zeros(selected_neg_kpts.shape[0])], axis=0)
306
+
307
+ bbox = bbox_xyxy if args.use_bbox else None
308
+
309
+ if args.extend_bbox and not bbox is None:
310
+ # Expand the bbox such that it contains all positive keypoints
311
+ pose_bbox = np.array(
312
+ [pos_kpts[:, 0].min() - 2, pos_kpts[:, 1].min() - 2, pos_kpts[:, 0].max() + 2, pos_kpts[:, 1].max() + 2]
313
+ )
314
+ expanded_bbox = np.array(bbox)
315
+ expanded_bbox[:2] = np.minimum(bbox[:2], pose_bbox[:2])
316
+ expanded_bbox[2:] = np.maximum(bbox[2:], pose_bbox[2:])
317
+ bbox = expanded_bbox
318
+
319
+ if args.crop and args.use_bbox and image is not None:
320
+ # Crop the image to the 1.5 * bbox size
321
+ crop_bbox = np.array(bbox)
322
+ bbox_center = np.array([(crop_bbox[0] + crop_bbox[2]) / 2, (crop_bbox[1] + crop_bbox[3]) / 2])
323
+ bbox_size = np.array([crop_bbox[2] - crop_bbox[0], crop_bbox[3] - crop_bbox[1]])
324
+ bbox_size = 1.5 * bbox_size
325
+ crop_bbox = np.array(
326
+ [
327
+ bbox_center[0] - bbox_size[0] / 2,
328
+ bbox_center[1] - bbox_size[1] / 2,
329
+ bbox_center[0] + bbox_size[0] / 2,
330
+ bbox_center[1] + bbox_size[1] / 2,
331
+ ]
332
+ )
333
+ crop_bbox = np.round(crop_bbox).astype(int)
334
+ crop_bbox = np.clip(crop_bbox, 0, [image.shape[1], image.shape[0], image.shape[1], image.shape[0]])
335
+ original_image_size = image.shape[:2]
336
+ image = image[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2], :]
337
+
338
+ # Update the keypoints
339
+ kpts = kpts - crop_bbox[:2]
340
+ bbox[:2] = bbox[:2] - crop_bbox[:2]
341
+ bbox[2:] = bbox[2:] - crop_bbox[:2]
342
+
343
+ model.set_image(image)
344
+
345
+ masks, scores, logits = model.predict(
346
+ point_coords=kpts,
347
+ point_labels=kpts_labels,
348
+ box=bbox,
349
+ multimask_output=False,
350
+ )
351
+ mask = masks[0]
352
+ scores = scores[0]
353
+
354
+ if args.crop and args.use_bbox and image is not None:
355
+ # Pad the mask to the original image size
356
+ mask_padded = np.zeros(original_image_size, dtype=np.uint8)
357
+ mask_padded[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2]] = mask
358
+ mask = mask_padded
359
+
360
+ bbox[:2] = bbox[:2] + crop_bbox[:2]
361
+ bbox[2:] = bbox[2:] + crop_bbox[:2]
362
+
363
+ if args.pose_mask_consistency:
364
+ if gt_mask_is_binary:
365
+ gt_mask_binary = gt_mask
366
+ else:
367
+ gt_mask_binary = Mask.decode(gt_mask).astype(bool) if gt_mask is not None else None
368
+
369
+ gt_mask_pose_consistency = _compute_one_mask_pose_consistency(gt_mask_binary, pos_kpts_backup, neg_kpts_backup)
370
+ dt_mask_pose_consistency = _compute_one_mask_pose_consistency(mask, pos_kpts_backup, neg_kpts_backup)
371
+
372
+ tol = 0.1
373
+ dt_is_same = np.abs(dt_mask_pose_consistency - gt_mask_pose_consistency) < tol
374
+ if dt_is_same:
375
+ mask = gt_mask_binary if gt_mask_binary.sum() < mask.sum() else mask
376
+ else:
377
+ mask = gt_mask_binary if gt_mask_pose_consistency > dt_mask_pose_consistency else mask
378
+
379
+ return mask, pos_kpts_backup, neg_kpts_backup, scores
380
+
381
+
382
+ def process_image_with_SAM(
383
+ sam_args: Any,
384
+ image: np.ndarray,
385
+ model: SAM2ImagePredictor,
386
+ new_dets: InstanceData,
387
+ old_dets: Optional[InstanceData] = None,
388
+ ) -> InstanceData:
389
+ """
390
+ Wrapper that validates args and routes to single or batch processing.
391
+ """
392
+ _validate_sam_args(sam_args)
393
+ if sam_args.batch:
394
+ return _process_image_batch(sam_args, image, model, new_dets, old_dets)
395
+ return _process_image_single(sam_args, image, model, new_dets, old_dets)
396
+
397
+
398
+ def _process_image_single(
399
+ sam_args: Any,
400
+ image: np.ndarray,
401
+ model: SAM2ImagePredictor,
402
+ new_dets: InstanceData,
403
+ old_dets: Optional[InstanceData] = None,
404
+ ) -> InstanceData:
405
+ """
406
+ Refine instance segmentation masks using SAM2 with pose-conditioned prompts.
407
+
408
+ Args:
409
+ sam_args (Any): DotDict containing required SAM parameters:
410
+ crop (bool), use_bbox (bool), confidence_thr (float),
411
+ ignore_small_bboxes (bool), num_pos_keypoints (int),
412
+ num_pos_keypoints_if_crowd (int), crowd_by_max_iou (Optional[float]),
413
+ batch (bool), exclusive_masks (bool), extend_bbox (bool), pose_mask_consistency (bool).
414
+ image (np.ndarray): BGR image array of shape (H, W, 3).
415
+ model (SAM2ImagePredictor): Initialized SAM2 predictor.
416
+ new_dets (InstanceData): New detections with attributes:
417
+ bboxes, pred_masks, keypoints, bbox_scores.
418
+ old_dets (Optional[InstanceData]): Previous detections for negative prompts.
419
+
420
+ Returns:
421
+ InstanceData: `new_dets` updated in-place with
422
+ `.refined_masks`, `.sam_scores`, and `.sam_kpts`.
423
+ """
424
+ _validate_sam_args(sam_args)
425
+
426
+ if not (sam_args.crop and sam_args.use_bbox):
427
+ model.set_image(image)
428
+
429
+ # Ignore all keypoints with confidence below the threshold
430
+ new_keypoints = new_dets.keypoints.copy()
431
+ for kpts in new_keypoints:
432
+ conf_mask = kpts[:, 2] < sam_args.confidence_thr
433
+ kpts[conf_mask, :] = 0
434
+ n_new_dets = len(new_dets.bboxes)
435
+ n_old_dets = 0
436
+ if old_dets is not None:
437
+ n_old_dets = len(old_dets.bboxes)
438
+ old_keypoints = old_dets.keypoints.copy()
439
+ for kpts in old_keypoints:
440
+ conf_mask = kpts[:, 2] < sam_args.confidence_thr
441
+ kpts[conf_mask, :] = 0
442
+
443
+ all_bboxes = new_dets.bboxes.copy()
444
+ if old_dets is not None:
445
+ all_bboxes = np.concatenate([all_bboxes, old_dets.bboxes], axis=0)
446
+
447
+ max_ious = _get_max_ious(all_bboxes)
448
+
449
+ gt_bboxes = []
450
+ new_dets.refined_masks = np.zeros((n_new_dets, image.shape[0], image.shape[1]), dtype=np.uint8)
451
+ new_dets.sam_scores = np.zeros_like(new_dets.bbox_scores)
452
+ new_dets.sam_kpts = np.zeros((len(new_dets.bboxes), sam_args.num_pos_keypoints, 3), dtype=np.float32)
453
+ for instance_idx in range(len(new_dets.bboxes)):
454
+ bbox_xywh = new_dets.bboxes[instance_idx]
455
+ bbox_area = bbox_xywh[2] * bbox_xywh[3]
456
+
457
+ if sam_args.ignore_small_bboxes and bbox_area < 100 * 100:
458
+ continue
459
+ dt_mask = new_dets.pred_masks[instance_idx] if new_dets.pred_masks is not None else None
460
+
461
+ bbox_xyxy = [bbox_xywh[0], bbox_xywh[1], bbox_xywh[0] + bbox_xywh[2], bbox_xywh[1] + bbox_xywh[3]]
462
+ gt_bboxes.append(bbox_xyxy)
463
+ this_kpts = new_keypoints[instance_idx].reshape(1, -1, 3)
464
+ other_kpts = None
465
+ if old_dets is not None:
466
+ other_kpts = old_keypoints.copy().reshape(n_old_dets, -1, 3)
467
+ if len(new_keypoints) > 1:
468
+ other_new_kpts = np.concatenate([new_keypoints[:instance_idx], new_keypoints[instance_idx + 1 :]], axis=0)
469
+ other_kpts = (
470
+ np.concatenate([other_kpts, other_new_kpts], axis=0) if other_kpts is not None else other_new_kpts
471
+ )
472
+
473
+ num_pos_keypoints = sam_args.num_pos_keypoints
474
+ if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou:
475
+ bbox_xyxy = None
476
+ num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd
477
+
478
+ dt_mask, pos_kpts, neg_kpts, scores = _pose2seg(
479
+ sam_args,
480
+ model,
481
+ bbox_xyxy,
482
+ pos_kpts=this_kpts,
483
+ neg_kpts=other_kpts,
484
+ image=image if (sam_args.crop and sam_args.use_bbox) else None,
485
+ gt_mask=dt_mask,
486
+ num_pos_keypoints=num_pos_keypoints,
487
+ gt_mask_is_binary=True,
488
+ )
489
+
490
+ new_dets.refined_masks[instance_idx] = dt_mask
491
+ new_dets.sam_scores[instance_idx] = scores
492
+
493
+ # If the number of positive keypoints is less than the required number, fill the rest with zeros
494
+ if len(pos_kpts) != sam_args.num_pos_keypoints:
495
+ pos_kpts = np.concatenate(
496
+ [pos_kpts, np.zeros((sam_args.num_pos_keypoints - len(pos_kpts), 3), dtype=np.float32)], axis=0
497
+ )
498
+ new_dets.sam_kpts[instance_idx] = pos_kpts
499
+
500
+ n_masks = len(new_dets.refined_masks) + (len(old_dets.refined_masks) if old_dets is not None else 0)
501
+
502
+ if sam_args.exclusive_masks and n_masks > 1:
503
+ all_masks = (
504
+ np.concatenate([new_dets.refined_masks, old_dets.refined_masks], axis=0)
505
+ if old_dets is not None
506
+ else new_dets.refined_masks
507
+ )
508
+ all_scores = (
509
+ np.concatenate([new_dets.sam_scores, old_dets.sam_scores], axis=0)
510
+ if old_dets is not None
511
+ else new_dets.sam_scores
512
+ )
513
+ refined_masks = _apply_exclusive_masks(all_masks, all_scores)
514
+ new_dets.refined_masks = refined_masks[: len(new_dets.refined_masks)]
515
+
516
+ return new_dets
517
+
518
+
519
+ def _process_image_batch(
520
+ sam_args: Any,
521
+ image: np.ndarray,
522
+ model: SAM2ImagePredictor,
523
+ new_dets: InstanceData,
524
+ old_dets: Optional[InstanceData] = None,
525
+ ) -> InstanceData:
526
+ """
527
+ Batch process multiple detection instances with SAM2 refinement.
528
+
529
+ Args:
530
+ sam_args (Any): DotDict of SAM parameters (same as `process_image_with_SAM`).
531
+ image (np.ndarray): Input BGR image.
532
+ model (SAM2ImagePredictor): Prepared SAM2 predictor.
533
+ new_dets (InstanceData): New detection instances.
534
+ old_dets (Optional[InstanceData]): Previous detections for negative prompts.
535
+
536
+ Returns:
537
+ InstanceData: `new_dets` updated as in `process_image_with_SAM`.
538
+ """
539
+ n_new_dets = len(new_dets.bboxes)
540
+
541
+ model.set_image(image)
542
+
543
+ image_kpts = []
544
+ image_bboxes = []
545
+ num_valid_kpts = []
546
+ for instance_idx in range(len(new_dets.bboxes)):
547
+
548
+ bbox_xywh = new_dets.bboxes[instance_idx].copy()
549
+ bbox_area = bbox_xywh[2] * bbox_xywh[3]
550
+ if sam_args.ignore_small_bboxes and bbox_area < 100 * 100:
551
+ continue
552
+
553
+ this_kpts = new_dets.keypoints[instance_idx].copy().reshape(-1, 3)
554
+ kpts_vis = np.array(this_kpts[:, 2])
555
+ visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr)
556
+ num_visible = (visible_kpts).sum()
557
+ if num_visible <= 0:
558
+ continue
559
+ num_valid_kpts.append(num_visible)
560
+ image_bboxes.append(np.array(bbox_xywh))
561
+ this_kpts[~visible_kpts, :2] = 0
562
+ this_kpts[:, 2] = visible_kpts
563
+ image_kpts.append(this_kpts)
564
+ if old_dets is not None:
565
+ for instance_idx in range(len(old_dets.bboxes)):
566
+ bbox_xywh = old_dets.bboxes[instance_idx].copy()
567
+ bbox_area = bbox_xywh[2] * bbox_xywh[3]
568
+ if sam_args.ignore_small_bboxes and bbox_area < 100 * 100:
569
+ continue
570
+ this_kpts = old_dets.keypoints[instance_idx].reshape(-1, 3)
571
+ kpts_vis = np.array(this_kpts[:, 2])
572
+ visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr)
573
+ num_visible = (visible_kpts).sum()
574
+ if num_visible <= 0:
575
+ continue
576
+ num_valid_kpts.append(num_visible)
577
+ image_bboxes.append(np.array(bbox_xywh))
578
+ this_kpts[~visible_kpts, :2] = 0
579
+ this_kpts[:, 2] = visible_kpts
580
+ image_kpts.append(this_kpts)
581
+
582
+ image_kpts = np.array(image_kpts)
583
+ image_bboxes = np.array(image_bboxes)
584
+ num_valid_kpts = np.array(num_valid_kpts)
585
+
586
+ image_kpts_backup = image_kpts.copy()
587
+
588
+ # Prepare keypoints such that all instances have the same number of keypoints
589
+ # First sort keypoints by their distance to the center of the bounding box
590
+ # If some are missing, duplicate the last one
591
+ prepared_kpts = []
592
+ prepared_kpts_backup = []
593
+ for bbox, kpts, num_visible in zip(image_bboxes, image_kpts, num_valid_kpts):
594
+
595
+ this_kpts, this_conf = _select_keypoints(sam_args, kpts, num_visible, bbox)
596
+
597
+ # Duplicate the last keypoint if some are missing
598
+ if this_kpts.shape[0] < num_valid_kpts.max():
599
+ this_kpts = np.concatenate(
600
+ [this_kpts, np.tile(this_kpts[-1], (num_valid_kpts.max() - this_kpts.shape[0], 1))], axis=0
601
+ )
602
+ this_conf = np.concatenate(
603
+ [this_conf, np.tile(this_conf[-1], (num_valid_kpts.max() - this_conf.shape[0],))], axis=0
604
+ )
605
+
606
+ prepared_kpts.append(this_kpts)
607
+ prepared_kpts_backup.append(np.concatenate([this_kpts, this_conf[:, None]], axis=1))
608
+ image_kpts = np.array(prepared_kpts)
609
+ image_kpts_backup = np.array(prepared_kpts_backup)
610
+ kpts_labels = np.ones(image_kpts.shape[:2])
611
+
612
+ # Compute IoUs between all bounding boxes
613
+ max_ious = _get_max_ious(image_bboxes)
614
+ num_pos_keypoints = sam_args.num_pos_keypoints
615
+ use_bbox = sam_args.use_bbox
616
+ if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou:
617
+ use_bbox = False
618
+ num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd
619
+
620
+ # Threshold the number of positive keypoints
621
+ if num_pos_keypoints > 0 and num_pos_keypoints < image_kpts.shape[1]:
622
+ image_kpts = image_kpts[:, :num_pos_keypoints, :]
623
+ kpts_labels = kpts_labels[:, :num_pos_keypoints]
624
+ image_kpts_backup = image_kpts_backup[:, :num_pos_keypoints, :]
625
+
626
+ elif num_pos_keypoints == 0:
627
+ image_kpts = None
628
+ kpts_labels = None
629
+ image_kpts_backup = np.empty((0, 3), dtype=np.float32)
630
+
631
+ image_bboxes_xyxy = None
632
+ if use_bbox:
633
+ image_bboxes_xyxy = np.array(image_bboxes)
634
+ image_bboxes_xyxy[:, 2:] += image_bboxes_xyxy[:, :2]
635
+
636
+ # Expand the bbox to include the positive keypoints
637
+ if sam_args.extend_bbox:
638
+ pose_bbox = np.stack(
639
+ [
640
+ np.min(image_kpts[:, :, 0], axis=1) - 2,
641
+ np.min(image_kpts[:, :, 1], axis=1) - 2,
642
+ np.max(image_kpts[:, :, 0], axis=1) + 2,
643
+ np.max(image_kpts[:, :, 1], axis=1) + 2,
644
+ ],
645
+ axis=1,
646
+ )
647
+ expanded_bbox = np.array(image_bboxes_xyxy)
648
+ expanded_bbox[:, :2] = np.minimum(expanded_bbox[:, :2], pose_bbox[:, :2])
649
+ expanded_bbox[:, 2:] = np.maximum(expanded_bbox[:, 2:], pose_bbox[:, 2:])
650
+ # bbox_expanded = (np.abs(expanded_bbox - image_bboxes_xyxy) > 1e-4).any(axis=1)
651
+ image_bboxes_xyxy = expanded_bbox
652
+
653
+ # Process even old detections to get their 'negative' keypoints
654
+ masks, scores, logits = model.predict(
655
+ point_coords=image_kpts,
656
+ point_labels=kpts_labels,
657
+ box=image_bboxes_xyxy,
658
+ multimask_output=False,
659
+ )
660
+
661
+ # Reshape the masks to (N, C, H, W). If the model outputs (C, H, W), add a number of masks dimension
662
+ if len(masks.shape) == 3:
663
+ masks = masks[None, :, :, :]
664
+ masks = masks[:, 0, :, :]
665
+ N = masks.shape[0]
666
+ scores = scores.reshape(N)
667
+
668
+ if sam_args.exclusive_masks and N > 1:
669
+ # Make sure the masks are non-overlapping
670
+ # If two masks overlap, set the pixel to the one with the highest score
671
+ masks = _apply_exclusive_masks(masks, scores)
672
+
673
+ gt_masks = new_dets.pred_masks.copy() if new_dets.pred_masks is not None else None
674
+ if sam_args.pose_mask_consistency and gt_masks is not None:
675
+ # Measure 'mask-pose_conistency' by computing number of keypoints inside the mask
676
+ # Compute for both gt (if available) and predicted masks and then choose the one with higher consistency
677
+ dt_mask_pose_consistency = _compute_mask_pose_consistency(masks, image_kpts_backup)
678
+ gt_mask_pose_consistency = _compute_mask_pose_consistency(gt_masks, image_kpts_backup)
679
+
680
+ dt_masks_area = np.array([m.sum() for m in masks])
681
+ gt_masks_area = np.array([m.sum() for m in gt_masks]) if gt_masks is not None else np.zeros_like(dt_masks_area)
682
+
683
+ # If PM-c is approx the same, prefer the smaller mask
684
+ tol = 0.1
685
+ pmc_is_equal = np.isclose(dt_mask_pose_consistency, gt_mask_pose_consistency, atol=tol)
686
+ dt_is_worse = (dt_mask_pose_consistency < (gt_mask_pose_consistency - tol)) | pmc_is_equal & (
687
+ dt_masks_area > gt_masks_area
688
+ )
689
+
690
+ new_masks = []
691
+ for dt_mask, gt_mask, dt_worse in zip(masks, gt_masks, dt_is_worse):
692
+ if dt_worse:
693
+ new_masks.append(gt_mask)
694
+ else:
695
+ new_masks.append(dt_mask)
696
+ masks = np.array(new_masks)
697
+
698
+ new_dets.refined_masks = masks[:n_new_dets]
699
+ new_dets.sam_scores = scores[:n_new_dets]
700
+ new_dets.sam_kpts = image_kpts_backup[:n_new_dets]
701
+
702
+ return new_dets
703
+
704
+
705
+ def _apply_exclusive_masks(masks: np.ndarray, scores: np.ndarray) -> np.ndarray:
706
+ """
707
+ Ensure masks are non-overlapping by keeping at each pixel the mask with the highest score.
708
+ """
709
+ no_mask = masks.sum(axis=0) == 0
710
+ masked_scores = masks * scores[:, None, None]
711
+ argmax_masks = np.argmax(masked_scores, axis=0)
712
+ new_masks = argmax_masks[None, :, :] == (np.arange(masks.shape[0])[:, None, None])
713
+ new_masks[:, no_mask] = 0
714
+ return new_masks
mmpose/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import mmcv
3
+ import mmengine
4
+ from mmengine.utils import digit_version
5
+
6
+ from .version import __version__, short_version
7
+
8
+ mmcv_minimum_version = '2.0.0rc4'
9
+ mmcv_maximum_version = '2.3.0'
10
+ mmcv_version = digit_version(mmcv.__version__)
11
+
12
+ mmengine_minimum_version = '0.6.0'
13
+ mmengine_maximum_version = '1.0.0'
14
+ mmengine_version = digit_version(mmengine.__version__)
15
+
16
+ assert (mmcv_version >= digit_version(mmcv_minimum_version)
17
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
18
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
19
+ f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
20
+
21
+ assert (mmengine_version >= digit_version(mmengine_minimum_version)
22
+ and mmengine_version <= digit_version(mmengine_maximum_version)), \
23
+ f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
24
+ f'Please install mmengine>={mmengine_minimum_version}, ' \
25
+ f'<={mmengine_maximum_version}.'
26
+
27
+ __all__ = ['__version__', 'short_version']
mmpose/apis/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .inference import (collect_multi_frames, inference_bottomup,
3
+ inference_topdown, init_model)
4
+ from .inference_3d import (collate_pose_sequence, convert_keypoint_definition,
5
+ extract_pose_sequence, inference_pose_lifter_model)
6
+ from .inference_tracking import _compute_iou, _track_by_iou, _track_by_oks
7
+ from .inferencers import MMPoseInferencer, Pose2DInferencer
8
+ from .visualization import visualize
9
+
10
+ __all__ = [
11
+ 'init_model', 'inference_topdown', 'inference_bottomup',
12
+ 'collect_multi_frames', 'Pose2DInferencer', 'MMPoseInferencer',
13
+ '_track_by_iou', '_track_by_oks', '_compute_iou',
14
+ 'inference_pose_lifter_model', 'extract_pose_sequence',
15
+ 'convert_keypoint_definition', 'collate_pose_sequence', 'visualize'
16
+ ]
mmpose/apis/inference.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+ from pathlib import Path
4
+ from typing import List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from mmengine.config import Config
10
+ from mmengine.dataset import Compose, pseudo_collate
11
+ from mmengine.model.utils import revert_sync_batchnorm
12
+ from mmengine.registry import init_default_scope
13
+ from mmengine.runner import load_checkpoint
14
+ from PIL import Image
15
+
16
+ from mmpose.datasets.datasets.utils import parse_pose_metainfo
17
+ from mmpose.models.builder import build_pose_estimator
18
+ from mmpose.structures import PoseDataSample
19
+ from mmpose.structures.bbox import bbox_xywh2xyxy
20
+
21
+ import cv2
22
+
23
+ def dataset_meta_from_config(config: Config,
24
+ dataset_mode: str = 'train') -> Optional[dict]:
25
+ """Get dataset metainfo from the model config.
26
+
27
+ Args:
28
+ config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
29
+ :obj:`Path`, or the config object.
30
+ dataset_mode (str): Specify the dataset of which to get the metainfo.
31
+ Options are ``'train'``, ``'val'`` and ``'test'``. Defaults to
32
+ ``'train'``
33
+
34
+ Returns:
35
+ dict, optional: The dataset metainfo. See
36
+ ``mmpose.datasets.datasets.utils.parse_pose_metainfo`` for details.
37
+ Return ``None`` if failing to get dataset metainfo from the config.
38
+ """
39
+ try:
40
+ if dataset_mode == 'train':
41
+ dataset_cfg = config.train_dataloader.dataset
42
+ elif dataset_mode == 'val':
43
+ dataset_cfg = config.val_dataloader.dataset
44
+ elif dataset_mode == 'test':
45
+ dataset_cfg = config.test_dataloader.dataset
46
+ else:
47
+ raise ValueError(
48
+ f'Invalid dataset {dataset_mode} to get metainfo. '
49
+ 'Should be one of "train", "val", or "test".')
50
+
51
+ if 'metainfo' in dataset_cfg:
52
+ metainfo = dataset_cfg.metainfo
53
+ else:
54
+ import mmpose.datasets.datasets # noqa: F401, F403
55
+ from mmpose.registry import DATASETS
56
+
57
+ dataset_class = dataset_cfg.type if isinstance(
58
+ dataset_cfg.type, type) else DATASETS.get(dataset_cfg.type)
59
+ metainfo = dataset_class.METAINFO
60
+
61
+ metainfo = parse_pose_metainfo(metainfo)
62
+
63
+ except AttributeError:
64
+ metainfo = None
65
+
66
+ return metainfo
67
+
68
+
69
+ def init_model(config: Union[str, Path, Config],
70
+ checkpoint: Optional[str] = None,
71
+ device: str = 'cuda:0',
72
+ cfg_options: Optional[dict] = None) -> nn.Module:
73
+ """Initialize a pose estimator from a config file.
74
+
75
+ Args:
76
+ config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
77
+ :obj:`Path`, or the config object.
78
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
79
+ will not load any weights. Defaults to ``None``
80
+ device (str): The device where the anchors will be put on.
81
+ Defaults to ``'cuda:0'``.
82
+ cfg_options (dict, optional): Options to override some settings in
83
+ the used config. Defaults to ``None``
84
+
85
+ Returns:
86
+ nn.Module: The constructed pose estimator.
87
+ """
88
+
89
+ if isinstance(config, (str, Path)):
90
+ config = Config.fromfile(config)
91
+ elif not isinstance(config, Config):
92
+ raise TypeError('config must be a filename or Config object, '
93
+ f'but got {type(config)}')
94
+ if cfg_options is not None:
95
+ config.merge_from_dict(cfg_options)
96
+ elif 'init_cfg' in config.model.backbone:
97
+ config.model.backbone.init_cfg = None
98
+ config.model.train_cfg = None
99
+
100
+ # register all modules in mmpose into the registries
101
+ scope = config.get('default_scope', 'mmpose')
102
+ if scope is not None:
103
+ init_default_scope(scope)
104
+
105
+ model = build_pose_estimator(config.model)
106
+ model = revert_sync_batchnorm(model)
107
+ # get dataset_meta in this priority: checkpoint > config > default (COCO)
108
+ dataset_meta = None
109
+
110
+ if checkpoint is not None:
111
+ ckpt = load_checkpoint(model, checkpoint, map_location='cpu')
112
+
113
+ if 'dataset_meta' in ckpt.get('meta', {}):
114
+ # checkpoint from mmpose 1.x
115
+ dataset_meta = ckpt['meta']['dataset_meta']
116
+
117
+ if dataset_meta is None:
118
+ dataset_meta = dataset_meta_from_config(config, dataset_mode='train')
119
+
120
+ if dataset_meta is None:
121
+ warnings.simplefilter('once')
122
+ warnings.warn('Can not load dataset_meta from the checkpoint or the '
123
+ 'model config. Use COCO metainfo by default.')
124
+ dataset_meta = parse_pose_metainfo(
125
+ dict(from_file='configs/_base_/datasets/coco.py'))
126
+
127
+ model.dataset_meta = dataset_meta
128
+
129
+ model.cfg = config # save the config in the model for convenience
130
+ model.to(device)
131
+ model.eval()
132
+ return model
133
+
134
+
135
+ def inference_topdown(model: nn.Module,
136
+ img: Union[np.ndarray, str],
137
+ bboxes: Optional[Union[List, np.ndarray]] = None,
138
+ masks: Optional[Union[List, np.ndarray]] = None,
139
+ bbox_format: str = 'xyxy') -> List[PoseDataSample]:
140
+ """Inference image with a top-down pose estimator.
141
+
142
+ Args:
143
+ model (nn.Module): The top-down pose estimator
144
+ img (np.ndarray | str): The loaded image or image file to inference
145
+ bboxes (np.ndarray, optional): The bboxes in shape (N, 4), each row
146
+ represents a bbox. If not given, the entire image will be regarded
147
+ as a single bbox area. Defaults to ``None``
148
+ bbox_format (str): The bbox format indicator. Options are ``'xywh'``
149
+ and ``'xyxy'``. Defaults to ``'xyxy'``
150
+
151
+ Returns:
152
+ List[:obj:`PoseDataSample`]: The inference results. Specifically, the
153
+ predicted keypoints and scores are saved at
154
+ ``data_sample.pred_instances.keypoints`` and
155
+ ``data_sample.pred_instances.keypoint_scores``.
156
+ """
157
+ scope = model.cfg.get('default_scope', 'mmpose')
158
+ if scope is not None:
159
+ init_default_scope(scope)
160
+ pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
161
+
162
+ if bboxes is None or len(bboxes) == 0:
163
+ # get bbox from the image size
164
+ if isinstance(img, str):
165
+ w, h = Image.open(img).size
166
+ else:
167
+ h, w = img.shape[:2]
168
+
169
+ bboxes = np.array([[0, 0, w, h]], dtype=np.float32)
170
+ else:
171
+ if isinstance(bboxes, list):
172
+ bboxes = np.array(bboxes)
173
+
174
+ assert bbox_format in {'xyxy', 'xywh'}, \
175
+ f'Invalid bbox_format "{bbox_format}".'
176
+
177
+ if bbox_format == 'xywh':
178
+ bboxes = bbox_xywh2xyxy(bboxes)
179
+
180
+ if masks is None or len(masks) == 0:
181
+ masks = np.zeros((bboxes.shape[0], img.shape[0], img.shape[1]),
182
+ dtype=np.uint8)
183
+
184
+ # Masks are expected in polygon format
185
+ poly_masks = []
186
+ for mask in masks:
187
+ if np.sum(mask) == 0:
188
+ poly_masks.append(None)
189
+ else:
190
+ contours, _ = cv2.findContours((mask*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
191
+ polygons = [contour.flatten() for contour in contours if len(contour) > 3]
192
+ poly_masks.append(polygons if polygons else None)
193
+
194
+ # construct batch data samples
195
+ data_list = []
196
+ for bbox, pmask in zip(bboxes, poly_masks):
197
+ if isinstance(img, str):
198
+ data_info = dict(img_path=img)
199
+ else:
200
+ data_info = dict(img=img)
201
+ data_info['bbox'] = bbox[None] # shape (1, 4)
202
+ data_info['segmentation'] = pmask
203
+ data_info['bbox_score'] = np.ones(1, dtype=np.float32) # shape (1,)
204
+ data_info.update(model.dataset_meta)
205
+ data_list.append(pipeline(data_info))
206
+
207
+ if data_list:
208
+ # collate data list into a batch, which is a dict with following keys:
209
+ # batch['inputs']: a list of input images
210
+ # batch['data_samples']: a list of :obj:`PoseDataSample`
211
+ batch = pseudo_collate(data_list)
212
+ with torch.no_grad():
213
+ results = model.test_step(batch)
214
+ else:
215
+ results = []
216
+
217
+ return results
218
+
219
+
220
+ def inference_bottomup(model: nn.Module, img: Union[np.ndarray, str]):
221
+ """Inference image with a bottom-up pose estimator.
222
+
223
+ Args:
224
+ model (nn.Module): The bottom-up pose estimator
225
+ img (np.ndarray | str): The loaded image or image file to inference
226
+
227
+ Returns:
228
+ List[:obj:`PoseDataSample`]: The inference results. Specifically, the
229
+ predicted keypoints and scores are saved at
230
+ ``data_sample.pred_instances.keypoints`` and
231
+ ``data_sample.pred_instances.keypoint_scores``.
232
+ """
233
+ pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
234
+
235
+ # prepare data batch
236
+ if isinstance(img, str):
237
+ data_info = dict(img_path=img)
238
+ else:
239
+ data_info = dict(img=img)
240
+ data_info.update(model.dataset_meta)
241
+ data = pipeline(data_info)
242
+ batch = pseudo_collate([data])
243
+
244
+ with torch.no_grad():
245
+ results = model.test_step(batch)
246
+
247
+ return results
248
+
249
+
250
+ def collect_multi_frames(video, frame_id, indices, online=False):
251
+ """Collect multi frames from the video.
252
+
253
+ Args:
254
+ video (mmcv.VideoReader): A VideoReader of the input video file.
255
+ frame_id (int): index of the current frame
256
+ indices (list(int)): index offsets of the frames to collect
257
+ online (bool): inference mode, if set to True, can not use future
258
+ frame information.
259
+
260
+ Returns:
261
+ list(ndarray): multi frames collected from the input video file.
262
+ """
263
+ num_frames = len(video)
264
+ frames = []
265
+ # put the current frame at first
266
+ frames.append(video[frame_id])
267
+ # use multi frames for inference
268
+ for idx in indices:
269
+ # skip current frame
270
+ if idx == 0:
271
+ continue
272
+ support_idx = frame_id + idx
273
+ # online mode, can not use future frame information
274
+ if online:
275
+ support_idx = np.clip(support_idx, 0, frame_id)
276
+ else:
277
+ support_idx = np.clip(support_idx, 0, num_frames - 1)
278
+ frames.append(video[support_idx])
279
+
280
+ return frames
mmpose/apis/inference_3d.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numpy as np
3
+ import torch
4
+ from mmengine.dataset import Compose, pseudo_collate
5
+ from mmengine.registry import init_default_scope
6
+ from mmengine.structures import InstanceData
7
+
8
+ from mmpose.structures import PoseDataSample
9
+
10
+
11
+ def convert_keypoint_definition(keypoints, pose_det_dataset,
12
+ pose_lift_dataset):
13
+ """Convert pose det dataset keypoints definition to pose lifter dataset
14
+ keypoints definition, so that they are compatible with the definitions
15
+ required for 3D pose lifting.
16
+
17
+ Args:
18
+ keypoints (ndarray[N, K, 2 or 3]): 2D keypoints to be transformed.
19
+ pose_det_dataset, (str): Name of the dataset for 2D pose detector.
20
+ pose_lift_dataset (str): Name of the dataset for pose lifter model.
21
+
22
+ Returns:
23
+ ndarray[K, 2 or 3]: the transformed 2D keypoints.
24
+ """
25
+ assert pose_lift_dataset in [
26
+ 'h36m', 'h3wb'], '`pose_lift_dataset` should be ' \
27
+ f'`h36m`, but got {pose_lift_dataset}.'
28
+
29
+ keypoints_new = np.zeros((keypoints.shape[0], 17, keypoints.shape[2]),
30
+ dtype=keypoints.dtype)
31
+ if pose_lift_dataset in ['h36m', 'h3wb']:
32
+ if pose_det_dataset in ['h36m', 'coco_wholebody']:
33
+ keypoints_new = keypoints
34
+ elif pose_det_dataset in ['coco', 'posetrack18']:
35
+ # pelvis (root) is in the middle of l_hip and r_hip
36
+ keypoints_new[:, 0] = (keypoints[:, 11] + keypoints[:, 12]) / 2
37
+ # thorax is in the middle of l_shoulder and r_shoulder
38
+ keypoints_new[:, 8] = (keypoints[:, 5] + keypoints[:, 6]) / 2
39
+ # spine is in the middle of thorax and pelvis
40
+ keypoints_new[:,
41
+ 7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2
42
+ # in COCO, head is in the middle of l_eye and r_eye
43
+ # in PoseTrack18, head is in the middle of head_bottom and head_top
44
+ keypoints_new[:, 10] = (keypoints[:, 1] + keypoints[:, 2]) / 2
45
+ # rearrange other keypoints
46
+ keypoints_new[:, [1, 2, 3, 4, 5, 6, 9, 11, 12, 13, 14, 15, 16]] = \
47
+ keypoints[:, [12, 14, 16, 11, 13, 15, 0, 5, 7, 9, 6, 8, 10]]
48
+ elif pose_det_dataset in ['aic']:
49
+ # pelvis (root) is in the middle of l_hip and r_hip
50
+ keypoints_new[:, 0] = (keypoints[:, 9] + keypoints[:, 6]) / 2
51
+ # thorax is in the middle of l_shoulder and r_shoulder
52
+ keypoints_new[:, 8] = (keypoints[:, 3] + keypoints[:, 0]) / 2
53
+ # spine is in the middle of thorax and pelvis
54
+ keypoints_new[:,
55
+ 7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2
56
+ # neck base (top end of neck) is 1/4 the way from
57
+ # neck (bottom end of neck) to head top
58
+ keypoints_new[:, 9] = (3 * keypoints[:, 13] + keypoints[:, 12]) / 4
59
+ # head (spherical centre of head) is 7/12 the way from
60
+ # neck (bottom end of neck) to head top
61
+ keypoints_new[:, 10] = (5 * keypoints[:, 13] +
62
+ 7 * keypoints[:, 12]) / 12
63
+
64
+ keypoints_new[:, [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16]] = \
65
+ keypoints[:, [6, 7, 8, 9, 10, 11, 3, 4, 5, 0, 1, 2]]
66
+ elif pose_det_dataset in ['crowdpose']:
67
+ # pelvis (root) is in the middle of l_hip and r_hip
68
+ keypoints_new[:, 0] = (keypoints[:, 6] + keypoints[:, 7]) / 2
69
+ # thorax is in the middle of l_shoulder and r_shoulder
70
+ keypoints_new[:, 8] = (keypoints[:, 0] + keypoints[:, 1]) / 2
71
+ # spine is in the middle of thorax and pelvis
72
+ keypoints_new[:,
73
+ 7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2
74
+ # neck base (top end of neck) is 1/4 the way from
75
+ # neck (bottom end of neck) to head top
76
+ keypoints_new[:, 9] = (3 * keypoints[:, 13] + keypoints[:, 12]) / 4
77
+ # head (spherical centre of head) is 7/12 the way from
78
+ # neck (bottom end of neck) to head top
79
+ keypoints_new[:, 10] = (5 * keypoints[:, 13] +
80
+ 7 * keypoints[:, 12]) / 12
81
+
82
+ keypoints_new[:, [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16]] = \
83
+ keypoints[:, [7, 9, 11, 6, 8, 10, 0, 2, 4, 1, 3, 5]]
84
+ else:
85
+ raise NotImplementedError(
86
+ f'unsupported conversion between {pose_lift_dataset} and '
87
+ f'{pose_det_dataset}')
88
+
89
+ return keypoints_new
90
+
91
+
92
+ def extract_pose_sequence(pose_results, frame_idx, causal, seq_len, step=1):
93
+ """Extract the target frame from 2D pose results, and pad the sequence to a
94
+ fixed length.
95
+
96
+ Args:
97
+ pose_results (List[List[:obj:`PoseDataSample`]]): Multi-frame pose
98
+ detection results stored in a list.
99
+ frame_idx (int): The index of the frame in the original video.
100
+ causal (bool): If True, the target frame is the last frame in
101
+ a sequence. Otherwise, the target frame is in the middle of
102
+ a sequence.
103
+ seq_len (int): The number of frames in the input sequence.
104
+ step (int): Step size to extract frames from the video.
105
+
106
+ Returns:
107
+ List[List[:obj:`PoseDataSample`]]: Multi-frame pose detection results
108
+ stored in a nested list with a length of seq_len.
109
+ """
110
+ if causal:
111
+ frames_left = seq_len - 1
112
+ frames_right = 0
113
+ else:
114
+ frames_left = (seq_len - 1) // 2
115
+ frames_right = frames_left
116
+ num_frames = len(pose_results)
117
+
118
+ # get the padded sequence
119
+ pad_left = max(0, frames_left - frame_idx // step)
120
+ pad_right = max(0, frames_right - (num_frames - 1 - frame_idx) // step)
121
+ start = max(frame_idx % step, frame_idx - frames_left * step)
122
+ end = min(num_frames - (num_frames - 1 - frame_idx) % step,
123
+ frame_idx + frames_right * step + 1)
124
+ pose_results_seq = [pose_results[0]] * pad_left + \
125
+ pose_results[start:end:step] + [pose_results[-1]] * pad_right
126
+ return pose_results_seq
127
+
128
+
129
+ def collate_pose_sequence(pose_results_2d,
130
+ with_track_id=True,
131
+ target_frame=-1):
132
+ """Reorganize multi-frame pose detection results into individual pose
133
+ sequences.
134
+
135
+ Note:
136
+ - The temporal length of the pose detection results: T
137
+ - The number of the person instances: N
138
+ - The number of the keypoints: K
139
+ - The channel number of each keypoint: C
140
+
141
+ Args:
142
+ pose_results_2d (List[List[:obj:`PoseDataSample`]]): Multi-frame pose
143
+ detection results stored in a nested list. Each element of the
144
+ outer list is the pose detection results of a single frame, and
145
+ each element of the inner list is the pose information of one
146
+ person, which contains:
147
+
148
+ - keypoints (ndarray[K, 2 or 3]): x, y, [score]
149
+ - track_id (int): unique id of each person, required when
150
+ ``with_track_id==True```
151
+
152
+ with_track_id (bool): If True, the element in pose_results is expected
153
+ to contain "track_id", which will be used to gather the pose
154
+ sequence of a person from multiple frames. Otherwise, the pose
155
+ results in each frame are expected to have a consistent number and
156
+ order of identities. Default is True.
157
+ target_frame (int): The index of the target frame. Default: -1.
158
+
159
+ Returns:
160
+ List[:obj:`PoseDataSample`]: Indivisual pose sequence in with length N.
161
+ """
162
+ T = len(pose_results_2d)
163
+ assert T > 0
164
+
165
+ target_frame = (T + target_frame) % T # convert negative index to positive
166
+
167
+ N = len(
168
+ pose_results_2d[target_frame]) # use identities in the target frame
169
+ if N == 0:
170
+ return []
171
+
172
+ B, K, C = pose_results_2d[target_frame][0].pred_instances.keypoints.shape
173
+
174
+ track_ids = None
175
+ if with_track_id:
176
+ track_ids = [res.track_id for res in pose_results_2d[target_frame]]
177
+
178
+ pose_sequences = []
179
+ for idx in range(N):
180
+ pose_seq = PoseDataSample()
181
+ pred_instances = InstanceData()
182
+
183
+ gt_instances = pose_results_2d[target_frame][idx].gt_instances.clone()
184
+ pred_instances = pose_results_2d[target_frame][
185
+ idx].pred_instances.clone()
186
+ pose_seq.pred_instances = pred_instances
187
+ pose_seq.gt_instances = gt_instances
188
+
189
+ if not with_track_id:
190
+ pose_seq.pred_instances.keypoints = np.stack([
191
+ frame[idx].pred_instances.keypoints
192
+ for frame in pose_results_2d
193
+ ],
194
+ axis=1)
195
+ else:
196
+ keypoints = np.zeros((B, T, K, C), dtype=np.float32)
197
+ keypoints[:, target_frame] = pose_results_2d[target_frame][
198
+ idx].pred_instances.keypoints
199
+ # find the left most frame containing track_ids[idx]
200
+ for frame_idx in range(target_frame - 1, -1, -1):
201
+ contains_idx = False
202
+ for res in pose_results_2d[frame_idx]:
203
+ if res.track_id == track_ids[idx]:
204
+ keypoints[:, frame_idx] = res.pred_instances.keypoints
205
+ contains_idx = True
206
+ break
207
+ if not contains_idx:
208
+ # replicate the left most frame
209
+ keypoints[:, :frame_idx + 1] = keypoints[:, frame_idx + 1]
210
+ break
211
+ # find the right most frame containing track_idx[idx]
212
+ for frame_idx in range(target_frame + 1, T):
213
+ contains_idx = False
214
+ for res in pose_results_2d[frame_idx]:
215
+ if res.track_id == track_ids[idx]:
216
+ keypoints[:, frame_idx] = res.pred_instances.keypoints
217
+ contains_idx = True
218
+ break
219
+ if not contains_idx:
220
+ # replicate the right most frame
221
+ keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx]
222
+ break
223
+ pose_seq.pred_instances.set_field(keypoints, 'keypoints')
224
+ pose_sequences.append(pose_seq)
225
+
226
+ return pose_sequences
227
+
228
+
229
+ def inference_pose_lifter_model(model,
230
+ pose_results_2d,
231
+ with_track_id=True,
232
+ image_size=None,
233
+ norm_pose_2d=False):
234
+ """Inference 3D pose from 2D pose sequences using a pose lifter model.
235
+
236
+ Args:
237
+ model (nn.Module): The loaded pose lifter model
238
+ pose_results_2d (List[List[:obj:`PoseDataSample`]]): The 2D pose
239
+ sequences stored in a nested list.
240
+ with_track_id: If True, the element in pose_results_2d is expected to
241
+ contain "track_id", which will be used to gather the pose sequence
242
+ of a person from multiple frames. Otherwise, the pose results in
243
+ each frame are expected to have a consistent number and order of
244
+ identities. Default is True.
245
+ image_size (tuple|list): image width, image height. If None, image size
246
+ will not be contained in dict ``data``.
247
+ norm_pose_2d (bool): If True, scale the bbox (along with the 2D
248
+ pose) to the average bbox scale of the dataset, and move the bbox
249
+ (along with the 2D pose) to the average bbox center of the dataset.
250
+
251
+ Returns:
252
+ List[:obj:`PoseDataSample`]: 3D pose inference results. Specifically,
253
+ the predicted keypoints and scores are saved at
254
+ ``data_sample.pred_instances.keypoints_3d``.
255
+ """
256
+ init_default_scope(model.cfg.get('default_scope', 'mmpose'))
257
+ pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
258
+
259
+ causal = model.cfg.test_dataloader.dataset.get('causal', False)
260
+ target_idx = -1 if causal else len(pose_results_2d) // 2
261
+
262
+ dataset_info = model.dataset_meta
263
+ if dataset_info is not None:
264
+ if 'stats_info' in dataset_info:
265
+ bbox_center = dataset_info['stats_info']['bbox_center']
266
+ bbox_scale = dataset_info['stats_info']['bbox_scale']
267
+ else:
268
+ if norm_pose_2d:
269
+ # compute the average bbox center and scale from the
270
+ # datasamples in pose_results_2d
271
+ bbox_center = np.zeros((1, 2), dtype=np.float32)
272
+ bbox_scale = 0
273
+ num_bbox = 0
274
+ for pose_res in pose_results_2d:
275
+ for data_sample in pose_res:
276
+ for bbox in data_sample.pred_instances.bboxes:
277
+ bbox_center += np.array([[(bbox[0] + bbox[2]) / 2,
278
+ (bbox[1] + bbox[3]) / 2]
279
+ ])
280
+ bbox_scale += max(bbox[2] - bbox[0],
281
+ bbox[3] - bbox[1])
282
+ num_bbox += 1
283
+ bbox_center /= num_bbox
284
+ bbox_scale /= num_bbox
285
+ else:
286
+ bbox_center = None
287
+ bbox_scale = None
288
+
289
+ pose_results_2d_copy = []
290
+ for i, pose_res in enumerate(pose_results_2d):
291
+ pose_res_copy = []
292
+ for j, data_sample in enumerate(pose_res):
293
+ data_sample_copy = PoseDataSample()
294
+ data_sample_copy.gt_instances = data_sample.gt_instances.clone()
295
+ data_sample_copy.pred_instances = data_sample.pred_instances.clone(
296
+ )
297
+ data_sample_copy.track_id = data_sample.track_id
298
+ kpts = data_sample.pred_instances.keypoints
299
+ bboxes = data_sample.pred_instances.bboxes
300
+ keypoints = []
301
+ for k in range(len(kpts)):
302
+ kpt = kpts[k]
303
+ if norm_pose_2d:
304
+ bbox = bboxes[k]
305
+ center = np.array([[(bbox[0] + bbox[2]) / 2,
306
+ (bbox[1] + bbox[3]) / 2]])
307
+ scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
308
+ keypoints.append((kpt[:, :2] - center) / scale *
309
+ bbox_scale + bbox_center)
310
+ else:
311
+ keypoints.append(kpt[:, :2])
312
+ data_sample_copy.pred_instances.set_field(
313
+ np.array(keypoints), 'keypoints')
314
+ pose_res_copy.append(data_sample_copy)
315
+ pose_results_2d_copy.append(pose_res_copy)
316
+
317
+ pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy,
318
+ with_track_id, target_idx)
319
+
320
+ if not pose_sequences_2d:
321
+ return []
322
+
323
+ data_list = []
324
+ for i, pose_seq in enumerate(pose_sequences_2d):
325
+ data_info = dict()
326
+
327
+ keypoints_2d = pose_seq.pred_instances.keypoints
328
+ keypoints_2d = np.squeeze(
329
+ keypoints_2d, axis=0) if keypoints_2d.ndim == 4 else keypoints_2d
330
+
331
+ T, K, C = keypoints_2d.shape
332
+
333
+ data_info['keypoints'] = keypoints_2d
334
+ data_info['keypoints_visible'] = np.ones((
335
+ T,
336
+ K,
337
+ ), dtype=np.float32)
338
+ data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32)
339
+ data_info['factor'] = np.zeros((T, ), dtype=np.float32)
340
+ data_info['lifting_target_visible'] = np.ones((1, K, 1),
341
+ dtype=np.float32)
342
+
343
+ if image_size is not None:
344
+ assert len(image_size) == 2
345
+ data_info['camera_param'] = dict(w=image_size[0], h=image_size[1])
346
+
347
+ data_info.update(model.dataset_meta)
348
+ data_list.append(pipeline(data_info))
349
+
350
+ if data_list:
351
+ # collate data list into a batch, which is a dict with following keys:
352
+ # batch['inputs']: a list of input images
353
+ # batch['data_samples']: a list of :obj:`PoseDataSample`
354
+ batch = pseudo_collate(data_list)
355
+ with torch.no_grad():
356
+ results = model.test_step(batch)
357
+ else:
358
+ results = []
359
+
360
+ return results
mmpose/apis/inference_tracking.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import numpy as np
5
+
6
+ from mmpose.evaluation.functional.nms import oks_iou
7
+
8
+
9
+ def _compute_iou(bboxA, bboxB):
10
+ """Compute the Intersection over Union (IoU) between two boxes .
11
+
12
+ Args:
13
+ bboxA (list): The first bbox info (left, top, right, bottom, score).
14
+ bboxB (list): The second bbox info (left, top, right, bottom, score).
15
+
16
+ Returns:
17
+ float: The IoU value.
18
+ """
19
+
20
+ x1 = max(bboxA[0], bboxB[0])
21
+ y1 = max(bboxA[1], bboxB[1])
22
+ x2 = min(bboxA[2], bboxB[2])
23
+ y2 = min(bboxA[3], bboxB[3])
24
+
25
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
26
+
27
+ bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1])
28
+ bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1])
29
+ union_area = float(bboxA_area + bboxB_area - inter_area)
30
+ if union_area == 0:
31
+ union_area = 1e-5
32
+ warnings.warn('union_area=0 is unexpected')
33
+
34
+ iou = inter_area / union_area
35
+
36
+ return iou
37
+
38
+
39
+ def _track_by_iou(res, results_last, thr):
40
+ """Get track id using IoU tracking greedily."""
41
+
42
+ bbox = list(np.squeeze(res.pred_instances.bboxes, axis=0))
43
+
44
+ max_iou_score = -1
45
+ max_index = -1
46
+ match_result = {}
47
+ for index, res_last in enumerate(results_last):
48
+ bbox_last = list(np.squeeze(res_last.pred_instances.bboxes, axis=0))
49
+
50
+ iou_score = _compute_iou(bbox, bbox_last)
51
+ if iou_score > max_iou_score:
52
+ max_iou_score = iou_score
53
+ max_index = index
54
+
55
+ if max_iou_score > thr:
56
+ track_id = results_last[max_index].track_id
57
+ match_result = results_last[max_index]
58
+ del results_last[max_index]
59
+ else:
60
+ track_id = -1
61
+
62
+ return track_id, results_last, match_result
63
+
64
+
65
+ def _track_by_oks(res, results_last, thr, sigmas=None):
66
+ """Get track id using OKS tracking greedily."""
67
+ keypoint = np.concatenate((res.pred_instances.keypoints,
68
+ res.pred_instances.keypoint_scores[:, :, None]),
69
+ axis=2)
70
+ keypoint = np.squeeze(keypoint, axis=0).reshape((-1))
71
+ area = np.squeeze(res.pred_instances.areas, axis=0)
72
+ max_index = -1
73
+ match_result = {}
74
+
75
+ if len(results_last) == 0:
76
+ return -1, results_last, match_result
77
+
78
+ keypoints_last = np.array([
79
+ np.squeeze(
80
+ np.concatenate(
81
+ (res_last.pred_instances.keypoints,
82
+ res_last.pred_instances.keypoint_scores[:, :, None]),
83
+ axis=2),
84
+ axis=0).reshape((-1)) for res_last in results_last
85
+ ])
86
+ area_last = np.array([
87
+ np.squeeze(res_last.pred_instances.areas, axis=0)
88
+ for res_last in results_last
89
+ ])
90
+
91
+ oks_score = oks_iou(
92
+ keypoint, keypoints_last, area, area_last, sigmas=sigmas)
93
+
94
+ max_index = np.argmax(oks_score)
95
+
96
+ if oks_score[max_index] > thr:
97
+ track_id = results_last[max_index].track_id
98
+ match_result = results_last[max_index]
99
+ del results_last[max_index]
100
+ else:
101
+ track_id = -1
102
+
103
+ return track_id, results_last, match_result
mmpose/apis/inferencers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .hand3d_inferencer import Hand3DInferencer
3
+ from .mmpose_inferencer import MMPoseInferencer
4
+ from .pose2d_inferencer import Pose2DInferencer
5
+ from .pose3d_inferencer import Pose3DInferencer
6
+ from .utils import get_model_aliases
7
+
8
+ __all__ = [
9
+ 'Pose2DInferencer', 'MMPoseInferencer', 'get_model_aliases',
10
+ 'Pose3DInferencer', 'Hand3DInferencer'
11
+ ]
mmpose/apis/inferencers/base_mmpose_inferencer.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+ import logging
4
+ import mimetypes
5
+ import os
6
+ from collections import defaultdict
7
+ from typing import (Callable, Dict, Generator, Iterable, List, Optional,
8
+ Sequence, Tuple, Union)
9
+
10
+ import cv2
11
+ import mmcv
12
+ import mmengine
13
+ import numpy as np
14
+ import torch.nn as nn
15
+ from mmengine.config import Config, ConfigDict
16
+ from mmengine.dataset import Compose
17
+ from mmengine.fileio import (get_file_backend, isdir, join_path,
18
+ list_dir_or_file)
19
+ from mmengine.infer.infer import BaseInferencer, ModelType
20
+ from mmengine.logging import print_log
21
+ from mmengine.registry import init_default_scope
22
+ from mmengine.runner.checkpoint import _load_checkpoint_to_model
23
+ from mmengine.structures import InstanceData
24
+ from mmengine.utils import mkdir_or_exist
25
+ from rich.progress import track
26
+
27
+ from mmpose.apis.inference import dataset_meta_from_config
28
+ from mmpose.registry import DATASETS
29
+ from mmpose.structures import PoseDataSample, split_instances
30
+ from .utils import default_det_models
31
+
32
+ try:
33
+ from mmdet.apis.det_inferencer import DetInferencer
34
+ has_mmdet = True
35
+ except (ImportError, ModuleNotFoundError):
36
+ has_mmdet = False
37
+
38
+ InstanceList = List[InstanceData]
39
+ InputType = Union[str, np.ndarray]
40
+ InputsType = Union[InputType, Sequence[InputType]]
41
+ PredType = Union[InstanceData, InstanceList]
42
+ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
43
+ ConfigType = Union[Config, ConfigDict]
44
+ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
45
+
46
+
47
+ class BaseMMPoseInferencer(BaseInferencer):
48
+ """The base class for MMPose inferencers."""
49
+
50
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
51
+ forward_kwargs: set = set()
52
+ visualize_kwargs: set = {
53
+ 'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
54
+ 'kpt_thr', 'vis_out_dir', 'black_background'
55
+ }
56
+ postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
57
+
58
+ def __init__(self,
59
+ model: Union[ModelType, str, None] = None,
60
+ weights: Optional[str] = None,
61
+ device: Optional[str] = None,
62
+ scope: Optional[str] = None,
63
+ show_progress: bool = False) -> None:
64
+ super().__init__(
65
+ model, weights, device, scope, show_progress=show_progress)
66
+
67
+ def _init_detector(
68
+ self,
69
+ det_model: Optional[Union[ModelType, str]] = None,
70
+ det_weights: Optional[str] = None,
71
+ det_cat_ids: Optional[Union[int, Tuple]] = None,
72
+ device: Optional[str] = None,
73
+ ):
74
+ object_type = DATASETS.get(self.cfg.dataset_type).__module__.split(
75
+ 'datasets.')[-1].split('.')[0].lower()
76
+
77
+ if det_model in ('whole_image', 'whole-image') or \
78
+ (det_model is None and
79
+ object_type not in default_det_models):
80
+ self.detector = None
81
+
82
+ else:
83
+ det_scope = 'mmdet'
84
+ if det_model is None:
85
+ det_info = default_det_models[object_type]
86
+ det_model, det_weights, det_cat_ids = det_info[
87
+ 'model'], det_info['weights'], det_info['cat_ids']
88
+ elif os.path.exists(det_model):
89
+ det_cfg = Config.fromfile(det_model)
90
+ det_scope = det_cfg.default_scope
91
+
92
+ if has_mmdet:
93
+ det_kwargs = dict(
94
+ model=det_model,
95
+ weights=det_weights,
96
+ device=device,
97
+ scope=det_scope,
98
+ )
99
+ # for compatibility with low version of mmdet
100
+ if 'show_progress' in inspect.signature(
101
+ DetInferencer).parameters:
102
+ det_kwargs['show_progress'] = False
103
+
104
+ self.detector = DetInferencer(**det_kwargs)
105
+ else:
106
+ raise RuntimeError(
107
+ 'MMDetection (v3.0.0 or above) is required to build '
108
+ 'inferencers for top-down pose estimation models.')
109
+
110
+ if isinstance(det_cat_ids, (tuple, list)):
111
+ self.det_cat_ids = det_cat_ids
112
+ else:
113
+ self.det_cat_ids = (det_cat_ids, )
114
+
115
+ def _load_weights_to_model(self, model: nn.Module,
116
+ checkpoint: Optional[dict],
117
+ cfg: Optional[ConfigType]) -> None:
118
+ """Loading model weights and meta information from cfg and checkpoint.
119
+
120
+ Subclasses could override this method to load extra meta information
121
+ from ``checkpoint`` and ``cfg`` to model.
122
+
123
+ Args:
124
+ model (nn.Module): Model to load weights and meta information.
125
+ checkpoint (dict, optional): The loaded checkpoint.
126
+ cfg (Config or ConfigDict, optional): The loaded config.
127
+ """
128
+ if checkpoint is not None:
129
+ _load_checkpoint_to_model(model, checkpoint)
130
+ checkpoint_meta = checkpoint.get('meta', {})
131
+ # save the dataset_meta in the model for convenience
132
+ if 'dataset_meta' in checkpoint_meta:
133
+ # mmpose 1.x
134
+ model.dataset_meta = checkpoint_meta['dataset_meta']
135
+ else:
136
+ print_log(
137
+ 'dataset_meta are not saved in the checkpoint\'s '
138
+ 'meta data, load via config.',
139
+ logger='current',
140
+ level=logging.WARNING)
141
+ model.dataset_meta = dataset_meta_from_config(
142
+ cfg, dataset_mode='train')
143
+ else:
144
+ print_log(
145
+ 'Checkpoint is not loaded, and the inference '
146
+ 'result is calculated by the randomly initialized '
147
+ 'model!',
148
+ logger='current',
149
+ level=logging.WARNING)
150
+ model.dataset_meta = dataset_meta_from_config(
151
+ cfg, dataset_mode='train')
152
+
153
+ def _inputs_to_list(self, inputs: InputsType) -> Iterable:
154
+ """Preprocess the inputs to a list.
155
+
156
+ Preprocess inputs to a list according to its type:
157
+
158
+ - list or tuple: return inputs
159
+ - str:
160
+ - Directory path: return all files in the directory
161
+ - other cases: return a list containing the string. The string
162
+ could be a path to file, a url or other types of string
163
+ according to the task.
164
+
165
+ Args:
166
+ inputs (InputsType): Inputs for the inferencer.
167
+
168
+ Returns:
169
+ list: List of input for the :meth:`preprocess`.
170
+ """
171
+ self._video_input = False
172
+
173
+ if isinstance(inputs, str):
174
+ backend = get_file_backend(inputs)
175
+ if hasattr(backend, 'isdir') and isdir(inputs):
176
+ # Backends like HttpsBackend do not implement `isdir`, so only
177
+ # those backends that implement `isdir` could accept the
178
+ # inputs as a directory
179
+ filepath_list = [
180
+ join_path(inputs, fname)
181
+ for fname in list_dir_or_file(inputs, list_dir=False)
182
+ ]
183
+ inputs = []
184
+ for filepath in filepath_list:
185
+ input_type = mimetypes.guess_type(filepath)[0].split(
186
+ '/')[0]
187
+ if input_type == 'image':
188
+ inputs.append(filepath)
189
+ inputs.sort()
190
+ else:
191
+ # if inputs is a path to a video file, it will be converted
192
+ # to a list containing separated frame filenames
193
+ input_type = mimetypes.guess_type(inputs)[0].split('/')[0]
194
+ if input_type == 'video':
195
+ self._video_input = True
196
+ video = mmcv.VideoReader(inputs)
197
+ self.video_info = dict(
198
+ fps=video.fps,
199
+ name=os.path.basename(inputs),
200
+ writer=None,
201
+ width=video.width,
202
+ height=video.height,
203
+ predictions=[])
204
+ inputs = video
205
+ elif input_type == 'image':
206
+ inputs = [inputs]
207
+ else:
208
+ raise ValueError(f'Expected input to be an image, video, '
209
+ f'or folder, but received {inputs} of '
210
+ f'type {input_type}.')
211
+
212
+ elif isinstance(inputs, np.ndarray):
213
+ inputs = [inputs]
214
+
215
+ return inputs
216
+
217
+ def _get_webcam_inputs(self, inputs: str) -> Generator:
218
+ """Sets up and returns a generator function that reads frames from a
219
+ webcam input. The generator function returns a new frame each time it
220
+ is iterated over.
221
+
222
+ Args:
223
+ inputs (str): A string describing the webcam input, in the format
224
+ "webcam:id".
225
+
226
+ Returns:
227
+ A generator function that yields frames from the webcam input.
228
+
229
+ Raises:
230
+ ValueError: If the inputs string is not in the expected format.
231
+ """
232
+
233
+ # Ensure the inputs string is in the expected format.
234
+ inputs = inputs.lower()
235
+ assert inputs.startswith('webcam'), f'Expected input to start with ' \
236
+ f'"webcam", but got "{inputs}"'
237
+
238
+ # Parse the camera ID from the inputs string.
239
+ inputs_ = inputs.split(':')
240
+ if len(inputs_) == 1:
241
+ camera_id = 0
242
+ elif len(inputs_) == 2 and str.isdigit(inputs_[1]):
243
+ camera_id = int(inputs_[1])
244
+ else:
245
+ raise ValueError(
246
+ f'Expected webcam input to have format "webcam:id", '
247
+ f'but got "{inputs}"')
248
+
249
+ # Attempt to open the video capture object.
250
+ vcap = cv2.VideoCapture(camera_id)
251
+ if not vcap.isOpened():
252
+ print_log(
253
+ f'Cannot open camera (ID={camera_id})',
254
+ logger='current',
255
+ level=logging.WARNING)
256
+ return []
257
+
258
+ # Set video input flag and metadata.
259
+ self._video_input = True
260
+ (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
261
+ if int(major_ver) < 3:
262
+ fps = vcap.get(cv2.cv.CV_CAP_PROP_FPS)
263
+ width = vcap.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
264
+ height = vcap.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
265
+ else:
266
+ fps = vcap.get(cv2.CAP_PROP_FPS)
267
+ width = vcap.get(cv2.CAP_PROP_FRAME_WIDTH)
268
+ height = vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)
269
+ self.video_info = dict(
270
+ fps=fps,
271
+ name='webcam.mp4',
272
+ writer=None,
273
+ width=width,
274
+ height=height,
275
+ predictions=[])
276
+
277
+ def _webcam_reader() -> Generator:
278
+ while True:
279
+ if cv2.waitKey(5) & 0xFF == 27:
280
+ vcap.release()
281
+ break
282
+
283
+ ret_val, frame = vcap.read()
284
+ if not ret_val:
285
+ break
286
+
287
+ yield frame
288
+
289
+ return _webcam_reader()
290
+
291
+ def _init_pipeline(self, cfg: ConfigType) -> Callable:
292
+ """Initialize the test pipeline.
293
+
294
+ Args:
295
+ cfg (ConfigType): model config path or dict
296
+
297
+ Returns:
298
+ A pipeline to handle various input data, such as ``str``,
299
+ ``np.ndarray``. The returned pipeline will be used to process
300
+ a single data.
301
+ """
302
+ scope = cfg.get('default_scope', 'mmpose')
303
+ if scope is not None:
304
+ init_default_scope(scope)
305
+ return Compose(cfg.test_dataloader.dataset.pipeline)
306
+
307
+ def update_model_visualizer_settings(self, **kwargs):
308
+ """Update the settings of models and visualizer according to inference
309
+ arguments."""
310
+
311
+ pass
312
+
313
+ def preprocess(self,
314
+ inputs: InputsType,
315
+ batch_size: int = 1,
316
+ bboxes: Optional[List] = None,
317
+ bbox_thr: float = 0.3,
318
+ nms_thr: float = 0.3,
319
+ **kwargs):
320
+ """Process the inputs into a model-feedable format.
321
+
322
+ Args:
323
+ inputs (InputsType): Inputs given by user.
324
+ batch_size (int): batch size. Defaults to 1.
325
+ bbox_thr (float): threshold for bounding box detection.
326
+ Defaults to 0.3.
327
+ nms_thr (float): IoU threshold for bounding box NMS.
328
+ Defaults to 0.3.
329
+
330
+ Yields:
331
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
332
+ List[str or np.ndarray]: List of original inputs in the batch
333
+ """
334
+
335
+ # One-stage pose estimators perform prediction filtering within the
336
+ # head's `predict` method. Here, we set the arguments for filtering
337
+ if self.cfg.model.type == 'BottomupPoseEstimator':
338
+ # 1. init with default arguments
339
+ test_cfg = self.model.head.test_cfg.copy()
340
+ # 2. update the score_thr and nms_thr in the test_cfg of the head
341
+ if 'score_thr' in test_cfg:
342
+ test_cfg['score_thr'] = bbox_thr
343
+ if 'nms_thr' in test_cfg:
344
+ test_cfg['nms_thr'] = nms_thr
345
+ self.model.test_cfg = test_cfg
346
+
347
+ for i, input in enumerate(inputs):
348
+ bbox = bboxes[i] if bboxes else []
349
+ data_infos = self.preprocess_single(
350
+ input,
351
+ index=i,
352
+ bboxes=bbox,
353
+ bbox_thr=bbox_thr,
354
+ nms_thr=nms_thr,
355
+ **kwargs)
356
+ # only supports inference with batch size 1
357
+ yield self.collate_fn(data_infos), [input]
358
+
359
+ def __call__(
360
+ self,
361
+ inputs: InputsType,
362
+ return_datasamples: bool = False,
363
+ batch_size: int = 1,
364
+ out_dir: Optional[str] = None,
365
+ **kwargs,
366
+ ) -> dict:
367
+ """Call the inferencer.
368
+
369
+ Args:
370
+ inputs (InputsType): Inputs for the inferencer.
371
+ return_datasamples (bool): Whether to return results as
372
+ :obj:`BaseDataElement`. Defaults to False.
373
+ batch_size (int): Batch size. Defaults to 1.
374
+ out_dir (str, optional): directory to save visualization
375
+ results and predictions. Will be overoden if vis_out_dir or
376
+ pred_out_dir are given. Defaults to None
377
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
378
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
379
+ Each key in kwargs should be in the corresponding set of
380
+ ``preprocess_kwargs``, ``forward_kwargs``,
381
+ ``visualize_kwargs`` and ``postprocess_kwargs``.
382
+
383
+ Returns:
384
+ dict: Inference and visualization results.
385
+ """
386
+ if out_dir is not None:
387
+ if 'vis_out_dir' not in kwargs:
388
+ kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
389
+ if 'pred_out_dir' not in kwargs:
390
+ kwargs['pred_out_dir'] = f'{out_dir}/predictions'
391
+
392
+ (
393
+ preprocess_kwargs,
394
+ forward_kwargs,
395
+ visualize_kwargs,
396
+ postprocess_kwargs,
397
+ ) = self._dispatch_kwargs(**kwargs)
398
+
399
+ self.update_model_visualizer_settings(**kwargs)
400
+
401
+ # preprocessing
402
+ if isinstance(inputs, str) and inputs.startswith('webcam'):
403
+ inputs = self._get_webcam_inputs(inputs)
404
+ batch_size = 1
405
+ if not visualize_kwargs.get('show', False):
406
+ print_log(
407
+ 'The display mode is closed when using webcam '
408
+ 'input. It will be turned on automatically.',
409
+ logger='current',
410
+ level=logging.WARNING)
411
+ visualize_kwargs['show'] = True
412
+ else:
413
+ inputs = self._inputs_to_list(inputs)
414
+
415
+ # check the compatibility between inputs/outputs
416
+ if not self._video_input and len(inputs) > 0:
417
+ vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
418
+ if vis_out_dir is not None:
419
+ _, file_extension = os.path.splitext(vis_out_dir)
420
+ assert not file_extension, f'the argument `vis_out_dir` ' \
421
+ f'should be a folder while the input contains multiple ' \
422
+ f'images, but got {vis_out_dir}'
423
+
424
+ if 'bbox_thr' in self.forward_kwargs:
425
+ forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
426
+ inputs = self.preprocess(
427
+ inputs, batch_size=batch_size, **preprocess_kwargs)
428
+
429
+ preds = []
430
+
431
+ for proc_inputs, ori_inputs in (track(inputs, description='Inference')
432
+ if self.show_progress else inputs):
433
+ preds = self.forward(proc_inputs, **forward_kwargs)
434
+
435
+ visualization = self.visualize(ori_inputs, preds,
436
+ **visualize_kwargs)
437
+ results = self.postprocess(
438
+ preds,
439
+ visualization,
440
+ return_datasamples=return_datasamples,
441
+ **postprocess_kwargs)
442
+ yield results
443
+
444
+ if self._video_input:
445
+ self._finalize_video_processing(
446
+ postprocess_kwargs.get('pred_out_dir', ''))
447
+
448
+ # In 3D Inferencers, some intermediate results (e.g. 2d keypoints)
449
+ # will be temporarily stored in `self._buffer`. It's essential to
450
+ # clear this information to prevent any interference with subsequent
451
+ # inferences.
452
+ if hasattr(self, '_buffer'):
453
+ self._buffer.clear()
454
+
455
+ def visualize(self,
456
+ inputs: list,
457
+ preds: List[PoseDataSample],
458
+ return_vis: bool = False,
459
+ show: bool = False,
460
+ draw_bbox: bool = False,
461
+ wait_time: float = 0,
462
+ radius: int = 3,
463
+ thickness: int = 1,
464
+ kpt_thr: float = 0.3,
465
+ vis_out_dir: str = '',
466
+ window_name: str = '',
467
+ black_background: bool = False,
468
+ **kwargs) -> List[np.ndarray]:
469
+ """Visualize predictions.
470
+
471
+ Args:
472
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
473
+ preds (Any): Predictions of the model.
474
+ return_vis (bool): Whether to return images with predicted results.
475
+ show (bool): Whether to display the image in a popup window.
476
+ Defaults to False.
477
+ wait_time (float): The interval of show (ms). Defaults to 0
478
+ draw_bbox (bool): Whether to draw the bounding boxes.
479
+ Defaults to False
480
+ radius (int): Keypoint radius for visualization. Defaults to 3
481
+ thickness (int): Link thickness for visualization. Defaults to 1
482
+ kpt_thr (float): The threshold to visualize the keypoints.
483
+ Defaults to 0.3
484
+ vis_out_dir (str, optional): Directory to save visualization
485
+ results w/o predictions. If left as empty, no file will
486
+ be saved. Defaults to ''.
487
+ window_name (str, optional): Title of display window.
488
+ black_background (bool, optional): Whether to plot keypoints on a
489
+ black image instead of the input image. Defaults to False.
490
+
491
+ Returns:
492
+ List[np.ndarray]: Visualization results.
493
+ """
494
+ if (not return_vis) and (not show) and (not vis_out_dir):
495
+ return
496
+
497
+ if getattr(self, 'visualizer', None) is None:
498
+ raise ValueError('Visualization needs the "visualizer" term'
499
+ 'defined in the config, but got None.')
500
+
501
+ self.visualizer.radius = radius
502
+ self.visualizer.line_width = thickness
503
+
504
+ results = []
505
+
506
+ for single_input, pred in zip(inputs, preds):
507
+ if isinstance(single_input, str):
508
+ img = mmcv.imread(single_input, channel_order='rgb')
509
+ elif isinstance(single_input, np.ndarray):
510
+ img = mmcv.bgr2rgb(single_input)
511
+ else:
512
+ raise ValueError('Unsupported input type: '
513
+ f'{type(single_input)}')
514
+ if black_background:
515
+ img = img * 0
516
+
517
+ img_name = os.path.basename(pred.metainfo['img_path'])
518
+ window_name = window_name if window_name else img_name
519
+
520
+ # since visualization and inference utilize the same process,
521
+ # the wait time is reduced when a video input is utilized,
522
+ # thereby eliminating the issue of inference getting stuck.
523
+ wait_time = 1e-5 if self._video_input else wait_time
524
+
525
+ visualization = self.visualizer.add_datasample(
526
+ window_name,
527
+ img,
528
+ pred,
529
+ draw_gt=False,
530
+ draw_bbox=draw_bbox,
531
+ show=show,
532
+ wait_time=wait_time,
533
+ kpt_thr=kpt_thr,
534
+ **kwargs)
535
+ results.append(visualization)
536
+
537
+ if vis_out_dir:
538
+ self.save_visualization(
539
+ visualization,
540
+ vis_out_dir,
541
+ img_name=img_name,
542
+ )
543
+
544
+ if return_vis:
545
+ return results
546
+ else:
547
+ return []
548
+
549
+ def save_visualization(self, visualization, vis_out_dir, img_name=None):
550
+ out_img = mmcv.rgb2bgr(visualization)
551
+ _, file_extension = os.path.splitext(vis_out_dir)
552
+ if file_extension:
553
+ dir_name = os.path.dirname(vis_out_dir)
554
+ file_name = os.path.basename(vis_out_dir)
555
+ else:
556
+ dir_name = vis_out_dir
557
+ file_name = None
558
+ mkdir_or_exist(dir_name)
559
+
560
+ if self._video_input:
561
+
562
+ if self.video_info['writer'] is None:
563
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
564
+ if file_name is None:
565
+ file_name = os.path.basename(self.video_info['name'])
566
+ out_file = join_path(dir_name, file_name)
567
+ self.video_info['output_file'] = out_file
568
+ self.video_info['writer'] = cv2.VideoWriter(
569
+ out_file, fourcc, self.video_info['fps'],
570
+ (visualization.shape[1], visualization.shape[0]))
571
+ self.video_info['writer'].write(out_img)
572
+
573
+ else:
574
+ if file_name is None:
575
+ file_name = img_name if img_name else 'visualization.jpg'
576
+
577
+ out_file = join_path(dir_name, file_name)
578
+ mmcv.imwrite(out_img, out_file)
579
+ print_log(
580
+ f'the output image has been saved at {out_file}',
581
+ logger='current',
582
+ level=logging.INFO)
583
+
584
+ def postprocess(
585
+ self,
586
+ preds: List[PoseDataSample],
587
+ visualization: List[np.ndarray],
588
+ return_datasample=None,
589
+ return_datasamples=False,
590
+ pred_out_dir: str = '',
591
+ ) -> dict:
592
+ """Process the predictions and visualization results from ``forward``
593
+ and ``visualize``.
594
+
595
+ This method should be responsible for the following tasks:
596
+
597
+ 1. Convert datasamples into a json-serializable dict if needed.
598
+ 2. Pack the predictions and visualization results and return them.
599
+ 3. Dump or log the predictions.
600
+
601
+ Args:
602
+ preds (List[Dict]): Predictions of the model.
603
+ visualization (np.ndarray): Visualized predictions.
604
+ return_datasamples (bool): Whether to return results as
605
+ datasamples. Defaults to False.
606
+ pred_out_dir (str): Directory to save the inference results w/o
607
+ visualization. If left as empty, no file will be saved.
608
+ Defaults to ''.
609
+
610
+ Returns:
611
+ dict: Inference and visualization results with key ``predictions``
612
+ and ``visualization``
613
+
614
+ - ``visualization (Any)``: Returned by :meth:`visualize`
615
+ - ``predictions`` (dict or DataSample): Returned by
616
+ :meth:`forward` and processed in :meth:`postprocess`.
617
+ If ``return_datasamples=False``, it usually should be a
618
+ json-serializable dict containing only basic data elements such
619
+ as strings and numbers.
620
+ """
621
+ if return_datasample is not None:
622
+ print_log(
623
+ 'The `return_datasample` argument is deprecated '
624
+ 'and will be removed in future versions. Please '
625
+ 'use `return_datasamples`.',
626
+ logger='current',
627
+ level=logging.WARNING)
628
+ return_datasamples = return_datasample
629
+
630
+ result_dict = defaultdict(list)
631
+
632
+ result_dict['visualization'] = visualization
633
+ for pred in preds:
634
+ if not return_datasamples:
635
+ # convert datasamples to list of instance predictions
636
+ pred = split_instances(pred.pred_instances)
637
+ result_dict['predictions'].append(pred)
638
+
639
+ if pred_out_dir != '':
640
+ for pred, data_sample in zip(result_dict['predictions'], preds):
641
+ if self._video_input:
642
+ # For video or webcam input, predictions for each frame
643
+ # are gathered in the 'predictions' key of 'video_info'
644
+ # dictionary. All frame predictions are then stored into
645
+ # a single file after processing all frames.
646
+ self.video_info['predictions'].append(pred)
647
+ else:
648
+ # For non-video inputs, predictions are stored in separate
649
+ # JSON files. The filename is determined by the basename
650
+ # of the input image path with a '.json' extension. The
651
+ # predictions are then dumped into this file.
652
+ fname = os.path.splitext(
653
+ os.path.basename(
654
+ data_sample.metainfo['img_path']))[0] + '.json'
655
+ mmengine.dump(
656
+ pred, join_path(pred_out_dir, fname), indent=' ')
657
+
658
+ return result_dict
659
+
660
+ def _finalize_video_processing(
661
+ self,
662
+ pred_out_dir: str = '',
663
+ ):
664
+ """Finalize video processing by releasing the video writer and saving
665
+ predictions to a file.
666
+
667
+ This method should be called after completing the video processing. It
668
+ releases the video writer, if it exists, and saves the predictions to a
669
+ JSON file if a prediction output directory is provided.
670
+ """
671
+
672
+ # Release the video writer if it exists
673
+ if self.video_info['writer'] is not None:
674
+ out_file = self.video_info['output_file']
675
+ print_log(
676
+ f'the output video has been saved at {out_file}',
677
+ logger='current',
678
+ level=logging.INFO)
679
+ self.video_info['writer'].release()
680
+
681
+ # Save predictions
682
+ if pred_out_dir:
683
+ fname = os.path.splitext(
684
+ os.path.basename(self.video_info['name']))[0] + '.json'
685
+ predictions = [
686
+ dict(frame_id=i, instances=pred)
687
+ for i, pred in enumerate(self.video_info['predictions'])
688
+ ]
689
+
690
+ mmengine.dump(
691
+ predictions, join_path(pred_out_dir, fname), indent=' ')
mmpose/apis/inferencers/hand3d_inferencer.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+ import os
4
+ from collections import defaultdict
5
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import mmcv
8
+ import numpy as np
9
+ import torch
10
+ from mmengine.config import Config, ConfigDict
11
+ from mmengine.infer.infer import ModelType
12
+ from mmengine.logging import print_log
13
+ from mmengine.model import revert_sync_batchnorm
14
+ from mmengine.registry import init_default_scope
15
+ from mmengine.structures import InstanceData
16
+
17
+ from mmpose.evaluation.functional import nms
18
+ from mmpose.registry import INFERENCERS
19
+ from mmpose.structures import PoseDataSample, merge_data_samples
20
+ from .base_mmpose_inferencer import BaseMMPoseInferencer
21
+
22
+ InstanceList = List[InstanceData]
23
+ InputType = Union[str, np.ndarray]
24
+ InputsType = Union[InputType, Sequence[InputType]]
25
+ PredType = Union[InstanceData, InstanceList]
26
+ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
27
+ ConfigType = Union[Config, ConfigDict]
28
+ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
29
+
30
+
31
+ @INFERENCERS.register_module()
32
+ class Hand3DInferencer(BaseMMPoseInferencer):
33
+ """The inferencer for 3D hand pose estimation.
34
+
35
+ Args:
36
+ model (str, optional): Pretrained 2D pose estimation algorithm.
37
+ It's the path to the config file or the model name defined in
38
+ metafile. For example, it could be:
39
+
40
+ - model alias, e.g. ``'body'``,
41
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
42
+ - config path
43
+
44
+ Defaults to ``None``.
45
+ weights (str, optional): Path to the checkpoint. If it is not
46
+ specified and "model" is a model name of metafile, the weights
47
+ will be loaded from metafile. Defaults to None.
48
+ device (str, optional): Device to run inference. If None, the
49
+ available device will be automatically used. Defaults to None.
50
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
51
+ det_model (str, optional): Config path or alias of detection model.
52
+ Defaults to None.
53
+ det_weights (str, optional): Path to the checkpoints of detection
54
+ model. Defaults to None.
55
+ det_cat_ids (int or list[int], optional): Category id for
56
+ detection model. Defaults to None.
57
+ """
58
+
59
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
60
+ forward_kwargs: set = {'disable_rebase_keypoint'}
61
+ visualize_kwargs: set = {
62
+ 'return_vis',
63
+ 'show',
64
+ 'wait_time',
65
+ 'draw_bbox',
66
+ 'radius',
67
+ 'thickness',
68
+ 'kpt_thr',
69
+ 'vis_out_dir',
70
+ 'num_instances',
71
+ }
72
+ postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
73
+
74
+ def __init__(self,
75
+ model: Union[ModelType, str],
76
+ weights: Optional[str] = None,
77
+ device: Optional[str] = None,
78
+ scope: Optional[str] = 'mmpose',
79
+ det_model: Optional[Union[ModelType, str]] = None,
80
+ det_weights: Optional[str] = None,
81
+ det_cat_ids: Optional[Union[int, Tuple]] = None,
82
+ show_progress: bool = False) -> None:
83
+
84
+ init_default_scope(scope)
85
+ super().__init__(
86
+ model=model,
87
+ weights=weights,
88
+ device=device,
89
+ scope=scope,
90
+ show_progress=show_progress)
91
+ self.model = revert_sync_batchnorm(self.model)
92
+
93
+ # assign dataset metainfo to self.visualizer
94
+ self.visualizer.set_dataset_meta(self.model.dataset_meta)
95
+
96
+ # initialize hand detector
97
+ self._init_detector(
98
+ det_model=det_model,
99
+ det_weights=det_weights,
100
+ det_cat_ids=det_cat_ids,
101
+ device=device,
102
+ )
103
+
104
+ self._video_input = False
105
+ self._buffer = defaultdict(list)
106
+
107
+ def preprocess_single(self,
108
+ input: InputType,
109
+ index: int,
110
+ bbox_thr: float = 0.3,
111
+ nms_thr: float = 0.3,
112
+ bboxes: Union[List[List], List[np.ndarray],
113
+ np.ndarray] = []):
114
+ """Process a single input into a model-feedable format.
115
+
116
+ Args:
117
+ input (InputType): Input given by user.
118
+ index (int): index of the input
119
+ bbox_thr (float): threshold for bounding box detection.
120
+ Defaults to 0.3.
121
+ nms_thr (float): IoU threshold for bounding box NMS.
122
+ Defaults to 0.3.
123
+
124
+ Yields:
125
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
126
+ """
127
+
128
+ if isinstance(input, str):
129
+ data_info = dict(img_path=input)
130
+ else:
131
+ data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0'))
132
+ data_info.update(self.model.dataset_meta)
133
+
134
+ if self.detector is not None:
135
+ try:
136
+ det_results = self.detector(
137
+ input, return_datasamples=True)['predictions']
138
+ except ValueError:
139
+ print_log(
140
+ 'Support for mmpose and mmdet versions up to 3.1.0 '
141
+ 'will be discontinued in upcoming releases. To '
142
+ 'ensure ongoing compatibility, please upgrade to '
143
+ 'mmdet version 3.2.0 or later.',
144
+ logger='current',
145
+ level=logging.WARNING)
146
+ det_results = self.detector(
147
+ input, return_datasample=True)['predictions']
148
+ pred_instance = det_results[0].pred_instances.cpu().numpy()
149
+ bboxes = np.concatenate(
150
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
151
+
152
+ label_mask = np.zeros(len(bboxes), dtype=np.uint8)
153
+ for cat_id in self.det_cat_ids:
154
+ label_mask = np.logical_or(label_mask,
155
+ pred_instance.labels == cat_id)
156
+
157
+ bboxes = bboxes[np.logical_and(label_mask,
158
+ pred_instance.scores > bbox_thr)]
159
+ bboxes = bboxes[nms(bboxes, nms_thr)]
160
+
161
+ data_infos = []
162
+ if len(bboxes) > 0:
163
+ for bbox in bboxes:
164
+ inst = data_info.copy()
165
+ inst['bbox'] = bbox[None, :4]
166
+ inst['bbox_score'] = bbox[4:5]
167
+ data_infos.append(self.pipeline(inst))
168
+ else:
169
+ inst = data_info.copy()
170
+
171
+ # get bbox from the image size
172
+ if isinstance(input, str):
173
+ input = mmcv.imread(input)
174
+ h, w = input.shape[:2]
175
+
176
+ inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32)
177
+ inst['bbox_score'] = np.ones(1, dtype=np.float32)
178
+ data_infos.append(self.pipeline(inst))
179
+
180
+ return data_infos
181
+
182
+ @torch.no_grad()
183
+ def forward(self,
184
+ inputs: Union[dict, tuple],
185
+ disable_rebase_keypoint: bool = False):
186
+ """Performs a forward pass through the model.
187
+
188
+ Args:
189
+ inputs (Union[dict, tuple]): The input data to be processed. Can
190
+ be either a dictionary or a tuple.
191
+ disable_rebase_keypoint (bool, optional): Flag to disable rebasing
192
+ the height of the keypoints. Defaults to False.
193
+
194
+ Returns:
195
+ A list of data samples with prediction instances.
196
+ """
197
+ data_samples = self.model.test_step(inputs)
198
+ data_samples_2d = []
199
+
200
+ for idx, res in enumerate(data_samples):
201
+ pred_instances = res.pred_instances
202
+ keypoints = pred_instances.keypoints
203
+ rel_root_depth = pred_instances.rel_root_depth
204
+ scores = pred_instances.keypoint_scores
205
+ hand_type = pred_instances.hand_type
206
+
207
+ res_2d = PoseDataSample()
208
+ gt_instances = res.gt_instances.clone()
209
+ pred_instances = pred_instances.clone()
210
+ res_2d.gt_instances = gt_instances
211
+ res_2d.pred_instances = pred_instances
212
+
213
+ # add relative root depth to left hand joints
214
+ keypoints[:, 21:, 2] += rel_root_depth
215
+
216
+ # set joint scores according to hand type
217
+ scores[:, :21] *= hand_type[:, [0]]
218
+ scores[:, 21:] *= hand_type[:, [1]]
219
+ # normalize kpt score
220
+ if scores.max() > 1:
221
+ scores /= 255
222
+
223
+ res_2d.pred_instances.set_field(keypoints[..., :2].copy(),
224
+ 'keypoints')
225
+
226
+ # rotate the keypoint to make z-axis correspondent to height
227
+ # for better visualization
228
+ vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
229
+ keypoints[..., :3] = keypoints[..., :3] @ vis_R
230
+
231
+ # rebase height (z-axis)
232
+ if not disable_rebase_keypoint:
233
+ valid = scores > 0
234
+ keypoints[..., 2] -= np.min(
235
+ keypoints[valid, 2], axis=-1, keepdims=True)
236
+
237
+ data_samples[idx].pred_instances.keypoints = keypoints
238
+ data_samples[idx].pred_instances.keypoint_scores = scores
239
+ data_samples_2d.append(res_2d)
240
+
241
+ data_samples = [merge_data_samples(data_samples)]
242
+ data_samples_2d = merge_data_samples(data_samples_2d)
243
+
244
+ self._buffer['pose2d_results'] = data_samples_2d
245
+
246
+ return data_samples
247
+
248
+ def visualize(
249
+ self,
250
+ inputs: list,
251
+ preds: List[PoseDataSample],
252
+ return_vis: bool = False,
253
+ show: bool = False,
254
+ draw_bbox: bool = False,
255
+ wait_time: float = 0,
256
+ radius: int = 3,
257
+ thickness: int = 1,
258
+ kpt_thr: float = 0.3,
259
+ num_instances: int = 1,
260
+ vis_out_dir: str = '',
261
+ window_name: str = '',
262
+ ) -> List[np.ndarray]:
263
+ """Visualize predictions.
264
+
265
+ Args:
266
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
267
+ preds (Any): Predictions of the model.
268
+ return_vis (bool): Whether to return images with predicted results.
269
+ show (bool): Whether to display the image in a popup window.
270
+ Defaults to False.
271
+ wait_time (float): The interval of show (ms). Defaults to 0
272
+ draw_bbox (bool): Whether to draw the bounding boxes.
273
+ Defaults to False
274
+ radius (int): Keypoint radius for visualization. Defaults to 3
275
+ thickness (int): Link thickness for visualization. Defaults to 1
276
+ kpt_thr (float): The threshold to visualize the keypoints.
277
+ Defaults to 0.3
278
+ vis_out_dir (str, optional): Directory to save visualization
279
+ results w/o predictions. If left as empty, no file will
280
+ be saved. Defaults to ''.
281
+ window_name (str, optional): Title of display window.
282
+ window_close_event_handler (callable, optional):
283
+
284
+ Returns:
285
+ List[np.ndarray]: Visualization results.
286
+ """
287
+ if (not return_vis) and (not show) and (not vis_out_dir):
288
+ return
289
+
290
+ if getattr(self, 'visualizer', None) is None:
291
+ raise ValueError('Visualization needs the "visualizer" term'
292
+ 'defined in the config, but got None.')
293
+
294
+ self.visualizer.radius = radius
295
+ self.visualizer.line_width = thickness
296
+
297
+ results = []
298
+
299
+ for single_input, pred in zip(inputs, preds):
300
+ if isinstance(single_input, str):
301
+ img = mmcv.imread(single_input, channel_order='rgb')
302
+ elif isinstance(single_input, np.ndarray):
303
+ img = mmcv.bgr2rgb(single_input)
304
+ else:
305
+ raise ValueError('Unsupported input type: '
306
+ f'{type(single_input)}')
307
+ img_name = os.path.basename(pred.metainfo['img_path'])
308
+
309
+ # since visualization and inference utilize the same process,
310
+ # the wait time is reduced when a video input is utilized,
311
+ # thereby eliminating the issue of inference getting stuck.
312
+ wait_time = 1e-5 if self._video_input else wait_time
313
+
314
+ if num_instances < 0:
315
+ num_instances = len(pred.pred_instances)
316
+
317
+ visualization = self.visualizer.add_datasample(
318
+ window_name,
319
+ img,
320
+ data_sample=pred,
321
+ det_data_sample=self._buffer['pose2d_results'],
322
+ draw_gt=False,
323
+ draw_bbox=draw_bbox,
324
+ show=show,
325
+ wait_time=wait_time,
326
+ convert_keypoint=False,
327
+ axis_azimuth=-115,
328
+ axis_limit=200,
329
+ axis_elev=15,
330
+ kpt_thr=kpt_thr,
331
+ num_instances=num_instances)
332
+ results.append(visualization)
333
+
334
+ if vis_out_dir:
335
+ self.save_visualization(
336
+ visualization,
337
+ vis_out_dir,
338
+ img_name=img_name,
339
+ )
340
+
341
+ if return_vis:
342
+ return results
343
+ else:
344
+ return []
mmpose/apis/inferencers/mmpose_inferencer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+ from typing import Dict, List, Optional, Sequence, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmengine.config import Config, ConfigDict
8
+ from mmengine.infer.infer import ModelType
9
+ from mmengine.structures import InstanceData
10
+ from rich.progress import track
11
+
12
+ from .base_mmpose_inferencer import BaseMMPoseInferencer
13
+ from .hand3d_inferencer import Hand3DInferencer
14
+ from .pose2d_inferencer import Pose2DInferencer
15
+ from .pose3d_inferencer import Pose3DInferencer
16
+
17
+ InstanceList = List[InstanceData]
18
+ InputType = Union[str, np.ndarray]
19
+ InputsType = Union[InputType, Sequence[InputType]]
20
+ PredType = Union[InstanceData, InstanceList]
21
+ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
22
+ ConfigType = Union[Config, ConfigDict]
23
+ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
24
+
25
+
26
+ class MMPoseInferencer(BaseMMPoseInferencer):
27
+ """MMPose Inferencer. It's a unified inferencer interface for pose
28
+ estimation task, currently including: Pose2D. and it can be used to perform
29
+ 2D keypoint detection.
30
+
31
+ Args:
32
+ pose2d (str, optional): Pretrained 2D pose estimation algorithm.
33
+ It's the path to the config file or the model name defined in
34
+ metafile. For example, it could be:
35
+
36
+ - model alias, e.g. ``'body'``,
37
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
38
+ - config path
39
+
40
+ Defaults to ``None``.
41
+ pose2d_weights (str, optional): Path to the custom checkpoint file of
42
+ the selected pose2d model. If it is not specified and "pose2d" is
43
+ a model name of metafile, the weights will be loaded from
44
+ metafile. Defaults to None.
45
+ device (str, optional): Device to run inference. If None, the
46
+ available device will be automatically used. Defaults to None.
47
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
48
+ det_model(str, optional): Config path or alias of detection model.
49
+ Defaults to None.
50
+ det_weights(str, optional): Path to the checkpoints of detection
51
+ model. Defaults to None.
52
+ det_cat_ids(int or list[int], optional): Category id for
53
+ detection model. Defaults to None.
54
+ output_heatmaps (bool, optional): Flag to visualize predicted
55
+ heatmaps. If set to None, the default setting from the model
56
+ config will be used. Default is None.
57
+ """
58
+
59
+ preprocess_kwargs: set = {
60
+ 'bbox_thr', 'nms_thr', 'bboxes', 'use_oks_tracking', 'tracking_thr',
61
+ 'disable_norm_pose_2d'
62
+ }
63
+ forward_kwargs: set = {
64
+ 'merge_results', 'disable_rebase_keypoint', 'pose_based_nms'
65
+ }
66
+ visualize_kwargs: set = {
67
+ 'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
68
+ 'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap',
69
+ 'black_background', 'num_instances'
70
+ }
71
+ postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
72
+
73
+ def __init__(self,
74
+ pose2d: Optional[str] = None,
75
+ pose2d_weights: Optional[str] = None,
76
+ pose3d: Optional[str] = None,
77
+ pose3d_weights: Optional[str] = None,
78
+ device: Optional[str] = None,
79
+ scope: str = 'mmpose',
80
+ det_model: Optional[Union[ModelType, str]] = None,
81
+ det_weights: Optional[str] = None,
82
+ det_cat_ids: Optional[Union[int, List]] = None,
83
+ show_progress: bool = False) -> None:
84
+
85
+ self.visualizer = None
86
+ self.show_progress = show_progress
87
+ if pose3d is not None:
88
+ if 'hand3d' in pose3d:
89
+ self.inferencer = Hand3DInferencer(pose3d, pose3d_weights,
90
+ device, scope, det_model,
91
+ det_weights, det_cat_ids,
92
+ show_progress)
93
+ else:
94
+ self.inferencer = Pose3DInferencer(pose3d, pose3d_weights,
95
+ pose2d, pose2d_weights,
96
+ device, scope, det_model,
97
+ det_weights, det_cat_ids,
98
+ show_progress)
99
+ elif pose2d is not None:
100
+ self.inferencer = Pose2DInferencer(pose2d, pose2d_weights, device,
101
+ scope, det_model, det_weights,
102
+ det_cat_ids, show_progress)
103
+ else:
104
+ raise ValueError('Either 2d or 3d pose estimation algorithm '
105
+ 'should be provided.')
106
+
107
+ def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
108
+ """Process the inputs into a model-feedable format.
109
+
110
+ Args:
111
+ inputs (InputsType): Inputs given by user.
112
+ batch_size (int): batch size. Defaults to 1.
113
+
114
+ Yields:
115
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
116
+ List[str or np.ndarray]: List of original inputs in the batch
117
+ """
118
+ for data in self.inferencer.preprocess(inputs, batch_size, **kwargs):
119
+ yield data
120
+
121
+ @torch.no_grad()
122
+ def forward(self, inputs: InputType, **forward_kwargs) -> PredType:
123
+ """Forward the inputs to the model.
124
+
125
+ Args:
126
+ inputs (InputsType): The inputs to be forwarded.
127
+
128
+ Returns:
129
+ Dict: The prediction results. Possibly with keys "pose2d".
130
+ """
131
+ return self.inferencer.forward(inputs, **forward_kwargs)
132
+
133
+ def __call__(
134
+ self,
135
+ inputs: InputsType,
136
+ return_datasamples: bool = False,
137
+ batch_size: int = 1,
138
+ out_dir: Optional[str] = None,
139
+ **kwargs,
140
+ ) -> dict:
141
+ """Call the inferencer.
142
+
143
+ Args:
144
+ inputs (InputsType): Inputs for the inferencer.
145
+ return_datasamples (bool): Whether to return results as
146
+ :obj:`BaseDataElement`. Defaults to False.
147
+ batch_size (int): Batch size. Defaults to 1.
148
+ out_dir (str, optional): directory to save visualization
149
+ results and predictions. Will be overoden if vis_out_dir or
150
+ pred_out_dir are given. Defaults to None
151
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
152
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
153
+ Each key in kwargs should be in the corresponding set of
154
+ ``preprocess_kwargs``, ``forward_kwargs``,
155
+ ``visualize_kwargs`` and ``postprocess_kwargs``.
156
+
157
+ Returns:
158
+ dict: Inference and visualization results.
159
+ """
160
+ if out_dir is not None:
161
+ if 'vis_out_dir' not in kwargs:
162
+ kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
163
+ if 'pred_out_dir' not in kwargs:
164
+ kwargs['pred_out_dir'] = f'{out_dir}/predictions'
165
+
166
+ kwargs = {
167
+ key: value
168
+ for key, value in kwargs.items()
169
+ if key in set.union(self.inferencer.preprocess_kwargs,
170
+ self.inferencer.forward_kwargs,
171
+ self.inferencer.visualize_kwargs,
172
+ self.inferencer.postprocess_kwargs)
173
+ }
174
+ (
175
+ preprocess_kwargs,
176
+ forward_kwargs,
177
+ visualize_kwargs,
178
+ postprocess_kwargs,
179
+ ) = self._dispatch_kwargs(**kwargs)
180
+
181
+ self.inferencer.update_model_visualizer_settings(**kwargs)
182
+
183
+ # preprocessing
184
+ if isinstance(inputs, str) and inputs.startswith('webcam'):
185
+ inputs = self.inferencer._get_webcam_inputs(inputs)
186
+ batch_size = 1
187
+ if not visualize_kwargs.get('show', False):
188
+ warnings.warn('The display mode is closed when using webcam '
189
+ 'input. It will be turned on automatically.')
190
+ visualize_kwargs['show'] = True
191
+ else:
192
+ inputs = self.inferencer._inputs_to_list(inputs)
193
+ self._video_input = self.inferencer._video_input
194
+ if self._video_input:
195
+ self.video_info = self.inferencer.video_info
196
+
197
+ inputs = self.preprocess(
198
+ inputs, batch_size=batch_size, **preprocess_kwargs)
199
+
200
+ # forward
201
+ if 'bbox_thr' in self.inferencer.forward_kwargs:
202
+ forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
203
+
204
+ preds = []
205
+
206
+ for proc_inputs, ori_inputs in (track(inputs, description='Inference')
207
+ if self.show_progress else inputs):
208
+ preds = self.forward(proc_inputs, **forward_kwargs)
209
+
210
+ visualization = self.visualize(ori_inputs, preds,
211
+ **visualize_kwargs)
212
+ results = self.postprocess(
213
+ preds,
214
+ visualization,
215
+ return_datasamples=return_datasamples,
216
+ **postprocess_kwargs)
217
+ yield results
218
+
219
+ if self._video_input:
220
+ self._finalize_video_processing(
221
+ postprocess_kwargs.get('pred_out_dir', ''))
222
+
223
+ def visualize(self, inputs: InputsType, preds: PredType,
224
+ **kwargs) -> List[np.ndarray]:
225
+ """Visualize predictions.
226
+
227
+ Args:
228
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
229
+ preds (Any): Predictions of the model.
230
+ return_vis (bool): Whether to return images with predicted results.
231
+ show (bool): Whether to display the image in a popup window.
232
+ Defaults to False.
233
+ show_interval (int): The interval of show (s). Defaults to 0
234
+ radius (int): Keypoint radius for visualization. Defaults to 3
235
+ thickness (int): Link thickness for visualization. Defaults to 1
236
+ kpt_thr (float): The threshold to visualize the keypoints.
237
+ Defaults to 0.3
238
+ vis_out_dir (str, optional): directory to save visualization
239
+ results w/o predictions. If left as empty, no file will
240
+ be saved. Defaults to ''.
241
+
242
+ Returns:
243
+ List[np.ndarray]: Visualization results.
244
+ """
245
+ window_name = ''
246
+ if self.inferencer._video_input:
247
+ window_name = self.inferencer.video_info['name']
248
+
249
+ return self.inferencer.visualize(
250
+ inputs, preds, window_name=window_name, **kwargs)
mmpose/apis/inferencers/pose2d_inferencer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ import mmcv
6
+ import numpy as np
7
+ import torch
8
+ from mmengine.config import Config, ConfigDict
9
+ from mmengine.infer.infer import ModelType
10
+ from mmengine.logging import print_log
11
+ from mmengine.model import revert_sync_batchnorm
12
+ from mmengine.registry import init_default_scope
13
+ from mmengine.structures import InstanceData
14
+
15
+ from mmpose.evaluation.functional import nearby_joints_nms, nms
16
+ from mmpose.registry import INFERENCERS
17
+ from mmpose.structures import merge_data_samples
18
+ from .base_mmpose_inferencer import BaseMMPoseInferencer
19
+
20
+ InstanceList = List[InstanceData]
21
+ InputType = Union[str, np.ndarray]
22
+ InputsType = Union[InputType, Sequence[InputType]]
23
+ PredType = Union[InstanceData, InstanceList]
24
+ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
25
+ ConfigType = Union[Config, ConfigDict]
26
+ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
27
+
28
+
29
+ @INFERENCERS.register_module(name='pose-estimation')
30
+ @INFERENCERS.register_module()
31
+ class Pose2DInferencer(BaseMMPoseInferencer):
32
+ """The inferencer for 2D pose estimation.
33
+
34
+ Args:
35
+ model (str, optional): Pretrained 2D pose estimation algorithm.
36
+ It's the path to the config file or the model name defined in
37
+ metafile. For example, it could be:
38
+
39
+ - model alias, e.g. ``'body'``,
40
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
41
+ - config path
42
+
43
+ Defaults to ``None``.
44
+ weights (str, optional): Path to the checkpoint. If it is not
45
+ specified and "model" is a model name of metafile, the weights
46
+ will be loaded from metafile. Defaults to None.
47
+ device (str, optional): Device to run inference. If None, the
48
+ available device will be automatically used. Defaults to None.
49
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
50
+ det_model (str, optional): Config path or alias of detection model.
51
+ Defaults to None.
52
+ det_weights (str, optional): Path to the checkpoints of detection
53
+ model. Defaults to None.
54
+ det_cat_ids (int or list[int], optional): Category id for
55
+ detection model. Defaults to None.
56
+ """
57
+
58
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
59
+ forward_kwargs: set = {'merge_results', 'pose_based_nms'}
60
+ visualize_kwargs: set = {
61
+ 'return_vis',
62
+ 'show',
63
+ 'wait_time',
64
+ 'draw_bbox',
65
+ 'radius',
66
+ 'thickness',
67
+ 'kpt_thr',
68
+ 'vis_out_dir',
69
+ 'skeleton_style',
70
+ 'draw_heatmap',
71
+ 'black_background',
72
+ }
73
+ postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
74
+
75
+ def __init__(self,
76
+ model: Union[ModelType, str],
77
+ weights: Optional[str] = None,
78
+ device: Optional[str] = None,
79
+ scope: Optional[str] = 'mmpose',
80
+ det_model: Optional[Union[ModelType, str]] = None,
81
+ det_weights: Optional[str] = None,
82
+ det_cat_ids: Optional[Union[int, Tuple]] = None,
83
+ show_progress: bool = False) -> None:
84
+
85
+ init_default_scope(scope)
86
+ super().__init__(
87
+ model=model,
88
+ weights=weights,
89
+ device=device,
90
+ scope=scope,
91
+ show_progress=show_progress)
92
+ self.model = revert_sync_batchnorm(self.model)
93
+
94
+ # assign dataset metainfo to self.visualizer
95
+ self.visualizer.set_dataset_meta(self.model.dataset_meta)
96
+
97
+ # initialize detector for top-down models
98
+ if self.cfg.data_mode == 'topdown':
99
+ self._init_detector(
100
+ det_model=det_model,
101
+ det_weights=det_weights,
102
+ det_cat_ids=det_cat_ids,
103
+ device=device,
104
+ )
105
+
106
+ self._video_input = False
107
+
108
+ def update_model_visualizer_settings(self,
109
+ draw_heatmap: bool = False,
110
+ skeleton_style: str = 'mmpose',
111
+ **kwargs) -> None:
112
+ """Update the settings of models and visualizer according to inference
113
+ arguments.
114
+
115
+ Args:
116
+ draw_heatmaps (bool, optional): Flag to visualize predicted
117
+ heatmaps. If not provided, it defaults to False.
118
+ skeleton_style (str, optional): Skeleton style selection. Valid
119
+ options are 'mmpose' and 'openpose'. Defaults to 'mmpose'.
120
+ """
121
+ self.model.test_cfg['output_heatmaps'] = draw_heatmap
122
+
123
+ if skeleton_style not in ['mmpose', 'openpose']:
124
+ raise ValueError('`skeleton_style` must be either \'mmpose\' '
125
+ 'or \'openpose\'')
126
+
127
+ if skeleton_style == 'openpose':
128
+ self.visualizer.set_dataset_meta(self.model.dataset_meta,
129
+ skeleton_style)
130
+
131
+ def preprocess_single(self,
132
+ input: InputType,
133
+ index: int,
134
+ bbox_thr: float = 0.3,
135
+ nms_thr: float = 0.3,
136
+ bboxes: Union[List[List], List[np.ndarray],
137
+ np.ndarray] = []):
138
+ """Process a single input into a model-feedable format.
139
+
140
+ Args:
141
+ input (InputType): Input given by user.
142
+ index (int): index of the input
143
+ bbox_thr (float): threshold for bounding box detection.
144
+ Defaults to 0.3.
145
+ nms_thr (float): IoU threshold for bounding box NMS.
146
+ Defaults to 0.3.
147
+
148
+ Yields:
149
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
150
+ """
151
+
152
+ if isinstance(input, str):
153
+ data_info = dict(img_path=input)
154
+ else:
155
+ data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0'))
156
+ data_info.update(self.model.dataset_meta)
157
+
158
+ if self.cfg.data_mode == 'topdown':
159
+ bboxes = []
160
+ if self.detector is not None:
161
+ try:
162
+ det_results = self.detector(
163
+ input, return_datasamples=True)['predictions']
164
+ except ValueError:
165
+ print_log(
166
+ 'Support for mmpose and mmdet versions up to 3.1.0 '
167
+ 'will be discontinued in upcoming releases. To '
168
+ 'ensure ongoing compatibility, please upgrade to '
169
+ 'mmdet version 3.2.0 or later.',
170
+ logger='current',
171
+ level=logging.WARNING)
172
+ det_results = self.detector(
173
+ input, return_datasample=True)['predictions']
174
+ pred_instance = det_results[0].pred_instances.cpu().numpy()
175
+ bboxes = np.concatenate(
176
+ (pred_instance.bboxes, pred_instance.scores[:, None]),
177
+ axis=1)
178
+
179
+ label_mask = np.zeros(len(bboxes), dtype=np.uint8)
180
+ for cat_id in self.det_cat_ids:
181
+ label_mask = np.logical_or(label_mask,
182
+ pred_instance.labels == cat_id)
183
+
184
+ bboxes = bboxes[np.logical_and(
185
+ label_mask, pred_instance.scores > bbox_thr)]
186
+ bboxes = bboxes[nms(bboxes, nms_thr)]
187
+
188
+ data_infos = []
189
+ if len(bboxes) > 0:
190
+ for bbox in bboxes:
191
+ inst = data_info.copy()
192
+ inst['bbox'] = bbox[None, :4]
193
+ inst['bbox_score'] = bbox[4:5]
194
+ data_infos.append(self.pipeline(inst))
195
+ else:
196
+ inst = data_info.copy()
197
+
198
+ # get bbox from the image size
199
+ if isinstance(input, str):
200
+ input = mmcv.imread(input)
201
+ h, w = input.shape[:2]
202
+
203
+ inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32)
204
+ inst['bbox_score'] = np.ones(1, dtype=np.float32)
205
+ data_infos.append(self.pipeline(inst))
206
+
207
+ else: # bottom-up
208
+ data_infos = [self.pipeline(data_info)]
209
+
210
+ return data_infos
211
+
212
+ @torch.no_grad()
213
+ def forward(self,
214
+ inputs: Union[dict, tuple],
215
+ merge_results: bool = True,
216
+ bbox_thr: float = -1,
217
+ pose_based_nms: bool = False):
218
+ """Performs a forward pass through the model.
219
+
220
+ Args:
221
+ inputs (Union[dict, tuple]): The input data to be processed. Can
222
+ be either a dictionary or a tuple.
223
+ merge_results (bool, optional): Whether to merge data samples,
224
+ default to True. This is only applicable when the data_mode
225
+ is 'topdown'.
226
+ bbox_thr (float, optional): A threshold for the bounding box
227
+ scores. Bounding boxes with scores greater than this value
228
+ will be retained. Default value is -1 which retains all
229
+ bounding boxes.
230
+
231
+ Returns:
232
+ A list of data samples with prediction instances.
233
+ """
234
+ data_samples = self.model.test_step(inputs)
235
+ if self.cfg.data_mode == 'topdown' and merge_results:
236
+ data_samples = [merge_data_samples(data_samples)]
237
+
238
+ if bbox_thr > 0:
239
+ for ds in data_samples:
240
+ if 'bbox_scores' in ds.pred_instances:
241
+ ds.pred_instances = ds.pred_instances[
242
+ ds.pred_instances.bbox_scores > bbox_thr]
243
+
244
+ if pose_based_nms:
245
+ for ds in data_samples:
246
+ if len(ds.pred_instances) == 0:
247
+ continue
248
+
249
+ kpts = ds.pred_instances.keypoints
250
+ scores = ds.pred_instances.bbox_scores
251
+ num_keypoints = kpts.shape[-2]
252
+
253
+ kept_indices = nearby_joints_nms(
254
+ [
255
+ dict(keypoints=kpts[i], score=scores[i])
256
+ for i in range(len(kpts))
257
+ ],
258
+ num_nearby_joints_thr=num_keypoints // 3,
259
+ )
260
+ ds.pred_instances = ds.pred_instances[kept_indices]
261
+
262
+ return data_samples
mmpose/apis/inferencers/pose3d_inferencer.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from collections import defaultdict
4
+ from functools import partial
5
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import mmcv
8
+ import numpy as np
9
+ import torch
10
+ from mmengine.config import Config, ConfigDict
11
+ from mmengine.infer.infer import ModelType
12
+ from mmengine.model import revert_sync_batchnorm
13
+ from mmengine.registry import init_default_scope
14
+ from mmengine.structures import InstanceData
15
+
16
+ from mmpose.apis import (_track_by_iou, _track_by_oks, collate_pose_sequence,
17
+ convert_keypoint_definition, extract_pose_sequence)
18
+ from mmpose.registry import INFERENCERS
19
+ from mmpose.structures import PoseDataSample, merge_data_samples
20
+ from .base_mmpose_inferencer import BaseMMPoseInferencer
21
+ from .pose2d_inferencer import Pose2DInferencer
22
+
23
+ InstanceList = List[InstanceData]
24
+ InputType = Union[str, np.ndarray]
25
+ InputsType = Union[InputType, Sequence[InputType]]
26
+ PredType = Union[InstanceData, InstanceList]
27
+ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
28
+ ConfigType = Union[Config, ConfigDict]
29
+ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
30
+
31
+
32
+ @INFERENCERS.register_module(name='pose-estimation-3d')
33
+ @INFERENCERS.register_module()
34
+ class Pose3DInferencer(BaseMMPoseInferencer):
35
+ """The inferencer for 3D pose estimation.
36
+
37
+ Args:
38
+ model (str, optional): Pretrained 2D pose estimation algorithm.
39
+ It's the path to the config file or the model name defined in
40
+ metafile. For example, it could be:
41
+
42
+ - model alias, e.g. ``'body'``,
43
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
44
+ - config path
45
+
46
+ Defaults to ``None``.
47
+ weights (str, optional): Path to the checkpoint. If it is not
48
+ specified and "model" is a model name of metafile, the weights
49
+ will be loaded from metafile. Defaults to None.
50
+ device (str, optional): Device to run inference. If None, the
51
+ available device will be automatically used. Defaults to None.
52
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
53
+ det_model (str, optional): Config path or alias of detection model.
54
+ Defaults to None.
55
+ det_weights (str, optional): Path to the checkpoints of detection
56
+ model. Defaults to None.
57
+ det_cat_ids (int or list[int], optional): Category id for
58
+ detection model. Defaults to None.
59
+ output_heatmaps (bool, optional): Flag to visualize predicted
60
+ heatmaps. If set to None, the default setting from the model
61
+ config will be used. Default is None.
62
+ """
63
+
64
+ preprocess_kwargs: set = {
65
+ 'bbox_thr', 'nms_thr', 'bboxes', 'use_oks_tracking', 'tracking_thr',
66
+ 'disable_norm_pose_2d'
67
+ }
68
+ forward_kwargs: set = {'disable_rebase_keypoint'}
69
+ visualize_kwargs: set = {
70
+ 'return_vis',
71
+ 'show',
72
+ 'wait_time',
73
+ 'draw_bbox',
74
+ 'radius',
75
+ 'thickness',
76
+ 'num_instances',
77
+ 'kpt_thr',
78
+ 'vis_out_dir',
79
+ }
80
+ postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
81
+
82
+ def __init__(self,
83
+ model: Union[ModelType, str],
84
+ weights: Optional[str] = None,
85
+ pose2d_model: Optional[Union[ModelType, str]] = None,
86
+ pose2d_weights: Optional[str] = None,
87
+ device: Optional[str] = None,
88
+ scope: Optional[str] = 'mmpose',
89
+ det_model: Optional[Union[ModelType, str]] = None,
90
+ det_weights: Optional[str] = None,
91
+ det_cat_ids: Optional[Union[int, Tuple]] = None,
92
+ show_progress: bool = False) -> None:
93
+
94
+ init_default_scope(scope)
95
+ super().__init__(
96
+ model=model,
97
+ weights=weights,
98
+ device=device,
99
+ scope=scope,
100
+ show_progress=show_progress)
101
+ self.model = revert_sync_batchnorm(self.model)
102
+
103
+ # assign dataset metainfo to self.visualizer
104
+ self.visualizer.set_dataset_meta(self.model.dataset_meta)
105
+
106
+ # initialize 2d pose estimator
107
+ self.pose2d_model = Pose2DInferencer(
108
+ pose2d_model if pose2d_model else 'human', pose2d_weights, device,
109
+ scope, det_model, det_weights, det_cat_ids)
110
+
111
+ # helper functions
112
+ self._keypoint_converter = partial(
113
+ convert_keypoint_definition,
114
+ pose_det_dataset=self.pose2d_model.model.
115
+ dataset_meta['dataset_name'],
116
+ pose_lift_dataset=self.model.dataset_meta['dataset_name'],
117
+ )
118
+
119
+ self._pose_seq_extractor = partial(
120
+ extract_pose_sequence,
121
+ causal=self.cfg.test_dataloader.dataset.get('causal', False),
122
+ seq_len=self.cfg.test_dataloader.dataset.get('seq_len', 1),
123
+ step=self.cfg.test_dataloader.dataset.get('seq_step', 1))
124
+
125
+ self._video_input = False
126
+ self._buffer = defaultdict(list)
127
+
128
+ def preprocess_single(self,
129
+ input: InputType,
130
+ index: int,
131
+ bbox_thr: float = 0.3,
132
+ nms_thr: float = 0.3,
133
+ bboxes: Union[List[List], List[np.ndarray],
134
+ np.ndarray] = [],
135
+ use_oks_tracking: bool = False,
136
+ tracking_thr: float = 0.3,
137
+ disable_norm_pose_2d: bool = False):
138
+ """Process a single input into a model-feedable format.
139
+
140
+ Args:
141
+ input (InputType): The input provided by the user.
142
+ index (int): The index of the input.
143
+ bbox_thr (float, optional): The threshold for bounding box
144
+ detection. Defaults to 0.3.
145
+ nms_thr (float, optional): The Intersection over Union (IoU)
146
+ threshold for bounding box Non-Maximum Suppression (NMS).
147
+ Defaults to 0.3.
148
+ bboxes (Union[List[List], List[np.ndarray], np.ndarray]):
149
+ The bounding boxes to use. Defaults to [].
150
+ use_oks_tracking (bool, optional): A flag that indicates
151
+ whether OKS-based tracking should be used. Defaults to False.
152
+ tracking_thr (float, optional): The threshold for tracking.
153
+ Defaults to 0.3.
154
+ disable_norm_pose_2d (bool, optional): A flag that indicates
155
+ whether 2D pose normalization should be used.
156
+ Defaults to False.
157
+
158
+ Yields:
159
+ Any: The data processed by the pipeline and collate_fn.
160
+
161
+ This method first calculates 2D keypoints using the provided
162
+ pose2d_model. The method also performs instance matching, which
163
+ can use either OKS-based tracking or IOU-based tracking.
164
+ """
165
+
166
+ # calculate 2d keypoints
167
+ results_pose2d = next(
168
+ self.pose2d_model(
169
+ input,
170
+ bbox_thr=bbox_thr,
171
+ nms_thr=nms_thr,
172
+ bboxes=bboxes,
173
+ merge_results=False,
174
+ return_datasamples=True))['predictions']
175
+
176
+ for ds in results_pose2d:
177
+ ds.pred_instances.set_field(
178
+ (ds.pred_instances.bboxes[..., 2:] -
179
+ ds.pred_instances.bboxes[..., :2]).prod(-1), 'areas')
180
+
181
+ if not self._video_input:
182
+ height, width = results_pose2d[0].metainfo['ori_shape']
183
+
184
+ # Clear the buffer if inputs are individual images to prevent
185
+ # carryover effects from previous images
186
+ self._buffer.clear()
187
+
188
+ else:
189
+ height = self.video_info['height']
190
+ width = self.video_info['width']
191
+ img_path = results_pose2d[0].metainfo['img_path']
192
+
193
+ # instance matching
194
+ if use_oks_tracking:
195
+ _track = partial(_track_by_oks)
196
+ else:
197
+ _track = _track_by_iou
198
+
199
+ for result in results_pose2d:
200
+ track_id, self._buffer['results_pose2d_last'], _ = _track(
201
+ result, self._buffer['results_pose2d_last'], tracking_thr)
202
+ if track_id == -1:
203
+ pred_instances = result.pred_instances.cpu().numpy()
204
+ keypoints = pred_instances.keypoints
205
+ if np.count_nonzero(keypoints[:, :, 1]) >= 3:
206
+ next_id = self._buffer.get('next_id', 0)
207
+ result.set_field(next_id, 'track_id')
208
+ self._buffer['next_id'] = next_id + 1
209
+ else:
210
+ # If the number of keypoints detected is small,
211
+ # delete that person instance.
212
+ result.pred_instances.keypoints[..., 1] = -10
213
+ result.pred_instances.bboxes *= 0
214
+ result.set_field(-1, 'track_id')
215
+ else:
216
+ result.set_field(track_id, 'track_id')
217
+ self._buffer['pose2d_results'] = merge_data_samples(results_pose2d)
218
+
219
+ # convert keypoints
220
+ results_pose2d_converted = [ds.cpu().numpy() for ds in results_pose2d]
221
+ for ds in results_pose2d_converted:
222
+ ds.pred_instances.keypoints = self._keypoint_converter(
223
+ ds.pred_instances.keypoints)
224
+ self._buffer['pose_est_results_list'].append(results_pose2d_converted)
225
+
226
+ # extract and pad input pose2d sequence
227
+ pose_results_2d = self._pose_seq_extractor(
228
+ self._buffer['pose_est_results_list'],
229
+ frame_idx=index if self._video_input else 0)
230
+ causal = self.cfg.test_dataloader.dataset.get('causal', False)
231
+ target_idx = -1 if causal else len(pose_results_2d) // 2
232
+
233
+ stats_info = self.model.dataset_meta.get('stats_info', {})
234
+ bbox_center = stats_info.get('bbox_center', None)
235
+ bbox_scale = stats_info.get('bbox_scale', None)
236
+
237
+ pose_results_2d_copy = []
238
+ for pose_res in pose_results_2d:
239
+ pose_res_copy = []
240
+ for data_sample in pose_res:
241
+
242
+ data_sample_copy = PoseDataSample()
243
+ data_sample_copy.gt_instances = \
244
+ data_sample.gt_instances.clone()
245
+ data_sample_copy.pred_instances = \
246
+ data_sample.pred_instances.clone()
247
+ data_sample_copy.track_id = data_sample.track_id
248
+
249
+ kpts = data_sample.pred_instances.keypoints
250
+ bboxes = data_sample.pred_instances.bboxes
251
+ keypoints = []
252
+ for k in range(len(kpts)):
253
+ kpt = kpts[k]
254
+ if not disable_norm_pose_2d:
255
+ bbox = bboxes[k]
256
+ center = np.array([[(bbox[0] + bbox[2]) / 2,
257
+ (bbox[1] + bbox[3]) / 2]])
258
+ scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
259
+ keypoints.append((kpt[:, :2] - center) / scale *
260
+ bbox_scale + bbox_center)
261
+ else:
262
+ keypoints.append(kpt[:, :2])
263
+ data_sample_copy.pred_instances.set_field(
264
+ np.array(keypoints), 'keypoints')
265
+ pose_res_copy.append(data_sample_copy)
266
+
267
+ pose_results_2d_copy.append(pose_res_copy)
268
+ pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy, True,
269
+ target_idx)
270
+ if not pose_sequences_2d:
271
+ return []
272
+
273
+ data_list = []
274
+ for i, pose_seq in enumerate(pose_sequences_2d):
275
+ data_info = dict()
276
+
277
+ keypoints_2d = pose_seq.pred_instances.keypoints
278
+ keypoints_2d = np.squeeze(
279
+ keypoints_2d,
280
+ axis=0) if keypoints_2d.ndim == 4 else keypoints_2d
281
+
282
+ T, K, C = keypoints_2d.shape
283
+
284
+ data_info['keypoints'] = keypoints_2d
285
+ data_info['keypoints_visible'] = np.ones((
286
+ T,
287
+ K,
288
+ ),
289
+ dtype=np.float32)
290
+ data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32)
291
+ data_info['factor'] = np.zeros((T, ), dtype=np.float32)
292
+ data_info['lifting_target_visible'] = np.ones((1, K, 1),
293
+ dtype=np.float32)
294
+ data_info['camera_param'] = dict(w=width, h=height)
295
+
296
+ data_info.update(self.model.dataset_meta)
297
+ data_info = self.pipeline(data_info)
298
+ data_info['data_samples'].set_field(
299
+ img_path, 'img_path', field_type='metainfo')
300
+ data_list.append(data_info)
301
+
302
+ return data_list
303
+
304
+ @torch.no_grad()
305
+ def forward(self,
306
+ inputs: Union[dict, tuple],
307
+ disable_rebase_keypoint: bool = False):
308
+ """Perform forward pass through the model and process the results.
309
+
310
+ Args:
311
+ inputs (Union[dict, tuple]): The inputs for the model.
312
+ disable_rebase_keypoint (bool, optional): Flag to disable rebasing
313
+ the height of the keypoints. Defaults to False.
314
+
315
+ Returns:
316
+ list: A list of data samples, each containing the model's output
317
+ results.
318
+ """
319
+ pose_lift_results = self.model.test_step(inputs)
320
+
321
+ # Post-processing of pose estimation results
322
+ pose_est_results_converted = self._buffer['pose_est_results_list'][-1]
323
+ for idx, pose_lift_res in enumerate(pose_lift_results):
324
+ # Update track_id from the pose estimation results
325
+ pose_lift_res.track_id = pose_est_results_converted[idx].get(
326
+ 'track_id', 1e4)
327
+
328
+ # align the shape of output keypoints coordinates and scores
329
+ keypoints = pose_lift_res.pred_instances.keypoints
330
+ keypoint_scores = pose_lift_res.pred_instances.keypoint_scores
331
+ if keypoint_scores.ndim == 3:
332
+ pose_lift_results[idx].pred_instances.keypoint_scores = \
333
+ np.squeeze(keypoint_scores, axis=1)
334
+ if keypoints.ndim == 4:
335
+ keypoints = np.squeeze(keypoints, axis=1)
336
+
337
+ # Invert x and z values of the keypoints
338
+ keypoints = keypoints[..., [0, 2, 1]]
339
+ keypoints[..., 0] = -keypoints[..., 0]
340
+ keypoints[..., 2] = -keypoints[..., 2]
341
+
342
+ # If rebase_keypoint_height is True, adjust z-axis values
343
+ if not disable_rebase_keypoint:
344
+ keypoints[..., 2] -= np.min(
345
+ keypoints[..., 2], axis=-1, keepdims=True)
346
+
347
+ pose_lift_results[idx].pred_instances.keypoints = keypoints
348
+
349
+ pose_lift_results = sorted(
350
+ pose_lift_results, key=lambda x: x.get('track_id', 1e4))
351
+
352
+ data_samples = [merge_data_samples(pose_lift_results)]
353
+ return data_samples
354
+
355
+ def visualize(self,
356
+ inputs: list,
357
+ preds: List[PoseDataSample],
358
+ return_vis: bool = False,
359
+ show: bool = False,
360
+ draw_bbox: bool = False,
361
+ wait_time: float = 0,
362
+ radius: int = 3,
363
+ thickness: int = 1,
364
+ kpt_thr: float = 0.3,
365
+ num_instances: int = 1,
366
+ vis_out_dir: str = '',
367
+ window_name: str = '',
368
+ window_close_event_handler: Optional[Callable] = None
369
+ ) -> List[np.ndarray]:
370
+ """Visualize predictions.
371
+
372
+ Args:
373
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
374
+ preds (Any): Predictions of the model.
375
+ return_vis (bool): Whether to return images with predicted results.
376
+ show (bool): Whether to display the image in a popup window.
377
+ Defaults to False.
378
+ wait_time (float): The interval of show (ms). Defaults to 0
379
+ draw_bbox (bool): Whether to draw the bounding boxes.
380
+ Defaults to False
381
+ radius (int): Keypoint radius for visualization. Defaults to 3
382
+ thickness (int): Link thickness for visualization. Defaults to 1
383
+ kpt_thr (float): The threshold to visualize the keypoints.
384
+ Defaults to 0.3
385
+ vis_out_dir (str, optional): Directory to save visualization
386
+ results w/o predictions. If left as empty, no file will
387
+ be saved. Defaults to ''.
388
+ window_name (str, optional): Title of display window.
389
+ window_close_event_handler (callable, optional):
390
+
391
+ Returns:
392
+ List[np.ndarray]: Visualization results.
393
+ """
394
+ if (not return_vis) and (not show) and (not vis_out_dir):
395
+ return
396
+
397
+ if getattr(self, 'visualizer', None) is None:
398
+ raise ValueError('Visualization needs the "visualizer" term'
399
+ 'defined in the config, but got None.')
400
+
401
+ self.visualizer.radius = radius
402
+ self.visualizer.line_width = thickness
403
+ det_kpt_color = self.pose2d_model.visualizer.kpt_color
404
+ det_dataset_skeleton = self.pose2d_model.visualizer.skeleton
405
+ det_dataset_link_color = self.pose2d_model.visualizer.link_color
406
+ self.visualizer.det_kpt_color = det_kpt_color
407
+ self.visualizer.det_dataset_skeleton = det_dataset_skeleton
408
+ self.visualizer.det_dataset_link_color = det_dataset_link_color
409
+
410
+ results = []
411
+
412
+ for single_input, pred in zip(inputs, preds):
413
+ if isinstance(single_input, str):
414
+ img = mmcv.imread(single_input, channel_order='rgb')
415
+ elif isinstance(single_input, np.ndarray):
416
+ img = mmcv.bgr2rgb(single_input)
417
+ else:
418
+ raise ValueError('Unsupported input type: '
419
+ f'{type(single_input)}')
420
+
421
+ # since visualization and inference utilize the same process,
422
+ # the wait time is reduced when a video input is utilized,
423
+ # thereby eliminating the issue of inference getting stuck.
424
+ wait_time = 1e-5 if self._video_input else wait_time
425
+
426
+ if num_instances < 0:
427
+ num_instances = len(pred.pred_instances)
428
+
429
+ visualization = self.visualizer.add_datasample(
430
+ window_name,
431
+ img,
432
+ data_sample=pred,
433
+ det_data_sample=self._buffer['pose2d_results'],
434
+ draw_gt=False,
435
+ draw_bbox=draw_bbox,
436
+ show=show,
437
+ wait_time=wait_time,
438
+ dataset_2d=self.pose2d_model.model.
439
+ dataset_meta['dataset_name'],
440
+ dataset_3d=self.model.dataset_meta['dataset_name'],
441
+ kpt_thr=kpt_thr,
442
+ num_instances=num_instances)
443
+ results.append(visualization)
444
+
445
+ if vis_out_dir:
446
+ img_name = os.path.basename(pred.metainfo['img_path']) \
447
+ if 'img_path' in pred.metainfo else None
448
+ self.save_visualization(
449
+ visualization,
450
+ vis_out_dir,
451
+ img_name=img_name,
452
+ )
453
+
454
+ if return_vis:
455
+ return results
456
+ else:
457
+ return []
mmpose/apis/inferencers/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .default_det_models import default_det_models
3
+ from .get_model_alias import get_model_aliases
4
+
5
+ __all__ = ['default_det_models', 'get_model_aliases']
mmpose/apis/inferencers/utils/default_det_models.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+
4
+ from mmengine.config.utils import MODULE2PACKAGE
5
+ from mmengine.utils import get_installed_path
6
+
7
+ mmpose_path = get_installed_path(MODULE2PACKAGE['mmpose'])
8
+
9
+ default_det_models = dict(
10
+ human=dict(
11
+ model=osp.join(
12
+ mmpose_path, '.mim', 'demo/mmdetection_cfg/'
13
+ 'rtmdet_m_640-8xb32_coco-person.py'),
14
+ weights='https://download.openmmlab.com/mmpose/v1/projects/'
15
+ 'rtmposev1/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth',
16
+ cat_ids=(0, )),
17
+ face=dict(
18
+ model=osp.join(mmpose_path, '.mim',
19
+ 'demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py'),
20
+ weights='https://download.openmmlab.com/mmpose/mmdet_pretrained/'
21
+ 'yolo-x_8xb8-300e_coco-face_13274d7c.pth',
22
+ cat_ids=(0, )),
23
+ hand=dict(
24
+ model=osp.join(mmpose_path, '.mim', 'demo/mmdetection_cfg/'
25
+ 'rtmdet_nano_320-8xb32_hand.py'),
26
+ weights='https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/'
27
+ 'rtmdet_nano_8xb32-300e_hand-267f9c8f.pth',
28
+ cat_ids=(0, )),
29
+ animal=dict(
30
+ model='rtmdet-m',
31
+ weights=None,
32
+ cat_ids=(15, 16, 17, 18, 19, 20, 21, 22, 23)),
33
+ )
34
+
35
+ default_det_models['body'] = default_det_models['human']
36
+ default_det_models['wholebody'] = default_det_models['human']
mmpose/apis/inferencers/utils/get_model_alias.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Dict
3
+
4
+ from mmengine.infer import BaseInferencer
5
+
6
+
7
+ def get_model_aliases(scope: str = 'mmpose') -> Dict[str, str]:
8
+ """Retrieve model aliases and their corresponding configuration names.
9
+
10
+ Args:
11
+ scope (str, optional): The scope for the model aliases. Defaults
12
+ to 'mmpose'.
13
+
14
+ Returns:
15
+ Dict[str, str]: A dictionary containing model aliases as keys and
16
+ their corresponding configuration names as values.
17
+ """
18
+
19
+ # Get a list of model configurations from the metafile
20
+ repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope)
21
+ model_cfgs = BaseInferencer._get_models_from_metafile(repo_or_mim_dir)
22
+
23
+ model_alias_dict = dict()
24
+ for model_cfg in model_cfgs:
25
+ if 'Alias' in model_cfg:
26
+ if isinstance(model_cfg['Alias'], str):
27
+ model_alias_dict[model_cfg['Alias']] = model_cfg['Name']
28
+ elif isinstance(model_cfg['Alias'], list):
29
+ for alias in model_cfg['Alias']:
30
+ model_alias_dict[alias] = model_cfg['Name']
31
+ else:
32
+ raise ValueError(
33
+ 'encounter an unexpected alias type. Please raise an '
34
+ 'issue at https://github.com/open-mmlab/mmpose/issues '
35
+ 'to announce us')
36
+
37
+ return model_alias_dict
mmpose/apis/visualization.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from copy import deepcopy
3
+ from typing import Union
4
+
5
+ import mmcv
6
+ import numpy as np
7
+ from mmengine.structures import InstanceData
8
+
9
+ from mmpose.datasets.datasets.utils import parse_pose_metainfo
10
+ from mmpose.structures import PoseDataSample
11
+ from mmpose.visualization import PoseLocalVisualizer
12
+
13
+ # from posevis import pose_visualization
14
+
15
+ # def visualize(
16
+ # img: Union[np.ndarray, str],
17
+ # keypoints: np.ndarray,
18
+ # keypoint_score: np.ndarray = None,
19
+ # metainfo: Union[str, dict] = None,
20
+ # visualizer: PoseLocalVisualizer = None,
21
+ # show_kpt_idx: bool = False,
22
+ # skeleton_style: str = 'mmpose',
23
+ # show: bool = False,
24
+ # kpt_thr: float = 0.3,
25
+ # ):
26
+ # """Visualize 2d keypoints on an image.
27
+
28
+ # Args:
29
+ # img (str | np.ndarray): The image to be displayed.
30
+ # keypoints (np.ndarray): The keypoint to be displayed.
31
+ # keypoint_score (np.ndarray): The score of each keypoint.
32
+ # metainfo (str | dict): The metainfo of dataset.
33
+ # visualizer (PoseLocalVisualizer): The visualizer.
34
+ # show_kpt_idx (bool): Whether to show the index of keypoints.
35
+ # skeleton_style (str): Skeleton style. Options are 'mmpose' and
36
+ # 'openpose'.
37
+ # show (bool): Whether to show the image.
38
+ # wait_time (int): Value of waitKey param.
39
+ # kpt_thr (float): Keypoint threshold.
40
+ # """
41
+ # kpts = keypoints.reshape(-1, 2)
42
+ # kpts = np.concatenate([kpts, keypoint_score[:, None]], axis=1)
43
+ # kpts[kpts[:, 2] < kpt_thr, :] = 0
44
+ # pose_results = [{
45
+ # 'keypoints': kpts,
46
+ # }]
47
+
48
+ # img = pose_visualization(
49
+ # img,
50
+ # pose_results,
51
+ # format="COCO",
52
+ # greyness=1.0,
53
+ # show_markers=True,
54
+ # show_bones=True,
55
+ # line_type="solid",
56
+ # width_multiplier=1.0,
57
+ # bbox_width_multiplier=1.0,
58
+ # show_bbox=False,
59
+ # differ_individuals=False,
60
+ # )
61
+ # return img
62
+
63
+
64
+ def visualize(
65
+ img: Union[np.ndarray, str],
66
+ keypoints: np.ndarray,
67
+ keypoint_score: np.ndarray = None,
68
+ metainfo: Union[str, dict] = None,
69
+ visualizer: PoseLocalVisualizer = None,
70
+ show_kpt_idx: bool = False,
71
+ skeleton_style: str = 'mmpose',
72
+ show: bool = False,
73
+ kpt_thr: float = 0.3,
74
+ ):
75
+ """Visualize 2d keypoints on an image.
76
+
77
+ Args:
78
+ img (str | np.ndarray): The image to be displayed.
79
+ keypoints (np.ndarray): The keypoint to be displayed.
80
+ keypoint_score (np.ndarray): The score of each keypoint.
81
+ metainfo (str | dict): The metainfo of dataset.
82
+ visualizer (PoseLocalVisualizer): The visualizer.
83
+ show_kpt_idx (bool): Whether to show the index of keypoints.
84
+ skeleton_style (str): Skeleton style. Options are 'mmpose' and
85
+ 'openpose'.
86
+ show (bool): Whether to show the image.
87
+ wait_time (int): Value of waitKey param.
88
+ kpt_thr (float): Keypoint threshold.
89
+ """
90
+ assert skeleton_style in [
91
+ 'mmpose', 'openpose'
92
+ ], (f'Only support skeleton style in {["mmpose", "openpose"]}, ')
93
+
94
+ if visualizer is None:
95
+ visualizer = PoseLocalVisualizer()
96
+ else:
97
+ visualizer = deepcopy(visualizer)
98
+
99
+ if isinstance(metainfo, str):
100
+ metainfo = parse_pose_metainfo(dict(from_file=metainfo))
101
+ elif isinstance(metainfo, dict):
102
+ metainfo = parse_pose_metainfo(metainfo)
103
+
104
+ if metainfo is not None:
105
+ visualizer.set_dataset_meta(metainfo, skeleton_style=skeleton_style)
106
+
107
+ if isinstance(img, str):
108
+ img = mmcv.imread(img, channel_order='rgb')
109
+ elif isinstance(img, np.ndarray):
110
+ img = mmcv.bgr2rgb(img)
111
+
112
+ if keypoint_score is None:
113
+ keypoint_score = np.ones(keypoints.shape[0])
114
+
115
+ tmp_instances = InstanceData()
116
+ tmp_instances.keypoints = keypoints
117
+ tmp_instances.keypoint_score = keypoint_score
118
+
119
+ tmp_datasample = PoseDataSample()
120
+ tmp_datasample.pred_instances = tmp_instances
121
+
122
+ visualizer.add_datasample(
123
+ 'visualization',
124
+ img,
125
+ tmp_datasample,
126
+ show_kpt_idx=show_kpt_idx,
127
+ skeleton_style=skeleton_style,
128
+ show=show,
129
+ wait_time=0,
130
+ kpt_thr=kpt_thr)
131
+
132
+ return visualizer.get_image()
mmpose/codecs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .annotation_processors import YOLOXPoseAnnotationProcessor
3
+ from .associative_embedding import AssociativeEmbedding
4
+ from .decoupled_heatmap import DecoupledHeatmap
5
+ from .edpose_label import EDPoseLabel
6
+ from .hand_3d_heatmap import Hand3DHeatmap
7
+ from .image_pose_lifting import ImagePoseLifting
8
+ from .integral_regression_label import IntegralRegressionLabel
9
+ from .megvii_heatmap import MegviiHeatmap
10
+ from .motionbert_label import MotionBERTLabel
11
+ from .msra_heatmap import MSRAHeatmap
12
+ from .regression_label import RegressionLabel
13
+ from .simcc_label import SimCCLabel
14
+ from .spr import SPR
15
+ from .udp_heatmap import UDPHeatmap
16
+ from .video_pose_lifting import VideoPoseLifting
17
+ from .onehot_heatmap import OneHotHeatmap
18
+
19
+ __all__ = [
20
+ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel',
21
+ 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR',
22
+ 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting',
23
+ 'MotionBERTLabel', 'YOLOXPoseAnnotationProcessor', 'EDPoseLabel',
24
+ 'Hand3DHeatmap', 'OneHotHeatmap'
25
+ ]
mmpose/codecs/annotation_processors.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from mmpose.registry import KEYPOINT_CODECS
7
+ from .base import BaseKeypointCodec
8
+
9
+ INF = 1e6
10
+ NEG_INF = -1e6
11
+
12
+
13
+ class BaseAnnotationProcessor(BaseKeypointCodec):
14
+ """Base class for annotation processors."""
15
+
16
+ def decode(self, *args, **kwargs):
17
+ pass
18
+
19
+
20
+ @KEYPOINT_CODECS.register_module()
21
+ class YOLOXPoseAnnotationProcessor(BaseAnnotationProcessor):
22
+ """Convert dataset annotations to the input format of YOLOX-Pose.
23
+
24
+ This processor expands bounding boxes and converts category IDs to labels.
25
+
26
+ Args:
27
+ extend_bbox (bool, optional): Whether to expand the bounding box
28
+ to include all keypoints. Defaults to False.
29
+ input_size (tuple, optional): The size of the input image for the
30
+ model, formatted as (h, w). This argument is necessary for the
31
+ codec in deployment but is not used indeed.
32
+ """
33
+
34
+ auxiliary_encode_keys = {'category_id', 'bbox'}
35
+ label_mapping_table = dict(
36
+ bbox='bboxes',
37
+ bbox_labels='labels',
38
+ keypoints='keypoints',
39
+ keypoints_visible='keypoints_visible',
40
+ area='areas',
41
+ )
42
+ instance_mapping_table = dict(
43
+ bbox='bboxes',
44
+ bbox_score='bbox_scores',
45
+ keypoints='keypoints',
46
+ keypoints_visible='keypoints_visible',
47
+ # remove 'bbox_scales' in default instance_mapping_table to avoid
48
+ # length mismatch during training with multiple datasets
49
+ )
50
+
51
+ def __init__(self,
52
+ extend_bbox: bool = False,
53
+ input_size: Optional[Tuple] = None):
54
+ super().__init__()
55
+ self.extend_bbox = extend_bbox
56
+
57
+ def encode(self,
58
+ keypoints: Optional[np.ndarray] = None,
59
+ keypoints_visible: Optional[np.ndarray] = None,
60
+ bbox: Optional[np.ndarray] = None,
61
+ category_id: Optional[List[int]] = None
62
+ ) -> Dict[str, np.ndarray]:
63
+ """Encode keypoints, bounding boxes, and category IDs.
64
+
65
+ Args:
66
+ keypoints (np.ndarray, optional): Keypoints array. Defaults
67
+ to None.
68
+ keypoints_visible (np.ndarray, optional): Visibility array for
69
+ keypoints. Defaults to None.
70
+ bbox (np.ndarray, optional): Bounding box array. Defaults to None.
71
+ category_id (List[int], optional): List of category IDs. Defaults
72
+ to None.
73
+
74
+ Returns:
75
+ Dict[str, np.ndarray]: Encoded annotations.
76
+ """
77
+ results = {}
78
+
79
+ if self.extend_bbox and bbox is not None:
80
+ # Handle keypoints visibility
81
+ if keypoints_visible.ndim == 3:
82
+ keypoints_visible = keypoints_visible[..., 0]
83
+
84
+ # Expand bounding box to include keypoints
85
+ kpts_min = keypoints.copy()
86
+ kpts_min[keypoints_visible == 0] = INF
87
+ bbox[..., :2] = np.minimum(bbox[..., :2], kpts_min.min(axis=1))
88
+
89
+ kpts_max = keypoints.copy()
90
+ kpts_max[keypoints_visible == 0] = NEG_INF
91
+ bbox[..., 2:] = np.maximum(bbox[..., 2:], kpts_max.max(axis=1))
92
+
93
+ results['bbox'] = bbox
94
+
95
+ if category_id is not None:
96
+ # Convert category IDs to labels
97
+ bbox_labels = np.array(category_id).astype(np.int8) - 1
98
+ results['bbox_labels'] = bbox_labels
99
+
100
+ return results
mmpose/codecs/associative_embedding.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ from munkres import Munkres
8
+ from torch import Tensor
9
+
10
+ from mmpose.registry import KEYPOINT_CODECS
11
+ from mmpose.utils.tensor_utils import to_numpy
12
+ from .base import BaseKeypointCodec
13
+ from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps,
14
+ generate_udp_gaussian_heatmaps, refine_keypoints,
15
+ refine_keypoints_dark_udp)
16
+
17
+
18
+ def _py_max_match(scores):
19
+ """Apply munkres algorithm to get the best match.
20
+
21
+ Args:
22
+ scores(np.ndarray): cost matrix.
23
+
24
+ Returns:
25
+ np.ndarray: best match.
26
+ """
27
+ m = Munkres()
28
+ tmp = m.compute(scores)
29
+ tmp = np.array(tmp).astype(int)
30
+ return tmp
31
+
32
+
33
+ def _group_keypoints_by_tags(vals: np.ndarray,
34
+ tags: np.ndarray,
35
+ locs: np.ndarray,
36
+ keypoint_order: List[int],
37
+ val_thr: float,
38
+ tag_thr: float = 1.0,
39
+ max_groups: Optional[int] = None) -> np.ndarray:
40
+ """Group the keypoints by tags using Munkres algorithm.
41
+
42
+ Note:
43
+
44
+ - keypoint number: K
45
+ - candidate number: M
46
+ - tag dimenssion: L
47
+ - coordinate dimension: D
48
+ - group number: G
49
+
50
+ Args:
51
+ vals (np.ndarray): The heatmap response values of keypoints in shape
52
+ (K, M)
53
+ tags (np.ndarray): The tags of the keypoint candidates in shape
54
+ (K, M, L)
55
+ locs (np.ndarray): The locations of the keypoint candidates in shape
56
+ (K, M, D)
57
+ keypoint_order (List[int]): The grouping order of the keypoints.
58
+ The groupping usually starts from a keypoints around the head and
59
+ torso, and gruadually moves out to the limbs
60
+ val_thr (float): The threshold of the keypoint response value
61
+ tag_thr (float): The maximum allowed tag distance when matching a
62
+ keypoint to a group. A keypoint with larger tag distance to any
63
+ of the existing groups will initializes a new group
64
+ max_groups (int, optional): The maximum group number. ``None`` means
65
+ no limitation. Defaults to ``None``
66
+
67
+ Returns:
68
+ np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
69
+ dimenssion is the concatenated keypoint coordinates and scores.
70
+ """
71
+
72
+ tag_k, loc_k, val_k = tags, locs, vals
73
+ K, M, D = locs.shape
74
+ assert vals.shape == tags.shape[:2] == (K, M)
75
+ assert len(keypoint_order) == K
76
+
77
+ default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32)
78
+
79
+ joint_dict = {}
80
+ tag_dict = {}
81
+ for i in range(K):
82
+ idx = keypoint_order[i]
83
+
84
+ tags = tag_k[idx]
85
+ joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
86
+ mask = joints[:, 2] > val_thr
87
+ tags = tags[mask] # shape: [M, L]
88
+ joints = joints[mask] # shape: [M, 3 + L], 3: x, y, val
89
+
90
+ if joints.shape[0] == 0:
91
+ continue
92
+
93
+ if i == 0 or len(joint_dict) == 0:
94
+ for tag, joint in zip(tags, joints):
95
+ key = tag[0]
96
+ joint_dict.setdefault(key, np.copy(default_))[idx] = joint
97
+ tag_dict[key] = [tag]
98
+ else:
99
+ # shape: [M]
100
+ grouped_keys = list(joint_dict.keys())
101
+ # shape: [M, L]
102
+ grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]
103
+
104
+ # shape: [M, M, L]
105
+ diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
106
+ # shape: [M, M]
107
+ diff_normed = np.linalg.norm(diff, ord=2, axis=2)
108
+ diff_saved = np.copy(diff_normed)
109
+ diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]
110
+
111
+ num_added = diff.shape[0]
112
+ num_grouped = diff.shape[1]
113
+
114
+ if num_added > num_grouped:
115
+ diff_normed = np.concatenate(
116
+ (diff_normed,
117
+ np.zeros((num_added, num_added - num_grouped),
118
+ dtype=np.float32) + 1e10),
119
+ axis=1)
120
+
121
+ pairs = _py_max_match(diff_normed)
122
+ for row, col in pairs:
123
+ if (row < num_added and col < num_grouped
124
+ and diff_saved[row][col] < tag_thr):
125
+ key = grouped_keys[col]
126
+ joint_dict[key][idx] = joints[row]
127
+ tag_dict[key].append(tags[row])
128
+ else:
129
+ key = tags[row][0]
130
+ joint_dict.setdefault(key, np.copy(default_))[idx] = \
131
+ joints[row]
132
+ tag_dict[key] = [tags[row]]
133
+
134
+ joint_dict_keys = list(joint_dict.keys())[:max_groups]
135
+
136
+ if joint_dict_keys:
137
+ results = np.array([joint_dict[i]
138
+ for i in joint_dict_keys]).astype(np.float32)
139
+ results = results[..., :D + 1]
140
+ else:
141
+ results = np.empty((0, K, D + 1), dtype=np.float32)
142
+ return results
143
+
144
+
145
+ @KEYPOINT_CODECS.register_module()
146
+ class AssociativeEmbedding(BaseKeypointCodec):
147
+ """Encode/decode keypoints with the method introduced in "Associative
148
+ Embedding". This is an asymmetric codec, where the keypoints are
149
+ represented as gaussian heatmaps and position indices during encoding, and
150
+ restored from predicted heatmaps and group tags.
151
+
152
+ See the paper `Associative Embedding: End-to-End Learning for Joint
153
+ Detection and Grouping`_ by Newell et al (2017) for details
154
+
155
+ Note:
156
+
157
+ - instance number: N
158
+ - keypoint number: K
159
+ - keypoint dimension: D
160
+ - embedding tag dimension: L
161
+ - image size: [w, h]
162
+ - heatmap size: [W, H]
163
+
164
+ Encoded:
165
+
166
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
167
+ where [W, H] is the `heatmap_size`
168
+ - keypoint_indices (np.ndarray): The keypoint position indices in shape
169
+ (N, K, 2). Each keypoint's index is [i, v], where i is the position
170
+ index in the heatmap (:math:`i=y*w+x`) and v is the visibility
171
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
172
+
173
+ Args:
174
+ input_size (tuple): Image size in [w, h]
175
+ heatmap_size (tuple): Heatmap size in [W, H]
176
+ sigma (float): The sigma value of the Gaussian heatmap
177
+ use_udp (bool): Whether use unbiased data processing. See
178
+ `UDP (CVPR 2020)`_ for details. Defaults to ``False``
179
+ decode_keypoint_order (List[int]): The grouping order of the
180
+ keypoint indices. The groupping usually starts from a keypoints
181
+ around the head and torso, and gruadually moves out to the limbs
182
+ decode_keypoint_thr (float): The threshold of keypoint response value
183
+ in heatmaps. Defaults to 0.1
184
+ decode_tag_thr (float): The maximum allowed tag distance when matching
185
+ a keypoint to a group. A keypoint with larger tag distance to any
186
+ of the existing groups will initializes a new group. Defaults to
187
+ 1.0
188
+ decode_nms_kernel (int): The kernel size of the NMS during decoding,
189
+ which should be an odd integer. Defaults to 5
190
+ decode_gaussian_kernel (int): The kernel size of the Gaussian blur
191
+ during decoding, which should be an odd integer. It is only used
192
+ when ``self.use_udp==True``. Defaults to 3
193
+ decode_topk (int): The number top-k candidates of each keypoints that
194
+ will be retrieved from the heatmaps during dedocding. Defaults to
195
+ 20
196
+ decode_max_instances (int, optional): The maximum number of instances
197
+ to decode. ``None`` means no limitation to the instance number.
198
+ Defaults to ``None``
199
+
200
+ .. _`Associative Embedding: End-to-End Learning for Joint Detection and
201
+ Grouping`: https://arxiv.org/abs/1611.05424
202
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ input_size: Tuple[int, int],
208
+ heatmap_size: Tuple[int, int],
209
+ sigma: Optional[float] = None,
210
+ use_udp: bool = False,
211
+ decode_keypoint_order: List[int] = [],
212
+ decode_nms_kernel: int = 5,
213
+ decode_gaussian_kernel: int = 3,
214
+ decode_keypoint_thr: float = 0.1,
215
+ decode_tag_thr: float = 1.0,
216
+ decode_topk: int = 30,
217
+ decode_center_shift=0.0,
218
+ decode_max_instances: Optional[int] = None,
219
+ ) -> None:
220
+ super().__init__()
221
+ self.input_size = input_size
222
+ self.heatmap_size = heatmap_size
223
+ self.use_udp = use_udp
224
+ self.decode_nms_kernel = decode_nms_kernel
225
+ self.decode_gaussian_kernel = decode_gaussian_kernel
226
+ self.decode_keypoint_thr = decode_keypoint_thr
227
+ self.decode_tag_thr = decode_tag_thr
228
+ self.decode_topk = decode_topk
229
+ self.decode_center_shift = decode_center_shift
230
+ self.decode_max_instances = decode_max_instances
231
+ self.decode_keypoint_order = decode_keypoint_order.copy()
232
+
233
+ if self.use_udp:
234
+ self.scale_factor = ((np.array(input_size) - 1) /
235
+ (np.array(heatmap_size) - 1)).astype(
236
+ np.float32)
237
+ else:
238
+ self.scale_factor = (np.array(input_size) /
239
+ heatmap_size).astype(np.float32)
240
+
241
+ if sigma is None:
242
+ sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64
243
+ self.sigma = sigma
244
+
245
+ def encode(
246
+ self,
247
+ keypoints: np.ndarray,
248
+ keypoints_visible: Optional[np.ndarray] = None
249
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
250
+ """Encode keypoints into heatmaps and position indices. Note that the
251
+ original keypoint coordinates should be in the input image space.
252
+
253
+ Args:
254
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
255
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
256
+ (N, K)
257
+
258
+ Returns:
259
+ dict:
260
+ - heatmaps (np.ndarray): The generated heatmap in shape
261
+ (K, H, W) where [W, H] is the `heatmap_size`
262
+ - keypoint_indices (np.ndarray): The keypoint position indices
263
+ in shape (N, K, 2). Each keypoint's index is [i, v], where i
264
+ is the position index in the heatmap (:math:`i=y*w+x`) and v
265
+ is the visibility
266
+ - keypoint_weights (np.ndarray): The target weights in shape
267
+ (N, K)
268
+ """
269
+
270
+ if keypoints_visible is None:
271
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
272
+
273
+ # keypoint coordinates in heatmap
274
+ _keypoints = keypoints / self.scale_factor
275
+
276
+ if self.use_udp:
277
+ heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps(
278
+ heatmap_size=self.heatmap_size,
279
+ keypoints=_keypoints,
280
+ keypoints_visible=keypoints_visible,
281
+ sigma=self.sigma)
282
+ else:
283
+ heatmaps, keypoint_weights = generate_gaussian_heatmaps(
284
+ heatmap_size=self.heatmap_size,
285
+ keypoints=_keypoints,
286
+ keypoints_visible=keypoints_visible,
287
+ sigma=self.sigma)
288
+
289
+ keypoint_indices = self._encode_keypoint_indices(
290
+ heatmap_size=self.heatmap_size,
291
+ keypoints=_keypoints,
292
+ keypoints_visible=keypoints_visible)
293
+
294
+ encoded = dict(
295
+ heatmaps=heatmaps,
296
+ keypoint_indices=keypoint_indices,
297
+ keypoint_weights=keypoint_weights)
298
+
299
+ return encoded
300
+
301
+ def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int],
302
+ keypoints: np.ndarray,
303
+ keypoints_visible: np.ndarray) -> np.ndarray:
304
+ w, h = heatmap_size
305
+ N, K, _ = keypoints.shape
306
+ keypoint_indices = np.zeros((N, K, 2), dtype=np.int64)
307
+
308
+ for n, k in product(range(N), range(K)):
309
+ x, y = (keypoints[n, k] + 0.5).astype(np.int64)
310
+ index = y * w + x
311
+ vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h)
312
+ keypoint_indices[n, k] = [index, vis]
313
+
314
+ return keypoint_indices
315
+
316
+ def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
317
+ raise NotImplementedError()
318
+
319
+ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
320
+ k: int):
321
+ """Get top-k response values from the heatmaps and corresponding tag
322
+ values from the tagging heatmaps.
323
+
324
+ Args:
325
+ batch_heatmaps (Tensor): Keypoint detection heatmaps in shape
326
+ (B, K, H, W)
327
+ batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where
328
+ the tag dim C is 2*K when using flip testing, or K otherwise
329
+ k (int): The number of top responses to get
330
+
331
+ Returns:
332
+ tuple:
333
+ - topk_vals (Tensor): Top-k response values of each heatmap in
334
+ shape (B, K, Topk)
335
+ - topk_tags (Tensor): The corresponding embedding tags of the
336
+ top-k responses, in shape (B, K, Topk, L)
337
+ - topk_locs (Tensor): The location of the top-k responses in each
338
+ heatmap, in shape (B, K, Topk, 2) where last dimension
339
+ represents x and y coordinates
340
+ """
341
+ B, K, H, W = batch_heatmaps.shape
342
+ L = batch_tags.shape[1] // K
343
+
344
+ # shape of topk_val, top_indices: (B, K, TopK)
345
+ topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
346
+ k, dim=-1)
347
+
348
+ topk_tags_per_kpts = [
349
+ torch.gather(_tag, dim=2, index=topk_indices)
350
+ for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1)
351
+ ]
352
+
353
+ topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L)
354
+ topk_locs = torch.stack([topk_indices % W, topk_indices // W],
355
+ dim=-1) # (B, K, TopK, 2)
356
+
357
+ return topk_vals, topk_tags, topk_locs
358
+
359
+ def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray,
360
+ batch_locs: np.ndarray):
361
+ """Group keypoints into groups (each represents an instance) by tags.
362
+
363
+ Args:
364
+ batch_vals (Tensor): Heatmap response values of keypoint
365
+ candidates in shape (B, K, Topk)
366
+ batch_tags (Tensor): Tags of keypoint candidates in shape
367
+ (B, K, Topk, L)
368
+ batch_locs (Tensor): Locations of keypoint candidates in shape
369
+ (B, K, Topk, 2)
370
+
371
+ Returns:
372
+ List[np.ndarray]: Grouping results of a batch, each element is a
373
+ np.ndarray (in shape [N, K, D+1]) that contains the groups
374
+ detected in an image, including both keypoint coordinates and
375
+ scores.
376
+ """
377
+
378
+ def _group_func(inputs: Tuple):
379
+ vals, tags, locs = inputs
380
+ return _group_keypoints_by_tags(
381
+ vals,
382
+ tags,
383
+ locs,
384
+ keypoint_order=self.decode_keypoint_order,
385
+ val_thr=self.decode_keypoint_thr,
386
+ tag_thr=self.decode_tag_thr,
387
+ max_groups=self.decode_max_instances)
388
+
389
+ _results = map(_group_func, zip(batch_vals, batch_tags, batch_locs))
390
+ results = list(_results)
391
+ return results
392
+
393
+ def _fill_missing_keypoints(self, keypoints: np.ndarray,
394
+ keypoint_scores: np.ndarray,
395
+ heatmaps: np.ndarray, tags: np.ndarray):
396
+ """Fill the missing keypoints in the initial predictions.
397
+
398
+ Args:
399
+ keypoints (np.ndarray): Keypoint predictions in shape (N, K, D)
400
+ keypoint_scores (np.ndarray): Keypint score predictions in shape
401
+ (N, K), in which 0 means the corresponding keypoint is
402
+ missing in the initial prediction
403
+ heatmaps (np.ndarry): Heatmaps in shape (K, H, W)
404
+ tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where
405
+ C=L*K
406
+
407
+ Returns:
408
+ tuple:
409
+ - keypoints (np.ndarray): Keypoint predictions with missing
410
+ ones filled
411
+ - keypoint_scores (np.ndarray): Keypoint score predictions with
412
+ missing ones filled
413
+ """
414
+
415
+ N, K = keypoints.shape[:2]
416
+ H, W = heatmaps.shape[1:]
417
+ L = tags.shape[0] // K
418
+ keypoint_tags = [tags[k::K] for k in range(K)]
419
+
420
+ for n in range(N):
421
+ # Calculate the instance tag (mean tag of detected keypoints)
422
+ _tag = []
423
+ for k in range(K):
424
+ if keypoint_scores[n, k] > 0:
425
+ x, y = keypoints[n, k, :2].astype(np.int64)
426
+ x = np.clip(x, 0, W - 1)
427
+ y = np.clip(y, 0, H - 1)
428
+ _tag.append(keypoint_tags[k][:, y, x])
429
+
430
+ tag = np.mean(_tag, axis=0)
431
+ tag = tag.reshape(L, 1, 1)
432
+ # Search maximum response of the missing keypoints
433
+ for k in range(K):
434
+ if keypoint_scores[n, k] > 0:
435
+ continue
436
+ dist_map = np.linalg.norm(
437
+ keypoint_tags[k] - tag, ord=2, axis=0)
438
+ cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W
439
+ y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W))
440
+ keypoints[n, k] = [x, y]
441
+ keypoint_scores[n, k] = heatmaps[k, y, x]
442
+
443
+ return keypoints, keypoint_scores
444
+
445
+ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
446
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
447
+ """Decode the keypoint coordinates from a batch of heatmaps and tagging
448
+ heatmaps. The decoded keypoint coordinates are in the input image
449
+ space.
450
+
451
+ Args:
452
+ batch_heatmaps (Tensor): Keypoint detection heatmaps in shape
453
+ (B, K, H, W)
454
+ batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where
455
+ :math:`C=L*K`
456
+
457
+ Returns:
458
+ tuple:
459
+ - batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates
460
+ of the batch, each is in shape (N, K, D)
461
+ - batch_scores (List[np.ndarray]): Decoded keypoint scores of the
462
+ batch, each is in shape (N, K). It usually represents the
463
+ confidience of the keypoint prediction
464
+ """
465
+ B, _, H, W = batch_heatmaps.shape
466
+ assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), (
467
+ f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and '
468
+ f'tagging map ({batch_tags.shape})')
469
+
470
+ # Heatmap NMS
471
+ batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps,
472
+ self.decode_nms_kernel)
473
+
474
+ # Get top-k in each heatmap and and convert to numpy
475
+ batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
476
+ self._get_batch_topk(
477
+ batch_heatmaps_peak, batch_tags, k=self.decode_topk))
478
+
479
+ # Group keypoint candidates into groups (instances)
480
+ batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags,
481
+ batch_topk_locs)
482
+
483
+ # Convert to numpy
484
+ batch_heatmaps_np = to_numpy(batch_heatmaps)
485
+ batch_tags_np = to_numpy(batch_tags)
486
+
487
+ # Refine the keypoint prediction
488
+ batch_keypoints = []
489
+ batch_keypoint_scores = []
490
+ batch_instance_scores = []
491
+ for i, (groups, heatmaps, tags) in enumerate(
492
+ zip(batch_groups, batch_heatmaps_np, batch_tags_np)):
493
+
494
+ keypoints, scores = groups[..., :-1], groups[..., -1]
495
+ instance_scores = scores.mean(axis=-1)
496
+
497
+ if keypoints.size > 0:
498
+ # refine keypoint coordinates according to heatmap distribution
499
+ if self.use_udp:
500
+ keypoints = refine_keypoints_dark_udp(
501
+ keypoints,
502
+ heatmaps,
503
+ blur_kernel_size=self.decode_gaussian_kernel)
504
+ else:
505
+ keypoints = refine_keypoints(keypoints, heatmaps)
506
+ keypoints += self.decode_center_shift * \
507
+ (scores > 0).astype(keypoints.dtype)[..., None]
508
+
509
+ # identify missing keypoints
510
+ keypoints, scores = self._fill_missing_keypoints(
511
+ keypoints, scores, heatmaps, tags)
512
+
513
+ batch_keypoints.append(keypoints)
514
+ batch_keypoint_scores.append(scores)
515
+ batch_instance_scores.append(instance_scores)
516
+
517
+ # restore keypoint scale
518
+ batch_keypoints = [
519
+ kpts * self.scale_factor for kpts in batch_keypoints
520
+ ]
521
+
522
+ return batch_keypoints, batch_keypoint_scores, batch_instance_scores
mmpose/codecs/base.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta, abstractmethod
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ from mmengine.utils import is_method_overridden
7
+
8
+
9
+ class BaseKeypointCodec(metaclass=ABCMeta):
10
+ """The base class of the keypoint codec.
11
+
12
+ A keypoint codec is a module to encode keypoint coordinates to specific
13
+ representation (e.g. heatmap) and vice versa. A subclass should implement
14
+ the methods :meth:`encode` and :meth:`decode`.
15
+ """
16
+
17
+ # pass additional encoding arguments to the `encode` method, beyond the
18
+ # mandatory `keypoints` and `keypoints_visible` arguments.
19
+ auxiliary_encode_keys = set()
20
+
21
+ field_mapping_table = dict()
22
+ instance_mapping_table = dict()
23
+ label_mapping_table = dict()
24
+
25
+ @abstractmethod
26
+ def encode(self,
27
+ keypoints: np.ndarray,
28
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
29
+ """Encode keypoints.
30
+
31
+ Note:
32
+
33
+ - instance number: N
34
+ - keypoint number: K
35
+ - keypoint dimension: D
36
+
37
+ Args:
38
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
39
+ keypoints_visible (np.ndarray): Keypoint visibility in shape
40
+ (N, K, D)
41
+
42
+ Returns:
43
+ dict: Encoded items.
44
+ """
45
+
46
+ @abstractmethod
47
+ def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
48
+ """Decode keypoints.
49
+
50
+ Args:
51
+ encoded (any): Encoded keypoint representation using the codec
52
+
53
+ Returns:
54
+ tuple:
55
+ - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
56
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
57
+ (N, K, D)
58
+ """
59
+
60
+ def batch_decode(self, batch_encoded: Any
61
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
62
+ """Decode keypoints.
63
+
64
+ Args:
65
+ batch_encoded (any): A batch of encoded keypoint
66
+ representations
67
+
68
+ Returns:
69
+ tuple:
70
+ - batch_keypoints (List[np.ndarray]): Each element is keypoint
71
+ coordinates in shape (N, K, D)
72
+ - batch_keypoints (List[np.ndarray]): Each element is keypoint
73
+ visibility in shape (N, K)
74
+ """
75
+ raise NotImplementedError()
76
+
77
+ @property
78
+ def support_batch_decoding(self) -> bool:
79
+ """Return whether the codec support decoding from batch data."""
80
+ return is_method_overridden('batch_decode', BaseKeypointCodec,
81
+ self.__class__)
mmpose/codecs/decoupled_heatmap.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import random
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from mmpose.registry import KEYPOINT_CODECS
8
+ from .base import BaseKeypointCodec
9
+ from .utils import (generate_gaussian_heatmaps, get_diagonal_lengths,
10
+ get_instance_bbox, get_instance_root)
11
+ from .utils.post_processing import get_heatmap_maximum
12
+ from .utils.refinement import refine_keypoints
13
+
14
+
15
+ @KEYPOINT_CODECS.register_module()
16
+ class DecoupledHeatmap(BaseKeypointCodec):
17
+ """Encode/decode keypoints with the method introduced in the paper CID.
18
+
19
+ See the paper Contextual Instance Decoupling for Robust Multi-Person
20
+ Pose Estimation`_ by Wang et al (2022) for details
21
+
22
+ Note:
23
+
24
+ - instance number: N
25
+ - keypoint number: K
26
+ - keypoint dimension: D
27
+ - image size: [w, h]
28
+ - heatmap size: [W, H]
29
+
30
+ Encoded:
31
+ - heatmaps (np.ndarray): The coupled heatmap in shape
32
+ (1+K, H, W) where [W, H] is the `heatmap_size`.
33
+ - instance_heatmaps (np.ndarray): The decoupled heatmap in shape
34
+ (M*K, H, W) where M is the number of instances.
35
+ - keypoint_weights (np.ndarray): The weight for heatmaps in shape
36
+ (M*K).
37
+ - instance_coords (np.ndarray): The coordinates of instance roots
38
+ in shape (M, 2)
39
+
40
+ Args:
41
+ input_size (tuple): Image size in [w, h]
42
+ heatmap_size (tuple): Heatmap size in [W, H]
43
+ root_type (str): The method to generate the instance root. Options
44
+ are:
45
+
46
+ - ``'kpt_center'``: Average coordinate of all visible keypoints.
47
+ - ``'bbox_center'``: Center point of bounding boxes outlined by
48
+ all visible keypoints.
49
+
50
+ Defaults to ``'kpt_center'``
51
+
52
+ heatmap_min_overlap (float): Minimum overlap rate among instances.
53
+ Used when calculating sigmas for instances. Defaults to 0.7
54
+ background_weight (float): Loss weight of background pixels.
55
+ Defaults to 0.1
56
+ encode_max_instances (int): The maximum number of instances
57
+ to encode for each sample. Defaults to 30
58
+
59
+ .. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_
60
+ Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_
61
+ CVPR_2022_paper.html
62
+ """
63
+
64
+ # DecoupledHeatmap requires bounding boxes to determine the size of each
65
+ # instance, so that it can assign varying sigmas based on their size
66
+ auxiliary_encode_keys = {'bbox'}
67
+
68
+ label_mapping_table = dict(
69
+ keypoint_weights='keypoint_weights',
70
+ instance_coords='instance_coords',
71
+ )
72
+ field_mapping_table = dict(
73
+ heatmaps='heatmaps',
74
+ instance_heatmaps='instance_heatmaps',
75
+ )
76
+
77
+ def __init__(
78
+ self,
79
+ input_size: Tuple[int, int],
80
+ heatmap_size: Tuple[int, int],
81
+ root_type: str = 'kpt_center',
82
+ heatmap_min_overlap: float = 0.7,
83
+ encode_max_instances: int = 30,
84
+ ):
85
+ super().__init__()
86
+
87
+ self.input_size = input_size
88
+ self.heatmap_size = heatmap_size
89
+ self.root_type = root_type
90
+ self.encode_max_instances = encode_max_instances
91
+ self.heatmap_min_overlap = heatmap_min_overlap
92
+
93
+ self.scale_factor = (np.array(input_size) /
94
+ heatmap_size).astype(np.float32)
95
+
96
+ def _get_instance_wise_sigmas(
97
+ self,
98
+ bbox: np.ndarray,
99
+ ) -> np.ndarray:
100
+ """Get sigma values for each instance according to their size.
101
+
102
+ Args:
103
+ bbox (np.ndarray): Bounding box in shape (N, 4, 2)
104
+
105
+ Returns:
106
+ np.ndarray: Array containing the sigma values for each instance.
107
+ """
108
+ sigmas = np.zeros((bbox.shape[0], ), dtype=np.float32)
109
+
110
+ heights = np.sqrt(np.power(bbox[:, 0] - bbox[:, 1], 2).sum(axis=-1))
111
+ widths = np.sqrt(np.power(bbox[:, 0] - bbox[:, 2], 2).sum(axis=-1))
112
+
113
+ for i in range(bbox.shape[0]):
114
+ h, w = heights[i], widths[i]
115
+
116
+ # compute sigma for each instance
117
+ # condition 1
118
+ a1, b1 = 1, h + w
119
+ c1 = w * h * (1 - self.heatmap_min_overlap) / (
120
+ 1 + self.heatmap_min_overlap)
121
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
122
+ r1 = (b1 + sq1) / 2
123
+
124
+ # condition 2
125
+ a2 = 4
126
+ b2 = 2 * (h + w)
127
+ c2 = (1 - self.heatmap_min_overlap) * w * h
128
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
129
+ r2 = (b2 + sq2) / 2
130
+
131
+ # condition 3
132
+ a3 = 4 * self.heatmap_min_overlap
133
+ b3 = -2 * self.heatmap_min_overlap * (h + w)
134
+ c3 = (self.heatmap_min_overlap - 1) * w * h
135
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
136
+ r3 = (b3 + sq3) / 2
137
+
138
+ sigmas[i] = min(r1, r2, r3) / 3
139
+
140
+ return sigmas
141
+
142
+ def encode(self,
143
+ keypoints: np.ndarray,
144
+ keypoints_visible: Optional[np.ndarray] = None,
145
+ bbox: Optional[np.ndarray] = None) -> dict:
146
+ """Encode keypoints into heatmaps.
147
+
148
+ Args:
149
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
150
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
151
+ (N, K)
152
+ bbox (np.ndarray): Bounding box in shape (N, 8) which includes
153
+ coordinates of 4 corners.
154
+
155
+ Returns:
156
+ dict:
157
+ - heatmaps (np.ndarray): The coupled heatmap in shape
158
+ (1+K, H, W) where [W, H] is the `heatmap_size`.
159
+ - instance_heatmaps (np.ndarray): The decoupled heatmap in shape
160
+ (N*K, H, W) where M is the number of instances.
161
+ - keypoint_weights (np.ndarray): The weight for heatmaps in shape
162
+ (N*K).
163
+ - instance_coords (np.ndarray): The coordinates of instance roots
164
+ in shape (N, 2)
165
+ """
166
+
167
+ if keypoints_visible is None:
168
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
169
+ if bbox is None:
170
+ # generate pseudo bbox via visible keypoints
171
+ bbox = get_instance_bbox(keypoints, keypoints_visible)
172
+ bbox = np.tile(bbox, 2).reshape(-1, 4, 2)
173
+ # corner order: left_top, left_bottom, right_top, right_bottom
174
+ bbox[:, 1:3, 0] = bbox[:, 0:2, 0]
175
+
176
+ # keypoint coordinates in heatmap
177
+ _keypoints = keypoints / self.scale_factor
178
+ _bbox = bbox.reshape(-1, 4, 2) / self.scale_factor
179
+
180
+ # compute the root and scale of each instance
181
+ roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
182
+ self.root_type)
183
+
184
+ sigmas = self._get_instance_wise_sigmas(_bbox)
185
+
186
+ # generate global heatmaps
187
+ heatmaps, keypoint_weights = generate_gaussian_heatmaps(
188
+ heatmap_size=self.heatmap_size,
189
+ keypoints=np.concatenate((_keypoints, roots[:, None]), axis=1),
190
+ keypoints_visible=np.concatenate(
191
+ (keypoints_visible, roots_visible[:, None]), axis=1),
192
+ sigma=sigmas)
193
+ roots_visible = keypoint_weights[:, -1]
194
+
195
+ # select instances
196
+ inst_roots, inst_indices = [], []
197
+ diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
198
+ for i in np.argsort(diagonal_lengths):
199
+ if roots_visible[i] < 1:
200
+ continue
201
+ # rand root point in 3x3 grid
202
+ x, y = roots[i] + np.random.randint(-1, 2, (2, ))
203
+ x = max(0, min(x, self.heatmap_size[0] - 1))
204
+ y = max(0, min(y, self.heatmap_size[1] - 1))
205
+ if (x, y) not in inst_roots:
206
+ inst_roots.append((x, y))
207
+ inst_indices.append(i)
208
+ if len(inst_indices) > self.encode_max_instances:
209
+ rand_indices = random.sample(
210
+ range(len(inst_indices)), self.encode_max_instances)
211
+ inst_roots = [inst_roots[i] for i in rand_indices]
212
+ inst_indices = [inst_indices[i] for i in rand_indices]
213
+
214
+ # generate instance-wise heatmaps
215
+ inst_heatmaps, inst_heatmap_weights = [], []
216
+ for i in inst_indices:
217
+ inst_heatmap, inst_heatmap_weight = generate_gaussian_heatmaps(
218
+ heatmap_size=self.heatmap_size,
219
+ keypoints=_keypoints[i:i + 1],
220
+ keypoints_visible=keypoints_visible[i:i + 1],
221
+ sigma=sigmas[i].item())
222
+ inst_heatmaps.append(inst_heatmap)
223
+ inst_heatmap_weights.append(inst_heatmap_weight)
224
+
225
+ if len(inst_indices) > 0:
226
+ inst_heatmaps = np.concatenate(inst_heatmaps)
227
+ inst_heatmap_weights = np.concatenate(inst_heatmap_weights)
228
+ inst_roots = np.array(inst_roots, dtype=np.int32)
229
+ else:
230
+ inst_heatmaps = np.empty((0, *self.heatmap_size[::-1]))
231
+ inst_heatmap_weights = np.empty((0, ))
232
+ inst_roots = np.empty((0, 2), dtype=np.int32)
233
+
234
+ encoded = dict(
235
+ heatmaps=heatmaps,
236
+ instance_heatmaps=inst_heatmaps,
237
+ keypoint_weights=inst_heatmap_weights,
238
+ instance_coords=inst_roots)
239
+
240
+ return encoded
241
+
242
+ def decode(self, instance_heatmaps: np.ndarray,
243
+ instance_scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
244
+ """Decode keypoint coordinates from decoupled heatmaps. The decoded
245
+ keypoint coordinates are in the input image space.
246
+
247
+ Args:
248
+ instance_heatmaps (np.ndarray): Heatmaps in shape (N, K, H, W)
249
+ instance_scores (np.ndarray): Confidence of instance roots
250
+ prediction in shape (N, 1)
251
+
252
+ Returns:
253
+ tuple:
254
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
255
+ (N, K, D)
256
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
257
+ usually represents the confidence of the keypoint prediction
258
+ """
259
+ keypoints, keypoint_scores = [], []
260
+
261
+ for i in range(instance_heatmaps.shape[0]):
262
+ heatmaps = instance_heatmaps[i].copy()
263
+ kpts, scores = get_heatmap_maximum(heatmaps)
264
+ keypoints.append(refine_keypoints(kpts[None], heatmaps))
265
+ keypoint_scores.append(scores[None])
266
+
267
+ keypoints = np.concatenate(keypoints)
268
+ # Restore the keypoint scale
269
+ keypoints = keypoints * self.scale_factor
270
+
271
+ keypoint_scores = np.concatenate(keypoint_scores)
272
+ keypoint_scores *= instance_scores
273
+
274
+ return keypoints, keypoint_scores
mmpose/codecs/edpose_label.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ from mmpose.registry import KEYPOINT_CODECS
7
+ from mmpose.structures import bbox_cs2xyxy, bbox_xyxy2cs
8
+ from .base import BaseKeypointCodec
9
+
10
+
11
+ @KEYPOINT_CODECS.register_module()
12
+ class EDPoseLabel(BaseKeypointCodec):
13
+ r"""Generate keypoint and label coordinates for `ED-Pose`_ by
14
+ Yang J. et al (2023).
15
+
16
+ Note:
17
+
18
+ - instance number: N
19
+ - keypoint number: K
20
+ - keypoint dimension: D
21
+ - image size: [w, h]
22
+
23
+ Encoded:
24
+
25
+ - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
26
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
27
+ (N, K, D)
28
+ - area (np.ndarray): Area in shape (N)
29
+ - bbox (np.ndarray): Bbox in shape (N, 4)
30
+
31
+ Args:
32
+ num_select (int): The number of candidate instances
33
+ num_keypoints (int): The Number of keypoints
34
+ """
35
+
36
+ auxiliary_encode_keys = {'area', 'bboxes', 'img_shape'}
37
+ instance_mapping_table = dict(
38
+ bbox='bboxes',
39
+ keypoints='keypoints',
40
+ keypoints_visible='keypoints_visible',
41
+ area='areas',
42
+ )
43
+
44
+ def __init__(self, num_select: int = 100, num_keypoints: int = 17):
45
+ super().__init__()
46
+
47
+ self.num_select = num_select
48
+ self.num_keypoints = num_keypoints
49
+
50
+ def encode(
51
+ self,
52
+ img_shape,
53
+ keypoints: np.ndarray,
54
+ keypoints_visible: Optional[np.ndarray] = None,
55
+ area: Optional[np.ndarray] = None,
56
+ bboxes: Optional[np.ndarray] = None,
57
+ ) -> dict:
58
+ """Encoding keypoints, area and bbox from input image space to
59
+ normalized space.
60
+
61
+ Args:
62
+ - img_shape (Sequence[int]): The shape of image in the format
63
+ of (width, height).
64
+ - keypoints (np.ndarray): Keypoint coordinates in
65
+ shape (N, K, D).
66
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
67
+ (N, K)
68
+ - area (np.ndarray):
69
+ - bboxes (np.ndarray):
70
+
71
+ Returns:
72
+ encoded (dict): Contains the following items:
73
+
74
+ - keypoint_labels (np.ndarray): The processed keypoints in
75
+ shape like (N, K, D).
76
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
77
+ (N, K, D)
78
+ - area_labels (np.ndarray): The processed target
79
+ area in shape (N).
80
+ - bboxes_labels: The processed target bbox in
81
+ shape (N, 4).
82
+ """
83
+ w, h = img_shape
84
+
85
+ if keypoints_visible is None:
86
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
87
+
88
+ if bboxes is not None:
89
+ bboxes = np.concatenate(bbox_xyxy2cs(bboxes), axis=-1)
90
+ bboxes = bboxes / np.array([w, h, w, h], dtype=np.float32)
91
+
92
+ if area is not None:
93
+ area = area / float(w * h)
94
+
95
+ if keypoints is not None:
96
+ keypoints = keypoints / np.array([w, h], dtype=np.float32)
97
+
98
+ encoded = dict(
99
+ keypoints=keypoints,
100
+ area=area,
101
+ bbox=bboxes,
102
+ keypoints_visible=keypoints_visible)
103
+
104
+ return encoded
105
+
106
+ def decode(self, input_shapes: np.ndarray, pred_logits: np.ndarray,
107
+ pred_boxes: np.ndarray, pred_keypoints: np.ndarray):
108
+ """Select the final top-k keypoints, and decode the results from
109
+ normalize size to origin input size.
110
+
111
+ Args:
112
+ input_shapes (Tensor): The size of input image resize.
113
+ test_cfg (ConfigType): Config of testing.
114
+ pred_logits (Tensor): The result of score.
115
+ pred_boxes (Tensor): The result of bbox.
116
+ pred_keypoints (Tensor): The result of keypoints.
117
+
118
+ Returns:
119
+ tuple: Decoded boxes, keypoints, and keypoint scores.
120
+ """
121
+
122
+ # Initialization
123
+ num_keypoints = self.num_keypoints
124
+ prob = pred_logits.reshape(-1)
125
+
126
+ # Select top-k instances based on prediction scores
127
+ topk_indexes = np.argsort(-prob)[:self.num_select]
128
+ topk_values = np.take_along_axis(prob, topk_indexes, axis=0)
129
+ scores = np.tile(topk_values[:, np.newaxis], [1, num_keypoints])
130
+
131
+ # Decode bounding boxes
132
+ topk_boxes = topk_indexes // pred_logits.shape[1]
133
+ boxes = bbox_cs2xyxy(*np.split(pred_boxes, [2], axis=-1))
134
+ boxes = np.take_along_axis(
135
+ boxes, np.tile(topk_boxes[:, np.newaxis], [1, 4]), axis=0)
136
+
137
+ # Convert from relative to absolute coordinates
138
+ img_h, img_w = np.split(input_shapes, 2, axis=0)
139
+ scale_fct = np.hstack([img_w, img_h, img_w, img_h])
140
+ boxes = boxes * scale_fct[np.newaxis, :]
141
+
142
+ # Decode keypoints
143
+ topk_keypoints = topk_indexes // pred_logits.shape[1]
144
+ keypoints = np.take_along_axis(
145
+ pred_keypoints,
146
+ np.tile(topk_keypoints[:, np.newaxis], [1, num_keypoints * 3]),
147
+ axis=0)
148
+ keypoints = keypoints[:, :(num_keypoints * 2)]
149
+ keypoints = keypoints * np.tile(
150
+ np.hstack([img_w, img_h]), [num_keypoints])[np.newaxis, :]
151
+ keypoints = keypoints.reshape(-1, num_keypoints, 2)
152
+
153
+ return boxes, keypoints, scores
mmpose/codecs/hand_3d_heatmap.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from mmpose.registry import KEYPOINT_CODECS
7
+ from .base import BaseKeypointCodec
8
+ from .utils.gaussian_heatmap import generate_3d_gaussian_heatmaps
9
+ from .utils.post_processing import get_heatmap_3d_maximum
10
+
11
+
12
+ @KEYPOINT_CODECS.register_module()
13
+ class Hand3DHeatmap(BaseKeypointCodec):
14
+ r"""Generate target 3d heatmap and relative root depth for hand datasets.
15
+
16
+ Note:
17
+
18
+ - instance number: N
19
+ - keypoint number: K
20
+ - keypoint dimension: D
21
+
22
+ Args:
23
+ image_size (tuple): Size of image. Default: ``[256, 256]``.
24
+ root_heatmap_size (int): Size of heatmap of root head.
25
+ Default: 64.
26
+ heatmap_size (tuple): Size of heatmap. Default: ``[64, 64, 64]``.
27
+ heatmap3d_depth_bound (float): Boundary for 3d heatmap depth.
28
+ Default: 400.0.
29
+ heatmap_size_root (int): Size of 3d heatmap root. Default: 64.
30
+ depth_size (int): Number of depth discretization size, used for
31
+ decoding. Defaults to 64.
32
+ root_depth_bound (float): Boundary for 3d heatmap root depth.
33
+ Default: 400.0.
34
+ use_different_joint_weights (bool): Whether to use different joint
35
+ weights. Default: ``False``.
36
+ sigma (int): Sigma of heatmap gaussian. Default: 2.
37
+ joint_indices (list, optional): Indices of joints used for heatmap
38
+ generation. If None (default) is given, all joints will be used.
39
+ Default: ``None``.
40
+ max_bound (float): The maximal value of heatmap. Default: 1.0.
41
+ """
42
+
43
+ auxiliary_encode_keys = {
44
+ 'dataset_keypoint_weights', 'rel_root_depth', 'rel_root_valid',
45
+ 'hand_type', 'hand_type_valid', 'focal', 'principal_pt'
46
+ }
47
+
48
+ instance_mapping_table = {
49
+ 'keypoints': 'keypoints',
50
+ 'keypoints_visible': 'keypoints_visible',
51
+ 'keypoints_cam': 'keypoints_cam',
52
+ }
53
+
54
+ label_mapping_table = {
55
+ 'keypoint_weights': 'keypoint_weights',
56
+ 'root_depth_weight': 'root_depth_weight',
57
+ 'type_weight': 'type_weight',
58
+ 'root_depth': 'root_depth',
59
+ 'type': 'type'
60
+ }
61
+
62
+ def __init__(self,
63
+ image_size: Tuple[int, int] = [256, 256],
64
+ root_heatmap_size: int = 64,
65
+ heatmap_size: Tuple[int, int, int] = [64, 64, 64],
66
+ heatmap3d_depth_bound: float = 400.0,
67
+ heatmap_size_root: int = 64,
68
+ root_depth_bound: float = 400.0,
69
+ depth_size: int = 64,
70
+ use_different_joint_weights: bool = False,
71
+ sigma: int = 2,
72
+ joint_indices: Optional[list] = None,
73
+ max_bound: float = 1.0):
74
+ super().__init__()
75
+
76
+ self.image_size = np.array(image_size)
77
+ self.root_heatmap_size = root_heatmap_size
78
+ self.heatmap_size = np.array(heatmap_size)
79
+ self.heatmap3d_depth_bound = heatmap3d_depth_bound
80
+ self.heatmap_size_root = heatmap_size_root
81
+ self.root_depth_bound = root_depth_bound
82
+ self.depth_size = depth_size
83
+ self.use_different_joint_weights = use_different_joint_weights
84
+
85
+ self.sigma = sigma
86
+ self.joint_indices = joint_indices
87
+ self.max_bound = max_bound
88
+ self.scale_factor = (np.array(image_size) /
89
+ heatmap_size[:-1]).astype(np.float32)
90
+
91
+ def encode(
92
+ self,
93
+ keypoints: np.ndarray,
94
+ keypoints_visible: Optional[np.ndarray],
95
+ dataset_keypoint_weights: Optional[np.ndarray],
96
+ rel_root_depth: np.float32,
97
+ rel_root_valid: np.float32,
98
+ hand_type: np.ndarray,
99
+ hand_type_valid: np.ndarray,
100
+ focal: np.ndarray,
101
+ principal_pt: np.ndarray,
102
+ ) -> dict:
103
+ """Encoding keypoints from input image space to input image space.
104
+
105
+ Args:
106
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D).
107
+ keypoints_visible (np.ndarray, optional): Keypoint visibilities in
108
+ shape (N, K).
109
+ dataset_keypoint_weights (np.ndarray, optional): Keypoints weight
110
+ in shape (K, ).
111
+ rel_root_depth (np.float32): Relative root depth.
112
+ rel_root_valid (float): Validity of relative root depth.
113
+ hand_type (np.ndarray): Type of hand encoded as a array.
114
+ hand_type_valid (np.ndarray): Validity of hand type.
115
+ focal (np.ndarray): Focal length of camera.
116
+ principal_pt (np.ndarray): Principal point of camera.
117
+
118
+ Returns:
119
+ encoded (dict): Contains the following items:
120
+
121
+ - heatmaps (np.ndarray): The generated heatmap in shape
122
+ (K * D, H, W) where [W, H, D] is the `heatmap_size`
123
+ - keypoint_weights (np.ndarray): The target weights in shape
124
+ (N, K)
125
+ - root_depth (np.ndarray): Encoded relative root depth
126
+ - root_depth_weight (np.ndarray): The weights of relative root
127
+ depth
128
+ - type (np.ndarray): Encoded hand type
129
+ - type_weight (np.ndarray): The weights of hand type
130
+ """
131
+ if keypoints_visible is None:
132
+ keypoints_visible = np.ones(keypoints.shape[:-1], dtype=np.float32)
133
+
134
+ if self.use_different_joint_weights:
135
+ assert dataset_keypoint_weights is not None, 'To use different ' \
136
+ 'joint weights,`dataset_keypoint_weights` cannot be None.'
137
+
138
+ heatmaps, keypoint_weights = generate_3d_gaussian_heatmaps(
139
+ heatmap_size=self.heatmap_size,
140
+ keypoints=keypoints,
141
+ keypoints_visible=keypoints_visible,
142
+ sigma=self.sigma,
143
+ image_size=self.image_size,
144
+ heatmap3d_depth_bound=self.heatmap3d_depth_bound,
145
+ joint_indices=self.joint_indices,
146
+ max_bound=self.max_bound,
147
+ use_different_joint_weights=self.use_different_joint_weights,
148
+ dataset_keypoint_weights=dataset_keypoint_weights)
149
+
150
+ rel_root_depth = (rel_root_depth / self.root_depth_bound +
151
+ 0.5) * self.heatmap_size_root
152
+ rel_root_valid = rel_root_valid * (rel_root_depth >= 0) * (
153
+ rel_root_depth <= self.heatmap_size_root)
154
+
155
+ encoded = dict(
156
+ heatmaps=heatmaps,
157
+ keypoint_weights=keypoint_weights,
158
+ root_depth=rel_root_depth * np.ones(1, dtype=np.float32),
159
+ type=hand_type,
160
+ type_weight=hand_type_valid,
161
+ root_depth_weight=rel_root_valid * np.ones(1, dtype=np.float32))
162
+ return encoded
163
+
164
+ def decode(self, heatmaps: np.ndarray, root_depth: np.ndarray,
165
+ hand_type: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
166
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
167
+ coordinates are in the input image space.
168
+
169
+ Args:
170
+ heatmaps (np.ndarray): Heatmaps in shape (K, D, H, W)
171
+ root_depth (np.ndarray): Root depth prediction.
172
+ hand_type (np.ndarray): Hand type prediction.
173
+
174
+ Returns:
175
+ tuple:
176
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
177
+ (N, K, D)
178
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
179
+ usually represents the confidence of the keypoint prediction
180
+ """
181
+ heatmap3d = heatmaps.copy()
182
+
183
+ keypoints, scores = get_heatmap_3d_maximum(heatmap3d)
184
+
185
+ # transform keypoint depth to camera space
186
+ keypoints[..., 2] = (keypoints[..., 2] / self.depth_size -
187
+ 0.5) * self.heatmap3d_depth_bound
188
+
189
+ # Unsqueeze the instance dimension for single-instance results
190
+ keypoints, scores = keypoints[None], scores[None]
191
+
192
+ # Restore the keypoint scale
193
+ keypoints[..., :2] = keypoints[..., :2] * self.scale_factor
194
+
195
+ # decode relative hand root depth
196
+ # transform relative root depth to camera space
197
+ rel_root_depth = ((root_depth / self.root_heatmap_size - 0.5) *
198
+ self.root_depth_bound)
199
+
200
+ hand_type = (hand_type > 0).reshape(1, -1).astype(int)
201
+
202
+ return keypoints, scores, rel_root_depth, hand_type
mmpose/codecs/image_pose_lifting.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from mmpose.registry import KEYPOINT_CODECS
7
+ from .base import BaseKeypointCodec
8
+
9
+
10
+ @KEYPOINT_CODECS.register_module()
11
+ class ImagePoseLifting(BaseKeypointCodec):
12
+ r"""Generate keypoint coordinates for pose lifter.
13
+
14
+ Note:
15
+
16
+ - instance number: N
17
+ - keypoint number: K
18
+ - keypoint dimension: D
19
+ - pose-lifitng target dimension: C
20
+
21
+ Args:
22
+ num_keypoints (int): The number of keypoints in the dataset.
23
+ root_index (Union[int, List]): Root keypoint index in the pose.
24
+ remove_root (bool): If true, remove the root keypoint from the pose.
25
+ Default: ``False``.
26
+ save_index (bool): If true, store the root position separated from the
27
+ original pose. Default: ``False``.
28
+ reshape_keypoints (bool): If true, reshape the keypoints into shape
29
+ (-1, N). Default: ``True``.
30
+ concat_vis (bool): If true, concat the visibility item of keypoints.
31
+ Default: ``False``.
32
+ keypoints_mean (np.ndarray, optional): Mean values of keypoints
33
+ coordinates in shape (K, D).
34
+ keypoints_std (np.ndarray, optional): Std values of keypoints
35
+ coordinates in shape (K, D).
36
+ target_mean (np.ndarray, optional): Mean values of pose-lifitng target
37
+ coordinates in shape (K, C).
38
+ target_std (np.ndarray, optional): Std values of pose-lifitng target
39
+ coordinates in shape (K, C).
40
+ """
41
+
42
+ auxiliary_encode_keys = {'lifting_target', 'lifting_target_visible'}
43
+
44
+ instance_mapping_table = dict(
45
+ lifting_target='lifting_target',
46
+ lifting_target_visible='lifting_target_visible',
47
+ )
48
+ label_mapping_table = dict(
49
+ trajectory_weights='trajectory_weights',
50
+ lifting_target_label='lifting_target_label',
51
+ lifting_target_weight='lifting_target_weight')
52
+
53
+ def __init__(self,
54
+ num_keypoints: int,
55
+ root_index: Union[int, List] = 0,
56
+ remove_root: bool = False,
57
+ save_index: bool = False,
58
+ reshape_keypoints: bool = True,
59
+ concat_vis: bool = False,
60
+ keypoints_mean: Optional[np.ndarray] = None,
61
+ keypoints_std: Optional[np.ndarray] = None,
62
+ target_mean: Optional[np.ndarray] = None,
63
+ target_std: Optional[np.ndarray] = None,
64
+ additional_encode_keys: Optional[List[str]] = None):
65
+ super().__init__()
66
+
67
+ self.num_keypoints = num_keypoints
68
+ if isinstance(root_index, int):
69
+ root_index = [root_index]
70
+ self.root_index = root_index
71
+ self.remove_root = remove_root
72
+ self.save_index = save_index
73
+ self.reshape_keypoints = reshape_keypoints
74
+ self.concat_vis = concat_vis
75
+ if keypoints_mean is not None:
76
+ assert keypoints_std is not None, 'keypoints_std is None'
77
+ keypoints_mean = np.array(
78
+ keypoints_mean,
79
+ dtype=np.float32).reshape(1, num_keypoints, -1)
80
+ keypoints_std = np.array(
81
+ keypoints_std, dtype=np.float32).reshape(1, num_keypoints, -1)
82
+
83
+ assert keypoints_mean.shape == keypoints_std.shape, (
84
+ f'keypoints_mean.shape {keypoints_mean.shape} != '
85
+ f'keypoints_std.shape {keypoints_std.shape}')
86
+ if target_mean is not None:
87
+ assert target_std is not None, 'target_std is None'
88
+ target_dim = num_keypoints - 1 if remove_root else num_keypoints
89
+ target_mean = np.array(
90
+ target_mean, dtype=np.float32).reshape(1, target_dim, -1)
91
+ target_std = np.array(
92
+ target_std, dtype=np.float32).reshape(1, target_dim, -1)
93
+
94
+ assert target_mean.shape == target_std.shape, (
95
+ f'target_mean.shape {target_mean.shape} != '
96
+ f'target_std.shape {target_std.shape}')
97
+ self.keypoints_mean = keypoints_mean
98
+ self.keypoints_std = keypoints_std
99
+ self.target_mean = target_mean
100
+ self.target_std = target_std
101
+
102
+ if additional_encode_keys is not None:
103
+ self.auxiliary_encode_keys.update(additional_encode_keys)
104
+
105
+ def encode(self,
106
+ keypoints: np.ndarray,
107
+ keypoints_visible: Optional[np.ndarray] = None,
108
+ lifting_target: Optional[np.ndarray] = None,
109
+ lifting_target_visible: Optional[np.ndarray] = None) -> dict:
110
+ """Encoding keypoints from input image space to normalized space.
111
+
112
+ Args:
113
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D).
114
+ keypoints_visible (np.ndarray, optional): Keypoint visibilities in
115
+ shape (N, K).
116
+ lifting_target (np.ndarray, optional): 3d target coordinate in
117
+ shape (T, K, C).
118
+ lifting_target_visible (np.ndarray, optional): Target coordinate in
119
+ shape (T, K, ).
120
+
121
+ Returns:
122
+ encoded (dict): Contains the following items:
123
+
124
+ - keypoint_labels (np.ndarray): The processed keypoints in
125
+ shape like (N, K, D) or (K * D, N).
126
+ - keypoint_labels_visible (np.ndarray): The processed
127
+ keypoints' weights in shape (N, K, ) or (N-1, K, ).
128
+ - lifting_target_label: The processed target coordinate in
129
+ shape (K, C) or (K-1, C).
130
+ - lifting_target_weight (np.ndarray): The target weights in
131
+ shape (K, ) or (K-1, ).
132
+ - trajectory_weights (np.ndarray): The trajectory weights in
133
+ shape (K, ).
134
+ - target_root (np.ndarray): The root coordinate of target in
135
+ shape (C, ).
136
+
137
+ In addition, there are some optional items it may contain:
138
+
139
+ - target_root (np.ndarray): The root coordinate of target in
140
+ shape (C, ). Exists if ``zero_center`` is ``True``.
141
+ - target_root_removed (bool): Indicate whether the root of
142
+ pose-lifitng target is removed. Exists if
143
+ ``remove_root`` is ``True``.
144
+ - target_root_index (int): An integer indicating the index of
145
+ root. Exists if ``remove_root`` and ``save_index``
146
+ are ``True``.
147
+ """
148
+ if keypoints_visible is None:
149
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
150
+
151
+ if lifting_target is None:
152
+ lifting_target = [keypoints[0]]
153
+
154
+ # set initial value for `lifting_target_weight`
155
+ # and `trajectory_weights`
156
+ if lifting_target_visible is None:
157
+ lifting_target_visible = np.ones(
158
+ lifting_target.shape[:-1], dtype=np.float32)
159
+ lifting_target_weight = lifting_target_visible
160
+ trajectory_weights = (1 / lifting_target[:, 2])
161
+ else:
162
+ valid = lifting_target_visible > 0.5
163
+ lifting_target_weight = np.where(valid, 1., 0.).astype(np.float32)
164
+ trajectory_weights = lifting_target_weight
165
+
166
+ encoded = dict()
167
+
168
+ # Zero-center the target pose around a given root keypoint
169
+ assert (lifting_target.ndim >= 2 and
170
+ lifting_target.shape[-2] > max(self.root_index)), \
171
+ f'Got invalid joint shape {lifting_target.shape}'
172
+
173
+ root = np.mean(
174
+ lifting_target[..., self.root_index, :], axis=-2, dtype=np.float32)
175
+ lifting_target_label = lifting_target - root[np.newaxis, ...]
176
+
177
+ if self.remove_root and len(self.root_index) == 1:
178
+ root_index = self.root_index[0]
179
+ lifting_target_label = np.delete(
180
+ lifting_target_label, root_index, axis=-2)
181
+ lifting_target_visible = np.delete(
182
+ lifting_target_visible, root_index, axis=-2)
183
+ assert lifting_target_weight.ndim in {
184
+ 2, 3
185
+ }, (f'lifting_target_weight.ndim {lifting_target_weight.ndim} '
186
+ 'is not in {2, 3}')
187
+
188
+ axis_to_remove = -2 if lifting_target_weight.ndim == 3 else -1
189
+ lifting_target_weight = np.delete(
190
+ lifting_target_weight, root_index, axis=axis_to_remove)
191
+ # Add a flag to avoid latter transforms that rely on the root
192
+ # joint or the original joint index
193
+ encoded['target_root_removed'] = True
194
+
195
+ # Save the root index which is necessary to restore the global pose
196
+ if self.save_index:
197
+ encoded['target_root_index'] = root_index
198
+
199
+ # Normalize the 2D keypoint coordinate with mean and std
200
+ keypoint_labels = keypoints.copy()
201
+
202
+ if self.keypoints_mean is not None:
203
+ assert self.keypoints_mean.shape[1:] == keypoints.shape[1:], (
204
+ f'self.keypoints_mean.shape[1:] {self.keypoints_mean.shape[1:]} ' # noqa
205
+ f'!= keypoints.shape[1:] {keypoints.shape[1:]}')
206
+ encoded['keypoints_mean'] = self.keypoints_mean.copy()
207
+ encoded['keypoints_std'] = self.keypoints_std.copy()
208
+
209
+ keypoint_labels = (keypoint_labels -
210
+ self.keypoints_mean) / self.keypoints_std
211
+ if self.target_mean is not None:
212
+ assert self.target_mean.shape == lifting_target_label.shape, (
213
+ f'self.target_mean.shape {self.target_mean.shape} '
214
+ f'!= lifting_target_label.shape {lifting_target_label.shape}' # noqa
215
+ )
216
+ encoded['target_mean'] = self.target_mean.copy()
217
+ encoded['target_std'] = self.target_std.copy()
218
+
219
+ lifting_target_label = (lifting_target_label -
220
+ self.target_mean) / self.target_std
221
+
222
+ # Generate reshaped keypoint coordinates
223
+ assert keypoint_labels.ndim in {
224
+ 2, 3
225
+ }, (f'keypoint_labels.ndim {keypoint_labels.ndim} is not in {2, 3}')
226
+ if keypoint_labels.ndim == 2:
227
+ keypoint_labels = keypoint_labels[None, ...]
228
+
229
+ if self.concat_vis:
230
+ keypoints_visible_ = keypoints_visible
231
+ if keypoints_visible.ndim == 2:
232
+ keypoints_visible_ = keypoints_visible[..., None]
233
+ keypoint_labels = np.concatenate(
234
+ (keypoint_labels, keypoints_visible_), axis=2)
235
+
236
+ if self.reshape_keypoints:
237
+ N = keypoint_labels.shape[0]
238
+ keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N)
239
+
240
+ encoded['keypoint_labels'] = keypoint_labels
241
+ encoded['keypoint_labels_visible'] = keypoints_visible
242
+ encoded['lifting_target_label'] = lifting_target_label
243
+ encoded['lifting_target_weight'] = lifting_target_weight
244
+ encoded['trajectory_weights'] = trajectory_weights
245
+ encoded['target_root'] = root
246
+
247
+ return encoded
248
+
249
+ def decode(self,
250
+ encoded: np.ndarray,
251
+ target_root: Optional[np.ndarray] = None
252
+ ) -> Tuple[np.ndarray, np.ndarray]:
253
+ """Decode keypoint coordinates from normalized space to input image
254
+ space.
255
+
256
+ Args:
257
+ encoded (np.ndarray): Coordinates in shape (N, K, C).
258
+ target_root (np.ndarray, optional): The target root coordinate.
259
+ Default: ``None``.
260
+
261
+ Returns:
262
+ keypoints (np.ndarray): Decoded coordinates in shape (N, K, C).
263
+ scores (np.ndarray): The keypoint scores in shape (N, K).
264
+ """
265
+ keypoints = encoded.copy()
266
+
267
+ if self.target_mean is not None and self.target_std is not None:
268
+ assert self.target_mean.shape == keypoints.shape, (
269
+ f'self.target_mean.shape {self.target_mean.shape} '
270
+ f'!= keypoints.shape {keypoints.shape}')
271
+ keypoints = keypoints * self.target_std + self.target_mean
272
+
273
+ if target_root is not None and target_root.size > 0:
274
+ keypoints = keypoints + target_root
275
+ if self.remove_root and len(self.root_index) == 1:
276
+ keypoints = np.insert(
277
+ keypoints, self.root_index, target_root, axis=1)
278
+ scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
279
+
280
+ return keypoints, scores
mmpose/codecs/integral_regression_label.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from mmpose.registry import KEYPOINT_CODECS
8
+ from .base import BaseKeypointCodec
9
+ from .msra_heatmap import MSRAHeatmap
10
+ from .regression_label import RegressionLabel
11
+
12
+
13
+ @KEYPOINT_CODECS.register_module()
14
+ class IntegralRegressionLabel(BaseKeypointCodec):
15
+ """Generate keypoint coordinates and normalized heatmaps. See the paper:
16
+ `DSNT`_ by Nibali et al(2018).
17
+
18
+ Note:
19
+
20
+ - instance number: N
21
+ - keypoint number: K
22
+ - keypoint dimension: D
23
+ - image size: [w, h]
24
+
25
+ Encoded:
26
+
27
+ - keypoint_labels (np.ndarray): The normalized regression labels in
28
+ shape (N, K, D) where D is 2 for 2d coordinates
29
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) where
30
+ [W, H] is the `heatmap_size`
31
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
32
+
33
+ Args:
34
+ input_size (tuple): Input image size in [w, h]
35
+ heatmap_size (tuple): Heatmap size in [W, H]
36
+ sigma (float): The sigma value of the Gaussian heatmap
37
+ unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'``
38
+ encoding. See `Dark Pose`_ for details. Defaults to ``False``
39
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
40
+ modulation in DarkPose. The kernel size and sigma should follow
41
+ the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`.
42
+ Defaults to 11
43
+ normalize (bool): Whether to normalize the heatmaps. Defaults to True.
44
+
45
+ .. _`DSNT`: https://arxiv.org/abs/1801.07372
46
+ """
47
+
48
+ label_mapping_table = dict(
49
+ keypoint_labels='keypoint_labels',
50
+ keypoint_weights='keypoint_weights',
51
+ )
52
+ field_mapping_table = dict(heatmaps='heatmaps', )
53
+
54
+ def __init__(self,
55
+ input_size: Tuple[int, int],
56
+ heatmap_size: Tuple[int, int],
57
+ sigma: float,
58
+ unbiased: bool = False,
59
+ blur_kernel_size: int = 11,
60
+ normalize: bool = True) -> None:
61
+ super().__init__()
62
+
63
+ self.heatmap_codec = MSRAHeatmap(input_size, heatmap_size, sigma,
64
+ unbiased, blur_kernel_size)
65
+ self.keypoint_codec = RegressionLabel(input_size)
66
+ self.normalize = normalize
67
+
68
+ def encode(self,
69
+ keypoints: np.ndarray,
70
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
71
+ """Encoding keypoints to regression labels and heatmaps.
72
+
73
+ Args:
74
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
75
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
76
+ (N, K)
77
+
78
+ Returns:
79
+ dict:
80
+ - keypoint_labels (np.ndarray): The normalized regression labels in
81
+ shape (N, K, D) where D is 2 for 2d coordinates
82
+ - heatmaps (np.ndarray): The generated heatmap in shape
83
+ (K, H, W) where [W, H] is the `heatmap_size`
84
+ - keypoint_weights (np.ndarray): The target weights in shape
85
+ (N, K)
86
+ """
87
+ encoded_hm = self.heatmap_codec.encode(keypoints, keypoints_visible)
88
+ encoded_kp = self.keypoint_codec.encode(keypoints, keypoints_visible)
89
+
90
+ heatmaps = encoded_hm['heatmaps']
91
+ keypoint_labels = encoded_kp['keypoint_labels']
92
+ keypoint_weights = encoded_kp['keypoint_weights']
93
+
94
+ if self.normalize:
95
+ val_sum = heatmaps.sum(axis=(-1, -2)).reshape(-1, 1, 1) + 1e-24
96
+ heatmaps = heatmaps / val_sum
97
+
98
+ encoded = dict(
99
+ keypoint_labels=keypoint_labels,
100
+ heatmaps=heatmaps,
101
+ keypoint_weights=keypoint_weights)
102
+
103
+ return encoded
104
+
105
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
106
+ """Decode keypoint coordinates from normalized space to input image
107
+ space.
108
+
109
+ Args:
110
+ encoded (np.ndarray): Coordinates in shape (N, K, D)
111
+
112
+ Returns:
113
+ tuple:
114
+ - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
115
+ - socres (np.ndarray): The keypoint scores in shape (N, K).
116
+ It usually represents the confidence of the keypoint prediction
117
+ """
118
+
119
+ keypoints, scores = self.keypoint_codec.decode(encoded)
120
+
121
+ return keypoints, scores
mmpose/codecs/megvii_heatmap.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Optional, Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from mmpose.registry import KEYPOINT_CODECS
9
+ from .base import BaseKeypointCodec
10
+ from .utils import gaussian_blur, get_heatmap_maximum
11
+
12
+
13
+ @KEYPOINT_CODECS.register_module()
14
+ class MegviiHeatmap(BaseKeypointCodec):
15
+ """Represent keypoints as heatmaps via "Megvii" approach. See `MSPN`_
16
+ (2019) and `CPN`_ (2018) for details.
17
+
18
+ Note:
19
+
20
+ - instance number: N
21
+ - keypoint number: K
22
+ - keypoint dimension: D
23
+ - image size: [w, h]
24
+ - heatmap size: [W, H]
25
+
26
+ Encoded:
27
+
28
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
29
+ where [W, H] is the `heatmap_size`
30
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
31
+
32
+ Args:
33
+ input_size (tuple): Image size in [w, h]
34
+ heatmap_size (tuple): Heatmap size in [W, H]
35
+ kernel_size (tuple): The kernel size of the heatmap gaussian in
36
+ [ks_x, ks_y]
37
+
38
+ .. _`MSPN`: https://arxiv.org/abs/1901.00148
39
+ .. _`CPN`: https://arxiv.org/abs/1711.07319
40
+ """
41
+
42
+ label_mapping_table = dict(keypoint_weights='keypoint_weights', )
43
+ field_mapping_table = dict(heatmaps='heatmaps', )
44
+
45
+ def __init__(
46
+ self,
47
+ input_size: Tuple[int, int],
48
+ heatmap_size: Tuple[int, int],
49
+ kernel_size: int,
50
+ ) -> None:
51
+
52
+ super().__init__()
53
+ self.input_size = input_size
54
+ self.heatmap_size = heatmap_size
55
+ self.kernel_size = kernel_size
56
+ self.scale_factor = (np.array(input_size) /
57
+ heatmap_size).astype(np.float32)
58
+
59
+ def encode(self,
60
+ keypoints: np.ndarray,
61
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
62
+ """Encode keypoints into heatmaps. Note that the original keypoint
63
+ coordinates should be in the input image space.
64
+
65
+ Args:
66
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
67
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
68
+ (N, K)
69
+
70
+ Returns:
71
+ dict:
72
+ - heatmaps (np.ndarray): The generated heatmap in shape
73
+ (K, H, W) where [W, H] is the `heatmap_size`
74
+ - keypoint_weights (np.ndarray): The target weights in shape
75
+ (N, K)
76
+ """
77
+
78
+ N, K, _ = keypoints.shape
79
+ W, H = self.heatmap_size
80
+
81
+ assert N == 1, (
82
+ f'{self.__class__.__name__} only support single-instance '
83
+ 'keypoint encoding')
84
+
85
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
86
+ keypoint_weights = keypoints_visible.copy()
87
+
88
+ for n, k in product(range(N), range(K)):
89
+ # skip unlabled keypoints
90
+ if keypoints_visible[n, k] < 0.5:
91
+ continue
92
+
93
+ # get center coordinates
94
+ kx, ky = (keypoints[n, k] / self.scale_factor).astype(np.int64)
95
+ if kx < 0 or kx >= W or ky < 0 or ky >= H:
96
+ keypoint_weights[n, k] = 0
97
+ continue
98
+
99
+ heatmaps[k, ky, kx] = 1.
100
+ kernel_size = (self.kernel_size, self.kernel_size)
101
+ heatmaps[k] = cv2.GaussianBlur(heatmaps[k], kernel_size, 0)
102
+
103
+ # normalize the heatmap
104
+ heatmaps[k] = heatmaps[k] / heatmaps[k, ky, kx] * 255.
105
+
106
+ encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
107
+
108
+ return encoded
109
+
110
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
111
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
112
+ coordinates are in the input image space.
113
+
114
+ Args:
115
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
116
+
117
+ Returns:
118
+ tuple:
119
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
120
+ (K, D)
121
+ - scores (np.ndarray): The keypoint scores in shape (K,). It
122
+ usually represents the confidence of the keypoint prediction
123
+ """
124
+ heatmaps = gaussian_blur(encoded.copy(), self.kernel_size)
125
+ K, H, W = heatmaps.shape
126
+
127
+ keypoints, scores = get_heatmap_maximum(heatmaps)
128
+
129
+ for k in range(K):
130
+ heatmap = heatmaps[k]
131
+ px = int(keypoints[k, 0])
132
+ py = int(keypoints[k, 1])
133
+ if 1 < px < W - 1 and 1 < py < H - 1:
134
+ diff = np.array([
135
+ heatmap[py][px + 1] - heatmap[py][px - 1],
136
+ heatmap[py + 1][px] - heatmap[py - 1][px]
137
+ ])
138
+ keypoints[k] += (np.sign(diff) * 0.25 + 0.5)
139
+
140
+ scores = scores / 255.0 + 0.5
141
+
142
+ # Unsqueeze the instance dimension for single-instance results
143
+ # and restore the keypoint scales
144
+ keypoints = keypoints[None] * self.scale_factor
145
+ scores = scores[None]
146
+
147
+ return keypoints, scores
mmpose/codecs/motionbert_label.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from copy import deepcopy
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from mmpose.registry import KEYPOINT_CODECS
9
+ from .base import BaseKeypointCodec
10
+ from .utils import camera_to_image_coord
11
+
12
+
13
+ @KEYPOINT_CODECS.register_module()
14
+ class MotionBERTLabel(BaseKeypointCodec):
15
+ r"""Generate keypoint and label coordinates for `MotionBERT`_ by Zhu et al
16
+ (2022).
17
+
18
+ Note:
19
+
20
+ - instance number: N
21
+ - keypoint number: K
22
+ - keypoint dimension: D
23
+ - pose-lifitng target dimension: C
24
+
25
+ Args:
26
+ num_keypoints (int): The number of keypoints in the dataset.
27
+ root_index (int): Root keypoint index in the pose. Default: 0.
28
+ remove_root (bool): If true, remove the root keypoint from the pose.
29
+ Default: ``False``.
30
+ save_index (bool): If true, store the root position separated from the
31
+ original pose, only takes effect if ``remove_root`` is ``True``.
32
+ Default: ``False``.
33
+ concat_vis (bool): If true, concat the visibility item of keypoints.
34
+ Default: ``False``.
35
+ rootrel (bool): If true, the root keypoint will be set to the
36
+ coordinate origin. Default: ``False``.
37
+ mode (str): Indicating whether the current mode is 'train' or 'test'.
38
+ Default: ``'test'``.
39
+ """
40
+
41
+ auxiliary_encode_keys = {
42
+ 'lifting_target', 'lifting_target_visible', 'camera_param', 'factor'
43
+ }
44
+
45
+ instance_mapping_table = dict(
46
+ lifting_target='lifting_target',
47
+ lifting_target_visible='lifting_target_visible',
48
+ )
49
+ label_mapping_table = dict(
50
+ trajectory_weights='trajectory_weights',
51
+ lifting_target_label='lifting_target_label',
52
+ lifting_target_weight='lifting_target_weight')
53
+
54
+ def __init__(self,
55
+ num_keypoints: int,
56
+ root_index: int = 0,
57
+ remove_root: bool = False,
58
+ save_index: bool = False,
59
+ concat_vis: bool = False,
60
+ rootrel: bool = False,
61
+ mode: str = 'test'):
62
+ super().__init__()
63
+
64
+ self.num_keypoints = num_keypoints
65
+ self.root_index = root_index
66
+ self.remove_root = remove_root
67
+ self.save_index = save_index
68
+ self.concat_vis = concat_vis
69
+ self.rootrel = rootrel
70
+ assert mode.lower() in {'train', 'test'
71
+ }, (f'Unsupported mode {mode}, '
72
+ 'mode should be one of ("train", "test").')
73
+ self.mode = mode.lower()
74
+
75
+ def encode(self,
76
+ keypoints: np.ndarray,
77
+ keypoints_visible: Optional[np.ndarray] = None,
78
+ lifting_target: Optional[np.ndarray] = None,
79
+ lifting_target_visible: Optional[np.ndarray] = None,
80
+ camera_param: Optional[dict] = None,
81
+ factor: Optional[np.ndarray] = None) -> dict:
82
+ """Encoding keypoints from input image space to normalized space.
83
+
84
+ Args:
85
+ keypoints (np.ndarray): Keypoint coordinates in shape (B, T, K, D).
86
+ keypoints_visible (np.ndarray, optional): Keypoint visibilities in
87
+ shape (B, T, K).
88
+ lifting_target (np.ndarray, optional): 3d target coordinate in
89
+ shape (T, K, C).
90
+ lifting_target_visible (np.ndarray, optional): Target coordinate in
91
+ shape (T, K, ).
92
+ camera_param (dict, optional): The camera parameter dictionary.
93
+ factor (np.ndarray, optional): The factor mapping camera and image
94
+ coordinate in shape (T, ).
95
+
96
+ Returns:
97
+ encoded (dict): Contains the following items:
98
+
99
+ - keypoint_labels (np.ndarray): The processed keypoints in
100
+ shape like (N, K, D).
101
+ - keypoint_labels_visible (np.ndarray): The processed
102
+ keypoints' weights in shape (N, K, ) or (N, K-1, ).
103
+ - lifting_target_label: The processed target coordinate in
104
+ shape (K, C) or (K-1, C).
105
+ - lifting_target_weight (np.ndarray): The target weights in
106
+ shape (K, ) or (K-1, ).
107
+ - factor (np.ndarray): The factor mapping camera and image
108
+ coordinate in shape (T, 1).
109
+ """
110
+ if keypoints_visible is None:
111
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
112
+
113
+ # set initial value for `lifting_target_weight`
114
+ if lifting_target_visible is None:
115
+ lifting_target_visible = np.ones(
116
+ lifting_target.shape[:-1], dtype=np.float32)
117
+ lifting_target_weight = lifting_target_visible
118
+ else:
119
+ valid = lifting_target_visible > 0.5
120
+ lifting_target_weight = np.where(valid, 1., 0.).astype(np.float32)
121
+
122
+ if camera_param is None:
123
+ camera_param = dict()
124
+
125
+ encoded = dict()
126
+
127
+ assert lifting_target is not None
128
+ lifting_target_label = lifting_target.copy()
129
+ keypoint_labels = keypoints.copy()
130
+
131
+ assert keypoint_labels.ndim in {
132
+ 2, 3
133
+ }, (f'Keypoint labels should have 2 or 3 dimensions, '
134
+ f'but got {keypoint_labels.ndim}.')
135
+ if keypoint_labels.ndim == 2:
136
+ keypoint_labels = keypoint_labels[None, ...]
137
+
138
+ # Normalize the 2D keypoint coordinate with image width and height
139
+ _camera_param = deepcopy(camera_param)
140
+ assert 'w' in _camera_param and 'h' in _camera_param, (
141
+ 'Camera parameters should contain "w" and "h".')
142
+ w, h = _camera_param['w'], _camera_param['h']
143
+ keypoint_labels[
144
+ ..., :2] = keypoint_labels[..., :2] / w * 2 - [1, h / w]
145
+
146
+ # convert target to image coordinate
147
+ T = keypoint_labels.shape[0]
148
+ factor_ = np.array([4] * T, dtype=np.float32).reshape(T, )
149
+ if 'f' in _camera_param and 'c' in _camera_param:
150
+ lifting_target_label, factor_ = camera_to_image_coord(
151
+ self.root_index, lifting_target_label, _camera_param)
152
+ if self.mode == 'train':
153
+ w, h = w / 1000, h / 1000
154
+ lifting_target_label[
155
+ ..., :2] = lifting_target_label[..., :2] / w * 2 - [1, h / w]
156
+ lifting_target_label[..., 2] = lifting_target_label[..., 2] / w * 2
157
+ lifting_target_label[..., :, :] = lifting_target_label[
158
+ ..., :, :] - lifting_target_label[...,
159
+ self.root_index:self.root_index +
160
+ 1, :]
161
+ if factor is None or factor[0] == 0:
162
+ factor = factor_
163
+ if factor.ndim == 1:
164
+ factor = factor[:, None]
165
+ if self.mode == 'test':
166
+ lifting_target_label *= factor[..., None]
167
+
168
+ if self.concat_vis:
169
+ keypoints_visible_ = keypoints_visible
170
+ if keypoints_visible.ndim == 2:
171
+ keypoints_visible_ = keypoints_visible[..., None]
172
+ keypoint_labels = np.concatenate(
173
+ (keypoint_labels, keypoints_visible_), axis=2)
174
+
175
+ encoded['keypoint_labels'] = keypoint_labels
176
+ encoded['keypoint_labels_visible'] = keypoints_visible
177
+ encoded['lifting_target_label'] = lifting_target_label
178
+ encoded['lifting_target_weight'] = lifting_target_weight
179
+ encoded['lifting_target'] = lifting_target_label
180
+ encoded['lifting_target_visible'] = lifting_target_visible
181
+ encoded['factor'] = factor
182
+
183
+ return encoded
184
+
185
+ def decode(
186
+ self,
187
+ encoded: np.ndarray,
188
+ w: Optional[np.ndarray] = None,
189
+ h: Optional[np.ndarray] = None,
190
+ factor: Optional[np.ndarray] = None,
191
+ ) -> Tuple[np.ndarray, np.ndarray]:
192
+ """Decode keypoint coordinates from normalized space to input image
193
+ space.
194
+
195
+ Args:
196
+ encoded (np.ndarray): Coordinates in shape (N, K, C).
197
+ w (np.ndarray, optional): The image widths in shape (N, ).
198
+ Default: ``None``.
199
+ h (np.ndarray, optional): The image heights in shape (N, ).
200
+ Default: ``None``.
201
+ factor (np.ndarray, optional): The factor for projection in shape
202
+ (N, ). Default: ``None``.
203
+
204
+ Returns:
205
+ keypoints (np.ndarray): Decoded coordinates in shape (N, K, C).
206
+ scores (np.ndarray): The keypoint scores in shape (N, K).
207
+ """
208
+ keypoints = encoded.copy()
209
+ scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
210
+
211
+ if self.rootrel:
212
+ keypoints[..., 0, :] = 0
213
+
214
+ if w is not None and w.size > 0:
215
+ assert w.shape == h.shape, (f'w and h should have the same shape, '
216
+ f'but got {w.shape} and {h.shape}.')
217
+ assert w.shape[0] == keypoints.shape[0], (
218
+ f'w and h should have the same batch size, '
219
+ f'but got {w.shape[0]} and {keypoints.shape[0]}.')
220
+ assert w.ndim in {1,
221
+ 2}, (f'w and h should have 1 or 2 dimensions, '
222
+ f'but got {w.ndim}.')
223
+ if w.ndim == 1:
224
+ w = w[:, None]
225
+ h = h[:, None]
226
+ trans = np.append(
227
+ np.ones((w.shape[0], 1)), h / w, axis=1)[:, None, :]
228
+ keypoints[..., :2] = (keypoints[..., :2] + trans) * w[:, None] / 2
229
+ keypoints[..., 2:] = keypoints[..., 2:] * w[:, None] / 2
230
+
231
+ if factor is not None and factor.size > 0:
232
+ assert factor.shape[0] == keypoints.shape[0], (
233
+ f'factor should have the same batch size, '
234
+ f'but got {factor.shape[0]} and {keypoints.shape[0]}.')
235
+ keypoints *= factor[..., None]
236
+
237
+ keypoints[..., :, :] = keypoints[..., :, :] - keypoints[
238
+ ..., self.root_index:self.root_index + 1, :]
239
+ keypoints /= 1000.
240
+ return keypoints, scores
mmpose/codecs/msra_heatmap.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from mmpose.registry import KEYPOINT_CODECS
7
+ from .base import BaseKeypointCodec
8
+ from .utils.gaussian_heatmap import (generate_gaussian_heatmaps,
9
+ generate_unbiased_gaussian_heatmaps)
10
+ from .utils.post_processing import get_heatmap_maximum
11
+ from .utils.refinement import refine_keypoints, refine_keypoints_dark
12
+
13
+
14
+ @KEYPOINT_CODECS.register_module()
15
+ class MSRAHeatmap(BaseKeypointCodec):
16
+ """Represent keypoints as heatmaps via "MSRA" approach. See the paper:
17
+ `Simple Baselines for Human Pose Estimation and Tracking`_ by Xiao et al
18
+ (2018) for details.
19
+
20
+ Note:
21
+
22
+ - instance number: N
23
+ - keypoint number: K
24
+ - keypoint dimension: D
25
+ - image size: [w, h]
26
+ - heatmap size: [W, H]
27
+
28
+ Encoded:
29
+
30
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
31
+ where [W, H] is the `heatmap_size`
32
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
33
+
34
+ Args:
35
+ input_size (tuple): Image size in [w, h]
36
+ heatmap_size (tuple): Heatmap size in [W, H]
37
+ sigma (float): The sigma value of the Gaussian heatmap
38
+ unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'``
39
+ encoding. See `Dark Pose`_ for details. Defaults to ``False``
40
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
41
+ modulation in DarkPose. The kernel size and sigma should follow
42
+ the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`.
43
+ Defaults to 11
44
+
45
+ .. _`Simple Baselines for Human Pose Estimation and Tracking`:
46
+ https://arxiv.org/abs/1804.06208
47
+ .. _`Dark Pose`: https://arxiv.org/abs/1910.06278
48
+ """
49
+
50
+ label_mapping_table = dict(keypoint_weights='keypoint_weights', )
51
+ field_mapping_table = dict(heatmaps='heatmaps', )
52
+
53
+ def __init__(self,
54
+ input_size: Tuple[int, int],
55
+ heatmap_size: Tuple[int, int],
56
+ sigma: float,
57
+ unbiased: bool = False,
58
+ blur_kernel_size: int = 11) -> None:
59
+ super().__init__()
60
+ self.input_size = input_size
61
+ self.heatmap_size = heatmap_size
62
+ self.sigma = sigma
63
+ self.unbiased = unbiased
64
+
65
+ # The Gaussian blur kernel size of the heatmap modulation
66
+ # in DarkPose and the sigma value follows the expirical
67
+ # formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`
68
+ # which gives:
69
+ # sigma~=3 if ks=17
70
+ # sigma=2 if ks=11;
71
+ # sigma~=1.5 if ks=7;
72
+ # sigma~=1 if ks=3;
73
+ self.blur_kernel_size = blur_kernel_size
74
+ self.scale_factor = (np.array(input_size) /
75
+ heatmap_size).astype(np.float32)
76
+
77
+ def encode(self,
78
+ keypoints: np.ndarray,
79
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
80
+ """Encode keypoints into heatmaps. Note that the original keypoint
81
+ coordinates should be in the input image space.
82
+
83
+ Args:
84
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
85
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
86
+ (N, K)
87
+
88
+ Returns:
89
+ dict:
90
+ - heatmaps (np.ndarray): The generated heatmap in shape
91
+ (K, H, W) where [W, H] is the `heatmap_size`
92
+ - keypoint_weights (np.ndarray): The target weights in shape
93
+ (N, K)
94
+ """
95
+
96
+ assert keypoints.shape[0] == 1, (
97
+ f'{self.__class__.__name__} only support single-instance '
98
+ 'keypoint encoding')
99
+
100
+ if keypoints_visible is None:
101
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
102
+
103
+ if self.unbiased:
104
+ heatmaps, keypoint_weights = generate_unbiased_gaussian_heatmaps(
105
+ heatmap_size=self.heatmap_size,
106
+ keypoints=keypoints / self.scale_factor,
107
+ keypoints_visible=keypoints_visible,
108
+ sigma=self.sigma)
109
+ else:
110
+ heatmaps, keypoint_weights = generate_gaussian_heatmaps(
111
+ heatmap_size=self.heatmap_size,
112
+ keypoints=keypoints / self.scale_factor,
113
+ keypoints_visible=keypoints_visible,
114
+ sigma=self.sigma)
115
+
116
+ encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
117
+
118
+ return encoded
119
+
120
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
121
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
122
+ coordinates are in the input image space.
123
+
124
+ Args:
125
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
126
+
127
+ Returns:
128
+ tuple:
129
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
130
+ (N, K, D)
131
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
132
+ usually represents the confidence of the keypoint prediction
133
+ """
134
+ heatmaps = encoded.copy()
135
+ K, H, W = heatmaps.shape
136
+
137
+ keypoints, scores = get_heatmap_maximum(heatmaps)
138
+
139
+ # Unsqueeze the instance dimension for single-instance results
140
+ keypoints, scores = keypoints[None], scores[None]
141
+
142
+ if self.unbiased:
143
+ # Alleviate biased coordinate
144
+ keypoints = refine_keypoints_dark(
145
+ keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
146
+
147
+ else:
148
+ keypoints = refine_keypoints(keypoints, heatmaps)
149
+
150
+ # Restore the keypoint scale
151
+ keypoints = keypoints * self.scale_factor
152
+
153
+ return keypoints, scores
mmpose/codecs/onehot_heatmap.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from mmpose.registry import KEYPOINT_CODECS
8
+ from .base import BaseKeypointCodec
9
+ from .utils import (generate_offset_heatmap, generate_onehot_heatmaps,
10
+ get_heatmap_maximum, refine_keypoints_dark_udp)
11
+
12
+
13
+ @KEYPOINT_CODECS.register_module()
14
+ class OneHotHeatmap(BaseKeypointCodec):
15
+ r"""Generate keypoint heatmaps by Unbiased Data Processing (UDP).
16
+ See the paper: `The Devil is in the Details: Delving into Unbiased Data
17
+ Processing for Human Pose Estimation`_ by Huang et al (2020) for details.
18
+
19
+ Note:
20
+
21
+ - instance number: N
22
+ - keypoint number: K
23
+ - keypoint dimension: D
24
+ - image size: [w, h]
25
+ - heatmap size: [W, H]
26
+
27
+ Encoded:
28
+
29
+ - heatmap (np.ndarray): The generated heatmap in shape (C_out, H, W)
30
+ where [W, H] is the `heatmap_size`, and the C_out is the output
31
+ channel number which depends on the `heatmap_type`. If
32
+ `heatmap_type=='gaussian'`, C_out equals to keypoint number K;
33
+ if `heatmap_type=='combined'`, C_out equals to K*3
34
+ (x_offset, y_offset and class label)
35
+ - keypoint_weights (np.ndarray): The target weights in shape (K,)
36
+
37
+ Args:
38
+ input_size (tuple): Image size in [w, h]
39
+ heatmap_size (tuple): Heatmap size in [W, H]
40
+ heatmap_type (str): The heatmap type to encode the keypoitns. Options
41
+ are:
42
+
43
+ - ``'gaussian'``: Gaussian heatmap
44
+ - ``'combined'``: Combination of a binary label map and offset
45
+ maps for X and Y axes.
46
+
47
+ sigma (float): The sigma value of the Gaussian heatmap when
48
+ ``heatmap_type=='gaussian'``. Defaults to 2.0
49
+ radius_factor (float): The radius factor of the binary label
50
+ map when ``heatmap_type=='combined'``. The positive region is
51
+ defined as the neighbor of the keypoit with the radius
52
+ :math:`r=radius_factor*max(W, H)`. Defaults to 0.0546875
53
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
54
+ modulation in DarkPose. Defaults to 11
55
+
56
+ .. _`The Devil is in the Details: Delving into Unbiased Data Processing for
57
+ Human Pose Estimation`: https://arxiv.org/abs/1911.07524
58
+ """
59
+
60
+ label_mapping_table = dict(keypoint_weights='keypoint_weights', )
61
+ field_mapping_table = dict(heatmaps='heatmaps', )
62
+
63
+ def __init__(self,
64
+ input_size: Tuple[int, int],
65
+ heatmap_size: Tuple[int, int],
66
+ heatmap_type: str = 'gaussian',
67
+ sigma: float = 2.,
68
+ radius_factor: float = 0.0546875,
69
+ blur_kernel_size: int = 11,
70
+ increase_sigma_with_padding=False,
71
+ amap_scale: float = 1.0,
72
+ normalize=None,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.input_size = np.array(input_size)
76
+ self.heatmap_size = np.array(heatmap_size)
77
+ self.sigma = sigma
78
+ self.radius_factor = radius_factor
79
+ self.heatmap_type = heatmap_type
80
+ self.blur_kernel_size = blur_kernel_size
81
+ self.increase_sigma_with_padding = increase_sigma_with_padding
82
+ self.normalize = normalize
83
+
84
+ self.amap_size = self.input_size * amap_scale
85
+ self.scale_factor = ((self.amap_size - 1) /
86
+ (self.heatmap_size - 1)).astype(np.float32)
87
+ self.input_center = self.input_size / 2
88
+ self.top_left = self.input_center - self.amap_size / 2
89
+
90
+ if self.heatmap_type not in {'gaussian', 'combined'}:
91
+ raise ValueError(
92
+ f'{self.__class__.__name__} got invalid `heatmap_type` value'
93
+ f'{self.heatmap_type}. Should be one of '
94
+ '{"gaussian", "combined"}')
95
+
96
+ def _kpts_to_activation_pts(self, keypoints: np.ndarray) -> np.ndarray:
97
+ """
98
+ Transform the keypoint coordinates to the activation space.
99
+ In the original UDPHeatmap, activation map is the same as the input image space with
100
+ different resolution but in this case we allow the activation map to have different
101
+ size (padding) than the input image space.
102
+ Centers of activation map and input image space are aligned.
103
+ """
104
+ transformed_keypoints = keypoints - self.top_left
105
+ transformed_keypoints = transformed_keypoints / self.scale_factor
106
+ return transformed_keypoints
107
+
108
+ def _activation_pts_to_kpts(self, keypoints: np.ndarray) -> np.ndarray:
109
+ """
110
+ Transform the points in activation map to the keypoint coordinates.
111
+ In the original UDPHeatmap, activation map is the same as the input image space with
112
+ different resolution but in this case we allow the activation map to have different
113
+ size (padding) than the input image space.
114
+ Centers of activation map and input image space are aligned.
115
+ """
116
+ W, H = self.heatmap_size
117
+ transformed_keypoints = keypoints / [W - 1, H - 1] * self.amap_size
118
+ transformed_keypoints += self.top_left
119
+ return transformed_keypoints
120
+
121
+ def encode(self,
122
+ keypoints: np.ndarray,
123
+ keypoints_visible: Optional[np.ndarray] = None,
124
+ id_similarity: Optional[float] = 0.0,
125
+ keypoints_visibility: Optional[np.ndarray] = None) -> dict:
126
+ """Encode keypoints into heatmaps. Note that the original keypoint
127
+ coordinates should be in the input image space.
128
+
129
+ Args:
130
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
131
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
132
+ (N, K)
133
+ id_similarity (float): The usefulness of the identity information
134
+ for the whole pose. Defaults to 0.0
135
+ keypoints_visibility (np.ndarray): The visibility bit for each
136
+ keypoint (N, K). Defaults to None
137
+
138
+ Returns:
139
+ dict:
140
+ - heatmap (np.ndarray): The generated heatmap in shape
141
+ (C_out, H, W) where [W, H] is the `heatmap_size`, and the
142
+ C_out is the output channel number which depends on the
143
+ `heatmap_type`. If `heatmap_type=='gaussian'`, C_out equals to
144
+ keypoint number K; if `heatmap_type=='combined'`, C_out
145
+ equals to K*3 (x_offset, y_offset and class label)
146
+ - keypoint_weights (np.ndarray): The target weights in shape
147
+ (K,)
148
+ """
149
+ assert keypoints.shape[0] == 1, (
150
+ f'{self.__class__.__name__} only support single-instance '
151
+ 'keypoint encoding')
152
+
153
+ if keypoints_visibility is None:
154
+ keypoints_visibility = np.zeros(keypoints.shape[:2], dtype=np.float32)
155
+
156
+ if keypoints_visible is None:
157
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
158
+
159
+ if self.heatmap_type == 'gaussian':
160
+ heatmaps, keypoint_weights = generate_onehot_heatmaps(
161
+ heatmap_size=self.heatmap_size,
162
+ keypoints=self._kpts_to_activation_pts(keypoints),
163
+ keypoints_visible=keypoints_visible,
164
+ sigma=self.sigma,
165
+ keypoints_visibility=keypoints_visibility,
166
+ increase_sigma_with_padding=self.increase_sigma_with_padding)
167
+ elif self.heatmap_type == 'combined':
168
+ heatmaps, keypoint_weights = generate_offset_heatmap(
169
+ heatmap_size=self.heatmap_size,
170
+ keypoints=self._kpts_to_activation_pts(keypoints),
171
+ keypoints_visible=keypoints_visible,
172
+ radius_factor=self.radius_factor)
173
+ else:
174
+ raise ValueError(
175
+ f'{self.__class__.__name__} got invalid `heatmap_type` value'
176
+ f'{self.heatmap_type}. Should be one of '
177
+ '{"gaussian", "combined"}')
178
+
179
+ if self.normalize is not None:
180
+ heatmaps_sum = np.sum(heatmaps, axis=(1, 2), keepdims=False)
181
+ mask = heatmaps_sum > 0
182
+ heatmaps[mask, :, :] = heatmaps[mask, :, :] / (heatmaps_sum[mask, None, None] + np.finfo(np.float32).eps)
183
+ heatmaps = heatmaps * self.normalize
184
+
185
+ annotated = keypoints_visible > 0
186
+
187
+ heatmap_keypoints = self._kpts_to_activation_pts(keypoints)
188
+ in_image = np.logical_and(
189
+ heatmap_keypoints[:, :, 0] >= 0,
190
+ heatmap_keypoints[:, :, 0] < self.heatmap_size[0],
191
+ )
192
+ in_image = np.logical_and(
193
+ in_image,
194
+ heatmap_keypoints[:, :, 1] >= 0,
195
+ )
196
+ in_image = np.logical_and(
197
+ in_image,
198
+ heatmap_keypoints[:, :, 1] < self.heatmap_size[1],
199
+ )
200
+
201
+ encoded = dict(
202
+ heatmaps=heatmaps,
203
+ keypoint_weights=keypoint_weights,
204
+ annotated=annotated,
205
+ in_image=in_image,
206
+ keypoints_scaled=keypoints,
207
+ heatmap_keypoints=heatmap_keypoints,
208
+ identification_similarity=id_similarity,
209
+ )
210
+
211
+ return encoded
212
+
213
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
214
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
215
+ coordinates are in the input image space.
216
+
217
+ Args:
218
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
219
+
220
+ Returns:
221
+ tuple:
222
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
223
+ (N, K, D)
224
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
225
+ usually represents the confidence of the keypoint prediction
226
+ """
227
+ heatmaps = encoded.copy()
228
+
229
+ if self.heatmap_type == 'gaussian':
230
+ keypoints, scores = get_heatmap_maximum(heatmaps)
231
+ # unsqueeze the instance dimension for single-instance results
232
+ keypoints = keypoints[None]
233
+ scores = scores[None]
234
+
235
+ keypoints = refine_keypoints_dark_udp(
236
+ keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
237
+
238
+ elif self.heatmap_type == 'combined':
239
+ _K, H, W = heatmaps.shape
240
+ K = _K // 3
241
+
242
+ for cls_heatmap in heatmaps[::3]:
243
+ # Apply Gaussian blur on classification maps
244
+ ks = 2 * self.blur_kernel_size + 1
245
+ cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap)
246
+
247
+ # valid radius
248
+ radius = self.radius_factor * max(W, H)
249
+
250
+ x_offset = heatmaps[1::3].flatten() * radius
251
+ y_offset = heatmaps[2::3].flatten() * radius
252
+ keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3])
253
+ index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten()
254
+ index += W * H * np.arange(0, K)
255
+ index = index.astype(int)
256
+ keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1)
257
+ # unsqueeze the instance dimension for single-instance results
258
+ keypoints = keypoints[None].astype(np.float32)
259
+ scores = scores[None]
260
+
261
+ keypoints = self._activation_pts_to_kpts(keypoints)
262
+
263
+ return keypoints, scores
mmpose/codecs/regression_label.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from mmpose.registry import KEYPOINT_CODECS
8
+ from .base import BaseKeypointCodec
9
+
10
+
11
+ @KEYPOINT_CODECS.register_module()
12
+ class RegressionLabel(BaseKeypointCodec):
13
+ r"""Generate keypoint coordinates.
14
+
15
+ Note:
16
+
17
+ - instance number: N
18
+ - keypoint number: K
19
+ - keypoint dimension: D
20
+ - image size: [w, h]
21
+
22
+ Encoded:
23
+
24
+ - keypoint_labels (np.ndarray): The normalized regression labels in
25
+ shape (N, K, D) where D is 2 for 2d coordinates
26
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
27
+
28
+ Args:
29
+ input_size (tuple): Input image size in [w, h]
30
+
31
+ """
32
+
33
+ label_mapping_table = dict(
34
+ keypoint_labels='keypoint_labels',
35
+ keypoint_weights='keypoint_weights',
36
+ )
37
+
38
+ def __init__(self, input_size: Tuple[int, int]) -> None:
39
+ super().__init__()
40
+
41
+ self.input_size = input_size
42
+
43
+ def encode(self,
44
+ keypoints: np.ndarray,
45
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
46
+ """Encoding keypoints from input image space to normalized space.
47
+
48
+ Args:
49
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
50
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
51
+ (N, K)
52
+
53
+ Returns:
54
+ dict:
55
+ - keypoint_labels (np.ndarray): The normalized regression labels in
56
+ shape (N, K, D) where D is 2 for 2d coordinates
57
+ - keypoint_weights (np.ndarray): The target weights in shape
58
+ (N, K)
59
+ """
60
+ if keypoints_visible is None:
61
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
62
+
63
+ w, h = self.input_size
64
+ valid = ((keypoints >= 0) &
65
+ (keypoints <= [w - 1, h - 1])).all(axis=-1) & (
66
+ keypoints_visible > 0.5)
67
+
68
+ keypoint_labels = (keypoints / np.array([w, h])).astype(np.float32)
69
+ keypoint_weights = np.where(valid, 1., 0.).astype(np.float32)
70
+
71
+ encoded = dict(
72
+ keypoint_labels=keypoint_labels, keypoint_weights=keypoint_weights)
73
+
74
+ return encoded
75
+
76
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
77
+ """Decode keypoint coordinates from normalized space to input image
78
+ space.
79
+
80
+ Args:
81
+ encoded (np.ndarray): Coordinates in shape (N, K, D)
82
+
83
+ Returns:
84
+ tuple:
85
+ - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
86
+ - scores (np.ndarray): The keypoint scores in shape (N, K).
87
+ It usually represents the confidence of the keypoint prediction
88
+ """
89
+
90
+ if encoded.shape[-1] == 2:
91
+ N, K, _ = encoded.shape
92
+ normalized_coords = encoded.copy()
93
+ scores = np.ones((N, K), dtype=np.float32)
94
+ elif encoded.shape[-1] == 4:
95
+ # split coords and sigma if outputs contain output_sigma
96
+ normalized_coords = encoded[..., :2].copy()
97
+ output_sigma = encoded[..., 2:4].copy()
98
+
99
+ scores = (1 - output_sigma).mean(axis=-1)
100
+ else:
101
+ raise ValueError(
102
+ 'Keypoint dimension should be 2 or 4 (with sigma), '
103
+ f'but got {encoded.shape[-1]}')
104
+
105
+ w, h = self.input_size
106
+ keypoints = normalized_coords * np.array([w, h])
107
+
108
+ return keypoints, scores
mmpose/codecs/simcc_label.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+
7
+ from mmpose.codecs.utils import get_simcc_maximum
8
+ from mmpose.codecs.utils.refinement import refine_simcc_dark
9
+ from mmpose.registry import KEYPOINT_CODECS
10
+ from .base import BaseKeypointCodec
11
+
12
+
13
+ @KEYPOINT_CODECS.register_module()
14
+ class SimCCLabel(BaseKeypointCodec):
15
+ r"""Generate keypoint representation via "SimCC" approach.
16
+ See the paper: `SimCC: a Simple Coordinate Classification Perspective for
17
+ Human Pose Estimation`_ by Li et al (2022) for more details.
18
+ Old name: SimDR
19
+
20
+ Note:
21
+
22
+ - instance number: N
23
+ - keypoint number: K
24
+ - keypoint dimension: D
25
+ - image size: [w, h]
26
+
27
+ Encoded:
28
+
29
+ - keypoint_x_labels (np.ndarray): The generated SimCC label for x-axis.
30
+ The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'``
31
+ and (N, K) if `smoothing_type=='standard'``, where
32
+ :math:`Wx=w*simcc_split_ratio`
33
+ - keypoint_y_labels (np.ndarray): The generated SimCC label for y-axis.
34
+ The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'``
35
+ and (N, K) if `smoothing_type=='standard'``, where
36
+ :math:`Wy=h*simcc_split_ratio`
37
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
38
+
39
+ Args:
40
+ input_size (tuple): Input image size in [w, h]
41
+ smoothing_type (str): The SimCC label smoothing strategy. Options are
42
+ ``'gaussian'`` and ``'standard'``. Defaults to ``'gaussian'``
43
+ sigma (float | int | tuple): The sigma value in the Gaussian SimCC
44
+ label. Defaults to 6.0
45
+ simcc_split_ratio (float): The ratio of the label size to the input
46
+ size. For example, if the input width is ``w``, the x label size
47
+ will be :math:`w*simcc_split_ratio`. Defaults to 2.0
48
+ label_smooth_weight (float): Label Smoothing weight. Defaults to 0.0
49
+ normalize (bool): Whether to normalize the heatmaps. Defaults to True.
50
+ use_dark (bool): Whether to use the DARK post processing. Defaults to
51
+ False.
52
+ decode_visibility (bool): Whether to decode the visibility. Defaults
53
+ to False.
54
+ decode_beta (float): The beta value for decoding visibility. Defaults
55
+ to 150.0.
56
+
57
+ .. _`SimCC: a Simple Coordinate Classification Perspective for Human Pose
58
+ Estimation`: https://arxiv.org/abs/2107.03332
59
+ """
60
+
61
+ label_mapping_table = dict(
62
+ keypoint_x_labels='keypoint_x_labels',
63
+ keypoint_y_labels='keypoint_y_labels',
64
+ keypoint_weights='keypoint_weights',
65
+ )
66
+
67
+ def __init__(
68
+ self,
69
+ input_size: Tuple[int, int],
70
+ smoothing_type: str = 'gaussian',
71
+ sigma: Union[float, int, Tuple[float]] = 6.0,
72
+ simcc_split_ratio: float = 2.0,
73
+ label_smooth_weight: float = 0.0,
74
+ normalize: bool = True,
75
+ use_dark: bool = False,
76
+ decode_visibility: bool = False,
77
+ decode_beta: float = 150.0,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.input_size = input_size
82
+ self.smoothing_type = smoothing_type
83
+ self.simcc_split_ratio = simcc_split_ratio
84
+ self.label_smooth_weight = label_smooth_weight
85
+ self.normalize = normalize
86
+ self.use_dark = use_dark
87
+ self.decode_visibility = decode_visibility
88
+ self.decode_beta = decode_beta
89
+
90
+ if isinstance(sigma, (float, int)):
91
+ self.sigma = np.array([sigma, sigma])
92
+ else:
93
+ self.sigma = np.array(sigma)
94
+
95
+ if self.smoothing_type not in {'gaussian', 'standard'}:
96
+ raise ValueError(
97
+ f'{self.__class__.__name__} got invalid `smoothing_type` value'
98
+ f'{self.smoothing_type}. Should be one of '
99
+ '{"gaussian", "standard"}')
100
+
101
+ if self.smoothing_type == 'gaussian' and self.label_smooth_weight > 0:
102
+ raise ValueError('Attribute `label_smooth_weight` is only '
103
+ 'used for `standard` mode.')
104
+
105
+ if self.label_smooth_weight < 0.0 or self.label_smooth_weight > 1.0:
106
+ raise ValueError('`label_smooth_weight` should be in range [0, 1]')
107
+
108
+ def encode(self,
109
+ keypoints: np.ndarray,
110
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
111
+ """Encoding keypoints into SimCC labels. Note that the original
112
+ keypoint coordinates should be in the input image space.
113
+
114
+ Args:
115
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
116
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
117
+ (N, K)
118
+
119
+ Returns:
120
+ dict:
121
+ - keypoint_x_labels (np.ndarray): The generated SimCC label for
122
+ x-axis.
123
+ The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'``
124
+ and (N, K) if `smoothing_type=='standard'``, where
125
+ :math:`Wx=w*simcc_split_ratio`
126
+ - keypoint_y_labels (np.ndarray): The generated SimCC label for
127
+ y-axis.
128
+ The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'``
129
+ and (N, K) if `smoothing_type=='standard'``, where
130
+ :math:`Wy=h*simcc_split_ratio`
131
+ - keypoint_weights (np.ndarray): The target weights in shape
132
+ (N, K)
133
+ """
134
+ if keypoints_visible is None:
135
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
136
+
137
+ if self.smoothing_type == 'gaussian':
138
+ x_labels, y_labels, keypoint_weights = self._generate_gaussian(
139
+ keypoints, keypoints_visible)
140
+ elif self.smoothing_type == 'standard':
141
+ x_labels, y_labels, keypoint_weights = self._generate_standard(
142
+ keypoints, keypoints_visible)
143
+ else:
144
+ raise ValueError(
145
+ f'{self.__class__.__name__} got invalid `smoothing_type` value'
146
+ f'{self.smoothing_type}. Should be one of '
147
+ '{"gaussian", "standard"}')
148
+
149
+ encoded = dict(
150
+ keypoint_x_labels=x_labels,
151
+ keypoint_y_labels=y_labels,
152
+ keypoint_weights=keypoint_weights)
153
+
154
+ return encoded
155
+
156
+ def decode(self, simcc_x: np.ndarray,
157
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
158
+ """Decode keypoint coordinates from SimCC representations. The decoded
159
+ coordinates are in the input image space.
160
+
161
+ Args:
162
+ encoded (Tuple[np.ndarray, np.ndarray]): SimCC labels for x-axis
163
+ and y-axis
164
+ simcc_x (np.ndarray): SimCC label for x-axis
165
+ simcc_y (np.ndarray): SimCC label for y-axis
166
+
167
+ Returns:
168
+ tuple:
169
+ - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
170
+ - socres (np.ndarray): The keypoint scores in shape (N, K).
171
+ It usually represents the confidence of the keypoint prediction
172
+ """
173
+
174
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
175
+
176
+ # Unsqueeze the instance dimension for single-instance results
177
+ if keypoints.ndim == 2:
178
+ keypoints = keypoints[None, :]
179
+ scores = scores[None, :]
180
+
181
+ if self.use_dark:
182
+ x_blur = int((self.sigma[0] * 20 - 7) // 3)
183
+ y_blur = int((self.sigma[1] * 20 - 7) // 3)
184
+ x_blur -= int((x_blur % 2) == 0)
185
+ y_blur -= int((y_blur % 2) == 0)
186
+ keypoints[:, :, 0] = refine_simcc_dark(keypoints[:, :, 0], simcc_x,
187
+ x_blur)
188
+ keypoints[:, :, 1] = refine_simcc_dark(keypoints[:, :, 1], simcc_y,
189
+ y_blur)
190
+
191
+ keypoints /= self.simcc_split_ratio
192
+
193
+ if self.decode_visibility:
194
+ _, visibility = get_simcc_maximum(
195
+ simcc_x * self.decode_beta * self.sigma[0],
196
+ simcc_y * self.decode_beta * self.sigma[1],
197
+ apply_softmax=True)
198
+ return keypoints, (scores, visibility)
199
+ else:
200
+ return keypoints, scores
201
+
202
+ def _map_coordinates(
203
+ self,
204
+ keypoints: np.ndarray,
205
+ keypoints_visible: Optional[np.ndarray] = None
206
+ ) -> Tuple[np.ndarray, np.ndarray]:
207
+ """Mapping keypoint coordinates into SimCC space."""
208
+
209
+ keypoints_split = keypoints.copy()
210
+ keypoints_split = np.around(keypoints_split * self.simcc_split_ratio)
211
+ keypoints_split = keypoints_split.astype(np.int64)
212
+ keypoint_weights = keypoints_visible.copy()
213
+
214
+ return keypoints_split, keypoint_weights
215
+
216
+ def _generate_standard(
217
+ self,
218
+ keypoints: np.ndarray,
219
+ keypoints_visible: Optional[np.ndarray] = None
220
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
221
+ """Encoding keypoints into SimCC labels with Standard Label Smoothing
222
+ strategy.
223
+
224
+ Labels will be one-hot vectors if self.label_smooth_weight==0.0
225
+ """
226
+
227
+ N, K, _ = keypoints.shape
228
+ w, h = self.input_size
229
+ W = np.around(w * self.simcc_split_ratio).astype(int)
230
+ H = np.around(h * self.simcc_split_ratio).astype(int)
231
+
232
+ keypoints_split, keypoint_weights = self._map_coordinates(
233
+ keypoints, keypoints_visible)
234
+
235
+ target_x = np.zeros((N, K, W), dtype=np.float32)
236
+ target_y = np.zeros((N, K, H), dtype=np.float32)
237
+
238
+ for n, k in product(range(N), range(K)):
239
+ # skip unlabled keypoints
240
+ if keypoints_visible[n, k] < 0.5:
241
+ continue
242
+
243
+ # get center coordinates
244
+ mu_x, mu_y = keypoints_split[n, k].astype(np.int64)
245
+
246
+ # detect abnormal coords and assign the weight 0
247
+ if mu_x >= W or mu_y >= H or mu_x < 0 or mu_y < 0:
248
+ keypoint_weights[n, k] = 0
249
+ continue
250
+
251
+ if self.label_smooth_weight > 0:
252
+ target_x[n, k] = self.label_smooth_weight / (W - 1)
253
+ target_y[n, k] = self.label_smooth_weight / (H - 1)
254
+
255
+ target_x[n, k, mu_x] = 1.0 - self.label_smooth_weight
256
+ target_y[n, k, mu_y] = 1.0 - self.label_smooth_weight
257
+
258
+ return target_x, target_y, keypoint_weights
259
+
260
+ def _generate_gaussian(
261
+ self,
262
+ keypoints: np.ndarray,
263
+ keypoints_visible: Optional[np.ndarray] = None
264
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
265
+ """Encoding keypoints into SimCC labels with Gaussian Label Smoothing
266
+ strategy."""
267
+
268
+ N, K, _ = keypoints.shape
269
+ w, h = self.input_size
270
+ W = np.around(w * self.simcc_split_ratio).astype(int)
271
+ H = np.around(h * self.simcc_split_ratio).astype(int)
272
+
273
+ keypoints_split, keypoint_weights = self._map_coordinates(
274
+ keypoints, keypoints_visible)
275
+
276
+ target_x = np.zeros((N, K, W), dtype=np.float32)
277
+ target_y = np.zeros((N, K, H), dtype=np.float32)
278
+
279
+ # 3-sigma rule
280
+ radius = self.sigma * 3
281
+
282
+ # xy grid
283
+ x = np.arange(0, W, 1, dtype=np.float32)
284
+ y = np.arange(0, H, 1, dtype=np.float32)
285
+
286
+ for n, k in product(range(N), range(K)):
287
+ # skip unlabled keypoints
288
+ if keypoints_visible[n, k] < 0.5:
289
+ continue
290
+
291
+ mu = keypoints_split[n, k]
292
+
293
+ # check that the gaussian has in-bounds part
294
+ left, top = mu - radius
295
+ right, bottom = mu + radius + 1
296
+
297
+ if left >= W or top >= H or right < 0 or bottom < 0:
298
+ keypoint_weights[n, k] = 0
299
+ continue
300
+
301
+ mu_x, mu_y = mu
302
+
303
+ target_x[n, k] = np.exp(-((x - mu_x)**2) / (2 * self.sigma[0]**2))
304
+ target_y[n, k] = np.exp(-((y - mu_y)**2) / (2 * self.sigma[1]**2))
305
+
306
+ if self.normalize:
307
+ norm_value = self.sigma * np.sqrt(np.pi * 2)
308
+ target_x /= norm_value[0]
309
+ target_y /= norm_value[1]
310
+
311
+ return target_x, target_y, keypoint_weights
mmpose/codecs/spr.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from mmpose.registry import KEYPOINT_CODECS
9
+ from .base import BaseKeypointCodec
10
+ from .utils import (batch_heatmap_nms, generate_displacement_heatmap,
11
+ generate_gaussian_heatmaps, get_diagonal_lengths,
12
+ get_instance_root)
13
+
14
+
15
+ @KEYPOINT_CODECS.register_module()
16
+ class SPR(BaseKeypointCodec):
17
+ """Encode/decode keypoints with Structured Pose Representation (SPR).
18
+
19
+ See the paper `Single-stage multi-person pose machines`_
20
+ by Nie et al (2017) for details
21
+
22
+ Note:
23
+
24
+ - instance number: N
25
+ - keypoint number: K
26
+ - keypoint dimension: D
27
+ - image size: [w, h]
28
+ - heatmap size: [W, H]
29
+
30
+ Encoded:
31
+
32
+ - heatmaps (np.ndarray): The generated heatmap in shape (1, H, W)
33
+ where [W, H] is the `heatmap_size`. If the keypoint heatmap is
34
+ generated together, the output heatmap shape is (K+1, H, W)
35
+ - heatmap_weights (np.ndarray): The target weights for heatmaps which
36
+ has same shape with heatmaps.
37
+ - displacements (np.ndarray): The dense keypoint displacement in
38
+ shape (K*2, H, W).
39
+ - displacement_weights (np.ndarray): The target weights for heatmaps
40
+ which has same shape with displacements.
41
+
42
+ Args:
43
+ input_size (tuple): Image size in [w, h]
44
+ heatmap_size (tuple): Heatmap size in [W, H]
45
+ sigma (float or tuple, optional): The sigma values of the Gaussian
46
+ heatmaps. If sigma is a tuple, it includes both sigmas for root
47
+ and keypoint heatmaps. ``None`` means the sigmas are computed
48
+ automatically from the heatmap size. Defaults to ``None``
49
+ generate_keypoint_heatmaps (bool): Whether to generate Gaussian
50
+ heatmaps for each keypoint. Defaults to ``False``
51
+ root_type (str): The method to generate the instance root. Options
52
+ are:
53
+
54
+ - ``'kpt_center'``: Average coordinate of all visible keypoints.
55
+ - ``'bbox_center'``: Center point of bounding boxes outlined by
56
+ all visible keypoints.
57
+
58
+ Defaults to ``'kpt_center'``
59
+
60
+ minimal_diagonal_length (int or float): The threshold of diagonal
61
+ length of instance bounding box. Small instances will not be
62
+ used in training. Defaults to 32
63
+ background_weight (float): Loss weight of background pixels.
64
+ Defaults to 0.1
65
+ decode_thr (float): The threshold of keypoint response value in
66
+ heatmaps. Defaults to 0.01
67
+ decode_nms_kernel (int): The kernel size of the NMS during decoding,
68
+ which should be an odd integer. Defaults to 5
69
+ decode_max_instances (int): The maximum number of instances
70
+ to decode. Defaults to 30
71
+
72
+ .. _`Single-stage multi-person pose machines`:
73
+ https://arxiv.org/abs/1908.09220
74
+ """
75
+
76
+ field_mapping_table = dict(
77
+ heatmaps='heatmaps',
78
+ heatmap_weights='heatmap_weights',
79
+ displacements='displacements',
80
+ displacement_weights='displacement_weights',
81
+ )
82
+
83
+ def __init__(
84
+ self,
85
+ input_size: Tuple[int, int],
86
+ heatmap_size: Tuple[int, int],
87
+ sigma: Optional[Union[float, Tuple[float]]] = None,
88
+ generate_keypoint_heatmaps: bool = False,
89
+ root_type: str = 'kpt_center',
90
+ minimal_diagonal_length: Union[int, float] = 5,
91
+ background_weight: float = 0.1,
92
+ decode_nms_kernel: int = 5,
93
+ decode_max_instances: int = 30,
94
+ decode_thr: float = 0.01,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.input_size = input_size
99
+ self.heatmap_size = heatmap_size
100
+ self.generate_keypoint_heatmaps = generate_keypoint_heatmaps
101
+ self.root_type = root_type
102
+ self.minimal_diagonal_length = minimal_diagonal_length
103
+ self.background_weight = background_weight
104
+ self.decode_nms_kernel = decode_nms_kernel
105
+ self.decode_max_instances = decode_max_instances
106
+ self.decode_thr = decode_thr
107
+
108
+ self.scale_factor = (np.array(input_size) /
109
+ heatmap_size).astype(np.float32)
110
+
111
+ if sigma is None:
112
+ sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32
113
+ if generate_keypoint_heatmaps:
114
+ # sigma for root heatmap and keypoint heatmaps
115
+ self.sigma = (sigma, sigma // 2)
116
+ else:
117
+ self.sigma = (sigma, )
118
+ else:
119
+ if not isinstance(sigma, (tuple, list)):
120
+ sigma = (sigma, )
121
+ if generate_keypoint_heatmaps:
122
+ assert len(sigma) == 2, 'sigma for keypoints must be given ' \
123
+ 'if `generate_keypoint_heatmaps` ' \
124
+ 'is True. e.g. sigma=(4, 2)'
125
+ self.sigma = sigma
126
+
127
+ def _get_heatmap_weights(self,
128
+ heatmaps,
129
+ fg_weight: float = 1,
130
+ bg_weight: float = 0):
131
+ """Generate weight array for heatmaps.
132
+
133
+ Args:
134
+ heatmaps (np.ndarray): Root and keypoint (optional) heatmaps
135
+ fg_weight (float): Weight for foreground pixels. Defaults to 1.0
136
+ bg_weight (float): Weight for background pixels. Defaults to 0.0
137
+
138
+ Returns:
139
+ np.ndarray: Heatmap weight array in the same shape with heatmaps
140
+ """
141
+ heatmap_weights = np.ones(heatmaps.shape, dtype=np.float32) * bg_weight
142
+ heatmap_weights[heatmaps > 0] = fg_weight
143
+ return heatmap_weights
144
+
145
+ def encode(self,
146
+ keypoints: np.ndarray,
147
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
148
+ """Encode keypoints into root heatmaps and keypoint displacement
149
+ fields. Note that the original keypoint coordinates should be in the
150
+ input image space.
151
+
152
+ Args:
153
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
154
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
155
+ (N, K)
156
+
157
+ Returns:
158
+ dict:
159
+ - heatmaps (np.ndarray): The generated heatmap in shape
160
+ (1, H, W) where [W, H] is the `heatmap_size`. If keypoint
161
+ heatmaps are generated together, the shape is (K+1, H, W)
162
+ - heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps
163
+ which has same shape with `heatmaps`
164
+ - displacements (np.ndarray): The generated displacement fields in
165
+ shape (K*D, H, W). The vector on each pixels represents the
166
+ displacement of keypoints belong to the associated instance
167
+ from this pixel.
168
+ - displacement_weights (np.ndarray): The pixel-wise weight for
169
+ displacements which has same shape with `displacements`
170
+ """
171
+
172
+ if keypoints_visible is None:
173
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
174
+
175
+ # keypoint coordinates in heatmap
176
+ _keypoints = keypoints / self.scale_factor
177
+
178
+ # compute the root and scale of each instance
179
+ roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
180
+ self.root_type)
181
+ diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
182
+
183
+ # discard the small instances
184
+ roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0
185
+
186
+ # generate heatmaps
187
+ heatmaps, _ = generate_gaussian_heatmaps(
188
+ heatmap_size=self.heatmap_size,
189
+ keypoints=roots[:, None],
190
+ keypoints_visible=roots_visible[:, None],
191
+ sigma=self.sigma[0])
192
+ heatmap_weights = self._get_heatmap_weights(
193
+ heatmaps, bg_weight=self.background_weight)
194
+
195
+ if self.generate_keypoint_heatmaps:
196
+ keypoint_heatmaps, _ = generate_gaussian_heatmaps(
197
+ heatmap_size=self.heatmap_size,
198
+ keypoints=_keypoints,
199
+ keypoints_visible=keypoints_visible,
200
+ sigma=self.sigma[1])
201
+
202
+ keypoint_heatmaps_weights = self._get_heatmap_weights(
203
+ keypoint_heatmaps, bg_weight=self.background_weight)
204
+
205
+ heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0)
206
+ heatmap_weights = np.concatenate(
207
+ (keypoint_heatmaps_weights, heatmap_weights), axis=0)
208
+
209
+ # generate displacements
210
+ displacements, displacement_weights = \
211
+ generate_displacement_heatmap(
212
+ self.heatmap_size,
213
+ _keypoints,
214
+ keypoints_visible,
215
+ roots,
216
+ roots_visible,
217
+ diagonal_lengths,
218
+ self.sigma[0],
219
+ )
220
+
221
+ encoded = dict(
222
+ heatmaps=heatmaps,
223
+ heatmap_weights=heatmap_weights,
224
+ displacements=displacements,
225
+ displacement_weights=displacement_weights)
226
+
227
+ return encoded
228
+
229
+ def decode(self, heatmaps: Tensor,
230
+ displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]:
231
+ """Decode the keypoint coordinates from heatmaps and displacements. The
232
+ decoded keypoint coordinates are in the input image space.
233
+
234
+ Args:
235
+ heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps
236
+ in shape (1, H, W) or (K+1, H, W)
237
+ displacements (Tensor): Encoded keypoints displacement fields
238
+ in shape (K*D, H, W)
239
+
240
+ Returns:
241
+ tuple:
242
+ - keypoints (Tensor): Decoded keypoint coordinates in shape
243
+ (N, K, D)
244
+ - scores (tuple):
245
+ - root_scores (Tensor): The root scores in shape (N, )
246
+ - keypoint_scores (Tensor): The keypoint scores in
247
+ shape (N, K). If keypoint heatmaps are not generated,
248
+ `keypoint_scores` will be `None`
249
+ """
250
+ # heatmaps, displacements = encoded
251
+ _k, h, w = displacements.shape
252
+ k = _k // 2
253
+ displacements = displacements.view(k, 2, h, w)
254
+
255
+ # convert displacements to a dense keypoint prediction
256
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
257
+ regular_grid = torch.stack([x, y], dim=0).to(displacements)
258
+ posemaps = (regular_grid[None] + displacements).flatten(2)
259
+
260
+ # find local maximum on root heatmap
261
+ root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:],
262
+ self.decode_nms_kernel)
263
+ root_scores, pos_idx = root_heatmap_peaks.flatten().topk(
264
+ self.decode_max_instances)
265
+ mask = root_scores > self.decode_thr
266
+ root_scores, pos_idx = root_scores[mask], pos_idx[mask]
267
+
268
+ keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous()
269
+
270
+ if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k:
271
+ # compute scores for each keypoint
272
+ keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints)
273
+ else:
274
+ keypoint_scores = None
275
+
276
+ keypoints = torch.cat([
277
+ kpt * self.scale_factor[i]
278
+ for i, kpt in enumerate(keypoints.split(1, -1))
279
+ ],
280
+ dim=-1)
281
+ return keypoints, (root_scores, keypoint_scores)
282
+
283
+ def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor):
284
+ """Calculate the keypoint scores with keypoints heatmaps and
285
+ coordinates.
286
+
287
+ Args:
288
+ heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W)
289
+ keypoints (Tensor): Keypoint coordinates in shape (N, K, D)
290
+
291
+ Returns:
292
+ Tensor: Keypoint scores in [N, K]
293
+ """
294
+ k, h, w = heatmaps.shape
295
+ keypoints = torch.stack((
296
+ keypoints[..., 0] / (w - 1) * 2 - 1,
297
+ keypoints[..., 1] / (h - 1) * 2 - 1,
298
+ ),
299
+ dim=-1)
300
+ keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous()
301
+
302
+ keypoint_scores = torch.nn.functional.grid_sample(
303
+ heatmaps.unsqueeze(1), keypoints,
304
+ padding_mode='border').view(k, -1).transpose(0, 1).contiguous()
305
+
306
+ return keypoint_scores
mmpose/codecs/udp_heatmap.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from mmpose.registry import KEYPOINT_CODECS
8
+ from .base import BaseKeypointCodec
9
+ from .utils import (generate_offset_heatmap, generate_udp_gaussian_heatmaps,
10
+ get_heatmap_maximum, refine_keypoints_dark_udp)
11
+
12
+
13
+ @KEYPOINT_CODECS.register_module()
14
+ class UDPHeatmap(BaseKeypointCodec):
15
+ r"""Generate keypoint heatmaps by Unbiased Data Processing (UDP).
16
+ See the paper: `The Devil is in the Details: Delving into Unbiased Data
17
+ Processing for Human Pose Estimation`_ by Huang et al (2020) for details.
18
+
19
+ Note:
20
+
21
+ - instance number: N
22
+ - keypoint number: K
23
+ - keypoint dimension: D
24
+ - image size: [w, h]
25
+ - heatmap size: [W, H]
26
+
27
+ Encoded:
28
+
29
+ - heatmap (np.ndarray): The generated heatmap in shape (C_out, H, W)
30
+ where [W, H] is the `heatmap_size`, and the C_out is the output
31
+ channel number which depends on the `heatmap_type`. If
32
+ `heatmap_type=='gaussian'`, C_out equals to keypoint number K;
33
+ if `heatmap_type=='combined'`, C_out equals to K*3
34
+ (x_offset, y_offset and class label)
35
+ - keypoint_weights (np.ndarray): The target weights in shape (K,)
36
+
37
+ Args:
38
+ input_size (tuple): Image size in [w, h]
39
+ heatmap_size (tuple): Heatmap size in [W, H]
40
+ heatmap_type (str): The heatmap type to encode the keypoitns. Options
41
+ are:
42
+
43
+ - ``'gaussian'``: Gaussian heatmap
44
+ - ``'combined'``: Combination of a binary label map and offset
45
+ maps for X and Y axes.
46
+
47
+ sigma (float): The sigma value of the Gaussian heatmap when
48
+ ``heatmap_type=='gaussian'``. Defaults to 2.0
49
+ radius_factor (float): The radius factor of the binary label
50
+ map when ``heatmap_type=='combined'``. The positive region is
51
+ defined as the neighbor of the keypoit with the radius
52
+ :math:`r=radius_factor*max(W, H)`. Defaults to 0.0546875
53
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
54
+ modulation in DarkPose. Defaults to 11
55
+
56
+ .. _`The Devil is in the Details: Delving into Unbiased Data Processing for
57
+ Human Pose Estimation`: https://arxiv.org/abs/1911.07524
58
+ """
59
+
60
+ label_mapping_table = dict(keypoint_weights='keypoint_weights', )
61
+ field_mapping_table = dict(heatmaps='heatmaps', )
62
+
63
+ def __init__(self,
64
+ input_size: Tuple[int, int],
65
+ heatmap_size: Tuple[int, int],
66
+ heatmap_type: str = 'gaussian',
67
+ sigma: float = 2.,
68
+ radius_factor: float = 0.0546875,
69
+ blur_kernel_size: int = 11,
70
+ increase_sigma_with_padding=False,
71
+ amap_scale: float = 1.0,
72
+ normalize=None,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.input_size = np.array(input_size)
76
+ self.heatmap_size = np.array(heatmap_size)
77
+ self.sigma = sigma
78
+ self.radius_factor = radius_factor
79
+ self.heatmap_type = heatmap_type
80
+ self.blur_kernel_size = blur_kernel_size
81
+ self.increase_sigma_with_padding = increase_sigma_with_padding
82
+ self.normalize = normalize
83
+
84
+ self.amap_size = self.input_size * amap_scale
85
+ self.scale_factor = ((self.amap_size - 1) /
86
+ (self.heatmap_size - 1)).astype(np.float32)
87
+ self.input_center = self.input_size / 2
88
+ self.top_left = self.input_center - self.amap_size / 2
89
+
90
+ if self.heatmap_type not in {'gaussian', 'combined'}:
91
+ raise ValueError(
92
+ f'{self.__class__.__name__} got invalid `heatmap_type` value'
93
+ f'{self.heatmap_type}. Should be one of '
94
+ '{"gaussian", "combined"}')
95
+
96
+ def _kpts_to_activation_pts(self, keypoints: np.ndarray) -> np.ndarray:
97
+ """
98
+ Transform the keypoint coordinates to the activation space.
99
+ In the original UDPHeatmap, activation map is the same as the input image space with
100
+ different resolution but in this case we allow the activation map to have different
101
+ size (padding) than the input image space.
102
+ Centers of activation map and input image space are aligned.
103
+ """
104
+ transformed_keypoints = keypoints - self.top_left
105
+ transformed_keypoints = transformed_keypoints / self.scale_factor
106
+ return transformed_keypoints
107
+
108
+ def _activation_pts_to_kpts(self, keypoints: np.ndarray) -> np.ndarray:
109
+ """
110
+ Transform the points in activation map to the keypoint coordinates.
111
+ In the original UDPHeatmap, activation map is the same as the input image space with
112
+ different resolution but in this case we allow the activation map to have different
113
+ size (padding) than the input image space.
114
+ Centers of activation map and input image space are aligned.
115
+ """
116
+ W, H = self.heatmap_size
117
+ transformed_keypoints = keypoints / [W - 1, H - 1] * self.amap_size
118
+ transformed_keypoints += self.top_left
119
+ return transformed_keypoints
120
+
121
+ def encode(self,
122
+ keypoints: np.ndarray,
123
+ keypoints_visible: Optional[np.ndarray] = None,
124
+ id_similarity: Optional[float] = 0.0,
125
+ keypoints_visibility: Optional[np.ndarray] = None) -> dict:
126
+ """Encode keypoints into heatmaps. Note that the original keypoint
127
+ coordinates should be in the input image space.
128
+
129
+ Args:
130
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
131
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
132
+ (N, K)
133
+ id_similarity (float): The usefulness of the identity information
134
+ for the whole pose. Defaults to 0.0
135
+ keypoints_visibility (np.ndarray): The visibility bit for each
136
+ keypoint (N, K). Defaults to None
137
+
138
+ Returns:
139
+ dict:
140
+ - heatmap (np.ndarray): The generated heatmap in shape
141
+ (C_out, H, W) where [W, H] is the `heatmap_size`, and the
142
+ C_out is the output channel number which depends on the
143
+ `heatmap_type`. If `heatmap_type=='gaussian'`, C_out equals to
144
+ keypoint number K; if `heatmap_type=='combined'`, C_out
145
+ equals to K*3 (x_offset, y_offset and class label)
146
+ - keypoint_weights (np.ndarray): The target weights in shape
147
+ (K,)
148
+ """
149
+ assert keypoints.shape[0] == 1, (
150
+ f'{self.__class__.__name__} only support single-instance '
151
+ 'keypoint encoding')
152
+
153
+ if keypoints_visibility is None:
154
+ keypoints_visibility = np.zeros(keypoints.shape[:2], dtype=np.float32)
155
+
156
+ if keypoints_visible is None:
157
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
158
+
159
+ if self.heatmap_type == 'gaussian':
160
+ heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps(
161
+ heatmap_size=self.heatmap_size,
162
+ keypoints=self._kpts_to_activation_pts(keypoints),
163
+ keypoints_visible=keypoints_visible,
164
+ sigma=self.sigma,
165
+ keypoints_visibility=keypoints_visibility,
166
+ increase_sigma_with_padding=self.increase_sigma_with_padding)
167
+ elif self.heatmap_type == 'combined':
168
+ heatmaps, keypoint_weights = generate_offset_heatmap(
169
+ heatmap_size=self.heatmap_size,
170
+ keypoints=self._kpts_to_activation_pts(keypoints),
171
+ keypoints_visible=keypoints_visible,
172
+ radius_factor=self.radius_factor)
173
+ else:
174
+ raise ValueError(
175
+ f'{self.__class__.__name__} got invalid `heatmap_type` value'
176
+ f'{self.heatmap_type}. Should be one of '
177
+ '{"gaussian", "combined"}')
178
+
179
+ if self.normalize is not None:
180
+ heatmaps_sum = np.sum(heatmaps, axis=(1, 2), keepdims=False)
181
+ mask = heatmaps_sum > 0
182
+ heatmaps[mask, :, :] = heatmaps[mask, :, :] / (heatmaps_sum[mask, None, None] + np.finfo(np.float32).eps)
183
+ heatmaps = heatmaps * self.normalize
184
+
185
+ annotated = keypoints_visible > 0
186
+
187
+ heatmap_keypoints = self._kpts_to_activation_pts(keypoints)
188
+ in_image = np.logical_and(
189
+ heatmap_keypoints[:, :, 0] >= 0,
190
+ heatmap_keypoints[:, :, 0] < self.heatmap_size[0],
191
+ )
192
+ in_image = np.logical_and(
193
+ in_image,
194
+ heatmap_keypoints[:, :, 1] >= 0,
195
+ )
196
+ in_image = np.logical_and(
197
+ in_image,
198
+ heatmap_keypoints[:, :, 1] < self.heatmap_size[1],
199
+ )
200
+
201
+ encoded = dict(
202
+ heatmaps=heatmaps,
203
+ keypoint_weights=keypoint_weights,
204
+ annotated=annotated,
205
+ in_image=in_image,
206
+ keypoints_scaled=keypoints,
207
+ heatmap_keypoints=heatmap_keypoints,
208
+ identification_similarity=id_similarity,
209
+ )
210
+
211
+ return encoded
212
+
213
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
214
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
215
+ coordinates are in the input image space.
216
+
217
+ Args:
218
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
219
+
220
+ Returns:
221
+ tuple:
222
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
223
+ (N, K, D)
224
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
225
+ usually represents the confidence of the keypoint prediction
226
+ """
227
+ heatmaps = encoded.copy()
228
+
229
+ if self.heatmap_type == 'gaussian':
230
+ keypoints, scores = get_heatmap_maximum(heatmaps)
231
+ # unsqueeze the instance dimension for single-instance results
232
+ keypoints = keypoints[None]
233
+ scores = scores[None]
234
+
235
+ keypoints = refine_keypoints_dark_udp(
236
+ keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
237
+
238
+ elif self.heatmap_type == 'combined':
239
+ _K, H, W = heatmaps.shape
240
+ K = _K // 3
241
+
242
+ for cls_heatmap in heatmaps[::3]:
243
+ # Apply Gaussian blur on classification maps
244
+ ks = 2 * self.blur_kernel_size + 1
245
+ cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap)
246
+
247
+ # valid radius
248
+ radius = self.radius_factor * max(W, H)
249
+
250
+ x_offset = heatmaps[1::3].flatten() * radius
251
+ y_offset = heatmaps[2::3].flatten() * radius
252
+ keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3])
253
+ index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten()
254
+ index += W * H * np.arange(0, K)
255
+ index = index.astype(int)
256
+ keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1)
257
+ # unsqueeze the instance dimension for single-instance results
258
+ keypoints = keypoints[None].astype(np.float32)
259
+ scores = scores[None]
260
+
261
+ keypoints = self._activation_pts_to_kpts(keypoints)
262
+
263
+ return keypoints, scores
mmpose/codecs/utils/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .camera_image_projection import (camera_to_image_coord, camera_to_pixel,
3
+ pixel_to_camera)
4
+ from .gaussian_heatmap import (generate_3d_gaussian_heatmaps,
5
+ generate_gaussian_heatmaps,
6
+ generate_udp_gaussian_heatmaps,
7
+ generate_unbiased_gaussian_heatmaps,
8
+ generate_onehot_heatmaps)
9
+ from .instance_property import (get_diagonal_lengths, get_instance_bbox,
10
+ get_instance_root)
11
+ from .offset_heatmap import (generate_displacement_heatmap,
12
+ generate_offset_heatmap)
13
+ from .post_processing import (batch_heatmap_nms, gaussian_blur,
14
+ gaussian_blur1d, get_heatmap_3d_maximum,
15
+ get_heatmap_maximum, get_simcc_maximum,
16
+ get_simcc_normalized, get_heatmap_expected_value)
17
+ from .refinement import (refine_keypoints, refine_keypoints_dark,
18
+ refine_keypoints_dark_udp, refine_simcc_dark)
19
+ from .oks_map import generate_oks_maps
20
+
21
+ __all__ = [
22
+ 'generate_gaussian_heatmaps', 'generate_udp_gaussian_heatmaps',
23
+ 'generate_unbiased_gaussian_heatmaps', 'gaussian_blur',
24
+ 'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap',
25
+ 'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark',
26
+ 'refine_keypoints_dark_udp', 'generate_displacement_heatmap',
27
+ 'refine_simcc_dark', 'gaussian_blur1d', 'get_diagonal_lengths',
28
+ 'get_instance_root', 'get_instance_bbox', 'get_simcc_normalized',
29
+ 'camera_to_image_coord', 'camera_to_pixel', 'pixel_to_camera',
30
+ 'get_heatmap_3d_maximum', 'generate_3d_gaussian_heatmaps',
31
+ 'generate_oks_maps', 'get_heatmap_expected_value', 'generate_onehot_heatmaps'
32
+ ]
mmpose/codecs/utils/camera_image_projection.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Dict, Tuple
3
+
4
+ import numpy as np
5
+
6
+
7
+ def camera_to_image_coord(root_index: int, kpts_3d_cam: np.ndarray,
8
+ camera_param: Dict) -> Tuple[np.ndarray, np.ndarray]:
9
+ """Project keypoints from camera space to image space and calculate factor.
10
+
11
+ Args:
12
+ root_index (int): Index for root keypoint.
13
+ kpts_3d_cam (np.ndarray): Keypoint coordinates in camera space in
14
+ shape (N, K, D).
15
+ camera_param (dict): Parameters for the camera.
16
+
17
+ Returns:
18
+ tuple:
19
+ - kpts_3d_image (np.ndarray): Keypoint coordinates in image space in
20
+ shape (N, K, D).
21
+ - factor (np.ndarray): The scaling factor that maps keypoints from
22
+ image space to camera space in shape (N, ).
23
+ """
24
+
25
+ root = kpts_3d_cam[..., root_index, :]
26
+ tl_kpt = root.copy()
27
+ tl_kpt[..., :2] -= 1.0
28
+ br_kpt = root.copy()
29
+ br_kpt[..., :2] += 1.0
30
+ tl_kpt = np.reshape(tl_kpt, (-1, 3))
31
+ br_kpt = np.reshape(br_kpt, (-1, 3))
32
+ fx, fy = camera_param['f'] / 1000.
33
+ cx, cy = camera_param['c'] / 1000.
34
+
35
+ tl2d = camera_to_pixel(tl_kpt, fx, fy, cx, cy)
36
+ br2d = camera_to_pixel(br_kpt, fx, fy, cx, cy)
37
+
38
+ rectangle_3d_size = 2.0
39
+ kpts_3d_image = np.zeros_like(kpts_3d_cam)
40
+ kpts_3d_image[..., :2] = camera_to_pixel(kpts_3d_cam.copy(), fx, fy, cx,
41
+ cy)
42
+ ratio = (br2d[..., 0] - tl2d[..., 0] + 0.001) / rectangle_3d_size
43
+ factor = rectangle_3d_size / (br2d[..., 0] - tl2d[..., 0] + 0.001)
44
+ kpts_3d_depth = ratio[:, None] * (
45
+ kpts_3d_cam[..., 2] - kpts_3d_cam[..., root_index:root_index + 1, 2])
46
+ kpts_3d_image[..., 2] = kpts_3d_depth
47
+ return kpts_3d_image, factor
48
+
49
+
50
+ def camera_to_pixel(kpts_3d: np.ndarray,
51
+ fx: float,
52
+ fy: float,
53
+ cx: float,
54
+ cy: float,
55
+ shift: bool = False) -> np.ndarray:
56
+ """Project keypoints from camera space to image space.
57
+
58
+ Args:
59
+ kpts_3d (np.ndarray): Keypoint coordinates in camera space.
60
+ fx (float): x-coordinate of camera's focal length.
61
+ fy (float): y-coordinate of camera's focal length.
62
+ cx (float): x-coordinate of image center.
63
+ cy (float): y-coordinate of image center.
64
+ shift (bool): Whether to shift the coordinates by 1e-8.
65
+
66
+ Returns:
67
+ pose_2d (np.ndarray): Projected keypoint coordinates in image space.
68
+ """
69
+ if not shift:
70
+ pose_2d = kpts_3d[..., :2] / kpts_3d[..., 2:3]
71
+ else:
72
+ pose_2d = kpts_3d[..., :2] / (kpts_3d[..., 2:3] + 1e-8)
73
+ pose_2d[..., 0] *= fx
74
+ pose_2d[..., 1] *= fy
75
+ pose_2d[..., 0] += cx
76
+ pose_2d[..., 1] += cy
77
+ return pose_2d
78
+
79
+
80
+ def pixel_to_camera(kpts_3d: np.ndarray, fx: float, fy: float, cx: float,
81
+ cy: float) -> np.ndarray:
82
+ """Project keypoints from camera space to image space.
83
+
84
+ Args:
85
+ kpts_3d (np.ndarray): Keypoint coordinates in camera space.
86
+ fx (float): x-coordinate of camera's focal length.
87
+ fy (float): y-coordinate of camera's focal length.
88
+ cx (float): x-coordinate of image center.
89
+ cy (float): y-coordinate of image center.
90
+ shift (bool): Whether to shift the coordinates by 1e-8.
91
+
92
+ Returns:
93
+ pose_2d (np.ndarray): Projected keypoint coordinates in image space.
94
+ """
95
+ pose_2d = kpts_3d.copy()
96
+ pose_2d[..., 0] -= cx
97
+ pose_2d[..., 1] -= cy
98
+ pose_2d[..., 0] /= fx
99
+ pose_2d[..., 1] /= fy
100
+ pose_2d[..., 0] *= kpts_3d[..., 2]
101
+ pose_2d[..., 1] *= kpts_3d[..., 2]
102
+ return pose_2d
mmpose/codecs/utils/gaussian_heatmap.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from scipy.spatial.distance import cdist
7
+
8
+
9
+ def generate_3d_gaussian_heatmaps(
10
+ heatmap_size: Tuple[int, int, int],
11
+ keypoints: np.ndarray,
12
+ keypoints_visible: np.ndarray,
13
+ sigma: Union[float, Tuple[float], np.ndarray],
14
+ image_size: Tuple[int, int],
15
+ heatmap3d_depth_bound: float = 400.0,
16
+ joint_indices: Optional[list] = None,
17
+ max_bound: float = 1.0,
18
+ use_different_joint_weights: bool = False,
19
+ dataset_keypoint_weights: Optional[np.ndarray] = None
20
+ ) -> Tuple[np.ndarray, np.ndarray]:
21
+ """Generate 3d gaussian heatmaps of keypoints.
22
+
23
+ Args:
24
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H, D]
25
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, C)
26
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
27
+ (N, K)
28
+ sigma (float or List[float]): A list of sigma values of the Gaussian
29
+ heatmap for each instance. If sigma is given as a single float
30
+ value, it will be expanded into a tuple
31
+ image_size (Tuple[int, int]): Size of input image.
32
+ heatmap3d_depth_bound (float): Boundary for 3d heatmap depth.
33
+ Default: 400.0.
34
+ joint_indices (List[int], optional): Indices of joints used for heatmap
35
+ generation. If None (default) is given, all joints will be used.
36
+ Default: ``None``.
37
+ max_bound (float): The maximal value of heatmap. Default: 1.0.
38
+ use_different_joint_weights (bool): Whether to use different joint
39
+ weights. Default: ``False``.
40
+ dataset_keypoint_weights (np.ndarray, optional): Keypoints weight in
41
+ shape (K, ).
42
+
43
+ Returns:
44
+ tuple:
45
+ - heatmaps (np.ndarray): The generated heatmap in shape
46
+ (K * D, H, W) where [W, H, D] is the `heatmap_size`
47
+ - keypoint_weights (np.ndarray): The target weights in shape
48
+ (N, K)
49
+ """
50
+
51
+ W, H, D = heatmap_size
52
+
53
+ # select the joints used for target generation
54
+ if joint_indices is not None:
55
+ keypoints = keypoints[:, joint_indices, ...]
56
+ keypoints_visible = keypoints_visible[:, joint_indices, ...]
57
+ N, K, _ = keypoints.shape
58
+
59
+ heatmaps = np.zeros([K, D, H, W], dtype=np.float32)
60
+ keypoint_weights = keypoints_visible.copy()
61
+
62
+ if isinstance(sigma, (int, float)):
63
+ sigma = (sigma, ) * N
64
+
65
+ for n in range(N):
66
+ # 3-sigma rule
67
+ radius = sigma[n] * 3
68
+
69
+ # joint location in heatmap coordinates
70
+ mu_x = keypoints[n, :, 0] * W / image_size[0] # (K, )
71
+ mu_y = keypoints[n, :, 1] * H / image_size[1]
72
+ mu_z = (keypoints[n, :, 2] / heatmap3d_depth_bound + 0.5) * D
73
+
74
+ keypoint_weights[n, ...] = keypoint_weights[n, ...] * (mu_z >= 0) * (
75
+ mu_z < D)
76
+ if use_different_joint_weights:
77
+ keypoint_weights[
78
+ n] = keypoint_weights[n] * dataset_keypoint_weights
79
+ # xy grid
80
+ gaussian_size = 2 * radius + 1
81
+
82
+ # get neighboring voxels coordinates
83
+ x = y = z = np.arange(gaussian_size, dtype=np.float32) - radius
84
+ zz, yy, xx = np.meshgrid(z, y, x)
85
+
86
+ xx = np.expand_dims(xx, axis=0)
87
+ yy = np.expand_dims(yy, axis=0)
88
+ zz = np.expand_dims(zz, axis=0)
89
+ mu_x = np.expand_dims(mu_x, axis=(-1, -2, -3))
90
+ mu_y = np.expand_dims(mu_y, axis=(-1, -2, -3))
91
+ mu_z = np.expand_dims(mu_z, axis=(-1, -2, -3))
92
+
93
+ xx, yy, zz = xx + mu_x, yy + mu_y, zz + mu_z
94
+ local_size = xx.shape[1]
95
+
96
+ # round the coordinates
97
+ xx = xx.round().clip(0, W - 1)
98
+ yy = yy.round().clip(0, H - 1)
99
+ zz = zz.round().clip(0, D - 1)
100
+
101
+ # compute the target value near joints
102
+ gaussian = np.exp(-((xx - mu_x)**2 + (yy - mu_y)**2 + (zz - mu_z)**2) /
103
+ (2 * sigma[n]**2))
104
+
105
+ # put the local target value to the full target heatmap
106
+ idx_joints = np.tile(
107
+ np.expand_dims(np.arange(K), axis=(-1, -2, -3)),
108
+ [1, local_size, local_size, local_size])
109
+ idx = np.stack([idx_joints, zz, yy, xx],
110
+ axis=-1).astype(int).reshape(-1, 4)
111
+
112
+ heatmaps[idx[:, 0], idx[:, 1], idx[:, 2], idx[:, 3]] = np.maximum(
113
+ heatmaps[idx[:, 0], idx[:, 1], idx[:, 2], idx[:, 3]],
114
+ gaussian.reshape(-1))
115
+
116
+ heatmaps = (heatmaps * max_bound).reshape(-1, H, W)
117
+
118
+ return heatmaps, keypoint_weights
119
+
120
+
121
+ def generate_gaussian_heatmaps(
122
+ heatmap_size: Tuple[int, int],
123
+ keypoints: np.ndarray,
124
+ keypoints_visible: np.ndarray,
125
+ sigma: Union[float, Tuple[float], np.ndarray],
126
+ ) -> Tuple[np.ndarray, np.ndarray]:
127
+ """Generate gaussian heatmaps of keypoints.
128
+
129
+ Args:
130
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
131
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
132
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
133
+ (N, K)
134
+ sigma (float or List[float]): A list of sigma values of the Gaussian
135
+ heatmap for each instance. If sigma is given as a single float
136
+ value, it will be expanded into a tuple
137
+
138
+ Returns:
139
+ tuple:
140
+ - heatmaps (np.ndarray): The generated heatmap in shape
141
+ (K, H, W) where [W, H] is the `heatmap_size`
142
+ - keypoint_weights (np.ndarray): The target weights in shape
143
+ (N, K)
144
+ """
145
+
146
+ N, K, _ = keypoints.shape
147
+ W, H = heatmap_size
148
+
149
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
150
+ keypoint_weights = keypoints_visible.copy()
151
+
152
+ if isinstance(sigma, (int, float)):
153
+ sigma = (sigma, ) * N
154
+
155
+ for n in range(N):
156
+ # 3-sigma rule
157
+ radius = sigma[n] * 3
158
+
159
+ # xy grid
160
+ gaussian_size = 2 * radius + 1
161
+ x = np.arange(0, gaussian_size, 1, dtype=np.float32)
162
+ y = x[:, None]
163
+ x0 = y0 = gaussian_size // 2
164
+
165
+ for k in range(K):
166
+ # skip unlabled keypoints
167
+ if keypoints_visible[n, k] < 0.5:
168
+ continue
169
+
170
+ # get gaussian center coordinates
171
+ mu = (keypoints[n, k] + 0.5).astype(np.int64)
172
+
173
+ # check that the gaussian has in-bounds part
174
+ left, top = (mu - radius).astype(np.int64)
175
+ right, bottom = (mu + radius + 1).astype(np.int64)
176
+
177
+ if left >= W or top >= H or right < 0 or bottom < 0:
178
+ keypoint_weights[n, k] = 0
179
+ continue
180
+
181
+ # The gaussian is not normalized,
182
+ # we want the center value to equal 1
183
+ gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma[n]**2))
184
+
185
+ # valid range in gaussian
186
+ g_x1 = max(0, -left)
187
+ g_x2 = min(W, right) - left
188
+ g_y1 = max(0, -top)
189
+ g_y2 = min(H, bottom) - top
190
+
191
+ # valid range in heatmap
192
+ h_x1 = max(0, left)
193
+ h_x2 = min(W, right)
194
+ h_y1 = max(0, top)
195
+ h_y2 = min(H, bottom)
196
+
197
+ heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2]
198
+ gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2]
199
+
200
+ _ = np.maximum(
201
+ heatmap_region, gaussian_regsion, out=heatmap_region)
202
+
203
+ return heatmaps, keypoint_weights
204
+
205
+
206
+ def generate_unbiased_gaussian_heatmaps(
207
+ heatmap_size: Tuple[int, int],
208
+ keypoints: np.ndarray,
209
+ keypoints_visible: np.ndarray,
210
+ sigma: float,
211
+ ) -> Tuple[np.ndarray, np.ndarray]:
212
+ """Generate gaussian heatmaps of keypoints using `Dark Pose`_.
213
+
214
+ Args:
215
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
216
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
217
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
218
+ (N, K)
219
+
220
+ Returns:
221
+ tuple:
222
+ - heatmaps (np.ndarray): The generated heatmap in shape
223
+ (K, H, W) where [W, H] is the `heatmap_size`
224
+ - keypoint_weights (np.ndarray): The target weights in shape
225
+ (N, K)
226
+
227
+ .. _`Dark Pose`: https://arxiv.org/abs/1910.06278
228
+ """
229
+
230
+ N, K, _ = keypoints.shape
231
+ W, H = heatmap_size
232
+
233
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
234
+ keypoint_weights = keypoints_visible.copy()
235
+
236
+ # 3-sigma rule
237
+ radius = sigma * 3
238
+
239
+ # xy grid
240
+ x = np.arange(0, W, 1, dtype=np.float32)
241
+ y = np.arange(0, H, 1, dtype=np.float32)[:, None]
242
+
243
+ for n, k in product(range(N), range(K)):
244
+ # skip unlabled keypoints
245
+ if keypoints_visible[n, k] < 0.5:
246
+ continue
247
+
248
+ mu = keypoints[n, k]
249
+ # check that the gaussian has in-bounds part
250
+ left, top = mu - radius
251
+ right, bottom = mu + radius + 1
252
+
253
+ if left >= W or top >= H or right < 0 or bottom < 0:
254
+ keypoint_weights[n, k] = 0
255
+ continue
256
+
257
+ gaussian = np.exp(-((x - mu[0])**2 + (y - mu[1])**2) / (2 * sigma**2))
258
+
259
+ _ = np.maximum(gaussian, heatmaps[k], out=heatmaps[k])
260
+
261
+ return heatmaps, keypoint_weights
262
+
263
+
264
+ def generate_udp_gaussian_heatmaps(
265
+ heatmap_size: Tuple[int, int],
266
+ keypoints: np.ndarray,
267
+ keypoints_visible: np.ndarray,
268
+ sigma,
269
+ keypoints_visibility: np.ndarray,
270
+ increase_sigma_with_padding: bool = False,
271
+ ) -> Tuple[np.ndarray, np.ndarray]:
272
+ """Generate gaussian heatmaps of keypoints using `UDP`_.
273
+
274
+ Args:
275
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
276
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
277
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
278
+ (N, K)
279
+ sigma (float): The sigma value of the Gaussian heatmap
280
+ keypoints_visibility (np.ndarray): The visibility bit for each keypoint (N, K)
281
+ increase_sigma_with_padding (bool): Whether to increase the sigma
282
+ value with padding. Default: False
283
+
284
+ Returns:
285
+ tuple:
286
+ - heatmaps (np.ndarray): The generated heatmap in shape
287
+ (K, H, W) where [W, H] is the `heatmap_size`
288
+ - keypoint_weights (np.ndarray): The target weights in shape
289
+ (N, K)
290
+
291
+ .. _`UDP`: https://arxiv.org/abs/1911.07524
292
+ """
293
+
294
+ N, K, _ = keypoints.shape
295
+ W, H = heatmap_size
296
+
297
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
298
+ keypoint_weights = keypoints_visible.copy()
299
+
300
+ if isinstance(sigma, (int, float)):
301
+ scaled_sigmas = sigma * np.ones((N, K), dtype=np.float32)
302
+ sigmas = np.array([sigma] * K).reshape(1, -1).repeat(N, axis=0)
303
+ else:
304
+ scaled_sigmas = np.array(sigma).reshape(1, -1).repeat(N, axis=0)
305
+ sigmas = np.array(sigma).reshape(1, -1).repeat(N, axis=0)
306
+
307
+ scales_arr = np.ones((N, K), dtype=np.float32)
308
+ if increase_sigma_with_padding:
309
+ diag = np.sqrt(W**2 + H**2)
310
+ for n in range(N):
311
+ image_kpts = keypoints[n, :].squeeze()
312
+ vis_kpts = image_kpts[keypoints_visibility[n, :] > 0.5]
313
+
314
+ # Compute the distance between img_kpts and visible_kpts
315
+ if vis_kpts.size == 0:
316
+ min_dists = np.ones(image_kpts.shape[0]) * diag
317
+ else:
318
+ dists = cdist(image_kpts, vis_kpts, metric='euclidean')
319
+ min_dists = np.min(dists, axis=1)
320
+
321
+ scales = min_dists / diag * 2.0 # Maximum distance (diagonal) results in .0*sigma
322
+ scales_arr[n, :] = scales
323
+ scaled_sigmas[n, :] = sigma * (1+scales)
324
+
325
+ # print(scales_arr)
326
+ # print(scaled_sigmas)
327
+
328
+ for n, k in product(range(N), range(K)):
329
+ scaled_sigma = scaled_sigmas[n, k]
330
+ # skip unlabled keypoints
331
+ if keypoints_visible[n, k] < 0.5:
332
+ continue
333
+
334
+ # 3-sigma rule
335
+ radius = scaled_sigma * 3
336
+
337
+ # xy grid
338
+ gaussian_size = 2 * radius + 1
339
+ x = np.arange(0, gaussian_size, 1, dtype=np.float32)
340
+ y = x[:, None]
341
+
342
+ mu = (keypoints[n, k] + 0.5).astype(np.int64)
343
+ # check that the gaussian has in-bounds part
344
+ left, top = (mu - radius).round().astype(np.int64)
345
+ right, bottom = (mu + radius + 1).round().astype(np.int64)
346
+ # left, top = (mu - radius).astype(np.int64)
347
+ # right, bottom = (mu + radius + 1).astype(np.int64)
348
+
349
+ if left >= W or top >= H or right < 0 or bottom < 0:
350
+ keypoint_weights[n, k] = 0
351
+ continue
352
+
353
+ mu_ac = keypoints[n, k]
354
+ x0 = y0 = gaussian_size // 2
355
+ x0 += mu_ac[0] - mu[0]
356
+ y0 += mu_ac[1] - mu[1]
357
+ gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * scaled_sigma**2))
358
+
359
+ # Normalize Gaussian such that scaled_sigma = sigma is the norm
360
+ gaussian = gaussian / (scaled_sigma / sigmas[n, k])
361
+
362
+ # valid range in gaussian
363
+ g_x1 = max(0, -left)
364
+ g_x2 = min(W, right) - left
365
+ g_y1 = max(0, -top)
366
+ g_y2 = min(H, bottom) - top
367
+
368
+ # valid range in heatmap
369
+ h_x1 = max(0, left)
370
+ h_x2 = min(W, right)
371
+ h_y1 = max(0, top)
372
+ h_y2 = min(H, bottom)
373
+
374
+ # breakpoint()
375
+
376
+ heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2]
377
+ gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2]
378
+
379
+ _ = np.maximum(heatmap_region, gaussian_regsion, out=heatmap_region)
380
+
381
+ return heatmaps, keypoint_weights
382
+
383
+
384
+ def generate_onehot_heatmaps(
385
+ heatmap_size: Tuple[int, int],
386
+ keypoints: np.ndarray,
387
+ keypoints_visible: np.ndarray,
388
+ sigma,
389
+ keypoints_visibility: np.ndarray,
390
+ increase_sigma_with_padding: bool = False,
391
+ ) -> Tuple[np.ndarray, np.ndarray]:
392
+ """Generate gaussian heatmaps of keypoints using `UDP`_.
393
+
394
+ Args:
395
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
396
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
397
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
398
+ (N, K)
399
+ sigma (float): The sigma value of the Gaussian heatmap
400
+ keypoints_visibility (np.ndarray): The visibility bit for each keypoint (N, K)
401
+ increase_sigma_with_padding (bool): Whether to increase the sigma
402
+ value with padding. Default: False
403
+
404
+ Returns:
405
+ tuple:
406
+ - heatmaps (np.ndarray): The generated heatmap in shape
407
+ (K, H, W) where [W, H] is the `heatmap_size`
408
+ - keypoint_weights (np.ndarray): The target weights in shape
409
+ (N, K)
410
+
411
+ .. _`UDP`: https://arxiv.org/abs/1911.07524
412
+ """
413
+
414
+ N, K, _ = keypoints.shape
415
+ W, H = heatmap_size
416
+
417
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
418
+ keypoint_weights = keypoints_visible.copy()
419
+
420
+ for n, k in product(range(N), range(K)):
421
+ # skip unlabled keypoints
422
+ if keypoints_visible[n, k] < 0.5:
423
+ continue
424
+
425
+ mu = (keypoints[n, k] + 0.5).astype(np.int64)
426
+
427
+
428
+ if mu[0] < 0 or mu[0] >= W or mu[1] < 0 or mu[1] >= H:
429
+ keypoint_weights[n, k] = 0
430
+ continue
431
+
432
+ heatmaps[k, mu[1], mu[0]] = 1
433
+ return heatmaps, keypoint_weights
mmpose/codecs/utils/instance_property.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def get_instance_root(keypoints: np.ndarray,
8
+ keypoints_visible: Optional[np.ndarray] = None,
9
+ root_type: str = 'kpt_center') -> np.ndarray:
10
+ """Calculate the coordinates and visibility of instance roots.
11
+
12
+ Args:
13
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
14
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
15
+ (N, K)
16
+ root_type (str): Calculation of instance roots which should
17
+ be one of the following options:
18
+
19
+ - ``'kpt_center'``: The roots' coordinates are the mean
20
+ coordinates of visible keypoints
21
+ - ``'bbox_center'``: The roots' are the center of bounding
22
+ boxes outlined by visible keypoints
23
+
24
+ Defaults to ``'kpt_center'``
25
+
26
+ Returns:
27
+ tuple
28
+ - roots_coordinate(np.ndarray): Coordinates of instance roots in
29
+ shape [N, D]
30
+ - roots_visible(np.ndarray): Visibility of instance roots in
31
+ shape [N]
32
+ """
33
+
34
+ roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32)
35
+ roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2
36
+
37
+ for i in range(keypoints.shape[0]):
38
+
39
+ # collect visible keypoints
40
+ if keypoints_visible is not None:
41
+ visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
42
+ else:
43
+ visible_keypoints = keypoints[i]
44
+ if visible_keypoints.size == 0:
45
+ roots_visible[i] = 0
46
+ continue
47
+
48
+ # compute the instance root with visible keypoints
49
+ if root_type == 'kpt_center':
50
+ roots_coordinate[i] = visible_keypoints.mean(axis=0)
51
+ roots_visible[i] = 1
52
+ elif root_type == 'bbox_center':
53
+ roots_coordinate[i] = (visible_keypoints.max(axis=0) +
54
+ visible_keypoints.min(axis=0)) / 2.0
55
+ roots_visible[i] = 1
56
+ else:
57
+ raise ValueError(
58
+ f'the value of `root_type` must be \'kpt_center\' or '
59
+ f'\'bbox_center\', but got \'{root_type}\'')
60
+
61
+ return roots_coordinate, roots_visible
62
+
63
+
64
+ def get_instance_bbox(keypoints: np.ndarray,
65
+ keypoints_visible: Optional[np.ndarray] = None
66
+ ) -> np.ndarray:
67
+ """Calculate the pseudo instance bounding box from visible keypoints. The
68
+ bounding boxes are in the xyxy format.
69
+
70
+ Args:
71
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
72
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
73
+ (N, K)
74
+
75
+ Returns:
76
+ np.ndarray: bounding boxes in [N, 4]
77
+ """
78
+ bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32)
79
+ for i in range(keypoints.shape[0]):
80
+ if keypoints_visible is not None:
81
+ visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
82
+ else:
83
+ visible_keypoints = keypoints[i]
84
+ if visible_keypoints.size == 0:
85
+ continue
86
+
87
+ bbox[i, :2] = visible_keypoints.min(axis=0)
88
+ bbox[i, 2:] = visible_keypoints.max(axis=0)
89
+ return bbox
90
+
91
+
92
+ def get_diagonal_lengths(keypoints: np.ndarray,
93
+ keypoints_visible: Optional[np.ndarray] = None
94
+ ) -> np.ndarray:
95
+ """Calculate the diagonal length of instance bounding box from visible
96
+ keypoints.
97
+
98
+ Args:
99
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
100
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
101
+ (N, K)
102
+
103
+ Returns:
104
+ np.ndarray: bounding box diagonal length in [N]
105
+ """
106
+ pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible)
107
+ pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2)
108
+ h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0]
109
+ diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1))
110
+
111
+ return diagonal_length
mmpose/codecs/utils/offset_heatmap.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+
7
+
8
+ def generate_offset_heatmap(
9
+ heatmap_size: Tuple[int, int],
10
+ keypoints: np.ndarray,
11
+ keypoints_visible: np.ndarray,
12
+ radius_factor: float,
13
+ ) -> Tuple[np.ndarray, np.ndarray]:
14
+ """Generate offset heatmaps of keypoints, where each keypoint is
15
+ represented by 3 maps: one pixel-level class label map (1 for keypoint and
16
+ 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions
17
+ respectively.
18
+
19
+ Args:
20
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
21
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
22
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
23
+ (N, K)
24
+ radius_factor (float): The radius factor of the binary label
25
+ map. The positive region is defined as the neighbor of the
26
+ keypoint with the radius :math:`r=radius_factor*max(W, H)`
27
+
28
+ Returns:
29
+ tuple:
30
+ - heatmap (np.ndarray): The generated heatmap in shape
31
+ (K*3, H, W) where [W, H] is the `heatmap_size`
32
+ - keypoint_weights (np.ndarray): The target weights in shape
33
+ (K,)
34
+ """
35
+
36
+ N, K, _ = keypoints.shape
37
+ W, H = heatmap_size
38
+
39
+ heatmaps = np.zeros((K, 3, H, W), dtype=np.float32)
40
+ keypoint_weights = keypoints_visible.copy()
41
+
42
+ # xy grid
43
+ x = np.arange(0, W, 1)
44
+ y = np.arange(0, H, 1)[:, None]
45
+
46
+ # positive area radius in the classification map
47
+ radius = radius_factor * max(W, H)
48
+
49
+ for n, k in product(range(N), range(K)):
50
+ if keypoints_visible[n, k] < 0.5:
51
+ continue
52
+
53
+ mu = keypoints[n, k]
54
+
55
+ x_offset = (mu[0] - x) / radius
56
+ y_offset = (mu[1] - y) / radius
57
+
58
+ heatmaps[k, 0] = np.where(x_offset**2 + y_offset**2 <= 1, 1., 0.)
59
+ heatmaps[k, 1] = x_offset
60
+ heatmaps[k, 2] = y_offset
61
+
62
+ heatmaps = heatmaps.reshape(K * 3, H, W)
63
+
64
+ return heatmaps, keypoint_weights
65
+
66
+
67
+ def generate_displacement_heatmap(
68
+ heatmap_size: Tuple[int, int],
69
+ keypoints: np.ndarray,
70
+ keypoints_visible: np.ndarray,
71
+ roots: np.ndarray,
72
+ roots_visible: np.ndarray,
73
+ diagonal_lengths: np.ndarray,
74
+ radius: float,
75
+ ):
76
+ """Generate displacement heatmaps of keypoints, where each keypoint is
77
+ represented by 3 maps: one pixel-level class label map (1 for keypoint and
78
+ 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions
79
+ respectively.
80
+
81
+ Args:
82
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
83
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
84
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
85
+ (N, K)
86
+ roots (np.ndarray): Coordinates of instance centers in shape (N, D).
87
+ The displacement fields of each instance will locate around its
88
+ center.
89
+ roots_visible (np.ndarray): Roots visibilities in shape (N,)
90
+ diagonal_lengths (np.ndarray): Diaginal length of the bounding boxes
91
+ of each instance in shape (N,)
92
+ radius (float): The radius factor of the binary label
93
+ map. The positive region is defined as the neighbor of the
94
+ keypoint with the radius :math:`r=radius_factor*max(W, H)`
95
+
96
+ Returns:
97
+ tuple:
98
+ - displacements (np.ndarray): The generated displacement map in
99
+ shape (K*2, H, W) where [W, H] is the `heatmap_size`
100
+ - displacement_weights (np.ndarray): The target weights in shape
101
+ (K*2, H, W)
102
+ """
103
+ N, K, _ = keypoints.shape
104
+ W, H = heatmap_size
105
+
106
+ displacements = np.zeros((K * 2, H, W), dtype=np.float32)
107
+ displacement_weights = np.zeros((K * 2, H, W), dtype=np.float32)
108
+ instance_size_map = np.zeros((H, W), dtype=np.float32)
109
+
110
+ for n in range(N):
111
+ if (roots_visible[n] < 1 or (roots[n, 0] < 0 or roots[n, 1] < 0)
112
+ or (roots[n, 0] >= W or roots[n, 1] >= H)):
113
+ continue
114
+
115
+ diagonal_length = diagonal_lengths[n]
116
+
117
+ for k in range(K):
118
+ if keypoints_visible[n, k] < 1 or keypoints[n, k, 0] < 0 \
119
+ or keypoints[n, k, 1] < 0 or keypoints[n, k, 0] >= W \
120
+ or keypoints[n, k, 1] >= H:
121
+ continue
122
+
123
+ start_x = max(int(roots[n, 0] - radius), 0)
124
+ start_y = max(int(roots[n, 1] - radius), 0)
125
+ end_x = min(int(roots[n, 0] + radius), W)
126
+ end_y = min(int(roots[n, 1] + radius), H)
127
+
128
+ for x in range(start_x, end_x):
129
+ for y in range(start_y, end_y):
130
+ if displacements[2 * k, y,
131
+ x] != 0 or displacements[2 * k + 1, y,
132
+ x] != 0:
133
+ if diagonal_length > instance_size_map[y, x]:
134
+ # keep the gt displacement of smaller instance
135
+ continue
136
+
137
+ displacement_weights[2 * k:2 * k + 2, y,
138
+ x] = 1 / diagonal_length
139
+ displacements[2 * k:2 * k + 2, y,
140
+ x] = keypoints[n, k] - [x, y]
141
+ instance_size_map[y, x] = diagonal_length
142
+
143
+ return displacements, displacement_weights
mmpose/codecs/utils/oks_map.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from scipy.spatial.distance import cdist
7
+
8
+
9
+ def generate_oks_maps(
10
+ heatmap_size: Tuple[int, int],
11
+ keypoints: np.ndarray,
12
+ keypoints_visible: np.ndarray,
13
+ keypoints_visibility: np.ndarray,
14
+ sigma: float = 0.55,
15
+ increase_sigma_with_padding: bool = False,
16
+ ) -> Tuple[np.ndarray, np.ndarray]:
17
+ """Generate gaussian heatmaps of keypoints using `UDP`_.
18
+
19
+ Args:
20
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
21
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
22
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
23
+ (N, K)
24
+ sigma (float): The sigma value of the Gaussian heatmap
25
+ keypoints_visibility (np.ndarray): The visibility bit for each keypoint (N, K)
26
+ increase_sigma_with_padding (bool): Whether to increase the sigma
27
+ value with padding. Default: False
28
+
29
+ Returns:
30
+ tuple:
31
+ - heatmaps (np.ndarray): The generated heatmap in shape
32
+ (K, H, W) where [W, H] is the `heatmap_size`
33
+ - keypoint_weights (np.ndarray): The target weights in shape
34
+ (N, K)
35
+
36
+ .. _`UDP`: https://arxiv.org/abs/1911.07524
37
+ """
38
+
39
+ N, K, _ = keypoints.shape
40
+ W, H = heatmap_size
41
+
42
+ # The default sigmas are used for COCO dataset.
43
+ sigmas = np.array(
44
+ [2.6, 2.5, 2.5, 3.5, 3.5, 7.9, 7.9, 7.2, 7.2, 6.2, 6.2, 10.7, 10.7, 8.7, 8.7, 8.9, 8.9])/100
45
+ # sigmas = sigmas * 2 / sigmas.mean()
46
+ # sigmas = np.round(sigmas).astype(int)
47
+ # sigmas = np.clip(sigmas, 1, 10)
48
+
49
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
50
+ keypoint_weights = keypoints_visible.copy()
51
+
52
+ # bbox_area = W/1.25 * H/1.25
53
+ # bbox_area = W * H * 0.53
54
+ bbox_area = np.sqrt(H/1.25 * W/1.25)
55
+
56
+ # print(scales_arr)
57
+ # print(scaled_sigmas)
58
+
59
+ for n, k in product(range(N), range(K)):
60
+ kpt_sigma = sigmas[k]
61
+ # skip unlabled keypoints
62
+ if keypoints_visible[n, k] < 0.5:
63
+ continue
64
+
65
+ y_idx, x_idx = np.indices((H, W))
66
+ dx = x_idx - keypoints[n, k, 0]
67
+ dy = y_idx - keypoints[n, k, 1]
68
+ dist = np.sqrt(dx**2 + dy**2)
69
+
70
+ # e_map = (dx**2 + dy**2) / ((kpt_sigma*100)**2 * sigma)
71
+ vars = (kpt_sigma*2)**2
72
+ s = vars * bbox_area * 2
73
+ s = np.clip(s, 0.55, 3.0)
74
+ if sigma is not None and sigma > 0:
75
+ s = sigma
76
+ e_map = dist**2 / (2*s)
77
+ oks_map = np.exp(-e_map)
78
+
79
+ keypoint_weights[n, k] = (oks_map.max() > 0).astype(int)
80
+
81
+ # Scale such that there is always 1 at the maximum
82
+ if oks_map.max() > 1e-3:
83
+ oks_map = oks_map / oks_map.max()
84
+
85
+ # Scale OKS map such that 1 stays 1 and 0.5 becomes 0
86
+ # oks_map[oks_map < 0.5] = 0
87
+ # oks_map = 2 * oks_map - 1
88
+
89
+
90
+ # oks_map[oks_map > 0.95] = 1
91
+ # print("{:.4f}, {:7.1f}, {:9.3f}, {:9.3f}, {:4.2f}".format(vars, bbox_area, vars * bbox_area* 2, s, oks_map.max()))
92
+ # if np.all(oks_map < 0.1):
93
+ # print("\t{:d} --> {:.4f}".format(k, s))
94
+ heatmaps[k] = oks_map
95
+ # breakpoint()
96
+
97
+ return heatmaps, keypoint_weights
mmpose/codecs/utils/post_processing.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import product
3
+ from typing import Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from scipy.signal import convolve2d
12
+
13
+
14
+ def get_simcc_normalized(batch_pred_simcc, sigma=None):
15
+ """Normalize the predicted SimCC.
16
+
17
+ Args:
18
+ batch_pred_simcc (torch.Tensor): The predicted SimCC.
19
+ sigma (float): The sigma of the Gaussian distribution.
20
+
21
+ Returns:
22
+ torch.Tensor: The normalized SimCC.
23
+ """
24
+ B, K, _ = batch_pred_simcc.shape
25
+
26
+ # Scale and clamp the tensor
27
+ if sigma is not None:
28
+ batch_pred_simcc = batch_pred_simcc / (sigma * np.sqrt(np.pi * 2))
29
+ batch_pred_simcc = batch_pred_simcc.clamp(min=0)
30
+
31
+ # Compute the binary mask
32
+ mask = (batch_pred_simcc.amax(dim=-1) > 1).reshape(B, K, 1)
33
+
34
+ # Normalize the tensor using the maximum value
35
+ norm = (batch_pred_simcc / batch_pred_simcc.amax(dim=-1).reshape(B, K, 1))
36
+
37
+ # Apply normalization
38
+ batch_pred_simcc = torch.where(mask, norm, batch_pred_simcc)
39
+
40
+ return batch_pred_simcc
41
+
42
+
43
+ def get_simcc_maximum(simcc_x: np.ndarray,
44
+ simcc_y: np.ndarray,
45
+ apply_softmax: bool = False
46
+ ) -> Tuple[np.ndarray, np.ndarray]:
47
+ """Get maximum response location and value from simcc representations.
48
+
49
+ Note:
50
+ instance number: N
51
+ num_keypoints: K
52
+ heatmap height: H
53
+ heatmap width: W
54
+
55
+ Args:
56
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
57
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
58
+ apply_softmax (bool): whether to apply softmax on the heatmap.
59
+ Defaults to False.
60
+
61
+ Returns:
62
+ tuple:
63
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
64
+ (K, 2) or (N, K, 2)
65
+ - vals (np.ndarray): values of maximum heatmap responses in shape
66
+ (K,) or (N, K)
67
+ """
68
+
69
+ assert isinstance(simcc_x, np.ndarray), ('simcc_x should be numpy.ndarray')
70
+ assert isinstance(simcc_y, np.ndarray), ('simcc_y should be numpy.ndarray')
71
+ assert simcc_x.ndim == 2 or simcc_x.ndim == 3, (
72
+ f'Invalid shape {simcc_x.shape}')
73
+ assert simcc_y.ndim == 2 or simcc_y.ndim == 3, (
74
+ f'Invalid shape {simcc_y.shape}')
75
+ assert simcc_x.ndim == simcc_y.ndim, (
76
+ f'{simcc_x.shape} != {simcc_y.shape}')
77
+
78
+ if simcc_x.ndim == 3:
79
+ N, K, Wx = simcc_x.shape
80
+ simcc_x = simcc_x.reshape(N * K, -1)
81
+ simcc_y = simcc_y.reshape(N * K, -1)
82
+ else:
83
+ N = None
84
+
85
+ if apply_softmax:
86
+ simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True)
87
+ simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True)
88
+ ex, ey = np.exp(simcc_x), np.exp(simcc_y)
89
+ simcc_x = ex / np.sum(ex, axis=1, keepdims=True)
90
+ simcc_y = ey / np.sum(ey, axis=1, keepdims=True)
91
+
92
+ x_locs = np.argmax(simcc_x, axis=1)
93
+ y_locs = np.argmax(simcc_y, axis=1)
94
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
95
+ max_val_x = np.amax(simcc_x, axis=1)
96
+ max_val_y = np.amax(simcc_y, axis=1)
97
+
98
+ mask = max_val_x > max_val_y
99
+ max_val_x[mask] = max_val_y[mask]
100
+ vals = max_val_x
101
+ locs[vals <= 0.] = -1
102
+
103
+ if N:
104
+ locs = locs.reshape(N, K, 2)
105
+ vals = vals.reshape(N, K)
106
+
107
+ return locs, vals
108
+
109
+
110
+ def get_heatmap_3d_maximum(heatmaps: np.ndarray
111
+ ) -> Tuple[np.ndarray, np.ndarray]:
112
+ """Get maximum response location and value from heatmaps.
113
+
114
+ Note:
115
+ batch_size: B
116
+ num_keypoints: K
117
+ heatmap dimension: D
118
+ heatmap height: H
119
+ heatmap width: W
120
+
121
+ Args:
122
+ heatmaps (np.ndarray): Heatmaps in shape (K, D, H, W) or
123
+ (B, K, D, H, W)
124
+
125
+ Returns:
126
+ tuple:
127
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
128
+ (K, 3) or (B, K, 3)
129
+ - vals (np.ndarray): values of maximum heatmap responses in shape
130
+ (K,) or (B, K)
131
+ """
132
+ assert isinstance(heatmaps,
133
+ np.ndarray), ('heatmaps should be numpy.ndarray')
134
+ assert heatmaps.ndim == 4 or heatmaps.ndim == 5, (
135
+ f'Invalid shape {heatmaps.shape}')
136
+
137
+ if heatmaps.ndim == 4:
138
+ K, D, H, W = heatmaps.shape
139
+ B = None
140
+ heatmaps_flatten = heatmaps.reshape(K, -1)
141
+ else:
142
+ B, K, D, H, W = heatmaps.shape
143
+ heatmaps_flatten = heatmaps.reshape(B * K, -1)
144
+
145
+ z_locs, y_locs, x_locs = np.unravel_index(
146
+ np.argmax(heatmaps_flatten, axis=1), shape=(D, H, W))
147
+ locs = np.stack((x_locs, y_locs, z_locs), axis=-1).astype(np.float32)
148
+ vals = np.amax(heatmaps_flatten, axis=1)
149
+ locs[vals <= 0.] = -1
150
+
151
+ if B:
152
+ locs = locs.reshape(B, K, 3)
153
+ vals = vals.reshape(B, K)
154
+
155
+ return locs, vals
156
+
157
+
158
+ def get_heatmap_maximum(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
159
+ """Get maximum response location and value from heatmaps.
160
+
161
+ Note:
162
+ batch_size: B
163
+ num_keypoints: K
164
+ heatmap height: H
165
+ heatmap width: W
166
+
167
+ Args:
168
+ heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W)
169
+
170
+ Returns:
171
+ tuple:
172
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
173
+ (K, 2) or (B, K, 2)
174
+ - vals (np.ndarray): values of maximum heatmap responses in shape
175
+ (K,) or (B, K)
176
+ """
177
+ assert isinstance(heatmaps,
178
+ np.ndarray), ('heatmaps should be numpy.ndarray')
179
+ assert heatmaps.ndim == 3 or heatmaps.ndim == 4, (
180
+ f'Invalid shape {heatmaps.shape}')
181
+
182
+ if heatmaps.ndim == 3:
183
+ K, H, W = heatmaps.shape
184
+ B = None
185
+ heatmaps_flatten = heatmaps.reshape(K, -1)
186
+ else:
187
+ B, K, H, W = heatmaps.shape
188
+ heatmaps_flatten = heatmaps.reshape(B * K, -1)
189
+
190
+ y_locs, x_locs = np.unravel_index(
191
+ np.argmax(heatmaps_flatten, axis=1), shape=(H, W))
192
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
193
+ vals = np.amax(heatmaps_flatten, axis=1)
194
+ locs[vals <= 0.] = -1
195
+
196
+ if B:
197
+ locs = locs.reshape(B, K, 2)
198
+ vals = vals.reshape(B, K)
199
+
200
+ return locs, vals
201
+
202
+
203
+ def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray:
204
+ """Modulate heatmap distribution with Gaussian.
205
+
206
+ Note:
207
+ - num_keypoints: K
208
+ - heatmap height: H
209
+ - heatmap width: W
210
+
211
+ Args:
212
+ heatmaps (np.ndarray[K, H, W]): model predicted heatmaps.
213
+ kernel (int): Gaussian kernel size (K) for modulation, which should
214
+ match the heatmap gaussian sigma when training.
215
+ K=17 for sigma=3 and k=11 for sigma=2.
216
+
217
+ Returns:
218
+ np.ndarray ([K, H, W]): Modulated heatmap distribution.
219
+ """
220
+ assert kernel % 2 == 1
221
+
222
+ border = (kernel - 1) // 2
223
+ K, H, W = heatmaps.shape
224
+
225
+ for k in range(K):
226
+ origin_max = np.max(heatmaps[k])
227
+ dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
228
+ dr[border:-border, border:-border] = heatmaps[k].copy()
229
+ dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
230
+ heatmaps[k] = dr[border:-border, border:-border].copy()
231
+ heatmaps[k] *= origin_max / (np.max(heatmaps[k])+1e-12)
232
+ return heatmaps
233
+
234
+
235
+ def gaussian_blur1d(simcc: np.ndarray, kernel: int = 11) -> np.ndarray:
236
+ """Modulate simcc distribution with Gaussian.
237
+
238
+ Note:
239
+ - num_keypoints: K
240
+ - simcc length: Wx
241
+
242
+ Args:
243
+ simcc (np.ndarray[K, Wx]): model predicted simcc.
244
+ kernel (int): Gaussian kernel size (K) for modulation, which should
245
+ match the simcc gaussian sigma when training.
246
+ K=17 for sigma=3 and k=11 for sigma=2.
247
+
248
+ Returns:
249
+ np.ndarray ([K, Wx]): Modulated simcc distribution.
250
+ """
251
+ assert kernel % 2 == 1
252
+
253
+ border = (kernel - 1) // 2
254
+ N, K, Wx = simcc.shape
255
+
256
+ for n, k in product(range(N), range(K)):
257
+ origin_max = np.max(simcc[n, k])
258
+ dr = np.zeros((1, Wx + 2 * border), dtype=np.float32)
259
+ dr[0, border:-border] = simcc[n, k].copy()
260
+ dr = cv2.GaussianBlur(dr, (kernel, 1), 0)
261
+ simcc[n, k] = dr[0, border:-border].copy()
262
+ simcc[n, k] *= origin_max / np.max(simcc[n, k])
263
+ return simcc
264
+
265
+
266
+ def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5):
267
+ """Apply NMS on a batch of heatmaps.
268
+
269
+ Args:
270
+ batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W)
271
+ kernel_size (int): The kernel size of the NMS which should be
272
+ a odd integer. Defaults to 5
273
+
274
+ Returns:
275
+ Tensor: The batch heatmaps after NMS.
276
+ """
277
+
278
+ assert isinstance(kernel_size, int) and kernel_size % 2 == 1, \
279
+ f'The kernel_size should be an odd integer, got {kernel_size}'
280
+
281
+ padding = (kernel_size - 1) // 2
282
+
283
+ maximum = F.max_pool2d(
284
+ batch_heatmaps, kernel_size, stride=1, padding=padding)
285
+ maximum_indicator = torch.eq(batch_heatmaps, maximum)
286
+ batch_heatmaps = batch_heatmaps * maximum_indicator.float()
287
+
288
+ return batch_heatmaps
289
+
290
+
291
+ def get_heatmap_expected_value(heatmaps: np.ndarray, parzen_size: float = 0.1, return_heatmap: bool = False) -> Tuple[np.ndarray, np.ndarray]:
292
+ """Get maximum response location and value from heatmaps.
293
+
294
+ Note:
295
+ batch_size: B
296
+ num_keypoints: K
297
+ heatmap height: H
298
+ heatmap width: W
299
+
300
+ Args:
301
+ heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W)
302
+
303
+ Returns:
304
+ tuple:
305
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
306
+ (K, 2) or (B, K, 2)
307
+ - vals (np.ndarray): values of maximum heatmap responses in shape
308
+ (K,) or (B, K)
309
+ """
310
+ assert isinstance(heatmaps,
311
+ np.ndarray), ('heatmaps should be numpy.ndarray')
312
+ assert heatmaps.ndim == 3 or heatmaps.ndim == 4, (
313
+ f'Invalid shape {heatmaps.shape}')
314
+
315
+ assert parzen_size >= 0.0 and parzen_size <= 1.0, (
316
+ f'Invalid parzen_size {parzen_size}')
317
+
318
+ if heatmaps.ndim == 3:
319
+ K, H, W = heatmaps.shape
320
+ B = 1
321
+ FIRST_DIM = K
322
+ heatmaps_flatten = heatmaps.reshape(1, K, H, W)
323
+ else:
324
+ B, K, H, W = heatmaps.shape
325
+ FIRST_DIM = K*B
326
+ heatmaps_flatten = heatmaps.reshape(B, K, H, W)
327
+
328
+ # Blur heatmaps with Gaussian
329
+ # heatmaps_flatten = gaussian_blur(heatmaps_flatten, kernel=9)
330
+
331
+ # Zero out pixels far from the maximum for each heatmap
332
+ # heatmaps_tmp = heatmaps_flatten.copy().reshape(B*K, H*W)
333
+ # y_locs, x_locs = np.unravel_index(
334
+ # np.argmax(heatmaps_tmp, axis=1), shape=(H, W))
335
+ # locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
336
+ # heatmaps_flatten = heatmaps_flatten.reshape(B*K, H, W)
337
+ # for i, x in enumerate(x_locs):
338
+ # y = y_locs[i]
339
+ # start_x = int(max(0, x - 0.2*W))
340
+ # end_x = int(min(W, x + 0.2*W))
341
+ # start_y = int(max(0, y - 0.2*H))
342
+ # end_y = int(min(H, y + 0.2*H))
343
+ # mask = np.zeros((H, W))
344
+ # mask[start_y:end_y, start_x:end_x] = 1
345
+ # heatmaps_flatten[i] = heatmaps_flatten[i] * mask
346
+ # heatmaps_flatten = heatmaps_flatten.reshape(B, K, H, W)
347
+
348
+
349
+ bbox_area = np.sqrt(H/1.25 * W/1.25)
350
+
351
+ kpt_sigmas = np.array(
352
+ [2.6, 2.5, 2.5, 3.5, 3.5, 7.9, 7.9, 7.2, 7.2, 6.2, 6.2, 10.7, 10.7, 8.7, 8.7, 8.9, 8.9])/100
353
+
354
+ heatmaps_covolved = np.zeros_like(heatmaps_flatten)
355
+ for k in range(K):
356
+ vars = (kpt_sigmas[k]*2)**2
357
+ s = vars * bbox_area * 2
358
+ s = np.clip(s, 0.55, 3.0)
359
+ radius = np.ceil(s * 3).astype(int)
360
+ diameter = 2*radius + 1
361
+ diameter = np.ceil(diameter).astype(int)
362
+ # kernel_sizes[kernel_sizes % 2 == 0] += 1
363
+ center = diameter // 2
364
+ dist_x = np.arange(diameter) - center
365
+ dist_y = np.arange(diameter) - center
366
+ dist_x, dist_y = np.meshgrid(dist_x, dist_y)
367
+ dist = np.sqrt(dist_x**2 + dist_y**2)
368
+ oks_kernel = np.exp(-dist**2 / (2 * s))
369
+ oks_kernel = oks_kernel / oks_kernel.sum()
370
+
371
+ htm = heatmaps_flatten[:, k, :, :].reshape(-1, H, W)
372
+ # htm = np.pad(htm, ((0, 0), (radius, radius), (radius, radius)), mode='symmetric')
373
+ # htm = torch.from_numpy(htm).float()
374
+ # oks_kernel = torch.from_numpy(oks_kernel).float().to(htm.device).reshape(1, diameter, diameter)
375
+ oks_kernel = oks_kernel.reshape(1, diameter, diameter)
376
+ htm_conv = np.zeros_like(htm)
377
+ for b in range(B):
378
+ htm_conv[b, :, :] = convolve2d(htm[b, :, :], oks_kernel[b, :, :], mode='same', boundary='symm')
379
+ # htm_conv = F.conv2d(htm.unsqueeze(1), oks_kernel.unsqueeze(1), padding='same')
380
+ # htm_conv = htm_conv[:, :, radius:-radius, radius:-radius]
381
+ htm_conv = htm_conv.reshape(-1, 1, H, W)
382
+ heatmaps_covolved[:, k, :, :] = htm_conv
383
+
384
+
385
+ heatmaps_covolved = heatmaps_covolved.reshape(B*K, H*W)
386
+ y_locs, x_locs = np.unravel_index(
387
+ np.argmax(heatmaps_covolved, axis=1), shape=(H, W))
388
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
389
+
390
+ # Apply mean-shift to get sub-pixel locations
391
+ locs = _get_subpixel_maximums(heatmaps_covolved.reshape(B*K, H, W), locs)
392
+ # breakpoint()
393
+
394
+
395
+ # heatmaps_sums = heatmaps_flatten.sum(axis=(1, 2))
396
+ # norm_heatmaps = heatmaps_flatten.copy()
397
+ # norm_heatmaps[heatmaps_sums > 0] = heatmaps_flatten[heatmaps_sums > 0] / heatmaps_sums[heatmaps_sums > 0, None, None]
398
+
399
+
400
+ # # Compute Parzen window with Gaussian blur along the edge instead of simple mirroring
401
+ # x_pad = int(parzen_size * W + 0.5)
402
+ # y_pad = int(parzen_size * H + 0.5)
403
+ # # x_pad = 0
404
+ # # y_pad = 0
405
+ # kernel_size = int(min(H, W)*parzen_size + 0.5)
406
+ # if kernel_size % 2 == 0:
407
+ # kernel_size += 1
408
+ # # norm_heatmaps_pad_blur = np.pad(norm_heatmaps, ((0, 0), (x_pad, x_pad), (y_pad, y_pad)), mode='symmetric')
409
+ # norm_heatmaps_pad = np.pad(norm_heatmaps, ((0, 0), (y_pad, y_pad), (x_pad, x_pad)), mode='constant', constant_values=0)
410
+ # norm_heatmaps_pad_blur = gaussian_blur(norm_heatmaps_pad, kernel=kernel_size)
411
+
412
+ # # norm_heatmaps_pad_blur[:, x_pad:-x_pad, y_pad:-y_pad] = norm_heatmaps
413
+
414
+ # norm_heatmaps_pad_sum = norm_heatmaps_pad_blur.sum(axis=(1, 2))
415
+ # norm_heatmaps_pad_blur[norm_heatmaps_pad_sum>0] = norm_heatmaps_pad_blur[norm_heatmaps_pad_sum>0] / norm_heatmaps_pad_sum[norm_heatmaps_pad_sum>0, None, None]
416
+
417
+ # # # Save the blurred heatmaps
418
+ # # for i in range(heatmaps.shape[0]):
419
+ # # tmp_htm = norm_heatmaps_pad_blur[i].copy()
420
+ # # tmp_htm = (tmp_htm - tmp_htm.min()) / (tmp_htm.max() - tmp_htm.min())
421
+ # # tmp_htm = (tmp_htm*255).astype(np.uint8)
422
+ # # tmp_htm = cv2.cvtColor(tmp_htm, cv2.COLOR_GRAY2BGR)
423
+ # # tmp_htm = cv2.applyColorMap(tmp_htm, cv2.COLORMAP_JET)
424
+
425
+ # # tmp_htm2 = norm_heatmaps_pad[i].copy()
426
+ # # tmp_htm2 = (tmp_htm2 - tmp_htm2.min()) / (tmp_htm2.max() - tmp_htm2.min())
427
+ # # tmp_htm2 = (tmp_htm2*255).astype(np.uint8)
428
+ # # tmp_htm2 = cv2.cvtColor(tmp_htm2, cv2.COLOR_GRAY2BGR)
429
+ # # tmp_htm2 = cv2.applyColorMap(tmp_htm2, cv2.COLORMAP_JET)
430
+
431
+ # # tmp_htm = cv2.addWeighted(tmp_htm, 0.5, tmp_htm2, 0.5, 0)
432
+
433
+ # # cv2.imwrite(f'heatmaps_blurred_{i}.png', tmp_htm)
434
+
435
+ # # norm_heatmaps_pad = np.pad(norm_heatmaps, ((0, 0), (x_pad, x_pad), (y_pad, y_pad)), mode='edge')
436
+
437
+ # y_idx, x_idx = np.indices(norm_heatmaps_pad_blur.shape[1:])
438
+
439
+ # # breakpoint()
440
+ # x_locs = np.sum(norm_heatmaps_pad_blur * x_idx, axis=(1, 2)) - x_pad
441
+ # y_locs = np.sum(norm_heatmaps_pad_blur * y_idx, axis=(1, 2)) - y_pad
442
+
443
+ # # mean_idx = np.argmax(heatmaps_flatten, axis=1)
444
+ # # x_locs, y_locs = np.unravel_index(mean_idx, shape=(H, W))
445
+ # # locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
446
+ # # breakpoint()
447
+ # # vals = heatmaps_flatten[np.arange(heatmaps_flatten.shape[0]), mean_idx]
448
+ # # locs[vals <= 0.] = -1
449
+
450
+ # # mean_idx = np.argmax(norm_heatmaps, axis=1)
451
+ # # y_locs, x_locs = np.unravel_index(
452
+ # # mean_idx, shape=(H, W))
453
+
454
+ # locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
455
+ # # vals = np.amax(heatmaps_flatten, axis=1)
456
+
457
+
458
+ x_locs_int = np.round(x_locs).astype(int)
459
+ x_locs_int = np.clip(x_locs_int, 0, W-1)
460
+ y_locs_int = np.round(y_locs).astype(int)
461
+ y_locs_int = np.clip(y_locs_int, 0, H-1)
462
+ vals = heatmaps_flatten[np.arange(B), np.arange(K), y_locs_int, x_locs_int]
463
+ # breakpoint()
464
+ # locs[vals <= 0.] = -1
465
+
466
+ # print(mean_idx)
467
+ # print(x_locs)
468
+ # print(y_locs)
469
+ # print(locs)
470
+ heatmaps_covolved = heatmaps_covolved.reshape(B, K, H, W)
471
+
472
+ if B > 1:
473
+ locs = locs.reshape(B, K, 2)
474
+ vals = vals.reshape(B, K)
475
+ heatmaps_covolved = heatmaps_covolved.reshape(B, K, H, W)
476
+ else:
477
+ locs = locs.reshape(K, 2)
478
+ vals = vals.reshape(K)
479
+ heatmaps_covolved = heatmaps_covolved.reshape(K, H, W)
480
+
481
+ if return_heatmap:
482
+ return locs, vals, heatmaps_covolved
483
+ else:
484
+ return locs, vals
485
+
486
+
487
+
488
+ def _get_subpixel_maximums(heatmaps, locs):
489
+ # Extract integer peak locations
490
+ x_locs = locs[:, 0].astype(np.int32)
491
+ y_locs = locs[:, 1].astype(np.int32)
492
+
493
+ # Ensure we are not near the boundaries (avoid boundary issues)
494
+ valid_mask = (x_locs > 0) & (x_locs < heatmaps.shape[2] - 1) & \
495
+ (y_locs > 0) & (y_locs < heatmaps.shape[1] - 1)
496
+
497
+ # Initialize the output array with the integer locations
498
+ subpixel_locs = locs.copy()
499
+
500
+ if np.any(valid_mask):
501
+ # Extract valid locations
502
+ x_locs_valid = x_locs[valid_mask]
503
+ y_locs_valid = y_locs[valid_mask]
504
+
505
+ # Compute gradients (dx, dy) and second derivatives (dxx, dyy)
506
+ dx = (heatmaps[valid_mask, y_locs_valid, x_locs_valid + 1] -
507
+ heatmaps[valid_mask, y_locs_valid, x_locs_valid - 1]) / 2.0
508
+ dy = (heatmaps[valid_mask, y_locs_valid + 1, x_locs_valid] -
509
+ heatmaps[valid_mask, y_locs_valid - 1, x_locs_valid]) / 2.0
510
+ dxx = heatmaps[valid_mask, y_locs_valid, x_locs_valid + 1] + \
511
+ heatmaps[valid_mask, y_locs_valid, x_locs_valid - 1] - \
512
+ 2 * heatmaps[valid_mask, y_locs_valid, x_locs_valid]
513
+ dyy = heatmaps[valid_mask, y_locs_valid + 1, x_locs_valid] + \
514
+ heatmaps[valid_mask, y_locs_valid - 1, x_locs_valid] - \
515
+ 2 * heatmaps[valid_mask, y_locs_valid, x_locs_valid]
516
+
517
+ # Avoid division by zero by setting a minimum threshold for the second derivatives
518
+ dxx = np.where(dxx != 0, dxx, 1e-6)
519
+ dyy = np.where(dyy != 0, dyy, 1e-6)
520
+
521
+ # Calculate the sub-pixel shift
522
+ subpixel_x_shift = -dx / dxx
523
+ subpixel_y_shift = -dy / dyy
524
+
525
+ # Update subpixel locations for valid indices
526
+ subpixel_locs[valid_mask, 0] += subpixel_x_shift
527
+ subpixel_locs[valid_mask, 1] += subpixel_y_shift
528
+
529
+ return subpixel_locs
530
+